mmagic.models.editors.fba
¶
Package Contents¶
Classes¶
Decoder for FBA matting. |
|
ResNet-based encoder for FBA image matting. |
- class mmagic.models.editors.fba.FBADecoder(pool_scales, in_channels, channels, conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), align_corners=False)[源代码]¶
Bases:
torch.nn.Module
Decoder for FBA matting.
- 参数
pool_scales (tuple[int]) – Pooling scales used in
Module. (Pooling Pyramid) –
in_channels (int) – Input channels.
channels (int) – Channels after modules, before conv_seg.
conv_cfg (dict|None) – Config of conv layers.
norm_cfg (dict|None) – Config of norm layers.
act_cfg (dict) – Config of activation layers.
align_corners (bool) – align_corners argument of F.interpolate.
- init_weights(pretrained=None)¶
Init weights for the model.
- 参数
pretrained (str, optional) – Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults to None.
- forward(inputs)¶
Forward function.
- 参数
inputs (dict) – Output dict of FbaEncoder.
- 返回
Predicted alpha, fg and bg of the current batch.
- 返回类型
tuple(Tensor)
- class mmagic.models.editors.fba.FBAResnetDilated(depth: int, in_channels: int = 3, stem_channels: int = 64, base_channels: int = 64, num_stages: int = 4, strides: Sequence[int] = (1, 2, 2, 2), dilations: Sequence[int] = (1, 1, 2, 4), deep_stem: bool = False, avg_down: bool = False, frozen_stages: int = - 1, act_cfg: dict = dict(type='ReLU'), conv_cfg: Optional[dict] = None, norm_cfg: dict = dict(type='BN'), with_cp: bool = False, multi_grid: Optional[Sequence[int]] = None, contract_dilation: bool = False, zero_init_residual: bool = True)[源代码]¶
Bases:
mmagic.models.archs.ResNet
ResNet-based encoder for FBA image matting.
- forward(x)¶
Forward function.
- 参数
x (Tensor) – Input tensor with shape (N, C, H, W).
- 返回
Output tensor.
- 返回类型
Tensor