Source code for mmagic.models.editors.lsgan.lsgan_generator
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from mmagic.registry import MODELS
from ...utils import get_module_device
@MODELS.register_module()
[docs]class LSGANGenerator(BaseModule):
"""Generator for LSGAN.
Implementation Details for LSGAN architecture:
#. Adopt transposed convolution in the generator;
#. Use batchnorm in the generator except for the final output layer;
#. Use ReLU in the generator in addition to the final output layer;
#. Keep channels of feature maps unchanged in the convolution backbone;
#. Use one more 3x3 conv every upsampling in the convolution backbone.
We follow the implementation details of the origin paper:
Least Squares Generative Adversarial Networks
https://arxiv.org/pdf/1611.04076.pdf
Args:
output_scale (int, optional): Output scale for the generated image.
Defaults to 128.
out_channels (int, optional): The channel number of the output feature.
Defaults to 3.
base_channels (int, optional): The basic channel number of the
generator. The other layers contains channels based on this number.
Defaults to 256.
input_scale (int, optional): The scale of the input 2D feature map.
Defaults to 8.
noise_size (int, optional): Size of the input noise
vector. Defaults to 1024.
conv_cfg (dict, optional): Config for the convolution module used in
this generator. Defaults to dict(type='ConvTranspose2d').
default_norm_cfg (dict, optional): Norm config for all of layers
except for the final output layer. Defaults to dict(type='BN').
default_act_cfg (dict, optional): Activation config for all of layers
except for the final output layer. Defaults to dict(type='ReLU').
out_act_cfg (dict, optional): Activation config for the final output
layer. Defaults to dict(type='Tanh').
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self,
output_scale=128,
out_channels=3,
base_channels=256,
input_scale=8,
noise_size=1024,
conv_cfg=dict(type='ConvTranspose2d'),
default_norm_cfg=dict(type='BN'),
default_act_cfg=dict(type='ReLU'),
out_act_cfg=dict(type='Tanh'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
assert output_scale % input_scale == 0
assert output_scale // input_scale >= 4
self.output_scale = output_scale
self.base_channels = base_channels
self.input_scale = input_scale
self.noise_size = noise_size
self.noise2feat_head = nn.Sequential(
nn.Linear(noise_size, input_scale * input_scale * base_channels))
self.noise2feat_tail = nn.Sequential(nn.BatchNorm2d(base_channels))
if default_act_cfg is not None:
self.noise2feat_tail.add_module('act',
MODELS.build(default_act_cfg))
# the number of times for upsampling
self.num_upsamples = int(np.log2(output_scale // input_scale)) - 2
# build up convolution backbone (excluding the output layer)
self.conv_blocks = nn.ModuleList()
for _ in range(self.num_upsamples):
self.conv_blocks.append(
ConvModule(
base_channels,
base_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=dict(conv_cfg, output_padding=1),
norm_cfg=default_norm_cfg,
act_cfg=default_act_cfg))
self.conv_blocks.append(
ConvModule(
base_channels,
base_channels,
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=default_norm_cfg,
act_cfg=default_act_cfg))
# output blocks
self.conv_blocks.append(
ConvModule(
base_channels,
int(base_channels // 2),
kernel_size=3,
stride=2,
padding=1,
conv_cfg=dict(conv_cfg, output_padding=1),
norm_cfg=default_norm_cfg,
act_cfg=default_act_cfg))
self.conv_blocks.append(
ConvModule(
int(base_channels // 2),
int(base_channels // 4),
kernel_size=3,
stride=2,
padding=1,
conv_cfg=dict(conv_cfg, output_padding=1),
norm_cfg=default_norm_cfg,
act_cfg=default_act_cfg))
self.conv_blocks.append(
ConvModule(
int(base_channels // 4),
out_channels,
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=None,
act_cfg=out_act_cfg))
[docs] def forward(self, noise, num_batches=0, return_noise=False):
"""Forward function.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
Returns:
torch.Tensor | dict: If not ``return_noise``, only the output image
will be returned. Otherwise, a dict contains ``fake_img`` and
``noise_batch`` will be returned.
"""
# receive noise and conduct sanity check.
if isinstance(noise, torch.Tensor):
assert noise.shape[1] == self.noise_size
if noise.ndim == 2:
noise_batch = noise
else:
raise ValueError('The noise should be in shape of (n, c)'
f'but got {noise.shape}')
# receive a noise generator and sample noise.
elif callable(noise):
noise_generator = noise
assert num_batches > 0
noise_batch = noise_generator((num_batches, self.noise_size))
# otherwise, we will adopt default noise sampler.
else:
assert num_batches > 0
noise_batch = torch.randn((num_batches, self.noise_size))
# dirty code for putting data on the right device
noise_batch = noise_batch.to(get_module_device(self))
# noise2feat
x = self.noise2feat_head(noise_batch)
x = x.reshape(
(-1, self.base_channels, self.input_scale, self.input_scale))
x = self.noise2feat_tail(x)
# conv module
for conv in self.conv_blocks:
x = conv(x)
if return_noise:
return dict(fake_img=x, noise_batch=noise_batch)
return x