mmagic.models.editors.gca.gca
¶
Module Contents¶
Classes¶
Guided Contextual Attention image matting model. |
- class mmagic.models.editors.gca.gca.GCA(data_preprocessor, backbone, loss_alpha=None, init_cfg: Optional[dict] = None, train_cfg=None, test_cfg=None)[source]¶
Bases:
mmagic.models.base_models.BaseMattor
Guided Contextual Attention image matting model.
https://arxiv.org/abs/2001.04069
- Parameters
data_preprocessor (dict, optional) – The pre-process config of
BaseDataPreprocessor
.backbone (dict) – Config of backbone.
loss_alpha (dict) – Config of the alpha prediction loss. Default: None.
init_cfg (dict, optional) – Initialization config dict. Default: None.
train_cfg (dict) – Config of training. In
train_cfg
,train_backbone
should be specified. If the model has a refiner,train_refiner
should be specified.test_cfg (dict) – Config of testing. In
test_cfg
, If the model has a refiner,train_refiner
should be specified.
- _forward(inputs)[source]¶
Forward function.
- Parameters
inputs (torch.Tensor) – Input tensor.
- Returns
Output tensor.
- Return type
Tensor
- _forward_test(inputs)[source]¶
Forward function for testing GCA model.
- Parameters
inputs (torch.Tensor) – batch input tensor.
- Returns
Output tensor of model.
- Return type
Tensor
- _forward_train(inputs, data_samples)[source]¶
Forward function for training GCA model.
- Parameters
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement]) – data samples collated by
data_preprocessor
.
- Returns
Contains the loss items and batch information.
- Return type
dict