Shortcuts

mmagic.models.editors.ttsr.ttsr_net 源代码

from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_conv_layer
from mmengine.model import BaseModule

from mmagic.models.archs import PixelShufflePack, ResidualBlockNoBN
from mmagic.models.utils import make_layer
from mmagic.registry import MODELS

# Use partial to specify some default arguments
[文档]_conv3x3_layer = partial( build_conv_layer, dict(type='Conv2d'), kernel_size=3, padding=1)
[文档]_conv1x1_layer = partial( build_conv_layer, dict(type='Conv2d'), kernel_size=1, padding=0)
@MODELS.register_module()
[文档]class TTSRNet(BaseModule): """TTSR network structure (main-net) for reference-based super-resolution. Paper: Learning Texture Transformer Network for Image Super-Resolution Adapted from 'https://github.com/researchmm/TTSR.git' 'https://github.com/researchmm/TTSR' Copyright permission at 'https://github.com/researchmm/TTSR/issues/38'. Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels in the output image mid_channels (int): Channel number of intermediate features. Default: 64 texture_channels (int): Number of texture channels. Default: 64. num_blocks (tuple[int]): Block numbers in the trunk network. Default: (16, 16, 8, 4) res_scale (float): Used to scale the residual in residual block. Default: 1. init_cfg (dict, optional): Initialization config dict. """ def __init__(self, in_channels, out_channels, mid_channels=64, texture_channels=64, num_blocks=(16, 16, 8, 4), res_scale=1.0, init_cfg=None): super().__init__(init_cfg=init_cfg) self.texture_channels = texture_channels self.sfe = SFE(in_channels, mid_channels, num_blocks[0], res_scale) # stage 1 self.conv_first1 = _conv3x3_layer(4 * texture_channels + mid_channels, mid_channels) self.res_block1 = make_layer( ResidualBlockNoBN, num_blocks[1], mid_channels=mid_channels, res_scale=res_scale) self.conv_last1 = _conv3x3_layer(mid_channels, mid_channels) # up-sampling 1 -> 2 self.up1 = PixelShufflePack( in_channels=mid_channels, out_channels=mid_channels, scale_factor=2, upsample_kernel=3) # stage 2 self.conv_first2 = _conv3x3_layer(2 * texture_channels + mid_channels, mid_channels) self.csfi2 = CSFI2(mid_channels) self.res_block2_1 = make_layer( ResidualBlockNoBN, num_blocks[2], mid_channels=mid_channels, res_scale=res_scale) self.res_block2_2 = make_layer( ResidualBlockNoBN, num_blocks[2], mid_channels=mid_channels, res_scale=res_scale) self.conv_last2_1 = _conv3x3_layer(mid_channels, mid_channels) self.conv_last2_2 = _conv3x3_layer(mid_channels, mid_channels) # up-sampling 2 -> 3 self.up2 = PixelShufflePack( in_channels=mid_channels, out_channels=mid_channels, scale_factor=2, upsample_kernel=3) # stage 3 self.conv_first3 = _conv3x3_layer(texture_channels + mid_channels, mid_channels) self.csfi3 = CSFI3(mid_channels) self.res_block3_1 = make_layer( ResidualBlockNoBN, num_blocks[3], mid_channels=mid_channels, res_scale=res_scale) self.res_block3_2 = make_layer( ResidualBlockNoBN, num_blocks[3], mid_channels=mid_channels, res_scale=res_scale) self.res_block3_3 = make_layer( ResidualBlockNoBN, num_blocks[3], mid_channels=mid_channels, res_scale=res_scale) self.conv_last3_1 = _conv3x3_layer(mid_channels, mid_channels) self.conv_last3_2 = _conv3x3_layer(mid_channels, mid_channels) self.conv_last3_3 = _conv3x3_layer(mid_channels, mid_channels) # end, merge features self.merge_features = MergeFeatures(mid_channels, out_channels)
[文档] def forward(self, x, soft_attention, textures): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). soft_attention (Tensor): Soft-Attention tensor with shape (n, 1, h, w). textures (Tuple[Tensor]): Transferred HR texture tensors. [(N, C, H, W), (N, C/2, 2H, 2W), ...] Returns: Tensor: Forward results. """ assert textures[-1].shape[1] == self.texture_channels x1 = self.sfe(x) # stage 1 x1_res = torch.cat((x1, textures[0]), dim=1) x1_res = self.conv_first1(x1_res) # soft-attention x1 = x1 + x1_res * soft_attention x1_res = self.res_block1(x1) x1_res = self.conv_last1(x1_res) x1 = x1 + x1_res # stage 2 x21 = x1 x22 = self.up1(x1) x22 = F.relu(x22) x22_res = torch.cat((x22, textures[1]), dim=1) x22_res = self.conv_first2(x22_res) # soft-attention x22_res = x22_res * F.interpolate( soft_attention, scale_factor=2, mode='bicubic', align_corners=False) x22 = x22 + x22_res x21_res, x22_res = self.csfi2(x21, x22) x21_res = self.res_block2_1(x21_res) x22_res = self.res_block2_2(x22_res) x21_res = self.conv_last2_1(x21_res) x22_res = self.conv_last2_2(x22_res) x21 = x21 + x21_res x22 = x22 + x22_res # stage 3 x31 = x21 x32 = x22 x33 = self.up2(x22) x33 = F.relu(x33) x33_res = torch.cat((x33, textures[2]), dim=1) x33_res = self.conv_first3(x33_res) # soft-attention x33_res = x33_res * F.interpolate( soft_attention, scale_factor=4, mode='bicubic', align_corners=False) x33 = x33 + x33_res x31_res, x32_res, x33_res = self.csfi3(x31, x32, x33) x31_res = self.res_block3_1(x31_res) x32_res = self.res_block3_2(x32_res) x33_res = self.res_block3_3(x33_res) x31_res = self.conv_last3_1(x31_res) x32_res = self.conv_last3_2(x32_res) x33_res = self.conv_last3_3(x33_res) x31 = x31 + x31_res x32 = x32 + x32_res x33 = x33 + x33_res x = self.merge_features(x31, x32, x33) return x
[文档]class SFE(nn.Module): """Structural Feature Encoder. Backbone of Texture Transformer Network for Image Super-Resolution. Args: in_channels (int): Number of channels in the input image mid_channels (int): Channel number of intermediate features num_blocks (int): Block number in the trunk network res_scale (float): Used to scale the residual in residual block. Default: 1. """ def __init__(self, in_channels, mid_channels, num_blocks, res_scale): super().__init__() self.num_blocks = num_blocks self.conv_first = _conv3x3_layer(in_channels, mid_channels) self.body = make_layer( ResidualBlockNoBN, num_blocks, mid_channels=mid_channels, res_scale=res_scale) self.conv_last = _conv3x3_layer(mid_channels, mid_channels)
[文档] def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ x1 = x = F.relu(self.conv_first(x)) x = self.body(x) x = self.conv_last(x) x = x + x1 return x
[文档]class CSFI2(nn.Module): """Cross-Scale Feature Integration between 1x and 2x features. Cross-Scale Feature Integration in Texture Transformer Network for Image Super-Resolution. It is cross-scale feature integration between 1x and 2x features. For example, `conv2to1` means conv layer from 2x feature to 1x feature. Down-sampling is achieved by conv layer with stride=2, and up-sampling is achieved by bicubic interpolate and conv layer. Args: mid_channels (int): Channel number of intermediate features """ def __init__(self, mid_channels): super().__init__() self.conv1to2 = _conv1x1_layer(mid_channels, mid_channels) self.conv2to1 = _conv3x3_layer(mid_channels, mid_channels, stride=2) self.conv_merge1 = _conv3x3_layer(mid_channels * 2, mid_channels) self.conv_merge2 = _conv3x3_layer(mid_channels * 2, mid_channels)
[文档] def forward(self, x1, x2): """Forward function. Args: x1 (Tensor): Input tensor with shape (n, c, h, w). x2 (Tensor): Input tensor with shape (n, c, 2h, 2w). Returns: x1 (Tensor): Output tensor with shape (n, c, h, w). x2 (Tensor): Output tensor with shape (n, c, 2h, 2w). """ x12 = F.interpolate( x1, scale_factor=2, mode='bicubic', align_corners=False) x12 = F.relu(self.conv1to2(x12)) x21 = F.relu(self.conv2to1(x2)) x1 = F.relu(self.conv_merge1(torch.cat((x1, x21), dim=1))) x2 = F.relu(self.conv_merge2(torch.cat((x2, x12), dim=1))) return x1, x2
[文档]class CSFI3(nn.Module): """Cross-Scale Feature Integration between 1x, 2x, and 4x features. Cross-Scale Feature Integration in Texture Transformer Network for Image Super-Resolution. It is cross-scale feature integration between 1x and 2x features. For example, `conv2to1` means conv layer from 2x feature to 1x feature. Down-sampling is achieved by conv layer with stride=2, and up-sampling is achieved by bicubic interpolate and conv layer. Args: mid_channels (int): Channel number of intermediate features """ def __init__(self, mid_channels): super().__init__() self.conv1to2 = _conv1x1_layer(mid_channels, mid_channels) self.conv1to4 = _conv1x1_layer(mid_channels, mid_channels) self.conv2to1 = _conv3x3_layer(mid_channels, mid_channels, stride=2) self.conv2to4 = _conv1x1_layer(mid_channels, mid_channels) self.conv4to1_1 = _conv3x3_layer(mid_channels, mid_channels, stride=2) self.conv4to1_2 = _conv3x3_layer(mid_channels, mid_channels, stride=2) self.conv4to2 = _conv3x3_layer(mid_channels, mid_channels, stride=2) self.conv_merge1 = _conv3x3_layer(mid_channels * 3, mid_channels) self.conv_merge2 = _conv3x3_layer(mid_channels * 3, mid_channels) self.conv_merge4 = _conv3x3_layer(mid_channels * 3, mid_channels)
[文档] def forward(self, x1, x2, x4): """Forward function. Args: x1 (Tensor): Input tensor with shape (n, c, h, w). x2 (Tensor): Input tensor with shape (n, c, 2h, 2w). x4 (Tensor): Input tensor with shape (n, c, 4h, 4w). Returns: x1 (Tensor): Output tensor with shape (n, c, h, w). x2 (Tensor): Output tensor with shape (n, c, 2h, 2w). x4 (Tensor): Output tensor with shape (n, c, 4h, 4w). """ x12 = F.interpolate( x1, scale_factor=2, mode='bicubic', align_corners=False) x12 = F.relu(self.conv1to2(x12)) x14 = F.interpolate( x1, scale_factor=4, mode='bicubic', align_corners=False) x14 = F.relu(self.conv1to4(x14)) x21 = F.relu(self.conv2to1(x2)) x24 = F.interpolate( x2, scale_factor=2, mode='bicubic', align_corners=False) x24 = F.relu(self.conv2to4(x24)) x41 = F.relu(self.conv4to1_1(x4)) x41 = F.relu(self.conv4to1_2(x41)) x42 = F.relu(self.conv4to2(x4)) x1 = F.relu(self.conv_merge1(torch.cat((x1, x21, x41), dim=1))) x2 = F.relu(self.conv_merge2(torch.cat((x2, x12, x42), dim=1))) x4 = F.relu(self.conv_merge4(torch.cat((x4, x14, x24), dim=1))) return x1, x2, x4
[文档]class MergeFeatures(nn.Module): """Merge Features. Merge 1x, 2x, and 4x features. Final module of Texture Transformer Network for Image Super-Resolution. Args: mid_channels (int): Channel number of intermediate features out_channels (int): Number of channels in the output image """ def __init__(self, mid_channels, out_channels): super().__init__() self.conv1to4 = _conv1x1_layer(mid_channels, mid_channels) self.conv2to4 = _conv1x1_layer(mid_channels, mid_channels) self.conv_merge = _conv3x3_layer(mid_channels * 3, mid_channels) self.conv_last1 = _conv3x3_layer(mid_channels, mid_channels // 2) self.conv_last2 = _conv1x1_layer(mid_channels // 2, out_channels)
[文档] def forward(self, x1, x2, x4): """Forward function. Args: x1 (Tensor): Input tensor with shape (n, c, h, w). x2 (Tensor): Input tensor with shape (n, c, 2h, 2w). x4 (Tensor): Input tensor with shape (n, c, 4h, 4w). Returns: x (Tensor): Output tensor with shape (n, c_out, 4h, 4w). """ x14 = F.interpolate( x1, scale_factor=4, mode='bicubic', align_corners=False) x14 = F.relu(self.conv1to4(x14)) x24 = F.interpolate( x2, scale_factor=2, mode='bicubic', align_corners=False) x24 = F.relu(self.conv2to4(x24)) x = F.relu(self.conv_merge(torch.cat((x4, x14, x24), dim=1))) x = self.conv_last1(x) x = self.conv_last2(x) x = torch.clamp(x, -1, 1) return x
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.