Source code for mmagic.models.losses.loss_comps.gan_loss_comps
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmagic.registry import MODELS
@MODELS.register_module()
[docs]class GANLossComps(nn.Module):
"""Define GAN loss.
Args:
gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge',
'wgan-logistic-ns'.
real_label_val (float): The value for real label. Default: 1.0.
fake_label_val (float): The value for fake label. Default: 0.0.
loss_weight (float): Loss weight. Default: 1.0.
Note that loss_weight is only for generators; and it is always 1.0
for discriminators.
"""
def __init__(self,
gan_type: str,
real_label_val: float = 1.0,
fake_label_val: float = 0.0,
loss_weight: float = 1.0) -> None:
super().__init__()
self.gan_type = gan_type
self.loss_weight = loss_weight
self.real_label_val = real_label_val
self.fake_label_val = fake_label_val
if self.gan_type == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif self.gan_type == 'lsgan':
self.loss = nn.MSELoss()
elif self.gan_type == 'wgan':
self.loss = self._wgan_loss
elif self.gan_type == 'wgan-logistic-ns':
self.loss = self._wgan_logistic_ns_loss
elif self.gan_type == 'hinge':
self.loss = nn.ReLU()
else:
raise NotImplementedError(
f'GAN type {self.gan_type} is not implemented.')
[docs] def _wgan_loss(self, input: torch.Tensor, target: bool) -> torch.Tensor:
"""wgan loss.
Args:
input (Tensor): Input tensor.
target (bool): Target label.
Returns:
Tensor: wgan loss.
"""
return -input.mean() if target else input.mean()
[docs] def _wgan_logistic_ns_loss(self, input: torch.Tensor,
target: bool) -> torch.Tensor:
"""WGAN loss in logistically non-saturating mode.
This loss is widely used in StyleGANv2.
Args:
input (Tensor): Input tensor.
target (bool): Target label.
Returns:
Tensor: wgan loss.
"""
return F.softplus(-input).mean() if target else F.softplus(
input).mean()
[docs] def get_target_label(self, input: torch.Tensor,
target_is_real: bool) -> Union[bool, torch.Tensor]:
"""Get target label.
Args:
input (Tensor): Input tensor.
target_is_real (bool): Whether the target is real or fake.
Returns:
(bool | Tensor): Target tensor. Return bool for wgan, otherwise, \
return Tensor.
"""
if self.gan_type in ['wgan', 'wgan-logistic-ns']:
return target_is_real
target_val = (
self.real_label_val if target_is_real else self.fake_label_val)
return input.new_ones(input.size()) * target_val
[docs] def forward(self,
input: torch.Tensor,
target_is_real: bool,
is_disc: bool = False) -> torch.Tensor:
"""
Args:
input (Tensor): The input for the loss module, i.e., the network
prediction.
target_is_real (bool): Whether the targe is real or fake.
is_disc (bool): Whether the loss for discriminators or not.
Default: False.
Returns:
Tensor: GAN loss value.
"""
target_label = self.get_target_label(input, target_is_real)
if self.gan_type == 'hinge':
if is_disc: # for discriminators in hinge-gan
input = -input if target_is_real else input
loss = self.loss(1 + input).mean()
else: # for generators in hinge-gan
loss = -input.mean()
else: # other gan types
loss = self.loss(input, target_label)
# loss_weight is always 1.0 for discriminators
return loss if is_disc else loss * self.loss_weight