Source code for mmagic.models.editors.wgan_gp.wgan_generator
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from mmagic.models.utils import get_module_device
from mmagic.registry import MODELS
from .wgan_gp_module import WGANNoiseTo2DFeat
@MODELS.register_module()
[docs]class WGANGPGenerator(BaseModule):
r"""Generator for WGANGP.
Implementation Details for WGANGP generator the same as training
configuration (a) described in PGGAN paper:
PROGRESSIVE GROWING OF GANS FOR IMPROVED QUALITY, STABILITY, AND VARIATION
https://research.nvidia.com/sites/default/files/pubs/2017-10_Progressive-Growing-of/karras2018iclr-paper.pdf # noqa
#. Adopt convolution architecture specified in appendix A.2;
#. Use batchnorm in the generator except for the final output layer;
#. Use ReLU in the generator except for the final output layer;
#. Use Tanh in the last layer;
#. Initialize all weights using He’s initializer.
Args:
noise_size (int): Size of the input noise vector.
out_scale (int): Output scale for the generated image.
conv_module_cfg (dict, optional): Config for the convolution
module used in this generator. Defaults to None.
upsample_cfg (dict, optional): Config for the upsampling operation.
Defaults to None.
init_cfg (dict, optional): Initialization config dict.
"""
[docs] _default_channels_per_scale = {
'4': 512,
'8': 512,
'16': 256,
'32': 128,
'64': 64,
'128': 32
}
[docs] _default_conv_module_cfg = dict(
conv_cfg=None,
kernel_size=3,
stride=1,
padding=1,
bias=True,
act_cfg=dict(type='ReLU'),
norm_cfg=dict(type='BN'),
order=('conv', 'norm', 'act'))
def __init__(self,
noise_size,
out_scale,
conv_module_cfg=None,
upsample_cfg=None,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
# set initial params
self.noise_size = noise_size
self.out_scale = out_scale
self.conv_module_cfg = deepcopy(self._default_conv_module_cfg)
if conv_module_cfg is not None:
self.conv_module_cfg.update(conv_module_cfg)
self.upsample_cfg = upsample_cfg if upsample_cfg else deepcopy(
self._default_upsample_cfg)
# set noise2feat head
self.noise2feat = WGANNoiseTo2DFeat(
self.noise_size, self._default_channels_per_scale['4'])
# set conv_blocks
self.conv_blocks = nn.ModuleList()
self.conv_blocks.append(ConvModule(512, 512, **self.conv_module_cfg))
log2scale = int(np.log2(self.out_scale))
for i in range(3, log2scale + 1):
self.conv_blocks.append(MODELS.build(self._default_upsample_cfg))
self.conv_blocks.append(
ConvModule(self._default_channels_per_scale[str(2**(i - 1))],
self._default_channels_per_scale[str(2**i)],
**self.conv_module_cfg))
self.conv_blocks.append(
ConvModule(self._default_channels_per_scale[str(2**i)],
self._default_channels_per_scale[str(2**i)],
**self.conv_module_cfg))
self.to_rgb = ConvModule(
self._default_channels_per_scale[str(self.out_scale)],
kernel_size=1,
out_channels=3,
act_cfg=dict(type='Tanh'))
[docs] def forward(self, noise, num_batches=0, return_noise=False):
"""Forward function.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size. Defaults to
0.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
Returns:
torch.Tensor | dict: If not ``return_noise``, only the output image
will be returned. Otherwise, a dict contains ``fake_img`` and
``noise_batch`` will be returned.
"""
# receive noise and conduct sanity check.
if isinstance(noise, torch.Tensor):
assert noise.shape[1] == self.noise_size
assert noise.ndim == 2, ('The noise should be in shape of (n, c), '
f'but got {noise.shape}')
noise_batch = noise
# receive a noise generator and sample noise.
elif callable(noise):
noise_generator = noise
assert num_batches > 0
noise_batch = noise_generator((num_batches, self.noise_size))
# otherwise, we will adopt default noise sampler.
else:
assert num_batches > 0
noise_batch = torch.randn((num_batches, self.noise_size))
# dirty code for putting data on the right device
noise_batch = noise_batch.to(get_module_device(self))
# noise vector to 2D feature
x = self.noise2feat(noise_batch)
for conv in self.conv_blocks:
x = conv(x)
out_img = self.to_rgb(x)
if return_noise:
output = dict(fake_img=out_img, noise_batch=noise_batch)
return output
return out_img