Shortcuts

Source code for mmagic.models.base_models.base_mattor

# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from mmengine.config import Config, ConfigDict
from mmengine.model import BaseModel

from mmagic.registry import MODELS
from mmagic.structures import DataSample

[docs]DataSamples = Optional[Union[list, torch.Tensor]]
[docs]ForwardResults = Union[Dict[str, torch.Tensor], List[DataSample], Tuple[torch.Tensor], torch.Tensor]
[docs]def _pad(batch_image: torch.Tensor, ds_factor: int, mode: str = 'reflect') -> Tuple[torch.Tensor, Tuple[int, int]]: """Pad image to a multiple of give down-sampling factor.""" h, w = batch_image.shape[-2:] # NCHW new_h = ds_factor * ((h - 1) // ds_factor + 1) new_w = ds_factor * ((w - 1) // ds_factor + 1) pad_h = new_h - h pad_w = new_w - w pad = (pad_h, pad_w) if new_h != h or new_w != w: pad_width = (0, pad_w, 0, pad_h) # torch.pad in reverse order batch_image = F.pad(batch_image, pad_width, mode) return batch_image, pad
[docs]def _interpolate(batch_image: torch.Tensor, ds_factor: int, mode: str = 'bicubic' ) -> Tuple[torch.Tensor, Tuple[int, int]]: """Resize image to multiple of give down-sampling factor.""" h, w = batch_image.shape[-2:] # NCHW new_h = h - (h % ds_factor) new_w = w - (w % ds_factor) size = (new_h, new_w) if new_h != h or new_w != w: batch_image = F.interpolate(batch_image, size=size, mode=mode) return batch_image, size
[docs]class BaseMattor(BaseModel, metaclass=ABCMeta): """Base class for trimap-based matting models. A matting model must contain a backbone which produces `pred_alpha`, a dense prediction with the same height and width of input image. In some cases (such as DIM), the model has a refiner which refines the prediction of the backbone. Subclasses should overwrite the following functions: - :meth:`_forward_train`, to return a loss - :meth:`_forward_test`, to return a prediction - :meth:`_forward`, to return raw tensors For test, this base class provides functions to resize inputs and post-process pred_alphas to get predictions Args: backbone (dict): Config of backbone. data_preprocessor (dict): Config of data_preprocessor. See :class:`MattorPreprocessor` for details. init_cfg (dict, optional): Initialization config dict. train_cfg (dict): Config of training. Customized by subclassesCustomized bu In ``train_cfg``, ``train_backbone`` should be specified. If the model has a refiner, ``train_refiner`` should be specified. test_cfg (dict): Config of testing. In ``test_cfg``, If the model has a refiner, ``train_refiner`` should be specified. """ def __init__(self, data_preprocessor: Union[dict, Config], backbone: dict, init_cfg: Optional[dict] = None, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None): # Build data_preprocessor in BaseModel # Initialize weights in BaseModule super().__init__( data_preprocessor=data_preprocessor, init_cfg=init_cfg) self.train_cfg = ConfigDict( train_cfg) if train_cfg is not None else ConfigDict() self.test_cfg = ConfigDict( test_cfg) if test_cfg is not None else ConfigDict() self.backbone = MODELS.build(backbone)
[docs] def resize_inputs(self, batch_inputs: torch.Tensor) -> torch.Tensor: """Pad or interpolate images and trimaps to multiple of given factor.""" resize_method = self.test_cfg['resize_method'] resize_mode = self.test_cfg['resize_mode'] size_divisor = self.test_cfg['size_divisor'] batch_images = batch_inputs[:, :3, :, :] batch_trimaps = batch_inputs[:, 3:, :, :] if resize_method == 'pad': batch_images, _ = _pad(batch_images, size_divisor, resize_mode) batch_trimaps, _ = _pad(batch_trimaps, size_divisor, resize_mode) elif resize_method == 'interp': batch_images, _ = _interpolate(batch_images, size_divisor, resize_mode) batch_trimaps, _ = _interpolate(batch_trimaps, size_divisor, 'nearest') else: raise NotImplementedError return torch.cat((batch_images, batch_trimaps), dim=1)
[docs] def restore_size(self, pred_alpha: torch.Tensor, data_sample: DataSample) -> torch.Tensor: """Restore the predicted alpha to the original shape. The shape of the predicted alpha may not be the same as the shape of original input image. This function restores the shape of the predicted alpha. Args: pred_alpha (torch.Tensor): A single predicted alpha of shape (1, H, W). data_sample (DataSample): Data sample containing original shape as meta data. Returns: torch.Tensor: The reshaped predicted alpha. """ resize_method = self.test_cfg['resize_method'] resize_mode = self.test_cfg['resize_mode'] ori_h, ori_w = data_sample.ori_merged_shape[:2] if resize_method == 'pad': pred_alpha = pred_alpha[:, :ori_h, :ori_w] elif resize_method == 'interp': pred_alpha = F.interpolate( pred_alpha.unsqueeze(0), size=(ori_h, ori_w), mode=resize_mode) pred_alpha = pred_alpha[0] # 1,H,W return pred_alpha
[docs] def postprocess( self, batch_pred_alpha: torch.Tensor, # N, 1, H, W, float32 data_samples: DataSample, ) -> List[DataSample]: """Post-process alpha predictions. This function contains the following steps: 1. Restore padding or interpolation 2. Mask alpha prediction with trimap 3. Clamp alpha prediction to 0-1 4. Convert alpha prediction to uint8 5. Pack alpha prediction into DataSample Currently only batch_size 1 is actually supported. Args: batch_pred_alpha (torch.Tensor): A batch of predicted alpha of shape (N, 1, H, W). data_samples (List[DataSample]): List of data samples. Returns: List[DataSample]: A list of predictions. Each data sample contains a pred_alpha, which is a torch.Tensor with dtype=uint8, device=cuda:0 """ assert batch_pred_alpha.ndim == 4 # N, 1, H, W, float32 assert len(batch_pred_alpha) == 1 # NOTE: for mattors, we split datasamples here, not in # `convert_to_datasample` data_samples = data_samples.split() predictions = [] for pa, ds in zip(batch_pred_alpha, data_samples): pa = self.restore_size(pa, ds) # 1, H, W pa = pa[0] # H, W pa.clamp_(min=0, max=1) ori_trimap = ds.ori_trimap[0, :, :] # H, W pa[ori_trimap == 255] = 1 pa[ori_trimap == 0] = 0 pa *= 255 pa.round_() pa = pa.to(dtype=torch.uint8) # pa = pa.cpu().numpy() pa_sample = DataSample(pred_alpha=pa) predictions.append(pa_sample) return predictions
[docs] def forward(self, inputs: torch.Tensor, data_samples: DataSamples = None, mode: str = 'tensor') -> List[DataSample]: """General forward function. Args: inputs (torch.Tensor): A batch of inputs. with image and trimap concatenated alone channel dimension. data_samples (List[DataSample], optional): A list of data samples, containing: - Ground-truth alpha / foreground / background to compute loss - other meta information mode (str): mode should be one of ``loss``, ``predict`` and ``tensor``. Default: 'tensor'. - ``loss``: Called by ``train_step`` and return loss ``dict`` used for logging - ``predict``: Called by ``val_step`` and ``test_step`` and return list of ``BaseDataElement`` results used for computing metric. - ``tensor``: Called by custom use to get ``Tensor`` type results. Returns: List[DataElement]: Sequence of predictions packed into DataElement """ if mode == 'tensor': raw = self._forward(inputs) return raw elif mode == 'predict': # Pre-process runs in runner inputs = self.resize_inputs(inputs) batch_pred_alpha = self._forward_test(inputs) predictions = self.postprocess(batch_pred_alpha, data_samples) predictions = self.convert_to_datasample(predictions, data_samples) return predictions elif mode == 'loss': loss = self._forward_train(inputs, data_samples) return loss else: raise ValueError('Invalid forward mode.')
[docs] def convert_to_datasample(self, predictions: List[DataSample], data_samples: DataSample) -> List[DataSample]: """Add predictions to data samples. Args: predictions (List[DataSample]): The predictions of the model. data_samples (DataSample): The data samples loaded from dataloader. Returns: List[DataSample]: Modified data samples. """ data_samples = data_samples.split() for data_sample, pred in zip(data_samples, predictions): data_sample.output = pred return data_samples
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.