Shortcuts

mmagic.models.editors.wgan_gp.wgan_gp

Module Contents

Classes

WGANGP

Implementation of Improved Training of Wasserstein GANs.

class mmagic.models.editors.wgan_gp.wgan_gp.WGANGP(*args, **kwargs)[source]

Bases: mmagic.models.base_models.BaseGAN

Implementation of Improved Training of Wasserstein GANs.

Paper link: https://arxiv.org/pdf/1704.00028

Detailed architecture can be found in WGANGPGenerator and WGANGPDiscriminator

disc_loss(real_data: torch.Tensor, fake_data: torch.Tensor, disc_pred_fake: torch.Tensor, disc_pred_real: torch.Tensor) Tuple[source]

Get disc loss. WGAN-GP use the wgan loss and gradient penalty to train the discriminator.

Parameters
  • real_data (Tensor) – Real input data.

  • fake_data (Tensor) – Fake input data.

  • disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

  • disc_pred_real (Tensor) – Discriminator’s prediction of the real images.

Returns

Loss value and a dict of log variables.

Return type

tuple[Tensor, dict]

gen_loss(disc_pred_fake: torch.Tensor) Tuple[source]

Get gen loss. DCGAN use the wgan loss to train the generator.

Parameters

disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

Returns

Loss value and a dict of log variables.

Return type

tuple[Tensor, dict]

train_discriminator(inputs: dict, data_samples: mmagic.structures.DataSample, optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor][source]

Train discriminator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (DataSample) – Data samples from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

train_generator(inputs: dict, data_samples: mmagic.structures.DataSample, optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor][source]

Train generator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (DataSample) – Data samples from dataloader. Do not used in generator’s training.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.