Shortcuts

Source code for mmagic.models.editors.global_local.gl_encoder_decoder

# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.model import BaseModule

from mmagic.registry import MODELS


@MODELS.register_module()
[docs]class GLEncoderDecoder(BaseModule): """Encoder-Decoder used in Global&Local model. This implementation follows: Globally and locally Consistent Image Completion The architecture of the encoder-decoder is:\ (conv2d x 6) --> (dilated conv2d x 4) --> (conv2d or deconv2d x 7) Args: encoder (dict): Config dict to encoder. decoder (dict): Config dict to build decoder. dilation_neck (dict): Config dict to build dilation neck. """ def __init__(self, encoder=dict(type='GLEncoder'), decoder=dict(type='GLDecoder'), dilation_neck=dict(type='GLDilationNeck')): super().__init__() self.encoder = MODELS.build(encoder) self.decoder = MODELS.build(decoder) self.dilation_neck = MODELS.build(dilation_neck) # support fp16 self.fp16_enabled = False
[docs] def forward(self, x): """Forward Function. Args: x (torch.Tensor): Input tensor with shape of (n, c, h, w). Returns: torch.Tensor: Output tensor with shape of (n, c, h', w'). """ x = self.encoder(x) if isinstance(x, dict): x = x['out'] x = self.dilation_neck(x) x = self.decoder(x) return x
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.