mmagic.models.base_models.base_mattor
¶
Module Contents¶
Classes¶
Base class for trimap-based matting models. |
Functions¶
|
Pad image to a multiple of give down-sampling factor. |
|
Resize image to multiple of give down-sampling factor. |
Attributes¶
- mmagic.models.base_models.base_mattor._pad(batch_image: torch.Tensor, ds_factor: int, mode: str = 'reflect') Tuple[torch.Tensor, Tuple[int, int]] [source]¶
Pad image to a multiple of give down-sampling factor.
- mmagic.models.base_models.base_mattor._interpolate(batch_image: torch.Tensor, ds_factor: int, mode: str = 'bicubic') Tuple[torch.Tensor, Tuple[int, int]] [source]¶
Resize image to multiple of give down-sampling factor.
- class mmagic.models.base_models.base_mattor.BaseMattor(data_preprocessor: Union[dict, mmengine.config.Config], backbone: dict, init_cfg: Optional[dict] = None, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None)[source]¶
Bases:
mmengine.model.BaseModel
Base class for trimap-based matting models.
A matting model must contain a backbone which produces pred_alpha, a dense prediction with the same height and width of input image. In some cases (such as DIM), the model has a refiner which refines the prediction of the backbone.
Subclasses should overwrite the following functions:
_forward_train()
, to return a loss_forward_test()
, to return a prediction_forward()
, to return raw tensors
For test, this base class provides functions to resize inputs and post-process pred_alphas to get predictions
- Parameters
backbone (dict) – Config of backbone.
data_preprocessor (dict) – Config of data_preprocessor. See
MattorPreprocessor
for details.init_cfg (dict, optional) – Initialization config dict.
train_cfg (dict) – Config of training. Customized by subclassesCustomized bu In
train_cfg
,train_backbone
should be specified. If the model has a refiner,train_refiner
should be specified.test_cfg (dict) – Config of testing. In
test_cfg
, If the model has a refiner,train_refiner
should be specified.
- resize_inputs(batch_inputs: torch.Tensor) torch.Tensor [source]¶
Pad or interpolate images and trimaps to multiple of given factor.
- restore_size(pred_alpha: torch.Tensor, data_sample: mmagic.structures.DataSample) torch.Tensor [source]¶
Restore the predicted alpha to the original shape.
The shape of the predicted alpha may not be the same as the shape of original input image. This function restores the shape of the predicted alpha.
- Parameters
pred_alpha (torch.Tensor) – A single predicted alpha of shape (1, H, W).
data_sample (DataSample) – Data sample containing original shape as meta data.
- Returns
The reshaped predicted alpha.
- Return type
torch.Tensor
- postprocess(batch_pred_alpha: torch.Tensor, data_samples: mmagic.structures.DataSample) List[mmagic.structures.DataSample] [source]¶
Post-process alpha predictions.
- This function contains the following steps:
Restore padding or interpolation
Mask alpha prediction with trimap
Clamp alpha prediction to 0-1
Convert alpha prediction to uint8
Pack alpha prediction into DataSample
Currently only batch_size 1 is actually supported.
- Parameters
batch_pred_alpha (torch.Tensor) – A batch of predicted alpha of shape (N, 1, H, W).
data_samples (List[DataSample]) – List of data samples.
- Returns
- A list of predictions.
Each data sample contains a pred_alpha, which is a torch.Tensor with dtype=uint8, device=cuda:0
- Return type
List[DataSample]
- forward(inputs: torch.Tensor, data_samples: DataSamples = None, mode: str = 'tensor') List[mmagic.structures.DataSample] [source]¶
General forward function.
- Parameters
inputs (torch.Tensor) – A batch of inputs. with image and trimap concatenated alone channel dimension.
data_samples (List[DataSample], optional) – A list of data samples, containing: - Ground-truth alpha / foreground / background to compute loss - other meta information
mode (str) –
mode should be one of
loss
,predict
andtensor
. Default: ‘tensor’.loss
: Called bytrain_step
and return lossdict
used for loggingpredict
: Called byval_step
andtest_step
and return list ofBaseDataElement
results used for computing metric.tensor
: Called by custom use to getTensor
type results.
- Returns
Sequence of predictions packed into DataElement
- Return type
List[DataElement]
- convert_to_datasample(predictions: List[mmagic.structures.DataSample], data_samples: mmagic.structures.DataSample) List[mmagic.structures.DataSample] [source]¶
Add predictions to data samples.
- Parameters
predictions (List[DataSample]) – The predictions of the model.
data_samples (DataSample) – The data samples loaded from dataloader.
- Returns
Modified data samples.
- Return type
List[DataSample]