Shortcuts

Source code for mmagic.models.archs.simple_encoder_decoder

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

from mmengine.model import BaseModule
from torch import Tensor

from mmagic.registry import MODELS


@MODELS.register_module()
[docs]class SimpleEncoderDecoder(BaseModule): """Simple encoder-decoder model from matting. Args: encoder (dict): Config of the encoder. decoder (dict): Config of the decoder. init_cfg (dict, optional): Initialization config dict. """ def __init__(self, encoder: dict, decoder: dict, init_cfg: Optional[dict] = None): super().__init__(init_cfg) self.encoder = MODELS.build(encoder) if hasattr(self.encoder, 'out_channels'): decoder['in_channels'] = self.encoder.out_channels self.decoder = MODELS.build(decoder)
[docs] def forward(self, *args, **kwargs) -> Tensor: """Forward function. Returns: Tensor: The output tensor of the decoder. """ out = self.encoder(*args, **kwargs) out = self.decoder(out) return out
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.