mmagic.models.editors.aotgan.aot_inpaintor
¶
Module Contents¶
Classes¶
Inpaintor for AOT-GAN method. |
- class mmagic.models.editors.aotgan.aot_inpaintor.AOTInpaintor(data_preprocessor: Union[dict, mmengine.config.Config], encdec: dict, disc: Optional[dict] = None, loss_gan: Optional[dict] = None, loss_gp: Optional[dict] = None, loss_disc_shift: Optional[dict] = None, loss_composed_percep: Optional[dict] = None, loss_out_percep: bool = False, loss_l1_hole: Optional[dict] = None, loss_l1_valid: Optional[dict] = None, loss_tv: Optional[dict] = None, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None, init_cfg: Optional[dict] = None)[source]¶
Bases:
mmagic.models.base_models.OneStageInpaintor
Inpaintor for AOT-GAN method.
This inpaintor is implemented according to the paper: Aggregated Contextual Transformations for High-Resolution Image Inpainting
- forward_train_d(data_batch, is_real, is_disc, mask)[source]¶
Forward function in discriminator training step.
In this function, we compute the prediction for each data batch (real or fake). Meanwhile, the standard gan loss will be computed with several proposed losses for stable training.
- Parameters
data_batch (torch.Tensor) – Batch of real data or fake data.
is_real (bool) – If True, the gan loss will regard this batch as real data. Otherwise, the gan loss will regard this batch as fake data.
is_disc (bool) – If True, this function is called in discriminator training step. Otherwise, this function is called in generator training step. This will help us to compute different types of adversarial loss, like LSGAN.
mask (torch.Tensor) – Mask of data.
- Returns
Contains the loss items computed in this function.
- Return type
dict
- generator_loss(fake_res, fake_img, gt, mask, masked_img)[source]¶
Forward function in generator training step.
In this function, we mainly compute the loss items for generator with the given (fake_res, fake_img). In general, the fake_res is the direct output of the generator and the fake_img is the composition of direct output and ground-truth image.
- Parameters
fake_res (torch.Tensor) – Direct output of the generator.
fake_img (torch.Tensor) – Composition of fake_res and ground-truth image.
gt (torch.Tensor) – Ground-truth image.
mask (torch.Tensor) – Mask image.
masked_img (torch.Tensor) – Composition of mask image and ground-truth image.
- Returns
- Dict contains the results computed within this
function for visualization and dict contains the loss items computed in this function.
- Return type
tuple(dict)
- forward_tensor(inputs, data_samples)[source]¶
Forward function in tensor mode.
- Parameters
inputs (torch.Tensor) – Input tensor.
data_samples (List[dict]) – List of data sample dict.
- Returns
- Direct output of the generator and composition of fake_res
and ground-truth image.
- Return type
tuple
- train_step(data: List[dict], optim_wrapper)[source]¶
Train step function.
In this function, the inpaintor will finish the train step following the pipeline: 1. get fake res/image 2. compute reconstruction losses for generator 3. compute adversarial loss for discriminator 4. optimize generator 5. optimize discriminator
- Parameters
data (List[dict]) – Batch of data as input.
optim_wrapper (dict[torch.optim.Optimizer]) – Dict with optimizers for generator and discriminator (if have).
- Returns
- Dict with loss, information for logger, the number of
samples and results for visualization.
- Return type
dict