Shortcuts

Source code for mmagic.apis.inferencers.matting_inferencer

# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Dict, List

import mmcv
import numpy as np
import torch
from mmengine import mkdir_or_exist
from mmengine.dataset import Compose

from mmagic.structures import DataSample
from .base_mmagic_inferencer import BaseMMagicInferencer, InputsType, PredType


[docs]class MattingInferencer(BaseMMagicInferencer): """inferencer that predicts with matting models."""
[docs] func_kwargs = dict( preprocess=['img', 'trimap'], forward=[], visualize=['result_out_dir'], postprocess=[])
[docs] def preprocess(self, img: InputsType, trimap: InputsType) -> Dict: """Process the inputs into a model-feedable format. Args: img(InputsType): Image to be processed by models. mask(InputsType): Mask corresponding to the input image. Returns: results(Dict): Results of preprocess. """ # remove alpha from test_pipeline keys_to_remove = ['alpha', 'ori_alpha'] for key in keys_to_remove: for pipeline in list(self.cfg.test_pipeline): if 'key' in pipeline and key == pipeline['key']: self.cfg.test_pipeline.remove(pipeline) if 'keys' in pipeline and key in pipeline['keys']: pipeline['keys'].remove(key) if len(pipeline['keys']) == 0: self.cfg.test_pipeline.remove(pipeline) if 'meta_keys' in pipeline and key in pipeline['meta_keys']: pipeline['meta_keys'].remove(key) # build the data pipeline test_pipeline = Compose(self.cfg.test_pipeline) # prepare data data = dict(merged_path=img, trimap_path=trimap) _data = test_pipeline(data) trimap = _data['data_samples'].trimap.data preprocess_res = dict() preprocess_res['inputs'] = [_data['inputs']] preprocess_res['data_samples'] = [_data['data_samples']] return preprocess_res
[docs] def forward(self, inputs: InputsType) -> PredType: """Forward the inputs to the model.""" inputs = self.model.data_preprocessor(inputs) with torch.no_grad(): return self.model(mode='predict', **inputs)
[docs] def visualize(self, preds: PredType, result_out_dir: str = None) -> List[np.ndarray]: """Visualize predictions. Args: preds (List[Union[str, np.ndarray]]): Forward results by the inferencer. data (List[Dict]): Not needed by this kind of inferencer. result_out_dir (str): Output directory of image. Defaults to ''. Returns: List[np.ndarray]: Result of visualize """ result = preds[0].output result = result.pred_alpha.data.cpu() # save images if result_out_dir: mkdir_or_exist(os.path.dirname(result_out_dir)) mmcv.imwrite(result.numpy(), result_out_dir) return result
[docs] def _pred2dict(self, data_sample: DataSample) -> Dict: """Extract elements necessary to represent a prediction into a dictionary. It's better to contain only basic data elements such as strings and numbers in order to guarantee it's json-serializable. Args: data_sample (DataSample): The data sample to be converted. Returns: dict: The output dictionary. """ result = {} result['pred_alpha'] = data_sample.output.pred_alpha.data.cpu() return result
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.