mmagic.models.losses.loss_comps.gan_loss_comps
¶
Module Contents¶
Classes¶
Define GAN loss. |
- class mmagic.models.losses.loss_comps.gan_loss_comps.GANLossComps(gan_type: str, real_label_val: float = 1.0, fake_label_val: float = 0.0, loss_weight: float = 1.0)[source]¶
Bases:
torch.nn.Module
Define GAN loss.
- Parameters
gan_type (str) – Support ‘vanilla’, ‘lsgan’, ‘wgan’, ‘hinge’, ‘wgan-logistic-ns’.
real_label_val (float) – The value for real label. Default: 1.0.
fake_label_val (float) – The value for fake label. Default: 0.0.
loss_weight (float) – Loss weight. Default: 1.0. Note that loss_weight is only for generators; and it is always 1.0 for discriminators.
- _wgan_loss(input: torch.Tensor, target: bool) torch.Tensor [source]¶
wgan loss.
- Parameters
input (Tensor) – Input tensor.
target (bool) – Target label.
- Returns
wgan loss.
- Return type
Tensor
- _wgan_logistic_ns_loss(input: torch.Tensor, target: bool) torch.Tensor [source]¶
WGAN loss in logistically non-saturating mode.
This loss is widely used in StyleGANv2.
- Parameters
input (Tensor) – Input tensor.
target (bool) – Target label.
- Returns
wgan loss.
- Return type
Tensor
- get_target_label(input: torch.Tensor, target_is_real: bool) Union[bool, torch.Tensor] [source]¶
Get target label.
- Parameters
input (Tensor) – Input tensor.
target_is_real (bool) – Whether the target is real or fake.
- Returns
Target tensor. Return bool for wgan, otherwise, return Tensor.
- Return type
(bool | Tensor)
- forward(input: torch.Tensor, target_is_real: bool, is_disc: bool = False) torch.Tensor [source]¶
- Parameters
input (Tensor) – The input for the loss module, i.e., the network prediction.
target_is_real (bool) – Whether the targe is real or fake.
is_disc (bool) – Whether the loss for discriminators or not. Default: False.
- Returns
GAN loss value.
- Return type
Tensor