Source code for mmagic.models.editors.basicvsr.basicvsr

# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmagic.models import BaseEditModel
from mmagic.registry import MODELS
from mmagic.structures import DataSample

[docs]class BasicVSR(BaseEditModel): """BasicVSR model for video super-resolution. Note that this model is used for IconVSR. Paper: BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond, CVPR, 2021 Args: generator (dict): Config for the generator structure. pixel_loss (dict): Config for pixel-wise loss. ensemble (dict): Config for ensemble. Default: None. train_cfg (dict): Config for training. Default: None. test_cfg (dict): Config for testing. Default: None. init_cfg (dict, optional): The weight initialized config for :class:`BaseModule`. data_preprocessor (dict, optional): The pre-process config of :class:`BaseDataPreprocessor`. """ def __init__(self, generator, pixel_loss, ensemble=None, train_cfg=None, test_cfg=None, init_cfg=None, data_preprocessor=None): super().__init__( generator=generator, pixel_loss=pixel_loss, train_cfg=train_cfg, test_cfg=test_cfg, init_cfg=init_cfg, data_preprocessor=data_preprocessor) # fix pre-trained networks self.fix_iter = train_cfg.get('fix_iter', 0) if train_cfg else 0 self.is_weight_fixed = False # count training steps self.register_buffer('step_counter', torch.zeros(1)) # ensemble self.forward_ensemble = None if ensemble is not None: if ensemble['type'] == 'SpatialTemporalEnsemble': from mmagic.models.archs import SpatialTemporalEnsemble is_temporal = ensemble.get('is_temporal_ensemble', False) self.forward_ensemble = SpatialTemporalEnsemble(is_temporal) else: raise NotImplementedError( 'Currently support only ' '"SpatialTemporalEnsemble", but got type ' f'[{ensemble["type"]}]')
[docs] def check_if_mirror_extended(self, lrs): """Check whether the input is a mirror-extended sequence. If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the (t-1-i)-th frame. Args: lrs (tensor): Input LR images with shape (n, t, c, h, w) """ is_mirror_extended = False if lrs.size(1) % 2 == 0: lrs_1, lrs_2 = torch.chunk(lrs, 2, dim=1) if torch.norm(lrs_1 - lrs_2.flip(1)) == 0: is_mirror_extended = True return is_mirror_extended
[docs] def forward_train(self, inputs, data_samples=None, **kwargs): """Forward training. Returns dict of losses of training. Args: inputs (torch.Tensor): batch input tensor collated by :attr:`data_preprocessor`. data_samples (List[BaseDataElement], optional): data samples collated by :attr:`data_preprocessor`. Returns: dict: Dict of losses. """ # fix SPyNet and EDVR at the beginning if self.step_counter < self.fix_iter: if not self.is_weight_fixed: self.is_weight_fixed = True for k, v in self.generator.named_parameters(): if 'spynet' in k or 'edvr' in k: v.requires_grad_(False) elif self.step_counter == self.fix_iter: # train all the parameters self.generator.requires_grad_(True) feats = self.forward_tensor(inputs, data_samples, **kwargs) batch_gt_data = data_samples.gt_img loss = self.pixel_loss(feats, batch_gt_data) self.step_counter += 1 return dict(loss=loss)
[docs] def forward_inference(self, inputs, data_samples=None, **kwargs): """Forward inference. Returns predictions of validation, testing. Args: inputs (torch.Tensor): batch input tensor collated by :attr:`data_preprocessor`. data_samples (List[BaseDataElement], optional): data samples collated by :attr:`data_preprocessor`. Returns: DataSample: predictions. """ feats = self.forward_tensor(inputs, data_samples, **kwargs) # feats.shape = [b, t, c, h, w] feats = self.data_preprocessor.destruct(feats, data_samples) # If the GT is an image (i.e. the center frame), the output sequence is # turned to an image. gt = data_samples.gt_img[0] if gt is not None and == 3: t = feats.size(1) if self.check_if_mirror_extended(inputs): # with mirror extension feats = 0.5 * (feats[:, t // 4] + feats[:, -1 - t // 4]) else: # without mirror extension feats = feats[:, t // 2] # create a stacked data sample predictions = DataSample( pred_img=feats.cpu(), metainfo=data_samples.metainfo) return predictions
Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.