Source code for mmagic.models.archs.sr_backbone
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from torch import Tensor
from ..utils import default_init_weights
# def default_init_weights(module, scale=1):
# """Initialize network weights.
# Args:
# modules (nn.Module): Modules to be initialized.
# scale (float): Scale initialized weights, especially for residual
# blocks. Default: 1.
# """
# for m in module.modules():
# if isinstance(m, nn.Conv2d):
# kaiming_init(m, a=0, mode='fan_in', bias=0)
# m.weight.data *= scale
# elif isinstance(m, nn.Linear):
# kaiming_init(m, a=0, mode='fan_in', bias=0)
# m.weight.data *= scale
# elif isinstance(m, _BatchNorm):
# constant_init(m.weight, val=1, bias=0)
# def make_layer(block, num_blocks, **kwarg):
# """Make layers by stacking the same blocks.
# Args:
# block (nn.module): nn.module class for basic block.
# num_blocks (int): number of blocks.
# Returns:
# nn.Sequential: Stacked blocks in nn.Sequential.
# """
# layers = []
# for _ in range(num_blocks):
# layers.append(block(**kwarg))
# return nn.Sequential(*layers)
[docs]class ResidualBlockNoBN(nn.Module):
"""Residual block without BN.
It has a style of:
::
---Conv-ReLU-Conv-+-
|________________|
Args:
mid_channels (int): Channel number of intermediate features.
Default: 64.
res_scale (float): Used to scale the residual before addition.
Default: 1.0.
"""
def __init__(self, mid_channels: int = 64, res_scale: float = 1.0):
super().__init__()
self.res_scale = res_scale
self.conv1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1, bias=True)
self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1, bias=True)
self.relu = nn.ReLU(inplace=True)
# if res_scale < 1.0, use the default initialization, as in EDSR.
# if res_scale = 1.0, use scaled kaiming_init, as in MSRResNet.
if res_scale == 1.0:
self.init_weights()
[docs] def init_weights(self) -> None:
"""Initialize weights for ResidualBlockNoBN.
Initialization methods like `kaiming_init` are for VGG-style modules.
For modules with residual paths, using smaller std is better for
stability and performance. We empirically use 0.1. See more details in
"ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks"
"""
for m in [self.conv1, self.conv2]:
default_init_weights(m, 0.1)
[docs] def forward(self, x: Tensor) -> Tensor:
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
identity = x
out = self.conv2(self.relu(self.conv1(x)))
return identity + out * self.res_scale