mmagic.models.losses.composition_loss
¶
Module Contents¶
Classes¶
L1 composition loss. |
|
MSE (L2) composition loss. |
|
Charbonnier composition loss. |
Attributes¶
- class mmagic.models.losses.composition_loss.L1CompositionLoss(loss_weight: float = 1.0, reduction: str = 'mean', sample_wise: bool = False)[source]¶
Bases:
torch.nn.Module
L1 composition loss.
- Parameters
loss_weight (float) – Loss weight for L1 loss. Default: 1.0.
reduction (str) – Specifies the reduction to apply to the output. Supported choices are ‘none’ | ‘mean’ | ‘sum’. Default: ‘mean’.
sample_wise (bool) – Whether calculate the loss sample-wise. This argument only takes effect when reduction is ‘mean’ and weight (argument of forward()) is not None. It will first reduces loss with ‘mean’ per-sample, and then it means over all the samples. Default: False.
- forward(pred_alpha: torch.Tensor, fg: torch.Tensor, bg: torch.Tensor, ori_merged: torch.Tensor, weight: Optional[torch.Tensor] = None, **kwargs) torch.Tensor [source]¶
- Parameters
pred_alpha (Tensor) – of shape (N, 1, H, W). Predicted alpha matte.
fg (Tensor) – of shape (N, 3, H, W). Tensor of foreground object.
bg (Tensor) – of shape (N, 3, H, W). Tensor of background object.
ori_merged (Tensor) – of shape (N, 3, H, W). Tensor of origin merged image before normalized by ImageNet mean and std.
weight (Tensor, optional) – of shape (N, 1, H, W). It is an indicating matrix: weight[trimap == 128] = 1. Default: None.
- class mmagic.models.losses.composition_loss.MSECompositionLoss(loss_weight: float = 1.0, reduction: str = 'mean', sample_wise: bool = False)[source]¶
Bases:
torch.nn.Module
MSE (L2) composition loss.
- Parameters
loss_weight (float) – Loss weight for MSE loss. Default: 1.0.
reduction (str) – Specifies the reduction to apply to the output. Supported choices are ‘none’ | ‘mean’ | ‘sum’. Default: ‘mean’.
sample_wise (bool) – Whether calculate the loss sample-wise. This argument only takes effect when reduction is ‘mean’ and weight (argument of forward()) is not None. It will first reduces loss with ‘mean’ per-sample, and then it means over all the samples. Default: False.
- forward(pred_alpha: torch.Tensor, fg: torch.Tensor, bg: torch.Tensor, ori_merged: torch.Tensor, weight: Optional[torch.Tensor] = None, **kwargs) torch.Tensor [source]¶
- Parameters
pred_alpha (Tensor) – of shape (N, 1, H, W). Predicted alpha matte.
fg (Tensor) – of shape (N, 3, H, W). Tensor of foreground object.
bg (Tensor) – of shape (N, 3, H, W). Tensor of background object.
ori_merged (Tensor) – of shape (N, 3, H, W). Tensor of origin merged image before normalized by ImageNet mean and std.
weight (Tensor, optional) – of shape (N, 1, H, W). It is an indicating matrix: weight[trimap == 128] = 1. Default: None.
- class mmagic.models.losses.composition_loss.CharbonnierCompLoss(loss_weight: float = 1.0, reduction: str = 'mean', sample_wise: bool = False, eps: bool = 1e-12)[source]¶
Bases:
torch.nn.Module
Charbonnier composition loss.
- Parameters
loss_weight (float) – Loss weight for L1 loss. Default: 1.0.
reduction (str) – Specifies the reduction to apply to the output. Supported choices are ‘none’ | ‘mean’ | ‘sum’. Default: ‘mean’.
sample_wise (bool) – Whether calculate the loss sample-wise. This argument only takes effect when reduction is ‘mean’ and weight (argument of forward()) is not None. It will first reduces loss with ‘mean’ per-sample, and then it means over all the samples. Default: False.
eps (float) – A value used to control the curvature near zero. Default: 1e-12.
- forward(pred_alpha: torch.Tensor, fg: torch.Tensor, bg: torch.Tensor, ori_merged: torch.Tensor, weight: Optional[torch.Tensor] = None, **kwargs) torch.Tensor [source]¶
- Parameters
pred_alpha (Tensor) – of shape (N, 1, H, W). Predicted alpha matte.
fg (Tensor) – of shape (N, 3, H, W). Tensor of foreground object.
bg (Tensor) – of shape (N, 3, H, W). Tensor of background object.
ori_merged (Tensor) – of shape (N, 3, H, W). Tensor of origin merged image before normalized by ImageNet mean and std.
weight (Tensor, optional) – of shape (N, 1, H, W). It is an indicating matrix: weight[trimap == 128] = 1. Default: None.