Shortcuts

Source code for mmagic.models.editors.iconvsr.iconvsr_net

# Copyright (c) OpenMMLab. All rights reserved.
from logging import WARNING

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine import MMLogger, print_log
from mmengine.model import BaseModule
from mmengine.runner import load_checkpoint

from mmagic.models.archs import PixelShufflePack, ResidualBlockNoBN
from mmagic.models.utils import flow_warp, make_layer
from mmagic.registry import MODELS
from ..basicvsr.basicvsr_net import ResidualBlocksWithInputConv, SPyNet
from ..edvr.edvr_net import PCDAlignment, TSAFusion


@MODELS.register_module()
[docs]class IconVSRNet(BaseModule): """IconVSR network structure for video super-resolution. Support only x4 upsampling. Paper: BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond, CVPR, 2021 Args: mid_channels (int): Channel number of the intermediate features. Default: 64. num_blocks (int): Number of residual blocks in each propagation branch. Default: 30. keyframe_stride (int): Number determining the keyframes. If stride=5, then the (0, 5, 10, 15, ...)-th frame will be the keyframes. Default: 5. padding (int): Number of frames to be padded at two ends of the sequence. 2 for REDS and 3 for Vimeo-90K. Default: 2. spynet_pretrained (str): Pre-trained model path of SPyNet. Default: None. edvr_pretrained (str): Pre-trained model path of EDVR (for refill). Default: None. """ def __init__(self, mid_channels=64, num_blocks=30, keyframe_stride=5, padding=2, spynet_pretrained=None, edvr_pretrained=None): super().__init__() self.mid_channels = mid_channels self.padding = padding self.keyframe_stride = keyframe_stride # optical flow network for alignment self.spynet = SPyNet(pretrained=spynet_pretrained) # information-refill self.edvr = EDVRFeatureExtractor( num_frames=padding * 2 + 1, center_frame_idx=padding, pretrained=edvr_pretrained) self.backward_fusion = nn.Conv2d( 2 * mid_channels, mid_channels, 3, 1, 1, bias=True) self.forward_fusion = nn.Conv2d( 2 * mid_channels, mid_channels, 3, 1, 1, bias=True) # propagation branches self.backward_resblocks = ResidualBlocksWithInputConv( mid_channels + 3, mid_channels, num_blocks) self.forward_resblocks = ResidualBlocksWithInputConv( 2 * mid_channels + 3, mid_channels, num_blocks) # upsample self.upsample1 = PixelShufflePack( mid_channels, mid_channels, 2, upsample_kernel=3) self.upsample2 = PixelShufflePack( mid_channels, 64, 2, upsample_kernel=3) self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1) self.conv_last = nn.Conv2d(64, 3, 3, 1, 1) self.img_upsample = nn.Upsample( scale_factor=4, mode='bilinear', align_corners=False) # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) self._raised_warning = False
[docs] def spatial_padding(self, lrs): """Apply padding spatially. Since the PCD module in EDVR requires that the resolution is a multiple of 4, we apply padding to the input LR images if their resolution is not divisible by 4. Args: lrs (Tensor): Input LR sequence with shape (n, t, c, h, w). Returns: Tensor: Padded LR sequence with shape (n, t, c, h_pad, w_pad). """ n, t, c, h, w = lrs.size() pad_h = (4 - h % 4) % 4 pad_w = (4 - w % 4) % 4 # padding lrs = lrs.view(-1, c, h, w) lrs = F.pad(lrs, [0, pad_w, 0, pad_h], mode='reflect') return lrs.view(n, t, c, h + pad_h, w + pad_w)
[docs] def check_if_mirror_extended(self, lrs): """Check whether the input is a mirror-extended sequence. If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the (t-1-i)-th frame. Args: lrs (tensor): Input LR images with shape (n, t, c, h, w) """ self.is_mirror_extended = False if lrs.size(1) % 2 == 0: lrs_1, lrs_2 = torch.chunk(lrs, 2, dim=1) if torch.norm(lrs_1 - lrs_2.flip(1)) == 0: self.is_mirror_extended = True
[docs] def compute_refill_features(self, lrs, keyframe_idx): """Compute keyframe features for information-refill. Since EDVR-M is used, padding is performed before feature computation. Args: lrs (Tensor): Input LR images with shape (n, t, c, h, w) keyframe_idx (list(int)): The indices specifying the keyframes. Return: dict(Tensor): The keyframe features. Each key corresponds to the indices in keyframe_idx. """ if self.padding == 2: lrs = [lrs[:, [4, 3]], lrs, lrs[:, [-4, -5]]] # padding elif self.padding == 3: lrs = [lrs[:, [6, 5, 4]], lrs, lrs[:, [-5, -6, -7]]] # padding lrs = torch.cat(lrs, dim=1) num_frames = 2 * self.padding + 1 feats_refill = {} for i in keyframe_idx: feats_refill[i] = self.edvr(lrs[:, i:i + num_frames].contiguous()) return feats_refill
[docs] def compute_flow(self, lrs): """Compute optical flow using SPyNet for feature warping. Note that if the input is an mirror-extended sequence, 'flows_forward' is not needed, since it is equal to 'flows_backward.flip(1)'. Args: lrs (tensor): Input LR images with shape (n, t, c, h, w) Return: tuple(Tensor): Optical flow. 'flows_forward' corresponds to the flows used for forward-time propagation (current to previous). 'flows_backward' corresponds to the flows used for backward-time propagation (current to next). """ n, t, c, h, w = lrs.size() lrs_1 = lrs[:, :-1, :, :, :].reshape(-1, c, h, w) lrs_2 = lrs[:, 1:, :, :, :].reshape(-1, c, h, w) flows_backward = self.spynet(lrs_1, lrs_2).view(n, t - 1, 2, h, w) if self.is_mirror_extended: # flows_forward = flows_backward.flip(1) flows_forward = None else: flows_forward = self.spynet(lrs_2, lrs_1).view(n, t - 1, 2, h, w) return flows_forward, flows_backward
[docs] def forward(self, lrs): """Forward function for IconVSR. Args: lrs (Tensor): Input LR tensor with shape (n, t, c, h, w). Returns: Tensor: Output HR tensor with shape (n, t, c, 4h, 4w). """ n, t, c, h_input, w_input = lrs.size() if (h_input < 64 or w_input < 64) and not self._raised_warning: print_log( f'{self.__class__.__name__} is designed for input ' 'larger than 64x64, but the resolution of current image ' f'is {h_input}x{w_input}. We recommend you to check your ' 'input.', 'current', WARNING) self._raised_warning = True # check whether the input is an extended sequence self.check_if_mirror_extended(lrs) lrs = self.spatial_padding(lrs) h, w = lrs.size(3), lrs.size(4) # get the keyframe indices for information-refill keyframe_idx = list(range(0, t, self.keyframe_stride)) if keyframe_idx[-1] != t - 1: keyframe_idx.append(t - 1) # the last frame must be a keyframe # compute optical flow and compute features for information-refill flows_forward, flows_backward = self.compute_flow(lrs) feats_refill = self.compute_refill_features(lrs, keyframe_idx) # backward-time propagation outputs = [] feat_prop = lrs.new_zeros(n, self.mid_channels, h, w) for i in range(t - 1, -1, -1): lr_curr = lrs[:, i, :, :, :] if i < t - 1: # no warping for the last timestep flow = flows_backward[:, i, :, :, :] feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1)) if i in keyframe_idx: feat_prop = torch.cat([feat_prop, feats_refill[i]], dim=1) feat_prop = self.backward_fusion(feat_prop) feat_prop = torch.cat([lr_curr, feat_prop], dim=1) feat_prop = self.backward_resblocks(feat_prop) outputs.append(feat_prop) outputs = outputs[::-1] # forward-time propagation and upsampling feat_prop = torch.zeros_like(feat_prop) for i in range(0, t): lr_curr = lrs[:, i, :, :, :] if i > 0: # no warping for the first timestep if flows_forward is not None: flow = flows_forward[:, i - 1, :, :, :] else: flow = flows_backward[:, -i, :, :, :] feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1)) if i in keyframe_idx: # information-refill feat_prop = torch.cat([feat_prop, feats_refill[i]], dim=1) feat_prop = self.forward_fusion(feat_prop) feat_prop = torch.cat([lr_curr, outputs[i], feat_prop], dim=1) feat_prop = self.forward_resblocks(feat_prop) out = self.lrelu(self.upsample1(feat_prop)) out = self.lrelu(self.upsample2(out)) out = self.lrelu(self.conv_hr(out)) out = self.conv_last(out) base = self.img_upsample(lr_curr) out += base outputs[i] = out return torch.stack(outputs, dim=1)[:, :, :, :4 * h_input, :4 * w_input]
[docs]class EDVRFeatureExtractor(BaseModule): """EDVR feature extractor for information-refill in IconVSR. We use EDVR-M in IconVSR. To adopt pretrained models, please specify "pretrained". Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks. Args: in_channels (int): Channel number of inputs. out_channels (int): Channel number of outputs. mid_channels (int): Channel number of intermediate features. Default: 64. num_frames (int): Number of input frames. Default: 5. deform_groups (int): Deformable groups. Defaults: 8. num_blocks_extraction (int): Number of blocks for feature extraction. Default: 5. num_blocks_reconstruction (int): Number of blocks for reconstruction. Default: 10. center_frame_idx (int): The index of center frame. Frame counting from 0. Default: 2. with_tsa (bool): Whether to use TSA module. Default: True. pretrained (str): The pretrained model path. Default: None. """ def __init__(self, in_channels=3, out_channel=3, mid_channels=64, num_frames=5, deform_groups=8, num_blocks_extraction=5, num_blocks_reconstruction=10, center_frame_idx=2, with_tsa=True, pretrained=None): super().__init__() self.center_frame_idx = center_frame_idx self.with_tsa = with_tsa act_cfg = dict(type='LeakyReLU', negative_slope=0.1) self.conv_first = nn.Conv2d(in_channels, mid_channels, 3, 1, 1) self.feature_extraction = make_layer( ResidualBlockNoBN, num_blocks_extraction, mid_channels=mid_channels) # generate pyramid features self.feat_l2_conv1 = ConvModule( mid_channels, mid_channels, 3, 2, 1, act_cfg=act_cfg) self.feat_l2_conv2 = ConvModule( mid_channels, mid_channels, 3, 1, 1, act_cfg=act_cfg) self.feat_l3_conv1 = ConvModule( mid_channels, mid_channels, 3, 2, 1, act_cfg=act_cfg) self.feat_l3_conv2 = ConvModule( mid_channels, mid_channels, 3, 1, 1, act_cfg=act_cfg) # pcd alignment self.pcd_alignment = PCDAlignment( mid_channels=mid_channels, deform_groups=deform_groups) # fusion if self.with_tsa: self.fusion = TSAFusion( mid_channels=mid_channels, num_frames=num_frames, center_frame_idx=self.center_frame_idx) else: self.fusion = nn.Conv2d(num_frames * mid_channels, mid_channels, 1, 1) # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) if isinstance(pretrained, str): logger = MMLogger.get_current_instance() load_checkpoint(self, pretrained, strict=True, logger=logger) elif pretrained is not None: raise TypeError(f'"pretrained" must be a str or None. ' f'But received {type(pretrained)}.')
[docs] def forward(self, x): """Forward function for EDVRFeatureExtractor. Args: x (Tensor): Input tensor with shape (n, t, 3, h, w). Returns: Tensor: Intermediate feature with shape (n, mid_channels, h, w). """ n, t, c, h, w = x.size() # extract LR features # L1 l1_feat = self.lrelu(self.conv_first(x.view(-1, c, h, w))) l1_feat = self.feature_extraction(l1_feat) # L2 l2_feat = self.feat_l2_conv2(self.feat_l2_conv1(l1_feat)) # L3 l3_feat = self.feat_l3_conv2(self.feat_l3_conv1(l2_feat)) l1_feat = l1_feat.view(n, t, -1, h, w) l2_feat = l2_feat.view(n, t, -1, h // 2, w // 2) l3_feat = l3_feat.view(n, t, -1, h // 4, w // 4) # pcd alignment ref_feats = [ # reference feature list l1_feat[:, self.center_frame_idx, :, :, :].clone(), l2_feat[:, self.center_frame_idx, :, :, :].clone(), l3_feat[:, self.center_frame_idx, :, :, :].clone() ] aligned_feat = [] for i in range(t): neighbor_feats = [ l1_feat[:, i, :, :, :].clone(), l2_feat[:, i, :, :, :].clone(), l3_feat[:, i, :, :, :].clone() ] aligned_feat.append(self.pcd_alignment(neighbor_feats, ref_feats)) aligned_feat = torch.stack(aligned_feat, dim=1) # (n, t, c, h, w) if self.with_tsa: feat = self.fusion(aligned_feat) else: aligned_feat = aligned_feat.view(n, -1, h, w) feat = self.fusion(aligned_feat) return feat
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.