Source code for mmagic.models.editors.singan.singan
# Copyright (c) OpenMMLab. All rights reserved.
import pickle
from copy import deepcopy
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
from mmengine import Config, MessageHub, print_log
from mmengine.model import is_model_wrapper
from mmengine.optim import OptimWrapper, OptimWrapperDict
from torch import Tensor
from mmagic.models.utils import get_module_device
from mmagic.registry import MODELS
from mmagic.structures import DataSample
from mmagic.utils import ForwardInputs, SampleList
from ...base_models import BaseGAN
from ...utils import set_requires_grad
@MODELS.register_module()
[docs]class SinGAN(BaseGAN):
"""SinGAN.
This model implement the single image generative adversarial model proposed
in: Singan: Learning a Generative Model from a Single Natural Image,
ICCV'19.
Notes for training:
- This model should be trained with our dataset ``SinGANDataset``.
- In training, the ``total_iters`` arguments is related to the number of
scales in the image pyramid and ``iters_per_scale`` in the ``train_cfg``.
You should set it carefully in the training config file.
Notes for model architectures:
- The generator and discriminator need ``num_scales`` in initialization.
However, this arguments is generated by ``create_real_pyramid`` function
from the ``singan_dataset.py``. The last element in the returned list
(``stop_scale``) is the value for ``num_scales``. Pay attention that this
scale is counted from zero. Please see our tutorial for SinGAN to obtain
more details or our standard config for reference.
Args:
generator (ModelType): The config or model of the generator.
discriminator (Optional[ModelType]): The config or model of the
discriminator. Defaults to None.
data_preprocessor (Optional[Union[dict, Config]]): The pre-process
config or :class:`~mmagic.models.DataPreprocessor`.
generator_steps (int): The number of times the generator is completely
updated before the discriminator is updated. Defaults to 1.
discriminator_steps (int): The number of times the discriminator is
completely updated before the generator is updated. Defaults to 1.
num_scales (int): The number of scales/stages in generator/
discriminator. Note that this number is counted from zero, which
is the same as the original paper. Defaults to None.
iters_per_scale (int): The training iteration for each resolution
scale. Defaults to 2000.
noise_weight_init (float): The initialize weight of fixed noise.
Defaults to 0.1
lr_scheduler_args (Optional[dict]): Arguments for learning schedulers.
Note that in SinGAN, we use MultiStepLR, which is the same as the
original paper. If not passed, no learning schedule will be used.
Defaults to None.
test_pkl_data (Optional[str]): The path of pickle file which contains
fixed noise and noise weight. This is must for test. Defaults to
None.
ema_config (Optional[Dict]): The config for generator's exponential
moving average setting. Defaults to None.
"""
def __init__(self,
generator: ModelType,
discriminator: Optional[ModelType] = None,
data_preprocessor: Optional[Union[dict, Config]] = None,
generator_steps: int = 1,
discriminator_steps: int = 1,
num_scales: Optional[int] = None,
iters_per_scale: int = 2000,
noise_weight_init: int = 0.1,
lr_scheduler_args: Optional[dict] = None,
test_pkl_data: Optional[str] = None,
ema_confg: Optional[dict] = None):
super().__init__(generator, discriminator, data_preprocessor,
generator_steps, discriminator_steps, None, ema_confg)
self.iters_per_scale = iters_per_scale
self.noise_weight_init = noise_weight_init
if lr_scheduler_args:
self.lr_scheduler_args = deepcopy(lr_scheduler_args)
else:
self.lr_scheduler_args = None
# related to train
self.num_scales = num_scales
self.curr_stage = -1
self.noise_weights = [1]
self.fixed_noises = []
self.reals = []
# related to test
self.loaded_test_pkl = False
self.pkl_data = test_pkl_data
[docs] def load_test_pkl(self):
"""Load pickle for test."""
if self.pkl_data is not None:
with open(self.pkl_data, 'rb') as f:
data = pickle.load(f)
self.fixed_noises = self._from_numpy(data['fixed_noises'])
self.noise_weights = self._from_numpy(data['noise_weights'])
self.curr_stage = data['curr_stage']
print_log(f'Load pkl data from {self.pkl_data}', 'current')
self.pkl_data = self.pkl_data
self.loaded_test_pkl = True
[docs] def _from_numpy(self,
data: Tuple[list,
np.ndarray]) -> Tuple[Tensor, List[Tensor]]:
"""Convert input numpy array or list of numpy array to Tensor or list
of Tensor.
Args:
data (Tuple[list, np.ndarray]): Input data to convert.
Returns:
Tuple[Tensor, List[Tensor]]: Converted Tensor or list of tensor.
"""
if isinstance(data, list):
return [self._from_numpy(x) for x in data]
if isinstance(data, np.ndarray):
data = torch.from_numpy(data)
device = get_module_device(self.generator)
data = data.to(device)
return data
return data
[docs] def get_module(self, model: nn.Module, module_name: str) -> nn.Module:
"""Get an inner module from model.
Since we will wrapper DDP for some model, we have to judge whether the
module can be indexed directly.
Args:
model (nn.Module): This model may wrapped with DDP or not.
module_name (str): The name of specific module.
Return:
nn.Module: Returned sub module.
"""
module = model.module if hasattr(model, 'module') else model
return getattr(module, module_name)
[docs] def construct_fixed_noises(self):
"""Construct the fixed noises list used in SinGAN."""
for i, real in enumerate(self.reals):
h, w = real.shape[-2:]
if i == 0:
noise = torch.randn(1, 1, h, w).to(real)
self.fixed_noises.append(noise)
else:
noise = torch.zeros_like(real)
self.fixed_noises.append(noise)
[docs] def forward(self,
inputs: ForwardInputs,
data_samples: Optional[list] = None,
mode=None) -> List[DataSample]:
"""Forward function for SinGAN. For SinGAN, `inputs` should be a dict
contains 'num_batches', 'mode' and other input arguments for the
generator.
Args:
inputs (dict): Dict containing the necessary information
(e.g., noise, num_batches, mode) to generate image.
data_samples (Optional[list]): Data samples collated by
:attr:`data_preprocessor`. Defaults to None.
mode (Optional[str]): `mode` is not used in
:class:`BaseConditionalGAN`. Defaults to None.
"""
# handle batch_inputs
assert isinstance(inputs, dict), (
'SinGAN only support dict type inputs in forward function.')
gen_kwargs = deepcopy(inputs)
num_batches = gen_kwargs.pop('num_batches', 1)
assert num_batches == 1, (
'SinGAN only support \'num_batches\' as 1, but receive '
f'{num_batches}.')
sample_model = self._get_valid_model(inputs)
gen_kwargs.pop('sample_model', None) # remove sample_model
mode = gen_kwargs.pop('mode', mode)
mode = 'rand' if mode is None else mode
curr_scale = gen_kwargs.pop('curr_scale', self.curr_stage)
self.fixed_noises = [
x.to(self.data_preprocessor.device) for x in self.fixed_noises
]
batch_sample_list = []
if sample_model in ['ema', 'orig']:
if sample_model == 'ema':
generator = self.generator_ema
else:
generator = self.generator
outputs = generator(
None,
fixed_noises=self.fixed_noises,
noise_weights=self.noise_weights,
rand_mode=mode,
num_batches=1,
curr_scale=curr_scale,
**gen_kwargs)
gen_sample = DataSample()
# destruct
if isinstance(outputs, dict):
outputs['fake_img'] = self.data_preprocessor.destruct(
outputs['fake_img'], data_samples)
outputs['prev_res_list'] = [
self.data_preprocessor.destruct(r, data_samples)
for r in outputs['prev_res_list']
]
gen_sample.fake_img = self.data_preprocessor.destruct(
outputs['fake_img'], data_samples)
# gen_sample.prev_res_list = self.data_preprocessor.destruct(
# outputs['fake_img'], data_samples)
else:
outputs = self.data_preprocessor.destruct(
outputs, data_samples)
# save to data sample
for idx in range(num_batches):
gen_sample = DataSample()
# save inputs to data sample
if data_samples:
gen_sample.update(data_samples[idx])
if isinstance(outputs, dict):
gen_sample.fake_img = outputs['fake_img'][idx]
gen_sample.prev_res_list = [
r[idx] for r in outputs['prev_res_list']
]
else:
gen_sample.fake_img = outputs[idx]
gen_sample.sample_model = sample_model
batch_sample_list.append(gen_sample)
else: # sample model is 'ema/orig'
outputs_orig = self.generator(
None,
fixed_noises=self.fixed_noises,
noise_weights=self.noise_weights,
rand_mode=mode,
num_batches=1,
curr_scale=curr_scale,
**gen_kwargs)
outputs_ema = self.generator_ema(
None,
fixed_noises=self.fixed_noises,
noise_weights=self.noise_weights,
rand_mode=mode,
num_batches=1,
curr_scale=curr_scale,
**gen_kwargs)
# destruct
if isinstance(outputs_orig, dict):
outputs_orig['fake_img'] = self.data_preprocessor.destruct(
outputs_orig['fake_img'], data_samples)
outputs_orig['prev_res_list'] = [
self.data_preprocessor.destruct(r, data_samples)
for r in outputs_orig['prev_res_list']
]
outputs_ema['fake_img'] = self.data_preprocessor.destruct(
outputs_ema['fake_img'], data_samples)
outputs_ema['prev_res_list'] = [
self.data_preprocessor.destruct(r, data_samples)
for r in outputs_ema['prev_res_list']
]
else:
outputs_orig = self.data_preprocessor.destruct(
outputs_orig, data_samples)
outputs_ema = self.data_preprocessor.destruct(
outputs_ema, data_samples)
# save to data sample
for idx in range(num_batches):
gen_sample = DataSample()
gen_sample.ema = DataSample()
gen_sample.orig = DataSample()
# save inputs to data sample
if data_samples:
gen_sample.update(data_samples[idx])
if isinstance(outputs_orig, dict):
gen_sample.ema.fake_img = outputs_ema['fake_img'][idx]
gen_sample.ema.prev_res_list = [
r[idx] for r in outputs_ema['prev_res_list']
]
gen_sample.orig.fake_img = outputs_orig['fake_img'][idx]
gen_sample.orig.prev_res_list = [
r[idx] for r in outputs_orig['prev_res_list']
]
else:
gen_sample.ema.fake_img = outputs_ema[idx]
gen_sample.orig.fake_img = outputs_orig[idx]
gen_sample.sample_model = sample_model
batch_sample_list.append(gen_sample)
return batch_sample_list
[docs] def gen_loss(self, disc_pred_fake: Tensor,
recon_imgs: Tensor) -> Tuple[Tensor, dict]:
r"""Generator loss for SinGAN. SinGAN use WGAN's loss and MSE loss to
train the generator.
.. math:
L_{D} = -\mathbb{E}_{z\sim{p_{z}}}D\left\(G\left\(z\right\)\right\)
+ L_{MSE}
L_{MSE} = \text{mean} \Vert x - G(z) \Vert_2
Args:
disc_pred_fake (Tensor): Discriminator's prediction of the fake
images.
recon_imgs (Tensor): Reconstructive images.
Returns:
Tuple[Tensor, dict]: Loss value and a dict of log variables.
"""
losses_dict = dict()
losses_dict['loss_gen'] = -disc_pred_fake.mean()
losses_dict['loss_mse'] = 10 * F.mse_loss(recon_imgs,
self.reals[self.curr_stage])
loss, log_vars = self.parse_losses(losses_dict)
return loss, log_vars
[docs] def disc_loss(self, disc_pred_fake: Tensor, disc_pred_real: Tensor,
fake_data: Tensor, real_data: Tensor) -> Tuple[Tensor, dict]:
r"""Get disc loss. SAGAN, SNGAN and Proj-GAN use hinge loss to train
the generator.
.. math:
L_{D} = \mathbb{E}_{z\sim{p_{z}}}D\left\(G\left\(z\right\)\right\)
- \mathbb{E}_{x\sim{p_{data}}}D\left\(x\right\) + L_{GP} \\
L_{GP} = \lambda\mathbb{E}(\Vert\nabla_{\tilde{x}}D(\tilde{x})
\Vert_2-1)^2 \\
\tilde{x} = \epsilon x + (1-\epsilon)G(z)
Args:
disc_pred_fake (Tensor): Discriminator's prediction of the fake
images.
disc_pred_real (Tensor): Discriminator's prediction of the real
images.
fake_data (Tensor): Generated images, used to calculate gradient
penalty.
real_data (Tensor): Real images, used to calculate gradient
penalty.
Returns:
Tuple[Tensor, dict]: Loss value and a dict of log variables.
"""
losses_dict = dict()
losses_dict['loss_disc_fake'] = disc_pred_fake.mean()
losses_dict['loss_disc_real'] = -disc_pred_real.mean()
# gradient penalty
batch_size = real_data.size(0)
alpha = torch.rand(batch_size, 1, 1, 1).to(real_data)
# interpolate between real_data and fake_data
interpolates = alpha * real_data + (1. - alpha) * fake_data
interpolates = autograd.Variable(interpolates, requires_grad=True)
disc_interpolates = self.discriminator(
interpolates, curr_scale=self.curr_stage)
gradients = autograd.grad(
outputs=disc_interpolates,
inputs=interpolates,
grad_outputs=torch.ones_like(disc_interpolates),
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
# norm_mode is 'pixel'
gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
losses_dict['loss_gp'] = 0.1 * gradients_penalty
parsed_loss, log_vars = self.parse_losses(losses_dict)
return parsed_loss, log_vars
[docs] def train_generator(self, inputs: dict, data_samples: List[DataSample],
optimizer_wrapper: OptimWrapper) -> Dict[str, Tensor]:
"""Train generator.
Args:
inputs (dict): Inputs from dataloader.
data_samples (List[DataSample]): Data samples from dataloader.
Do not used in generator's training.
optim_wrapper (OptimWrapper): OptimWrapper instance used to update
model parameters.
Returns:
Dict[str, Tensor]: A ``dict`` of tensor for logging.
"""
fake_imgs = self.generator(
inputs['input_sample'],
self.fixed_noises,
self.noise_weights,
rand_mode='rand',
curr_scale=self.curr_stage)
disc_pred_fake_g = self.discriminator(
fake_imgs, curr_scale=self.curr_stage)
recon_imgs = self.generator(
inputs['input_sample'],
self.fixed_noises,
self.noise_weights,
rand_mode='recon',
curr_scale=self.curr_stage)
parsed_loss, log_vars = self.gen_loss(disc_pred_fake_g, recon_imgs)
optimizer_wrapper.update_params(parsed_loss)
return log_vars
[docs] def train_discriminator(self, inputs: dict, data_samples: List[DataSample],
optimizer_wrapper: OptimWrapper
) -> Dict[str, Tensor]:
"""Train discriminator.
Args:
inputs (dict): Inputs from dataloader.
data_samples (List[DataSample]): Data samples from dataloader.
optim_wrapper (OptimWrapper): OptimWrapper instance used to update
model parameters.
Returns:
Dict[str, Tensor]: A ``dict`` of tensor for logging.
"""
input_sample = inputs['input_sample']
fake_imgs = self.generator(
input_sample,
self.fixed_noises,
self.noise_weights,
rand_mode='rand',
curr_scale=self.curr_stage)
# disc pred for fake imgs and real_imgs
real_imgs = self.reals[self.curr_stage]
disc_pred_fake = self.discriminator(fake_imgs.detach(),
self.curr_stage)
disc_pred_real = self.discriminator(real_imgs, self.curr_stage)
parsed_loss, log_vars = self.disc_loss(disc_pred_fake, disc_pred_real,
fake_imgs, real_imgs)
optimizer_wrapper.update_params(parsed_loss)
return log_vars
[docs] def train_gan(self, inputs_dict: dict, data_sample: List[DataSample],
optim_wrapper: OptimWrapperDict) -> Dict[str, torch.Tensor]:
"""Train GAN model. In the training of GAN models, generator and
discriminator are updated alternatively. In MMagic's design,
`self.train_step` is called with data input. Therefore we always update
discriminator, whose updating is relay on real data, and then determine
if the generator needs to be updated based on the current number of
iterations. More details about whether to update generator can be found
in :meth:`should_gen_update`.
Args:
data (dict): Data sampled from dataloader.
data_sample (List[DataSample]): List of data sample contains GT
and meta information.
optim_wrapper (OptimWrapperDict): OptimWrapperDict instance
contains OptimWrapper of generator and discriminator.
Returns:
Dict[str, torch.Tensor]: A ``dict`` of tensor for logging.
"""
message_hub = MessageHub.get_current_instance()
curr_iter = message_hub.get_info('iter')
disc_optimizer_wrapper: OptimWrapper = optim_wrapper['discriminator']
disc_accu_iters = disc_optimizer_wrapper._accumulative_counts
with disc_optimizer_wrapper.optim_context(self.discriminator):
log_vars = self.train_discriminator(inputs_dict, data_sample,
disc_optimizer_wrapper)
# add 1 to `curr_iter` because iter is updated in train loop.
# Whether to update the generator. We update generator with
# discriminator is fully updated for `self.n_discriminator_steps`
# iterations. And one full updating for discriminator contains
# `disc_accu_counts` times of grad accumulations.
if (curr_iter + 1) % (self.discriminator_steps * disc_accu_iters) == 0:
set_requires_grad(self.discriminator, False)
gen_optimizer_wrapper = optim_wrapper['generator']
gen_accu_iters = gen_optimizer_wrapper._accumulative_counts
log_vars_gen_list = []
# init optimizer wrapper status for generator manually
gen_optimizer_wrapper.initialize_count_status(
self.generator, 0, self.generator_steps * gen_accu_iters)
for _ in range(self.generator_steps * gen_accu_iters):
with gen_optimizer_wrapper.optim_context(self.generator):
log_vars_gen = self.train_generator(
inputs_dict, data_sample, gen_optimizer_wrapper)
log_vars_gen_list.append(log_vars_gen)
log_vars_gen = self.gather_log_vars(log_vars_gen_list)
log_vars_gen.pop('loss', None) # remove 'loss' from gen logs
set_requires_grad(self.discriminator, True)
# only do ema after generator update
if self.with_ema_gen and (curr_iter + 1) >= (
self.ema_start * self.discriminator_steps *
disc_accu_iters):
self.generator_ema.update_parameters(
self.generator.module
if is_model_wrapper(self.generator) else self.generator)
# if not update buffer, copy buffer from orig model
if not self.generator_ema.update_buffers:
self.generator_ema.sync_buffers(
self.generator.module if is_model_wrapper(
self.generator) else self.generator)
elif self.with_ema_gen:
# before ema, copy weights from orig
self.generator_ema.sync_parameters(
self.generator.module
if is_model_wrapper(self.generator) else self.generator)
log_vars.update(log_vars_gen)
return log_vars
[docs] def train_step(self, data: dict,
optim_wrapper: OptimWrapperDict) -> Dict[str, Tensor]:
"""Train step for SinGAN model. SinGAN is trained with multi-resolution
images, and each resolution is trained for `:attr:self.iters_per_scale`
times.
We initialize the weight and learning rate scheduler of the
corresponding module at the start of each resolution's training. At
the end of each resolution's training, we update the weight of the
noise of current resolution by mse loss between reconstructed image and
real image.
Args:
data (dict): Data sampled from dataloader.
optim_wrapper (OptimWrapperDict): OptimWrapperDict instance
contains OptimWrapper of generator and discriminator.
Returns:
Dict[str, torch.Tensor]: A ``dict`` of tensor for logging.
"""
message_hub = MessageHub.get_current_instance()
curr_iter = message_hub.get_info('iter')
if curr_iter % (self.iters_per_scale * self.discriminator_steps) == 0:
self.curr_stage += 1
# load weights from prev scale
self.get_module(self.generator, 'check_and_load_prev_weight')(
self.curr_stage)
self.get_module(self.discriminator, 'check_and_load_prev_weight')(
self.curr_stage)
# assert grad_accumulation step is 1
curr_gen_optim = optim_wrapper[f'generator_{self.curr_stage}']
curr_disc_optim = optim_wrapper[f'discriminator_{self.curr_stage}']
_warning_msg = ('SinGAN do set batch size as 1 during training '
'and do not support gradient accumulation.')
assert curr_gen_optim._accumulative_counts == 1, _warning_msg
assert curr_disc_optim._accumulative_counts == 1, _warning_msg
# build parameters scheduler manually, because parameters_schedule
# hook update all scheduler at the same
if self.lr_scheduler_args:
self.gen_scheduler = torch.optim.lr_scheduler.MultiStepLR(
curr_gen_optim.optimizer, **self.lr_scheduler_args)
self.disc_scheduler = torch.optim.lr_scheduler.MultiStepLR(
curr_disc_optim.optimizer, **self.lr_scheduler_args)
# get current optimizer_wrapper
curr_optimizer_wrapper = OptimWrapperDict(
generator=optim_wrapper[f'generator_{self.curr_stage}'],
discriminator=optim_wrapper[f'discriminator_{self.curr_stage}'])
# handle inputs
data = self.data_preprocessor(data)
inputs_dict, data_samples = data['inputs'], data['data_samples']
# setup fixed noises and reals pyramid
if curr_iter == 0 or len(self.reals) == 0:
keys = [k for k in inputs_dict.keys() if 'real_scale' in k]
scales = len(keys)
self.reals = [inputs_dict[f'real_scale{s}'] for s in range(scales)]
# here we do not padding fixed noises
self.construct_fixed_noises()
# standard train step
log_vars = self.train_gan(inputs_dict, data_samples,
curr_optimizer_wrapper)
log_vars['curr_stage'] = self.curr_stage
# update noise weight
if ((curr_iter + 1) % (self.iters_per_scale * self.discriminator_steps)
== 0) and (self.curr_stage < len(self.reals) - 1):
with torch.no_grad():
g_recon = self.generator(
inputs_dict['input_sample'],
self.fixed_noises,
self.noise_weights,
rand_mode='recon',
curr_scale=self.curr_stage)
if isinstance(g_recon, dict):
g_recon = g_recon['fake_img']
g_recon = F.interpolate(
g_recon, self.reals[self.curr_stage + 1].shape[-2:])
mse = F.mse_loss(g_recon.detach(), self.reals[self.curr_stage + 1])
rmse = torch.sqrt(mse)
self.noise_weights.append(self.noise_weight_init * rmse.item())
# call scheduler when all submodules are fully updated.
if (curr_iter + 1) % self.discriminator_steps == 0:
if self.lr_scheduler_args:
self.disc_scheduler.step()
self.gen_scheduler.step()
return log_vars
[docs] def test_step(self, data: dict) -> SampleList:
"""Gets the generated image of given data in test progress. Before
generate images, we call `:meth:self.load_test_pkl` to load the fixed
noise and current stage of the model from the pickle file.
Args:
data (dict): Data sampled from metric specific
sampler. More details in `Metrics` and `Evaluator`.
Returns:
SampleList: A list of ``DataSample`` contain generated results.
"""
if not self.loaded_test_pkl:
self.load_test_pkl()
return super().test_step(data)