Shortcuts

Source code for mmagic.models.editors.deepfillv1.deepfill_disc

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmengine.model import BaseModule
from mmengine.model.weight_init import normal_init

from mmagic.registry import MODELS


@MODELS.register_module()
[docs]class DeepFillv1Discriminators(BaseModule): """Discriminators used in DeepFillv1 model. In DeepFillv1 model, the discriminators are independent without any concatenation like Global&Local model. Thus, we call this model `DeepFillv1Discriminators`. There exist a global discriminator and a local discriminator with global and local input respectively. The details can be found in: Generative Image Inpainting with Contextual Attention. Args: global_disc_cfg (dict): Config dict for global discriminator. local_disc_cfg (dict): Config dict for local discriminator. """ def __init__(self, global_disc_cfg, local_disc_cfg): super().__init__() self.global_disc = MODELS.build(global_disc_cfg) self.local_disc = MODELS.build(local_disc_cfg)
[docs] def forward(self, x): """Forward function. Args: x (tuple[torch.Tensor]): Contains global image and the local image patch. Returns: tuple[torch.Tensor]: Contains the prediction from discriminators \ in global image and local image patch. """ global_img, local_img = x global_pred = self.global_disc(global_img) local_pred = self.local_disc(local_img) return global_pred, local_pred
[docs] def init_weights(self): """Init weights for models.""" for m in self.modules(): if isinstance(m, nn.Linear): normal_init(m, 0, std=0.02) elif isinstance(m, nn.Conv2d): normal_init(m, 0.0, std=0.02) self._is_init = True
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.