mmagic.models.editors.wgan_gp.wgan_gp
¶
Module Contents¶
Classes¶
Implementation of Improved Training of Wasserstein GANs. |
- class mmagic.models.editors.wgan_gp.wgan_gp.WGANGP(*args, **kwargs)[source]¶
Bases:
mmagic.models.base_models.BaseGAN
Implementation of Improved Training of Wasserstein GANs.
Paper link: https://arxiv.org/pdf/1704.00028
Detailed architecture can be found in
WGANGPGenerator
andWGANGPDiscriminator
- disc_loss(real_data: torch.Tensor, fake_data: torch.Tensor, disc_pred_fake: torch.Tensor, disc_pred_real: torch.Tensor) Tuple [source]¶
Get disc loss. WGAN-GP use the wgan loss and gradient penalty to train the discriminator.
- Parameters
real_data (Tensor) – Real input data.
fake_data (Tensor) – Fake input data.
disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.
disc_pred_real (Tensor) – Discriminator’s prediction of the real images.
- Returns
Loss value and a dict of log variables.
- Return type
tuple[Tensor, dict]
- gen_loss(disc_pred_fake: torch.Tensor) Tuple [source]¶
Get gen loss. DCGAN use the wgan loss to train the generator.
- Parameters
disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.
- Returns
Loss value and a dict of log variables.
- Return type
tuple[Tensor, dict]
- train_discriminator(inputs: dict, data_samples: mmagic.structures.DataSample, optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor] [source]¶
Train discriminator.
- Parameters
inputs (dict) – Inputs from dataloader.
data_samples (DataSample) – Data samples from dataloader.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- Returns
A
dict
of tensor for logging.- Return type
Dict[str, Tensor]
- train_generator(inputs: dict, data_samples: mmagic.structures.DataSample, optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor] [source]¶
Train generator.
- Parameters
inputs (dict) – Inputs from dataloader.
data_samples (DataSample) – Data samples from dataloader. Do not used in generator’s training.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- Returns
A
dict
of tensor for logging.- Return type
Dict[str, Tensor]