# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule

from mmagic.registry import MODELS

[docs]class FLAVRNet(BaseModule): """PyTorch implementation of FLAVR for video frame interpolation. Paper: FLAVR: Flow-Agnostic Video Representations for Fast Frame Interpolation Ref repo: Args: num_input_frames (int): Number of input frames. num_output_frames (int): Number of output frames. mid_channels_list (list[int]): List of number of mid channels. Default: [512, 256, 128, 64] encoder_layers_list (list[int]): List of number of layers in encoder. Default: [2, 2, 2, 2] bias (bool): If ``True``, adds a learnable bias to the conv layers. Default: ``True`` norm_cfg (dict | None): Config dict for normalization layer. Default: None join_type (str): Join type of tensors from decoder and encoder. Candidates are ``concat`` and ``add``. Default: ``concat`` up_mode (str): Up-mode UpConv3d, candidates are ``transpose`` and ``trilinear``. Default: ``transpose`` init_cfg (dict, optional): Initialization config dict. Default: None. """ def __init__(self, num_input_frames, num_output_frames, mid_channels_list=[512, 256, 128, 64], encoder_layers_list=[2, 2, 2, 2], bias=False, norm_cfg=None, join_type='concat', up_mode='transpose', init_cfg=None): super().__init__(init_cfg=init_cfg) self.encoder = Encoder( block=BasicBlock, layers=encoder_layers_list, stem_layer=BasicStem, mid_channels_list=mid_channels_list[::-1], bias=bias, norm_cfg=norm_cfg) self.decoder = Decoder( join_type=join_type, up_mode=up_mode, mid_channels_list=mid_channels_list, batchnorm=norm_cfg) self.feature_fuse = ConvModule( mid_channels_list[3] * num_input_frames, mid_channels_list[3], kernel_size=1, stride=1, bias=False, norm_cfg=norm_cfg, act_cfg=dict(type='LeakyReLU', negative_slope=0.2, inplace=True)) out_channels = 3 * num_output_frames self.conv_last = nn.Sequential( nn.ReflectionPad2d(3), nn.Conv2d( mid_channels_list[3], out_channels=out_channels, kernel_size=7, stride=1, padding=0))
[docs] def forward(self, images: torch.Tensor): """Forward function. Args: images (Tensor): Input frames tensor with shape (N, T, C, H, W). Returns: out (Tensor): Output tensor. """ # from [b, t, c, h, w] to [b, c, d, h, w], where t==d images = images.permute((0, 2, 1, 3, 4)) # Batch mean normalization works slightly better than global mean # normalization, Refer to mean_ = images.mean((2, 3, 4), keepdim=True) images = images - mean_ xs = self.encoder(images) dx_out = self.decoder(xs) out = self.feature_fuse(dx_out) out = self.conv_last(out) # b, t*c, h, w b, c_all, h, w = out.shape t = c_all // 3 mean_ = mean_.view(b, 1, 3, 1, 1) out = out.view(b, t, 3, h, w) out = out + mean_ # if t==1, which means the output only contains one frame. out = out.squeeze(1) return out
[docs]class Encoder(nn.Module): """Encoder of FLAVR. Args: block (nn.Module): Basic block of encoder. layers (str): List of layers in encoder. stem_layer (nn.Module): stem layer (conv first). mid_channels_list (list[int]): List of mid channels. norm_cfg (dict | None): Config dict for normalization layer. Default: None bias (bool): If ``True``, adds a learnable bias to the conv layers. Default: ``True`` """ def __init__(self, block, layers, stem_layer, mid_channels_list, norm_cfg, bias): super().__init__() self.in_channels = mid_channels_list[0] self.bias = bias self.stem_layer = stem_layer(mid_channels_list[0], bias, norm_cfg) self.layer1 = self._make_layer( block, mid_channels_list[0], layers[0], norm_cfg=norm_cfg, stride=1) self.layer2 = self._make_layer( block, mid_channels_list[1], layers[1], norm_cfg=norm_cfg, stride=2, temporal_stride=1) self.layer3 = self._make_layer( block, mid_channels_list[2], layers[2], norm_cfg=norm_cfg, stride=2, temporal_stride=1) self.layer4 = self._make_layer( block, mid_channels_list[3], layers[3], norm_cfg=norm_cfg, stride=1, temporal_stride=1) # init weights self._initialize_weights()
[docs] def forward(self, x): """Forward function. Args: x (Tensor): Input tensor). Returns: tuple(Tensor): Output tensors. """ x_0 = self.stem_layer(x) x_1 = self.layer1(x_0) x_2 = self.layer2(x_1) x_3 = self.layer3(x_2) x_4 = self.layer4(x_3) return x_0, x_1, x_2, x_3, x_4
[docs] def _make_layer(self, block, mid_channels, num_blocks, norm_cfg, stride=1, temporal_stride=None): """Make layers by stacking the blocks.""" downsample = None if stride != 1 or self.in_channels != mid_channels * block.expansion: if temporal_stride: ds_stride = (temporal_stride, stride, stride) else: ds_stride = (stride, stride, stride) downsample = ConvModule( self.in_channels, mid_channels * block.expansion, kernel_size=1, stride=ds_stride, bias=False, conv_cfg=dict(type='Conv3d'), norm_cfg=norm_cfg, act_cfg=None) stride = ds_stride layers = [] layers.append( block( self.in_channels, mid_channels, norm_cfg=norm_cfg, stride=stride, bias=self.bias, downsample=downsample)) self.in_channels = mid_channels * block.expansion for _ in range(1, num_blocks): layers.append( block( self.in_channels, mid_channels, norm_cfg=norm_cfg, bias=self.bias)) return nn.Sequential(*layers)
[docs] def _initialize_weights(self): """Init weights for models.""" for m in self.modules(): if isinstance(m, nn.Conv3d): nn.init.kaiming_normal_( m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm3d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0)
[docs]class Decoder(nn.Module): """Decoder of FLAVR. Args: join_type (str): Join type of tensors from decoder and encoder. Candidates are ``concat`` and ``add``. Default: ``concat`` up_mode (str): Up-mode UpConv3d, candidates are ``transpose`` and ``trilinear``. Default: ``transpose`` mid_channels_list (list[int]): List of mid channels. Default: [512, 256, 128, 64] batchnorm (bool): Whether contains BatchNorm3d. Default: False. """ def __init__(self, join_type, up_mode, mid_channels_list=[512, 256, 128, 64], batchnorm=False): super().__init__() growth = 2 if join_type == 'concat' else 1 self.join_type = join_type self.lrelu = nn.LeakyReLU(0.2, True) self.layer0 = Conv3d( mid_channels_list[0], mid_channels_list[1], kernel_size=3, padding=1, bias=True, batchnorm=batchnorm) self.layer1 = UpConv3d( mid_channels_list[1] * growth, mid_channels_list[2], kernel_size=(3, 4, 4), stride=(1, 2, 2), padding=(1, 1, 1), up_mode=up_mode, batchnorm=batchnorm) self.layer2 = UpConv3d( mid_channels_list[2] * growth, mid_channels_list[3], kernel_size=(3, 4, 4), stride=(1, 2, 2), padding=(1, 1, 1), up_mode=up_mode, batchnorm=batchnorm) self.layer3 = Conv3d( mid_channels_list[3] * growth, mid_channels_list[3], kernel_size=3, padding=1, bias=True, batchnorm=batchnorm) self.layer4 = UpConv3d( mid_channels_list[3] * growth, mid_channels_list[3], kernel_size=(3, 4, 4), stride=(1, 2, 2), padding=(1, 1, 1), up_mode=up_mode, batchnorm=batchnorm)
[docs] def forward(self, xs): """Forward function. Args: xs (Tensor): Input tensor). Returns: dx_out (Tensor): Output tensor. """ dx_3 = self.lrelu(self.layer0(xs[4])) dx_3 = self._join_tensors(dx_3, xs[3]) dx_2 = self.lrelu(self.layer1(dx_3)) dx_2 = self._join_tensors(dx_2, xs[2]) dx_1 = self.lrelu(self.layer2(dx_2)) dx_1 = self._join_tensors(dx_1, xs[1]) dx_0 = self.lrelu(self.layer3(dx_1)) dx_0 = self._join_tensors(dx_0, xs[0]) dx_out = self.lrelu(self.layer4(dx_0)) dx_out =, 2), 1) return dx_out
[docs] def _join_tensors(self, x1, x2): """Concat or Add two tensors. Args: x1 (Tensor): The first input tensor. x2 (Tensor): The second input tensor. """ if self.join_type == 'concat': return[x1, x2], dim=1) else: return x1 + x2
[docs]class UpConv3d(nn.Module): """A conv block that bundles conv/SEGating/norm layers. Args: in_channels (int): Number of channels in the input feature map. Same as that in ``nn._ConvNd``. out_channels (int): Number of channels produced by the convolution. Same as that in ``nn._ConvNd``. kernel_size (int | tuple[int]): Size of the convolving kernel. Same as that in ``nn._ConvNd``. stride (int | tuple[int]): Stride of the convolution. Same as that in ``nn._ConvNd``. padding (int | tuple[int]): Zero-padding added to both sides of the input. Same as that in ``nn._ConvNd``. up_mode (str): Up-mode UpConv3d, candidates are ``transpose`` and ``trilinear``. Default: ``transpose``. batchnorm (bool): Whether contains BatchNorm3d. Default: False. """ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, up_mode='transpose', batchnorm=False): super().__init__() self.up_mode = up_mode if self.up_mode == 'transpose': self.upconv = nn.ModuleList([ nn.ConvTranspose3d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding), SEGating(out_channels) ]) else: self.upconv = nn.ModuleList([ nn.Upsample( mode='trilinear', scale_factor=(1, 2, 2), align_corners=False), nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1), SEGating(out_channels) ]) if batchnorm: self.upconv += [nn.BatchNorm3d(out_channels)] self.upconv = nn.Sequential(*self.upconv)
[docs] def forward(self, x): """Forward function.""" return self.upconv(x)
[docs]class Conv3d(nn.Module): """A conv block that bundles conv/SEGating/norm layers. Args: in_channels (int): Number of channels in the input feature map. Same as that in ``nn._ConvNd``. out_channels (int): Number of channels produced by the convolution. Same as that in ``nn._ConvNd``. kernel_size (int | tuple[int]): Size of the convolving kernel. Same as that in ``nn._ConvNd``. stride (int | tuple[int]): Stride of the convolution. Same as that in ``nn._ConvNd``. Default: 1. padding (int | tuple[int]): Zero-padding added to both sides of the input. Same as that in ``nn._ConvNd``. bias (bool): If ``True``, adds a learnable bias to the conv layer. Default: ``True`` batchnorm (bool): Whether contains BatchNorm3d. Default: False. """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, batchnorm=False): super().__init__() self.conv = [ nn.Conv3d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias), SEGating(out_channels) ] if batchnorm: self.conv += [nn.BatchNorm3d(out_channels)] self.conv = nn.Sequential(*self.conv)
[docs] def forward(self, x): """Forward function.""" return self.conv(x)
[docs]class BasicStem(ConvModule): """The default conv-batchnorm-relu stem of FLAVR. Args: out_channels (int): Number of output channels. Default: 64 bias (bool): If ``True``, adds a learnable bias to the conv layer. Default: ``False`` norm_cfg (dict | None): Config dict for normalization layer. Default: None. """ def __init__(self, out_channels=64, bias=False, norm_cfg=None): super().__init__( 3, out_channels, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=bias, conv_cfg=dict(type='Conv3d'), norm_cfg=norm_cfg, inplace=False)
[docs]class BasicBlock(nn.Module): """Basic block of encoder in FLAVR. Args: in_channels (int): Number of channels in the input feature map. mid_channels (int): Number of middle channels. stride (int | tuple[int]): Stride of the first convolution. Default: 1. norm_cfg (dict | None): Config dict for normalization layer. Default: None. bias (bool): If ``True``, adds a learnable bias to the conv layers. Default: ``True`` downsample (None | torch.nn.Module): Down-sample layer. Default: None. """
[docs] expansion = 1
def __init__( self, in_channels, mid_channels, stride=1, norm_cfg=None, bias=False, downsample=None, ): super().__init__() self.conv1 = ConvModule( in_channels, mid_channels, kernel_size=(3, 3, 3), stride=stride, padding=(1, 1, 1), bias=bias, conv_cfg=dict(type='Conv3d'), norm_cfg=norm_cfg) self.conv2 = ConvModule( mid_channels, mid_channels, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=bias, conv_cfg=dict(type='Conv3d'), norm_cfg=norm_cfg, act_cfg=None) self.fg = SEGating(mid_channels) # Feature Gating self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride
[docs] def forward(self, x): """Forward function. Args: xs (Tensor): Input tensor). Returns: out (Tensor): Output tensor. """ residual = x out = self.conv1(x) out = self.conv2(out) out = self.fg(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out
[docs]class SEGating(nn.Module): """Gating of SE attention. Args: in_channels (int): Number of channels in the input feature map. """ def __init__(self, in_channels): super().__init__() self.pool = nn.AdaptiveAvgPool3d(1) self.attn_layer = nn.Sequential( nn.Conv3d( in_channels, in_channels, kernel_size=1, stride=1, bias=True), nn.Sigmoid())
[docs] def forward(self, x): """Forward function. Args: x (Tensor): Input tensor). Returns: Tensor: Output tensors. """ out = self.pool(x) y = self.attn_layer(out) return x * y
