Source code for mmagic.models.editors.stylegan2.stylegan2_generator
# Copyright (c) OpenMMLab. All rights reserved.
import random
import mmengine
import numpy as np
import torch
import torch.nn as nn
from mmengine.model import BaseModule
from mmengine.runner.amp import autocast
from mmengine.runner.checkpoint import _load_checkpoint_with_prefix
from mmagic.registry import MODELS
from ...utils import get_module_device
from ..pggan import PixelNorm
from ..stylegan1 import (ConstantInput, EqualLinearActModule, get_mean_latent,
style_mixing)
from .stylegan2_modules import ModulatedStyleConv, ModulatedToRGB
@MODELS.register_module('StyleGANv2Generator')
@MODELS.register_module()
[docs]class StyleGAN2Generator(BaseModule):
r"""StyleGAN2 Generator.
In StyleGAN2, we use a static architecture composing of a style mapping
module and number of convolutional style blocks. More details can be found
in: Analyzing and Improving the Image Quality of StyleGAN CVPR2020.
You can load pretrained model through passing information into
``pretrained`` argument. We have already offered official weights as
follows:
- stylegan2-ffhq-config-f: https://download.openmmlab.com/mmediting/stylegan2/official_weights/stylegan2-ffhq-config-f-official_20210327_171224-bce9310c.pth # noqa
- stylegan2-horse-config-f: https://download.openmmlab.com/mmediting/stylegan2/official_weights/stylegan2-horse-config-f-official_20210327_173203-ef3e69ca.pth # noqa
- stylegan2-car-config-f: https://download.openmmlab.com/mmediting/stylegan2/official_weights/stylegan2-car-config-f-official_20210327_172340-8cfe053c.pth # noqa
- stylegan2-cat-config-f: https://download.openmmlab.com/mmediting/stylegan2/official_weights/stylegan2-cat-config-f-official_20210327_172444-15bc485b.pth # noqa
- stylegan2-church-config-f: https://download.openmmlab.com/mmediting/stylegan2/official_weights/stylegan2-church-config-f-official_20210327_172657-1d42b7d1.pth # noqa
If you want to load the ema model, you can just use following codes:
.. code-block:: python
# ckpt_http is one of the valid path from http source
generator = StyleGANv2Generator(1024, 512,
pretrained=dict(
ckpt_path=ckpt_http,
prefix='generator_ema'))
Of course, you can also download the checkpoint in advance and set
``ckpt_path`` with local path. If you just want to load the original
generator (not the ema model), please set the prefix with 'generator'.
Note that our implementation allows to generate BGR image, while the
original StyleGAN2 outputs RGB images by default. Thus, we provide
``bgr2rgb`` argument to convert the image space.
Args:
out_size (int): The output size of the StyleGAN2 generator.
style_channels (int): The number of channels for style code.
out_channels (int): The number of channels for output. Defaults to 3.
noise_size (int, optional): The size of (number of channels) the input
noise. If not passed, will be set the same value as
:attr:`style_channels`. Defaults to None.
cond_size (int, optional): The size of the conditional input. If not
passed or less than 1, no conditional embedding will be used.
Defaults to None.
cond_mapping_channels (int, optional): The channels of the
conditional mapping layers. If not passed, will use the same value
as :attr:`style_channels`. Defaults to None.
num_mlps (int, optional): The number of MLP layers. Defaults to 8.
channel_multiplier (int, optional): The multiplier factor for the
channel number. Defaults to 2.
blur_kernel (list, optional): The blurry kernel. Defaults
to [1, 3, 3, 1].
lr_mlp (float, optional): The learning rate for the style mapping
layer. Defaults to 0.01.
default_style_mode (str, optional): The default mode of style mixing.
In training, we adopt mixing style mode in default. However, in the
evaluation, we use 'single' style mode. `['mix', 'single']` are
currently supported. Defaults to 'mix'.
eval_style_mode (str, optional): The evaluation mode of style mixing.
Defaults to 'single'.
mix_prob (float, optional): Mixing probability. The value should be
in range of [0, 1]. Defaults to ``0.9``.
update_mean_latent_with_ema (bool, optional): Whether update mean
latent code (w) with EMA. Defaults to False.
w_avg_beta (float, optional): The value used for update `w_avg`.
Defaults to 0.998.
num_fp16_scales (int, optional): The number of resolutions to use auto
fp16 training. Different from ``fp16_enabled``, this argument
allows users to adopt FP16 training only in several blocks.
This behaviour is much more similar to the official implementation
by Tero. Defaults to 0.
fp16_enabled (bool, optional): Whether to use fp16 training in this
module. If this flag is `True`, the whole module will be wrapped
with ``auto_fp16``. Defaults to False.
pretrained (dict | None, optional): Information for pretrained models.
The necessary key is 'ckpt_path'. Besides, you can also provide
'prefix' to load the generator part from the whole state dict.
Defaults to None.
"""
def __init__(self,
out_size,
style_channels,
out_channels=3,
noise_size=None,
cond_size=None,
cond_mapping_channels=None,
num_mlps=8,
channel_multiplier=2,
blur_kernel=[1, 3, 3, 1],
lr_mlp=0.01,
default_style_mode='mix',
eval_style_mode='single',
norm_eps=1e-6,
mix_prob=0.9,
update_mean_latent_with_ema=False,
w_avg_beta=0.998,
num_fp16_scales=0,
fp16_enabled=False,
bgr2rgb=False,
pretrained=None,
fixed_noise=False):
super().__init__()
self.out_size = out_size
self.style_channels = style_channels
self.out_channels = out_channels
self.num_mlps = num_mlps
self.channel_multiplier = channel_multiplier
self.lr_mlp = lr_mlp
self._default_style_mode = default_style_mode
self.default_style_mode = default_style_mode
self.eval_style_mode = eval_style_mode
self.mix_prob = mix_prob
self.num_fp16_scales = num_fp16_scales
self.fp16_enabled = fp16_enabled
self.bgr2rgb = bgr2rgb
self.noise_size = style_channels if noise_size is None else noise_size
self.cond_size = cond_size
if self.cond_size is not None and self.cond_size > 0:
cond_mapping_channels = style_channels \
if cond_mapping_channels is None else cond_mapping_channels
self.embed = EqualLinearActModule(cond_size, cond_mapping_channels)
# NOTE: conditional input is passed, do 2nd moment norm for
# embedding and noise input respectively, therefore mapping layer
# start with FC layer
mapping_layers = []
else:
cond_mapping_channels = 0
# NOTE: conditional input is not passed, put 2nd moment norm at
# the start of mapping layers
mapping_layers = [PixelNorm(eps=norm_eps)]
in_feat = cond_mapping_channels + self.noise_size
# define pixel norm
self.pixel_norm = PixelNorm(eps=norm_eps)
# define style mapping layers
for idx in range(num_mlps):
mapping_layers.append(
EqualLinearActModule(
in_feat if idx == 0 else style_channels,
style_channels,
equalized_lr_cfg=dict(lr_mul=lr_mlp, gain=1.),
act_cfg=dict(type='fused_bias')))
self.style_mapping = nn.Sequential(*mapping_layers)
self.channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
# constant input layer
self.constant_input = ConstantInput(self.channels[4])
# 4x4 stage
self.conv1 = ModulatedStyleConv(
self.channels[4],
self.channels[4],
kernel_size=3,
style_channels=style_channels,
blur_kernel=blur_kernel,
fp16_enabled=fp16_enabled,
fixed_noise=fixed_noise)
self.to_rgb1 = ModulatedToRGB(
self.channels[4],
style_channels,
out_channels=out_channels,
upsample=False,
fp16_enabled=fp16_enabled)
# generator backbone (8x8 --> higher resolutions)
self.log_size = int(np.log2(self.out_size))
self.convs = nn.ModuleList()
self.upsamples = nn.ModuleList()
self.to_rgbs = nn.ModuleList()
blk_in_channels_ = self.channels[4] # in channels of the conv blocks
for i in range(3, self.log_size + 1):
blk_out_channels_ = self.channels[2**i]
# If `fp16_enabled` is True, all of layers will be run in auto
# FP16. In the case of `num_fp16_scales` > 0, only partial
# layers will be run in fp16.
_use_fp16 = (self.log_size - i) < num_fp16_scales or fp16_enabled
self.convs.append(
ModulatedStyleConv(
blk_in_channels_,
blk_out_channels_,
3,
style_channels,
upsample=True,
blur_kernel=blur_kernel,
fp16_enabled=_use_fp16))
self.convs.append(
ModulatedStyleConv(
blk_out_channels_,
blk_out_channels_,
3,
style_channels,
upsample=False,
blur_kernel=blur_kernel,
fp16_enabled=_use_fp16))
self.to_rgbs.append(
ModulatedToRGB(
blk_out_channels_,
style_channels,
out_channels=out_channels,
upsample=True,
fp16_enabled=_use_fp16)) # set to global fp16
blk_in_channels_ = blk_out_channels_
self.num_latents = self.log_size * 2 - 2
self.num_injected_noises = self.num_latents - 1
# register buffer for injected noises
for layer_idx in range(self.num_injected_noises):
res = (layer_idx + 5) // 2
shape = [1, 1, 2**res, 2**res]
self.register_buffer(f'injected_noise_{layer_idx}',
torch.randn(*shape))
if (self.cond_size is not None
and self.cond_size > 0) or update_mean_latent_with_ema:
# Due to `get_mean_latent` cannot handle conditional input,
# assign avg style code here and update with EMA.
self.register_buffer('w_avg', torch.zeros([style_channels]))
self.w_avg_beta = w_avg_beta
mmengine.print_log('Mean latent code (w) is updated with EMA.')
if pretrained is not None:
self._load_pretrained_model(**pretrained)
[docs] def _load_pretrained_model(self,
ckpt_path,
prefix='',
map_location='cpu',
strict=True):
state_dict = _load_checkpoint_with_prefix(prefix, ckpt_path,
map_location)
self.load_state_dict(state_dict, strict=strict)
mmengine.print_log(f'Load pretrained model from {ckpt_path}')
[docs] def train(self, mode=True):
if mode:
if self.default_style_mode != self._default_style_mode:
mmengine.print_log(
f'Switch to train style mode: {self._default_style_mode}')
self.default_style_mode = self._default_style_mode
else:
if self.default_style_mode != self.eval_style_mode:
mmengine.print_log(
f'Switch to evaluation style mode: {self.eval_style_mode}')
self.default_style_mode = self.eval_style_mode
return super(StyleGAN2Generator, self).train(mode)
[docs] def make_injected_noise(self):
"""make noises that will be injected into feature maps.
Returns:
list[Tensor]: List of layer-wise noise tensor.
"""
device = get_module_device(self)
noises = [torch.randn(1, 1, 2**2, 2**2, device=device)]
for i in range(3, self.log_size + 1):
for _ in range(2):
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
return noises
[docs] def get_mean_latent(self, num_samples=4096, **kwargs):
"""Get mean latent of W space in this generator.
Args:
num_samples (int, optional): Number of sample times. Defaults
to 4096.
Returns:
Tensor: Mean latent of this generator.
"""
if hasattr(self, 'w_avg'):
mmengine.print_log('Get latent code (w) which is updated by EMA.')
return self.w_avg
return get_mean_latent(self, num_samples, **kwargs)
[docs] def style_mixing(self,
n_source,
n_target,
inject_index=1,
truncation_latent=None,
truncation=0.7):
return style_mixing(
self,
n_source=n_source,
n_target=n_target,
inject_index=inject_index,
truncation=truncation,
truncation_latent=truncation_latent,
style_channels=self.style_channels)
# @auto_fp16()
[docs] def forward(self,
styles,
label=None,
num_batches=-1,
return_noise=False,
return_latents=False,
inject_index=None,
truncation=1,
truncation_latent=None,
input_is_latent=False,
injected_noise=None,
add_noise=True,
randomize_noise=True,
update_ws=False,
return_features=False,
feat_idx=5,
return_latent_only=False):
"""Forward function.
This function has been integrated with the truncation trick. Please
refer to the usage of `truncation` and `truncation_latent`.
Args:
styles (torch.Tensor | list[torch.Tensor] | callable | None): In
StyleGAN2, you can provide noise tensor or latent tensor. Given
a list containing more than one noise or latent tensors, style
mixing trick will be used in training. Of course, You can
directly give a batch of noise through a ``torch.Tensor`` or
offer a callable function to sample a batch of noise data.
Otherwise, the ``None`` indicates to use the default noise
sampler.
label (torch.Tensor, optional): Conditional inputs for the
generator. Defaults to None.
num_batches (int, optional): The number of batch size.
Defaults to 0.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
return_latents (bool, optional): If True, ``latent`` will be
returned in a dict with ``fake_img``. Defaults to False.
inject_index (int | None, optional): The index number for mixing
style codes. Defaults to None.
truncation (float, optional): Truncation factor. Give value less
than 1., the truncation trick will be adopted. Defaults to 1.
truncation_latent (torch.Tensor, optional): Mean truncation latent.
Defaults to None.
input_is_latent (bool, optional): If `True`, the input tensor is
the latent tensor. Defaults to False.
injected_noise (torch.Tensor | None, optional): Given a tensor, the
random noise will be fixed as this input injected noise.
Defaults to None.
add_noise (bool): Whether apply noise injection. Defaults to True.
randomize_noise (bool, optional): If `False`, images are sampled
with the buffered noise tensor injected to the style conv
block. Defaults to True.
update_ws (bool): Whether update latent code with EMA. Only work
when `w_avg` is registered. Defaults to False.
Returns:
torch.Tensor | dict: Generated image tensor or dictionary \
containing more data.
"""
# device = styles.device
input_dim = self.style_channels if input_is_latent else self.noise_size
# receive noise and conduct sanity check.
if isinstance(styles, torch.Tensor):
assert styles.shape[1] == input_dim
styles = [styles]
elif mmengine.is_seq_of(styles, torch.Tensor):
for t in styles:
assert t.shape[-1] == input_dim
# receive a noise generator and sample noise.
elif callable(styles):
device = get_module_device(self)
noise_generator = styles
assert num_batches > 0
if self.default_style_mode == 'mix' and random.random(
) < self.mix_prob:
styles = [
noise_generator((num_batches, input_dim)) for _ in range(2)
]
else:
styles = [noise_generator((num_batches, input_dim))]
styles = [s.to(device) for s in styles]
# otherwise, we will adopt default noise sampler.
else:
device = get_module_device(self)
assert num_batches > 0 and not input_is_latent
if self.default_style_mode == 'mix' and random.random(
) < self.mix_prob:
styles = [
torch.randn((num_batches, input_dim)) for _ in range(2)
]
else:
styles = [torch.randn((num_batches, input_dim))]
styles = [s.to(device) for s in styles]
# no amp for style-mapping and condition-embedding
if not input_is_latent:
noise_batch = styles
if self.cond_size is not None and self.cond_size > 0:
assert label is not None, (
'\'cond_channels\' is not None, \'cond\' must be passed.')
assert label.shape[1] == self.cond_size
embedding = self.embed(label)
# NOTE: If conditional input is passed, do norm for cond
# embedding and noise input respectively
# do pixel_norm (2nd_moment_norm) to cond embedding
embedding = self.pixel_norm(embedding)
# do pixel_norm (2nd_moment_norm) to noise input
styles = [self.pixel_norm(s) for s in styles]
styles_list = []
for s in styles:
if self.cond_size is not None and self.cond_size > 0:
s = torch.cat([s, embedding], dim=1)
styles_list.append(self.style_mapping(s))
styles = styles_list
else:
noise_batch = None
# update w_avg during training, if need
if hasattr(self, 'w_avg') and self.training and update_ws:
# only update w_avg with the first style code
self.w_avg.copy_(styles[0].detach().mean(
dim=0).lerp(self.w_avg, self.w_avg_beta))
if injected_noise is None:
if randomize_noise:
injected_noise = [None] * self.num_injected_noises
else:
injected_noise = [
getattr(self, f'injected_noise_{i}')
for i in range(self.num_injected_noises)
]
# use truncation trick
if truncation < 1:
style_t = []
# calculate truncation latent on the fly
if truncation_latent is None and not hasattr(
self, 'truncation_latent'):
self.truncation_latent = self.get_mean_latent()
truncation_latent = self.truncation_latent
elif truncation_latent is None and hasattr(self,
'truncation_latent'):
truncation_latent = self.truncation_latent
for style in styles:
style_t.append(truncation_latent + truncation *
(style - truncation_latent))
styles = style_t
# no style mixing
if len(styles) < 2:
inject_index = self.num_latents
if styles[0].ndim < 3:
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else:
latent = styles[0]
# style mixing
else:
if inject_index is None:
inject_index = random.randint(1, self.num_latents - 1)
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(
1, self.num_latents - inject_index, 1)
latent = torch.cat([latent, latent2], 1)
if return_latent_only:
return latent
feats = []
with autocast(enabled=self.fp16_enabled):
# 4x4 stage
out = self.constant_input(latent)
if self.fp16_enabled:
out = out.to(torch.float16)
out = self.conv1(
out,
latent[:, 0],
noise=injected_noise[0],
add_noise=add_noise)
feats.append(out)
skip = self.to_rgb1(out, latent[:, 1])
_index = 1
# 8x8 ---> higher resolutions
for up_conv, conv, noise1, noise2, to_rgb in zip(
self.convs[::2], self.convs[1::2], injected_noise[1::2],
injected_noise[2::2], self.to_rgbs):
out = up_conv(
out, latent[:, _index], noise=noise1, add_noise=add_noise)
out = conv(
out,
latent[:, _index + 1],
noise=noise2,
add_noise=add_noise)
feats.append(out)
skip = to_rgb(out, latent[:, _index + 2], skip)
_index += 2
# make sure the output image is torch.float32 to avoid RunTime Error
# in other modules
img = skip.to(torch.float32)
if self.bgr2rgb:
img = torch.flip(img, dims=[1])
if return_latents or return_noise:
if return_features:
output_dict = dict(
fake_img=img,
latent=latent,
inject_index=inject_index,
noise_batch=noise_batch,
feats=feats[feat_idx])
return output_dict
output_dict = dict(
fake_img=img,
latent=latent,
inject_index=inject_index,
noise_batch=noise_batch)
return output_dict
return img