Shortcuts

mmagic.models.losses.loss_wrapper

Module Contents

Functions

reduce_loss(→ torch.Tensor)

Reduce loss as specified.

mask_reduce_loss(→ torch.Tensor)

Apply element-wise weight and reduce loss.

masked_loss(loss_func)

Create a masked version of a given loss function.

mmagic.models.losses.loss_wrapper.reduce_loss(loss: torch.Tensor, reduction: str) torch.Tensor[source]

Reduce loss as specified.

Parameters
  • loss (Tensor) – Elementwise loss tensor.

  • reduction (str) – Options are “none”, “mean” and “sum”.

Returns

Reduced loss tensor.

Return type

Tensor

mmagic.models.losses.loss_wrapper.mask_reduce_loss(loss: torch.Tensor, weight: Optional[torch.Tensor] = None, reduction: str = 'mean', sample_wise: bool = False) torch.Tensor[source]

Apply element-wise weight and reduce loss.

Parameters
  • loss (Tensor) – Element-wise loss.

  • weight (Tensor) – Element-wise weights. Default: None.

  • reduction (str) – Same as built-in losses of PyTorch. Options are “none”, “mean” and “sum”. Default: ‘mean’.

  • sample_wise (bool) – Whether calculate the loss sample-wise. This argument only takes effect when reduction is ‘mean’ and weight (argument of forward()) is not None. It will first reduces loss with ‘mean’ per-sample, and then it means over all the samples. Default: False.

Returns

Processed loss values.

Return type

Tensor

mmagic.models.losses.loss_wrapper.masked_loss(loss_func)[source]

Create a masked version of a given loss function.

To use this decorator, the loss function must have the signature like loss_func(pred, target, **kwargs). The function only needs to compute element-wise loss without any reduction. This decorator will add weight and reduction arguments to the function. The decorated function will have the signature like loss_func(pred, target, weight=None, reduction=’mean’, avg_factor=None, **kwargs).

Example

>>> import torch
>>> @masked_loss
>>> def l1_loss(pred, target):
>>>     return (pred - target).abs()
>>> pred = torch.Tensor([0, 2, 3])
>>> target = torch.Tensor([1, 1, 1])
>>> weight = torch.Tensor([1, 0, 1])
>>> l1_loss(pred, target)
tensor(1.3333)
>>> l1_loss(pred, target, weight)
tensor(1.5000)
>>> l1_loss(pred, target, reduction='none')
tensor([1., 1., 2.])
>>> l1_loss(pred, target, weight, reduction='sum')
tensor(3.)
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.