Shortcuts

# mmagic.models.losses.loss_wrapper 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import functools
from typing import Optional

import torch
import torch.nn.functional as F

[文档]def reduce_loss(loss: torch.Tensor, reduction: str) -> torch.Tensor:
"""Reduce loss as specified.

Args:
loss (Tensor): Elementwise loss tensor.
reduction (str): Options are "none", "mean" and "sum".

Returns:
Tensor: Reduced loss tensor.
"""
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2
if reduction_enum == 0:
return loss
if reduction_enum == 1:
return loss.mean()
if reduction_enum == 2:
return loss.sum()
raise ValueError(f'reduction type {reduction} not supported')

weight: Optional[torch.Tensor] = None,
reduction: str = 'mean',
sample_wise: bool = False) -> torch.Tensor:
"""Apply element-wise weight and reduce loss.

Args:
loss (Tensor): Element-wise loss.
weight (Tensor): Element-wise weights. Default: None.
reduction (str): Same as built-in losses of PyTorch. Options are
"none", "mean" and "sum". Default: 'mean'.
sample_wise (bool): Whether calculate the loss sample-wise. This
argument only takes effect when reduction is 'mean' and weight
(argument of forward()) is not None. It will first reduces loss
with 'mean' per-sample, and then it means over all the samples.
Default: False.

Returns:
Tensor: Processed loss values.
"""
# if weight is specified, apply element-wise weight
if weight is not None:
assert weight.dim() == loss.dim()
assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
loss = loss * weight

# if weight is not specified or reduction is sum, just reduce the loss
if weight is None or reduction == 'sum':
loss = reduce_loss(loss, reduction)
# if reduction is mean, then compute mean over masked region
elif reduction == 'mean':
# expand weight from N1HW to NCHW
if weight.size(1) == 1:
weight = weight.expand_as(loss)
# small value to prevent division by zero
eps = 1e-12

# perform sample-wise mean
if sample_wise:
weight = weight.sum(dim=[1, 2, 3], keepdim=True)  # NCHW to N111
loss = (loss / (weight + eps)).sum() / weight.size(0)
# perform pixel-wise mean
else:
loss = loss.sum() / (weight.sum() + eps)

return loss

"""Create a masked version of a given loss function.

To use this decorator, the loss function must have the signature like
loss_func(pred, target, **kwargs). The function only needs to compute
element-wise loss without any reduction. This decorator will add weight
and reduction arguments to the function. The decorated function will have
the signature like loss_func(pred, target, weight=None, reduction='mean',
avg_factor=None, **kwargs).

:Example:

>>> import torch
>>> def l1_loss(pred, target):
>>>     return (pred - target).abs()

>>> pred = torch.Tensor([0, 2, 3])
>>> target = torch.Tensor([1, 1, 1])
>>> weight = torch.Tensor([1, 0, 1])

>>> l1_loss(pred, target)
tensor(1.3333)
>>> l1_loss(pred, target, weight)
tensor(1.5000)
>>> l1_loss(pred, target, reduction='none')
tensor([1., 1., 2.])
>>> l1_loss(pred, target, weight, reduction='sum')
tensor(3.)
"""

@functools.wraps(loss_func)
def wrapper(pred: torch.Tensor,
target: torch.Tensor,
weight: Optional[torch.Tensor] = None,
reduction: str = 'mean',
sample_wise: bool = False,
**kwargs) -> torch.Tensor:
# get element-wise loss
loss = loss_func(pred, target, **kwargs)
loss = mask_reduce_loss(loss, weight, reduction, sample_wise)
return loss

return wrapper


© Copyright 2023, MMagic Authors. Revision 0a560bba.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
0.x