mmagic.models.editors.dcgan.dcgan
¶
Module Contents¶
Classes¶
Implementation of `Unsupervised Representation Learning with Deep |
- class mmagic.models.editors.dcgan.dcgan.DCGAN(generator: ModelType, discriminator: Optional[ModelType] = None, data_preprocessor: Optional[Union[dict, mmengine.Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, noise_size: Optional[int] = None, ema_config: Optional[Dict] = None, loss_config: Optional[Dict] = None)[源代码]¶
Bases:
mmagic.models.base_models.BaseGAN
Implementation of Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks.
- Paper link:
<https://arxiv.org/abs/1511.06434>`_ (DCGAN).
Detailed architecture can be found in
DCGANGenerator
# noqa andDCGANDiscriminator
# noqa- disc_loss(disc_pred_fake: torch.Tensor, disc_pred_real: torch.Tensor) Tuple [源代码]¶
Get disc loss. DCGAN use the vanilla gan loss to train the discriminator.
- 参数
disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.
disc_pred_real (Tensor) – Discriminator’s prediction of the real images.
- 返回
Loss value and a dict of log variables.
- 返回类型
tuple[Tensor, dict]
- gen_loss(disc_pred_fake: torch.Tensor) Tuple [源代码]¶
Get gen loss. DCGAN use the vanilla gan loss to train the generator.
- 参数
disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.
- 返回
Loss value and a dict of log variables.
- 返回类型
tuple[Tensor, dict]
- train_discriminator(inputs: dict, data_samples: mmagic.structures.DataSample, optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor] [源代码]¶
Train discriminator.
- 参数
inputs (dict) – Inputs from dataloader.
data_samples (DataSample) – Data samples from dataloader.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- 返回
A
dict
of tensor for logging.- 返回类型
Dict[str, Tensor]
- train_generator(inputs: dict, data_samples: mmagic.structures.DataSample, optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor] [源代码]¶
Train generator.
- 参数
inputs (dict) – Inputs from dataloader.
data_samples (DataSample) – Data samples from dataloader. Do not used in generator’s training.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- 返回
A
dict
of tensor for logging.- 返回类型
Dict[str, Tensor]