Source code for mmagic.datasets.transforms.loading
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List, Optional, Tuple
import mmcv
import numpy as np
from mmcv.transforms import BaseTransform
from mmengine.fileio import get_file_backend, list_from_file
from mmagic.registry import TRANSFORMS
from mmagic.utils import (bbox2mask, brush_stroke_mask, get_irregular_mask,
random_bbox)
@TRANSFORMS.register_module()
[docs]class LoadImageFromFile(BaseTransform):
"""Load a single image or image frames from corresponding paths. Required
Keys:
- [Key]_path
New Keys:
- [KEY]
- ori_[KEY]_shape
- ori_[KEY]
Args:
key (str): Keys in results to find corresponding path.
color_type (str): The flag argument for :func:``mmcv.imfrombytes``.
Defaults to 'color'.
channel_order (str): Order of channel, candidates are 'bgr' and 'rgb'.
Default: 'bgr'.
imdecode_backend (str): The image decoding backend type. The backend
argument for :func:``mmcv.imfrombytes``.
See :func:``mmcv.imfrombytes`` for details.
candidates are 'cv2', 'turbojpeg', 'pillow', and 'tifffile'.
Defaults to None.
use_cache (bool): If True, load all images at once. Default: False.
to_float32 (bool): Whether to convert the loaded image to a float32
numpy array. If set to False, the loaded image is an uint8 array.
Defaults to False.
to_y_channel (bool): Whether to convert the loaded image to y channel.
Only support 'rgb2ycbcr' and 'rgb2ycbcr'
Defaults to False.
backend_args (dict, optional): Arguments to instantiate the prefix of
uri corresponding backend. Defaults to None.
"""
def __init__(
self,
key: str,
color_type: str = 'color',
channel_order: str = 'bgr',
imdecode_backend: Optional[str] = None,
use_cache: bool = False,
to_float32: bool = False,
to_y_channel: bool = False,
save_original_img: bool = False,
backend_args: Optional[dict] = None,
) -> None:
self.key = key
self.color_type = color_type
self.channel_order = channel_order
self.imdecode_backend = imdecode_backend
self.save_original_img = save_original_img
if backend_args is None:
# lasy init at loading
self.backend_args = None
self.file_backend = None
else:
self.backend_args = backend_args.copy()
self.file_backend = get_file_backend(backend_args=backend_args)
# cache
self.use_cache = use_cache
self.cache = dict()
# convert
self.to_float32 = to_float32
self.to_y_channel = to_y_channel
[docs] def transform(self, results: dict) -> dict:
"""Functions to load image or frames.
Args:
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
Returns:
dict: The dict contains loaded image and meta information.
"""
filenames = results[f'{self.key}_path']
if not isinstance(filenames, (List, Tuple)):
filenames = [str(filenames)]
is_frames = False
else:
filenames = [str(v) for v in filenames]
is_frames = True
images = []
shapes = []
if self.save_original_img:
ori_imgs = []
for filename in filenames:
img = self._load_image(filename)
img = self._convert(img)
images.append(img)
shapes.append(img.shape)
if self.save_original_img:
ori_imgs.append(img.copy())
if not is_frames:
images = images[0]
shapes = shapes[0]
if self.save_original_img:
ori_imgs = ori_imgs[0]
results[self.key] = images
results[f'ori_{self.key}_shape'] = shapes
results[f'{self.key}_channel_order'] = self.channel_order
results[f'{self.key}_color_type'] = self.color_type
if self.save_original_img:
results[f'ori_{self.key}'] = ori_imgs
return results
[docs] def _load_image(self, filename):
"""Load an image from file.
Args:
filename (str): Path of image file.
Returns:
np.ndarray: Image.
"""
if self.file_backend is None:
self.file_backend = get_file_backend(
uri=filename, backend_args=self.backend_args)
if (self.backend_args is not None) and (self.backend_args.get(
'backend', None) == 'lmdb'):
filename, _ = osp.splitext(osp.basename(filename))
if filename in self.cache:
img_bytes = self.cache[filename]
else:
img_bytes = self.file_backend.get(filename)
if self.use_cache:
self.cache[filename] = img_bytes
img = mmcv.imfrombytes(
content=img_bytes,
flag=self.color_type,
channel_order=self.channel_order,
backend=self.imdecode_backend)
return img
[docs] def _convert(self, img: np.ndarray):
"""Convert an image to the require format.
Args:
img (np.ndarray): The original image.
Returns:
np.ndarray: The converted image.
"""
if self.to_y_channel:
if self.channel_order.lower() == 'rgb':
img = mmcv.rgb2ycbcr(img, y_only=True)
elif self.channel_order.lower() == 'bgr':
img = mmcv.bgr2ycbcr(img, y_only=True)
else:
raise ValueError('Currently support only "bgr2ycbcr" or '
'"bgr2ycbcr".')
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
if self.to_float32:
img = img.astype(np.float32)
return img
[docs] def __repr__(self):
repr_str = (f'{self.__class__.__name__}('
f'key={self.key}, '
f'color_type={self.color_type}, '
f'channel_order={self.channel_order}, '
f'imdecode_backend={self.imdecode_backend}, '
f'use_cache={self.use_cache}, '
f'to_float32={self.to_float32}, '
f'to_y_channel={self.to_y_channel}, '
f'save_original_img={self.save_original_img}, '
f'backend_args={self.backend_args})')
return repr_str
@TRANSFORMS.register_module()
[docs]class LoadMask(BaseTransform):
"""Load Mask for multiple types.
For different types of mask, users need to provide the corresponding
config dict.
Example config for bbox:
.. code-block:: python
config = dict(img_shape=(256, 256), max_bbox_shape=128)
Example config for irregular:
.. code-block:: python
config = dict(
img_shape=(256, 256),
num_vertices=(4, 12),
max_angle=4.,
length_range=(10, 100),
brush_width=(10, 40),
area_ratio_range=(0.15, 0.5))
Example config for ff:
.. code-block:: python
config = dict(
img_shape=(256, 256),
num_vertices=(4, 12),
mean_angle=1.2,
angle_range=0.4,
brush_width=(12, 40))
Example config for set:
.. code-block:: python
config = dict(
mask_list_file='xxx/xxx/ooxx.txt',
prefix='/xxx/xxx/ooxx/',
io_backend='local',
color_type='unchanged',
file_client_kwargs=dict()
)
The mask_list_file contains the list of mask file name like this:
test1.jpeg
test2.jpeg
...
...
The prefix gives the data path.
Args:
mask_mode (str): Mask mode in ['bbox', 'irregular', 'ff', 'set',
'file']. Default: 'bbox'.
* bbox: square bounding box masks.
* irregular: irregular holes.
* ff: free-form holes from DeepFillv2.
* set: randomly get a mask from a mask set.
* file: get mask from 'mask_path' in results.
mask_config (dict): Params for creating masks. Each type of mask needs
different configs. Default: None.
"""
def __init__(self, mask_mode='bbox', mask_config=None):
self.mask_mode = mask_mode
self.mask_config = dict() if mask_config is None else mask_config
assert isinstance(self.mask_config, dict)
# set init info if needed in some modes
self._init_info()
[docs] def _init_info(self):
if self.mask_mode == 'set':
# get mask list information
self.io_backend = self.mask_config['io_backend']
self.color_type = self.mask_config['color_type']
self.file_prefix = self.mask_config['prefix']
self.file_client_kwargs = self.mask_config['file_client_kwargs']
self.file_backend = None
mask_list_file = self.mask_config['mask_list_file']
self.mask_list = list_from_file(
mask_list_file, backend_args=self.file_client_kwargs)
self.mask_list = [
osp.join(self.file_prefix, i) for i in self.mask_list
]
self.mask_set_size = len(self.mask_list)
elif self.mask_mode == 'file':
self.io_backend = 'local'
self.color_type = 'unchanged'
self.file_client_kwargs = dict()
self.file_backend = None
[docs] def _get_random_mask_from_set(self):
if self.file_backend is None:
self.file_backend = get_file_backend(
backend_args={'backend': self.io_backend})
# minus 1 to avoid out of range error
mask_idx = np.random.randint(0, self.mask_set_size)
mask_bytes = self.file_backend.get(self.mask_list[mask_idx])
mask = mmcv.imfrombytes(mask_bytes, flag=self.color_type) # HWC, BGR
if mask.ndim == 2:
mask = np.expand_dims(mask, axis=2)
else:
mask = mask[:, :, 0:1]
mask[mask > 0] = 1.
return mask
[docs] def _get_mask_from_file(self, path):
if self.file_backend is None:
backend_args = self.file_client_kwargs.copy()
backend_args['backend'] = self.io_backend
self.file_backend = get_file_backend(backend_args=backend_args)
mask_bytes = self.file_backend.get(path)
mask = mmcv.imfrombytes(mask_bytes, flag=self.color_type) # HWC, BGR
if mask.ndim == 2:
mask = np.expand_dims(mask, axis=2)
else:
mask = mask[:, :, 0:1]
mask[mask > 0] = 1.
return mask
[docs] def transform(self, results):
"""transform function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
if self.mask_mode == 'bbox':
mask_bbox = random_bbox(**self.mask_config)
mask = bbox2mask(self.mask_config['img_shape'], mask_bbox)
results['mask_bbox'] = mask_bbox
elif self.mask_mode == 'irregular':
mask = get_irregular_mask(**self.mask_config)
elif self.mask_mode == 'set':
mask = self._get_random_mask_from_set()
elif self.mask_mode == 'ff':
mask = brush_stroke_mask(**self.mask_config)
elif self.mask_mode == 'file':
mask = self._get_mask_from_file(results['mask_path'])
else:
raise NotImplementedError(
f'Mask mode {self.mask_mode} has not been implemented.')
results['mask'] = mask
return results
@TRANSFORMS.register_module()
[docs]class GetSpatialDiscountMask(BaseTransform):
"""Get spatial discounting mask constant.
Spatial discounting mask is first introduced in:
Generative Image Inpainting with Contextual Attention.
Args:
gamma (float, optional): Gamma for computing spatial discounting.
Defaults to 0.99.
beta (float, optional): Beta for computing spatial discounting.
Defaults to 1.5.
"""
def __init__(self, gamma=0.99, beta=1.5):
self.gamma = gamma
self.beta = beta
[docs] def spatial_discount_mask(self, mask_width, mask_height):
"""Generate spatial discounting mask constant.
Args:
mask_width (int): The width of bbox hole.
mask_height (int): The height of bbox height.
Returns:
np.ndarray: Spatial discounting mask.
"""
w, h = np.meshgrid(np.arange(mask_width), np.arange(mask_height))
grid_stack = np.stack([h, w], axis=2)
mask_values = (self.gamma**(np.minimum(
grid_stack, [mask_height - 1, mask_width - 1] - grid_stack) *
self.beta)).max(
axis=2, keepdims=True)
return mask_values
[docs] def transform(self, results):
"""transform function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
mask_bbox = results['mask_bbox']
mask = results['mask']
mask_height, mask_width = mask_bbox[-2:]
discount_hole = self.spatial_discount_mask(mask_width, mask_height)
discount_mask = np.zeros_like(mask)
discount_mask[mask_bbox[0]:mask_bbox[0] + mask_height,
mask_bbox[1]:mask_bbox[1] + mask_width,
...] = discount_hole
results['discount_mask'] = discount_mask
return results
f'beta={self.beta})')
@TRANSFORMS.register_module()
[docs]class LoadPairedImageFromFile(LoadImageFromFile):
"""Load a pair of images from file.
Each sample contains a pair of images, which are concatenated in the w
dimension (a|b). This is a special loading class for generation paired
dataset. It loads a pair of images as the common loader does and crops
it into two images with the same shape in different domains.
Required key is "pair_path". Added or modified keys are "pair",
"pair_ori_shape", "ori_pair", "img_{domain_a}", "img_{domain_b}",
"img_{domain_a}_path", "img_{domain_b}_path", "img_{domain_a}_ori_shape",
"img_{domain_b}_ori_shape", "ori_img_{domain_a}" and
"ori_img_{domain_b}".
Args:
key (str): Keys in results to find corresponding path.
domain_a (str, Optional): One of the paired image domain. Defaults
to 'A'.
domain_b (str, Optional): The other of the paired image domain.
Defaults to 'B'.
color_type (str): The flag argument for :func:``mmcv.imfrombytes``.
Defaults to 'color'.
channel_order (str): Order of channel, candidates are 'bgr' and 'rgb'.
Default: 'bgr'.
imdecode_backend (str): The image decoding backend type. The backend
argument for :func:``mmcv.imfrombytes``.
See :func:``mmcv.imfrombytes`` for details.
candidates are 'cv2', 'turbojpeg', 'pillow', and 'tifffile'.
Defaults to None.
use_cache (bool): If True, load all images at once. Default: False.
to_float32 (bool): Whether to convert the loaded image to a float32
numpy array. If set to False, the loaded image is an uint8 array.
Defaults to False.
to_y_channel (bool): Whether to convert the loaded image to y channel.
Only support 'rgb2ycbcr' and 'rgb2ycbcr'
Defaults to False.
backend_args (dict, optional): Arguments to instantiate the prefix of
uri corresponding backend. Defaults to None.
io_backend (str, optional): io backend where images are store. Defaults
to None.
"""
def __init__(self,
key: str,
domain_a: str = 'A',
domain_b: str = 'B',
color_type: str = 'color',
channel_order: str = 'bgr',
imdecode_backend: Optional[str] = None,
use_cache: bool = False,
to_float32: bool = False,
to_y_channel: bool = False,
save_original_img: bool = False,
backend_args: Optional[dict] = None):
super().__init__(key, color_type, channel_order, imdecode_backend,
use_cache, to_float32, to_y_channel,
save_original_img, backend_args)
assert isinstance(domain_a, str)
assert isinstance(domain_b, str)
self.domain_a = domain_a
self.domain_b = domain_b
[docs] def transform(self, results: dict) -> dict:
"""Functions to load paired images.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
filename = results[f'{self.key}_path']
image = self._load_image(filename)
image = self._convert(image)
if self.save_original_img:
ori_image = image.copy()
shape = image.shape
# crop pair into a and b
w = shape[1]
if w % 2 != 0:
raise ValueError(
f'The width of image pair must be even number, but got {w}.')
new_w = w // 2
image_a = image[:, :new_w, :]
image_b = image[:, new_w:, :]
results[f'img_{self.domain_a}'] = image_a
results[f'img_{self.domain_b}'] = image_b
results[f'img_{self.domain_a}_path'] = filename
results[f'img_{self.domain_b}_path'] = filename
results[f'img_{self.domain_a}_ori_shape'] = image_a.shape
results[f'img_{self.domain_b}_ori_shape'] = image_b.shape
if self.save_original_img:
results[f'ori_img_{self.domain_a}'] = image_a.copy()
results[f'ori_img_{self.domain_b}'] = image_b.copy()
results[self.key] = image
results[f'ori_{self.key}_shape'] = shape
results[f'{self.key}_channel_order'] = self.channel_order
results[f'{self.key}_color_type'] = self.color_type
if self.save_original_img:
results[f'ori_{self.key}'] = ori_image
return results