Source code for mmagic.models.losses.pixelwise_loss
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmagic.registry import MODELS
from .loss_wrapper import masked_loss
@masked_loss
[docs]def l1_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""L1 loss.
Args:
pred (Tensor): Prediction Tensor with shape (n, c, h, w).
target ([type]): Target Tensor with shape (n, c, h, w).
Returns:
Tensor: Calculated L1 loss.
"""
return F.l1_loss(pred, target, reduction='none')
@masked_loss
[docs]def mse_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""MSE loss.
Args:
pred (Tensor): Prediction Tensor with shape (n, c, h, w).
target ([type]): Target Tensor with shape (n, c, h, w).
Returns:
Tensor: Calculated MSE loss.
"""
return F.mse_loss(pred, target, reduction='none')
@masked_loss
[docs]def charbonnier_loss(pred: torch.Tensor,
target: torch.Tensor,
eps: float = 1e-12) -> torch.Tensor:
"""Charbonnier loss.
Args:
pred (Tensor): Prediction Tensor with shape (n, c, h, w).
target ([type]): Target Tensor with shape (n, c, h, w).
eps (float): A value used to control the curvature near zero.
Default: 1e-12.
Returns:
Tensor: Calculated Charbonnier loss.
"""
return torch.sqrt((pred - target)**2 + eps)
[docs]def tv_loss(input: torch.Tensor) -> torch.Tensor:
"""L2 total variation loss, as in Mahendran et al."""
input = F.pad(input, (0, 1, 0, 1), 'replicate')
x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
return (x_diff**2 + y_diff**2).mean([1, 2, 3])
@MODELS.register_module()
[docs]class L1Loss(nn.Module):
"""L1 (mean absolute error, MAE) loss.
Args:
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
sample_wise (bool): Whether calculate the loss sample-wise. This
argument only takes effect when `reduction` is 'mean' and `weight`
(argument of `forward()`) is not None. It will first reduce loss
with 'mean' per-sample, and then it means over all the samples.
Default: False.
"""
def __init__(self,
loss_weight: float = 1.0,
reduction: str = 'mean',
sample_wise: bool = False) -> None:
super().__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. '
f'Supported ones are: {_reduction_modes}')
self.loss_weight = loss_weight
self.reduction = reduction
self.sample_wise = sample_wise
[docs] def forward(self,
pred: torch.Tensor,
target: torch.Tensor,
weight: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
"""Forward Function.
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
weights. Default: None.
"""
return self.loss_weight * l1_loss(
pred,
target,
weight,
reduction=self.reduction,
sample_wise=self.sample_wise)
@MODELS.register_module()
[docs]class MSELoss(nn.Module):
"""MSE (L2) loss.
Args:
loss_weight (float): Loss weight for MSE loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
sample_wise (bool): Whether calculate the loss sample-wise. This
argument only takes effect when `reduction` is 'mean' and `weight`
(argument of `forward()`) is not None. It will first reduces loss
with 'mean' per-sample, and then it means over all the samples.
Default: False.
"""
def __init__(self,
loss_weight: float = 1.0,
reduction: str = 'mean',
sample_wise: bool = False) -> None:
super().__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. '
f'Supported ones are: {_reduction_modes}')
self.loss_weight = loss_weight
self.reduction = reduction
self.sample_wise = sample_wise
[docs] def forward(self,
pred: torch.Tensor,
target: torch.Tensor,
weight: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
"""Forward Function.
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
weights. Default: None.
"""
return self.loss_weight * mse_loss(
pred,
target,
weight,
reduction=self.reduction,
sample_wise=self.sample_wise)
@MODELS.register_module()
[docs]class CharbonnierLoss(nn.Module):
"""Charbonnier loss (one variant of Robust L1Loss, a differentiable variant
of L1Loss).
Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
Super-Resolution".
Args:
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
sample_wise (bool): Whether calculate the loss sample-wise. This
argument only takes effect when `reduction` is 'mean' and `weight`
(argument of `forward()`) is not None. It will first reduces loss
with 'mean' per-sample, and then it means over all the samples.
Default: False.
eps (float): A value used to control the curvature near zero.
Default: 1e-12.
"""
def __init__(self,
loss_weight: float = 1.0,
reduction: str = 'mean',
sample_wise: bool = False,
eps: float = 1e-12) -> None:
super().__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. '
f'Supported ones are: {_reduction_modes}')
self.loss_weight = loss_weight
self.reduction = reduction
self.sample_wise = sample_wise
self.eps = eps
[docs] def forward(self,
pred: torch.Tensor,
target: torch.Tensor,
weight: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
"""Forward Function.
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
weights. Default: None.
"""
return self.loss_weight * charbonnier_loss(
pred,
target,
weight,
eps=self.eps,
reduction=self.reduction,
sample_wise=self.sample_wise)
@MODELS.register_module()
[docs]class MaskedTVLoss(L1Loss):
"""Masked TV loss.
Args:
loss_weight (float, optional): Loss weight. Defaults to 1.0.
"""
def __init__(self, loss_weight: float = 1.0) -> None:
super().__init__(loss_weight=loss_weight)
[docs] def forward(self,
pred: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Forward function.
Args:
pred (torch.Tensor): Tensor with shape of (n, c, h, w).
mask (torch.Tensor, optional): Tensor with shape of (n, 1, h, w).
Defaults to None.
Returns:
[type]: [description]
"""
y_diff = super().forward(
pred[:, :, :-1, :], pred[:, :, 1:, :], weight=mask[:, :, :-1, :])
x_diff = super().forward(
pred[:, :, :, :-1], pred[:, :, :, 1:], weight=mask[:, :, :, :-1])
loss = x_diff + y_diff
return loss
@MODELS.register_module()
[docs]class PSNRLoss(nn.Module):
"""PSNR Loss in "HINet: Half Instance Normalization Network for Image
Restoration".
Args:
loss_weight (float, optional): Loss weight. Defaults to 1.0.
reduction: reduction for PSNR. Can only be mean here.
toY: change to calculate the PSNR of Y channel in YCbCr format
"""
def __init__(self, loss_weight: float = 1.0, toY: bool = False) -> None:
super(PSNRLoss, self).__init__()
self.loss_weight = loss_weight
import numpy as np
self.scale = 10 / np.log(10)
self.toY = toY
self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1)
self.first = True
[docs] def forward(self, pred: torch.Tensor,
target: torch.Tensor) -> torch.Tensor:
assert len(pred.size()) == 4
return self.loss_weight * self.scale * torch.log((
(pred - target)**2).mean(dim=(1, 2, 3)) + 1e-8).mean()