Source code for mmagic.models.editors.mspie.mspie_stylegan2_discriminator
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch.nn as nn
from mmengine.model import BaseModule
from mmagic.registry import MODELS
from ..stylegan1 import EqualLinearActModule
from ..stylegan2 import ConvDownLayer, ModMBStddevLayer, ResBlock
@MODELS.register_module()
[docs]class MSStyleGAN2Discriminator(BaseModule):
"""StyleGAN2 Discriminator.
The architecture of this discriminator is proposed in StyleGAN2. More
details can be found in: Analyzing and Improving the Image Quality of
StyleGAN CVPR2020.
Args:
in_size (int): The input size of images.
channel_multiplier (int, optional): The multiplier factor for the
channel number. Defaults to 2.
blur_kernel (list, optional): The blurry kernel. Defaults
to [1, 3, 3, 1].
mbstd_cfg (dict, optional): Configs for minibatch-stddev layer.
Defaults to dict(group_size=4, channel_groups=1).
"""
def __init__(self,
in_size,
channel_multiplier=2,
blur_kernel=[1, 3, 3, 1],
mbstd_cfg=dict(group_size=4, channel_groups=1),
with_adaptive_pool=False,
pool_size=(2, 2)):
super().__init__()
self.with_adaptive_pool = with_adaptive_pool
self.pool_size = pool_size
channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
log_size = int(np.log2(in_size))
in_channels = channels[in_size]
convs = [ConvDownLayer(3, channels[in_size], 1)]
for i in range(log_size, 2, -1):
out_channel = channels[2**(i - 1)]
convs.append(ResBlock(in_channels, out_channel, blur_kernel))
in_channels = out_channel
self.convs = nn.Sequential(*convs)
self.mbstd_layer = ModMBStddevLayer(**mbstd_cfg)
self.final_conv = ConvDownLayer(in_channels + 1, channels[4], 3)
if self.with_adaptive_pool:
self.adaptive_pool = nn.AdaptiveAvgPool2d(pool_size)
linear_in_channels = channels[4] * pool_size[0] * pool_size[1]
else:
linear_in_channels = channels[4] * 4 * 4
self.final_linear = nn.Sequential(
EqualLinearActModule(
linear_in_channels,
channels[4],
act_cfg=dict(type='fused_bias')),
EqualLinearActModule(channels[4], 1),
)
[docs] def forward(self, x):
"""Forward function.
Args:
x (torch.Tensor): Input image tensor.
Returns:
torch.Tensor: Predict score for the input image.
"""
x = self.convs(x)
x = self.mbstd_layer(x)
x = self.final_conv(x)
if self.with_adaptive_pool:
x = self.adaptive_pool(x)
x = x.view(x.shape[0], -1)
x = self.final_linear(x)
return x