mmagic.models.editors.pggan.pggan
¶
Module Contents¶
Classes¶
Progressive Growing Unconditional GAN. |
Attributes¶
- class mmagic.models.editors.pggan.pggan.ProgressiveGrowingGAN(generator, discriminator, data_preprocessor, nkimgs_per_scale, noise_size=None, interp_real=None, transition_kimgs: int = 600, prev_stage: int = 0, ema_config: Optional[Dict] = None)[源代码]¶
Bases:
mmagic.models.base_models.BaseGAN
Progressive Growing Unconditional GAN.
In this GAN model, we implement progressive growing training schedule, which is proposed in Progressive Growing of GANs for improved Quality, Stability and Variation, ICLR 2018.
We highly recommend to use
GrowScaleImgDataset
for saving computational load in data pre-processing.Notes for using PGGAN:
In official implementation, Tero uses gradient penalty with
norm_mode="HWC"
We do not implement
minibatch_repeats
where has been used in official Tensorflow implementation.
Notes for resuming progressive growing GANs: Users should specify the
prev_stage
intrain_cfg
. Otherwise, the model is possible to reset the optimizer status, which will bring inferior performance. For example, if your model is resumed from the 256 stage, you should settrain_cfg=dict(prev_stage=256)
.- 参数
generator (dict) – Config for generator.
discriminator (dict) – Config for discriminator.
- forward(inputs: mmagic.utils.typing.ForwardInputs, data_samples: Optional[list] = None, mode: Optional[str] = None) mmagic.utils.typing.SampleList [源代码]¶
Sample images from noises by using the generator.
- 参数
batch_inputs (ForwardInputs) – Dict containing the necessary information (e.g. noise, num_batches, mode) to generate image.
data_samples (Optional[list]) – Data samples collated by
data_preprocessor
. Defaults to None.mode (Optional[str]) – mode is not used in
ProgressiveGrowingGAN
. Defaults to None.
- 返回
A list of
DataSample
contain generated results.- 返回类型
SampleList
- train_discriminator(inputs: torch.Tensor, data_samples: List[mmagic.structures.DataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor] [源代码]¶
Train discriminator.
- 参数
inputs (Tensor) – Inputs from current resolution training.
data_samples (List[DataSample]) – Data samples from dataloader. Do not used in generator’s training.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- 返回
A
dict
of tensor for logging.- 返回类型
Dict[str, Tensor]
- disc_loss(disc_pred_fake: torch.Tensor, disc_pred_real: torch.Tensor, fake_data: torch.Tensor, real_data: torch.Tensor) Tuple[torch.Tensor, dict] [源代码]¶
Get disc loss. PGGAN use WGAN-GP’s loss and discriminator shift loss to train the discriminator.
- 参数
disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.
disc_pred_real (Tensor) – Discriminator’s prediction of the real images.
fake_data (Tensor) – Generated images, used to calculate gradient penalty.
real_data (Tensor) – Real images, used to calculate gradient penalty.
- 返回
Loss value and a dict of log variables.
- 返回类型
Tuple[Tensor, dict]
- train_generator(inputs: torch.Tensor, data_samples: List[mmagic.structures.DataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor] [源代码]¶
Train generator.
- 参数
inputs (Tensor) – Inputs from current resolution training.
data_samples (List[DataSample]) – Data samples from dataloader. Do not used in generator’s training.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- 返回
A
dict
of tensor for logging.- 返回类型
Dict[str, Tensor]
- gen_loss(disc_pred_fake: torch.Tensor) Tuple[torch.Tensor, dict] [源代码]¶
Generator loss for PGGAN. PGGAN use WGAN’s loss to train the generator.
- 参数
disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.
recon_imgs (Tensor) – Reconstructive images.
- 返回
Loss value and a dict of log variables.
- 返回类型
Tuple[Tensor, dict]
- train_step(data: dict, optim_wrapper: mmengine.optim.OptimWrapperDict)[源代码]¶
Train step function.
This function implements the standard training iteration for asynchronous adversarial training. Namely, in each iteration, we first update discriminator and then compute loss for generator with the newly updated discriminator.
As for distributed training, we use the
reducer
from ddp to synchronize the necessary params in current computational graph.- 参数
data_batch (dict) – Input data from dataloader.
optimizer (dict) – Dict contains optimizer for generator and discriminator.
ddp_reducer (
Reducer
| None, optional) – Reducer from ddp. It is used to prepare forbackward()
in ddp. Defaults to None.running_status (dict | None, optional) – Contains necessary basic information for training, e.g., iteration number. Defaults to None.
- 返回
Contains ‘log_vars’, ‘num_samples’, and ‘results’.
- 返回类型
dict