Shortcuts

Source code for mmagic.models.editors.stylegan2.stylegan2_modules

# Copyright (c) OpenMMLab. All rights reserved.
import mmengine
import numpy as np
import torch
import torch.nn as nn
from mmcv.ops.fused_bias_leakyrelu import (FusedBiasLeakyReLU,
                                           fused_bias_leakyrelu)
from mmcv.ops.upfirdn2d import upfirdn2d
from mmengine.dist import get_dist_info
from mmengine.model import BaseModule
from mmengine.runner.amp import autocast

from mmagic.models.archs import AllGatherLayer
from ..pggan import EqualizedLRConvModule, equalized_lr
from ..stylegan1 import Blur, EqualLinearActModule, NoiseInjection, make_kernel

try:
    from mmcv.ops import conv2d, conv_transpose2d
except ImportError:
    import torch.nn.functional as F
[docs] conv2d = F.conv2d
conv_transpose2d = F.conv_transpose2d print('Warning: mmcv.ops.conv2d, mmcv.ops.conv_transpose2d' ' and mmcv.ops.upfirdn2d are not available.')
[docs]class _FusedBiasLeakyReLU(FusedBiasLeakyReLU): """Wrap FusedBiasLeakyReLU to support FP16 training."""
[docs] def forward(self, x): """Forward function. Args: x (Tensor): Input feature map with shape of (N, C, ...). Returns: Tensor: Output feature map. """ return fused_bias_leakyrelu(x, self.bias.to(x.dtype), self.negative_slope, self.scale)
[docs]class UpsampleUpFIRDn(BaseModule): """UpFIRDn for Upsampling. This module is used in the ``to_rgb`` layers in StyleGAN2 for upsampling the images. Args: kernel (Array): Blur kernel/filter used in UpFIRDn. factor (int, optional): Upsampling factor. Defaults to 2. """ def __init__(self, kernel, factor=2): super().__init__() self.factor = factor kernel = make_kernel(kernel) * (factor**2) self.register_buffer('kernel', kernel) p = kernel.shape[0] - factor pad0 = (p + 1) // 2 + factor - 1 pad1 = p // 2 self.pad = (pad0, pad1, pad0, pad1)
[docs] def forward(self, x): """Forward function. Args: x (Tensor): Input feature map with shape of (N, C, H, W). Returns: Tensor: Output feature map. """ out = upfirdn2d( x, self.kernel.to(x.dtype), up=self.factor, down=1, padding=self.pad) return out
[docs]class DownsampleUpFIRDn(BaseModule): """UpFIRDn for Downsampling. This module is mentioned in StyleGAN2 for dowampling the feature maps. Args: kernel (Array): Blur kernel/filter used in UpFIRDn. factor (int, optional): Downsampling factor. Defaults to 2. """ def __init__(self, kernel, factor=2): super().__init__() self.factor = factor kernel = make_kernel(kernel) self.register_buffer('kernel', kernel) p = kernel.shape[0] - factor pad0 = (p + 1) // 2 pad1 = p // 2 self.pad = (pad0, pad1)
[docs] def forward(self, input): """Forward function. Args: input (Tensor): Input feature map with shape of (N, C, H, W). Returns: Tensor: Output feature map. """ out = upfirdn2d( input, self.kernel.to(input.dtype), up=1, down=self.factor, padding=self.pad) return out
[docs]class ModulatedConv2d(BaseModule): r"""Modulated Conv2d in StyleGANv2. This module implements the modulated convolution layers proposed in StyleGAN2. Details can be found in Analyzing and Improving the Image Quality of StyleGAN, CVPR2020. Args: in_channels (int): Input channels. out_channels (int): Output channels. kernel_size (int): Kernel size, same as :obj:`nn.Con2d`. style_channels (int): Channels for the style codes. demodulate (bool, optional): Whether to adopt demodulation. Defaults to True. upsample (bool, optional): Whether to adopt upsampling in features. Defaults to False. downsample (bool, optional): Whether to adopt downsampling in features. Defaults to False. blur_kernel (list[int], optional): Blurry kernel. Defaults to [1, 3, 3, 1]. equalized_lr_cfg (dict | None, optional): Configs for equalized lr. Defaults to dict(mode='fan_in', lr_mul=1., gain=1.). style_mod_cfg (dict, optional): Configs for style modulation module. Defaults to dict(bias_init=1.). style_bias (float, optional): Bias value for style code. Defaults to 0.. eps (float, optional): Epsilon value to avoid computation error. Defaults to 1e-8. """ def __init__( self, in_channels, out_channels, kernel_size, style_channels, demodulate=True, upsample=False, downsample=False, blur_kernel=[1, 3, 3, 1], equalized_lr_cfg=dict(mode='fan_in', lr_mul=1., gain=1.), style_mod_cfg=dict(bias_init=1.), style_bias=0., padding=None, # self define padding eps=1e-8, fp16_enabled=False): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.style_channels = style_channels self.demodulate = demodulate self.fp16_enabled = fp16_enabled # sanity check for kernel size assert isinstance(self.kernel_size, int) and (self.kernel_size >= 1 and self.kernel_size % 2 == 1) self.upsample = upsample self.downsample = downsample self.style_bias = style_bias self.eps = eps # build style modulation module style_mod_cfg = dict() if style_mod_cfg is None else style_mod_cfg self.style_modulation = EqualLinearActModule(style_channels, in_channels, **style_mod_cfg) # set lr_mul for conv weight lr_mul_ = 1. if equalized_lr_cfg is not None: lr_mul_ = equalized_lr_cfg.get('lr_mul', 1.) self.weight = nn.Parameter( torch.randn(1, out_channels, in_channels, kernel_size, kernel_size).div_(lr_mul_)) # build blurry layer for upsampling if upsample: factor = 2 p = (len(blur_kernel) - factor) - (kernel_size - 1) pad0 = (p + 1) // 2 + factor - 1 pad1 = p // 2 + 1 self.blur = Blur(blur_kernel, (pad0, pad1), upsample_factor=factor) # build blurry layer for downsampling if downsample: factor = 2 p = (len(blur_kernel) - factor) + (kernel_size - 1) pad0 = (p + 1) // 2 pad1 = p // 2 self.blur = Blur(blur_kernel, pad=(pad0, pad1)) # add equalized_lr hook for conv weight if equalized_lr_cfg is not None: equalized_lr(self, **equalized_lr_cfg) self.padding = padding if padding else (kernel_size // 2)
[docs] def forward(self, x, style, input_gain=None): n, c, h, w = x.shape weight = self.weight # Pre-normalize inputs to avoid FP16 overflow. # if x.dtype == torch.float16 and self.demodulate: if self.fp16_enabled and self.demodulate: weight = weight * ( 1 / np.sqrt( self.in_channels * self.kernel_size * self.kernel_size) / weight.norm(float('inf'), dim=[1, 2, 3], keepdim=True) ) # max_Ikk style = style / style.norm( float('inf'), dim=1, keepdim=True) # max_I with autocast(enabled=self.fp16_enabled): # process style code style = self.style_modulation(style).view(n, 1, c, 1, 1) + self.style_bias # combine weight and style weight = weight * style if self.demodulate: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps) weight = weight * demod.view(n, self.out_channels, 1, 1, 1) if input_gain is not None: # input_gain shape [batch, in_ch] input_gain = input_gain.expand(n, self.in_channels) # weight shape [batch, out_ch, in_ch, kernel_size, kernel_size] weight = weight * input_gain.unsqueeze(1).unsqueeze( 3).unsqueeze(4) weight = weight.view(n * self.out_channels, c, self.kernel_size, self.kernel_size) if self.fp16_enabled: weight = weight.to(torch.float16) x = x.to(torch.float16) if self.upsample: x = x.reshape(1, n * c, h, w) weight = weight.view(n, self.out_channels, c, self.kernel_size, self.kernel_size) weight = weight.transpose(1, 2).reshape(n * c, self.out_channels, self.kernel_size, self.kernel_size) x = conv_transpose2d(x, weight, padding=0, stride=2, groups=n) x = x.reshape(n, self.out_channels, *x.shape[-2:]) x = self.blur(x) elif self.downsample: x = self.blur(x) x = x.view(1, n * self.in_channels, *x.shape[-2:]) x = conv2d(x, weight, stride=2, padding=0, groups=n) x = x.view(n, self.out_channels, *x.shape[-2:]) else: x = x.reshape(1, n * c, h, w) x = conv2d(x, weight, stride=1, padding=self.padding, groups=n) x = x.view(n, self.out_channels, *x.shape[-2:]) return x
[docs]class ModulatedStyleConv(BaseModule): """Modulated Style Convolution. In this module, we integrate the modulated conv2d, noise injector and activation layers into together. Args: in_channels (int): Input channels. out_channels (int): Output channels. kernel_size (int): Kernel size, same as :obj:`nn.Con2d`. style_channels (int): Channels for the style codes. demodulate (bool, optional): Whether to adopt demodulation. Defaults to True. upsample (bool, optional): Whether to adopt upsampling in features. Defaults to False. downsample (bool, optional): Whether to adopt downsampling in features. Defaults to False. blur_kernel (list[int], optional): Blurry kernel. Defaults to [1, 3, 3, 1]. equalized_lr_cfg (dict | None, optional): Configs for equalized lr. Defaults to dict(mode='fan_in', lr_mul=1., gain=1.). style_mod_cfg (dict, optional): Configs for style modulation module. Defaults to dict(bias_init=1.). style_bias (float, optional): Bias value for style code. Defaults to ``0.``. fp16_enabled (bool, optional): Whether to use fp16 training in this module. Defaults to False. conv_clamp (float, optional): Clamp the convolutional layer results to avoid gradient overflow. Defaults to `256.0`. """ def __init__(self, in_channels, out_channels, kernel_size, style_channels, upsample=False, blur_kernel=[1, 3, 3, 1], demodulate=True, style_mod_cfg=dict(bias_init=1.), style_bias=0., fp16_enabled=False, conv_clamp=256, fixed_noise=False): super().__init__() # add support for fp16 self.fp16_enabled = fp16_enabled self.conv_clamp = float(conv_clamp) self.conv = ModulatedConv2d( in_channels, out_channels, kernel_size, style_channels, demodulate=demodulate, upsample=upsample, blur_kernel=blur_kernel, style_mod_cfg=style_mod_cfg, style_bias=style_bias, fp16_enabled=fp16_enabled) self.noise_injector = NoiseInjection(fixed_noise=fixed_noise) self.activate = _FusedBiasLeakyReLU(out_channels)
[docs] def forward(self, x, style, noise=None, add_noise=True, return_noise=False): """Forward Function. Args: x ([Tensor): Input features with shape of (N, C, H, W). style (Tensor): Style latent with shape of (N, C). noise (Tensor, optional): Noise for injection. Defaults to None. add_noise (bool, optional): Whether apply noise injection to feature. Defaults to True. return_noise (bool, optional): Whether to return noise tensors. Defaults to False. Returns: Tensor: Output features with shape of (N, C, H, W) """ with autocast(enabled=self.fp16_enabled): out = self.conv(x, style) if add_noise: if return_noise: out, noise = self.noise_injector( out, noise=noise, return_noise=return_noise) else: out = self.noise_injector( out, noise=noise, return_noise=return_noise) # TODO: FP16 in activate layers out = self.activate(out) if self.fp16_enabled: out = torch.clamp( out, min=-self.conv_clamp, max=self.conv_clamp) if return_noise: return out, noise return out
[docs]class ModulatedToRGB(BaseModule): """To RGB layer. This module is designed to output image tensor in StyleGAN2. Args: in_channels (int): Input channels. style_channels (int): Channels for the style codes. out_channels (int, optional): Output channels. Defaults to 3. upsample (bool, optional): Whether to adopt upsampling in features. Defaults to False. blur_kernel (list[int], optional): Blurry kernel. Defaults to [1, 3, 3, 1]. style_mod_cfg (dict, optional): Configs for style modulation module. Defaults to dict(bias_init=1.). style_bias (float, optional): Bias value for style code. Defaults to 0.. fp16_enabled (bool, optional): Whether to use fp16 training in this module. Defaults to False. conv_clamp (float, optional): Clamp the convolutional layer results to avoid gradient overflow. Defaults to `256.0`. out_fp32 (bool, optional): Whether to convert the output feature map to `torch.float32`. Defaults to `True`. """ def __init__(self, in_channels, style_channels, out_channels=3, upsample=True, blur_kernel=[1, 3, 3, 1], style_mod_cfg=dict(bias_init=1.), style_bias=0., fp16_enabled=False, conv_clamp=256, out_fp32=True): super().__init__() if upsample: self.upsample = UpsampleUpFIRDn(blur_kernel) # add support for fp16 self.fp16_enabled = fp16_enabled self.conv_clamp = float(conv_clamp) self.conv = ModulatedConv2d( in_channels, out_channels=out_channels, kernel_size=1, style_channels=style_channels, demodulate=False, style_mod_cfg=style_mod_cfg, style_bias=style_bias, fp16_enabled=fp16_enabled) self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1)) # enforece the output to be fp32 (follow Tero's implementation) self.out_fp32 = out_fp32 # @auto_fp16(apply_to=('x', 'style'))
[docs] def forward(self, x, style, skip=None): """Forward Function. Args: x ([Tensor): Input features with shape of (N, C, H, W). style (Tensor): Style latent with shape of (N, C). skip (Tensor, optional): Tensor for skip link. Defaults to None. Returns: Tensor: Output features with shape of (N, C, H, W) """ with autocast(enabled=self.fp16_enabled): out = self.conv(x, style) out = out + self.bias.to(x.dtype) if self.fp16_enabled: out = torch.clamp( out, min=-self.conv_clamp, max=self.conv_clamp) # Here, Tero adopts FP16 at `skip`. if skip is not None: if hasattr(self, 'upsample'): skip = self.upsample(skip) out = out + skip if self.out_fp32: out = out.to(torch.float32) return out
[docs]class ConvDownLayer(nn.Sequential): """Convolution and Downsampling layer. Args: in_channels (int): Input channels. out_channels (int): Output channels. kernel_size (int): Kernel size, same as :obj:`nn.Con2d`. downsample (bool, optional): Whether to adopt downsampling in features. Defaults to False. blur_kernel (list[int], optional): Blurry kernel. Defaults to [1, 3, 3, 1]. bias (bool, optional): Whether to use bias parameter. Defaults to True. act_cfg (dict, optional): Activation configs. Defaults to dict(type='fused_bias'). fp16_enabled (bool, optional): Whether to use fp16 training in this module. Defaults to False. conv_clamp (float, optional): Clamp the convolutional layer results to avoid gradient overflow. Defaults to `256.0`. """ def __init__(self, in_channels, out_channels, kernel_size, downsample=False, blur_kernel=[1, 3, 3, 1], bias=True, act_cfg=dict(type='fused_bias'), fp16_enabled=False, conv_clamp=256.): self.fp16_enabled = fp16_enabled self.conv_clamp = float(conv_clamp) layers = [] if downsample: factor = 2 p = (len(blur_kernel) - factor) + (kernel_size - 1) pad0 = (p + 1) // 2 pad1 = p // 2 layers.append(Blur(blur_kernel, pad=(pad0, pad1))) stride = 2 self.padding = 0 else: stride = 1 self.padding = kernel_size // 2 self.with_fused_bias = act_cfg is not None and act_cfg.get( 'type') == 'fused_bias' if self.with_fused_bias: conv_act_cfg = None else: conv_act_cfg = act_cfg layers.append( EqualizedLRConvModule( in_channels, out_channels, kernel_size, padding=self.padding, stride=stride, bias=bias and not self.with_fused_bias, norm_cfg=None, act_cfg=conv_act_cfg, equalized_lr_cfg=dict(mode='fan_in', gain=1.))) if self.with_fused_bias: layers.append(_FusedBiasLeakyReLU(out_channels)) super(ConvDownLayer, self).__init__(*layers) # @auto_fp16(apply_to=('x', ))
[docs] def forward(self, x): with autocast(enabled=self.fp16_enabled): x = super().forward(x) if self.fp16_enabled: x = torch.clamp(x, min=-self.conv_clamp, max=self.conv_clamp) return x
[docs]class ResBlock(BaseModule): """Residual block used in the discriminator of StyleGAN2. Args: in_channels (int): Input channels. out_channels (int): Output channels. kernel_size (int): Kernel size, same as :obj:`nn.Con2d`. fp16_enabled (bool, optional): Whether to use fp16 training in this module. Defaults to False. convert_input_fp32 (bool, optional): Whether to convert input type to fp32 if not `fp16_enabled`. This argument is designed to deal with the cases where some modules are run in FP16 and others in FP32. Defaults to True. """ def __init__(self, in_channels, out_channels, blur_kernel=[1, 3, 3, 1], fp16_enabled=False, convert_input_fp32=True): super().__init__() self.fp16_enabled = fp16_enabled self.convert_input_fp32 = convert_input_fp32 self.conv1 = ConvDownLayer( in_channels, in_channels, 3, fp16_enabled=fp16_enabled, blur_kernel=blur_kernel) self.conv2 = ConvDownLayer( in_channels, out_channels, 3, downsample=True, fp16_enabled=fp16_enabled, blur_kernel=blur_kernel) self.skip = ConvDownLayer( in_channels, out_channels, 1, downsample=True, act_cfg=None, bias=False, fp16_enabled=fp16_enabled, blur_kernel=blur_kernel)
[docs] def forward(self, input): """Forward function. Args: input (Tensor): Input feature map with shape of (N, C, H, W). Returns: Tensor: Output feature map. """ # TODO: study whether this explicit datatype transfer will harm the # apex training speed if not self.fp16_enabled and self.convert_input_fp32: input = input.to(torch.float32) with autocast(enabled=self.fp16_enabled): out = self.conv1(input) out = self.conv2(out) skip = self.skip(input) out = (out + skip) / np.sqrt(2) return out
[docs]class ModMBStddevLayer(BaseModule): """Modified MiniBatch Stddev Layer. This layer is modified from ``MiniBatchStddevLayer`` used in PGGAN. In StyleGAN2, the authors add a new feature, `channel_groups`, into this layer. Note that to accelerate the training procedure, we also add a new feature of ``sync_std`` to achieve multi-nodes/machine training. This feature is still in beta version and we have tested it on 256 scales. Args: group_size (int, optional): The size of groups in batch dimension. Defaults to 4. channel_groups (int, optional): The size of groups in channel dimension. Defaults to 1. sync_std (bool, optional): Whether to use synchronized std feature. Defaults to False. sync_groups (int | None, optional): The size of groups in node dimension. Defaults to None. eps (float, optional): Epsilon value to avoid computation error. Defaults to 1e-8. """ def __init__(self, group_size=4, channel_groups=1, sync_std=False, sync_groups=None, eps=1e-8): super().__init__() self.group_size = group_size self.eps = eps self.channel_groups = channel_groups self.sync_std = sync_std self.sync_groups = group_size if sync_groups is None else sync_groups if self.sync_std: assert torch.distributed.is_initialized( ), 'Only in distributed training can the sync_std be activated.' mmengine.print_log('Adopt synced minibatch stddev layer', 'mmagic')
[docs] def forward(self, x): """Forward function. Args: x (Tensor): Input feature map with shape of (N, C, H, W). Returns: Tensor: Output feature map with shape of (N, C+1, H, W). """ if self.sync_std: # concatenate all features all_features = torch.cat(AllGatherLayer.apply(x), dim=0) # get the exact features we need in calculating std-dev rank, ws = get_dist_info() local_bs = all_features.shape[0] // ws start_idx = local_bs * rank # avoid the case where start idx near the tail of features if start_idx + self.sync_groups > all_features.shape[0]: start_idx = all_features.shape[0] - self.sync_groups end_idx = min(local_bs * rank + self.sync_groups, all_features.shape[0]) x = all_features[start_idx:end_idx] # batch size should be smaller than or equal to group size. Otherwise, # batch size should be divisible by the group size. assert x.shape[ 0] <= self.group_size or x.shape[0] % self.group_size == 0, ( 'Batch size be smaller than or equal ' 'to group size. Otherwise,' ' batch size should be divisible by the group size.' f'But got batch size {x.shape[0]},' f' group size {self.group_size}') assert x.shape[1] % self.channel_groups == 0, ( '"channel_groups" must be divided by the feature channels. ' f'channel_groups: {self.channel_groups}, ' f'feature channels: {x.shape[1]}') n, c, h, w = x.shape group_size = min(n, self.group_size) # [G, M, Gc, C', H, W] y = torch.reshape(x, (group_size, -1, self.channel_groups, c // self.channel_groups, h, w)) y = torch.var(y, dim=0, unbiased=False) y = torch.sqrt(y + self.eps) # [M, 1, 1, 1] y = y.mean(dim=(2, 3, 4), keepdim=True).squeeze(2) y = y.repeat(group_size, 1, h, w) return torch.cat([x, y], dim=1)