Shortcuts

Source code for mmagic.datasets.transforms.albu_function

# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

import albumentations as albu
from mmcv.transforms import BaseTransform

from mmagic.registry import TRANSFORMS


@TRANSFORMS.register_module()
[docs]class PairedAlbuTransForms(BaseTransform): """PairedAlbuTransForms augmentation. Apply the same AlbuTransforms augmentation to paired images. """ def __init__(self, size: int, lq_key: str = 'img', gt_key: str = 'gt', scope: str = 'geometric', crop: str = 'random', p: float = 0.5): self.size = size self.lq_key = lq_key self.gt_key = gt_key self.scope = scope self.crop = crop self.p = p augs = { 'weak': albu.Compose([ albu.HorizontalFlip(), ], p=self.p), 'geometric': albu.OneOf([ albu.HorizontalFlip(always_apply=True), albu.ShiftScaleRotate(always_apply=True), albu.Transpose(always_apply=True), albu.OpticalDistortion(always_apply=True), albu.ElasticTransform(always_apply=True), ], p=self.p) } aug_fn = augs[self.scope] crop_fn = { 'random': albu.RandomCrop(self.size, self.size, always_apply=True), 'center': albu.CenterCrop(self.size, self.size, always_apply=True) }[self.crop] pad = albu.PadIfNeeded(self.size, self.size) self.pipeline = albu.Compose([aug_fn, pad, crop_fn], additional_targets={'target': 'image'})
[docs] def transform(self, results): """processing input results according to `self.pipeline`. Args: results (dict): contains the processed data through the transform pipeline. Returns: results: the processed data. """ r = self.pipeline( image=results[self.lq_key], target=results[self.gt_key]) results[self.lq_key] = r['image'] results[self.gt_key] = r['target'] return results
[docs] def __repr__(self): repr_str = self.__class__.__name__ repr_str += (f'(size={self.size}, ' f'lq_key={self.lq_key}, ' f'gt_key={self.gt_key}, ' f'scope={self.scope}, ' f'crop={self.crop}, ' f'p={self.p})') return repr_str
@TRANSFORMS.register_module()
[docs]class AlbuTransForms(BaseTransform): """AlbuTransForms augmentation. Apply the same AlbuTransForms augmentation to the input images. """ def __init__(self, size: int, keys: List, scope: str = 'geometric', crop: str = 'random', p: float = 0.5): self.size = size self.keys = keys self.scope = scope self.crop = crop self.p = p augs = { 'weak': albu.Compose([ albu.HorizontalFlip(), ]), 'geometric': albu.OneOf([ albu.HorizontalFlip(always_apply=True), albu.ShiftScaleRotate(always_apply=True), albu.Transpose(always_apply=True), albu.OpticalDistortion(always_apply=True), albu.ElasticTransform(always_apply=True), ], p=self.p) } aug_fn = augs[self.scope] crop_fn = { 'random': albu.RandomCrop(self.size, self.size, always_apply=True), 'center': albu.CenterCrop(self.size, self.size, always_apply=True) }[self.crop] pad = albu.PadIfNeeded(self.size, self.size) self.pipeline = albu.Compose([aug_fn, pad, crop_fn])
[docs] def transform(self, results): """processing input results according to `self.pipeline`. Args: results (dict): contains the processed data through the transform pipeline. Returns: results: the processed data. """ for key in self.keys: r = self.pipeline(image=results[key]) results[key] = r['image'] return results
[docs] def __repr__(self): repr_str = self.__class__.__name__ repr_str += (f'(size={self.size}, ' f'keys={self.keys}, ' f'scope={self.scope}, ' f'crop={self.crop}, ' f'p={self.p})') return repr_str
@TRANSFORMS.register_module()
[docs]class PairedAlbuNormalize(BaseTransform): """PairedAlbuNormalize augmentation. Apply the same AlbuNormalize augmentation to the paired images. """ def __init__(self, lq_key: str, gt_key: str, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), max_pixel_value: float = 255.0, always_apply: bool = False, p: float = 1.0): self.lq_key = lq_key self.gt_key = gt_key self.mean = mean self.std = std self.max_pixel_value = max_pixel_value self.always_apply = always_apply self.p = p normalize = albu.Normalize( mean=self.mean, std=self.std, max_pixel_value=self.max_pixel_value, always_apply=self.always_apply, p=self.p) self.normalize = albu.Compose([normalize], additional_targets={'target': 'image'})
[docs] def transform(self, results): """processing input results according to `self.normalize`. Args: results (dict): contains the processed data through the transform pipeline. Returns: results: the processed data. """ if self.gt_key not in results.keys(): r = self.normalize(image=results[self.lq_key]) else: r = self.normalize( image=results[self.lq_key], target=results[self.gt_key]) results[self.gt_key] = r['target'] results[self.lq_key] = r['image'] return results
[docs] def __repr__(self): repr_str = self.__class__.__name__ repr_str += (f'(lq_key={self.lq_key}, ' f'gt_key={self.gt_key}, ' f'mean={self.mean}, ' f'std={self.std}, ' f'max_pixel_value={self.max_pixel_value}, ' f'always_apply={self.always_apply}, ' f'p={self.p}) ') return repr_str
@TRANSFORMS.register_module()
[docs]class AlbuNormalize(BaseTransform): """AlbuNormalize augmentation. Apply the same AlbuNormalize augmentation to the input images. """ def __init__(self, keys: List, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), max_pixel_value: float = 255.0, always_apply: bool = False, p: float = 1.0): self.keys = keys self.mean = mean self.std = std self.max_pixel_value = max_pixel_value self.always_apply = always_apply self.p = p normalize = albu.Normalize( mean=self.mean, std=self.std, max_pixel_value=self.max_pixel_value, always_apply=self.always_apply, p=self.p) self.normalize = albu.Compose([normalize])
[docs] def transform(self, results): """processing input results according to `self.normalize`. Args: results (dict): contains the processed data through the transform pipeline. Returns: results: the processed data. """ for key in self.keys: r = self.normalize(image=results[key]) results[key] = r['image'] return results
[docs] def __repr__(self): repr_str = self.__class__.__name__ repr_str += (f'(keys={self.keys}, ' f'mean={self.mean}, ' f'std={self.std}, ' f'max_pixel_value={self.max_pixel_value}, ' f'always_apply={self.always_apply}, ' f'p={self.p}) ') return repr_str
[docs]def _resolve_aug_fn(name): d = { 'cutout': albu.Cutout, 'rgb_shift': albu.RGBShift, 'hsv_shift': albu.HueSaturationValue, 'motion_blur': albu.MotionBlur, 'median_blur': albu.MedianBlur, 'snow': albu.RandomSnow, 'shadow': albu.RandomShadow, 'fog': albu.RandomFog, 'brightness_contrast': albu.RandomBrightnessContrast, 'gamma': albu.RandomGamma, 'sun_flare': albu.RandomSunFlare, 'sharpen': albu.Sharpen, 'jpeg': albu.ImageCompression, 'gray': albu.ToGray, 'pixelize': albu.Downscale, # ToDo: partial gray } return d[name]
@TRANSFORMS.register_module()
[docs]class AlbuCorruptFunction(BaseTransform): """AlbuCorruptFunction augmentation. Apply the same AlbuCorruptFunction augmentation to the input images. """ def __init__(self, keys: List[str], config: List[dict], p: float = 1.0): self.keys = keys self.config = config self.p = p augs = [] for aug_params in self.config: name = aug_params.pop('name') cls = _resolve_aug_fn(name) prob = aug_params.pop('prob') if 'prob' in aug_params else .5 augs.append(cls(p=prob, **aug_params)) self.augs = albu.OneOf(augs, p=self.p)
[docs] def transform(self, results): """processing input results according to `self.augs`. Args: results (dict): contains the processed data through the transform pipeline. Returns: results: the processed data. """ for key in self.keys: results[key] = self.augs(image=results[key])['image'] return results
[docs] def __repr__(self): repr_str = self.__class__.__name__ repr_str += (f'(keys={self.keys}, ' f'config={self.config}, ' f'p={self.p}) ') return repr_str
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.