Shortcuts

Source code for mmagic.models.editors.plain.plain_decoder

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule
from mmengine.model.weight_init import xavier_init
from torch.autograd import Function
from torch.nn.modules.pooling import _MaxUnpoolNd
from torch.nn.modules.utils import _pair

from mmagic.registry import MODELS


[docs]class MaxUnpool2dop(Function): """We warp the `torch.nn.functional.max_unpool2d` with an extra `symbolic` method, which is needed while exporting to ONNX. Users should not call this function directly. """ @staticmethod
[docs] def forward(ctx, input, indices, kernel_size, stride, padding, output_size): """Forward function of MaxUnpool2dop. Args: input (Tensor): Tensor needed to upsample. indices (Tensor): Indices output of the previous MaxPool. kernel_size (Tuple): Size of the max pooling window. stride (Tuple): Stride of the max pooling window. padding (Tuple): Padding that was added to the input. output_size (List or Tuple): The shape of output tensor. Returns: Tensor: Output tensor. """ return F.max_unpool2d(input, indices, kernel_size, stride, padding, output_size)
@staticmethod
[docs] def symbolic(g, input, indices, kernel_size, stride, padding, output_size): """This is the function to define the module of MaxUnpool. Args: g (_type_): _description_ input (Tensor): Tensor needed to upsample. indices (Tensor): Indices output of the previous MaxPool. kernel_size (int): Size of the max pooling window. stride (Tuple): Stride of the max pooling window. padding (Tuple): Padding that was added to the input. output_size (List or Tuple): The shape of output tensor. Returns: _type_: _description_ """ # get shape input_shape = g.op('Shape', input) const_0 = g.op('Constant', value_t=torch.tensor(0)) const_1 = g.op('Constant', value_t=torch.tensor(1)) batch_size = g.op('Gather', input_shape, const_0, axis_i=0) channel = g.op('Gather', input_shape, const_1, axis_i=0) # height = (height - 1) * stride + kernel_size height = g.op( 'Gather', input_shape, g.op('Constant', value_t=torch.tensor(2)), axis_i=0) height = g.op('Sub', height, const_1) height = g.op('Mul', height, g.op('Constant', value_t=torch.tensor(stride[1]))) height = g.op('Add', height, g.op('Constant', value_t=torch.tensor(kernel_size[1]))) # width = (width - 1) * stride + kernel_size width = g.op( 'Gather', input_shape, g.op('Constant', value_t=torch.tensor(3)), axis_i=0) width = g.op('Sub', width, const_1) width = g.op('Mul', width, g.op('Constant', value_t=torch.tensor(stride[0]))) width = g.op('Add', width, g.op('Constant', value_t=torch.tensor(kernel_size[0]))) # step of channel channel_step = g.op('Mul', height, width) # step of batch batch_step = g.op('Mul', channel_step, channel) # channel offset range_channel = g.op('Range', const_0, channel, const_1) range_channel = g.op( 'Reshape', range_channel, g.op('Constant', value_t=torch.tensor([1, -1, 1, 1]))) range_channel = g.op('Mul', range_channel, channel_step) range_channel = g.op('Cast', range_channel, to_i=7) # 7 is int64 # batch offset range_batch = g.op('Range', const_0, batch_size, const_1) range_batch = g.op( 'Reshape', range_batch, g.op('Constant', value_t=torch.tensor([-1, 1, 1, 1]))) range_batch = g.op('Mul', range_batch, batch_step) range_batch = g.op('Cast', range_batch, to_i=7) # 7 is int64 # update indices indices = g.op('Add', indices, range_channel) indices = g.op('Add', indices, range_batch) return g.op( 'MaxUnpool', input, indices, kernel_shape_i=kernel_size, strides_i=stride)
[docs]class MaxUnpool2d(_MaxUnpoolNd): """This module is modified from Pytorch `MaxUnpool2d` module. Args: kernel_size (int or tuple): Size of the max pooling window. stride (int or tuple): Stride of the max pooling window. Default: None (It is set to `kernel_size` by default). padding (int or tuple): Padding that is added to the input. Default: 0. """ def __init__(self, kernel_size, stride=None, padding=0): super(MaxUnpool2d, self).__init__() self.kernel_size = _pair(kernel_size) self.stride = _pair(stride or kernel_size) self.padding = _pair(padding)
[docs] def forward(self, input, indices, output_size=None): """Forward function of MaxUnpool2d. Args: input (Tensor): Tensor needed to upsample. indices (Tensor): Indices output of the previous MaxPool. output_size (List or Tuple): The shape of output tensor. Default: None. Returns: Tensor: Output tensor. """ return MaxUnpool2dop.apply(input, indices, self.kernel_size, self.stride, self.padding, output_size)
@MODELS.register_module()
[docs]class PlainDecoder(BaseModule): """Simple decoder from Deep Image Matting. Args: in_channels (int): Channel num of input features. init_cfg (dict, optional): Initialization config dict. Default: None. """ def __init__(self, in_channels, init_cfg: Optional[dict] = None): super().__init__(init_cfg=init_cfg) self.deconv6_1 = nn.Conv2d(in_channels, 512, kernel_size=1) self.deconv5_1 = nn.Conv2d(512, 512, kernel_size=5, padding=2) self.deconv4_1 = nn.Conv2d(512, 256, kernel_size=5, padding=2) self.deconv3_1 = nn.Conv2d(256, 128, kernel_size=5, padding=2) self.deconv2_1 = nn.Conv2d(128, 64, kernel_size=5, padding=2) self.deconv1_1 = nn.Conv2d(64, 64, kernel_size=5, padding=2) self.deconv1 = nn.Conv2d(64, 1, kernel_size=5, padding=2) self.relu = nn.ReLU(inplace=True) self.max_unpool2d_for_onnx = MaxUnpool2d(kernel_size=2, stride=2) self.max_unpool2d = nn.MaxUnpool2d(kernel_size=2, stride=2)
[docs] def init_weights(self): """Init weights for the module.""" if self.init_cfg is not None: super().init_weights() else: # Default initialization for m in self.modules(): if isinstance(m, nn.Conv2d): xavier_init(m)
[docs] def forward(self, inputs): """Forward function of PlainDecoder. Args: inputs (dict): Output dictionary of the VGG encoder containing: - out (Tensor): Output of the VGG encoder. - max_idx_1 (Tensor): Index of the first maxpooling layer in the VGG encoder. - max_idx_2 (Tensor): Index of the second maxpooling layer in the VGG encoder. - max_idx_3 (Tensor): Index of the third maxpooling layer in the VGG encoder. - max_idx_4 (Tensor): Index of the fourth maxpooling layer in the VGG encoder. - max_idx_5 (Tensor): Index of the fifth maxpooling layer in the VGG encoder. Returns: Tensor: Output tensor. """ max_idx_1 = inputs['max_idx_1'] max_idx_2 = inputs['max_idx_2'] max_idx_3 = inputs['max_idx_3'] max_idx_4 = inputs['max_idx_4'] max_idx_5 = inputs['max_idx_5'] x = inputs['out'] max_unpool2d = self.max_unpool2d if torch.onnx.is_in_onnx_export(): max_unpool2d = self.max_unpool2d_for_onnx out = self.relu(self.deconv6_1(x)) out = max_unpool2d(out, max_idx_5) out = self.relu(self.deconv5_1(out)) out = max_unpool2d(out, max_idx_4) out = self.relu(self.deconv4_1(out)) out = max_unpool2d(out, max_idx_3) out = self.relu(self.deconv3_1(out)) out = max_unpool2d(out, max_idx_2) out = self.relu(self.deconv2_1(out)) out = max_unpool2d(out, max_idx_1) out = self.relu(self.deconv1_1(out)) raw_alpha = self.deconv1(out) return raw_alpha
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.