mmagic.models.editors.dim.dim
¶
Module Contents¶
Classes¶
Deep Image Matting model. |
- class mmagic.models.editors.dim.dim.DIM(data_preprocessor, backbone, refiner=None, train_cfg=None, test_cfg=None, loss_alpha=None, loss_comp=None, loss_refine=None, init_cfg: Optional[dict] = None)[source]¶
Bases:
mmagic.models.base_models.BaseMattor
Deep Image Matting model.
https://arxiv.org/abs/1703.03872
Note
For
(self.train_cfg.train_backbone, self.train_cfg.train_refiner)
:(True, False)
corresponds to the encoder-decoder stage in the paper.(False, True)
corresponds to the refinement stage in the paper.(True, True)
corresponds to the fine-tune stage in the paper.
- Parameters
data_preprocessor (dict, optional) – Config of data pre-processor.
backbone (dict) – Config of backbone.
refiner (dict) – Config of refiner.
loss_alpha (dict) – Config of the alpha prediction loss. Default: None.
loss_comp (dict) – Config of the composition loss. Default: None.
loss_refine (dict) – Config of the loss of the refiner. Default: None.
train_cfg (dict) – Config of training. In
train_cfg
,train_backbone
should be specified. If the model has a refiner,train_refiner
should be specified.test_cfg (dict) – Config of testing. In
test_cfg
, If the model has a refiner,train_refiner
should be specified.init_cfg (dict, optional) – The weight initialized config for
BaseModule
. Default: None.
- train(mode=True)[source]¶
Mode switcher.
- Parameters
mode (bool) – whether to set training mode (
True
) or evaluation mode (False
). Default:True
.
- _forward(x: torch.Tensor, *, refine: bool = True) Tuple[torch.Tensor, torch.Tensor] [source]¶
Raw forward function.
- Parameters
x (torch.Tensor) – Concatenation of merged image and trimap with shape (N, 4, H, W)
refine (bool) – if forward through refiner
- Returns
pred_alpha, with shape (N, 1, H, W) torch.Tensor: pred_refine, with shape (N, 4, H, W)
- Return type
torch.Tensor
- _forward_train(inputs, data_samples)[source]¶
Defines the computation performed at every training call.
- Parameters
inputs (torch.Tensor) – Concatenation of normalized image and trimap shape (N, 4, H, W)
data_samples (list[DataSample]) –
Data samples containing: - gt_alpha (Tensor): Ground-truth of alpha
shape (N, 1, H, W), normalized to 0 to 1.
- gt_fg (Tensor): Ground-truth of foreground
shape (N, C, H, W), normalized to 0 to 1.
- gt_bg (Tensor): Ground-truth of background
shape (N, C, H, W), normalized to 0 to 1.
- Returns
Contains the loss items and batch information.
- Return type
dict