Shortcuts

Source code for mmagic.models.editors.pggan.pggan

# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial
from typing import Dict, List, Optional, Tuple, Union

import mmengine
import numpy as np
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
from mmengine import MessageHub
from mmengine.dist import get_world_size
from mmengine.model import is_model_wrapper
from mmengine.optim import OptimWrapper, OptimWrapperDict
from torch import Tensor

from mmagic.registry import MODELS
from mmagic.structures import DataSample
from mmagic.utils.typing import ForwardInputs, SampleList
from ...base_models import BaseGAN
from ...utils import get_valid_num_batches, set_requires_grad

[docs]ModelType = Union[Dict, nn.Module]
[docs]TrainInput = Union[dict, Tensor]
@MODELS.register_module('PGGAN') @MODELS.register_module()
[docs]class ProgressiveGrowingGAN(BaseGAN): """Progressive Growing Unconditional GAN. In this GAN model, we implement progressive growing training schedule, which is proposed in Progressive Growing of GANs for improved Quality, Stability and Variation, ICLR 2018. We highly recommend to use ``GrowScaleImgDataset`` for saving computational load in data pre-processing. Notes for **using PGGAN**: #. In official implementation, Tero uses gradient penalty with ``norm_mode="HWC"`` #. We do not implement ``minibatch_repeats`` where has been used in official Tensorflow implementation. Notes for resuming progressive growing GANs: Users should specify the ``prev_stage`` in ``train_cfg``. Otherwise, the model is possible to reset the optimizer status, which will bring inferior performance. For example, if your model is resumed from the `256` stage, you should set ``train_cfg=dict(prev_stage=256)``. Args: generator (dict): Config for generator. discriminator (dict): Config for discriminator. """ def __init__(self, generator, discriminator, data_preprocessor, nkimgs_per_scale, noise_size=None, interp_real=None, transition_kimgs: int = 600, prev_stage: int = 0, ema_config: Optional[Dict] = None): super().__init__(generator, discriminator, data_preprocessor, 1, 1, noise_size, ema_config) # register necessary training status self.register_buffer('shown_nkimg', torch.tensor(0.)) self.register_buffer('_curr_transition_weight', torch.tensor(1.)) if interp_real is None: interp_real = dict(mode='bilinear', align_corners=True) self.interp_real_to = partial(F.interpolate, **interp_real) self.scales, self.nkimgs = [], [] for k, v in nkimgs_per_scale.items(): # support for different data types if isinstance(k, str): k = (int(k), int(k)) elif isinstance(k, int): k = (k, k) else: assert mmengine.is_tuple_of(k, int) # sanity check for the order of scales assert len(self.scales) == 0 or k[0] > self.scales[-1][0] self.scales.append(k) self.nkimgs.append(v) self.cum_nkimgs = np.cumsum(self.nkimgs) self.curr_stage = 0 # dirty workaround for avoiding optimizer bug in resuming self.prev_stage = prev_stage # actually nkimgs shown at the end of per training stage self._actual_nkimgs = [] # In each scale, transit from previous torgb layer to newer torgb layer # with `transition_kimgs` imgs self.transition_kimgs = transition_kimgs # this buffer is used to resume model easily self.register_buffer( '_next_scale_int', torch.tensor(self.scales[0][0], dtype=torch.int32)) # TODO: init it with the same value as `_next_scale_int` # a dirty workaround for testing self.register_buffer( '_curr_scale_int', torch.tensor(self.scales[-1][0], dtype=torch.int32))
[docs] def forward(self, inputs: ForwardInputs, data_samples: Optional[list] = None, mode: Optional[str] = None) -> SampleList: """Sample images from noises by using the generator. Args: batch_inputs (ForwardInputs): Dict containing the necessary information (e.g. noise, num_batches, mode) to generate image. data_samples (Optional[list]): Data samples collated by :attr:`data_preprocessor`. Defaults to None. mode (Optional[str]): `mode` is not used in :class:`ProgressiveGrowingGAN`. Defaults to None. Returns: SampleList: A list of ``DataSample`` contain generated results. """ if isinstance(inputs, Tensor): noise = inputs curr_scale = transition_weight = None else: noise = inputs.get('noise', None) num_batches = get_valid_num_batches(inputs, data_samples) noise = self.noise_fn(noise, num_batches=num_batches) curr_scale = inputs.get('curr_scale', None) transition_weight = inputs.get('transition_weight', None) num_batches = noise.shape[0] # use `self.curr_scale` if curr_scale is None if curr_scale is None: # in training, 'curr_scale' will be set as attribute if hasattr(self, 'curr_scale'): curr_scale = self.curr_scale[0] # in testing, adopt '_curr_scale_int' from buffer as testing scale else: curr_scale = self._curr_scale_int.item() # use `self._curr_transition_weight` if `transition_weight` is None if transition_weight is None: transition_weight = self._curr_transition_weight.item() sample_model = self._get_valid_model(inputs) batch_sample_list = [] if sample_model in ['ema', 'orig']: if sample_model == 'ema': generator = self.generator_ema else: generator = self.generator outputs = generator( noise, curr_scale=curr_scale, transition_weight=transition_weight) outputs = self.data_preprocessor.destruct(outputs, data_samples) gen_sample = DataSample() if data_samples: gen_sample.update(data_samples) if isinstance(inputs, dict) and 'img' in inputs: gen_sample.gt_img = inputs['img'] gen_sample.fake_img = outputs gen_sample.sample_model = sample_model gen_sample.noise = noise batch_sample_list = gen_sample.split(allow_nonseq_value=True) else: # sample model is 'ema/orig' outputs_orig = self.generator( noise, curr_scale=curr_scale, transition_weight=transition_weight) outputs_ema = self.generator_ema( noise, curr_scale=curr_scale, transition_weight=transition_weight) outputs_orig = self.data_preprocessor.destruct( outputs_orig, data_samples) outputs_ema = self.data_preprocessor.destruct( outputs_ema, data_samples) gen_sample = DataSample() if data_samples: gen_sample.update(data_samples) if isinstance(inputs, dict) and 'img' in inputs: gen_sample.gt_img = inputs['img'] gen_sample.ema = DataSample(fake_img=outputs_ema) gen_sample.orig = DataSample(fake_img=outputs_orig) gen_sample.noise = noise gen_sample.sample_model = 'ema/orig' batch_sample_list = gen_sample.split(allow_nonseq_value=True) return batch_sample_list
[docs] def train_discriminator(self, inputs: Tensor, data_samples: List[DataSample], optimizer_wrapper: OptimWrapper ) -> Dict[str, Tensor]: """Train discriminator. Args: inputs (Tensor): Inputs from current resolution training. data_samples (List[DataSample]): Data samples from dataloader. Do not used in generator's training. optim_wrapper (OptimWrapper): OptimWrapper instance used to update model parameters. Returns: Dict[str, Tensor]: A ``dict`` of tensor for logging. """ real_imgs = inputs num_batches = len(data_samples) noise_batch = self.noise_fn(num_batches=num_batches) with torch.no_grad(): fake_imgs = self.generator( noise_batch, curr_scale=self.curr_scale[0], transition_weight=self._curr_transition_weight, return_noise=False) disc_pred_fake = self.discriminator( fake_imgs, curr_scale=self.curr_scale[0], transition_weight=self._curr_transition_weight) disc_pred_real = self.discriminator( real_imgs, curr_scale=self.curr_scale[0], transition_weight=self._curr_transition_weight) parsed_loss, log_vars = self.disc_loss(disc_pred_fake, disc_pred_real, fake_imgs, real_imgs) optimizer_wrapper.update_params(parsed_loss) return log_vars
[docs] def disc_loss(self, disc_pred_fake: Tensor, disc_pred_real: Tensor, fake_data: Tensor, real_data: Tensor) -> Tuple[Tensor, dict]: r"""Get disc loss. PGGAN use WGAN-GP's loss and discriminator shift loss to train the discriminator. .. math: L_{D} = \mathbb{E}_{z\sim{p_{z}}}D\left\(G\left\(z\right\)\right\) - \mathbb{E}_{x\sim{p_{data}}}D\left\(x\right\) + L_{GP} \\ L_{GP} = \lambda\mathbb{E}(\Vert\nabla_{\tilde{x}}D(\tilde{x}) \Vert_2-1)^2 \\ \tilde{x} = \epsilon x + (1-\epsilon)G(z) L_{shift} = Args: disc_pred_fake (Tensor): Discriminator's prediction of the fake images. disc_pred_real (Tensor): Discriminator's prediction of the real images. fake_data (Tensor): Generated images, used to calculate gradient penalty. real_data (Tensor): Real images, used to calculate gradient penalty. Returns: Tuple[Tensor, dict]: Loss value and a dict of log variables. """ losses_dict = dict() losses_dict['loss_disc_fake'] = disc_pred_fake.mean() losses_dict['loss_disc_real'] = -disc_pred_real.mean() # gradient penalty batch_size = real_data.size(0) alpha = torch.rand(batch_size, 1, 1, 1).to(real_data) # interpolate between real_data and fake_data interpolates = alpha * real_data + (1. - alpha) * fake_data interpolates = autograd.Variable(interpolates, requires_grad=True) disc_interpolates = self.discriminator( interpolates, curr_scale=self.curr_scale[0], transition_weight=self._curr_transition_weight) gradients = autograd.grad( outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones_like(disc_interpolates), create_graph=True, retain_graph=True, only_inputs=True)[0] # norm_mode is 'HWC' gradients_penalty = (( gradients.reshape(batch_size, -1).norm(2, dim=1) - 1)**2).mean() losses_dict['loss_gp'] = 10 * gradients_penalty losses_dict['loss_disc_shift'] = 0.001 * 0.5 * ( disc_pred_fake**2 + disc_pred_real**2) parsed_loss, log_vars = self.parse_losses(losses_dict) return parsed_loss, log_vars
[docs] def train_generator(self, inputs: Tensor, data_samples: List[DataSample], optimizer_wrapper: OptimWrapper) -> Dict[str, Tensor]: """Train generator. Args: inputs (Tensor): Inputs from current resolution training. data_samples (List[DataSample]): Data samples from dataloader. Do not used in generator's training. optim_wrapper (OptimWrapper): OptimWrapper instance used to update model parameters. Returns: Dict[str, Tensor]: A ``dict`` of tensor for logging. """ num_batches = len(data_samples) noise_batch = self.noise_fn(num_batches=num_batches) fake_imgs = self.generator( noise_batch, num_batches=num_batches, curr_scale=self.curr_scale[0], transition_weight=self._curr_transition_weight) disc_pred_fake_g = self.discriminator( fake_imgs, curr_scale=self.curr_scale[0], transition_weight=self._curr_transition_weight) parsed_loss, log_vars = self.gen_loss(disc_pred_fake_g) optimizer_wrapper.update_params(parsed_loss) return log_vars
[docs] def gen_loss(self, disc_pred_fake: Tensor) -> Tuple[Tensor, dict]: r"""Generator loss for PGGAN. PGGAN use WGAN's loss to train the generator. .. math: L_{G} = -\mathbb{E}_{z\sim{p_{z}}}D\left\(G\left\(z\right\)\right\) + L_{MSE} Args: disc_pred_fake (Tensor): Discriminator's prediction of the fake images. recon_imgs (Tensor): Reconstructive images. Returns: Tuple[Tensor, dict]: Loss value and a dict of log variables. """ losses_dict = dict() losses_dict['loss_gen'] = -disc_pred_fake.mean() loss, log_vars = self.parse_losses(losses_dict) return loss, log_vars
[docs] def train_step(self, data: dict, optim_wrapper: OptimWrapperDict): """Train step function. This function implements the standard training iteration for asynchronous adversarial training. Namely, in each iteration, we first update discriminator and then compute loss for generator with the newly updated discriminator. As for distributed training, we use the ``reducer`` from ddp to synchronize the necessary params in current computational graph. Args: data_batch (dict): Input data from dataloader. optimizer (dict): Dict contains optimizer for generator and discriminator. ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp. It is used to prepare for ``backward()`` in ddp. Defaults to None. running_status (dict | None, optional): Contains necessary basic information for training, e.g., iteration number. Defaults to None. Returns: dict: Contains 'log_vars', 'num_samples', and 'results'. """ message_hub = MessageHub.get_current_instance() curr_iter = message_hub.get_info('iter') # update current stage self.curr_stage = int( min( sum(self.cum_nkimgs <= self.shown_nkimg.item()), len(self.scales) - 1)) self.curr_scale = self.scales[self.curr_stage] self._curr_scale_int = self._next_scale_int.clone() if self.curr_stage != self.prev_stage: self.prev_stage = self.curr_stage self._actual_nkimgs.append(self.shown_nkimg.item()) data = self.data_preprocessor(data, True) data_sample = data['data_samples'] real_imgs = data_sample.gt_img curr_scale = str(self.curr_scale[0]) disc_optimizer_wrapper: OptimWrapper = optim_wrapper[ f'discriminator_{curr_scale}'] gen_optimizer_wrapper: OptimWrapper = optim_wrapper[ f'generator_{curr_scale}'] disc_accu_iters = disc_optimizer_wrapper._accumulative_counts # update training configs, like transition weight for torgb layers. # get current transition weight for interpolating two torgb layers if self.curr_stage == 0: transition_weight = 1. else: transition_weight = ( self.shown_nkimg.item() - self._actual_nkimgs[-1]) / self.transition_kimgs # clip to [0, 1] transition_weight = min(max(transition_weight, 0.), 1.) self._curr_transition_weight = torch.tensor(transition_weight).to( self._curr_transition_weight) if real_imgs.shape[2:] == self.curr_scale: pass elif real_imgs.shape[2] >= self.curr_scale[0] and real_imgs.shape[ 3] >= self.curr_scale[1]: real_imgs = self.interp_real_to(real_imgs, size=self.curr_scale) else: raise RuntimeError( f'The scale of real image {real_imgs.shape[2:]} is smaller ' f'than current scale {self.curr_scale}.') # normal gan training process with disc_optimizer_wrapper.optim_context(self.discriminator): log_vars = self.train_discriminator(real_imgs, data_sample, disc_optimizer_wrapper) # add 1 to `curr_iter` because iter is updated in train loop. # Whether to update the generator. We update generator with # discriminator is fully updated for `self.n_discriminator_steps` # iterations. And one full updating for discriminator contains # `disc_accu_counts` times of grad accumulations. if (curr_iter + 1) % (self.discriminator_steps * disc_accu_iters) == 0: set_requires_grad(self.discriminator, False) gen_accu_iters = gen_optimizer_wrapper._accumulative_counts log_vars_gen_list = [] # init optimizer wrapper status for generator manually gen_optimizer_wrapper.initialize_count_status( self.generator, 0, self.generator_steps * gen_accu_iters) for _ in range(self.generator_steps * gen_accu_iters): with gen_optimizer_wrapper.optim_context(self.generator): log_vars_gen = self.train_generator( real_imgs, data_sample, gen_optimizer_wrapper) log_vars_gen_list.append(log_vars_gen) log_vars_gen = self.gather_log_vars(log_vars_gen_list) log_vars_gen.pop('loss', None) # remove 'loss' from gen logs set_requires_grad(self.discriminator, True) # only do ema after generator update if self.with_ema_gen and (curr_iter + 1) >= ( self.ema_start * self.discriminator_steps * disc_accu_iters): self.generator_ema.update_parameters( self.generator.module if is_model_wrapper(self.generator) else self.generator) # if not update buffer, copy buffer from orig model if not self.generator_ema.update_buffers: self.generator_ema.sync_buffers( self.generator.module if is_model_wrapper( self.generator) else self.generator) elif self.with_ema_gen: # before ema, copy weights from orig self.generator_ema.sync_parameters( self.generator.module if is_model_wrapper(self.generator) else self.generator) log_vars.update(log_vars_gen) # add batch size info to log_vars _batch_size = real_imgs.shape[0] * get_world_size() self.shown_nkimg += (_batch_size / 1000.) log_vars.update( dict( shown_nkimg=self.shown_nkimg.item(), curr_scale=self.curr_scale[0], transition_weight=transition_weight)) # check if a new scale will be added in the next iteration _curr_stage = int( min( sum(self.cum_nkimgs <= self.shown_nkimg.item()), len(self.scales) - 1)) # in the next iteration, we will switch to a new scale if _curr_stage != self.curr_stage: # `self._next_scale_int` is updated at the end of `train_step` self._next_scale_int = self._next_scale_int * 2 return log_vars
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.