Shortcuts

Source code for mmagic.models.data_preprocessors.data_preprocessor

# Copyright (c) OpenMMLab. All rights reserved.
import math
from logging import WARNING
from typing import List, Optional, Sequence, Tuple, Union

import torch
import torch.nn.functional as F
from mmengine import print_log
from mmengine.model import ImgDataPreprocessor
from mmengine.utils import is_seq_of
from torch import Tensor

from mmagic.registry import MODELS
from mmagic.structures import DataSample
from mmagic.utils.typing import SampleList

[docs]CastData = Union[tuple, dict, DataSample, Tensor, list]
@MODELS.register_module()
[docs]class DataPreprocessor(ImgDataPreprocessor): """Image pre-processor for generative models. This class provide normalization and bgr to rgb conversion for image tensor inputs. The input of this classes should be dict which keys are `inputs` and `data_samples`. Besides to process tensor `inputs`, this class support dict as `inputs`. - If the value is `Tensor` and the corresponding key is not contained in :attr:`_NON_IMAGE_KEYS`, it will be processed as image tensor. - If the value is `Tensor` and the corresponding key belongs to :attr:`_NON_IMAGE_KEYS`, it will not remains unchanged. - If value is string or integer, it will not remains unchanged. Args: mean (Sequence[float or int], float or int, optional): The pixel mean of image channels. Noted that normalization operation is performed *after channel order conversion*. If it is not specified, images will not be normalized. Defaults None. std (Sequence[float or int], float or int, optional): The pixel standard deviation of image channels. Noted that normalization operation is performed *after channel order conversion*. If it is not specified, images will not be normalized. Defaults None. pad_size_divisor (int): The size of padded image should be divisible by ``pad_size_divisor``. Defaults to 1. pad_value (float or int): The padded pixel value. Defaults to 0. pad_mode (str): Padding mode for ``torch.nn.functional.pad``. Defaults to 'constant'. non_image_keys (List[str] or str): Keys for fields that not need to be processed (padding, channel conversion and normalization) as images. If not passed, the keys in :attr:`_NON_IMAGE_KEYS` will be used. This argument will only work when `inputs` is dict or list of dict. Defaults to None. non_concatenate_keys (List[str] or str): Keys for fields that not need to be concatenated. If not passed, the keys in :attr:`_NON_CONCATENATE_KEYS` will be used. This argument will only work when `inputs` is dict or list of dict. Defaults to None. output_channel_order (str, optional): The desired image channel order of output the data preprocessor. This is also the desired input channel order of model (and this most likely to be the output order of model). If not passed, no channel order conversion will be performed. Defaults to None. data_keys (List[str] or str): Keys to preprocess in data samples. Defaults to 'gt_img'. input_view (tuple, optional): The view of input tensor. This argument maybe deleted in the future. Defaults to None. output_view (tuple, optional): The view of output tensor. This argument maybe deleted in the future. Defaults to None. stack_data_sample (bool): Whether stack a list of data samples to one data sample. Only support with input data samples are `DataSamples`. Defaults to True. """
[docs] _NON_IMAGE_KEYS = ['noise']
[docs] _NON_CONCATENATE_KEYS = ['num_batches', 'mode', 'sample_kwargs', 'eq_cfg']
def __init__(self, mean: Union[Sequence[Union[float, int]], float, int] = 127.5, std: Union[Sequence[Union[float, int]], float, int] = 127.5, pad_size_divisor: int = 1, pad_value: Union[float, int] = 0, pad_mode: str = 'constant', non_image_keys: Optional[Tuple[str, List[str]]] = None, non_concentate_keys: Optional[Tuple[str, List[str]]] = None, output_channel_order: Optional[str] = None, data_keys: Union[List[str], str] = 'gt_img', input_view: Optional[tuple] = None, output_view: Optional[tuple] = None, stack_data_sample=True): if not isinstance(mean, (list, tuple)) and mean is not None: mean = [mean] if not isinstance(std, (list, tuple)) and std is not None: std = [std] super().__init__(mean, std, pad_size_divisor, pad_value) # get channel order assert (output_channel_order is None or output_channel_order in ['RGB', 'BGR']), ( 'Only support \'RGB\', \'BGR\' or None for ' '\'output_channel_order\', but receive ' f'\'{output_channel_order}\'.') self.output_channel_order = output_channel_order # add user defined keys if non_image_keys is not None: if not isinstance(non_image_keys, list): non_image_keys = [non_image_keys] self._NON_IMAGE_KEYS += non_image_keys if non_concentate_keys is not None: if not isinstance(non_concentate_keys, list): non_concentate_keys = [non_concentate_keys] self._NON_CONCATENATE_KEYS += non_concentate_keys self.pad_mode = pad_mode self.pad_size_dict = dict() if data_keys is not None and not isinstance(data_keys, list): self.data_keys = [data_keys] else: self.data_keys = data_keys # TODO: can be removed since only be used in LIIF self.input_view = input_view self.output_view = output_view self._done_padding = False # flag for padding checking self._conversion_warning_raised = False # flag for conversion warning self.stack_data_sample = stack_data_sample
[docs] def cast_data(self, data: CastData) -> CastData: """Copying data to the target device. Args: data (dict): Data returned by ``DataLoader``. Returns: CollatedResult: Inputs and data sample at target device. """ if isinstance(data, (str, int, float)): return data return super().cast_data(data)
@staticmethod
[docs] def _parse_channel_index(inputs) -> int: """Parse channel index of inputs.""" channel_index_mapping = {2: 1, 3: 0, 4: 1, 5: 2} if isinstance(inputs, dict): ndim = inputs['fake_img'].ndim assert ndim in channel_index_mapping, ( 'Only support (H*W, C), (C, H, W), (N, C, H, W) or ' '(N, t, C, H, W) inputs. But received ' f'\'({inputs.shape})\'.') channel_index = channel_index_mapping[ndim] else: assert inputs.ndim in channel_index_mapping, ( 'Only support (H*W, C), (C, H, W), (N, C, H, W) or ' '(N, t, C, H, W) inputs. But received ' f'\'({inputs.shape})\'.') channel_index = channel_index_mapping[inputs.ndim] return channel_index
[docs] def _parse_channel_order(self, key: str, inputs: Tensor, data_sample: Optional[DataSample] = None) -> str: channel_index = self._parse_channel_index(inputs) if isinstance(inputs, dict): num_color_channels = inputs['fake_img'].shape[channel_index] else: num_color_channels = inputs.shape[channel_index] # data sample is None, attempt to infer from input tensor if data_sample is None: if num_color_channels == 1: return 'single' else: # default as BGR return 'BGR' # data sample is not None, infer from metainfo channel_order_key = 'gt_channel_order' if key == 'gt_img' \ else f'{key}_channel_order' color_type_key = 'gt_color_type' if key == 'gt_img' \ else f'{key}_color_type' # TODO: raise warning here, we can build a dict which fields are keys # have been parsed. color_flag = data_sample.metainfo.get(color_type_key, None) channel_order = data_sample.metainfo.get(channel_order_key, None) # handle stacked data sample, refers to `DataSample.stack` if isinstance(color_flag, list): assert all([c == color_flag[0] for c in color_flag]) color_flag = color_flag[0] if isinstance(channel_order, list): assert all([c == channel_order[0] for c in channel_order]) channel_order = channel_order[0] # NOTE: to handle inputs such as Y, users may modify the following code if color_flag == 'grayscale': assert num_color_channels == 1 return 'single' elif color_flag == 'unchanged': # if inputs is not None: return 'single' if num_color_channels == 1 else 'BGR' else: # inference from channel_order if channel_order: return channel_order else: # no channel order, infer from num channels return 'single' if num_color_channels == 1 else 'BGR'
[docs] def _parse_batch_channel_order(self, key: str, inputs: Sequence, data_samples: Optional[Sequence[DataSample]] ) -> str: """Parse channel order of inputs in batch.""" assert len(inputs) == len(data_samples) batch_inputs_orders = [ self._parse_channel_order(key, inp, data_sample) for inp, data_sample in zip(inputs, data_samples) ] inputs_order = batch_inputs_orders[0] # security checking for channel order assert all([ inputs_order == order for order in batch_inputs_orders ]), (f'Channel order ({batch_inputs_orders}) of input targets ' f'(\'{key}\') are inconsistent.') return inputs_order
[docs] def _update_metainfo(self, padding_info: Tensor, channel_order_info: Optional[dict] = None, data_samples: Optional[SampleList] = None ) -> SampleList: """Update `padding_info` and `channel_order` to metainfo of. *a batch of `data_samples`*. For channel order, we consider same field among data samples share the same channel order. Therefore `channel_order` is passed as a dict, which key and value are field name and corresponding channel order. For padding info, we consider padding info is same among all field of a sample, but can vary between samples. Therefore, we pass `padding_info` as Tensor shape like (B, 1, 1). Args: padding_info (Tensor): The padding info of each sample. Shape like (B, 1, 1). channel_order (dict, Optional): The channel order of target field. Key and value are field name and corresponding channel order respectively. data_samples (List[DataSample], optional): The data samples to be updated. If not passed, will initialize a list of empty data samples. Defaults to None. Returns: List[DataSample]: The updated data samples. """ n_samples = padding_info.shape[0] if data_samples is None: data_samples = [DataSample() for _ in range(n_samples)] else: assert len(data_samples) == n_samples, ( f'The length of \'data_samples\'({len(data_samples)}) and ' f'\'padding_info\'({n_samples}) are inconsistent. Please ' 'check your inputs.') # update padding info for pad_size, data_sample in zip(padding_info, data_samples): data_sample.set_metainfo({'padding_size': pad_size}) # update channel order if channel_order_info is not None: for data_sample in data_samples: for key, channel_order in channel_order_info.items(): data_sample.set_metainfo( {f'{key}_output_channel_order': channel_order}) self._done_padding = padding_info.sum() != 0 return data_samples
[docs] def _do_conversion(self, inputs: Tensor, inputs_order: str = 'BGR', target_order: Optional[str] = None ) -> Tuple[Tensor, str]: """Conduct channel order conversion for *a batch of inputs*, and return the converted inputs and order after conversion. inputs_order: * RGB / RGB: Convert to target order. * SINGLE: Do not change """ if (target_order is None or inputs_order.upper() == target_order.upper()): # order is not changed, return the input one return inputs, inputs_order def conversion(inputs, channel_index): if inputs.shape[channel_index] == 4: new_index = [2, 1, 0, 3] else: new_index = [2, 1, 0] # do conversion inputs = torch.index_select( inputs, channel_index, torch.LongTensor(new_index).to(inputs.device)) return inputs channel_index = self._parse_channel_index(inputs) if inputs_order.upper() in ['RGB', 'BGR']: inputs = conversion(inputs, channel_index) return inputs, target_order elif inputs_order.upper() == 'SINGLE': if not self._conversion_warning_raised: print_log( 'Cannot convert inputs with \'single\' channel order ' f'to \'output_channel_order\' ({self.output_channel_order}' '). Return without conversion.', 'current', WARNING) self._conversion_warning_raised = True return inputs, inputs_order else: raise ValueError(f'Unsupported inputs order \'{inputs_order}\'.')
[docs] def _do_norm(self, inputs: Tensor, do_norm: Optional[bool] = None) -> Tensor: do_norm = self._enable_normalize if do_norm is None else do_norm if do_norm: if self.input_view is None: if inputs.ndim == 2: # special case for (H*W, C) tensor target_shape = [1, -1] else: target_shape = [1 for _ in range(inputs.ndim - 3) ] + [-1, 1, 1] else: target_shape = self.input_view mean = self.mean.view(target_shape) std = self.std.view(target_shape) # shape checking to avoid broadcast a single channel tensor to 3 channel_idx = self._parse_channel_index(inputs) n_channel_inputs = inputs.shape[channel_idx] n_channel_mean = mean.shape[channel_idx] n_channel_std = std.shape[channel_idx] assert n_channel_mean == 1 or n_channel_mean == n_channel_inputs assert n_channel_std == 1 or n_channel_std == n_channel_inputs inputs = (inputs - mean) / std return inputs
[docs] def _preprocess_image_tensor(self, inputs: Tensor, data_samples: Optional[SampleList] = None, key: str = 'img' ) -> Tuple[Tensor, SampleList]: """Preprocess a batch of image tensor and update metainfo to corresponding data samples. Args: inputs (Tensor): Image tensor with shape (C, H, W), (N, C, H, W) or (N, t, C, H, W) to preprocess. data_samples (List[DataSample], optional): The data samples of corresponding inputs. If not passed, a list of empty data samples will be initialized to save metainfo. Defaults to None. key (str): The key of image tensor in data samples. Defaults to 'img'. Returns: Tuple[Tensor, List[DataSample]]: The preprocessed image tensor and updated data samples. """ if not data_samples: # none or empty list data_samples = [DataSample() for _ in range(inputs.shape[0])] assert inputs.dim() in [ 3, 4, 5 ], ('The input of `_preprocess_image_tensor` should be a (C, H, W), ' '(N, C, H, W) or (N, t, C, H, W)tensor, but got a tensor with ' f'shape: {inputs.shape}') channel_order = self._parse_batch_channel_order( key, inputs, data_samples) inputs, output_channel_order = self._do_conversion( inputs, channel_order, self.output_channel_order) inputs = self._do_norm(inputs) h, w = inputs.shape[-2:] target_h = math.ceil(h / self.pad_size_divisor) * self.pad_size_divisor target_w = math.ceil(w / self.pad_size_divisor) * self.pad_size_divisor pad_h = target_h - h pad_w = target_w - w batch_inputs = F.pad(inputs, (0, pad_w, 0, pad_h), self.pad_mode, self.pad_value) padding_size = torch.FloatTensor((0, pad_h, pad_w))[None, ...] padding_size = padding_size.repeat(inputs.shape[0], 1) data_samples = self._update_metainfo(padding_size, {key: output_channel_order}, data_samples) return batch_inputs, data_samples
[docs] def _preprocess_image_list(self, tensor_list: List[Tensor], data_samples: Optional[SampleList], key: str = 'img') -> Tuple[Tensor, SampleList]: """Preprocess a list of image tensor and update metainfo to corresponding data samples. Args: tensor_list (List[Tensor]): Image tensor list to be preprocess. data_samples (List[DataSample], optional): The data samples of corresponding inputs. If not passed, a list of empty data samples will be initialized to save metainfo. Defaults to None. key (str): The key of tensor list in data samples. Defaults to 'img'. Returns: Tuple[Tensor, List[DataSample]]: The preprocessed image tensor and updated data samples. """ if not data_samples: # none or empty list data_samples = [DataSample() for _ in range(len(tensor_list))] channel_order = self._parse_batch_channel_order( key, tensor_list, data_samples) dim = tensor_list[0].dim() assert all([ tensor.ndim == dim for tensor in tensor_list ]), ('Expected the dimensions of all tensors must be the same, ' f'but got {[tensor.ndim for tensor in tensor_list]}') num_img = len(tensor_list) all_sizes: torch.Tensor = torch.Tensor( [tensor.shape for tensor in tensor_list]) max_sizes = torch.ceil( torch.max(all_sizes, dim=0)[0] / self.pad_size_divisor) * self.pad_size_divisor padding_sizes = max_sizes - all_sizes # The dim of channel and frame index should not be padded. padding_sizes[:, :-2] = 0 if padding_sizes.sum() == 0: stacked_tensor = torch.stack(tensor_list) stacked_tensor, output_channel_order = self._do_conversion( stacked_tensor, channel_order, self.output_channel_order) stacked_tensor = self._do_norm(stacked_tensor) data_samples = self._update_metainfo(padding_sizes, {key: output_channel_order}, data_samples) return stacked_tensor, data_samples # `pad` is the second arguments of `F.pad`. If pad is (1, 2, 3, 4), # it means that padding the last dim with 1(left) 2(right), padding the # penultimate dim to 3(top) 4(bottom). The order of `pad` is opposite # of the `padded_sizes`. Therefore, the `padded_sizes` needs to be # reversed, and only odd index of pad should be assigned to keep # padding "right" and "bottom". pad = torch.zeros(num_img, 2 * dim, dtype=torch.int) pad[:, 1::2] = padding_sizes[:, range(dim - 1, -1, -1)] batch_tensor = [] for idx, tensor in enumerate(tensor_list): paded_tensor = F.pad(tensor, tuple(pad[idx].tolist()), self.pad_mode, self.pad_value) batch_tensor.append(paded_tensor) stacked_tensor = torch.stack(batch_tensor) stacked_tensor, output_channel_order = self._do_conversion( stacked_tensor, channel_order, self.output_channel_order) stacked_tensor = self._do_norm(stacked_tensor) data_samples = self._update_metainfo(padding_sizes, {key: output_channel_order}, data_samples) # return stacked_tensor, padding_sizes return stacked_tensor, data_samples
[docs] def _preprocess_dict_inputs(self, batch_inputs: dict, data_samples: Optional[SampleList] = None ) -> Tuple[dict, SampleList]: """Preprocess dict type inputs. Args: batch_inputs (dict): Input dict. data_samples (List[DataSample], optional): The data samples of corresponding inputs. If not passed, a list of empty data samples will be initialized to save metainfo. Defaults to None. Returns: Tuple[dict, List[DataSample]]: The preprocessed dict and updated data samples. """ pad_size_dict = dict() for k, inputs in batch_inputs.items(): # handle concentrate for values in list if isinstance(inputs, list): if k in self._NON_CONCATENATE_KEYS: # use the first value assert all([ inputs[0] == inp for inp in inputs ]), (f'NON_CONCENTATE_KEY \'{k}\' should be consistency ' 'among the data list.') batch_inputs[k] = inputs[0] else: assert all([ isinstance(inp, torch.Tensor) for inp in inputs ]), ('Only support stack list of Tensor in inputs dict. ' f'But \'{k}\' is list of \'{type(inputs[0])}\'.') if k not in self._NON_IMAGE_KEYS: # preprocess as image inputs, data_samples = self._preprocess_image_list( inputs, data_samples, k) pad_size_dict[k] = [ data.metainfo.get('padding_size') for data in data_samples ] else: # only stack inputs = torch.stack(inputs) batch_inputs[k] = inputs elif isinstance(inputs, Tensor) and k not in self._NON_IMAGE_KEYS: batch_inputs[k], data_samples = \ self._preprocess_image_tensor(inputs, data_samples, k) pad_size_dict[k] = [ data.metainfo.get('padding_size') for data in data_samples ] # NOTE: we only support all key shares the same padding size if pad_size_dict: padding_sizes = list(pad_size_dict.values())[0] padding_key = list(pad_size_dict.keys())[0] for idx, tar_size in enumerate(padding_sizes): for k, sizes in pad_size_dict.items(): if (tar_size != sizes[idx]).any(): raise ValueError( f'All fields of a data sample should share the ' 'same padding size, but got different size for ' f'\'{k}\'(\'{sizes[idx]}\') and \'{padding_key}\'' f'(\'{tar_size}\') at index {idx}.Please check ' 'your data carefully.') return batch_inputs, data_samples
[docs] def _preprocess_data_sample(self, data_samples: SampleList, training: bool) -> DataSample: """Preprocess data samples. When `training` is True, fields belong to :attr:`self.data_keys` will be converted to :attr:`self.output_channel_order` and then normalized by `self.mean` and `self.std`. When `training` is False, fields belongs to :attr:`self.data_keys` will be attempted to convert to 'BGR' without normalization. The corresponding metainfo related to normalization, channel order conversion will be updated to data sample as well. Args: data_samples (List[DataSample]): A list of data samples to preprocess. training (bool): Whether in training mode. Returns: list: The list of processed data samples. """ if not training: # set default order to BGR in test stage target_order, do_norm = 'BGR', False else: # norm in training, conversion as default (None) target_order, do_norm = self.output_channel_order, True for data_sample in data_samples: if not self.data_keys: break for key in self.data_keys: if not hasattr(data_sample, key): # do not raise error here print_log(f'Cannot find key \'{key}\' in data sample.', 'current', WARNING) break data = data_sample.get(key) data_channel_order = self._parse_channel_order( key, data, data_sample) data, channel_order = self._do_conversion( data, data_channel_order, target_order) data = self._do_norm(data, do_norm) data_sample.set_data({f'{key}': data}) data_process_meta = { f'{key}_enable_norm': self._enable_normalize, f'{key}_output_channel_order': channel_order, f'{key}_mean': self.mean, f'{key}_std': self.std } data_sample.set_metainfo(data_process_meta) if self.stack_data_sample: assert is_seq_of(data_samples, DataSample), ( 'Only support \'stack_data_sample\' for DataSample ' 'object. Please refer to \'DataSample.stack\'.') return DataSample.stack(data_samples) return data_samples
[docs] def forward(self, data: dict, training: bool = False) -> dict: """Performs normalization、padding and channel order conversion. Args: data (dict): Input data to process. training (bool): Whether to in training mode. Default: False. Returns: dict: Data in the same format as the model input. """ data = self.cast_data(data) _batch_inputs = data['inputs'] _batch_data_samples = data.get('data_samples', None) # process input if isinstance(_batch_inputs, torch.Tensor): _batch_inputs, _batch_data_samples = \ self._preprocess_image_tensor( _batch_inputs, _batch_data_samples) elif is_seq_of(_batch_inputs, torch.Tensor): _batch_inputs, _batch_data_samples = \ self._preprocess_image_list( _batch_inputs, _batch_data_samples) elif isinstance(_batch_inputs, dict): _batch_inputs, _batch_data_samples = \ self._preprocess_dict_inputs( _batch_inputs, _batch_data_samples) elif is_seq_of(_batch_inputs, dict): # convert list of dict to dict of list keys = _batch_inputs[0].keys() dict_input = {k: [inp[k] for inp in _batch_inputs] for k in keys} _batch_inputs, _batch_data_samples = \ self._preprocess_dict_inputs( dict_input, _batch_data_samples) else: raise ValueError('Only support following inputs types: ' '\'torch.Tensor\', \'List[torch.Tensor]\', ' '\'dict\', \'List[dict]\'. But receive ' f'\'{type(_batch_inputs)}\'.') data['inputs'] = _batch_inputs # process data samples if _batch_data_samples: _batch_data_samples = self._preprocess_data_sample( _batch_data_samples, training) data['data_samples'] = _batch_data_samples return data
[docs] def destruct(self, outputs: Tensor, data_samples: Union[SampleList, DataSample, None] = None, key: str = 'img') -> Union[list, Tensor]: """Destruct padding, normalization and convert channel order to BGR if could. If `data_samples` is a list, outputs will be destructed as a batch of tensor. If `data_samples` is a `DataSample`, `outputs` will be destructed as a single tensor. Before feed model outputs to visualizer and evaluator, users should call this function for model outputs and inputs. Use cases: >>> # destruct model outputs. >>> # model outputs share the same preprocess information with inputs >>> # ('img') therefore use 'img' as key >>> feats = self.forward_tensor(inputs, data_samples, **kwargs) >>> feats = self.data_preprocessor.destruct(feats, data_samples, 'img') >>> # destruct model inputs for visualization >>> for idx, data_sample in enumerate(data_samples): >>> destructed_input = self.data_preprocessor.destruct( >>> inputs[idx], data_sample, key='img') >>> data_sample.set_data({'input': destructed_input}) Args: outputs (Tensor): Tensor to destruct. data_samples (Union[SampleList, DataSample], optional): Data samples (or data sample) corresponding to `outputs`. Defaults to None key (str): The key of field in data sample. Defaults to 'img'. Returns: Union[list, Tensor]: Destructed outputs. """ # NOTE: only support passing tensor sample, if the output of model is # a dict, users should call this manually. # Since we do not know whether the outputs is image tensor. _batch_outputs = self._destruct_norm_and_conversion( outputs, data_samples, key) _batch_outputs = self._destruct_padding(_batch_outputs, data_samples) _batch_outputs = _batch_outputs.clamp_(0, 255) return _batch_outputs
[docs] def _destruct_norm_and_conversion(self, batch_tensor: Tensor, data_samples: Union[SampleList, DataSample, None], key: str) -> Tensor: """De-norm and de-convert channel order. Noted that, we de-norm first, and then de-conversion, since mean and std used in normalization is based on channel order after conversion. Args: batch_tensor (Tensor): Tensor to destruct. data_samples (Union[SampleList, DataSample], optional): Data samples (or data sample) corresponding to `outputs`. key (str): The key of field in data sample. Returns: Tensor: Destructed tensor. """ output_key = f'{key}_output' # get channel order from data sample if isinstance(data_samples, list): inputs_order = self._parse_batch_channel_order( output_key, batch_tensor, data_samples) else: inputs_order = self._parse_channel_order(output_key, batch_tensor, data_samples) if self._enable_normalize: if self.output_view is None: if batch_tensor.ndim == 2: # special case for (H*W, C) tensor target_shape = [1, -1] else: target_shape = [1 for _ in range(batch_tensor.ndim - 3) ] + [-1, 1, 1] else: target_shape = self.output_view mean = self.mean.view(target_shape) std = self.std.view(target_shape) batch_tensor = batch_tensor * std + mean # convert output to 'BGR' if able batch_tensor, _ = self._do_conversion( batch_tensor, inputs_order=inputs_order, target_order='BGR') return batch_tensor
[docs] def _destruct_padding(self, batch_tensor: Tensor, data_samples: Union[SampleList, DataSample, None], same_padding: bool = True) -> Union[list, Tensor]: """Destruct padding of the input tensor. Args: batch_tensor (Tensor): Tensor to destruct. data_samples (Union[SampleList, DataSample], optional): Data samples (or data sample) corresponding to `outputs`. If same_padding (bool): Whether all samples will un-padded with the padding info of the first sample, and return a stacked un-padded tensor. Otherwise each sample will be unpadded with padding info saved in corresponding data samples, and return a list of un-padded tensor, since each un-padded tensor may have the different shape. Defaults to True. Returns: Union[list, Tensor]: Destructed outputs. """ # NOTE: If same padding, batch_tensor will un-padded with the padding # info # of the first sample and return a Unpadded tensor. Otherwise, # input tensor # will un-padded with the corresponding padding info # saved in data samples and return a list of tensor. if data_samples is None: return batch_tensor if isinstance(data_samples, list): is_batch_data = True if 'padding_size' in data_samples[0].metainfo_keys(): pad_infos = [ sample.metainfo['padding_size'] for sample in data_samples ] else: pad_infos = None else: if 'padding_size' in data_samples.metainfo_keys(): pad_infos = data_samples.metainfo['padding_size'] else: pad_infos = None # NOTE: here we assume padding size in metainfo are saved as tensor if not isinstance(pad_infos, list): pad_infos = [pad_infos] is_batch_data = False else: is_batch_data = True if all([pad_info is None for pad_info in pad_infos]): pad_infos = None if not is_batch_data: batch_tensor = batch_tensor[None, ...] if pad_infos is None: if self._done_padding: print_log( 'Cannot find padding information (\'padding_size\') in ' 'meta info of \'data_samples\'. Please check whether ' 'you have called \'self.forward\' properly.', 'current', WARNING) return batch_tensor if is_batch_data else batch_tensor[0] if same_padding: # un-pad with the padding info of the first sample padded_h, padded_w = pad_infos[0][-2:] padded_h, padded_w = int(padded_h), int(padded_w) h, w = batch_tensor.shape[-2:] batch_tensor = batch_tensor[..., :h - padded_h, :w - padded_w] return batch_tensor if is_batch_data else batch_tensor[0] else: # un-pad with the corresponding padding info unpadded_tensors = [] for idx, pad_info in enumerate(pad_infos): padded_h, padded_w = pad_info[-2:] padded_h = int(padded_h) padded_w = int(padded_w) h, w = batch_tensor[idx].shape[-2:] unpadded_tensor = batch_tensor[idx][..., :h - padded_h, :w - padded_w] unpadded_tensors.append(unpadded_tensor) return unpadded_tensors if is_batch_data else unpadded_tensors[0]
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.