Shortcuts

mmagic.models.editors.lsgan.lsgan_generator 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule

from mmagic.registry import MODELS
from ...utils import get_module_device


@MODELS.register_module()
[文档]class LSGANGenerator(BaseModule): """Generator for LSGAN. Implementation Details for LSGAN architecture: #. Adopt transposed convolution in the generator; #. Use batchnorm in the generator except for the final output layer; #. Use ReLU in the generator in addition to the final output layer; #. Keep channels of feature maps unchanged in the convolution backbone; #. Use one more 3x3 conv every upsampling in the convolution backbone. We follow the implementation details of the origin paper: Least Squares Generative Adversarial Networks https://arxiv.org/pdf/1611.04076.pdf Args: output_scale (int, optional): Output scale for the generated image. Defaults to 128. out_channels (int, optional): The channel number of the output feature. Defaults to 3. base_channels (int, optional): The basic channel number of the generator. The other layers contains channels based on this number. Defaults to 256. input_scale (int, optional): The scale of the input 2D feature map. Defaults to 8. noise_size (int, optional): Size of the input noise vector. Defaults to 1024. conv_cfg (dict, optional): Config for the convolution module used in this generator. Defaults to dict(type='ConvTranspose2d'). default_norm_cfg (dict, optional): Norm config for all of layers except for the final output layer. Defaults to dict(type='BN'). default_act_cfg (dict, optional): Activation config for all of layers except for the final output layer. Defaults to dict(type='ReLU'). out_act_cfg (dict, optional): Activation config for the final output layer. Defaults to dict(type='Tanh'). init_cfg (dict, optional): Initialization config dict. """ def __init__(self, output_scale=128, out_channels=3, base_channels=256, input_scale=8, noise_size=1024, conv_cfg=dict(type='ConvTranspose2d'), default_norm_cfg=dict(type='BN'), default_act_cfg=dict(type='ReLU'), out_act_cfg=dict(type='Tanh'), init_cfg=None): super().__init__(init_cfg=init_cfg) assert output_scale % input_scale == 0 assert output_scale // input_scale >= 4 self.output_scale = output_scale self.base_channels = base_channels self.input_scale = input_scale self.noise_size = noise_size self.noise2feat_head = nn.Sequential( nn.Linear(noise_size, input_scale * input_scale * base_channels)) self.noise2feat_tail = nn.Sequential(nn.BatchNorm2d(base_channels)) if default_act_cfg is not None: self.noise2feat_tail.add_module('act', MODELS.build(default_act_cfg)) # the number of times for upsampling self.num_upsamples = int(np.log2(output_scale // input_scale)) - 2 # build up convolution backbone (excluding the output layer) self.conv_blocks = nn.ModuleList() for _ in range(self.num_upsamples): self.conv_blocks.append( ConvModule( base_channels, base_channels, kernel_size=3, stride=2, padding=1, conv_cfg=dict(conv_cfg, output_padding=1), norm_cfg=default_norm_cfg, act_cfg=default_act_cfg)) self.conv_blocks.append( ConvModule( base_channels, base_channels, kernel_size=3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=default_norm_cfg, act_cfg=default_act_cfg)) # output blocks self.conv_blocks.append( ConvModule( base_channels, int(base_channels // 2), kernel_size=3, stride=2, padding=1, conv_cfg=dict(conv_cfg, output_padding=1), norm_cfg=default_norm_cfg, act_cfg=default_act_cfg)) self.conv_blocks.append( ConvModule( int(base_channels // 2), int(base_channels // 4), kernel_size=3, stride=2, padding=1, conv_cfg=dict(conv_cfg, output_padding=1), norm_cfg=default_norm_cfg, act_cfg=default_act_cfg)) self.conv_blocks.append( ConvModule( int(base_channels // 4), out_channels, kernel_size=3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=None, act_cfg=out_act_cfg))
[文档] def forward(self, noise, num_batches=0, return_noise=False): """Forward function. Args: noise (torch.Tensor | callable | None): You can directly give a batch of noise through a ``torch.Tensor`` or offer a callable function to sample a batch of noise data. Otherwise, the ``None`` indicates to use the default noise sampler. num_batches (int, optional): The number of batch size. Defaults to 0. return_noise (bool, optional): If True, ``noise_batch`` will be returned in a dict with ``fake_img``. Defaults to False. Returns: torch.Tensor | dict: If not ``return_noise``, only the output image will be returned. Otherwise, a dict contains ``fake_img`` and ``noise_batch`` will be returned. """ # receive noise and conduct sanity check. if isinstance(noise, torch.Tensor): assert noise.shape[1] == self.noise_size if noise.ndim == 2: noise_batch = noise else: raise ValueError('The noise should be in shape of (n, c)' f'but got {noise.shape}') # receive a noise generator and sample noise. elif callable(noise): noise_generator = noise assert num_batches > 0 noise_batch = noise_generator((num_batches, self.noise_size)) # otherwise, we will adopt default noise sampler. else: assert num_batches > 0 noise_batch = torch.randn((num_batches, self.noise_size)) # dirty code for putting data on the right device noise_batch = noise_batch.to(get_module_device(self)) # noise2feat x = self.noise2feat_head(noise_batch) x = x.reshape( (-1, self.base_channels, self.input_scale, self.input_scale)) x = self.noise2feat_tail(x) # conv module for conv in self.conv_blocks: x = conv(x) if return_noise: return dict(fake_img=x, noise_batch=noise_batch) return x
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.