Shortcuts

Source code for mmagic.models.editors.singan.singan

# Copyright (c) OpenMMLab. All rights reserved.
import pickle
from copy import deepcopy
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
from mmengine import Config, MessageHub, print_log
from mmengine.model import is_model_wrapper
from mmengine.optim import OptimWrapper, OptimWrapperDict
from torch import Tensor

from mmagic.models.utils import get_module_device
from mmagic.registry import MODELS
from mmagic.structures import DataSample
from mmagic.utils import ForwardInputs, SampleList
from ...base_models import BaseGAN
from ...utils import set_requires_grad

[docs]ModelType = Union[Dict, nn.Module]
[docs]TrainInput = Union[dict, Tensor]
@MODELS.register_module()
[docs]class SinGAN(BaseGAN): """SinGAN. This model implement the single image generative adversarial model proposed in: Singan: Learning a Generative Model from a Single Natural Image, ICCV'19. Notes for training: - This model should be trained with our dataset ``SinGANDataset``. - In training, the ``total_iters`` arguments is related to the number of scales in the image pyramid and ``iters_per_scale`` in the ``train_cfg``. You should set it carefully in the training config file. Notes for model architectures: - The generator and discriminator need ``num_scales`` in initialization. However, this arguments is generated by ``create_real_pyramid`` function from the ``singan_dataset.py``. The last element in the returned list (``stop_scale``) is the value for ``num_scales``. Pay attention that this scale is counted from zero. Please see our tutorial for SinGAN to obtain more details or our standard config for reference. Args: generator (ModelType): The config or model of the generator. discriminator (Optional[ModelType]): The config or model of the discriminator. Defaults to None. data_preprocessor (Optional[Union[dict, Config]]): The pre-process config or :class:`~mmagic.models.DataPreprocessor`. generator_steps (int): The number of times the generator is completely updated before the discriminator is updated. Defaults to 1. discriminator_steps (int): The number of times the discriminator is completely updated before the generator is updated. Defaults to 1. num_scales (int): The number of scales/stages in generator/ discriminator. Note that this number is counted from zero, which is the same as the original paper. Defaults to None. iters_per_scale (int): The training iteration for each resolution scale. Defaults to 2000. noise_weight_init (float): The initialize weight of fixed noise. Defaults to 0.1 lr_scheduler_args (Optional[dict]): Arguments for learning schedulers. Note that in SinGAN, we use MultiStepLR, which is the same as the original paper. If not passed, no learning schedule will be used. Defaults to None. test_pkl_data (Optional[str]): The path of pickle file which contains fixed noise and noise weight. This is must for test. Defaults to None. ema_config (Optional[Dict]): The config for generator's exponential moving average setting. Defaults to None. """ def __init__(self, generator: ModelType, discriminator: Optional[ModelType] = None, data_preprocessor: Optional[Union[dict, Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, num_scales: Optional[int] = None, iters_per_scale: int = 2000, noise_weight_init: int = 0.1, lr_scheduler_args: Optional[dict] = None, test_pkl_data: Optional[str] = None, ema_confg: Optional[dict] = None): super().__init__(generator, discriminator, data_preprocessor, generator_steps, discriminator_steps, None, ema_confg) self.iters_per_scale = iters_per_scale self.noise_weight_init = noise_weight_init if lr_scheduler_args: self.lr_scheduler_args = deepcopy(lr_scheduler_args) else: self.lr_scheduler_args = None # related to train self.num_scales = num_scales self.curr_stage = -1 self.noise_weights = [1] self.fixed_noises = [] self.reals = [] # related to test self.loaded_test_pkl = False self.pkl_data = test_pkl_data
[docs] def load_test_pkl(self): """Load pickle for test.""" if self.pkl_data is not None: with open(self.pkl_data, 'rb') as f: data = pickle.load(f) self.fixed_noises = self._from_numpy(data['fixed_noises']) self.noise_weights = self._from_numpy(data['noise_weights']) self.curr_stage = data['curr_stage'] print_log(f'Load pkl data from {self.pkl_data}', 'current') self.pkl_data = self.pkl_data self.loaded_test_pkl = True
[docs] def _from_numpy(self, data: Tuple[list, np.ndarray]) -> Tuple[Tensor, List[Tensor]]: """Convert input numpy array or list of numpy array to Tensor or list of Tensor. Args: data (Tuple[list, np.ndarray]): Input data to convert. Returns: Tuple[Tensor, List[Tensor]]: Converted Tensor or list of tensor. """ if isinstance(data, list): return [self._from_numpy(x) for x in data] if isinstance(data, np.ndarray): data = torch.from_numpy(data) device = get_module_device(self.generator) data = data.to(device) return data return data
[docs] def get_module(self, model: nn.Module, module_name: str) -> nn.Module: """Get an inner module from model. Since we will wrapper DDP for some model, we have to judge whether the module can be indexed directly. Args: model (nn.Module): This model may wrapped with DDP or not. module_name (str): The name of specific module. Return: nn.Module: Returned sub module. """ module = model.module if hasattr(model, 'module') else model return getattr(module, module_name)
[docs] def construct_fixed_noises(self): """Construct the fixed noises list used in SinGAN.""" for i, real in enumerate(self.reals): h, w = real.shape[-2:] if i == 0: noise = torch.randn(1, 1, h, w).to(real) self.fixed_noises.append(noise) else: noise = torch.zeros_like(real) self.fixed_noises.append(noise)
[docs] def forward(self, inputs: ForwardInputs, data_samples: Optional[list] = None, mode=None) -> List[DataSample]: """Forward function for SinGAN. For SinGAN, `inputs` should be a dict contains 'num_batches', 'mode' and other input arguments for the generator. Args: inputs (dict): Dict containing the necessary information (e.g., noise, num_batches, mode) to generate image. data_samples (Optional[list]): Data samples collated by :attr:`data_preprocessor`. Defaults to None. mode (Optional[str]): `mode` is not used in :class:`BaseConditionalGAN`. Defaults to None. """ # handle batch_inputs assert isinstance(inputs, dict), ( 'SinGAN only support dict type inputs in forward function.') gen_kwargs = deepcopy(inputs) num_batches = gen_kwargs.pop('num_batches', 1) assert num_batches == 1, ( 'SinGAN only support \'num_batches\' as 1, but receive ' f'{num_batches}.') sample_model = self._get_valid_model(inputs) gen_kwargs.pop('sample_model', None) # remove sample_model mode = gen_kwargs.pop('mode', mode) mode = 'rand' if mode is None else mode curr_scale = gen_kwargs.pop('curr_scale', self.curr_stage) self.fixed_noises = [ x.to(self.data_preprocessor.device) for x in self.fixed_noises ] batch_sample_list = [] if sample_model in ['ema', 'orig']: if sample_model == 'ema': generator = self.generator_ema else: generator = self.generator outputs = generator( None, fixed_noises=self.fixed_noises, noise_weights=self.noise_weights, rand_mode=mode, num_batches=1, curr_scale=curr_scale, **gen_kwargs) gen_sample = DataSample() # destruct if isinstance(outputs, dict): outputs['fake_img'] = self.data_preprocessor.destruct( outputs['fake_img'], data_samples) outputs['prev_res_list'] = [ self.data_preprocessor.destruct(r, data_samples) for r in outputs['prev_res_list'] ] gen_sample.fake_img = self.data_preprocessor.destruct( outputs['fake_img'], data_samples) # gen_sample.prev_res_list = self.data_preprocessor.destruct( # outputs['fake_img'], data_samples) else: outputs = self.data_preprocessor.destruct( outputs, data_samples) # save to data sample for idx in range(num_batches): gen_sample = DataSample() # save inputs to data sample if data_samples: gen_sample.update(data_samples[idx]) if isinstance(outputs, dict): gen_sample.fake_img = outputs['fake_img'][idx] gen_sample.prev_res_list = [ r[idx] for r in outputs['prev_res_list'] ] else: gen_sample.fake_img = outputs[idx] gen_sample.sample_model = sample_model batch_sample_list.append(gen_sample) else: # sample model is 'ema/orig' outputs_orig = self.generator( None, fixed_noises=self.fixed_noises, noise_weights=self.noise_weights, rand_mode=mode, num_batches=1, curr_scale=curr_scale, **gen_kwargs) outputs_ema = self.generator_ema( None, fixed_noises=self.fixed_noises, noise_weights=self.noise_weights, rand_mode=mode, num_batches=1, curr_scale=curr_scale, **gen_kwargs) # destruct if isinstance(outputs_orig, dict): outputs_orig['fake_img'] = self.data_preprocessor.destruct( outputs_orig['fake_img'], data_samples) outputs_orig['prev_res_list'] = [ self.data_preprocessor.destruct(r, data_samples) for r in outputs_orig['prev_res_list'] ] outputs_ema['fake_img'] = self.data_preprocessor.destruct( outputs_ema['fake_img'], data_samples) outputs_ema['prev_res_list'] = [ self.data_preprocessor.destruct(r, data_samples) for r in outputs_ema['prev_res_list'] ] else: outputs_orig = self.data_preprocessor.destruct( outputs_orig, data_samples) outputs_ema = self.data_preprocessor.destruct( outputs_ema, data_samples) # save to data sample for idx in range(num_batches): gen_sample = DataSample() gen_sample.ema = DataSample() gen_sample.orig = DataSample() # save inputs to data sample if data_samples: gen_sample.update(data_samples[idx]) if isinstance(outputs_orig, dict): gen_sample.ema.fake_img = outputs_ema['fake_img'][idx] gen_sample.ema.prev_res_list = [ r[idx] for r in outputs_ema['prev_res_list'] ] gen_sample.orig.fake_img = outputs_orig['fake_img'][idx] gen_sample.orig.prev_res_list = [ r[idx] for r in outputs_orig['prev_res_list'] ] else: gen_sample.ema.fake_img = outputs_ema[idx] gen_sample.orig.fake_img = outputs_orig[idx] gen_sample.sample_model = sample_model batch_sample_list.append(gen_sample) return batch_sample_list
[docs] def gen_loss(self, disc_pred_fake: Tensor, recon_imgs: Tensor) -> Tuple[Tensor, dict]: r"""Generator loss for SinGAN. SinGAN use WGAN's loss and MSE loss to train the generator. .. math: L_{D} = -\mathbb{E}_{z\sim{p_{z}}}D\left\(G\left\(z\right\)\right\) + L_{MSE} L_{MSE} = \text{mean} \Vert x - G(z) \Vert_2 Args: disc_pred_fake (Tensor): Discriminator's prediction of the fake images. recon_imgs (Tensor): Reconstructive images. Returns: Tuple[Tensor, dict]: Loss value and a dict of log variables. """ losses_dict = dict() losses_dict['loss_gen'] = -disc_pred_fake.mean() losses_dict['loss_mse'] = 10 * F.mse_loss(recon_imgs, self.reals[self.curr_stage]) loss, log_vars = self.parse_losses(losses_dict) return loss, log_vars
[docs] def disc_loss(self, disc_pred_fake: Tensor, disc_pred_real: Tensor, fake_data: Tensor, real_data: Tensor) -> Tuple[Tensor, dict]: r"""Get disc loss. SAGAN, SNGAN and Proj-GAN use hinge loss to train the generator. .. math: L_{D} = \mathbb{E}_{z\sim{p_{z}}}D\left\(G\left\(z\right\)\right\) - \mathbb{E}_{x\sim{p_{data}}}D\left\(x\right\) + L_{GP} \\ L_{GP} = \lambda\mathbb{E}(\Vert\nabla_{\tilde{x}}D(\tilde{x}) \Vert_2-1)^2 \\ \tilde{x} = \epsilon x + (1-\epsilon)G(z) Args: disc_pred_fake (Tensor): Discriminator's prediction of the fake images. disc_pred_real (Tensor): Discriminator's prediction of the real images. fake_data (Tensor): Generated images, used to calculate gradient penalty. real_data (Tensor): Real images, used to calculate gradient penalty. Returns: Tuple[Tensor, dict]: Loss value and a dict of log variables. """ losses_dict = dict() losses_dict['loss_disc_fake'] = disc_pred_fake.mean() losses_dict['loss_disc_real'] = -disc_pred_real.mean() # gradient penalty batch_size = real_data.size(0) alpha = torch.rand(batch_size, 1, 1, 1).to(real_data) # interpolate between real_data and fake_data interpolates = alpha * real_data + (1. - alpha) * fake_data interpolates = autograd.Variable(interpolates, requires_grad=True) disc_interpolates = self.discriminator( interpolates, curr_scale=self.curr_stage) gradients = autograd.grad( outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones_like(disc_interpolates), create_graph=True, retain_graph=True, only_inputs=True)[0] # norm_mode is 'pixel' gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() losses_dict['loss_gp'] = 0.1 * gradients_penalty parsed_loss, log_vars = self.parse_losses(losses_dict) return parsed_loss, log_vars
[docs] def train_generator(self, inputs: dict, data_samples: List[DataSample], optimizer_wrapper: OptimWrapper) -> Dict[str, Tensor]: """Train generator. Args: inputs (dict): Inputs from dataloader. data_samples (List[DataSample]): Data samples from dataloader. Do not used in generator's training. optim_wrapper (OptimWrapper): OptimWrapper instance used to update model parameters. Returns: Dict[str, Tensor]: A ``dict`` of tensor for logging. """ fake_imgs = self.generator( inputs['input_sample'], self.fixed_noises, self.noise_weights, rand_mode='rand', curr_scale=self.curr_stage) disc_pred_fake_g = self.discriminator( fake_imgs, curr_scale=self.curr_stage) recon_imgs = self.generator( inputs['input_sample'], self.fixed_noises, self.noise_weights, rand_mode='recon', curr_scale=self.curr_stage) parsed_loss, log_vars = self.gen_loss(disc_pred_fake_g, recon_imgs) optimizer_wrapper.update_params(parsed_loss) return log_vars
[docs] def train_discriminator(self, inputs: dict, data_samples: List[DataSample], optimizer_wrapper: OptimWrapper ) -> Dict[str, Tensor]: """Train discriminator. Args: inputs (dict): Inputs from dataloader. data_samples (List[DataSample]): Data samples from dataloader. optim_wrapper (OptimWrapper): OptimWrapper instance used to update model parameters. Returns: Dict[str, Tensor]: A ``dict`` of tensor for logging. """ input_sample = inputs['input_sample'] fake_imgs = self.generator( input_sample, self.fixed_noises, self.noise_weights, rand_mode='rand', curr_scale=self.curr_stage) # disc pred for fake imgs and real_imgs real_imgs = self.reals[self.curr_stage] disc_pred_fake = self.discriminator(fake_imgs.detach(), self.curr_stage) disc_pred_real = self.discriminator(real_imgs, self.curr_stage) parsed_loss, log_vars = self.disc_loss(disc_pred_fake, disc_pred_real, fake_imgs, real_imgs) optimizer_wrapper.update_params(parsed_loss) return log_vars
[docs] def train_gan(self, inputs_dict: dict, data_sample: List[DataSample], optim_wrapper: OptimWrapperDict) -> Dict[str, torch.Tensor]: """Train GAN model. In the training of GAN models, generator and discriminator are updated alternatively. In MMagic's design, `self.train_step` is called with data input. Therefore we always update discriminator, whose updating is relay on real data, and then determine if the generator needs to be updated based on the current number of iterations. More details about whether to update generator can be found in :meth:`should_gen_update`. Args: data (dict): Data sampled from dataloader. data_sample (List[DataSample]): List of data sample contains GT and meta information. optim_wrapper (OptimWrapperDict): OptimWrapperDict instance contains OptimWrapper of generator and discriminator. Returns: Dict[str, torch.Tensor]: A ``dict`` of tensor for logging. """ message_hub = MessageHub.get_current_instance() curr_iter = message_hub.get_info('iter') disc_optimizer_wrapper: OptimWrapper = optim_wrapper['discriminator'] disc_accu_iters = disc_optimizer_wrapper._accumulative_counts with disc_optimizer_wrapper.optim_context(self.discriminator): log_vars = self.train_discriminator(inputs_dict, data_sample, disc_optimizer_wrapper) # add 1 to `curr_iter` because iter is updated in train loop. # Whether to update the generator. We update generator with # discriminator is fully updated for `self.n_discriminator_steps` # iterations. And one full updating for discriminator contains # `disc_accu_counts` times of grad accumulations. if (curr_iter + 1) % (self.discriminator_steps * disc_accu_iters) == 0: set_requires_grad(self.discriminator, False) gen_optimizer_wrapper = optim_wrapper['generator'] gen_accu_iters = gen_optimizer_wrapper._accumulative_counts log_vars_gen_list = [] # init optimizer wrapper status for generator manually gen_optimizer_wrapper.initialize_count_status( self.generator, 0, self.generator_steps * gen_accu_iters) for _ in range(self.generator_steps * gen_accu_iters): with gen_optimizer_wrapper.optim_context(self.generator): log_vars_gen = self.train_generator( inputs_dict, data_sample, gen_optimizer_wrapper) log_vars_gen_list.append(log_vars_gen) log_vars_gen = self.gather_log_vars(log_vars_gen_list) log_vars_gen.pop('loss', None) # remove 'loss' from gen logs set_requires_grad(self.discriminator, True) # only do ema after generator update if self.with_ema_gen and (curr_iter + 1) >= ( self.ema_start * self.discriminator_steps * disc_accu_iters): self.generator_ema.update_parameters( self.generator.module if is_model_wrapper(self.generator) else self.generator) # if not update buffer, copy buffer from orig model if not self.generator_ema.update_buffers: self.generator_ema.sync_buffers( self.generator.module if is_model_wrapper( self.generator) else self.generator) elif self.with_ema_gen: # before ema, copy weights from orig self.generator_ema.sync_parameters( self.generator.module if is_model_wrapper(self.generator) else self.generator) log_vars.update(log_vars_gen) return log_vars
[docs] def train_step(self, data: dict, optim_wrapper: OptimWrapperDict) -> Dict[str, Tensor]: """Train step for SinGAN model. SinGAN is trained with multi-resolution images, and each resolution is trained for `:attr:self.iters_per_scale` times. We initialize the weight and learning rate scheduler of the corresponding module at the start of each resolution's training. At the end of each resolution's training, we update the weight of the noise of current resolution by mse loss between reconstructed image and real image. Args: data (dict): Data sampled from dataloader. optim_wrapper (OptimWrapperDict): OptimWrapperDict instance contains OptimWrapper of generator and discriminator. Returns: Dict[str, torch.Tensor]: A ``dict`` of tensor for logging. """ message_hub = MessageHub.get_current_instance() curr_iter = message_hub.get_info('iter') if curr_iter % (self.iters_per_scale * self.discriminator_steps) == 0: self.curr_stage += 1 # load weights from prev scale self.get_module(self.generator, 'check_and_load_prev_weight')( self.curr_stage) self.get_module(self.discriminator, 'check_and_load_prev_weight')( self.curr_stage) # assert grad_accumulation step is 1 curr_gen_optim = optim_wrapper[f'generator_{self.curr_stage}'] curr_disc_optim = optim_wrapper[f'discriminator_{self.curr_stage}'] _warning_msg = ('SinGAN do set batch size as 1 during training ' 'and do not support gradient accumulation.') assert curr_gen_optim._accumulative_counts == 1, _warning_msg assert curr_disc_optim._accumulative_counts == 1, _warning_msg # build parameters scheduler manually, because parameters_schedule # hook update all scheduler at the same if self.lr_scheduler_args: self.gen_scheduler = torch.optim.lr_scheduler.MultiStepLR( curr_gen_optim.optimizer, **self.lr_scheduler_args) self.disc_scheduler = torch.optim.lr_scheduler.MultiStepLR( curr_disc_optim.optimizer, **self.lr_scheduler_args) # get current optimizer_wrapper curr_optimizer_wrapper = OptimWrapperDict( generator=optim_wrapper[f'generator_{self.curr_stage}'], discriminator=optim_wrapper[f'discriminator_{self.curr_stage}']) # handle inputs data = self.data_preprocessor(data) inputs_dict, data_samples = data['inputs'], data['data_samples'] # setup fixed noises and reals pyramid if curr_iter == 0 or len(self.reals) == 0: keys = [k for k in inputs_dict.keys() if 'real_scale' in k] scales = len(keys) self.reals = [inputs_dict[f'real_scale{s}'] for s in range(scales)] # here we do not padding fixed noises self.construct_fixed_noises() # standard train step log_vars = self.train_gan(inputs_dict, data_samples, curr_optimizer_wrapper) log_vars['curr_stage'] = self.curr_stage # update noise weight if ((curr_iter + 1) % (self.iters_per_scale * self.discriminator_steps) == 0) and (self.curr_stage < len(self.reals) - 1): with torch.no_grad(): g_recon = self.generator( inputs_dict['input_sample'], self.fixed_noises, self.noise_weights, rand_mode='recon', curr_scale=self.curr_stage) if isinstance(g_recon, dict): g_recon = g_recon['fake_img'] g_recon = F.interpolate( g_recon, self.reals[self.curr_stage + 1].shape[-2:]) mse = F.mse_loss(g_recon.detach(), self.reals[self.curr_stage + 1]) rmse = torch.sqrt(mse) self.noise_weights.append(self.noise_weight_init * rmse.item()) # call scheduler when all submodules are fully updated. if (curr_iter + 1) % self.discriminator_steps == 0: if self.lr_scheduler_args: self.disc_scheduler.step() self.gen_scheduler.step() return log_vars
[docs] def test_step(self, data: dict) -> SampleList: """Gets the generated image of given data in test progress. Before generate images, we call `:meth:self.load_test_pkl` to load the fixed noise and current stage of the model from the pickle file. Args: data (dict): Data sampled from metric specific sampler. More details in `Metrics` and `Evaluator`. Returns: SampleList: A list of ``DataSample`` contain generated results. """ if not self.loaded_test_pkl: self.load_test_pkl() return super().test_step(data)
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.