Shortcuts

mmagic.models.editors.sagan.sagan

Module Contents

Classes

SAGAN

Implementation of Self-Attention Generative Adversarial Networks.

Attributes

ModelType

TrainInput

mmagic.models.editors.sagan.sagan.ModelType[source]
mmagic.models.editors.sagan.sagan.TrainInput[source]
class mmagic.models.editors.sagan.sagan.SAGAN(generator: ModelType, discriminator: Optional[ModelType] = None, data_preprocessor: Optional[Union[dict, mmengine.Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, noise_size: Optional[int] = 128, num_classes: Optional[int] = None, ema_config: Optional[Dict] = None)[source]

Bases: mmagic.models.base_models.BaseConditionalGAN

Implementation of Self-Attention Generative Adversarial Networks.

<https://arxiv.org/abs/1805.08318>`_ (SAGAN), Spectral Normalization for Generative Adversarial Networks (SNGAN), and cGANs with Projection Discriminator (Proj-GAN).

Detailed architecture can be found in SNGANGenerator and ProjDiscriminator

Parameters
  • generator (ModelType) – The config or model of the generator.

  • discriminator (Optional[ModelType]) – The config or model of the discriminator. Defaults to None.

  • data_preprocessor (Optional[Union[dict, Config]]) – The pre-process config or DataPreprocessor.

  • generator_steps (int) – Number of times the generator was completely updated before the discriminator is updated. Defaults to 1.

  • discriminator_steps (int) – Number of times the discriminator was completely updated before the generator is updated. Defaults to 1.

  • noise_size (Optional[int]) – Size of the input noise vector. Default to 128.

  • num_classes (Optional[int]) – The number classes you would like to generate. Defaults to None.

  • ema_config (Optional[Dict]) – The config for generator’s exponential moving average setting. Defaults to None.

disc_loss(disc_pred_fake: torch.Tensor, disc_pred_real: torch.Tensor) Tuple[torch.Tensor, dict][source]

Get disc loss. SAGAN, SNGAN and Proj-GAN use hinge loss to train the discriminator.

Parameters
  • disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

  • disc_pred_real (Tensor) – Discriminator’s prediction of the real images.

Returns

Loss value and a dict of log variables.

Return type

Tuple[Tensor, dict]

gen_loss(disc_pred_fake: torch.Tensor) Tuple[torch.Tensor, dict][source]

Get disc loss. SAGAN, SNGAN and Proj-GAN use hinge loss to train the generator.

Parameters

disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.

Returns

Loss value and a dict of log variables.

Return type

Tuple[Tensor, dict]

train_discriminator(inputs: dict, data_samples: mmagic.structures.DataSample, optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor][source]

Train discriminator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (DataSample) – Data samples from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

train_generator(inputs: dict, data_samples: mmagic.structures.DataSample, optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor][source]

Train generator.

Parameters
  • inputs (dict) – Inputs from dataloader.

  • data_samples (DataSample) – Data samples from dataloader. Do not used in generator’s training.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, Tensor]

Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.