Shortcuts

Source code for mmagic.models.editors.cyclegan.cyclegan

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F
from mmengine import MessageHub
from mmengine.optim import OptimWrapperDict

from mmagic.registry import MODELS
from mmagic.structures import DataSample
from mmagic.utils.typing import SampleList
from ...base_models import BaseTranslationModel
from ...utils import set_requires_grad
from .cyclegan_modules import GANImageBuffer


@MODELS.register_module()
[docs]class CycleGAN(BaseTranslationModel): """CycleGAN model for unpaired image-to-image translation. Ref: Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks """ def __init__(self, *args, buffer_size=50, loss_config=dict(cycle_loss_weight=10., id_loss_weight=0.5), **kwargs): super().__init__(*args, **kwargs) # GAN image buffers self.image_buffers = dict() self.buffer_size = buffer_size for domain in self._reachable_domains: self.image_buffers[domain] = GANImageBuffer(self.buffer_size) self.loss_config = loss_config
[docs] def forward_test(self, img, target_domain, **kwargs): """Forward function for testing. Args: img (tensor): Input image tensor. target_domain (str): Target domain of output image. kwargs (dict): Other arguments. Returns: dict: Forward results. """ # This is a trick for CycleGAN # ref: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/e1bdf46198662b0f4d0b318e24568205ec4d7aee/test.py#L54 # noqa self.train() target = self.translation(img, target_domain=target_domain, **kwargs) results = dict(source=img, target=target) return results
[docs] def _get_disc_loss(self, outputs): """Backward function for the discriminators. Args: outputs (dict): Dict of forward results. Returns: dict: Discriminators' loss and loss dict. """ discriminators = self.get_module(self.discriminators) log_vars_d = dict() loss_d = 0 # GAN loss for discriminators['a'] for domain in self._reachable_domains: losses = dict() fake_img = self.image_buffers[domain].query( outputs[f'fake_{domain}']) fake_pred = discriminators[domain](fake_img.detach()) losses[f'loss_gan_d_{domain}_fake'] = F.mse_loss( fake_pred, 0. * torch.ones_like(fake_pred)) real_pred = discriminators[domain](outputs[f'real_{domain}']) losses[f'loss_gan_d_{domain}_real'] = F.mse_loss( real_pred, 1. * torch.ones_like(real_pred)) _loss_d, _log_vars_d = self.parse_losses(losses) _loss_d *= 0.5 loss_d += _loss_d log_vars_d[f'loss_gan_d_{domain}'] = _log_vars_d['loss'] * 0.5 return loss_d, log_vars_d
[docs] def _get_gen_loss(self, outputs): """Backward function for the generators. Args: outputs (dict): Dict of forward results. Returns: dict: Generators' loss and loss dict. """ generators = self.get_module(self.generators) discriminators = self.get_module(self.discriminators) losses = dict() # gan loss for domain in self._reachable_domains: # Identity reconstruction for generators outputs[f'identity_{domain}'] = generators[domain]( outputs[f'real_{domain}']) # GAN loss for generators fake_pred = discriminators[domain](outputs[f'fake_{domain}']) # LSGAN loss losses[f'loss_gan_g_{domain}'] = F.mse_loss( fake_pred, 1. * torch.ones_like(fake_pred)) # cycle loss loss_weight = self.loss_config['cycle_loss_weight'] losses['cycle_loss'] = 0. for domain in self._reachable_domains: losses['cycle_loss'] += loss_weight * F.l1_loss( outputs[f'cycle_{domain}'], outputs[f'real_{domain}'], reduction='mean') # id loss loss_weight = self.loss_config['id_loss_weight'] if loss_weight != 0.: losses['id_loss'] = 0. for domain in self._reachable_domains: losses['id_loss'] += loss_weight * F.l1_loss( outputs[f'identity_{domain}'], outputs[f'real_{domain}'], reduction='mean') loss_g, log_vars_g = self.parse_losses(losses) return loss_g, log_vars_g
[docs] def _get_opposite_domain(self, domain): """Get the opposite domain respect to the input domain. Args: domain (str): The input domain. Returns: str: The opposite domain. """ for item in self._reachable_domains: if item != domain: return item return None
[docs] def train_step(self, data: dict, optim_wrapper: OptimWrapperDict): """Training step function. Args: data_batch (dict): Dict of the input data batch. optimizer (dict[torch.optim.Optimizer]): Dict of optimizers for the generators and discriminators. ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp. It is used to prepare for ``backward()`` in ddp. Defaults to None. running_status (dict | None, optional): Contains necessary basic information for training, e.g., iteration number. Defaults to None. Returns: dict: Dict of loss, information for logger, the number of samples\ and results for visualization. """ message_hub = MessageHub.get_current_instance() curr_iter = message_hub.get_info('iter') data = self.data_preprocessor(data, True) disc_optimizer_wrapper = optim_wrapper['discriminators'] inputs_dict = data['inputs'] outputs, log_vars = dict(), dict() # forward generators with disc_optimizer_wrapper.optim_context(self.discriminators): for target_domain in self._reachable_domains: # fetch data by domain source_domain = self.get_other_domains(target_domain)[0] img = inputs_dict[f'img_{source_domain}'] # translation process results = self( img, test_mode=False, target_domain=target_domain) outputs[f'real_{source_domain}'] = results['source'] outputs[f'fake_{target_domain}'] = results['target'] # cycle process results = self( results['target'], test_mode=False, target_domain=source_domain) outputs[f'cycle_{source_domain}'] = results['target'] # update discriminators disc_accu_iters = disc_optimizer_wrapper._accumulative_counts loss_d, log_vars_d = self._get_disc_loss(outputs) disc_optimizer_wrapper.update_params(loss_d) log_vars.update(log_vars_d) # generators, no updates to discriminator parameters. if ((curr_iter + 1) % (self.discriminator_steps * disc_accu_iters) == 0 and curr_iter >= self.disc_init_steps): set_requires_grad(self.discriminators, False) # update generator gen_optimizer_wrapper = optim_wrapper['generators'] with gen_optimizer_wrapper.optim_context(self.generators): loss_g, log_vars_g = self._get_gen_loss(outputs) gen_optimizer_wrapper.update_params(loss_g) log_vars.update(log_vars_g) set_requires_grad(self.discriminators, True) return log_vars
[docs] def test_step(self, data: dict) -> SampleList: """Gets the generated image of given data. Same as :meth:`val_step`. Args: data (dict): Data sampled from metric specific sampler. More details in `Metrics` and `Evaluator`. Returns: SampleList: A list of ``DataSample`` contain generated results. """ data = self.data_preprocessor(data) inputs_dict, data_samples = data['inputs'], data['data_samples'] outputs = {} for src_domain in self._reachable_domains: # Identity reconstruction for generators target_domain = self.get_other_domains(src_domain)[0] target = self.forward_test( inputs_dict[f'img_{src_domain}'], target_domain=target_domain)['target'] outputs[f'img_{target_domain}'] = target batch_sample_list = [] num_batches = next(iter(outputs.values())).shape[0] data_samples = data_samples.split() for idx in range(num_batches): gen_sample = DataSample() if data_samples: gen_sample.update(data_samples[idx]) for src_domain in self._reachable_domains: target_domain = self.get_other_domains(src_domain)[0] fake_img = outputs[f'img_{target_domain}'][idx] fake_img = self.data_preprocessor.destruct( fake_img, data_samples[idx], f'img_{target_domain}') setattr(gen_sample, f'fake_{target_domain}', fake_img) batch_sample_list.append(gen_sample) return batch_sample_list
[docs] def val_step(self, data: dict) -> SampleList: """Gets the generated image of given data. Same as :meth:`val_step`. Args: data (dict): Data sampled from metric specific sampler. More details in `Metrics` and `Evaluator`. Returns: SampleList: A list of ``DataSample`` contain generated results. """ data = self.data_preprocessor(data) inputs_dict, data_samples = data['inputs'], data['data_samples'] outputs = {} for src_domain in self._reachable_domains: # Identity reconstruction for generators target_domain = self.get_other_domains(src_domain)[0] target = self.forward_test( inputs_dict[f'img_{src_domain}'], target_domain=target_domain)['target'] outputs[f'img_{target_domain}'] = target batch_sample_list = [] num_batches = next(iter(outputs.values())).shape[0] data_samples = data_samples.split() for idx in range(num_batches): gen_sample = DataSample() if data_samples: gen_sample.update(data_samples[idx]) for src_domain in self._reachable_domains: target_domain = self.get_other_domains(src_domain)[0] fake_img = outputs[f'img_{target_domain}'][idx] fake_img = self.data_preprocessor.destruct( fake_img, data_samples[idx], f'img_{target_domain}') gen_sample.set_tensor_data({f'fake_{target_domain}': fake_img}) batch_sample_list.append(gen_sample) return batch_sample_list
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.