mmagic.models.editors.wgan_gp
¶
Package Contents¶
Classes¶
Discriminator for WGANGP. |
|
Generator for WGANGP. |
|
Implementation of Improved Training of Wasserstein GANs. |
- class mmagic.models.editors.wgan_gp.WGANGPDiscriminator(in_channel, in_scale, conv_module_cfg=None, init_cfg=None)[source]¶
Bases:
mmengine.model.BaseModule
Discriminator for WGANGP.
Implementation Details for WGANGP discriminator the same as training configuration (a) described in PGGAN paper: PROGRESSIVE GROWING OF GANS FOR IMPROVED QUALITY, STABILITY, AND VARIATION https://research.nvidia.com/sites/default/files/pubs/2017-10_Progressive-Growing-of/karras2018iclr-paper.pdf # noqa
Adopt convolution architecture specified in appendix A.2;
Add layer normalization to all conv3x3 and conv4x4 layers;
Use LeakyReLU in the discriminator except for the final output layer;
Initialize all weights using He’s initializer.
- Parameters
in_channel (int) – The channel number of the input image.
in_scale (int) – The scale of the input image.
conv_module_cfg (dict, optional) – Config for the convolution module used in this discriminator. Defaults to None.
init_cfg (dict, optional) – Initialization config dict.
- _default_channels_per_scale¶
- _default_conv_module_cfg¶
- _default_upsample_cfg¶
- class mmagic.models.editors.wgan_gp.WGANGPGenerator(noise_size, out_scale, conv_module_cfg=None, upsample_cfg=None, init_cfg=None)[source]¶
Bases:
mmengine.model.BaseModule
Generator for WGANGP.
Implementation Details for WGANGP generator the same as training configuration (a) described in PGGAN paper: PROGRESSIVE GROWING OF GANS FOR IMPROVED QUALITY, STABILITY, AND VARIATION https://research.nvidia.com/sites/default/files/pubs/2017-10_Progressive-Growing-of/karras2018iclr-paper.pdf # noqa
Adopt convolution architecture specified in appendix A.2;
Use batchnorm in the generator except for the final output layer;
Use ReLU in the generator except for the final output layer;
Use Tanh in the last layer;
Initialize all weights using He’s initializer.
- Parameters
noise_size (int) – Size of the input noise vector.
out_scale (int) – Output scale for the generated image.
conv_module_cfg (dict, optional) – Config for the convolution module used in this generator. Defaults to None.
upsample_cfg (dict, optional) – Config for the upsampling operation. Defaults to None.
init_cfg (dict, optional) – Initialization config dict.
- _default_channels_per_scale¶
- _default_conv_module_cfg¶
- _default_upsample_cfg¶
- forward(noise, num_batches=0, return_noise=False)[source]¶
Forward function.
- Parameters
noise (torch.Tensor | callable | None) – You can directly give a batch of noise through a
torch.Tensor
or offer a callable function to sample a batch of noise data. Otherwise, theNone
indicates to use the default noise sampler.num_batches (int, optional) – The number of batch size. Defaults to 0.
return_noise (bool, optional) – If True,
noise_batch
will be returned in a dict withfake_img
. Defaults to False.
- Returns
- If not
return_noise
, only the output image will be returned. Otherwise, a dict contains
fake_img
andnoise_batch
will be returned.
- If not
- Return type
torch.Tensor | dict
- class mmagic.models.editors.wgan_gp.WGANGP(*args, **kwargs)[source]¶
Bases:
mmagic.models.base_models.BaseGAN
Implementation of Improved Training of Wasserstein GANs.
Paper link: https://arxiv.org/pdf/1704.00028
Detailed architecture can be found in
WGANGPGenerator
andWGANGPDiscriminator
- disc_loss(real_data: torch.Tensor, fake_data: torch.Tensor, disc_pred_fake: torch.Tensor, disc_pred_real: torch.Tensor) Tuple [source]¶
Get disc loss. WGAN-GP use the wgan loss and gradient penalty to train the discriminator.
- Parameters
real_data (Tensor) – Real input data.
fake_data (Tensor) – Fake input data.
disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.
disc_pred_real (Tensor) – Discriminator’s prediction of the real images.
- Returns
Loss value and a dict of log variables.
- Return type
tuple[Tensor, dict]
- gen_loss(disc_pred_fake: torch.Tensor) Tuple [source]¶
Get gen loss. DCGAN use the wgan loss to train the generator.
- Parameters
disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.
- Returns
Loss value and a dict of log variables.
- Return type
tuple[Tensor, dict]
- train_discriminator(inputs: dict, data_samples: mmagic.structures.DataSample, optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor] [source]¶
Train discriminator.
- Parameters
inputs (dict) – Inputs from dataloader.
data_samples (DataSample) – Data samples from dataloader.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- Returns
A
dict
of tensor for logging.- Return type
Dict[str, Tensor]
- train_generator(inputs: dict, data_samples: mmagic.structures.DataSample, optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor] [source]¶
Train generator.
- Parameters
inputs (dict) – Inputs from dataloader.
data_samples (DataSample) – Data samples from dataloader. Do not used in generator’s training.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- Returns
A
dict
of tensor for logging.- Return type
Dict[str, Tensor]