Source code for mmagic.models.editors.stylegan1.stylegan_utils
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from ...utils import get_module_device
@torch.no_grad()
[docs]def get_mean_latent(generator, num_samples=4096, bs_per_repeat=1024):
"""Get mean latent of W space in Style-based GANs.
Args:
generator (BaseModule): Generator of a Style-based GAN.
num_samples (int, optional): Number of sample times. Defaults to 4096.
bs_per_repeat (int, optional): Batch size of noises per sample.
Defaults to 1024.
Returns:
Tensor: Mean latent of this generator.
"""
device = get_module_device(generator)
mean_style = None
n_repeat = num_samples // bs_per_repeat
assert n_repeat * bs_per_repeat == num_samples
for _ in range(n_repeat):
style = generator.style_mapping(
torch.randn(bs_per_repeat,
generator.style_channels).to(device)).mean(
0, keepdim=True)
if mean_style is None:
mean_style = style
else:
mean_style += style
mean_style /= float(n_repeat)
return mean_style
@torch.no_grad()
[docs]def style_mixing(generator,
n_source,
n_target,
inject_index=1,
truncation_latent=None,
truncation=0.7,
style_channels=512,
**kwargs):
device = get_module_device(generator)
source_code = torch.randn(n_source, style_channels).to(device)
target_code = torch.randn(n_target, style_channels).to(device)
source_image = generator(
source_code,
truncation_latent=truncation_latent,
truncation=truncation,
**kwargs)
h, w = source_image.shape[-2:]
images = [torch.ones(1, 3, h, w).to(device) * -1]
target_image = generator(
target_code,
truncation_latent=truncation_latent,
truncation=truncation,
**kwargs)
images.append(source_image)
for i in range(n_target):
image = generator(
[target_code[i].unsqueeze(0).repeat(n_source, 1), source_code],
truncation_latent=truncation_latent,
truncation=truncation,
inject_index=inject_index,
**kwargs)
images.append(target_image[i].unsqueeze(0))
images.append(image)
images = torch.cat(images, 0)
return images