Shortcuts

Source code for mmagic.apis.inferencers.video_interpolation_inferencer

# Copyright (c) OpenMMLab. All rights reserved.
import math
import os
import os.path as osp
from typing import Dict, List, Optional, Tuple, Union

import cv2
import mmcv
import mmengine
import numpy as np
import torch
from mmengine.dataset import Compose
from mmengine.dataset.utils import default_collate as collate
from mmengine.logging import MMLogger
from mmengine.utils import ProgressBar

from .base_mmagic_inferencer import (BaseMMagicInferencer, InputsType,
                                     PredType, ResType)
from .inference_functions import VIDEO_EXTENSIONS, read_frames, read_image


[docs]class VideoInterpolationInferencer(BaseMMagicInferencer): """inferencer that predicts with video interpolation models."""
[docs] func_kwargs = dict( preprocess=['video'], forward=['result_out_dir'], visualize=[], postprocess=[])
[docs] extra_parameters = dict( start_idx=0, end_idx=None, batch_size=4, fps_multiplier=0, fps=0, filename_tmpl='{:08d}.png')
[docs] def preprocess(self, video: InputsType) -> Dict: """Process the inputs into a model-feedable format. Args: video(InputsType): Video to be interpolated by models. Returns: video(InputsType): Video to be interpolated by models. """ # build the data pipeline if self.model.cfg.get('demo_pipeline', None): test_pipeline = self.model.cfg.demo_pipeline elif self.model.cfg.get('test_pipeline', None): test_pipeline = self.model.cfg.test_pipeline else: test_pipeline = self.model.cfg.val_pipeline # remove the data loading pipeline tmp_pipeline = [] for pipeline in test_pipeline: if pipeline['type'] not in [ 'GenerateSegmentIndices', 'LoadImageFromFile' ]: tmp_pipeline.append(pipeline) test_pipeline = tmp_pipeline # compose the pipeline self.test_pipeline = Compose(test_pipeline) return video
[docs] def forward(self, inputs: InputsType, result_out_dir: InputsType = '') -> PredType: """Forward the inputs to the model. Args: inputs (InputsType): Input video directory. result_out_dir (str): Output directory of video. Defaults to ''. Returns: PredType: Result of forwarding """ # check if the input is a video input_file_extension = os.path.splitext(inputs)[1] if input_file_extension in VIDEO_EXTENSIONS: source = mmcv.VideoReader(inputs) input_fps = source.fps length = source.frame_cnt from_video = True h, w = source.height, source.width if self.extra_parameters['fps_multiplier']: assert self.extra_parameters['fps_multiplier'] > 0, \ '`fps_multiplier` cannot be negative' output_fps = \ self.extra_parameters['fps_multiplier'] * input_fps else: fps = self.extra_parameters['fps'] output_fps = fps if fps > 0 else input_fps * 2 else: files = os.listdir(inputs) files = [osp.join(inputs, f) for f in files] files.sort() source = files length = files.__len__() from_video = False example_frame = read_image(files[0]) h, w = example_frame.shape[:2] fps = self.extra_parameters['fps'] output_fps = fps if fps > 0 else 60 # check if the output is a video output_file_extension = os.path.splitext(result_out_dir)[1] mmengine.utils.mkdir_or_exist(osp.dirname(result_out_dir)) if output_file_extension in VIDEO_EXTENSIONS: fourcc = cv2.VideoWriter_fourcc(*'mp4v') target = cv2.VideoWriter(result_out_dir, fourcc, output_fps, (w, h)) to_video = True else: to_video = False self.extra_parameters['end_idx'] = min( self.extra_parameters['end_idx'], length) \ if self.extra_parameters['end_idx'] is not None else length # calculate step args step_size = \ self.model.step_frames * self.extra_parameters['batch_size'] lenth_per_step = self.model.required_frames + \ self.model.step_frames * (self.extra_parameters['batch_size'] - 1) repeat_frame = self.model.required_frames - self.model.step_frames prog_bar = ProgressBar( math.ceil((self.extra_parameters['end_idx'] + step_size - lenth_per_step - self.extra_parameters['start_idx']) / step_size)) output_index = self.extra_parameters['start_idx'] for start_index in range(self.extra_parameters['start_idx'], self.extra_parameters['end_idx'], step_size): images = read_frames( source, start_index, lenth_per_step, from_video, end_index=self.extra_parameters['end_idx']) # data prepare data = dict(img=images, inputs_path=None, key=inputs) data = self.test_pipeline(data)['inputs'] / 255.0 data = collate([data]) # data.shape: [1, t, c, h, w] # forward the model data = self.model.split_frames(data) input_tensors = data.clone().detach() with torch.no_grad(): output = self.model(data.to(self.device), mode='tensor') if len(output.shape) == 4: output = output.unsqueeze(1) output_tensors = output.cpu() if len(output_tensors.shape) == 4: output_tensors = output_tensors.unsqueeze(1) result = self.model.merge_frames(input_tensors, output_tensors) if not self.extra_parameters['start_idx'] == start_index: result = result[repeat_frame:] prog_bar.update() # save frames if to_video: for frame in result: target.write(frame) else: filename_tmpl = self.extra_parameters['filename_tmpl'] for frame in result: save_path = osp.join(result_out_dir, filename_tmpl.format(output_index)) mmcv.imwrite(frame, save_path) output_index += 1 if start_index + lenth_per_step >= \ self.extra_parameters['end_idx']: break if to_video: target.release() logger: MMLogger = MMLogger.get_current_instance() logger.info(f'Output video is save at {result_out_dir}.') return {}
[docs] def visualize(self, preds: PredType, result_out_dir: str = '') -> List[np.ndarray]: """Visualize is not needed in this inferencer.""" logger: MMLogger = MMLogger.get_current_instance() logger.info('Visualization is implemented in forward process.') return None
[docs] def postprocess( self, preds: PredType, imgs: Optional[List[np.ndarray]] = None ) -> Union[ResType, Tuple[ResType, np.ndarray]]: """Postprocess is not needed in this inferencer.""" logger: MMLogger = MMLogger.get_current_instance() logger.info('Postprocess is implemented in forward process.') return None
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.