Shortcuts

How to design your own data transforms

In this tutorial, we introduce the design of transforms pipeline in MMagic.

The structure of this guide are as follows:

Data pipelines in MMagic

Following typical conventions, we use Dataset and DataLoader for data loading with multiple workers. Dataset returns a dict of data items corresponding the arguments of models’ forward method.

The data preparation pipeline and the dataset is decomposed. Usually a dataset defines how to process the annotations and a data pipeline defines all the steps to prepare a data dict.

A pipeline consists of a sequence of operations. Each operation takes a dict as input and also output a dict for the next transform.

The operations are categorized into data loading, pre-processing, and formatting

In MMagic, all data transformations are inherited from BaseTransform. The input and output types of transformations are both dict.

A simple example of data transform

>>> from mmagic.transforms import LoadPairedImageFromFile
>>> transforms = LoadPairedImageFromFile(
>>>     key='pair',
>>>     domain_a='horse',
>>>     domain_b='zebra',
>>>     flag='color'),
>>> data_dict = {'pair_path': './data/pix2pix/facades/train/1.png'}
>>> data_dict = transforms(data_dict)
>>> print(data_dict.keys())
dict_keys(['pair_path', 'pair', 'pair_ori_shape', 'img_mask', 'img_photo', 'img_mask_path', 'img_photo_path', 'img_mask_ori_shape', 'img_photo_ori_shape'])

Generally, the last step of the transforms pipeline must be PackInputs. PackInputs will pack the processed data into a dict containing two fields: inputs and data_samples. inputs is the variable you want to use as the model’s input, which can be the type of torch.Tensor, dict of torch.Tensor, or any type you want. data_samples is a list of DataSample. Each DataSample contains groundtruth and necessary information for corresponding input.

An example of BasicVSR

Here is a pipeline example for BasicVSR.

train_pipeline = [
    dict(type='LoadImageFromFile', key='img', channel_order='rgb'),
    dict(type='LoadImageFromFile', key='gt', channel_order='rgb'),
    dict(type='SetValues', dictionary=dict(scale=scale)),
    dict(type='PairedRandomCrop', gt_patch_size=256),
    dict(
        type='Flip',
        keys=['img', 'gt'],
        flip_ratio=0.5,
        direction='horizontal'),
    dict(
        type='Flip', keys=['img', 'gt'], flip_ratio=0.5, direction='vertical'),
    dict(type='RandomTransposeHW', keys=['img', 'gt'], transpose_ratio=0.5),
    dict(type='MirrorSequence', keys=['img', 'gt']),
    dict(type='PackInputs')
]

val_pipeline = [
    dict(type='GenerateSegmentIndices', interval_list=[1]),
    dict(type='LoadImageFromFile', key='img', channel_order='rgb'),
    dict(type='LoadImageFromFile', key='gt', channel_order='rgb'),
    dict(type='PackInputs')
]

test_pipeline = [
    dict(type='LoadImageFromFile', key='img', channel_order='rgb'),
    dict(type='LoadImageFromFile', key='gt', channel_order='rgb'),
    dict(type='MirrorSequence', keys=['img']),
    dict(type='PackInputs')
]

For each operation, we list the related dict fields that are added/updated/removed, the dict fields marked by ‘*’ are optional.

An example of Pix2Pix

Here is a pipeline example for Pix2Pix training on aerial2maps dataset.

source_domain = 'aerial'
target_domain = 'map'

pipeline = [
    dict(
        type='LoadPairedImageFromFile',
        io_backend='disk',
        key='pair',
        domain_a=domain_a,
        domain_b=domain_b,
        flag='color'),
    dict(
        type='TransformBroadcaster',
        mapping={'img': [f'img_{domain_a}', f'img_{domain_b}']},
        auto_remap=True,
        share_random_params=True,
        transforms=[
            dict(
                type='mmagic.Resize', scale=(286, 286),
                interpolation='bicubic'),
            dict(type='mmagic.FixedCrop', crop_size=(256, 256))
        ]),
    dict(
        type='Flip',
        keys=[f'img_{domain_a}', f'img_{domain_b}'],
        direction='horizontal'),
    dict(
        type='PackInputs',
        keys=[f'img_{domain_a}', f'img_{domain_b}', 'pair'])

Supported transforms in MMagic

Data loading

Transform Modification of Results' keys
LoadImageFromFile - add: img, img_path, img_ori_shape, \*ori_img
RandomLoadResizeBg - add: bg
LoadMask - add: mask
GetSpatialDiscountMask - add: discount_mask

Pre-processing

Transform Modification of Results' keys
Resize - add: scale_factor, keep_ratio, interpolation, backend - update: specified by keys
MATLABLikeResize - add: scale, output_shape - update: specified by keys
RandomRotation - add: degrees - update: specified by keys
Flip - add: flip, flip_direction - update: specified by keys
RandomAffine - update: specified by keys
RandomJitter - update: fg (img)
ColorJitter - update: specified by keys
BinarizeImage - update: specified by keys
RandomMaskDilation - add: img_dilate_kernel_size
RandomTransposeHW - add: transpose
RandomDownSampling - update: scale, gt (img), lq (img)
RandomBlur - update: specified by keys
RandomResize - update: specified by keys
RandomNoise - update: specified by keys
RandomJPEGCompression - update: specified by keys
RandomVideoCompression - update: specified by keys
DegradationsWithShuffle - update: specified by keys
GenerateFrameIndices - update: img_path (gt_path, lq_path)
GenerateFrameIndiceswithPadding - update: img_path (gt_path, lq_path)
TemporalReverse - add: reverse - update: specified by keys
GenerateSegmentIndices - add: interval - update: img_path (gt_path, lq_path)
MirrorSequence - update: specified by keys
CopyValues - add: specified by dst_key
UnsharpMasking - add: img_unsharp
Crop - add: img_crop_bbox, crop_size - update: specified by keys
RandomResizedCrop - add: img_crop_bbox - update: specified by keys
FixedCrop - add: crop_size, crop_pos - update: specified by keys
PairedRandomCrop - update: gt (img), lq (img)
CropAroundCenter - add: crop_bbox - update: fg (img), alpha (img), trimap (img), bg (img)
CropAroundUnknown - add: crop_bbox - update: specified by keys
CropAroundFg - add: crop_bbox - update: specified by keys
ModCrop - update: gt (img)
CropLike - update: specified by target_key
GetMaskedImage - add: masked_img
GenerateFacialHeatmap - add: heatmap
GenerateCoordinateAndCell - add: coord, cell - update: gt (img)
Normalize - add: img_norm_cfg - update: specified by keys
RescaleToZeroOne - update: specified by keys

Formatting

Transform Modification of Results' keys
ToTensor update: specified by keys.
FormatTrimap - update: trimap
PackInputs - add: inputs, data_sample - remove: all other keys

Albumentations

MMagic support adding custom transformations from Albumentations library. Please visit https://albumentations.ai/docs/getting_started/transforms_and_targets to get more information.

An example of Albumentations’s transforms is as followed:

albu_transforms = [
   dict(
         type='Resize',
         height=100,
         width=100,
   ),
   dict(
         type='RandomFog',
         p=0.5,
   ),
   dict(
         type='RandomRain',
         p=0.5
   ),
   dict(
         type='RandomSnow',
         p=0.5,
   ),
]
pipeline = [
   dict(
         type='LoadImageFromFile',
         key='img',
         color_type='color',
         channel_order='rgb',
         imdecode_backend='cv2'),
   dict(
         type='Albumentations',
         keys=['img'],
         transforms=albu_transforms),
   dict(type='PackInputs')
]

Extend and use custom pipelines

A simple example of MyTransform

  1. Write a new pipeline in a file, e.g., in my_pipeline.py. It takes a dict as input and returns a dict.

import random
from mmcv.transforms import BaseTransform
from mmagic.registry import TRANSFORMS


@TRANSFORMS.register_module()
class MyTransform(BaseTransform):
    """Add your transform

    Args:
        p (float): Probability of shifts. Default 0.5.
    """

    def __init__(self, p=0.5):
        self.p = p

    def transform(self, results):
        if random.random() > self.p:
            results['dummy'] = True
        return results

    def __repr__(self):

        repr_str = self.__class__.__name__
        repr_str += (f'(p={self.p})')

        return repr_str
  1. Import and use the pipeline in your config file.

Make sure the import is relative to where your train script is located.

train_pipeline = [
    ...
    dict(type='MyTransform', p=0.2),
    ...
]

An example of flipping

Here we use a simple flipping transformation as example:

import random
import mmcv
from mmcv.transforms import BaseTransform, TRANSFORMS

@TRANSFORMS.register_module()
class MyFlip(BaseTransform):
    def __init__(self, direction: str):
        super().__init__()
        self.direction = direction

    def transform(self, results: dict) -> dict:
        img = results['img']
        results['img'] = mmcv.imflip(img, direction=self.direction)
        return results

Thus, we can instantiate a MyFlip object and use it to process the data dict.

import numpy as np

transform = MyFlip(direction='horizontal')
data_dict = {'img': np.random.rand(224, 224, 3)}
data_dict = transform(data_dict)
processed_img = data_dict['img']

Or, we can use MyFlip transformation in data pipeline in our config file.

pipeline = [
    ...
    dict(type='MyFlip', direction='horizontal'),
    ...
]

Note that if you want to use MyFlip in config, you must ensure the file containing MyFlip is imported during the program run.

Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.