Shortcuts

mmagic.models.base_models.one_stage

Module Contents

Classes

OneStageInpaintor

Standard one-stage inpaintor with commonly used losses.

Attributes

FORWARD_RETURN_TYPE

mmagic.models.base_models.one_stage.FORWARD_RETURN_TYPE[源代码]
class mmagic.models.base_models.one_stage.OneStageInpaintor(data_preprocessor: Union[dict, mmengine.config.Config], encdec: dict, disc: Optional[dict] = None, loss_gan: Optional[dict] = None, loss_gp: Optional[dict] = None, loss_disc_shift: Optional[dict] = None, loss_composed_percep: Optional[dict] = None, loss_out_percep: bool = False, loss_l1_hole: Optional[dict] = None, loss_l1_valid: Optional[dict] = None, loss_tv: Optional[dict] = None, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None, init_cfg: Optional[dict] = None)[源代码]

Bases: mmengine.model.BaseModel

Standard one-stage inpaintor with commonly used losses.

An inpaintor must contain an encoder-decoder style generator to inpaint masked regions. A discriminator will be adopted when adversarial training is needed.

In this class, we provide a common interface for inpaintors. For other inpaintors, only some funcs may be modified to fit the input style or training schedule.

参数
  • data_preprocessor (dict) – Config of data_preprocessor.

  • encdec (dict) – Config for encoder-decoder style generator.

  • disc (dict) – Config for discriminator.

  • loss_gan (dict) – Config for adversarial loss.

  • loss_gp (dict) – Config for gradient penalty loss.

  • loss_disc_shift (dict) – Config for discriminator shift loss.

  • loss_composed_percep (dict) – Config for perceptual and style loss with composed image as input.

  • loss_out_percep (dict) – Config for perceptual and style loss with direct output as input.

  • loss_l1_hole (dict) – Config for l1 loss in the hole.

  • loss_l1_valid (dict) – Config for l1 loss in the valid region.

  • loss_tv (dict) – Config for total variation loss.

  • train_cfg (dict) – Configs for training scheduler. disc_step must be contained for indicates the discriminator updating steps in each training step.

  • test_cfg (dict) – Configs for testing scheduler.

  • init_cfg (dict, optional) – Initialization config dict.

forward(inputs: torch.Tensor, data_samples: Optional[mmagic.utils.SampleList], mode: str = 'tensor') FORWARD_RETURN_TYPE[源代码]

Forward function.

参数
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

  • mode (str) –

    mode should be one of loss, predict and tensor. Default: ‘tensor’.

    • loss: Called by train_step and return loss dict used for logging

    • predict: Called by val_step and test_step and return list of BaseDataElement results used for computing metric.

    • tensor: Called by custom use to get Tensor type results.

返回

  • If mode == loss, return a dict of loss tensor used for backward and logging.

  • If mode == predict, return a list of BaseDataElement for computing metric and getting inference result.

  • If mode == tensor, return a tensor or tuple of tensor or dict or tensor for custom use.

返回类型

ForwardResults

train_step(data: List[dict], optim_wrapper: mmengine.optim.OptimWrapperDict) dict[源代码]

Train step function.

In this function, the inpaintor will finish the train step following the pipeline:

  1. get fake res/image

  2. optimize discriminator (if have)

  3. optimize generator

If self.train_cfg.disc_step > 1, the train step will contain multiple iterations for optimizing discriminator with different input data and only one iteration for optimizing generator after disc_step iterations for discriminator.

参数
  • data (List[dict]) – Batch of data as input.

  • optim_wrapper (dict[torch.optim.Optimizer]) – Dict with optimizers for generator and discriminator (if have).

返回

Dict with loss, information for logger, the number of

samples and results for visualization.

返回类型

dict

abstract forward_train(*args, **kwargs) None[源代码]

Forward function for training.

In this version, we do not use this interface.

forward_train_d(data_batch: torch.Tensor, is_real: bool, is_disc: bool) dict[源代码]

Forward function in discriminator training step.

In this function, we compute the prediction for each data batch (real or fake). Meanwhile, the standard gan loss will be computed with several proposed losses for stable training.

参数
  • data_batch (torch.Tensor) – Batch of real data or fake data.

  • is_real (bool) – If True, the gan loss will regard this batch as real data. Otherwise, the gan loss will regard this batch as fake data.

  • is_disc (bool) – If True, this function is called in discriminator training step. Otherwise, this function is called in generator training step. This will help us to compute different types of adversarial loss, like LSGAN.

返回

Contains the loss items computed in this function.

返回类型

dict

generator_loss(fake_res: torch.Tensor, fake_img: torch.Tensor, gt: torch.Tensor, mask: torch.Tensor, masked_img: torch.Tensor) Tuple[dict, dict][源代码]

Forward function in generator training step.

In this function, we mainly compute the loss items for generator with the given (fake_res, fake_img). In general, the fake_res is the direct output of the generator and the fake_img is the composition of direct output and ground-truth image.

参数
  • fake_res (torch.Tensor) – Direct output of the generator.

  • fake_img (torch.Tensor) – Composition of fake_res and ground-truth image.

  • gt (torch.Tensor) – Ground-truth image.

  • mask (torch.Tensor) – Mask image.

  • masked_img (torch.Tensor) – Composition of mask image and ground-truth image.

返回

Dict contains the results computed within this function for visualization and dict contains the loss items computed in this function.

返回类型

tuple(dict)

forward_tensor(inputs: torch.Tensor, data_samples: mmagic.utils.SampleList) Tuple[torch.Tensor, torch.Tensor][源代码]

Forward function in tensor mode.

参数
  • inputs (torch.Tensor) – Input tensor.

  • data_samples (List[dict]) – List of data sample dict.

返回

Direct output of the generator and composition of fake_res

and ground-truth image.

返回类型

tuple

forward_test(inputs: torch.Tensor, data_samples: mmagic.utils.SampleList) mmagic.structures.DataSample[源代码]

Forward function for testing.

参数
  • inputs (torch.Tensor) – Input tensor.

  • data_samples (List[dict]) – List of data sample dict.

返回

List of prediction saved in

DataSample.

返回类型

predictions (List[DataSample])

convert_to_datasample(predictions: mmagic.structures.DataSample, data_samples: mmagic.structures.DataSample, inputs: Optional[torch.Tensor]) List[mmagic.structures.DataSample][源代码]

Add predictions and destructed inputs (if passed) to data samples.

参数
  • predictions (DataSample) – The predictions of the model.

  • data_samples (DataSample) – The data samples loaded from dataloader.

  • inputs (Optional[torch.Tensor]) – The input of model. Defaults to None.

返回

Modified data samples.

返回类型

List[DataSample]

forward_dummy(x: torch.Tensor) torch.Tensor[源代码]

Forward dummy function for getting flops.

参数

x (torch.Tensor) – Input tensor with shape of (n, c, h, w).

返回

Results tensor with shape of (n, 3, h, w).

返回类型

torch.Tensor

Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.