Shortcuts

Source code for mmagic.models.editors.edsr.edsr_net

# Copyright (c) OpenMMLab. All rights reserved.
import math

import torch
import torch.nn as nn
from mmengine.model import BaseModule

from mmagic.models.archs import PixelShufflePack, ResidualBlockNoBN
from mmagic.models.utils import make_layer
from mmagic.registry import MODELS


@MODELS.register_module()
[docs]class EDSRNet(BaseModule): """EDSR network structure. Paper: Enhanced Deep Residual Networks for Single Image Super-Resolution. Ref repo: https://github.com/thstkdgus35/EDSR-PyTorch Args: in_channels (int): Channel number of inputs. out_channels (int): Channel number of outputs. mid_channels (int): Channel number of intermediate features. Default: 64. num_blocks (int): Block number in the trunk network. Default: 16. upscale_factor (int): Upsampling factor. Support 2^n and 3. Default: 4. res_scale (float): Used to scale the residual in residual block. Default: 1. rgb_mean (list[float]): Image mean in RGB orders. Default: [0.4488, 0.4371, 0.4040], calculated from DIV2K dataset. rgb_std (list[float]): Image std in RGB orders. In EDSR, it uses [1.0, 1.0, 1.0]. Default: [1.0, 1.0, 1.0]. """ def __init__(self, in_channels, out_channels, mid_channels=64, num_blocks=16, upscale_factor=4, res_scale=1, rgb_mean=[0.4488, 0.4371, 0.4040], rgb_std=[1.0, 1.0, 1.0]): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.mid_channels = mid_channels self.num_blocks = num_blocks self.upscale_factor = upscale_factor self.mean = torch.Tensor(rgb_mean).view(1, -1, 1, 1) self.std = torch.Tensor(rgb_std).view(1, -1, 1, 1) self.conv_first = nn.Conv2d(in_channels, mid_channels, 3, padding=1) self.body = make_layer( ResidualBlockNoBN, num_blocks, mid_channels=mid_channels, res_scale=res_scale) self.conv_after_body = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) self.upsample = UpsampleModule(upscale_factor, mid_channels) self.conv_last = nn.Conv2d( mid_channels, out_channels, 3, 1, 1, bias=True)
[docs] def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ self.mean = self.mean.to(x) self.std = self.std.to(x) x = (x - self.mean) / self.std x = self.conv_first(x) res = self.conv_after_body(self.body(x)) res += x x = self.conv_last(self.upsample(res)) x = x * self.std + self.mean return x
[docs]class UpsampleModule(nn.Sequential): """Upsample module used in EDSR. Args: scale (int): Scale factor. Supported scales: 2^n and 3. mid_channels (int): Channel number of intermediate features. """ def __init__(self, scale, mid_channels): modules = [] if (scale & (scale - 1)) == 0: # scale = 2^n for _ in range(int(math.log(scale, 2))): modules.append( PixelShufflePack( mid_channels, mid_channels, 2, upsample_kernel=3)) elif scale == 3: modules.append( PixelShufflePack( mid_channels, mid_channels, scale, upsample_kernel=3)) else: raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') super().__init__(*modules)
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.