Source code for mmagic.models.editors.basicvsr.basicvsr_net
# Copyright (c) OpenMMLab. All rights reserved.
from logging import WARNING
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine import MMLogger, print_log
from mmengine.model import BaseModule
from mmengine.runner import load_checkpoint
from mmagic.models.archs import PixelShufflePack, ResidualBlockNoBN
from mmagic.models.utils import flow_warp, make_layer
from mmagic.registry import MODELS
@MODELS.register_module()
[docs]class BasicVSRNet(BaseModule):
"""BasicVSR network structure for video super-resolution.
Support only x4 upsampling.
Paper:
BasicVSR: The Search for Essential Components in Video Super-Resolution
and Beyond, CVPR, 2021
Args:
mid_channels (int): Channel number of the intermediate features.
Default: 64.
num_blocks (int): Number of residual blocks in each propagation branch.
Default: 30.
spynet_pretrained (str): Pre-trained model path of SPyNet.
Default: None.
"""
def __init__(self, mid_channels=64, num_blocks=30, spynet_pretrained=None):
super().__init__()
self.mid_channels = mid_channels
# optical flow network for feature alignment
self.spynet = SPyNet(pretrained=spynet_pretrained)
# propagation branches
self.backward_resblocks = ResidualBlocksWithInputConv(
mid_channels + 3, mid_channels, num_blocks)
self.forward_resblocks = ResidualBlocksWithInputConv(
mid_channels + 3, mid_channels, num_blocks)
# upsample
self.fusion = nn.Conv2d(
mid_channels * 2, mid_channels, 1, 1, 0, bias=True)
self.upsample1 = PixelShufflePack(
mid_channels, mid_channels, 2, upsample_kernel=3)
self.upsample2 = PixelShufflePack(
mid_channels, 64, 2, upsample_kernel=3)
self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
self.img_upsample = nn.Upsample(
scale_factor=4, mode='bilinear', align_corners=False)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self._raised_warning = False
[docs] def check_if_mirror_extended(self, lrs):
"""Check whether the input is a mirror-extended sequence.
If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the
(t-1-i)-th frame.
Args:
lrs (tensor): Input LR images with shape (n, t, c, h, w)
"""
self.is_mirror_extended = False
if lrs.size(1) % 2 == 0:
lrs_1, lrs_2 = torch.chunk(lrs, 2, dim=1)
if torch.norm(lrs_1 - lrs_2.flip(1)) == 0:
self.is_mirror_extended = True
[docs] def compute_flow(self, lrs):
"""Compute optical flow using SPyNet for feature warping.
Note that if the input is an mirror-extended sequence, 'flows_forward'
is not needed, since it is equal to 'flows_backward.flip(1)'.
Args:
lrs (tensor): Input LR images with shape (n, t, c, h, w)
Return:
tuple(Tensor): Optical flow. 'flows_forward' corresponds to the
flows used for forward-time propagation (current to previous).
'flows_backward' corresponds to the flows used for
backward-time propagation (current to next).
"""
n, t, c, h, w = lrs.size()
lrs_1 = lrs[:, :-1, :, :, :].reshape(-1, c, h, w)
lrs_2 = lrs[:, 1:, :, :, :].reshape(-1, c, h, w)
flows_backward = self.spynet(lrs_1, lrs_2).view(n, t - 1, 2, h, w)
if self.is_mirror_extended: # flows_forward = flows_backward.flip(1)
flows_forward = None
else:
flows_forward = self.spynet(lrs_2, lrs_1).view(n, t - 1, 2, h, w)
return flows_forward, flows_backward
[docs] def forward(self, lrs):
"""Forward function for BasicVSR.
Args:
lrs (Tensor): Input LR sequence with shape (n, t, c, h, w).
Returns:
Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
"""
n, t, c, h, w = lrs.size()
if (h < 64 or w < 64) and not self._raised_warning:
print_log(
f'{self.__class__.__name__} is designed for input '
'larger than 64x64, but the resolution of current image '
f'is {h}x{w}. We recommend you to check your input.',
'current', WARNING)
self._raised_warning = True
# check whether the input is an extended sequence
self.check_if_mirror_extended(lrs)
# compute optical flow
flows_forward, flows_backward = self.compute_flow(lrs)
# backward-time propagation
outputs = []
feat_prop = lrs.new_zeros(n, self.mid_channels, h, w)
for i in range(t - 1, -1, -1):
if i < t - 1: # no warping required for the last timestep
flow = flows_backward[:, i, :, :, :]
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
feat_prop = torch.cat([lrs[:, i, :, :, :], feat_prop], dim=1)
feat_prop = self.backward_resblocks(feat_prop)
outputs.append(feat_prop)
outputs = outputs[::-1]
# forward-time propagation and upsampling
feat_prop = torch.zeros_like(feat_prop)
for i in range(0, t):
lr_curr = lrs[:, i, :, :, :]
if i > 0: # no warping required for the first timestep
if flows_forward is not None:
flow = flows_forward[:, i - 1, :, :, :]
else:
flow = flows_backward[:, -i, :, :, :]
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
feat_prop = torch.cat([lr_curr, feat_prop], dim=1)
feat_prop = self.forward_resblocks(feat_prop)
# upsampling given the backward and forward features
out = torch.cat([outputs[i], feat_prop], dim=1)
out = self.lrelu(self.fusion(out))
out = self.lrelu(self.upsample1(out))
out = self.lrelu(self.upsample2(out))
out = self.lrelu(self.conv_hr(out))
out = self.conv_last(out)
base = self.img_upsample(lr_curr)
out += base
outputs[i] = out
return torch.stack(outputs, dim=1)
[docs]class ResidualBlocksWithInputConv(BaseModule):
"""Residual blocks with a convolution in front.
Args:
in_channels (int): Number of input channels of the first conv.
out_channels (int): Number of channels of the residual blocks.
Default: 64.
num_blocks (int): Number of residual blocks. Default: 30.
"""
def __init__(self, in_channels, out_channels=64, num_blocks=30):
super().__init__()
main = []
# a convolution used to match the channels of the residual blocks
main.append(nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=True))
main.append(nn.LeakyReLU(negative_slope=0.1, inplace=True))
# residual blocks
main.append(
make_layer(
ResidualBlockNoBN, num_blocks, mid_channels=out_channels))
self.main = nn.Sequential(*main)
[docs] def forward(self, feat):
"""Forward function for ResidualBlocksWithInputConv.
Args:
feat (Tensor): Input feature with shape (n, in_channels, h, w)
Returns:
Tensor: Output feature with shape (n, out_channels, h, w)
"""
return self.main(feat)
[docs]class SPyNet(BaseModule):
"""SPyNet network structure.
The difference to the SPyNet in [tof.py] is that
1. more SPyNetBasicModule is used in this version, and
2. no batch normalization is used in this version.
Paper:
Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
Args:
pretrained (str): path for pre-trained SPyNet. Default: None.
"""
def __init__(self, pretrained):
super().__init__()
self.basic_module = nn.ModuleList(
[SPyNetBasicModule() for _ in range(6)])
if isinstance(pretrained, str):
logger = MMLogger.get_current_instance()
load_checkpoint(self, pretrained, strict=True, logger=logger)
elif pretrained is not None:
raise TypeError('[pretrained] should be str or None, '
f'but got {type(pretrained)}.')
self.register_buffer(
'mean',
torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer(
'std',
torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
[docs] def compute_flow(self, ref, supp):
"""Compute flow from ref to supp.
Note that in this function, the images are already resized to a
multiple of 32.
Args:
ref (Tensor): Reference image with shape of (n, 3, h, w).
supp (Tensor): Supporting image with shape of (n, 3, h, w).
Returns:
Tensor: Estimated optical flow: (n, 2, h, w).
"""
n, _, h, w = ref.size()
# normalize the input images
ref = [(ref - self.mean) / self.std]
supp = [(supp - self.mean) / self.std]
# generate downsampled frames
for level in range(5):
ref.append(
F.avg_pool2d(
input=ref[-1],
kernel_size=2,
stride=2,
count_include_pad=False))
supp.append(
F.avg_pool2d(
input=supp[-1],
kernel_size=2,
stride=2,
count_include_pad=False))
ref = ref[::-1]
supp = supp[::-1]
# flow computation
flow = ref[0].new_zeros(n, 2, h // 32, w // 32)
for level in range(len(ref)):
if level == 0:
flow_up = flow
else:
flow_up = F.interpolate(
input=flow,
scale_factor=2,
mode='bilinear',
align_corners=True) * 2.0
# add the residue to the upsampled flow
flow = flow_up + self.basic_module[level](
torch.cat([
ref[level],
flow_warp(
supp[level],
flow_up.permute(0, 2, 3, 1),
padding_mode='border'), flow_up
], 1))
return flow
[docs] def forward(self, ref, supp):
"""Forward function of SPyNet.
This function computes the optical flow from ref to supp.
Args:
ref (Tensor): Reference image with shape of (n, 3, h, w).
supp (Tensor): Supporting image with shape of (n, 3, h, w).
Returns:
Tensor: Estimated optical flow: (n, 2, h, w).
"""
# upsize to a multiple of 32
h, w = ref.shape[2:4]
w_up = w if (w % 32) == 0 else 32 * (w // 32 + 1)
h_up = h if (h % 32) == 0 else 32 * (h // 32 + 1)
ref = F.interpolate(
input=ref, size=(h_up, w_up), mode='bilinear', align_corners=False)
supp = F.interpolate(
input=supp,
size=(h_up, w_up),
mode='bilinear',
align_corners=False)
# compute flow, and resize back to the original resolution
flow = F.interpolate(
input=self.compute_flow(ref, supp),
size=(h, w),
mode='bilinear',
align_corners=False)
# adjust the flow values
flow[:, 0, :, :] *= float(w) / float(w_up)
flow[:, 1, :, :] *= float(h) / float(h_up)
return flow
[docs]class SPyNetBasicModule(BaseModule):
"""Basic Module for SPyNet.
Paper:
Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
"""
def __init__(self):
super().__init__()
self.basic_module = nn.Sequential(
ConvModule(
in_channels=8,
out_channels=32,
kernel_size=7,
stride=1,
padding=3,
norm_cfg=None,
act_cfg=dict(type='ReLU')),
ConvModule(
in_channels=32,
out_channels=64,
kernel_size=7,
stride=1,
padding=3,
norm_cfg=None,
act_cfg=dict(type='ReLU')),
ConvModule(
in_channels=64,
out_channels=32,
kernel_size=7,
stride=1,
padding=3,
norm_cfg=None,
act_cfg=dict(type='ReLU')),
ConvModule(
in_channels=32,
out_channels=16,
kernel_size=7,
stride=1,
padding=3,
norm_cfg=None,
act_cfg=dict(type='ReLU')),
ConvModule(
in_channels=16,
out_channels=2,
kernel_size=7,
stride=1,
padding=3,
norm_cfg=None,
act_cfg=None))
[docs] def forward(self, tensor_input):
"""
Args:
tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
8 channels contain:
[reference image (3), neighbor image (3), initial flow (2)].
Returns:
Tensor: Refined flow with shape (b, 2, h, w)
"""
return self.basic_module(tensor_input)