Shortcuts

Source code for mmagic.models.losses.composition_loss

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch
import torch.nn as nn

from mmagic.registry import MODELS
from .pixelwise_loss import charbonnier_loss, l1_loss, mse_loss

[docs]_reduction_modes = ['none', 'mean', 'sum']
@MODELS.register_module()
[docs]class L1CompositionLoss(nn.Module): """L1 composition loss. Args: loss_weight (float): Loss weight for L1 loss. Default: 1.0. reduction (str): Specifies the reduction to apply to the output. Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. sample_wise (bool): Whether calculate the loss sample-wise. This argument only takes effect when `reduction` is 'mean' and `weight` (argument of `forward()`) is not None. It will first reduces loss with 'mean' per-sample, and then it means over all the samples. Default: False. """ def __init__(self, loss_weight: float = 1.0, reduction: str = 'mean', sample_wise: bool = False) -> None: super().__init__() if reduction not in ['none', 'mean', 'sum']: raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') self.loss_weight = loss_weight self.reduction = reduction self.sample_wise = sample_wise
[docs] def forward(self, pred_alpha: torch.Tensor, fg: torch.Tensor, bg: torch.Tensor, ori_merged: torch.Tensor, weight: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: """ Args: pred_alpha (Tensor): of shape (N, 1, H, W). Predicted alpha matte. fg (Tensor): of shape (N, 3, H, W). Tensor of foreground object. bg (Tensor): of shape (N, 3, H, W). Tensor of background object. ori_merged (Tensor): of shape (N, 3, H, W). Tensor of origin merged image before normalized by ImageNet mean and std. weight (Tensor, optional): of shape (N, 1, H, W). It is an indicating matrix: weight[trimap == 128] = 1. Default: None. """ pred_merged = pred_alpha * fg + (1. - pred_alpha) * bg if weight is not None: weight = weight.expand(-1, 3, -1, -1) return self.loss_weight * l1_loss( pred_merged, ori_merged, weight, reduction=self.reduction, sample_wise=self.sample_wise)
@MODELS.register_module()
[docs]class MSECompositionLoss(nn.Module): """MSE (L2) composition loss. Args: loss_weight (float): Loss weight for MSE loss. Default: 1.0. reduction (str): Specifies the reduction to apply to the output. Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. sample_wise (bool): Whether calculate the loss sample-wise. This argument only takes effect when `reduction` is 'mean' and `weight` (argument of `forward()`) is not None. It will first reduces loss with 'mean' per-sample, and then it means over all the samples. Default: False. """ def __init__(self, loss_weight: float = 1.0, reduction: str = 'mean', sample_wise: bool = False) -> None: super().__init__() if reduction not in ['none', 'mean', 'sum']: raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') self.loss_weight = loss_weight self.reduction = reduction self.sample_wise = sample_wise
[docs] def forward(self, pred_alpha: torch.Tensor, fg: torch.Tensor, bg: torch.Tensor, ori_merged: torch.Tensor, weight: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: """ Args: pred_alpha (Tensor): of shape (N, 1, H, W). Predicted alpha matte. fg (Tensor): of shape (N, 3, H, W). Tensor of foreground object. bg (Tensor): of shape (N, 3, H, W). Tensor of background object. ori_merged (Tensor): of shape (N, 3, H, W). Tensor of origin merged image before normalized by ImageNet mean and std. weight (Tensor, optional): of shape (N, 1, H, W). It is an indicating matrix: weight[trimap == 128] = 1. Default: None. """ pred_merged = pred_alpha * fg + (1. - pred_alpha) * bg if weight is not None: weight = weight.expand(-1, 3, -1, -1) return self.loss_weight * mse_loss( pred_merged, ori_merged, weight, reduction=self.reduction, sample_wise=self.sample_wise)
@MODELS.register_module()
[docs]class CharbonnierCompLoss(nn.Module): """Charbonnier composition loss. Args: loss_weight (float): Loss weight for L1 loss. Default: 1.0. reduction (str): Specifies the reduction to apply to the output. Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. sample_wise (bool): Whether calculate the loss sample-wise. This argument only takes effect when `reduction` is 'mean' and `weight` (argument of `forward()`) is not None. It will first reduces loss with 'mean' per-sample, and then it means over all the samples. Default: False. eps (float): A value used to control the curvature near zero. Default: 1e-12. """ def __init__(self, loss_weight: float = 1.0, reduction: str = 'mean', sample_wise: bool = False, eps: bool = 1e-12) -> None: super().__init__() if reduction not in ['none', 'mean', 'sum']: raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') self.loss_weight = loss_weight self.reduction = reduction self.sample_wise = sample_wise self.eps = eps
[docs] def forward(self, pred_alpha: torch.Tensor, fg: torch.Tensor, bg: torch.Tensor, ori_merged: torch.Tensor, weight: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: """ Args: pred_alpha (Tensor): of shape (N, 1, H, W). Predicted alpha matte. fg (Tensor): of shape (N, 3, H, W). Tensor of foreground object. bg (Tensor): of shape (N, 3, H, W). Tensor of background object. ori_merged (Tensor): of shape (N, 3, H, W). Tensor of origin merged image before normalized by ImageNet mean and std. weight (Tensor, optional): of shape (N, 1, H, W). It is an indicating matrix: weight[trimap == 128] = 1. Default: None. """ pred_merged = pred_alpha * fg + (1. - pred_alpha) * bg if weight is not None: weight = weight.expand(-1, 3, -1, -1) return self.loss_weight * charbonnier_loss( pred_merged, ori_merged, weight, eps=self.eps, reduction=self.reduction, sample_wise=self.sample_wise)
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.