Shortcuts

Source code for mmagic.models.editors.pix2pix.pix2pix_generator

# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.model import BaseModule

from mmagic.registry import MODELS
from ...utils import generation_init_weights
from .pix2pix_modules import UnetSkipConnectionBlock


@MODELS.register_module()
[docs]class UnetGenerator(BaseModule): """Construct the Unet-based generator from the innermost layer to the outermost layer, which is a recursive process. Args: in_channels (int): Number of channels in input images. out_channels (int): Number of channels in output images. num_down (int): Number of downsamplings in Unet. If `num_down` is 8, the image with size 256x256 will become 1x1 at the bottleneck. Default: 8. base_channels (int): Number of channels at the last conv layer. Default: 64. norm_cfg (dict): Config dict to build norm layer. Default: `dict(type='BN')`. use_dropout (bool): Whether to use dropout layers. Default: False. init_cfg (dict): Config dict for initialization. `type`: The name of our initialization method. Default: 'normal'. `gain`: Scaling factor for normal, xavier and orthogonal. Default: 0.02. """ def __init__(self, in_channels, out_channels, num_down=8, base_channels=64, norm_cfg=dict(type='BN'), use_dropout=False, init_cfg=dict(type='normal', gain=0.02)): super().__init__(init_cfg=init_cfg) # We use norm layers in the unet generator. assert isinstance(norm_cfg, dict), ("'norm_cfg' should be dict, but" f'got {type(norm_cfg)}') assert 'type' in norm_cfg, "'norm_cfg' must have key 'type'" # add the innermost layer unet_block = UnetSkipConnectionBlock( base_channels * 8, base_channels * 8, in_channels=None, submodule=None, norm_cfg=norm_cfg, is_innermost=True) # add intermediate layers with base_channels * 8 filters for _ in range(num_down - 5): unet_block = UnetSkipConnectionBlock( base_channels * 8, base_channels * 8, in_channels=None, submodule=unet_block, norm_cfg=norm_cfg, use_dropout=use_dropout) # gradually reduce the number of filters # from base_channels * 8 to base_channels unet_block = UnetSkipConnectionBlock( base_channels * 4, base_channels * 8, in_channels=None, submodule=unet_block, norm_cfg=norm_cfg) unet_block = UnetSkipConnectionBlock( base_channels * 2, base_channels * 4, in_channels=None, submodule=unet_block, norm_cfg=norm_cfg) unet_block = UnetSkipConnectionBlock( base_channels, base_channels * 2, in_channels=None, submodule=unet_block, norm_cfg=norm_cfg) # add the outermost layer self.model = UnetSkipConnectionBlock( out_channels, base_channels, in_channels=in_channels, submodule=unet_block, is_outermost=True, norm_cfg=norm_cfg) self.init_type = 'normal' if init_cfg is None else init_cfg.get( 'type', 'normal') self.init_gain = 0.02 if init_cfg is None else init_cfg.get( 'gain', 0.02)
[docs] def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ return self.model(x)
[docs] def init_weights(self): """Initialize weights for the model. Args: pretrained (str, optional): Path for pretrained weights. If given None, pretrained weights will not be loaded. Default: None. strict (bool, optional): Whether to allow different params for the model and checkpoint. Default: True. """ if self.init_cfg is not None and self.init_cfg['type'] == 'Pretrained': super().init_weights() return generation_init_weights( self, init_type=self.init_type, init_gain=self.init_gain) self._is_init = True