Shortcuts

Source code for mmagic.models.editors.deblurganv2.deblurganv2_util

# Copyright (c) OpenMMLab. All rights reserved.
from __future__ import absolute_import, division, print_function
import functools
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import model_zoo

[docs]pretrained_settings = { 'inceptionresnetv2': { 'imagenet': { 'url': 'https://download.openxlab.org.cn/models/xiaomile/DeblurGANv2/' 'weight/inceptionresnetv2-520b38e4.pth', 'input_space': 'RGB', 'input_size': [3, 299, 299], 'input_range': [0, 1], 'mean': [0.5, 0.5, 0.5], 'std': [0.5, 0.5, 0.5], 'num_classes': 1000 }, 'imagenet+background': { 'url': 'https://download.openxlab.org.cn/models/xiaomile/DeblurGANv2/' 'weight/inceptionresnetv2-520b38e4.pth', 'input_space': 'RGB', 'input_size': [3, 299, 299], 'input_range': [0, 1], 'mean': [0.5, 0.5, 0.5], 'std': [0.5, 0.5, 0.5], 'num_classes': 1001
} } }
[docs]class BasicConv2d(nn.Module): def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): super(BasicConv2d, self).__init__() self.conv = nn.Conv2d( in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) # verify bias false self.bn = nn.BatchNorm2d( out_planes, eps=0.001, # value found in tensorflow momentum=0.1, # default pytorch value affine=True) self.relu = nn.ReLU(inplace=False)
[docs] def forward(self, x): """Forward function. Args: x (torch.Tensor ): You can directly input a ``torch.Tensor``. Returns: torch.Tensor : ``torch.tensor`` will be returned. """ x = self.conv(x) x = self.bn(x) x = self.relu(x) return x
[docs]class Mixed_5b(nn.Module): def __init__(self): super(Mixed_5b, self).__init__() self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1) self.branch1 = nn.Sequential( BasicConv2d(192, 48, kernel_size=1, stride=1), BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2)) self.branch2 = nn.Sequential( BasicConv2d(192, 64, kernel_size=1, stride=1), BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)) self.branch3 = nn.Sequential( nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), BasicConv2d(192, 64, kernel_size=1, stride=1))
[docs] def forward(self, x): """Forward function. Args: x (torch.Tensor ): You can directly input a ``torch.Tensor``. Returns: torch.Tensor : ``torch.tensor`` will be returned. """ x0 = self.branch0(x) x1 = self.branch1(x) x2 = self.branch2(x) x3 = self.branch3(x) out = torch.cat((x0, x1, x2, x3), 1) return out
[docs]class Block35(nn.Module): def __init__(self, scale=1.0): super(Block35, self).__init__() self.scale = scale self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1) self.branch1 = nn.Sequential( BasicConv2d(320, 32, kernel_size=1, stride=1), BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)) self.branch2 = nn.Sequential( BasicConv2d(320, 32, kernel_size=1, stride=1), BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1), BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1)) self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1) self.relu = nn.ReLU(inplace=False)
[docs] def forward(self, x): """Forward function. Args: x (torch.Tensor ): You can directly input a ``torch.Tensor``. Returns: torch.Tensor : ``torch.tensor`` will be returned. """ x0 = self.branch0(x) x1 = self.branch1(x) x2 = self.branch2(x) out = torch.cat((x0, x1, x2), 1) out = self.conv2d(out) out = out * self.scale + x out = self.relu(out) return out
[docs]class Mixed_6a(nn.Module): def __init__(self): super(Mixed_6a, self).__init__() self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2) self.branch1 = nn.Sequential( BasicConv2d(320, 256, kernel_size=1, stride=1), BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1), BasicConv2d(256, 384, kernel_size=3, stride=2)) self.branch2 = nn.MaxPool2d(3, stride=2)
[docs] def forward(self, x): """Forward function. Args: x (torch.Tensor ): You can directly input a ``torch.Tensor``. Returns: torch.Tensor : ``torch.tensor`` will be returned. """ x0 = self.branch0(x) x1 = self.branch1(x) x2 = self.branch2(x) out = torch.cat((x0, x1, x2), 1) return out
[docs]class Block17(nn.Module): def __init__(self, scale=1.0): super(Block17, self).__init__() self.scale = scale self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1) self.branch1 = nn.Sequential( BasicConv2d(1088, 128, kernel_size=1, stride=1), BasicConv2d( 128, 160, kernel_size=(1, 7), stride=1, padding=(0, 3)), BasicConv2d( 160, 192, kernel_size=(7, 1), stride=1, padding=(3, 0))) self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1) self.relu = nn.ReLU(inplace=False)
[docs] def forward(self, x): """Forward function. Args: x (torch.Tensor ): You can directly input a ``torch.Tensor``. Returns: torch.Tensor : ``torch.tensor`` will be returned. """ x0 = self.branch0(x) x1 = self.branch1(x) out = torch.cat((x0, x1), 1) out = self.conv2d(out) out = out * self.scale + x out = self.relu(out) return out
[docs]class Mixed_7a(nn.Module): def __init__(self): super(Mixed_7a, self).__init__() self.branch0 = nn.Sequential( BasicConv2d(1088, 256, kernel_size=1, stride=1), BasicConv2d(256, 384, kernel_size=3, stride=2)) self.branch1 = nn.Sequential( BasicConv2d(1088, 256, kernel_size=1, stride=1), BasicConv2d(256, 288, kernel_size=3, stride=2)) self.branch2 = nn.Sequential( BasicConv2d(1088, 256, kernel_size=1, stride=1), BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1), BasicConv2d(288, 320, kernel_size=3, stride=2)) self.branch3 = nn.MaxPool2d(3, stride=2)
[docs] def forward(self, x): """Forward function. Args: x (torch.Tensor ): You can directly input a ``torch.Tensor``. Returns: torch.Tensor : ``torch.tensor`` will be returned. """ x0 = self.branch0(x) x1 = self.branch1(x) x2 = self.branch2(x) x3 = self.branch3(x) out = torch.cat((x0, x1, x2, x3), 1) return out
[docs]class Block8(nn.Module): def __init__(self, scale=1.0, noReLU=False): super(Block8, self).__init__() self.scale = scale self.noReLU = noReLU self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1) self.branch1 = nn.Sequential( BasicConv2d(2080, 192, kernel_size=1, stride=1), BasicConv2d( 192, 224, kernel_size=(1, 3), stride=1, padding=(0, 1)), BasicConv2d( 224, 256, kernel_size=(3, 1), stride=1, padding=(1, 0))) self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1) if not self.noReLU: self.relu = nn.ReLU(inplace=False)
[docs] def forward(self, x): """Forward function. Args: x (torch.Tensor ): You can directly input a ``torch.Tensor``. Returns: torch.Tensor : ``torch.tensor`` will be returned. """ x0 = self.branch0(x) x1 = self.branch1(x) out = torch.cat((x0, x1), 1) out = self.conv2d(out) out = out * self.scale + x if not self.noReLU: out = self.relu(out) return out
[docs]class InceptionResNetV2(nn.Module): """Define a inceptionresnetv2 model.""" def __init__(self, num_classes=1001): super(InceptionResNetV2, self).__init__() # Special attributes self.input_space = None self.input_size = (299, 299, 3) self.mean = None self.std = None # Modules self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2) self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1) self.conv2d_2b = BasicConv2d( 32, 64, kernel_size=3, stride=1, padding=1) self.maxpool_3a = nn.MaxPool2d(3, stride=2) self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1) self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1) self.maxpool_5a = nn.MaxPool2d(3, stride=2) self.mixed_5b = Mixed_5b() self.repeat = nn.Sequential( Block35(scale=0.17), Block35(scale=0.17), Block35(scale=0.17), Block35(scale=0.17), Block35(scale=0.17), Block35(scale=0.17), Block35(scale=0.17), Block35(scale=0.17), Block35(scale=0.17), Block35(scale=0.17)) self.mixed_6a = Mixed_6a() self.repeat_1 = nn.Sequential( Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10)) self.mixed_7a = Mixed_7a() self.repeat_2 = nn.Sequential( Block8(scale=0.20), Block8(scale=0.20), Block8(scale=0.20), Block8(scale=0.20), Block8(scale=0.20), Block8(scale=0.20), Block8(scale=0.20), Block8(scale=0.20), Block8(scale=0.20)) self.block8 = Block8(noReLU=True) self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1) self.avgpool_1a = nn.AvgPool2d(8, count_include_pad=False) self.last_linear = nn.Linear(1536, num_classes)
[docs] def features(self, input): """Get network features. Args: input (torch.tensor): You can directly input a ``torch.Tensor``. """ x = self.conv2d_1a(input) x = self.conv2d_2a(x) x = self.conv2d_2b(x) x = self.maxpool_3a(x) x = self.conv2d_3b(x) x = self.conv2d_4a(x) x = self.maxpool_5a(x) x = self.mixed_5b(x) x = self.repeat(x) x = self.mixed_6a(x) x = self.repeat_1(x) x = self.mixed_7a(x) x = self.repeat_2(x) x = self.block8(x) x = self.conv2d_7b(x) return x
[docs] def logits(self, features): """Get features logits. Args: features (torch.tensor): You can directly input a ``torch.Tensor``. """ x = self.avgpool_1a(features) x = x.view(x.size(0), -1) x = F.dropout(x, training=self.training) x = self.last_linear(x) return x
[docs] def forward(self, input): """Forward function. Args: input (torch.Tensor ): You can directly input a ``torch.Tensor``. Returns: torch.Tensor : ``torch.tensor`` will be returned. """ x = self.features(input) x = self.logits(x) return x
[docs]def inceptionresnetv2(num_classes=1000, pretrained='imagenet'): """return a inceptionresnetv2 network.""" if pretrained: settings = pretrained_settings['inceptionresnetv2'][pretrained] model = InceptionResNetV2(num_classes=1001) model.load_state_dict(model_zoo.load_url(settings['url'])) if pretrained == 'imagenet': new_last_linear = nn.Linear(1536, 1000) new_last_linear.weight.data = model.last_linear.weight.data[1:] new_last_linear.bias.data = model.last_linear.bias.data[1:] model.last_linear = new_last_linear model.input_space = settings['input_space'] model.input_size = settings['input_size'] model.input_range = settings['input_range'] model.mean = settings['mean'] model.std = settings['std'] else: model = InceptionResNetV2(num_classes=num_classes) return model
[docs]def conv_bn(inp, oup, stride): return nn.Sequential( nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), nn.ReLU6(inplace=True))
[docs]def conv_1x1_bn(inp, oup): return nn.Sequential( nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), nn.ReLU6(inplace=True))
[docs]class InvertedResidual(nn.Module): def __init__(self, inp, oup, stride, expand_ratio): super(InvertedResidual, self).__init__() self.stride = stride assert stride in [1, 2] hidden_dim = round(inp * expand_ratio) self.use_res_connect = self.stride == 1 and inp == oup if expand_ratio == 1: self.conv = nn.Sequential( # dw nn.Conv2d( hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), nn.BatchNorm2d(hidden_dim), nn.ReLU6(inplace=True), # pw-linear nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), ) else: self.conv = nn.Sequential( # pw nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), nn.BatchNorm2d(hidden_dim), nn.ReLU6(inplace=True), # dw nn.Conv2d( hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), nn.BatchNorm2d(hidden_dim), nn.ReLU6(inplace=True), # pw-linear nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), )
[docs] def forward(self, x): """Forward function. Args: x (torch.Tensor ): You can directly input a ``torch.Tensor``. Returns: torch.Tensor : ``torch.tensor`` will be returned. """ if self.use_res_connect: return x + self.conv(x) else: return self.conv(x)
[docs]class MobileNetV2(nn.Module): def __init__(self, n_class=1000, input_size=224, width_mult=1.): super(MobileNetV2, self).__init__() block = InvertedResidual input_channel = 32 last_channel = 1280 interverted_residual_setting = [ # t, c, n, s [1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2], [6, 64, 4, 2], [6, 96, 3, 1], [6, 160, 3, 2], [6, 320, 1, 1], ] # building first layer assert input_size % 32 == 0 input_channel = int(input_channel * width_mult) self.last_channel = int( last_channel * width_mult) if width_mult > 1.0 else last_channel self.features = [conv_bn(3, input_channel, 2)] # building inverted residual blocks for t, c, n, s in interverted_residual_setting: output_channel = int(c * width_mult) for i in range(n): if i == 0: self.features.append( block( input_channel, output_channel, s, expand_ratio=t)) else: self.features.append( block( input_channel, output_channel, 1, expand_ratio=t)) input_channel = output_channel # building last several layers self.features.append(conv_1x1_bn(input_channel, self.last_channel)) # make it nn.Sequential self.features = nn.Sequential(*self.features) # building classifier self.classifier = nn.Sequential( nn.Dropout(0.2), nn.Linear(self.last_channel, n_class), ) self._initialize_weights()
[docs] def forward(self, x): """Forward function. Args: x (torch.Tensor ): You can directly input a ``torch.Tensor``. Returns: torch.Tensor : ``torch.tensor`` will be returned. """ x = self.features(x) x = x.mean(3).mean(2) x = self.classifier(x) return x
[docs] def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): n = m.weight.size(1) m.weight.data.normal_(0, 0.01) m.bias.data.zero_()
[docs]def get_norm_layer(norm_type='instance'): """Returns a norm layer of the specified type. Args: norm_type (Str): norm layer type """ if norm_type == 'batch': norm_layer = functools.partial(nn.BatchNorm2d, affine=True) elif norm_type == 'instance': norm_layer = functools.partial( nn.InstanceNorm2d, affine=False, track_running_stats=True) else: raise NotImplementedError('normalization layer [%s] is not found' % norm_type) return norm_layer