Shortcuts

Source code for mmagic.models.editors.tdan.tdan_net

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.ops import DeformConv2d, DeformConv2dPack, deform_conv2d
from mmengine.model import BaseModule
from mmengine.model.weight_init import constant_init
from torch.nn.modules.utils import _pair

from mmagic.models.archs import PixelShufflePack, ResidualBlockNoBN
from mmagic.models.utils import make_layer
from mmagic.registry import MODELS


@MODELS.register_module()
[docs]class TDANNet(BaseModule): """TDAN network structure for video super-resolution. Support only x4 upsampling. Paper: TDAN: Temporally-Deformable Alignment Network for Video Super- Resolution, CVPR, 2020 Args: in_channels (int): Number of channels of the input image. Default: 3. mid_channels (int): Number of channels of the intermediate features. Default: 64. out_channels (int): Number of channels of the output image. Default: 3. num_blocks_before_align (int): Number of residual blocks before temporal alignment. Default: 5. num_blocks_after_align (int): Number of residual blocks after temporal alignment. Default: 10. """ def __init__(self, in_channels=3, mid_channels=64, out_channels=3, num_blocks_before_align=5, num_blocks_after_align=10): super().__init__() self.feat_extract = nn.Sequential( ConvModule(in_channels, mid_channels, 3, padding=1), make_layer( ResidualBlockNoBN, num_blocks_before_align, mid_channels=mid_channels)) self.feat_aggregate = nn.Sequential( nn.Conv2d(mid_channels * 2, mid_channels, 3, padding=1, bias=True), DeformConv2dPack( mid_channels, mid_channels, 3, padding=1, deform_groups=8), DeformConv2dPack( mid_channels, mid_channels, 3, padding=1, deform_groups=8)) self.align_1 = AugmentedDeformConv2dPack( mid_channels, mid_channels, 3, padding=1, deform_groups=8) self.align_2 = DeformConv2dPack( mid_channels, mid_channels, 3, padding=1, deform_groups=8) self.to_rgb = nn.Conv2d(mid_channels, 3, 3, padding=1, bias=True) self.reconstruct = nn.Sequential( ConvModule(in_channels * 5, mid_channels, 3, padding=1), make_layer( ResidualBlockNoBN, num_blocks_after_align, mid_channels=mid_channels), PixelShufflePack(mid_channels, mid_channels, 2, upsample_kernel=3), PixelShufflePack(mid_channels, mid_channels, 2, upsample_kernel=3), nn.Conv2d(mid_channels, out_channels, 3, 1, 1, bias=False))
[docs] def forward(self, lrs): """Forward function for TDANNet. Args: lrs (Tensor): Input LR sequence with shape (n, t, c, h, w). Returns: tuple[Tensor]: Output HR image with shape (n, c, 4h, 4w) and aligned LR images with shape (n, t, c, h, w). """ n, t, c, h, w = lrs.size() lr_center = lrs[:, t // 2, :, :, :] # LR center frame # extract features feats = self.feat_extract(lrs.view(-1, c, h, w)).view(n, t, -1, h, w) # alignment of LR frames feat_center = feats[:, t // 2, :, :, :].contiguous() aligned_lrs = [] for i in range(0, t): if i == t // 2: aligned_lrs.append(lr_center) else: feat_neig = feats[:, i, :, :, :].contiguous() feat_agg = torch.cat([feat_center, feat_neig], dim=1) feat_agg = self.feat_aggregate(feat_agg) aligned_feat = self.align_2(self.align_1(feat_neig, feat_agg)) aligned_lrs.append(self.to_rgb(aligned_feat)) aligned_lrs = torch.cat(aligned_lrs, dim=1) # output HR center frame and the aligned LR frames return self.reconstruct(aligned_lrs), aligned_lrs.view(n, t, c, h, w)
[docs]class AugmentedDeformConv2dPack(DeformConv2d): """Augmented Deformable Convolution Pack. Different from DeformConv2dPack, which generates offsets from the preceding feature, this AugmentedDeformConv2dPack takes another feature to generate the offsets. Args: in_channels (int): Number of channels in the input feature. out_channels (int): Number of channels produced by the convolution. kernel_size (int or tuple[int]): Size of the convolving kernel. stride (int or tuple[int]): Stride of the convolution. Default: 1. padding (int or tuple[int]): Zero-padding added to both sides of the input. Default: 0. dilation (int or tuple[int]): Spacing between kernel elements. Default: 1. groups (int): Number of blocked connections from input channels to output channels. Default: 1. deform_groups (int): Number of deformable group partitions. bias (bool or str): If specified as `auto`, it will be decided by the norm_cfg. Bias will be set as True if norm_cfg is None, otherwise False. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.conv_offset = nn.Conv2d( self.in_channels, self.deform_groups * 2 * self.kernel_size[0] * self.kernel_size[1], kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding), bias=True) self.init_offset()
[docs] def init_offset(self): """Init constant offset.""" constant_init(self.conv_offset, val=0, bias=0)
[docs] def forward(self, x, extra_feat): """Forward function.""" offset = self.conv_offset(extra_feat) return deform_conv2d(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups, self.deform_groups)
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.