mmagic.models.editors.tdan.tdan 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from mmagic.models import BaseEditModel
from mmagic.registry import MODELS

[文档]class TDAN(BaseEditModel): """TDAN model for video super-resolution. Paper: TDAN: Temporally-Deformable Alignment Network for Video Super- Resolution, CVPR, 2020 Args: generator (dict): Config for the generator structure. pixel_loss (dict): Config for pixel-wise loss. lq_pixel_loss (dict): Config for pixel-wise loss for the LQ images. train_cfg (dict): Config for training. Default: None. test_cfg (dict): Config for testing. Default: None. init_cfg (dict, optional): The weight initialized config for :class:`BaseModule`. data_preprocessor (dict, optional): The pre-process config of :class:`BaseDataPreprocessor`. """ def __init__(self, generator, pixel_loss, lq_pixel_loss, train_cfg=None, test_cfg=None, init_cfg=None, data_preprocessor=None): super().__init__( generator=generator, pixel_loss=pixel_loss, train_cfg=train_cfg, test_cfg=test_cfg, init_cfg=init_cfg, data_preprocessor=data_preprocessor) self.lq_pixel_loss =
[文档] def forward_train(self, inputs, data_samples=None, **kwargs): """Forward training. Returns dict of losses of training. Args: inputs (torch.Tensor): batch input tensor collated by :attr:`data_preprocessor`. data_samples (List[BaseDataElement], optional): data samples collated by :attr:`data_preprocessor`. Returns: dict: Dict of losses. """ feats, aligned_img = self.forward_tensor( inputs, data_samples, training=True, **kwargs) batch_gt_data = data_samples.gt_img losses = dict() # loss on the HR image losses['loss_pix'] = self.pixel_loss(feats, batch_gt_data) # loss on the aligned LR images t = aligned_img.size(1) lq_ref = inputs[:, t // 2:t // 2 + 1, :, :, :].expand(-1, t, -1, -1, -1) loss_pix_lq = self.lq_pixel_loss(aligned_img, lq_ref) losses['loss_pix_lq'] = loss_pix_lq return losses
[文档] def forward_tensor(self, inputs, data_samples=None, training=False, **kwargs): """Forward tensor. Returns result of simple forward. Args: inputs (torch.Tensor): batch input tensor collated by :attr:`data_preprocessor`. data_samples (List[BaseDataElement], optional): data samples collated by :attr:`data_preprocessor`. training (bool): Whether is training. Default: False. Returns: (Tensor | List[Tensor]): results of forward inference and forward train. """ outputs = self.generator(inputs, **kwargs) return outputs if training else outputs[0]
Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.