mmagic.models.editors.stylegan2.stylegan2
¶
Module Contents¶
Attributes¶
- class mmagic.models.editors.stylegan2.stylegan2.StyleGAN2(generator: ModelType, discriminator: Optional[ModelType] = None, data_preprocessor: Optional[Union[dict, mmengine.Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, ema_config: Optional[Dict] = None, loss_config=dict())[源代码]¶
Bases:
mmagic.models.base_models.BaseGAN
Implementation of Analyzing and Improving the Image Quality of Stylegan. # noqa.
Paper link: https://openaccess.thecvf.com/content_CVPR_2020/html/Karras_Analyzing_and_Improving_the_Image_Quality_of_StyleGAN_CVPR_2020_paper.html. # noqa
StyleGAN2Generator
andStyleGAN2Discriminator
- 参数
generator (ModelType) – The config or model of the generator.
discriminator (Optional[ModelType]) – The config or model of the discriminator. Defaults to None.
data_preprocessor (Optional[Union[dict, Config]]) – The pre-process config or
DataPreprocessor
.generator_steps (int) – The number of times the generator is completely updated before the discriminator is updated. Defaults to 1.
discriminator_steps (int) – The number of times the discriminator is completely updated before the generator is updated. Defaults to 1.
ema_config (Optional[Dict]) – The config for generator’s exponential moving average setting. Defaults to None.
- disc_loss(disc_pred_fake: torch.Tensor, disc_pred_real: torch.Tensor, real_imgs: torch.Tensor) Tuple [源代码]¶
- Get disc loss. StyleGANv2 use the non-saturating loss and R1
gradient penalty to train the discriminator.
- 参数
disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.
disc_pred_real (Tensor) – Discriminator’s prediction of the real images.
real_imgs (Tensor) – Input real images.
- 返回
Loss value and a dict of log variables.
- 返回类型
tuple[Tensor, dict]
- gen_loss(disc_pred_fake: torch.Tensor, batch_size: int) Tuple [源代码]¶
Get gen loss. StyleGANv2 use the non-saturating loss and generator path regularization to train the generator.
- 参数
disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.
batch_size (int) – Batch size for generating fake images.
- 返回
Loss value and a dict of log variables.
- 返回类型
tuple[Tensor, dict]
- train_discriminator(inputs: dict, data_samples: mmagic.structures.DataSample, optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor] [源代码]¶
Train discriminator.
- 参数
inputs (dict) – Inputs from dataloader.
data_samples (DataSample) – Data samples from dataloader.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- 返回
A
dict
of tensor for logging.- 返回类型
Dict[str, Tensor]
- train_generator(inputs: dict, data_samples: mmagic.structures.DataSample, optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor] [源代码]¶
Train generator.
- 参数
inputs (dict) – Inputs from dataloader.
data_samples (DataSample) – Data samples from dataloader. Do not used in generator’s training.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- 返回
A
dict
of tensor for logging.- 返回类型
Dict[str, Tensor]
- train_step(data: dict, optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor] [源代码]¶
Train GAN model. In the training of GAN models, generator and discriminator are updated alternatively. In MMagic’s design, self.train_step is called with data input. Therefore we always update discriminator, whose updating is relay on real data, and then determine if the generator needs to be updated based on the current number of iterations. More details about whether to update generator can be found in
should_gen_update()
.- 参数
data (dict) – Data sampled from dataloader.
optim_wrapper (OptimWrapperDict) – OptimWrapperDict instance contains OptimWrapper of generator and discriminator.
- 返回
A
dict
of tensor for logging.- 返回类型
Dict[str, torch.Tensor]