Source code for mmagic.models.editors.iconvsr.iconvsr_net
# Copyright (c) OpenMMLab. All rights reserved.
from logging import WARNING
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine import MMLogger, print_log
from mmengine.model import BaseModule
from mmengine.runner import load_checkpoint
from mmagic.models.archs import PixelShufflePack, ResidualBlockNoBN
from mmagic.models.utils import flow_warp, make_layer
from mmagic.registry import MODELS
from ..basicvsr.basicvsr_net import ResidualBlocksWithInputConv, SPyNet
from ..edvr.edvr_net import PCDAlignment, TSAFusion
@MODELS.register_module()
[docs]class IconVSRNet(BaseModule):
"""IconVSR network structure for video super-resolution.
Support only x4 upsampling.
Paper:
BasicVSR: The Search for Essential Components in Video Super-Resolution
and Beyond, CVPR, 2021
Args:
mid_channels (int): Channel number of the intermediate features.
Default: 64.
num_blocks (int): Number of residual blocks in each propagation branch.
Default: 30.
keyframe_stride (int): Number determining the keyframes. If stride=5,
then the (0, 5, 10, 15, ...)-th frame will be the keyframes.
Default: 5.
padding (int): Number of frames to be padded at two ends of the
sequence. 2 for REDS and 3 for Vimeo-90K. Default: 2.
spynet_pretrained (str): Pre-trained model path of SPyNet.
Default: None.
edvr_pretrained (str): Pre-trained model path of EDVR (for refill).
Default: None.
"""
def __init__(self,
mid_channels=64,
num_blocks=30,
keyframe_stride=5,
padding=2,
spynet_pretrained=None,
edvr_pretrained=None):
super().__init__()
self.mid_channels = mid_channels
self.padding = padding
self.keyframe_stride = keyframe_stride
# optical flow network for alignment
self.spynet = SPyNet(pretrained=spynet_pretrained)
# information-refill
self.edvr = EDVRFeatureExtractor(
num_frames=padding * 2 + 1,
center_frame_idx=padding,
pretrained=edvr_pretrained)
self.backward_fusion = nn.Conv2d(
2 * mid_channels, mid_channels, 3, 1, 1, bias=True)
self.forward_fusion = nn.Conv2d(
2 * mid_channels, mid_channels, 3, 1, 1, bias=True)
# propagation branches
self.backward_resblocks = ResidualBlocksWithInputConv(
mid_channels + 3, mid_channels, num_blocks)
self.forward_resblocks = ResidualBlocksWithInputConv(
2 * mid_channels + 3, mid_channels, num_blocks)
# upsample
self.upsample1 = PixelShufflePack(
mid_channels, mid_channels, 2, upsample_kernel=3)
self.upsample2 = PixelShufflePack(
mid_channels, 64, 2, upsample_kernel=3)
self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
self.img_upsample = nn.Upsample(
scale_factor=4, mode='bilinear', align_corners=False)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self._raised_warning = False
[docs] def spatial_padding(self, lrs):
"""Apply padding spatially.
Since the PCD module in EDVR requires that the resolution is a multiple
of 4, we apply padding to the input LR images if their resolution is
not divisible by 4.
Args:
lrs (Tensor): Input LR sequence with shape (n, t, c, h, w).
Returns:
Tensor: Padded LR sequence with shape (n, t, c, h_pad, w_pad).
"""
n, t, c, h, w = lrs.size()
pad_h = (4 - h % 4) % 4
pad_w = (4 - w % 4) % 4
# padding
lrs = lrs.view(-1, c, h, w)
lrs = F.pad(lrs, [0, pad_w, 0, pad_h], mode='reflect')
return lrs.view(n, t, c, h + pad_h, w + pad_w)
[docs] def check_if_mirror_extended(self, lrs):
"""Check whether the input is a mirror-extended sequence.
If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the
(t-1-i)-th frame.
Args:
lrs (tensor): Input LR images with shape (n, t, c, h, w)
"""
self.is_mirror_extended = False
if lrs.size(1) % 2 == 0:
lrs_1, lrs_2 = torch.chunk(lrs, 2, dim=1)
if torch.norm(lrs_1 - lrs_2.flip(1)) == 0:
self.is_mirror_extended = True
[docs] def compute_refill_features(self, lrs, keyframe_idx):
"""Compute keyframe features for information-refill.
Since EDVR-M is used, padding is performed before feature computation.
Args:
lrs (Tensor): Input LR images with shape (n, t, c, h, w)
keyframe_idx (list(int)): The indices specifying the keyframes.
Return:
dict(Tensor): The keyframe features. Each key corresponds to the
indices in keyframe_idx.
"""
if self.padding == 2:
lrs = [lrs[:, [4, 3]], lrs, lrs[:, [-4, -5]]] # padding
elif self.padding == 3:
lrs = [lrs[:, [6, 5, 4]], lrs, lrs[:, [-5, -6, -7]]] # padding
lrs = torch.cat(lrs, dim=1)
num_frames = 2 * self.padding + 1
feats_refill = {}
for i in keyframe_idx:
feats_refill[i] = self.edvr(lrs[:, i:i + num_frames].contiguous())
return feats_refill
[docs] def compute_flow(self, lrs):
"""Compute optical flow using SPyNet for feature warping.
Note that if the input is an mirror-extended sequence, 'flows_forward'
is not needed, since it is equal to 'flows_backward.flip(1)'.
Args:
lrs (tensor): Input LR images with shape (n, t, c, h, w)
Return:
tuple(Tensor): Optical flow. 'flows_forward' corresponds to the
flows used for forward-time propagation (current to previous).
'flows_backward' corresponds to the flows used for
backward-time propagation (current to next).
"""
n, t, c, h, w = lrs.size()
lrs_1 = lrs[:, :-1, :, :, :].reshape(-1, c, h, w)
lrs_2 = lrs[:, 1:, :, :, :].reshape(-1, c, h, w)
flows_backward = self.spynet(lrs_1, lrs_2).view(n, t - 1, 2, h, w)
if self.is_mirror_extended: # flows_forward = flows_backward.flip(1)
flows_forward = None
else:
flows_forward = self.spynet(lrs_2, lrs_1).view(n, t - 1, 2, h, w)
return flows_forward, flows_backward
[docs] def forward(self, lrs):
"""Forward function for IconVSR.
Args:
lrs (Tensor): Input LR tensor with shape (n, t, c, h, w).
Returns:
Tensor: Output HR tensor with shape (n, t, c, 4h, 4w).
"""
n, t, c, h_input, w_input = lrs.size()
if (h_input < 64 or w_input < 64) and not self._raised_warning:
print_log(
f'{self.__class__.__name__} is designed for input '
'larger than 64x64, but the resolution of current image '
f'is {h_input}x{w_input}. We recommend you to check your '
'input.', 'current', WARNING)
self._raised_warning = True
# check whether the input is an extended sequence
self.check_if_mirror_extended(lrs)
lrs = self.spatial_padding(lrs)
h, w = lrs.size(3), lrs.size(4)
# get the keyframe indices for information-refill
keyframe_idx = list(range(0, t, self.keyframe_stride))
if keyframe_idx[-1] != t - 1:
keyframe_idx.append(t - 1) # the last frame must be a keyframe
# compute optical flow and compute features for information-refill
flows_forward, flows_backward = self.compute_flow(lrs)
feats_refill = self.compute_refill_features(lrs, keyframe_idx)
# backward-time propagation
outputs = []
feat_prop = lrs.new_zeros(n, self.mid_channels, h, w)
for i in range(t - 1, -1, -1):
lr_curr = lrs[:, i, :, :, :]
if i < t - 1: # no warping for the last timestep
flow = flows_backward[:, i, :, :, :]
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
if i in keyframe_idx:
feat_prop = torch.cat([feat_prop, feats_refill[i]], dim=1)
feat_prop = self.backward_fusion(feat_prop)
feat_prop = torch.cat([lr_curr, feat_prop], dim=1)
feat_prop = self.backward_resblocks(feat_prop)
outputs.append(feat_prop)
outputs = outputs[::-1]
# forward-time propagation and upsampling
feat_prop = torch.zeros_like(feat_prop)
for i in range(0, t):
lr_curr = lrs[:, i, :, :, :]
if i > 0: # no warping for the first timestep
if flows_forward is not None:
flow = flows_forward[:, i - 1, :, :, :]
else:
flow = flows_backward[:, -i, :, :, :]
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
if i in keyframe_idx: # information-refill
feat_prop = torch.cat([feat_prop, feats_refill[i]], dim=1)
feat_prop = self.forward_fusion(feat_prop)
feat_prop = torch.cat([lr_curr, outputs[i], feat_prop], dim=1)
feat_prop = self.forward_resblocks(feat_prop)
out = self.lrelu(self.upsample1(feat_prop))
out = self.lrelu(self.upsample2(out))
out = self.lrelu(self.conv_hr(out))
out = self.conv_last(out)
base = self.img_upsample(lr_curr)
out += base
outputs[i] = out
return torch.stack(outputs, dim=1)[:, :, :, :4 * h_input, :4 * w_input]
[docs]class EDVRFeatureExtractor(BaseModule):
"""EDVR feature extractor for information-refill in IconVSR.
We use EDVR-M in IconVSR. To adopt pretrained models, please
specify "pretrained".
Paper:
EDVR: Video Restoration with Enhanced Deformable Convolutional Networks.
Args:
in_channels (int): Channel number of inputs.
out_channels (int): Channel number of outputs.
mid_channels (int): Channel number of intermediate features.
Default: 64.
num_frames (int): Number of input frames. Default: 5.
deform_groups (int): Deformable groups. Defaults: 8.
num_blocks_extraction (int): Number of blocks for feature extraction.
Default: 5.
num_blocks_reconstruction (int): Number of blocks for reconstruction.
Default: 10.
center_frame_idx (int): The index of center frame. Frame counting from
0. Default: 2.
with_tsa (bool): Whether to use TSA module. Default: True.
pretrained (str): The pretrained model path. Default: None.
"""
def __init__(self,
in_channels=3,
out_channel=3,
mid_channels=64,
num_frames=5,
deform_groups=8,
num_blocks_extraction=5,
num_blocks_reconstruction=10,
center_frame_idx=2,
with_tsa=True,
pretrained=None):
super().__init__()
self.center_frame_idx = center_frame_idx
self.with_tsa = with_tsa
act_cfg = dict(type='LeakyReLU', negative_slope=0.1)
self.conv_first = nn.Conv2d(in_channels, mid_channels, 3, 1, 1)
self.feature_extraction = make_layer(
ResidualBlockNoBN,
num_blocks_extraction,
mid_channels=mid_channels)
# generate pyramid features
self.feat_l2_conv1 = ConvModule(
mid_channels, mid_channels, 3, 2, 1, act_cfg=act_cfg)
self.feat_l2_conv2 = ConvModule(
mid_channels, mid_channels, 3, 1, 1, act_cfg=act_cfg)
self.feat_l3_conv1 = ConvModule(
mid_channels, mid_channels, 3, 2, 1, act_cfg=act_cfg)
self.feat_l3_conv2 = ConvModule(
mid_channels, mid_channels, 3, 1, 1, act_cfg=act_cfg)
# pcd alignment
self.pcd_alignment = PCDAlignment(
mid_channels=mid_channels, deform_groups=deform_groups)
# fusion
if self.with_tsa:
self.fusion = TSAFusion(
mid_channels=mid_channels,
num_frames=num_frames,
center_frame_idx=self.center_frame_idx)
else:
self.fusion = nn.Conv2d(num_frames * mid_channels, mid_channels, 1,
1)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
if isinstance(pretrained, str):
logger = MMLogger.get_current_instance()
load_checkpoint(self, pretrained, strict=True, logger=logger)
elif pretrained is not None:
raise TypeError(f'"pretrained" must be a str or None. '
f'But received {type(pretrained)}.')
[docs] def forward(self, x):
"""Forward function for EDVRFeatureExtractor.
Args:
x (Tensor): Input tensor with shape (n, t, 3, h, w).
Returns:
Tensor: Intermediate feature with shape (n, mid_channels, h, w).
"""
n, t, c, h, w = x.size()
# extract LR features
# L1
l1_feat = self.lrelu(self.conv_first(x.view(-1, c, h, w)))
l1_feat = self.feature_extraction(l1_feat)
# L2
l2_feat = self.feat_l2_conv2(self.feat_l2_conv1(l1_feat))
# L3
l3_feat = self.feat_l3_conv2(self.feat_l3_conv1(l2_feat))
l1_feat = l1_feat.view(n, t, -1, h, w)
l2_feat = l2_feat.view(n, t, -1, h // 2, w // 2)
l3_feat = l3_feat.view(n, t, -1, h // 4, w // 4)
# pcd alignment
ref_feats = [ # reference feature list
l1_feat[:, self.center_frame_idx, :, :, :].clone(),
l2_feat[:, self.center_frame_idx, :, :, :].clone(),
l3_feat[:, self.center_frame_idx, :, :, :].clone()
]
aligned_feat = []
for i in range(t):
neighbor_feats = [
l1_feat[:, i, :, :, :].clone(), l2_feat[:, i, :, :, :].clone(),
l3_feat[:, i, :, :, :].clone()
]
aligned_feat.append(self.pcd_alignment(neighbor_feats, ref_feats))
aligned_feat = torch.stack(aligned_feat, dim=1) # (n, t, c, h, w)
if self.with_tsa:
feat = self.fusion(aligned_feat)
else:
aligned_feat = aligned_feat.view(n, -1, h, w)
feat = self.fusion(aligned_feat)
return feat