Shortcuts

Source code for mmagic.models.editors.stylegan3.stylegan3

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Union

import numpy as np
import torch
import torch.nn as nn
from mmengine import Config, MessageHub
from mmengine.optim import OptimWrapper
from torch import Tensor

from mmagic.registry import MODELS
from mmagic.structures import DataSample
from mmagic.utils.typing import SampleList
from ...utils import get_module_device, get_valid_num_batches
from ..stylegan2 import StyleGAN2
from .stylegan3_utils import (apply_fractional_pseudo_rotation,
                              apply_fractional_rotation,
                              apply_fractional_translation,
                              apply_integer_translation, rotation_matrix)

[docs]ModelType = Union[Dict, nn.Module]
@MODELS.register_module()
[docs]class StyleGAN3(StyleGAN2): """Implementation of `Alias-Free Generative Adversarial Networks`. # noqa. Paper link: https://nvlabs-fi-cdn.nvidia.com/stylegan3/stylegan3-paper.pdf # noqa Detailed architecture can be found in :class:`~mmagic.models.editors.stylegan3.StyleGAN3Generator` and :class:`~mmagic.models.editors.stylegan2.StyleGAN2Discriminator` """ def __init__(self, generator: ModelType, discriminator: Optional[ModelType] = None, data_preprocessor: Optional[Union[dict, Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, forward_kwargs: Optional[Dict] = None, ema_config: Optional[Dict] = None, loss_config=dict()): super().__init__(generator, discriminator, data_preprocessor, generator_steps, discriminator_steps, ema_config, loss_config) self.noise_size = getattr(self.generator, 'noise_size', 512) forward_kwargs = dict() if forward_kwargs is None else forward_kwargs disc_default_forward_kwargs = dict(update_emas=True, force_fp32=False) gen_default_forward_kwargs = dict(force_fp32=False) forward_kwargs.setdefault('disc', disc_default_forward_kwargs) forward_kwargs.setdefault('gen', gen_default_forward_kwargs) self.forward_kwargs = forward_kwargs
[docs] def test_step(self, data: dict) -> SampleList: """Gets the generated image of given data. Same as :meth:`val_step`. Args: data (dict): Data sampled from metric specific sampler. More details in `Metrics` and `Evaluator`. Returns: SampleList: A list of ``DataSample`` contain generated results. """ data = self.data_preprocessor(data) inputs_dict, data_samples = data['inputs'], data['data_samples'] # hard code to compute equivarience if 'mode' in inputs_dict and 'eq_cfg' in inputs_dict['mode']: batch_size = get_valid_num_batches(inputs_dict, data_samples) outputs = self.sample_equivarience_pairs( batch_size, sample_mode=inputs_dict['mode']['sample_mode'], eq_cfg=inputs_dict['mode']['eq_cfg'], sample_kwargs=inputs_dict['mode']['sample_kwargs']) else: outputs = self(inputs_dict, data_samples) return outputs
[docs] def val_step(self, data: dict) -> SampleList: """Gets the generated image of given data. Same as :meth:`val_step`. Args: data (dict): Data sampled from metric specific sampler. More details in `Metrics` and `Evaluator`. Returns: SampleList: A list of ``DataSample`` contain generated results. """ data = self.data_preprocessor(data) inputs_dict, data_samples = data['inputs'], data['data_samples'] # hard code to compute equivarience if 'mode' in inputs_dict and 'eq_cfg' in inputs_dict['mode']: batch_size = get_valid_num_batches(inputs_dict, data_samples) outputs = self.sample_equivarience_pairs( batch_size, sample_mode=inputs_dict['mode']['sample_mode'], eq_cfg=inputs_dict['mode']['eq_cfg'], sample_kwargs=inputs_dict['mode']['sample_kwargs']) else: outputs = self(inputs_dict, data_samples) return outputs
[docs] def train_discriminator(self, inputs: dict, data_samples: DataSample, optimizer_wrapper: OptimWrapper ) -> Dict[str, Tensor]: """Train discriminator. Args: inputs (dict): Inputs from dataloader. data_samples (DataSample): Data samples from dataloader. optim_wrapper (OptimWrapper): OptimWrapper instance used to update model parameters. Returns: Dict[str, Tensor]: A ``dict`` of tensor for logging. """ real_imgs = data_samples.gt_img num_batches = real_imgs.shape[0] noise_batch = self.noise_fn(num_batches=num_batches) with torch.no_grad(): fake_imgs = self.generator( noise_batch, return_noise=False, **self.forward_kwargs['disc']) disc_pred_fake = self.discriminator(fake_imgs) disc_pred_real = self.discriminator(real_imgs) parsed_losses, log_vars = self.disc_loss(disc_pred_fake, disc_pred_real, real_imgs) optimizer_wrapper.update_params(parsed_losses) # save ada info message_hub = MessageHub.get_current_instance() message_hub.update_info('disc_pred_real', disc_pred_real) return log_vars
[docs] def train_generator(self, inputs: dict, data_samples: DataSample, optimizer_wrapper: OptimWrapper) -> Dict[str, Tensor]: """Train generator. Args: inputs (dict): Inputs from dataloader. data_samples (DataSample): Data samples from dataloader. Do not used in generator's training. optim_wrapper (OptimWrapper): OptimWrapper instance used to update model parameters. Returns: Dict[str, Tensor]: A ``dict`` of tensor for logging. """ num_batches = len(data_samples) noise = self.noise_fn(num_batches=num_batches) fake_imgs = self.generator( noise, return_noise=False, **self.forward_kwargs['gen']) disc_pred_fake = self.discriminator(fake_imgs) parsed_loss, log_vars = self.gen_loss(disc_pred_fake, num_batches) optimizer_wrapper.update_params(parsed_loss) return log_vars
[docs] def sample_equivarience_pairs(self, batch_size, sample_mode='ema', eq_cfg=dict( compute_eqt_int=False, compute_eqt_frac=False, compute_eqr=False, translate_max=0.125, rotate_max=1), sample_kwargs=dict()): generator = self.generator if (sample_mode == 'orig') else self.generator_ema if hasattr(generator, 'module'): generator = generator.module device = get_module_device(generator) identity_matrix = torch.eye(3, device=device) # Run mapping network. z = torch.randn([batch_size, self.noise_size], device=device) ws = generator.style_mapping(z=z) transform_matrix = getattr( getattr(getattr(generator, 'synthesis', None), 'input', None), 'transform', None) # Generate reference image. transform_matrix[:] = identity_matrix orig = generator.synthesis(ws=ws, **sample_kwargs) batch_sample = [DataSample() for _ in range(batch_size)] # Integer translation (EQ-T). if eq_cfg['compute_eqt_int']: t = (torch.rand(2, device=device) * 2 - 1) * eq_cfg['translate_max'] t = (t * generator.out_size).round() / generator.out_size transform_matrix[:] = identity_matrix transform_matrix[:2, 2] = -t img = generator.synthesis(ws=ws, **sample_kwargs) ref, mask = apply_integer_translation(orig, t[0], t[1]) diff = (ref - img).square() * mask for idx in range(batch_size): data_sample = batch_sample[idx] setattr(data_sample, 'eqt_int', DataSample(diff=diff, mask=mask)) # Fractional translation (EQ-T_frac). if eq_cfg['compute_eqt_frac']: t = (torch.rand(2, device=device) * 2 - 1) * eq_cfg['translate_max'] transform_matrix[:] = identity_matrix transform_matrix[:2, 2] = -t img = generator.synthesis(ws=ws, **sample_kwargs) ref, mask = apply_fractional_translation(orig, t[0], t[1]) diff = (ref - img).square() * mask for idx in range(batch_size): data_sample = batch_sample[idx] setattr(data_sample, 'eqt_frac', DataSample(diff=diff, mask=mask)) # Rotation (EQ-R). if eq_cfg['compute_eqr']: angle = (torch.rand([], device=device) * 2 - 1) * ( eq_cfg['rotate_max'] * np.pi) transform_matrix[:] = rotation_matrix(-angle) img = generator.synthesis(ws=ws, **sample_kwargs) ref, ref_mask = apply_fractional_rotation(orig, angle) pseudo, pseudo_mask = apply_fractional_pseudo_rotation(img, angle) mask = ref_mask * pseudo_mask diff = (ref - pseudo).square() * mask for idx in range(batch_size): data_sample = batch_sample[idx] setattr(data_sample, 'eqr', DataSample(diff=diff, mask=mask)) return batch_sample
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.