Shortcuts

mmagic.models.losses.loss_comps.gan_loss_comps

Module Contents

Classes

GANLossComps

Define GAN loss.

class mmagic.models.losses.loss_comps.gan_loss_comps.GANLossComps(gan_type: str, real_label_val: float = 1.0, fake_label_val: float = 0.0, loss_weight: float = 1.0)[source]

Bases: torch.nn.Module

Define GAN loss.

Parameters
  • gan_type (str) – Support ‘vanilla’, ‘lsgan’, ‘wgan’, ‘hinge’, ‘wgan-logistic-ns’.

  • real_label_val (float) – The value for real label. Default: 1.0.

  • fake_label_val (float) – The value for fake label. Default: 0.0.

  • loss_weight (float) – Loss weight. Default: 1.0. Note that loss_weight is only for generators; and it is always 1.0 for discriminators.

_wgan_loss(input: torch.Tensor, target: bool) torch.Tensor[source]

wgan loss.

Parameters
  • input (Tensor) – Input tensor.

  • target (bool) – Target label.

Returns

wgan loss.

Return type

Tensor

_wgan_logistic_ns_loss(input: torch.Tensor, target: bool) torch.Tensor[source]

WGAN loss in logistically non-saturating mode.

This loss is widely used in StyleGANv2.

Parameters
  • input (Tensor) – Input tensor.

  • target (bool) – Target label.

Returns

wgan loss.

Return type

Tensor

get_target_label(input: torch.Tensor, target_is_real: bool) Union[bool, torch.Tensor][source]

Get target label.

Parameters
  • input (Tensor) – Input tensor.

  • target_is_real (bool) – Whether the target is real or fake.

Returns

Target tensor. Return bool for wgan, otherwise, return Tensor.

Return type

(bool | Tensor)

forward(input: torch.Tensor, target_is_real: bool, is_disc: bool = False) torch.Tensor[source]
Parameters
  • input (Tensor) – The input for the loss module, i.e., the network prediction.

  • target_is_real (bool) – Whether the targe is real or fake.

  • is_disc (bool) – Whether the loss for discriminators or not. Default: False.

Returns

GAN loss value.

Return type

Tensor

Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.