Source code for mmagic.models.editors.dcgan.dcgan
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Tuple
import torch
import torch.nn.functional as F
from mmengine.optim import OptimWrapper
from torch import Tensor
from mmagic.registry import MODELS
from mmagic.structures import DataSample
from ...base_models import BaseGAN
@MODELS.register_module()
[docs]class DCGAN(BaseGAN):
"""Implementation of `Unsupervised Representation Learning with Deep
Convolutional Generative Adversarial Networks`.
Paper link:
<https://arxiv.org/abs/1511.06434>`_ (DCGAN).
Detailed architecture can be found in
:class:`~mmagic.models.editors.dcgan.DCGANGenerator` # noqa
and
:class:`~mmagic.models.editors.dcgan.DCGANDiscriminator` # noqa
"""
[docs] def disc_loss(self, disc_pred_fake: Tensor,
disc_pred_real: Tensor) -> Tuple:
r"""Get disc loss. DCGAN use the vanilla gan loss to train
the discriminator.
Args:
disc_pred_fake (Tensor): Discriminator's prediction of the fake
images.
disc_pred_real (Tensor): Discriminator's prediction of the real
images.
Returns:
tuple[Tensor, dict]: Loss value and a dict of log variables.
"""
losses_dict = dict()
losses_dict['loss_disc_fake'] = F.binary_cross_entropy_with_logits(
disc_pred_fake, 0. * torch.ones_like(disc_pred_fake))
losses_dict['loss_disc_real'] = F.binary_cross_entropy_with_logits(
disc_pred_real, 1. * torch.ones_like(disc_pred_real))
loss, log_var = self.parse_losses(losses_dict)
return loss, log_var
[docs] def gen_loss(self, disc_pred_fake: Tensor) -> Tuple:
"""Get gen loss. DCGAN use the vanilla gan loss to train the generator.
Args:
disc_pred_fake (Tensor): Discriminator's prediction of the fake
images.
Returns:
tuple[Tensor, dict]: Loss value and a dict of log variables.
"""
losses_dict = dict()
losses_dict['loss_gen'] = F.binary_cross_entropy_with_logits(
disc_pred_fake, 1. * torch.ones_like(disc_pred_fake))
loss, log_var = self.parse_losses(losses_dict)
return loss, log_var
[docs] def train_discriminator(self, inputs: dict, data_samples: DataSample,
optimizer_wrapper: OptimWrapper
) -> Dict[str, Tensor]:
"""Train discriminator.
Args:
inputs (dict): Inputs from dataloader.
data_samples (DataSample): Data samples from dataloader.
optim_wrapper (OptimWrapper): OptimWrapper instance used to update
model parameters.
Returns:
Dict[str, Tensor]: A ``dict`` of tensor for logging.
"""
real_imgs = data_samples.gt_img
num_batches = real_imgs.shape[0]
noise_batch = self.noise_fn(num_batches=num_batches)
with torch.no_grad():
fake_imgs = self.generator(noise=noise_batch, return_noise=False)
disc_pred_fake = self.discriminator(fake_imgs)
disc_pred_real = self.discriminator(real_imgs)
parsed_losses, log_vars = self.disc_loss(disc_pred_fake,
disc_pred_real)
optimizer_wrapper.update_params(parsed_losses)
return log_vars
[docs] def train_generator(self, inputs: dict, data_samples: DataSample,
optimizer_wrapper: OptimWrapper) -> Dict[str, Tensor]:
"""Train generator.
Args:
inputs (dict): Inputs from dataloader.
data_samples (DataSample): Data samples from dataloader.
Do not used in generator's training.
optim_wrapper (OptimWrapper): OptimWrapper instance used to update
model parameters.
Returns:
Dict[str, Tensor]: A ``dict`` of tensor for logging.
"""
num_batches = len(data_samples)
noise = self.noise_fn(num_batches=num_batches)
fake_imgs = self.generator(noise=noise, return_noise=False)
disc_pred_fake = self.discriminator(fake_imgs)
parsed_loss, log_vars = self.gen_loss(disc_pred_fake)
optimizer_wrapper.update_params(parsed_loss)
return log_vars