mmagic.models.base_models.two_stage
¶
Module Contents¶
Classes¶
Standard two-stage inpaintor with commonly used losses. A two-stage |
- class mmagic.models.base_models.two_stage.TwoStageInpaintor(data_preprocessor: Union[dict, mmengine.config.Config], encdec: dict, disc: Optional[dict] = None, loss_gan: Optional[dict] = None, loss_gp: Optional[dict] = None, loss_disc_shift: Optional[dict] = None, loss_composed_percep: Optional[dict] = None, loss_out_percep: bool = False, loss_l1_hole: Optional[dict] = None, loss_l1_valid: Optional[dict] = None, loss_tv: Optional[dict] = None, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None, init_cfg: Optional[dict] = None, stage1_loss_type: Optional[Sequence[str]] = ('loss_l1_hole',), stage2_loss_type: Optional[Sequence[str]] = ('loss_l1_hole', 'loss_gan'), input_with_ones: bool = True, disc_input_with_mask: bool = False)[source]¶
Bases:
mmagic.models.base_models.one_stage.OneStageInpaintor
Standard two-stage inpaintor with commonly used losses. A two-stage inpaintor contains two encoder-decoder style generators to inpaint masked regions. Currently, we support these loss types in each of two stage inpaintors:
[‘loss_gan’, ‘loss_l1_hole’, ‘loss_l1_valid’, ‘loss_composed_percep’, ‘loss_out_percep’, ‘loss_tv’] The stage1_loss_type and stage2_loss_type should be chosen from these loss types.
- Parameters
data_preprocessor (dict) – Config of data_preprocessor.
encdec (dict) – Config for encoder-decoder style generator.
disc (dict) – Config for discriminator.
loss_gan (dict) – Config for adversarial loss.
loss_gp (dict) – Config for gradient penalty loss.
loss_disc_shift (dict) – Config for discriminator shift loss.
loss_composed_percep (dict) – Config for perceptual and style loss with composed image as input.
loss_out_percep (dict) – Config for perceptual and style loss with direct output as input.
loss_l1_hole (dict) – Config for l1 loss in the hole.
loss_l1_valid (dict) – Config for l1 loss in the valid region.
loss_tv (dict) – Config for total variation loss.
train_cfg (dict) – Configs for training scheduler. disc_step must be contained for indicates the discriminator updating steps in each training step.
test_cfg (dict) – Configs for testing scheduler.
init_cfg (dict, optional) – Initialization config dict.
stage1_loss_type (tuple[str]) – Contains the loss names used in the first stage model. Default: (‘loss_l1_hole’).
stage2_loss_type (tuple[str]) – Contains the loss names used in the second stage model. Default: (‘loss_l1_hole’, ‘loss_gan’).
input_with_ones (bool) – Whether to concatenate an extra ones tensor in input. Default: True.
disc_input_with_mask (bool) – Whether to add mask as input in discriminator. Default: False.
- forward_tensor(inputs: torch.Tensor, data_samples: mmagic.utils.SampleList) Tuple[torch.Tensor, torch.Tensor] [source]¶
Forward function in tensor mode.
- Parameters
inputs (torch.Tensor) – Input tensor.
data_samples (List[dict]) – List of data sample dict.
- Returns
Dict contains output results.
- Return type
dict
- two_stage_loss(stage1_data: dict, stage2_data: dict, gt: torch.Tensor, mask: torch.Tensor, masked_img: torch.Tensor) Tuple[dict, dict] [source]¶
Calculate two-stage loss.
- Parameters
stage1_data (dict) – Contain stage1 results.
stage2_data (dict) – Contain stage2 results..
gt (torch.Tensor) – Ground-truth image.
mask (torch.Tensor) – Mask image.
masked_img (torch.Tensor) – Composition of mask image and ground-truth image.
- Returns
Dict contains the results computed within this function for visualization and dict contains the loss items computed in this function.
- Return type
tuple(dict)
- calculate_loss_with_type(loss_type: str, fake_res: torch.Tensor, fake_img: torch.Tensor, gt: torch.Tensor, mask: torch.Tensor, prefix: Optional[str] = 'stage1_') dict [source]¶
Calculate multiple types of losses.
- Parameters
loss_type (str) – Type of the loss.
fake_res (torch.Tensor) – Direct results from model.
fake_img (torch.Tensor) – Composited results from model.
gt (torch.Tensor) – Ground-truth tensor.
mask (torch.Tensor) – Mask tensor.
prefix (str, optional) – Prefix for loss name. Defaults to ‘stage1_’. # noqa
- Returns
Contain loss value with its name.
- Return type
dict
- train_step(data: List[dict], optim_wrapper: mmengine.optim.OptimWrapperDict) dict [source]¶
Train step function.
In this function, the inpaintor will finish the train step following the pipeline: 1. get fake res/image 2. optimize discriminator (if have) 3. optimize generator
If self.train_cfg.disc_step > 1, the train step will contain multiple iterations for optimizing discriminator with different input data and only one iteration for optimizing generator after disc_step iterations for discriminator.
- Parameters
data (List[dict]) – Batch of data as input.
optim_wrapper (dict[torch.optim.Optimizer]) – Dict with optimizers for generator and discriminator (if have).
- Returns
Dict with loss, information for logger, the number of samples and results for visualization.
- Return type
dict