Shortcuts

Source code for mmagic.models.editors.srgan.sr_resnet

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmengine.model import BaseModule

from mmagic.models.archs import PixelShufflePack, ResidualBlockNoBN
from mmagic.models.utils import default_init_weights, make_layer
from mmagic.registry import MODELS


@MODELS.register_module()
[docs]class MSRResNet(BaseModule): """Modified SRResNet. A compacted version modified from SRResNet in "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network". It uses residual blocks without BN, similar to EDSR. Currently, it supports x2, x3 and x4 upsampling scale factor. Args: in_channels (int): Channel number of inputs. out_channels (int): Channel number of outputs. mid_channels (int): Channel number of intermediate features. Default: 64. num_blocks (int): Block number in the trunk network. Default: 16. upscale_factor (int): Upsampling factor. Support x2, x3 and x4. Default: 4. """
[docs] _supported_upscale_factors = [2, 3, 4]
def __init__(self, in_channels, out_channels, mid_channels=64, num_blocks=16, upscale_factor=4): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.mid_channels = mid_channels self.num_blocks = num_blocks self.upscale_factor = upscale_factor self.conv_first = nn.Conv2d( in_channels, mid_channels, 3, 1, 1, bias=True) self.trunk_net = make_layer( ResidualBlockNoBN, num_blocks, mid_channels=mid_channels) # upsampling if self.upscale_factor in [2, 3]: self.upsample1 = PixelShufflePack( mid_channels, mid_channels, self.upscale_factor, upsample_kernel=3) elif self.upscale_factor == 4: self.upsample1 = PixelShufflePack( mid_channels, mid_channels, 2, upsample_kernel=3) self.upsample2 = PixelShufflePack( mid_channels, mid_channels, 2, upsample_kernel=3) else: raise ValueError( f'Unsupported scale factor {self.upscale_factor}. ' f'Currently supported ones are ' f'{self._supported_upscale_factors}.') self.conv_hr = nn.Conv2d( mid_channels, mid_channels, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d( mid_channels, out_channels, 3, 1, 1, bias=True) self.img_upsampler = nn.Upsample( scale_factor=self.upscale_factor, mode='bilinear', align_corners=False) # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) self.init_weights()
[docs] def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ feat = self.lrelu(self.conv_first(x)) out = self.trunk_net(feat) if self.upscale_factor in [2, 3]: out = self.upsample1(out) elif self.upscale_factor == 4: out = self.upsample1(out) out = self.upsample2(out) out = self.conv_last(self.lrelu(self.conv_hr(out))) upsampled_img = self.img_upsampler(x) out += upsampled_img return out
[docs] def init_weights(self): """Init weights for models.""" for m in [self.conv_first, self.conv_hr, self.conv_last]: default_init_weights(m, 0.1)
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.