mmagic.models.editors.plain.plain_decoder
¶
Module Contents¶
Classes¶
We warp the torch.nn.functional.max_unpool2d with an extra symbolic |
|
This module is modified from Pytorch MaxUnpool2d module. |
|
Simple decoder from Deep Image Matting. |
- class mmagic.models.editors.plain.plain_decoder.MaxUnpool2dop(*args, **kwargs)[源代码]¶
Bases:
torch.autograd.Function
We warp the torch.nn.functional.max_unpool2d with an extra symbolic method, which is needed while exporting to ONNX.
Users should not call this function directly.
- static forward(ctx, input, indices, kernel_size, stride, padding, output_size)[源代码]¶
Forward function of MaxUnpool2dop.
- 参数
input (Tensor) – Tensor needed to upsample.
indices (Tensor) – Indices output of the previous MaxPool.
kernel_size (Tuple) – Size of the max pooling window.
stride (Tuple) – Stride of the max pooling window.
padding (Tuple) – Padding that was added to the input.
output_size (List or Tuple) – The shape of output tensor.
- 返回
Output tensor.
- 返回类型
Tensor
- static symbolic(g, input, indices, kernel_size, stride, padding, output_size)[源代码]¶
This is the function to define the module of MaxUnpool.
- 参数
g (_type_) – _description_
input (Tensor) – Tensor needed to upsample.
indices (Tensor) – Indices output of the previous MaxPool.
kernel_size (int) – Size of the max pooling window.
stride (Tuple) – Stride of the max pooling window.
padding (Tuple) – Padding that was added to the input.
output_size (List or Tuple) – The shape of output tensor.
- 返回
_description_
- 返回类型
_type_
- class mmagic.models.editors.plain.plain_decoder.MaxUnpool2d(kernel_size, stride=None, padding=0)[源代码]¶
Bases:
torch.nn.modules.pooling._MaxUnpoolNd
This module is modified from Pytorch MaxUnpool2d module.
- 参数
kernel_size (int or tuple) – Size of the max pooling window.
stride (int or tuple) – Stride of the max pooling window. Default: None (It is set to kernel_size by default).
padding (int or tuple) – Padding that is added to the input. Default: 0.
- class mmagic.models.editors.plain.plain_decoder.PlainDecoder(in_channels, init_cfg: Optional[dict] = None)[源代码]¶
Bases:
mmengine.model.BaseModule
Simple decoder from Deep Image Matting.
- 参数
in_channels (int) – Channel num of input features.
init_cfg (dict, optional) – Initialization config dict. Default: None.
- forward(inputs)[源代码]¶
Forward function of PlainDecoder.
- 参数
inputs (dict) –
Output dictionary of the VGG encoder containing:
out (Tensor): Output of the VGG encoder.
max_idx_1 (Tensor): Index of the first maxpooling layer in the VGG encoder.
max_idx_2 (Tensor): Index of the second maxpooling layer in the VGG encoder.
max_idx_3 (Tensor): Index of the third maxpooling layer in the VGG encoder.
max_idx_4 (Tensor): Index of the fourth maxpooling layer in the VGG encoder.
max_idx_5 (Tensor): Index of the fifth maxpooling layer in the VGG encoder.
- 返回
Output tensor.
- 返回类型
Tensor