mmagic.models.losses.gan_loss
¶
Module Contents¶
Classes¶
Define GAN loss. |
|
A Gaussian filter which blurs a given tensor with a two-dimensional |
|
Gradient penalty loss for wgan-gp. |
|
Disc shift loss. |
Functions¶
|
Calculate gradient penalty for wgan-gp. |
|
Disc Shift loss. |
|
Calculate R1 gradient penalty for WGAN-GP. |
|
Generator Path Regularization. |
- class mmagic.models.losses.gan_loss.GANLoss(gan_type: str, real_label_val: float = 1.0, fake_label_val: float = 0.0, loss_weight: float = 1.0)[source]¶
Bases:
torch.nn.Module
Define GAN loss.
- Parameters
gan_type (str) – Support ‘vanilla’, ‘lsgan’, ‘wgan’, ‘hinge’, ‘l1’.
real_label_val (float) – The value for real label. Default: 1.0.
fake_label_val (float) – The value for fake label. Default: 0.0.
loss_weight (float) – Loss weight. Default: 1.0. Note that loss_weight is only for generators; and it is always 1.0 for discriminators.
- _wgan_loss(input: torch.Tensor, target: bool) torch.Tensor [source]¶
wgan loss.
- Parameters
input (Tensor) – Input tensor.
target (bool) – Target label.
- Returns
wgan loss.
- Return type
Tensor
- get_target_label(input: torch.Tensor, target_is_real: bool) Union[bool, torch.Tensor] [source]¶
Get target label.
- Parameters
input (Tensor) – Input tensor.
target_is_real (bool) – Whether the target is real or fake.
- Returns
- Target tensor. Return bool for wgan, otherwise,
return Tensor.
- Return type
(bool | Tensor)
- forward(input: torch.Tensor, target_is_real: bool, is_disc: bool = False, mask: Optional[torch.Tensor] = None) torch.Tensor [source]¶
- Parameters
input (Tensor) – The input for the loss module, i.e., the network prediction.
target_is_real (bool) – Whether the target is real or fake.
is_disc (bool) – Whether the loss for discriminators or not. Default: False.
mask (Tensor) – The mask tensor. Default: None.
- Returns
GAN loss value.
- Return type
Tensor
- class mmagic.models.losses.gan_loss.GaussianBlur(kernel_size: Tuple[int, int] = (71, 71), sigma: Tuple[float, float] = (10.0, 10.0))[source]¶
Bases:
torch.nn.Module
A Gaussian filter which blurs a given tensor with a two-dimensional gaussian kernel by convolving it along each channel. Batch operation is supported.
This function is modified from kornia.filters.gaussian: <https://kornia.readthedocs.io/en/latest/_modules/kornia/filters/gaussian.html>.
- Parameters
kernel_size (tuple[int]) – The size of the kernel. Default: (71, 71).
sigma (tuple[float]) – The standard deviation of the kernel.
Default (10.0, 10.0) –
- Returns
The Gaussian-blurred tensor.
- Return type
Tensor
- Shape:
input: Tensor with shape of (n, c, h, w)
output: Tensor with shape of (n, c, h, w)
- static compute_zero_padding(kernel_size: Tuple[int, int]) tuple [source]¶
Compute zero padding tuple.
- Parameters
kernel_size (tuple[int]) – The size of the kernel.
- Returns
Padding of height and weight.
- Return type
tuple
- get_2d_gaussian_kernel(kernel_size: Tuple[int, int], sigma: Tuple[float, float]) torch.Tensor [source]¶
Get the two-dimensional Gaussian filter matrix coefficients.
- Parameters
kernel_size (tuple[int]) – Kernel filter size in the x and y direction. The kernel sizes should be odd and positive.
sigma (tuple[int]) – Gaussian standard deviation in the x and y direction.
- Returns
- A 2D torch tensor with gaussian filter
matrix coefficients.
- Return type
kernel_2d (Tensor)
- get_1d_gaussian_kernel(kernel_size: int, sigma: float) torch.Tensor [source]¶
Get the Gaussian filter coefficients in one dimension (x or y direction).
- Parameters
kernel_size (int) – Kernel filter size in x or y direction. Should be odd and positive.
sigma (float) – Gaussian standard deviation in x or y direction.
- Returns
- A 1D torch tensor with gaussian filter
coefficients in x or y direction.
- Return type
kernel_1d (Tensor)
- gaussian(kernel_size: int, sigma: float) torch.Tensor [source]¶
Gaussian function.
- Parameters
kernel_size (int) – Kernel filter size in x or y direction. Should be odd and positive.
sigma (float) – Gaussian standard deviation in x or y direction.
- Returns
- A 1D torch tensor with gaussian filter
coefficients in x or y direction.
- Return type
Tensor
- mmagic.models.losses.gan_loss.gradient_penalty_loss(discriminator: torch.nn.Module, real_data: torch.Tensor, fake_data: torch.Tensor, mask: Optional[torch.Tensor] = None, norm_mode: str = 'pixel') torch.Tensor [source]¶
Calculate gradient penalty for wgan-gp.
- Parameters
discriminator (nn.Module) – Network for the discriminator.
real_data (Tensor) – Real input data.
fake_data (Tensor) – Fake input data.
mask (Tensor) – Masks for inpainting. Default: None.
- Returns
A tensor for gradient penalty.
- Return type
Tensor
- class mmagic.models.losses.gan_loss.GradientPenaltyLoss(loss_weight: float = 1.0)[source]¶
Bases:
torch.nn.Module
Gradient penalty loss for wgan-gp.
- Parameters
loss_weight (float) – Loss weight. Default: 1.0.
- forward(discriminator: torch.nn.Module, real_data: torch.Tensor, fake_data: torch.Tensor, mask: Optional[torch.Tensor] = None) torch.Tensor [source]¶
Forward function.
- Parameters
discriminator (nn.Module) – Network for the discriminator.
real_data (Tensor) – Real input data.
fake_data (Tensor) – Fake input data.
mask (Tensor) – Masks for inpainting. Default: None.
- Returns
Loss.
- Return type
Tensor
- mmagic.models.losses.gan_loss.disc_shift_loss(pred: torch.Tensor) torch.Tensor [source]¶
Disc Shift loss.
This loss is proposed in PGGAN as an auxiliary loss for discriminator.
- Parameters
pred (Tensor) – Input tensor.
- Returns
loss tensor.
- Return type
torch.Tensor
- class mmagic.models.losses.gan_loss.DiscShiftLoss(loss_weight: float = 0.1)[source]¶
Bases:
torch.nn.Module
Disc shift loss.
- Parameters
loss_weight (float, optional) – Loss weight. Defaults to 1.0.
- mmagic.models.losses.gan_loss.r1_gradient_penalty_loss(discriminator: torch.nn.Module, real_data: torch.Tensor, mask: Optional[torch.Tensor] = None, norm_mode: str = 'pixel', loss_scaler: Optional[torch.cuda.amp.grad_scaler.GradScaler] = None, use_apex_amp: bool = False) torch.Tensor [source]¶
Calculate R1 gradient penalty for WGAN-GP.
R1 regularizer comes from: “Which Training Methods for GANs do actually Converge?” ICML’2018
Different from original gradient penalty, this regularizer only penalized gradient w.r.t. real data.
- Parameters
discriminator (nn.Module) – Network for the discriminator.
real_data (Tensor) – Real input data.
mask (Tensor) – Masks for inpainting. Default: None.
norm_mode (str) – This argument decides along which dimension the norm of the gradients will be calculated. Currently, we support [“pixel” , “HWC”]. Defaults to “pixel”.
- Returns
A tensor for gradient penalty.
- Return type
Tensor
- mmagic.models.losses.gan_loss.gen_path_regularizer(generator: torch.nn.Module, num_batches: int, mean_path_length: torch.Tensor, pl_batch_shrink: int = 1, decay: float = 0.01, weight: float = 1.0, pl_batch_size: Optional[int] = None, sync_mean_buffer: bool = False, loss_scaler: Optional[torch.cuda.amp.grad_scaler.GradScaler] = None, use_apex_amp: bool = False) Tuple[torch.Tensor] [source]¶
Generator Path Regularization.
Path regularization is proposed in StyleGAN2, which can help the improve the continuity of the latent space. More details can be found in: Analyzing and Improving the Image Quality of StyleGAN, CVPR2020.
- Parameters
generator (nn.Module) – The generator module. Note that this loss requires that the generator contains
return_latents
interface, with which we can get the latent code of the current sample.num_batches (int) – The number of samples used in calculating this loss.
mean_path_length (Tensor) – The mean path length, calculated by moving average.
pl_batch_shrink (int, optional) – The factor of shrinking the batch size for saving GPU memory. Defaults to 1.
decay (float, optional) – Decay for moving average of mean path length. Defaults to 0.01.
weight (float, optional) – Weight of this loss item. Defaults to
1.
.pl_batch_size (int | None, optional) – The batch size in calculating generator path. Once this argument is set, the
num_batches
will be overridden with this argument and won’t be affected bypl_batch_shrink
. Defaults to None.sync_mean_buffer (bool, optional) – Whether to sync mean path length across all of GPUs. Defaults to False.
- Returns
The penalty loss, detached mean path tensor, and current path length.
- Return type
tuple[Tensor]