Source code for mmagic.models.losses.feature_loss
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Optional
import torch
import torch.nn as nn
from mmengine import MMLogger
from mmengine.runner import load_checkpoint
from mmagic.models.editors.dic import LightCNN
from mmagic.registry import MODELS
[docs]class LightCNNFeature(nn.Module):
"""Feature of LightCNN.
It is used to train DICGAN.
"""
def __init__(self) -> None:
super().__init__()
model = LightCNN(3)
self.features = nn.Sequential(*list(model.features.children()))
self.features.requires_grad_(False)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function.
Args:
x (Tensor): Input tensor.
Returns:
Tensor: Forward results.
"""
return self.features(x)
[docs] def init_weights(self,
pretrained: Optional[str] = None,
strict: bool = True) -> None:
"""Init weights for models.
Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Defaults to None.
strict (boo, optional): Whether strictly load the pretrained model.
Defaults to True.
"""
if isinstance(pretrained, str):
logger = MMLogger.get_current_instance()
load_checkpoint(self, pretrained, strict=strict, logger=logger)
elif pretrained is not None:
raise TypeError(f'"pretrained" must be a str or None. '
f'But received {type(pretrained)}.')
@MODELS.register_module()
[docs]class LightCNNFeatureLoss(nn.Module):
"""Feature loss of DICGAN, based on LightCNN.
Args:
pretrained (str): Path for pretrained weights.
loss_weight (float): Loss weight. Default: 1.0.
criterion (str): Criterion type. Options are 'l1' and 'mse'.
Default: 'l1'.
"""
def __init__(self,
pretrained: str,
loss_weight: float = 1.0,
criterion: str = 'l1') -> None:
super().__init__()
self.model = LightCNNFeature()
if not isinstance(pretrained, str):
warnings.warn('`LightCNNFeature` model in FeatureLoss ' +
'should be pretrained')
self.model.init_weights(pretrained)
self.model.eval()
self.loss_weight = loss_weight
if criterion == 'l1':
self.criterion = torch.nn.L1Loss()
elif criterion == 'mse':
self.criterion = torch.nn.MSELoss()
else:
raise ValueError("'criterion' should be 'l1' or 'mse', "
f'but got {criterion}')
[docs] def forward(self, pred: torch.Tensor, gt: torch.Tensor) -> torch.Tensor:
"""Forward function.
Args:
pred (Tensor): Predicted tensor.
gt (Tensor): GT tensor.
Returns:
Tensor: Forward results.
"""
self.model.eval()
pred_feature = self.model(pred)
gt_feature = self.model(gt).detach()
feature_loss = self.criterion(pred_feature, gt_feature)
return feature_loss * self.loss_weight