Source code for mmagic.models.editors.srgan.sr_resnet
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmengine.model import BaseModule
from mmagic.models.archs import PixelShufflePack, ResidualBlockNoBN
from mmagic.models.utils import default_init_weights, make_layer
from mmagic.registry import MODELS
@MODELS.register_module()
[docs]class MSRResNet(BaseModule):
"""Modified SRResNet.
A compacted version modified from SRResNet in "Photo-Realistic Single
Image Super-Resolution Using a Generative Adversarial Network".
It uses residual blocks without BN, similar to EDSR.
Currently, it supports x2, x3 and x4 upsampling scale factor.
Args:
in_channels (int): Channel number of inputs.
out_channels (int): Channel number of outputs.
mid_channels (int): Channel number of intermediate features.
Default: 64.
num_blocks (int): Block number in the trunk network. Default: 16.
upscale_factor (int): Upsampling factor. Support x2, x3 and x4.
Default: 4.
"""
def __init__(self,
in_channels,
out_channels,
mid_channels=64,
num_blocks=16,
upscale_factor=4):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.mid_channels = mid_channels
self.num_blocks = num_blocks
self.upscale_factor = upscale_factor
self.conv_first = nn.Conv2d(
in_channels, mid_channels, 3, 1, 1, bias=True)
self.trunk_net = make_layer(
ResidualBlockNoBN, num_blocks, mid_channels=mid_channels)
# upsampling
if self.upscale_factor in [2, 3]:
self.upsample1 = PixelShufflePack(
mid_channels,
mid_channels,
self.upscale_factor,
upsample_kernel=3)
elif self.upscale_factor == 4:
self.upsample1 = PixelShufflePack(
mid_channels, mid_channels, 2, upsample_kernel=3)
self.upsample2 = PixelShufflePack(
mid_channels, mid_channels, 2, upsample_kernel=3)
else:
raise ValueError(
f'Unsupported scale factor {self.upscale_factor}. '
f'Currently supported ones are '
f'{self._supported_upscale_factors}.')
self.conv_hr = nn.Conv2d(
mid_channels, mid_channels, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(
mid_channels, out_channels, 3, 1, 1, bias=True)
self.img_upsampler = nn.Upsample(
scale_factor=self.upscale_factor,
mode='bilinear',
align_corners=False)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.init_weights()
[docs] def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
feat = self.lrelu(self.conv_first(x))
out = self.trunk_net(feat)
if self.upscale_factor in [2, 3]:
out = self.upsample1(out)
elif self.upscale_factor == 4:
out = self.upsample1(out)
out = self.upsample2(out)
out = self.conv_last(self.lrelu(self.conv_hr(out)))
upsampled_img = self.img_upsampler(x)
out += upsampled_img
return out
[docs] def init_weights(self):
"""Init weights for models."""
for m in [self.conv_first, self.conv_hr, self.conv_last]:
default_init_weights(m, 0.1)