Shortcuts

mmagic.models

Package Contents

Classes

BaseConditionalGAN

Base class for Conditional GAM models.

BaseEditModel

Base model for image and video editing.

BaseGAN

Base class for GAN models.

BaseMattor

Base class for trimap-based matting models.

BaseTranslationModel

Base Translation Model.

BasicInterpolator

Basic model for video interpolation.

ExponentialMovingAverage

Implements the exponential moving average (EMA) of the model.

DataPreprocessor

Image pre-processor for generative models. This class provide

MattorPreprocessor

DataPreprocessor for matting models.

class mmagic.models.BaseConditionalGAN(generator: ModelType, discriminator: Optional[ModelType] = None, data_preprocessor: Optional[Union[dict, mmengine.Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, noise_size: Optional[int] = None, num_classes: Optional[int] = None, ema_config: Optional[Dict] = None, loss_config: Optional[Dict] = None)

Bases: mmagic.models.base_models.base_gan.BaseGAN

Base class for Conditional GAM models.

参数
  • generator (ModelType) – The config or model of the generator.

  • discriminator (Optional[ModelType]) – The config or model of the discriminator. Defaults to None.

  • data_preprocessor (Optional[Union[dict, Config]]) – The pre-process config or DataPreprocessor.

  • generator_steps (int) – The number of times the generator is completely updated before the discriminator is updated. Defaults to 1.

  • discriminator_steps (int) – The number of times the discriminator is completely updated before the generator is updated. Defaults to 1.

  • noise_size (Optional[int]) – Size of the input noise vector. Default to None.

  • num_classes (Optional[int]) – The number classes you would like to generate. Defaults to None.

  • ema_config (Optional[Dict]) – The config for generator’s exponential moving average setting. Defaults to None.

label_fn(label: mmagic.utils.typing.LabelVar = None, num_batches: int = 1) torch.Tensor

Sampling function for label. There are three scenarios in this function:

  • If label is a callable function, sample num_batches of labels with passed label.

  • If label is None, sample num_batches of labels in range of [0, self.num_classes-1] uniformly.

  • If label is a torch.Tensor, check the range of the tensor is in [0, self.num_classes-1]. If all values are in valid range, directly return label.

参数
  • label (Union[Tensor, Callable, List[int], None]) – You can directly give a batch of label through a torch.Tensor or offer a callable function to sample a batch of label data. Otherwise, the None indicates to use the default label sampler. Defaults to None.

  • num_batches (int, optional) – The number of batches label want to sample. If label is a Tensor, this will be ignored. Defaults to 1.

返回

Sampled label tensor.

返回类型

Tensor

data_sample_to_label(data_sample: mmagic.structures.DataSample) Optional[torch.Tensor]

Get labels from input data_sample and pack to torch.Tensor. If no label is found in the passed data_sample, None would be returned.

参数

data_sample (DataSample) – Input data samples.

返回

Packed label tensor.

返回类型

Optional[torch.Tensor]

static _get_valid_num_classes(num_classes: Optional[int], generator: ModelType, discriminator: Optional[ModelType]) int

Try to get the value of num_classes from input, generator and discriminator and check the consistency of these values. If no conflict is found, return the num_classes.

参数
  • num_classes (Optional[int]) – num_classes passed to BaseConditionalGAN_refactor’s initialize function.

  • generator (ModelType) – The config or the model of generator.

  • discriminator (Optional[ModelType]) – The config or model of discriminator.

返回

The number of classes to be generated.

返回类型

int

forward(inputs: mmagic.utils.typing.ForwardInputs, data_samples: Optional[list] = None, mode: Optional[str] = None) List[mmagic.structures.DataSample]

Sample images with the given inputs. If forward mode is ‘ema’ or ‘orig’, the image generated by corresponding generator will be returned. If forward mode is ‘ema/orig’, images generated by original generator and EMA generator will both be returned in a dict.

参数
  • inputs (ForwardInputs) – Dict containing the necessary information (e.g. noise, num_batches, mode) to generate image.

  • data_samples (Optional[list]) – Data samples collated by data_preprocessor. Defaults to None.

  • mode (Optional[str]) – mode is not used in BaseConditionalGAN. Defaults to None.

返回

Generated images or image dict.

返回类型

List[DataSample]

train_generator(inputs: dict, data_samples: List[mmagic.structures.DataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor]

Training function for discriminator. All GANs should implement this function by themselves.

参数
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[DataSample]) – Data samples from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

返回

A dict of tensor for logging.

返回类型

Dict[str, Tensor]

train_discriminator(inputs: dict, data_samples: List[mmagic.structures.DataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor]

Training function for discriminator. All GANs should implement this function by themselves.

参数
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[DataSample]) – Data samples from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

返回

A dict of tensor for logging.

返回类型

Dict[str, Tensor]

class mmagic.models.BaseEditModel(generator: dict, pixel_loss: dict, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None, init_cfg: Optional[dict] = None, data_preprocessor: Optional[dict] = None)

Bases: mmengine.model.BaseModel

Base model for image and video editing.

It must contain a generator that takes frames as inputs and outputs an interpolated frame. It also has a pixel-wise loss for training.

参数
  • generator (dict) – Config for the generator structure.

  • pixel_loss (dict) – Config for pixel-wise loss.

  • train_cfg (dict) – Config for training. Default: None.

  • test_cfg (dict) – Config for testing. Default: None.

  • init_cfg (dict, optional) – The weight initialized config for BaseModule.

  • data_preprocessor (dict, optional) – The pre-process config of BaseDataPreprocessor.

init_cfg

Initialization config dict.

Type

dict, optional

data_preprocessor

Used for pre-processing data sampled by dataloader to the format accepted by forward(). Default: None.

Type

BaseDataPreprocessor

forward(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, mode: str = 'tensor', **kwargs) Union[torch.Tensor, List[mmagic.structures.DataSample], dict]

Returns losses or predictions of training, validation, testing, and simple inference process.

forward method of BaseModel is an abstract method, its subclasses must implement this method.

Accepts inputs and data_samples processed by data_preprocessor, and returns results according to mode arguments.

During non-distributed training, validation, and testing process, forward will be called by BaseModel.train_step, BaseModel.val_step and BaseModel.val_step directly.

During distributed data parallel training process, MMSeparateDistributedDataParallel.train_step will first call DistributedDataParallel.forward to enable automatic gradient synchronization, and then call forward to get training loss.

参数
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

  • mode (str) –

    mode should be one of loss, predict and tensor. Default: ‘tensor’.

    • loss: Called by train_step and return loss dict used for logging

    • predict: Called by val_step and test_step and return list of BaseDataElement results used for computing metric.

    • tensor: Called by custom use to get Tensor type results.

返回

  • If mode == loss, return a dict of loss tensor used for backward and logging.

  • If mode == predict, return a list of BaseDataElement for computing metric and getting inference result.

  • If mode == tensor, return a tensor or tuple of tensor or dict or tensor for custom use.

返回类型

ForwardResults

convert_to_datasample(predictions: mmagic.structures.DataSample, data_samples: mmagic.structures.DataSample, inputs: Optional[torch.Tensor]) List[mmagic.structures.DataSample]

Add predictions and destructed inputs (if passed) to data samples.

参数
  • predictions (DataSample) – The predictions of the model.

  • data_samples (DataSample) – The data samples loaded from dataloader.

  • inputs (Optional[torch.Tensor]) – The input of model. Defaults to None.

返回

Modified data samples.

返回类型

List[DataSample]

forward_tensor(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, **kwargs) torch.Tensor

Forward tensor. Returns result of simple forward.

参数
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

返回

result of simple forward.

返回类型

Tensor

forward_inference(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, **kwargs) mmagic.structures.DataSample

Forward inference. Returns predictions of validation, testing, and simple inference.

参数
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

返回

predictions.

返回类型

DataSample

forward_train(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, **kwargs) Dict[str, torch.Tensor]

Forward training. Returns dict of losses of training.

参数
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

返回

Dict of losses.

返回类型

dict

class mmagic.models.BaseGAN(generator: ModelType, discriminator: Optional[ModelType] = None, data_preprocessor: Optional[Union[dict, mmengine.Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, noise_size: Optional[int] = None, ema_config: Optional[Dict] = None, loss_config: Optional[Dict] = None)

Bases: mmengine.model.BaseModel

Base class for GAN models.

参数
  • generator (ModelType) – The config or model of the generator.

  • discriminator (Optional[ModelType]) – The config or model of the discriminator. Defaults to None.

  • data_preprocessor (Optional[Union[dict, Config]]) – The pre-process config or DataPreprocessor.

  • generator_steps (int) – The number of times the generator is completely updated before the discriminator is updated. Defaults to 1.

  • discriminator_steps (int) – The number of times the discriminator is completely updated before the generator is updated. Defaults to 1.

  • ema_config (Optional[Dict]) – The config for generator’s exponential moving average setting. Defaults to None.

property generator_steps: int

The number of times the generator is completely updated before the discriminator is updated.

Type

int

property discriminator_steps: int

The number of times the discriminator is completely updated before the generator is updated.

Type

int

property device: torch.device

Get current device of the model.

返回

The current device of the model.

返回类型

torch.device

property with_ema_gen: bool

Whether the GAN adopts exponential moving average.

返回

If True, means this GAN model is adopted to exponential

moving average and vice versa.

返回类型

bool

static gather_log_vars(log_vars_list: List[Dict[str, torch.Tensor]]) Dict[str, torch.Tensor]

Gather a list of log_vars. :param log_vars_list: List[Dict[str, Tensor]]

返回

Dict[str, Tensor]

_init_loss(loss_config: Optional[Dict] = None) None

Initialize customized loss modules.

If loss_config is a dict, we allow kinds of value for each field.

  1. loss_config is None: Users will implement all loss calculations

    in their own function. Weights for each loss terms are hard coded.

  2. loss_config is dict of scalar or string: Users will implement all

    loss calculations and use passed loss_config to control the weight or behavior of the loss calculation. Users will unpack and use each field in this dict by themselves.

    loss_config = dict(gp_norm_mode=’HWC’, gp_loss_weight=10)

  3. loss_config is dict of dict: Each field in loss_config will

    used to build a corresponding loss module. And use loss calculation function predefined by BaseGAN to calculate the loss.

    loss_config = dict()

示例

loss_config = dict(

# BaseGAN pre-defined fields gan_loss=dict(type=’GANLoss’, gan_type=’wgan-logistic-ns’), disc_auxiliary_loss=dict(

type=’R1GradientPenalty’, loss_weight=10. / 2., interval=2, norm_mode=’HWC’, data_info=dict(

real_data=’real_imgs’, discriminator=’disc’)),

gen_auxiliary_loss=dict(

type=’GeneratorPathRegularizer’, loss_weight=2, pl_batch_shrink=2, interval=g_reg_interval, data_info=dict(

generator=’gen’, num_batches=’batch_size’)),

# user-defined field for loss weights or loss calculation my_loss_2=dict(weight=2, norm_mode=’L1’), my_loss_3=2, my_loss_4_norm_type=’L2’)

参数

loss_config (Optional[Dict], optional) – Loss config used to build loss modules or define the loss weights. Defaults to None.

noise_fn(noise: mmagic.utils.typing.NoiseVar = None, num_batches: int = 1)

Sampling function for noise. There are three scenarios in this function:

  • If noise is a callable function, sample num_batches of noise with passed noise.

  • If noise is None, sample num_batches of noise from gaussian distribution.

  • If noise is a torch.Tensor, directly return noise.

参数
  • noise (Union[Tensor, Callable, List[int], None]) – You can directly give a batch of label through a torch.Tensor or offer a callable function to sample a batch of label data. Otherwise, the None indicates to use the default noise sampler. Defaults to None.

  • num_batches (int, optional) – The number of batches label want to sample. If label is a Tensor, this will be ignored. Defaults to 1.

返回

Sampled noise tensor.

返回类型

Tensor

_init_ema_model(ema_config: dict)

Initialize a EMA model corresponding to the given ema_config. If ema_config is an empty dict or None, EMA model will not be initialized.

参数

ema_config (dict) – Config to initialize the EMA model.

_get_valid_model(batch_inputs: mmagic.utils.typing.ForwardInputs) str

Try to get the valid forward model from inputs.

  • If forward model is defined in batch_inputs, it will be used as forward model.

  • If forward model is not defined in batch_inputs, ‘ema’ will returned if :property:`with_ema_gen` is true. Otherwise, ‘orig’ will be returned.

参数

batch_inputs (ForwardInputs) – Inputs passed to forward().

返回

Forward model to generate image. (‘orig’, ‘ema’ or

’ema/orig’).

返回类型

str

forward(inputs: mmagic.utils.typing.ForwardInputs, data_samples: Optional[list] = None, mode: Optional[str] = None) mmagic.utils.typing.SampleList

Sample images with the given inputs. If forward mode is ‘ema’ or ‘orig’, the image generated by corresponding generator will be returned. If forward mode is ‘ema/orig’, images generated by original generator and EMA generator will both be returned in a dict.

参数
  • batch_inputs (ForwardInputs) – Dict containing the necessary information (e.g. noise, num_batches, mode) to generate image.

  • data_samples (Optional[list]) – Data samples collated by data_preprocessor. Defaults to None.

  • mode (Optional[str]) – mode is not used in BaseGAN. Defaults to None.

返回

A list of DataSample contain generated results.

返回类型

SampleList

val_step(data: dict) mmagic.utils.typing.SampleList

Gets the generated image of given data.

Calls self.data_preprocessor(data) and self(inputs, data_sample, mode=None) in order. Return the generated results which will be passed to evaluator.

参数

data (dict) – Data sampled from metric specific sampler. More details in Metrics and Evaluator.

返回

Generated image or image dict.

返回类型

SampleList

test_step(data: dict) mmagic.utils.typing.SampleList

Gets the generated image of given data. Same as val_step().

参数

data (dict) – Data sampled from metric specific sampler. More details in Metrics and Evaluator.

返回

Generated image or image dict.

返回类型

List[DataSample]

train_step(data: dict, optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor]

Train GAN model. In the training of GAN models, generator and discriminator are updated alternatively. In MMagic’s design, self.train_step is called with data input. Therefore we always update discriminator, whose updating is relay on real data, and then determine if the generator needs to be updated based on the current number of iterations. More details about whether to update generator can be found in should_gen_update().

参数
  • data (dict) – Data sampled from dataloader.

  • optim_wrapper (OptimWrapperDict) – OptimWrapperDict instance contains OptimWrapper of generator and discriminator.

返回

A dict of tensor for logging.

返回类型

Dict[str, torch.Tensor]

_get_gen_loss(out_dict)
_get_disc_loss(out_dict)
train_generator(inputs: dict, data_samples: List[mmagic.structures.DataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor]

Training function for discriminator. All GANs should implement this function by themselves.

参数
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[DataSample]) – Data samples from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

返回

A dict of tensor for logging.

返回类型

Dict[str, Tensor]

train_discriminator(inputs: dict, data_samples: List[mmagic.structures.DataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor]

Training function for discriminator. All GANs should implement this function by themselves.

参数
  • inputs (dict) – Inputs from dataloader.

  • data_samples (List[DataSample]) – Data samples from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

返回

A dict of tensor for logging.

返回类型

Dict[str, Tensor]

class mmagic.models.BaseMattor(data_preprocessor: Union[dict, mmengine.config.Config], backbone: dict, init_cfg: Optional[dict] = None, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None)

Bases: mmengine.model.BaseModel

Base class for trimap-based matting models.

A matting model must contain a backbone which produces pred_alpha, a dense prediction with the same height and width of input image. In some cases (such as DIM), the model has a refiner which refines the prediction of the backbone.

Subclasses should overwrite the following functions:

  • _forward_train(), to return a loss

  • _forward_test(), to return a prediction

  • _forward(), to return raw tensors

For test, this base class provides functions to resize inputs and post-process pred_alphas to get predictions

参数
  • backbone (dict) – Config of backbone.

  • data_preprocessor (dict) – Config of data_preprocessor. See MattorPreprocessor for details.

  • init_cfg (dict, optional) – Initialization config dict.

  • train_cfg (dict) – Config of training. Customized by subclassesCustomized bu In train_cfg, train_backbone should be specified. If the model has a refiner, train_refiner should be specified.

  • test_cfg (dict) – Config of testing. In test_cfg, If the model has a refiner, train_refiner should be specified.

resize_inputs(batch_inputs: torch.Tensor) torch.Tensor

Pad or interpolate images and trimaps to multiple of given factor.

restore_size(pred_alpha: torch.Tensor, data_sample: mmagic.structures.DataSample) torch.Tensor

Restore the predicted alpha to the original shape.

The shape of the predicted alpha may not be the same as the shape of original input image. This function restores the shape of the predicted alpha.

参数
  • pred_alpha (torch.Tensor) – A single predicted alpha of shape (1, H, W).

  • data_sample (DataSample) – Data sample containing original shape as meta data.

返回

The reshaped predicted alpha.

返回类型

torch.Tensor

postprocess(batch_pred_alpha: torch.Tensor, data_samples: mmagic.structures.DataSample) List[mmagic.structures.DataSample]

Post-process alpha predictions.

This function contains the following steps:
  1. Restore padding or interpolation

  2. Mask alpha prediction with trimap

  3. Clamp alpha prediction to 0-1

  4. Convert alpha prediction to uint8

  5. Pack alpha prediction into DataSample

Currently only batch_size 1 is actually supported.

参数
  • batch_pred_alpha (torch.Tensor) – A batch of predicted alpha of shape (N, 1, H, W).

  • data_samples (List[DataSample]) – List of data samples.

返回

A list of predictions.

Each data sample contains a pred_alpha, which is a torch.Tensor with dtype=uint8, device=cuda:0

返回类型

List[DataSample]

forward(inputs: torch.Tensor, data_samples: DataSamples = None, mode: str = 'tensor') List[mmagic.structures.DataSample]

General forward function.

参数
  • inputs (torch.Tensor) – A batch of inputs. with image and trimap concatenated alone channel dimension.

  • data_samples (List[DataSample], optional) – A list of data samples, containing: - Ground-truth alpha / foreground / background to compute loss - other meta information

  • mode (str) –

    mode should be one of loss, predict and tensor. Default: ‘tensor’.

    • loss: Called by train_step and return loss dict used for logging

    • predict: Called by val_step and test_step and return list of BaseDataElement results used for computing metric.

    • tensor: Called by custom use to get Tensor type results.

返回

Sequence of predictions packed into DataElement

返回类型

List[DataElement]

convert_to_datasample(predictions: List[mmagic.structures.DataSample], data_samples: mmagic.structures.DataSample) List[mmagic.structures.DataSample]

Add predictions to data samples.

参数
  • predictions (List[DataSample]) – The predictions of the model.

  • data_samples (DataSample) – The data samples loaded from dataloader.

返回

Modified data samples.

返回类型

List[DataSample]

class mmagic.models.BaseTranslationModel(generator, discriminator, default_domain: str, reachable_domains: List[str], related_domains: List[str], data_preprocessor, discriminator_steps: int = 1, disc_init_steps: int = 0, real_img_key: str = 'real_img', loss_config: Optional[dict] = None)

Bases: mmengine.model.BaseModel

Base Translation Model.

Translation models can transfer images from one domain to another. Domain information like default_domain, reachable_domains are needed to initialize the class. And we also provide query functions like is_domain_reachable, get_other_domains.

You can get a specific generator based on the domain, and by specifying target_domain in the forward function, you can decide the domain of generated images. Considering the difference among different image translation models, we only provide the external interfaces mentioned above. When you implement image translation with a specific method, you can inherit both BaseTranslationModel and the method (e.g BaseGAN) and implement abstract methods.

参数
  • default_domain (str) – Default output domain.

  • reachable_domains (list[str]) – Domains that can be generated by the model.

  • related_domains (list[str]) – Domains involved in training and testing. reachable_domains must be contained in related_domains. However, related_domains may contain source domains that are used to retrieve source images from data_batch but not in reachable_domains.

  • discriminator_steps (int) – The number of times the discriminator is completely updated before the generator is updated. Defaults to 1.

  • disc_init_steps (int) – The number of initial steps used only to train discriminators.

init_weights()

Initialize weights for the module dict.

参数

pretrained (str, optional) – Path for pretrained weights. If given None, pretrained weights will not be loaded. Default: None.

get_module(module)

Get nn.ModuleDict to fit the MMDistributedDataParallel interface.

参数

module (MMDistributedDataParallel | nn.ModuleDict) – The input module that needs processing.

返回

The ModuleDict of multiple networks.

返回类型

nn.ModuleDict

forward(img, test_mode=False, **kwargs)

Forward function.

参数
  • img (tensor) – Input image tensor.

  • test_mode (bool) – Whether in test mode or not. Default: False.

  • kwargs (dict) – Other arguments.

forward_train(img, target_domain, **kwargs)

Forward function for training.

参数
  • img (tensor) – Input image tensor.

  • target_domain (str) – Target domain of output image.

  • kwargs (dict) – Other arguments.

返回

Forward results.

返回类型

dict

forward_test(img, target_domain, **kwargs)

Forward function for testing.

参数
  • img (tensor) – Input image tensor.

  • target_domain (str) – Target domain of output image.

  • kwargs (dict) – Other arguments.

返回

Forward results.

返回类型

dict

is_domain_reachable(domain)

Whether image of this domain can be generated.

get_other_domains(domain)

get other domains.

_get_target_generator(domain)

get target generator.

_get_target_discriminator(domain)

get target discriminator.

translation(image, target_domain=None, **kwargs)

Translation Image to target style.

参数
  • image (tensor) – Image tensor with a shape of (N, C, H, W).

  • target_domain (str, optional) – Target domain of output image. Default to None.

返回

Image tensor of target style.

返回类型

dict

class mmagic.models.BasicInterpolator(generator: dict, pixel_loss: dict, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None, required_frames: int = 2, step_frames: int = 1, init_cfg: Optional[dict] = None, data_preprocessor: Optional[dict] = None)

Bases: mmagic.models.base_models.base_edit_model.BaseEditModel

Basic model for video interpolation.

It must contain a generator that takes frames as inputs and outputs an interpolated frame. It also has a pixel-wise loss for training.

参数
  • generator (dict) – Config for the generator structure.

  • pixel_loss (dict) – Config for pixel-wise loss.

  • train_cfg (dict) – Config for training. Default: None.

  • test_cfg (dict) – Config for testing. Default: None.

  • required_frames (int) – Required frames in each process. Default: 2

  • step_frames (int) – Step size of video frame interpolation. Default: 1

  • init_cfg (dict, optional) – The weight initialized config for BaseModule.

  • data_preprocessor (dict, optional) – The pre-process config of BaseDataPreprocessor.

init_cfg

Initialization config dict.

Type

dict, optional

data_preprocessor

Used for pre-processing data sampled by dataloader to the format accepted by forward().

Type

BaseDataPreprocessor

split_frames(input_tensors: torch.Tensor) torch.Tensor

split input tensors for inference.

参数

input_tensors (Tensor) – Tensor of input frames with shape [1, t, c, h, w]

返回

Split tensor with shape [t-1, 2, c, h, w]

返回类型

Tensor

static merge_frames(input_tensors: torch.Tensor, output_tensors: torch.Tensor) list

merge input frames and output frames.

Interpolate a frame between the given two frames.

Merged from

[[in1, in2], [in2, in3], [in3, in4], …] [[out1], [out2], [out3], …]

to

[in1, out1, in2, out2, in3, out3, in4, …]

参数
  • input_tensors (Tensor) – The input frames with shape [n, 2, c, h, w]

  • output_tensors (Tensor) – The output frames with shape [n, 1, c, h, w].

返回

The final frames.

返回类型

list[np.array]

class mmagic.models.ExponentialMovingAverage(model: torch.nn.Module, momentum: float = 0.0002, interval: int = 1, device: Optional[torch.device] = None, update_buffers: bool = False)

Bases: mmengine.model.BaseAveragedModel

Implements the exponential moving average (EMA) of the model.

All parameters are updated by the formula as below:

\[Xema_{t+1} = (1 - momentum) * Xema_{t} + momentum * X_t\]
参数
  • model (nn.Module) – The model to be averaged.

  • momentum (float) – The momentum used for updating ema parameter. Defaults to 0.0002. Ema’s parameter are updated with the formula \(averaged\_param = (1-momentum) * averaged\_param + momentum * source\_param\).

  • interval (int) – Interval between two updates. Defaults to 1.

  • device (torch.device, optional) – If provided, the averaged model will be stored on the device. Defaults to None.

  • update_buffers (bool) – if True, it will compute running averages for both the parameters and the buffers of the model. Defaults to False.

avg_func(averaged_param: torch.Tensor, source_param: torch.Tensor, steps: int) None

Compute the moving average of the parameters using exponential moving average.

参数
  • averaged_param (Tensor) – The averaged parameters.

  • source_param (Tensor) – The source parameters.

  • steps (int) – The number of times the parameters have been updated.

_load_from_state_dict(state_dict: dict, prefix: str, local_metadata: dict, strict: bool, missing_keys: list, unexpected_keys: list, error_msgs: List[str]) None

Overrides nn.Module._load_from_state_dict to support loading state_dict without wrap ema module with BaseAveragedModel.

In OpenMMLab 1.0, model will not wrap ema submodule with BaseAveragedModel, and the ema weight key in state_dict will miss module prefix. Therefore, BaseAveragedModel need to automatically add the module prefix if the corresponding key in state_dict misses it.

参数
  • state_dict (dict) – A dict containing parameters and persistent buffers.

  • prefix (str) – The prefix for parameters and buffers used in this module

  • local_metadata (dict) – a dict containing the metadata for this module.

  • strict (bool) – Whether to strictly enforce that the keys in state_dict with prefix match the names of parameters and buffers in this module

  • missing_keys (List[str]) – if strict=True, add missing keys to this list

  • unexpected_keys (List[str]) – if strict=True, add unexpected keys to this list

  • error_msgs (List[str]) – error messages should be added to this list, and will be reported together in load_state_dict().

sync_buffers(model: torch.nn.Module) None

Copy buffer from model to averaged model.

参数

model (nn.Module) – The model whose parameters will be averaged.

sync_parameters(model: torch.nn.Module) None

Copy buffer and parameters from model to averaged model.

参数

model (nn.Module) – The model whose parameters will be averaged.

class mmagic.models.DataPreprocessor(mean: Union[Sequence[Union[float, int]], float, int] = 127.5, std: Union[Sequence[Union[float, int]], float, int] = 127.5, pad_size_divisor: int = 1, pad_value: Union[float, int] = 0, pad_mode: str = 'constant', non_image_keys: Optional[Tuple[str, List[str]]] = None, non_concentate_keys: Optional[Tuple[str, List[str]]] = None, output_channel_order: Optional[str] = None, data_keys: Union[List[str], str] = 'gt_img', input_view: Optional[tuple] = None, output_view: Optional[tuple] = None, stack_data_sample=True)

Bases: mmengine.model.ImgDataPreprocessor

Image pre-processor for generative models. This class provide normalization and bgr to rgb conversion for image tensor inputs. The input of this classes should be dict which keys are inputs and data_samples.

Besides to process tensor inputs, this class support dict as inputs. - If the value is Tensor and the corresponding key is not contained in _NON_IMAGE_KEYS, it will be processed as image tensor. - If the value is Tensor and the corresponding key belongs to _NON_IMAGE_KEYS, it will not remains unchanged. - If value is string or integer, it will not remains unchanged.

参数
  • mean (Sequence[float or int], float or int, optional) – The pixel mean of image channels. Noted that normalization operation is performed after channel order conversion. If it is not specified, images will not be normalized. Defaults None.

  • std (Sequence[float or int], float or int, optional) – The pixel standard deviation of image channels. Noted that normalization operation is performed after channel order conversion. If it is not specified, images will not be normalized. Defaults None.

  • pad_size_divisor (int) – The size of padded image should be divisible by pad_size_divisor. Defaults to 1.

  • pad_value (float or int) – The padded pixel value. Defaults to 0.

  • pad_mode (str) – Padding mode for torch.nn.functional.pad. Defaults to ‘constant’.

  • non_image_keys (List[str] or str) – Keys for fields that not need to be processed (padding, channel conversion and normalization) as images. If not passed, the keys in _NON_IMAGE_KEYS will be used. This argument will only work when inputs is dict or list of dict. Defaults to None.

  • non_concatenate_keys (List[str] or str) – Keys for fields that not need to be concatenated. If not passed, the keys in _NON_CONCATENATE_KEYS will be used. This argument will only work when inputs is dict or list of dict. Defaults to None.

  • output_channel_order (str, optional) – The desired image channel order of output the data preprocessor. This is also the desired input channel order of model (and this most likely to be the output order of model). If not passed, no channel order conversion will be performed. Defaults to None.

  • data_keys (List[str] or str) – Keys to preprocess in data samples. Defaults to ‘gt_img’.

  • input_view (tuple, optional) – The view of input tensor. This argument maybe deleted in the future. Defaults to None.

  • output_view (tuple, optional) – The view of output tensor. This argument maybe deleted in the future. Defaults to None.

  • stack_data_sample (bool) – Whether stack a list of data samples to one data sample. Only support with input data samples are DataSamples. Defaults to True.

_NON_IMAGE_KEYS = ['noise']
_NON_CONCATENATE_KEYS = ['num_batches', 'mode', 'sample_kwargs', 'eq_cfg']
cast_data(data: CastData) CastData

Copying data to the target device.

参数

data (dict) – Data returned by DataLoader.

返回

Inputs and data sample at target device.

返回类型

CollatedResult

static _parse_channel_index(inputs) int

Parse channel index of inputs.

_parse_channel_order(key: str, inputs: torch.Tensor, data_sample: Optional[mmagic.structures.DataSample] = None) str
_parse_batch_channel_order(key: str, inputs: Sequence, data_samples: Optional[Sequence[mmagic.structures.DataSample]]) str

Parse channel order of inputs in batch.

_update_metainfo(padding_info: torch.Tensor, channel_order_info: Optional[dict] = None, data_samples: Optional[mmagic.utils.typing.SampleList] = None) mmagic.utils.typing.SampleList

Update padding_info and channel_order to metainfo of.

a batch of `data_samples`. For channel order, we consider same field among data samples share the same channel order. Therefore channel_order is passed as a dict, which key and value are field name and corresponding channel order. For padding info, we consider padding info is same among all field of a sample, but can vary between samples. Therefore, we pass padding_info as Tensor shape like (B, 1, 1).

参数
  • padding_info (Tensor) – The padding info of each sample. Shape like (B, 1, 1).

  • channel_order (dict, Optional) – The channel order of target field. Key and value are field name and corresponding channel order respectively.

  • data_samples (List[DataSample], optional) – The data samples to be updated. If not passed, will initialize a list of empty data samples. Defaults to None.

返回

The updated data samples.

返回类型

List[DataSample]

_do_conversion(inputs: torch.Tensor, inputs_order: str = 'BGR', target_order: Optional[str] = None) Tuple[torch.Tensor, str]

Conduct channel order conversion for a batch of inputs, and return the converted inputs and order after conversion.

inputs_order:
  • RGB / RGB: Convert to target order.

  • SINGLE: Do not change

_do_norm(inputs: torch.Tensor, do_norm: Optional[bool] = None) torch.Tensor
_preprocess_image_tensor(inputs: torch.Tensor, data_samples: Optional[mmagic.utils.typing.SampleList] = None, key: str = 'img') Tuple[torch.Tensor, mmagic.utils.typing.SampleList]

Preprocess a batch of image tensor and update metainfo to corresponding data samples.

参数
  • inputs (Tensor) – Image tensor with shape (C, H, W), (N, C, H, W) or (N, t, C, H, W) to preprocess.

  • data_samples (List[DataSample], optional) – The data samples of corresponding inputs. If not passed, a list of empty data samples will be initialized to save metainfo. Defaults to None.

  • key (str) – The key of image tensor in data samples. Defaults to ‘img’.

返回

The preprocessed image tensor

and updated data samples.

返回类型

Tuple[Tensor, List[DataSample]]

_preprocess_image_list(tensor_list: List[torch.Tensor], data_samples: Optional[mmagic.utils.typing.SampleList], key: str = 'img') Tuple[torch.Tensor, mmagic.utils.typing.SampleList]

Preprocess a list of image tensor and update metainfo to corresponding data samples.

参数
  • tensor_list (List[Tensor]) – Image tensor list to be preprocess.

  • data_samples (List[DataSample], optional) – The data samples of corresponding inputs. If not passed, a list of empty data samples will be initialized to save metainfo. Defaults to None.

  • key (str) – The key of tensor list in data samples. Defaults to ‘img’.

返回

The preprocessed image tensor

and updated data samples.

返回类型

Tuple[Tensor, List[DataSample]]

_preprocess_dict_inputs(batch_inputs: dict, data_samples: Optional[mmagic.utils.typing.SampleList] = None) Tuple[dict, mmagic.utils.typing.SampleList]

Preprocess dict type inputs.

参数
  • batch_inputs (dict) – Input dict.

  • data_samples (List[DataSample], optional) – The data samples of corresponding inputs. If not passed, a list of empty data samples will be initialized to save metainfo. Defaults to None.

返回

The preprocessed dict and

updated data samples.

返回类型

Tuple[dict, List[DataSample]]

_preprocess_data_sample(data_samples: mmagic.utils.typing.SampleList, training: bool) mmagic.structures.DataSample

Preprocess data samples. When training is True, fields belong to self.data_keys will be converted to self.output_channel_order and then normalized by self.mean and self.std. When training is False, fields belongs to self.data_keys will be attempted to convert to ‘BGR’ without normalization. The corresponding metainfo related to normalization, channel order conversion will be updated to data sample as well.

参数
  • data_samples (List[DataSample]) – A list of data samples to preprocess.

  • training (bool) – Whether in training mode.

返回

The list of processed data samples.

返回类型

list

forward(data: dict, training: bool = False) dict

Performs normalization、padding and channel order conversion.

参数
  • data (dict) – Input data to process.

  • training (bool) – Whether to in training mode. Default: False.

返回

Data in the same format as the model input.

返回类型

dict

destruct(outputs: torch.Tensor, data_samples: Union[mmagic.utils.typing.SampleList, mmagic.structures.DataSample, None] = None, key: str = 'img') Union[list, torch.Tensor]

Destruct padding, normalization and convert channel order to BGR if could. If data_samples is a list, outputs will be destructed as a batch of tensor. If data_samples is a DataSample, outputs will be destructed as a single tensor.

Before feed model outputs to visualizer and evaluator, users should call this function for model outputs and inputs.

Use cases:

>>> # destruct model outputs.
>>> # model outputs share the same preprocess information with inputs
>>> # ('img') therefore use 'img' as key
>>> feats = self.forward_tensor(inputs, data_samples, **kwargs)
>>> feats = self.data_preprocessor.destruct(feats, data_samples, 'img')
>>> # destruct model inputs for visualization
>>> for idx, data_sample in enumerate(data_samples):
>>>     destructed_input = self.data_preprocessor.destruct(
>>>         inputs[idx], data_sample, key='img')
>>>     data_sample.set_data({'input': destructed_input})
参数
  • outputs (Tensor) – Tensor to destruct.

  • data_samples (Union[SampleList, DataSample], optional) – Data samples (or data sample) corresponding to outputs. Defaults to None

  • key (str) – The key of field in data sample. Defaults to ‘img’.

返回

Destructed outputs.

返回类型

Union[list, Tensor]

_destruct_norm_and_conversion(batch_tensor: torch.Tensor, data_samples: Union[mmagic.utils.typing.SampleList, mmagic.structures.DataSample, None], key: str) torch.Tensor

De-norm and de-convert channel order. Noted that, we de-norm first, and then de-conversion, since mean and std used in normalization is based on channel order after conversion.

参数
  • batch_tensor (Tensor) – Tensor to destruct.

  • data_samples (Union[SampleList, DataSample], optional) – Data samples (or data sample) corresponding to outputs.

  • key (str) – The key of field in data sample.

返回

Destructed tensor.

返回类型

Tensor

_destruct_padding(batch_tensor: torch.Tensor, data_samples: Union[mmagic.utils.typing.SampleList, mmagic.structures.DataSample, None], same_padding: bool = True) Union[list, torch.Tensor]

Destruct padding of the input tensor.

参数
  • batch_tensor (Tensor) – Tensor to destruct.

  • data_samples (Union[SampleList, DataSample], optional) – Data samples (or data sample) corresponding to outputs. If

  • same_padding (bool) – Whether all samples will un-padded with the padding info of the first sample, and return a stacked un-padded tensor. Otherwise each sample will be unpadded with padding info saved in corresponding data samples, and return a list of un-padded tensor, since each un-padded tensor may have the different shape. Defaults to True.

返回

Destructed outputs.

返回类型

Union[list, Tensor]

class mmagic.models.MattorPreprocessor(mean: MEAN_STD_TYPE = [123.675, 116.28, 103.53], std: MEAN_STD_TYPE = [58.395, 57.12, 57.375], output_channel_order: str = 'RGB', proc_trimap: str = 'rescale_to_zero_one', stack_data_sample=True)

Bases: mmagic.models.data_preprocessors.data_preprocessor.DataPreprocessor

DataPreprocessor for matting models.

See base class DataPreprocessor for detailed information.

Workflow as follow :

  • Collate and move data to the target device.

  • Convert inputs from bgr to rgb if the shape of input is (3, H, W).

  • Normalize image with defined std and mean.

  • Stack inputs to batch_inputs.

参数
  • mean (Sequence[float or int], float or int, optional) – The pixel mean of image channels. Noted that normalization operation is performed after channel order conversion. If it is not specified, images will not be normalized. Defaults None.

  • std (Sequence[float or int], float or int, optional) – The pixel standard deviation of image channels. Noted that normalization operation is performed after channel order conversion. If it is not specified, images will not be normalized. Defaults None.

  • proc_trimap (str) – Methods to process gt tensors. Default: ‘rescale_to_zero_one’. Available options are rescale_to_zero_one and as-is.

  • stack_data_sample (bool) – Whether stack a list of data samples to one data sample. Only support with input data samples are DataSamples. Defaults to True.

_proc_batch_trimap(batch_trimaps: torch.Tensor)
_preprocess_data_sample(data_samples: mmagic.utils.typing.SampleList, training: bool) list

Preprocess data samples. When training is True, fields belong to self.data_keys will be converted to self.output_channel_order and divided by 255. When training is False, fields belongs to self.data_keys will be attempted to convert to ‘BGR’ without normalization. The corresponding metainfo related to normalization, channel order conversion will be updated to data sample as well.

参数
  • data_samples (List[DataSample]) – A list of data samples to preprocess.

  • training (bool) – Whether in training mode.

返回

The list of processed data samples.

返回类型

list

forward(data: Sequence[dict], training: bool = False) Tuple[torch.Tensor, list]

Pre-process input images, trimaps, ground-truth as configured.

参数
  • data (Sequence[dict]) – data sampled from dataloader.

  • training (bool) – Whether to enable training time augmentation. Default: False.

返回

Batched inputs and list of data samples.

返回类型

Tuple[torch.Tensor, list]

Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.