mmagic.models.editors.deepfillv1.deepfillv1
¶
Module Contents¶
Classes¶
Inpaintor for deepfillv1 method. |
- class mmagic.models.editors.deepfillv1.deepfillv1.DeepFillv1Inpaintor(data_preprocessor: dict, encdec: dict, disc=None, loss_gan=None, loss_gp=None, loss_disc_shift=None, loss_composed_percep=None, loss_out_percep=False, loss_l1_hole=None, loss_l1_valid=None, loss_tv=None, stage1_loss_type=None, stage2_loss_type=None, train_cfg=None, test_cfg=None, init_cfg: Optional[dict] = None)[source]¶
Bases:
mmagic.models.base_models.TwoStageInpaintor
Inpaintor for deepfillv1 method.
This inpaintor is implemented according to the paper: Generative image inpainting with contextual attention
Importantly, this inpaintor is an example for using custom training schedule based on TwoStageInpaintor.
The training pipeline of deepfillv1 is as following:
if cur_iter < iter_tc: update generator with only l1 loss else: update discriminator if cur_iter > iter_td: update generator with l1 loss and adversarial loss
The new attribute cur_iter is added for recording current number of iteration. The train_cfg contains the setting of the training schedule:
train_cfg = dict( start_iter=0, disc_step=1, iter_tc=90000, iter_td=100000 )
iter_tc and iter_td correspond to the notation \(T_C\) and \(T_D\) of the original paper.
- Parameters
generator (dict) – Config for encoder-decoder style generator.
disc (dict) – Config for discriminator.
loss_gan (dict) – Config for adversarial loss.
loss_gp (dict) – Config for gradient penalty loss.
loss_disc_shift (dict) – Config for discriminator shift loss.
loss_composed_percep (dict) – Config for perceptual and style loss with composed image as input.
loss_out_percep (dict) – Config for perceptual and style loss with direct output as input.
loss_l1_hole (dict) – Config for l1 loss in the hole.
loss_l1_valid (dict) – Config for l1 loss in the valid region.
loss_tv (dict) – Config for total variation loss.
train_cfg (dict) – Configs for training scheduler. disc_step must be contained for indicates the discriminator updating steps in each training step.
test_cfg (dict) – Configs for testing scheduler.
init_cfg (dict, optional) – Initialization config dict.
- forward_train_d(data_batch, is_real, is_disc)[source]¶
Forward function in discriminator training step.
In this function, we modify the default implementation with only one discriminator. In DeepFillv1 model, they use two separated discriminators for global and local consistency.
- Parameters
data_batch (torch.Tensor) – Batch of real data or fake data.
is_real (bool) – If True, the gan loss will regard this batch as real data. Otherwise, the gan loss will regard this batch as fake data.
is_disc (bool) – If True, this function is called in discriminator training step. Otherwise, this function is called in generator training step. This will help us to compute different types of adversarial loss, like LSGAN.
- Returns
Contains the loss items computed in this function.
- Return type
dict
- two_stage_loss(stage1_data, stage2_data, gt, mask, masked_img)[source]¶
Calculate two-stage loss.
- Parameters
stage1_data (dict) – Contain stage1 results.
stage2_data (dict) – Contain stage2 results.
gt (torch.Tensor) – Ground-truth image.
mask (torch.Tensor) – Mask image.
masked_img (torch.Tensor) – Composition of mask image and ground-truth image.
- Returns
Dict contains the results computed within this function for visualization and dict contains the loss items computed in this function.
- Return type
tuple(dict)
- calculate_loss_with_type(loss_type, fake_res, fake_img, gt, mask, prefix='stage1_', fake_local=None)[source]¶
Calculate multiple types of losses.
- Parameters
loss_type (str) – Type of the loss.
fake_res (torch.Tensor) – Direct results from model.
fake_img (torch.Tensor) – Composited results from model.
gt (torch.Tensor) – Ground-truth tensor.
mask (torch.Tensor) – Mask tensor.
prefix (str, optional) – Prefix for loss name. Defaults to ‘stage1_’. # noqa
fake_local (torch.Tensor, optional) – Local results from model. Defaults to None.
- Returns
Contain loss value with its name.
- Return type
dict
- train_step(data: List[dict], optim_wrapper)[source]¶
Train step function.
In this function, the inpaintor will finish the train step following the pipeline:
get fake res/image
optimize discriminator (if have)
optimize generator
If self.train_cfg.disc_step > 1, the train step will contain multiple iterations for optimizing discriminator with different input data and only one iteration for optimizing generator after disc_step iterations for discriminator.
- Parameters
data (List[dict]) – Batch of data as input.
optim_wrapper (dict[torch.optim.Optimizer]) – Dict with optimizers for generator and discriminator (if have).
- Returns
Dict with loss, information for logger, the number of samples and results for visualization.
- Return type
dict