Shortcuts

mmagic.models.base_models.base_mattor

Module Contents

Classes

BaseMattor

Base class for trimap-based matting models.

Functions

_pad(→ Tuple[torch.Tensor, Tuple[int, int]])

Pad image to a multiple of give down-sampling factor.

_interpolate(→ Tuple[torch.Tensor, Tuple[int, int]])

Resize image to multiple of give down-sampling factor.

Attributes

DataSamples

ForwardResults

mmagic.models.base_models.base_mattor.DataSamples[source]
mmagic.models.base_models.base_mattor.ForwardResults[source]
mmagic.models.base_models.base_mattor._pad(batch_image: torch.Tensor, ds_factor: int, mode: str = 'reflect') Tuple[torch.Tensor, Tuple[int, int]][source]

Pad image to a multiple of give down-sampling factor.

mmagic.models.base_models.base_mattor._interpolate(batch_image: torch.Tensor, ds_factor: int, mode: str = 'bicubic') Tuple[torch.Tensor, Tuple[int, int]][source]

Resize image to multiple of give down-sampling factor.

class mmagic.models.base_models.base_mattor.BaseMattor(data_preprocessor: Union[dict, mmengine.config.Config], backbone: dict, init_cfg: Optional[dict] = None, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None)[source]

Bases: mmengine.model.BaseModel

Base class for trimap-based matting models.

A matting model must contain a backbone which produces pred_alpha, a dense prediction with the same height and width of input image. In some cases (such as DIM), the model has a refiner which refines the prediction of the backbone.

Subclasses should overwrite the following functions:

  • _forward_train(), to return a loss

  • _forward_test(), to return a prediction

  • _forward(), to return raw tensors

For test, this base class provides functions to resize inputs and post-process pred_alphas to get predictions

Parameters
  • backbone (dict) – Config of backbone.

  • data_preprocessor (dict) – Config of data_preprocessor. See MattorPreprocessor for details.

  • init_cfg (dict, optional) – Initialization config dict.

  • train_cfg (dict) – Config of training. Customized by subclassesCustomized bu In train_cfg, train_backbone should be specified. If the model has a refiner, train_refiner should be specified.

  • test_cfg (dict) – Config of testing. In test_cfg, If the model has a refiner, train_refiner should be specified.

resize_inputs(batch_inputs: torch.Tensor) torch.Tensor[source]

Pad or interpolate images and trimaps to multiple of given factor.

restore_size(pred_alpha: torch.Tensor, data_sample: mmagic.structures.DataSample) torch.Tensor[source]

Restore the predicted alpha to the original shape.

The shape of the predicted alpha may not be the same as the shape of original input image. This function restores the shape of the predicted alpha.

Parameters
  • pred_alpha (torch.Tensor) – A single predicted alpha of shape (1, H, W).

  • data_sample (DataSample) – Data sample containing original shape as meta data.

Returns

The reshaped predicted alpha.

Return type

torch.Tensor

postprocess(batch_pred_alpha: torch.Tensor, data_samples: mmagic.structures.DataSample) List[mmagic.structures.DataSample][source]

Post-process alpha predictions.

This function contains the following steps:
  1. Restore padding or interpolation

  2. Mask alpha prediction with trimap

  3. Clamp alpha prediction to 0-1

  4. Convert alpha prediction to uint8

  5. Pack alpha prediction into DataSample

Currently only batch_size 1 is actually supported.

Parameters
  • batch_pred_alpha (torch.Tensor) – A batch of predicted alpha of shape (N, 1, H, W).

  • data_samples (List[DataSample]) – List of data samples.

Returns

A list of predictions.

Each data sample contains a pred_alpha, which is a torch.Tensor with dtype=uint8, device=cuda:0

Return type

List[DataSample]

forward(inputs: torch.Tensor, data_samples: DataSamples = None, mode: str = 'tensor') List[mmagic.structures.DataSample][source]

General forward function.

Parameters
  • inputs (torch.Tensor) – A batch of inputs. with image and trimap concatenated alone channel dimension.

  • data_samples (List[DataSample], optional) – A list of data samples, containing: - Ground-truth alpha / foreground / background to compute loss - other meta information

  • mode (str) –

    mode should be one of loss, predict and tensor. Default: ‘tensor’.

    • loss: Called by train_step and return loss dict used for logging

    • predict: Called by val_step and test_step and return list of BaseDataElement results used for computing metric.

    • tensor: Called by custom use to get Tensor type results.

Returns

Sequence of predictions packed into DataElement

Return type

List[DataElement]

convert_to_datasample(predictions: List[mmagic.structures.DataSample], data_samples: mmagic.structures.DataSample) List[mmagic.structures.DataSample][source]

Add predictions to data samples.

Parameters
  • predictions (List[DataSample]) – The predictions of the model.

  • data_samples (DataSample) – The data samples loaded from dataloader.

Returns

Modified data samples.

Return type

List[DataSample]

Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.