Shortcuts

Source code for mmagic.models.editors.swinir.swinir_net

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule
from mmengine.model.weight_init import trunc_normal_

from mmagic.registry import MODELS
from .swinir_modules import PatchEmbed, PatchUnEmbed, Upsample, UpsampleOneStep
from .swinir_rstb import RSTB


@MODELS.register_module()
[docs]class SwinIRNet(BaseModule): r""" SwinIR A PyTorch impl of: `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. Ref repo: https://github.com/JingyunLiang/SwinIR Args: img_size (int | tuple(int)): Input image size. Default 64 patch_size (int | tuple(int)): Patch size. Default: 1 in_chans (int): Number of input image channels. Default: 3 embed_dim (int): Patch embedding dimension. Default: 96 depths (tuple(int)): Depth of each Swin Transformer layer. Default: [6, 6, 6, 6] num_heads (tuple(int)): Number of attention heads in different layers. Default: [6, 6, 6, 6] window_size (int): Window size. Default: 7 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None drop_rate (float): Dropout rate. Default: 0 attn_drop_rate (float): Attention dropout rate. Default: 0 drop_path_rate (float): Stochastic depth rate. Default: 0.1 norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False patch_norm (bool): If True, add normalization after patch embedding. Default: True use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False upscale (int): Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction. Default: 2 img_range (float): Image range. 1. or 255. Default: 1.0 upsampler (string, optional): The reconstruction module. 'pixelshuffle' / 'pixelshuffledirect' /'nearest+conv'/None. Default: '' resi_connection (string): The convolutional block before residual connection. '1conv'/'3conv'. Default: '1conv' """ def __init__(self, img_size=64, patch_size=1, in_chans=3, embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', **kwargs): super(SwinIRNet, self).__init__() num_in_ch = in_chans num_out_ch = in_chans num_feat = 64 self.img_range = img_range if in_chans == 3: rgb_mean = (0.4488, 0.4371, 0.4040) self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) else: self.mean = torch.zeros(1, 1, 1, 1) self.upscale = upscale self.upsampler = upsampler self.window_size = window_size # 1, shallow feature extraction self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) # 2, deep feature extraction self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm self.num_features = embed_dim self.mlp_ratio = mlp_ratio # split image into non-overlapping patches self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution # merge non-overlapping patches into image self.patch_unembed = PatchUnEmbed( img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) # absolute position embedding if self.ape: self.absolute_pos_embed = nn.Parameter( torch.zeros(1, num_patches, embed_dim)) trunc_normal_(self.absolute_pos_embed, std=.02) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth decay rule dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) ] # build Residual Swin Transformer blocks (RSTB) self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = RSTB( dim=embed_dim, input_resolution=(patches_resolution[0], patches_resolution[1]), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=None, use_checkpoint=use_checkpoint, img_size=img_size, patch_size=patch_size, resi_connection=resi_connection) self.layers.append(layer) self.norm = norm_layer(self.num_features) # build the last conv layer in deep feature extraction if resi_connection == '1conv': self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) elif resi_connection == '3conv': # to save parameters and memory self.conv_after_body = nn.Sequential( nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) # 3, high quality image reconstruction if self.upsampler == 'pixelshuffle': # for classical SR self.conv_before_upsample = nn.Sequential( nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)) self.upsample = Upsample(upscale, num_feat) self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) elif self.upsampler == 'pixelshuffledirect': # for lightweight SR (to save parameters) self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, (patches_resolution[0], patches_resolution[1])) elif self.upsampler == 'nearest+conv': # for real-world SR (less artifacts) self.conv_before_upsample = nn.Sequential( nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)) self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) if self.upscale == 4: self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) else: # for image denoising and JPEG compression artifact reduction self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) self.apply(self._init_weights)
[docs] def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
[docs] def no_weight_decay(self): return {'absolute_pos_embed'}
@torch.jit.ignore
[docs] def no_weight_decay_keywords(self): return {'relative_position_bias_table'}
[docs] def check_image_size(self, x): """Check image size and pad images so that it has enough dimension do window size. args: x: input tensor image with (B, C, H, W) shape. """ _, _, h, w = x.size() mod_pad_h = (self.window_size - h % self.window_size) % self.window_size mod_pad_w = (self.window_size - w % self.window_size) % self.window_size x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') return x
[docs] def forward_features(self, x): """Forward function of Deep Feature Extraction. Args: x (Tensor): Input tensor with shape (B, C, H, W). Returns: Tensor: Forward results. """ x_size = (x.shape[2], x.shape[3]) x = self.patch_embed(x) if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) for layer in self.layers: x = layer(x, x_size) x = self.norm(x) # B L C x = self.patch_unembed(x, x_size) return x
[docs] def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (B, C, H, W). Returns: Tensor: Forward results. """ H, W = x.shape[2:] x = self.check_image_size(x) self.mean = self.mean.type_as(x) x = (x - self.mean) * self.img_range if self.upsampler == 'pixelshuffle': # for classical SR x = self.conv_first(x) x = self.conv_after_body(self.forward_features(x)) + x x = self.conv_before_upsample(x) x = self.conv_last(self.upsample(x)) elif self.upsampler == 'pixelshuffledirect': # for lightweight SR x = self.conv_first(x) x = self.conv_after_body(self.forward_features(x)) + x x = self.upsample(x) elif self.upsampler == 'nearest+conv': # for real-world SR x = self.conv_first(x) x = self.conv_after_body(self.forward_features(x)) + x x = self.conv_before_upsample(x) x = self.lrelu( self.conv_up1( torch.nn.functional.interpolate( x, scale_factor=2, mode='nearest'))) if self.upscale == 4: x = self.lrelu( self.conv_up2( torch.nn.functional.interpolate( x, scale_factor=2, mode='nearest'))) x = self.conv_last(self.lrelu(self.conv_hr(x))) else: # for image denoising and JPEG compression artifact reduction x_first = self.conv_first(x) res = self.conv_after_body( self.forward_features(x_first)) + x_first x = x + self.conv_last(res) x = x / self.img_range + self.mean return x[:, :, :H * self.upscale, :W * self.upscale]
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.