Source code for mmagic.models.editors.real_basicvsr.real_basicvsr_net
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmengine.model import BaseModule
from mmagic.registry import MODELS
from ..basicvsr.basicvsr_net import BasicVSRNet, ResidualBlocksWithInputConv
@MODELS.register_module()
[docs]class RealBasicVSRNet(BaseModule):
"""RealBasicVSR network structure for real-world video super-resolution.
Support only x4 upsampling.
Paper:
Investigating Tradeoffs in Real-World Video Super-Resolution, arXiv
Args:
mid_channels (int, optional): Channel number of the intermediate
features. Default: 64.
num_propagation_blocks (int, optional): Number of residual blocks in
each propagation branch. Default: 20.
num_cleaning_blocks (int, optional): Number of residual blocks in the
image cleaning module. Default: 20.
dynamic_refine_thres (int, optional): Stop cleaning the images when
the residue is smaller than this value. Default: 255.
spynet_pretrained (str, optional): Pre-trained model path of SPyNet.
Default: None.
is_fix_cleaning (bool, optional): Whether to fix the weights of
the image cleaning module during training. Default: False.
is_sequential_cleaning (bool, optional): Whether to clean the images
sequentially. This is used to save GPU memory, but the speed is
slightly slower. Default: False.
"""
def __init__(self,
mid_channels=64,
num_propagation_blocks=20,
num_cleaning_blocks=20,
dynamic_refine_thres=255,
spynet_pretrained=None,
is_fix_cleaning=False,
is_sequential_cleaning=False):
super().__init__()
self.dynamic_refine_thres = dynamic_refine_thres / 255.
self.is_sequential_cleaning = is_sequential_cleaning
# image cleaning module
self.image_cleaning = nn.Sequential(
ResidualBlocksWithInputConv(3, mid_channels, num_cleaning_blocks),
nn.Conv2d(mid_channels, 3, 3, 1, 1, bias=True),
)
if is_fix_cleaning: # keep the weights of the cleaning module fixed
self.image_cleaning.requires_grad_(False)
# BasicVSR
self.basicvsr = BasicVSRNet(mid_channels, num_propagation_blocks,
spynet_pretrained)
self.basicvsr.spynet.requires_grad_(False)
[docs] def forward(self, lqs, return_lqs=False):
"""Forward function for BasicVSR++.
Args:
lqs (tensor): Input low quality (LQ) sequence with
shape (n, t, c, h, w).
return_lqs (bool): Whether to return LQ sequence. Default: False.
Returns:
Tensor: Output HR sequence.
"""
n, t, c, h, w = lqs.size()
for _ in range(0, 3): # at most 3 cleaning, determined empirically
if self.is_sequential_cleaning:
residues = []
for i in range(0, t):
residue_i = self.image_cleaning(lqs[:, i, :, :, :])
lqs[:, i, :, :, :] += residue_i
residues.append(residue_i)
residues = torch.stack(residues, dim=1)
else: # time -> batch, then apply cleaning at once
lqs = lqs.view(-1, c, h, w)
residues = self.image_cleaning(lqs)
lqs = (lqs + residues).view(n, t, c, h, w)
# determine whether to continue cleaning
if torch.mean(torch.abs(residues)) < self.dynamic_refine_thres:
break
# Super-resolution (BasicVSR)
outputs = self.basicvsr(lqs)
if return_lqs:
return outputs, lqs
else:
return outputs