Shortcuts

Source code for mmagic.models.editors.dic.dic_net

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmengine.model import BaseModule

from mmagic.models.utils import make_layer
from mmagic.registry import MODELS
from .feedback_hour_glass import FeedbackHourglass, reduce_to_five_heatmaps


@MODELS.register_module()
[docs]class DICNet(BaseModule): """DIC network structure for face super-resolution. Paper: Deep Face Super-Resolution with Iterative Collaboration between Attentive Recovery and Landmark Estimation Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels in the output image mid_channels (int): Channel number of intermediate features. Default: 64 num_blocks (tuple[int]): Block numbers in the trunk network. Default: 6 hg_mid_channels (int): Channel number of intermediate features of HourGlass. Default: 256 hg_num_keypoints (int): Keypoint number of HourGlass. Default: 68 num_steps (int): Number of iterative steps. Default: 4 upscale_factor (int): Upsampling factor. Default: 8 detach_attention (bool): Detached from the current tensor for heatmap or not. prelu_init (float): `init` of PReLU. Default: 0.2 num_heatmaps (int): Number of heatmaps. Default: 5 num_fusion_blocks (int): Number of fusion blocks. Default: 7 init_cfg (dict, optional): Initialization config dict. Default: None. """ def __init__(self, in_channels, out_channels, mid_channels, num_blocks=6, hg_mid_channels=256, hg_num_keypoints=68, num_steps=4, upscale_factor=8, detach_attention=False, prelu_init=0.2, num_heatmaps=5, num_fusion_blocks=7, init_cfg=None): super().__init__(init_cfg=init_cfg) self.num_steps = num_steps self.detach_attention = detach_attention self.conv_first = nn.Sequential( nn.Conv2d(in_channels, mid_channels * 4, 3, 1, 1), nn.PReLU(init=prelu_init), nn.PixelShuffle(2)) self.first_block = FeedbackBlockCustom( in_channels=mid_channels, mid_channels=mid_channels, num_blocks=num_blocks, upscale_factor=upscale_factor) self.block = FeedbackBlockHeatmapAttention( mid_channels=mid_channels, num_blocks=num_blocks, upscale_factor=upscale_factor, num_heatmaps=num_heatmaps, num_fusion_blocks=num_fusion_blocks) self.block.need_reset = False self.hour_glass = FeedbackHourglass( mid_channels=hg_mid_channels, num_keypoints=hg_num_keypoints) self.conv_last = nn.Sequential( nn.ConvTranspose2d(mid_channels, mid_channels, 8, 4, 2), nn.PReLU(init=prelu_init), nn.Conv2d(mid_channels, out_channels, 3, 1, 1))
[docs] def forward(self, x): """Forward function. Args: x (Tensor): Input tensor. Returns: Tensor: Forward results. sr_outputs (list[Tensor]): forward sr results. heatmap_outputs (list[Tensor]): forward heatmap results. """ inter_res = nn.functional.interpolate( x, size=(128, 128), mode='bilinear', align_corners=False) x = self.conv_first(x) sr_outputs = [] heatmap_outputs = [] last_hidden = None heatmap = None for step in range(self.num_steps): if step == 0: sr_feature = self.first_block(x) self.block.last_hidden = sr_feature else: heatmap = reduce_to_five_heatmaps(heatmap, self.detach_attention) sr_feature = self.block(x, heatmap) sr = self.conv_last(sr_feature) sr = torch.add(inter_res, sr) heatmap, last_hidden = self.hour_glass(sr, last_hidden) sr_outputs.append(sr) heatmap_outputs.append(heatmap) return sr_outputs, heatmap_outputs
[docs]class FeedbackBlock(nn.Module): """Feedback Block of DIC. It has a style of: .. code-block:: text ----- Module -----> ^ | |____________| Args: mid_channels (int): Number of channels in the intermediate features. num_blocks (int): Number of blocks. upscale_factor (int): upscale factor. padding (int): Padding size. Default: 2. prelu_init (float): `init` of PReLU. Default: 0.2 """ def __init__(self, mid_channels, num_blocks, upscale_factor, padding=2, prelu_init=0.2): super().__init__() stride = upscale_factor kernel_size = upscale_factor + 4 self.num_blocks = num_blocks self.need_reset = True self.last_hidden = None self.conv_first = nn.Sequential( nn.Conv2d(2 * mid_channels, mid_channels, kernel_size=1), nn.PReLU(init=prelu_init)) self.up_blocks = nn.ModuleList() self.down_blocks = nn.ModuleList() self.lr_blocks = nn.ModuleList() self.hr_blocks = nn.ModuleList() for idx in range(self.num_blocks): self.up_blocks.append( nn.Sequential( nn.ConvTranspose2d(mid_channels, mid_channels, kernel_size, stride, padding), nn.PReLU(init=prelu_init))) self.down_blocks.append( nn.Sequential( nn.Conv2d(mid_channels, mid_channels, kernel_size, stride, padding), nn.PReLU(init=prelu_init))) if idx > 0: self.lr_blocks.append( nn.Sequential( nn.Conv2d( mid_channels * (idx + 1), mid_channels, kernel_size=1), nn.PReLU(init=prelu_init))) self.hr_blocks.append( nn.Sequential( nn.Conv2d( mid_channels * (idx + 1), mid_channels, kernel_size=1), nn.PReLU(init=prelu_init))) self.conv_last = nn.Sequential( nn.Conv2d(num_blocks * mid_channels, mid_channels, kernel_size=1), nn.PReLU(init=prelu_init))
[docs] def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ if self.need_reset: self.last_hidden = x self.need_reset = False x = torch.cat((x, self.last_hidden), dim=1) x = self.conv_first(x) lr_features = [x] hr_features = [] for idx in range(self.num_blocks): # when idx == 0, lr_features == [x] lr = torch.cat(lr_features, 1) if idx > 0: lr = self.lr_blocks[idx - 1](lr) hr = self.up_blocks[idx](lr) hr_features.append(hr) hr = torch.cat(hr_features, 1) if idx > 0: hr = self.hr_blocks[idx - 1](hr) lr = self.down_blocks[idx](hr) lr_features.append(lr) output = torch.cat(lr_features[1:], 1) output = self.conv_last(output) self.last_hidden = output return output
[docs]class FeedbackBlockCustom(FeedbackBlock): """Custom feedback block, will be used as the first feedback block. Args: in_channels (int): Number of channels in the input features. mid_channels (int): Number of channels in the intermediate features. num_blocks (int): Number of blocks. upscale_factor (int): upscale factor. """ def __init__(self, in_channels, mid_channels, num_blocks, upscale_factor): super().__init__(mid_channels, num_blocks, upscale_factor) prelu_init = 0.2 self.conv_first = nn.Sequential( nn.Conv2d(in_channels, mid_channels, kernel_size=1), nn.PReLU(init=prelu_init))
[docs] def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ x = self.conv_first(x) lr_features = [x] hr_features = [] for idx in range(self.num_blocks): # when idx == 0, lr_features == [x] lr = torch.cat(lr_features, 1) if idx > 0: lr = self.lr_blocks[idx - 1](lr) hr = self.up_blocks[idx](lr) hr_features.append(hr) hr = torch.cat(hr_features, 1) if idx > 0: hr = self.hr_blocks[idx - 1](hr) lr = self.down_blocks[idx](hr) lr_features.append(lr) output = torch.cat(lr_features[1:], 1) output = self.conv_last(output) return output
[docs]class GroupResBlock(nn.Module): """ResBlock with Group Conv. Args: in_channels (int): Channel number of input features. out_channels (int): Channel number of output features. mid_channels (int): Channel number of intermediate features. groups (int): Number of blocked connections from input to output. res_scale (float): Used to scale the residual before addition. Default: 1.0. """ def __init__(self, in_channels, out_channels, mid_channels, groups, res_scale=1.0): super().__init__() self.res = nn.Sequential( nn.Conv2d(in_channels, mid_channels, 3, 1, 1, groups=groups), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Conv2d(mid_channels, out_channels, 3, 1, 1, groups=groups)) self.res_scale = res_scale
[docs] def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ res = self.res(x).mul(self.res_scale) return x + res
[docs]class FeatureHeatmapFusingBlock(nn.Module): """Fusing Feature and Heatmap. Args: in_channels (int): Number of channels in the input features. num_heatmaps (int): Number of heatmap. num_blocks (int): Number of blocks. mid_channels (int | None): Number of channels in the intermediate features. Default: None """ def __init__(self, in_channels, num_heatmaps, num_blocks, mid_channels=None): super().__init__() self.num_heatmaps = num_heatmaps res_block_channel = in_channels * num_heatmaps if mid_channels is None: self.mid_channels = num_heatmaps * in_channels else: self.mid_channels = mid_channels self.conv_first = nn.Sequential( nn.Conv2d(in_channels, res_block_channel, kernel_size=1), nn.LeakyReLU(negative_slope=0.2, inplace=True)) self.body = make_layer( GroupResBlock, num_blocks, in_channels=res_block_channel, out_channels=res_block_channel, mid_channels=self.mid_channels, groups=num_heatmaps)
[docs] def forward(self, feature, heatmap): """Forward function. Args: feature (Tensor): Input feature tensor. heatmap (Tensor): Input heatmap tensor. Returns: Tensor: Forward results. """ assert self.num_heatmaps == heatmap.size(1) batch_size = heatmap.size(0) w, h = feature.shape[-2:] feature = self.conv_first(feature) # B * (num_heatmaps*in_channels) * h * w feature = self.body(feature) attention = nn.functional.softmax( heatmap, dim=1) # B * num_heatmaps * h * w feature = feature.view(batch_size, self.num_heatmaps, -1, w, h) * attention.unsqueeze(2) feature = feature.sum(1) return feature
[docs]class FeedbackBlockHeatmapAttention(FeedbackBlock): """Feedback block with HeatmapAttention. Args: in_channels (int): Number of channels in the input features. mid_channels (int): Number of channels in the intermediate features. num_blocks (int): Number of blocks. upscale_factor (int): upscale factor. padding (int): Padding size. Default: 2. prelu_init (float): `init` of PReLU. Default: 0.2 """ def __init__(self, mid_channels, num_blocks, upscale_factor, num_heatmaps, num_fusion_blocks, padding=2, prelu_init=0.2): super().__init__( mid_channels, num_blocks, upscale_factor, padding=padding, prelu_init=prelu_init) self.fusion_block = FeatureHeatmapFusingBlock(mid_channels, num_heatmaps, num_fusion_blocks)
[docs] def forward(self, x, heatmap): """Forward function. Args: x (Tensor): Input feature tensor. heatmap (Tensor): Input heatmap tensor. Returns: Tensor: Forward results. """ if self.need_reset: self.last_hidden = x self.need_reset = False x = torch.cat((x, self.last_hidden), dim=1) x = self.conv_first(x) # fusion x = self.fusion_block(x, heatmap) lr_features = [] hr_features = [] lr_features.append(x) for idx in range(self.num_blocks): # when idx == 0, lr_features == [x] lr = torch.cat(lr_features, 1) if idx > 0: lr = self.lr_blocks[idx - 1](lr) hr = self.up_blocks[idx](lr) hr_features.append(hr) hr = torch.cat(hr_features, 1) if idx > 0: hr = self.hr_blocks[idx - 1](hr) lr = self.down_blocks[idx](hr) lr_features.append(lr) output = torch.cat(lr_features[1:], 1) output = self.conv_last(output) self.last_hidden = output return output
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.