mmagic.models
¶
Package Contents¶
Classes¶
Base class for Conditional GAM models. |
|
Base model for image and video editing. |
|
Base class for GAN models. |
|
Base class for trimap-based matting models. |
|
Base Translation Model. |
|
Basic model for video interpolation. |
|
Implements the exponential moving average (EMA) of the model. |
|
Image pre-processor for generative models. This class provide |
|
DataPreprocessor for matting models. |
- class mmagic.models.BaseConditionalGAN(generator: ModelType, discriminator: Optional[ModelType] = None, data_preprocessor: Optional[Union[dict, mmengine.Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, noise_size: Optional[int] = None, num_classes: Optional[int] = None, ema_config: Optional[Dict] = None, loss_config: Optional[Dict] = None)[source]¶
Bases:
mmagic.models.base_models.base_gan.BaseGAN
Base class for Conditional GAM models.
- Parameters
generator (ModelType) – The config or model of the generator.
discriminator (Optional[ModelType]) – The config or model of the discriminator. Defaults to None.
data_preprocessor (Optional[Union[dict, Config]]) – The pre-process config or
DataPreprocessor
.generator_steps (int) – The number of times the generator is completely updated before the discriminator is updated. Defaults to 1.
discriminator_steps (int) – The number of times the discriminator is completely updated before the generator is updated. Defaults to 1.
noise_size (Optional[int]) – Size of the input noise vector. Default to None.
num_classes (Optional[int]) – The number classes you would like to generate. Defaults to None.
ema_config (Optional[Dict]) – The config for generator’s exponential moving average setting. Defaults to None.
- label_fn(label: mmagic.utils.typing.LabelVar = None, num_batches: int = 1) torch.Tensor [source]¶
Sampling function for label. There are three scenarios in this function:
If label is a callable function, sample num_batches of labels with passed label.
If label is None, sample num_batches of labels in range of [0, self.num_classes-1] uniformly.
If label is a torch.Tensor, check the range of the tensor is in [0, self.num_classes-1]. If all values are in valid range, directly return label.
- Parameters
label (Union[Tensor, Callable, List[int], None]) – You can directly give a batch of label through a
torch.Tensor
or offer a callable function to sample a batch of label data. Otherwise, theNone
indicates to use the default label sampler. Defaults to None.num_batches (int, optional) – The number of batches label want to sample. If label is a Tensor, this will be ignored. Defaults to 1.
- Returns
Sampled label tensor.
- Return type
Tensor
- data_sample_to_label(data_sample: mmagic.structures.DataSample) Optional[torch.Tensor] [source]¶
Get labels from input data_sample and pack to torch.Tensor. If no label is found in the passed data_sample, None would be returned.
- Parameters
data_sample (DataSample) – Input data samples.
- Returns
Packed label tensor.
- Return type
Optional[torch.Tensor]
- static _get_valid_num_classes(num_classes: Optional[int], generator: ModelType, discriminator: Optional[ModelType]) int [source]¶
Try to get the value of num_classes from input, generator and discriminator and check the consistency of these values. If no conflict is found, return the num_classes.
- Parameters
num_classes (Optional[int]) – num_classes passed to BaseConditionalGAN_refactor’s initialize function.
generator (ModelType) – The config or the model of generator.
discriminator (Optional[ModelType]) – The config or model of discriminator.
- Returns
The number of classes to be generated.
- Return type
int
- forward(inputs: mmagic.utils.typing.ForwardInputs, data_samples: Optional[list] = None, mode: Optional[str] = None) List[mmagic.structures.DataSample] [source]¶
Sample images with the given inputs. If forward mode is ‘ema’ or ‘orig’, the image generated by corresponding generator will be returned. If forward mode is ‘ema/orig’, images generated by original generator and EMA generator will both be returned in a dict.
- Parameters
inputs (ForwardInputs) – Dict containing the necessary information (e.g. noise, num_batches, mode) to generate image.
data_samples (Optional[list]) – Data samples collated by
data_preprocessor
. Defaults to None.mode (Optional[str]) – mode is not used in
BaseConditionalGAN
. Defaults to None.
- Returns
Generated images or image dict.
- Return type
List[DataSample]
- train_generator(inputs: dict, data_samples: List[mmagic.structures.DataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor] [source]¶
Training function for discriminator. All GANs should implement this function by themselves.
- Parameters
inputs (dict) – Inputs from dataloader.
data_samples (List[DataSample]) – Data samples from dataloader.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- Returns
A
dict
of tensor for logging.- Return type
Dict[str, Tensor]
- train_discriminator(inputs: dict, data_samples: List[mmagic.structures.DataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor] [source]¶
Training function for discriminator. All GANs should implement this function by themselves.
- Parameters
inputs (dict) – Inputs from dataloader.
data_samples (List[DataSample]) – Data samples from dataloader.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- Returns
A
dict
of tensor for logging.- Return type
Dict[str, Tensor]
- class mmagic.models.BaseEditModel(generator: dict, pixel_loss: dict, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None, init_cfg: Optional[dict] = None, data_preprocessor: Optional[dict] = None)[source]¶
Bases:
mmengine.model.BaseModel
Base model for image and video editing.
It must contain a generator that takes frames as inputs and outputs an interpolated frame. It also has a pixel-wise loss for training.
- Parameters
generator (dict) – Config for the generator structure.
pixel_loss (dict) – Config for pixel-wise loss.
train_cfg (dict) – Config for training. Default: None.
test_cfg (dict) – Config for testing. Default: None.
init_cfg (dict, optional) – The weight initialized config for
BaseModule
.data_preprocessor (dict, optional) – The pre-process config of
BaseDataPreprocessor
.
- init_cfg¶
Initialization config dict.
- Type
dict, optional
- data_preprocessor¶
Used for pre-processing data sampled by dataloader to the format accepted by
forward()
. Default: None.- Type
BaseDataPreprocessor
- forward(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, mode: str = 'tensor', **kwargs) Union[torch.Tensor, List[mmagic.structures.DataSample], dict] [source]¶
Returns losses or predictions of training, validation, testing, and simple inference process.
forward
method of BaseModel is an abstract method, its subclasses must implement this method.Accepts
inputs
anddata_samples
processed bydata_preprocessor
, and returns results according to mode arguments.During non-distributed training, validation, and testing process,
forward
will be called byBaseModel.train_step
,BaseModel.val_step
andBaseModel.val_step
directly.During distributed data parallel training process,
MMSeparateDistributedDataParallel.train_step
will first callDistributedDataParallel.forward
to enable automatic gradient synchronization, and then callforward
to get training loss.- Parameters
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor
.mode (str) –
mode should be one of
loss
,predict
andtensor
. Default: ‘tensor’.loss
: Called bytrain_step
and return lossdict
used for loggingpredict
: Called byval_step
andtest_step
and return list ofBaseDataElement
results used for computing metric.tensor
: Called by custom use to getTensor
type results.
- Returns
If
mode == loss
, return adict
of loss tensor used for backward and logging.If
mode == predict
, return alist
ofBaseDataElement
for computing metric and getting inference result.If
mode == tensor
, return a tensor ortuple
of tensor ordict
or tensor for custom use.
- Return type
ForwardResults
- convert_to_datasample(predictions: mmagic.structures.DataSample, data_samples: mmagic.structures.DataSample, inputs: Optional[torch.Tensor]) List[mmagic.structures.DataSample] [source]¶
Add predictions and destructed inputs (if passed) to data samples.
- Parameters
predictions (DataSample) – The predictions of the model.
data_samples (DataSample) – The data samples loaded from dataloader.
inputs (Optional[torch.Tensor]) – The input of model. Defaults to None.
- Returns
Modified data samples.
- Return type
List[DataSample]
- forward_tensor(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, **kwargs) torch.Tensor [source]¶
Forward tensor. Returns result of simple forward.
- Parameters
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor
.
- Returns
result of simple forward.
- Return type
Tensor
- forward_inference(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, **kwargs) mmagic.structures.DataSample [source]¶
Forward inference. Returns predictions of validation, testing, and simple inference.
- Parameters
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor
.
- Returns
predictions.
- Return type
- forward_train(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, **kwargs) Dict[str, torch.Tensor] [source]¶
Forward training. Returns dict of losses of training.
- Parameters
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor
.
- Returns
Dict of losses.
- Return type
dict
- class mmagic.models.BaseGAN(generator: ModelType, discriminator: Optional[ModelType] = None, data_preprocessor: Optional[Union[dict, mmengine.Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, noise_size: Optional[int] = None, ema_config: Optional[Dict] = None, loss_config: Optional[Dict] = None)[source]¶
Bases:
mmengine.model.BaseModel
Base class for GAN models.
- Parameters
generator (ModelType) – The config or model of the generator.
discriminator (Optional[ModelType]) – The config or model of the discriminator. Defaults to None.
data_preprocessor (Optional[Union[dict, Config]]) – The pre-process config or
DataPreprocessor
.generator_steps (int) – The number of times the generator is completely updated before the discriminator is updated. Defaults to 1.
discriminator_steps (int) – The number of times the discriminator is completely updated before the generator is updated. Defaults to 1.
ema_config (Optional[Dict]) – The config for generator’s exponential moving average setting. Defaults to None.
- property generator_steps: int¶
The number of times the generator is completely updated before the discriminator is updated.
- Type
int
- property discriminator_steps: int¶
The number of times the discriminator is completely updated before the generator is updated.
- Type
int
- property device: torch.device¶
Get current device of the model.
- Returns
The current device of the model.
- Return type
torch.device
- property with_ema_gen: bool¶
Whether the GAN adopts exponential moving average.
- Returns
- If True, means this GAN model is adopted to exponential
moving average and vice versa.
- Return type
bool
- static gather_log_vars(log_vars_list: List[Dict[str, torch.Tensor]]) Dict[str, torch.Tensor] [source]¶
Gather a list of log_vars. :param log_vars_list: List[Dict[str, Tensor]]
- Returns
Dict[str, Tensor]
- _init_loss(loss_config: Optional[Dict] = None) None [source]¶
Initialize customized loss modules.
If loss_config is a dict, we allow kinds of value for each field.
- loss_config is None: Users will implement all loss calculations
in their own function. Weights for each loss terms are hard coded.
- loss_config is dict of scalar or string: Users will implement all
loss calculations and use passed loss_config to control the weight or behavior of the loss calculation. Users will unpack and use each field in this dict by themselves.
loss_config = dict(gp_norm_mode=’HWC’, gp_loss_weight=10)
- loss_config is dict of dict: Each field in loss_config will
used to build a corresponding loss module. And use loss calculation function predefined by
BaseGAN
to calculate the loss.loss_config = dict()
Example
- loss_config = dict(
# BaseGAN pre-defined fields gan_loss=dict(type=’GANLoss’, gan_type=’wgan-logistic-ns’), disc_auxiliary_loss=dict(
type=’R1GradientPenalty’, loss_weight=10. / 2., interval=2, norm_mode=’HWC’, data_info=dict(
real_data=’real_imgs’, discriminator=’disc’)),
- gen_auxiliary_loss=dict(
type=’GeneratorPathRegularizer’, loss_weight=2, pl_batch_shrink=2, interval=g_reg_interval, data_info=dict(
generator=’gen’, num_batches=’batch_size’)),
# user-defined field for loss weights or loss calculation my_loss_2=dict(weight=2, norm_mode=’L1’), my_loss_3=2, my_loss_4_norm_type=’L2’)
- Parameters
loss_config (Optional[Dict], optional) – Loss config used to build loss modules or define the loss weights. Defaults to None.
- noise_fn(noise: mmagic.utils.typing.NoiseVar = None, num_batches: int = 1)[source]¶
Sampling function for noise. There are three scenarios in this function:
If noise is a callable function, sample num_batches of noise with passed noise.
If noise is None, sample num_batches of noise from gaussian distribution.
If noise is a torch.Tensor, directly return noise.
- Parameters
noise (Union[Tensor, Callable, List[int], None]) – You can directly give a batch of label through a
torch.Tensor
or offer a callable function to sample a batch of label data. Otherwise, theNone
indicates to use the default noise sampler. Defaults to None.num_batches (int, optional) – The number of batches label want to sample. If label is a Tensor, this will be ignored. Defaults to 1.
- Returns
Sampled noise tensor.
- Return type
Tensor
- _init_ema_model(ema_config: dict)[source]¶
Initialize a EMA model corresponding to the given ema_config. If ema_config is an empty dict or None, EMA model will not be initialized.
- Parameters
ema_config (dict) – Config to initialize the EMA model.
- _get_valid_model(batch_inputs: mmagic.utils.typing.ForwardInputs) str [source]¶
Try to get the valid forward model from inputs.
If forward model is defined in batch_inputs, it will be used as forward model.
If forward model is not defined in batch_inputs, ‘ema’ will returned if :property:`with_ema_gen` is true. Otherwise, ‘orig’ will be returned.
- Parameters
batch_inputs (ForwardInputs) – Inputs passed to
forward()
.- Returns
- Forward model to generate image. (‘orig’, ‘ema’ or
’ema/orig’).
- Return type
str
- forward(inputs: mmagic.utils.typing.ForwardInputs, data_samples: Optional[list] = None, mode: Optional[str] = None) mmagic.utils.typing.SampleList [source]¶
Sample images with the given inputs. If forward mode is ‘ema’ or ‘orig’, the image generated by corresponding generator will be returned. If forward mode is ‘ema/orig’, images generated by original generator and EMA generator will both be returned in a dict.
- Parameters
batch_inputs (ForwardInputs) – Dict containing the necessary information (e.g. noise, num_batches, mode) to generate image.
data_samples (Optional[list]) – Data samples collated by
data_preprocessor
. Defaults to None.mode (Optional[str]) – mode is not used in
BaseGAN
. Defaults to None.
- Returns
A list of
DataSample
contain generated results.- Return type
SampleList
- val_step(data: dict) mmagic.utils.typing.SampleList [source]¶
Gets the generated image of given data.
Calls
self.data_preprocessor(data)
andself(inputs, data_sample, mode=None)
in order. Return the generated results which will be passed to evaluator.- Parameters
data (dict) – Data sampled from metric specific sampler. More details in Metrics and Evaluator.
- Returns
Generated image or image dict.
- Return type
SampleList
- test_step(data: dict) mmagic.utils.typing.SampleList [source]¶
Gets the generated image of given data. Same as
val_step()
.- Parameters
data (dict) – Data sampled from metric specific sampler. More details in Metrics and Evaluator.
- Returns
Generated image or image dict.
- Return type
List[DataSample]
- train_step(data: dict, optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor] [source]¶
Train GAN model. In the training of GAN models, generator and discriminator are updated alternatively. In MMagic’s design, self.train_step is called with data input. Therefore we always update discriminator, whose updating is relay on real data, and then determine if the generator needs to be updated based on the current number of iterations. More details about whether to update generator can be found in
should_gen_update()
.- Parameters
data (dict) – Data sampled from dataloader.
optim_wrapper (OptimWrapperDict) – OptimWrapperDict instance contains OptimWrapper of generator and discriminator.
- Returns
A
dict
of tensor for logging.- Return type
Dict[str, torch.Tensor]
- train_generator(inputs: dict, data_samples: List[mmagic.structures.DataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor] [source]¶
Training function for discriminator. All GANs should implement this function by themselves.
- Parameters
inputs (dict) – Inputs from dataloader.
data_samples (List[DataSample]) – Data samples from dataloader.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- Returns
A
dict
of tensor for logging.- Return type
Dict[str, Tensor]
- train_discriminator(inputs: dict, data_samples: List[mmagic.structures.DataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor] [source]¶
Training function for discriminator. All GANs should implement this function by themselves.
- Parameters
inputs (dict) – Inputs from dataloader.
data_samples (List[DataSample]) – Data samples from dataloader.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- Returns
A
dict
of tensor for logging.- Return type
Dict[str, Tensor]
- class mmagic.models.BaseMattor(data_preprocessor: Union[dict, mmengine.config.Config], backbone: dict, init_cfg: Optional[dict] = None, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None)[source]¶
Bases:
mmengine.model.BaseModel
Base class for trimap-based matting models.
A matting model must contain a backbone which produces pred_alpha, a dense prediction with the same height and width of input image. In some cases (such as DIM), the model has a refiner which refines the prediction of the backbone.
Subclasses should overwrite the following functions:
_forward_train()
, to return a loss_forward_test()
, to return a prediction_forward()
, to return raw tensors
For test, this base class provides functions to resize inputs and post-process pred_alphas to get predictions
- Parameters
backbone (dict) – Config of backbone.
data_preprocessor (dict) – Config of data_preprocessor. See
MattorPreprocessor
for details.init_cfg (dict, optional) – Initialization config dict.
train_cfg (dict) – Config of training. Customized by subclassesCustomized bu In
train_cfg
,train_backbone
should be specified. If the model has a refiner,train_refiner
should be specified.test_cfg (dict) – Config of testing. In
test_cfg
, If the model has a refiner,train_refiner
should be specified.
- resize_inputs(batch_inputs: torch.Tensor) torch.Tensor [source]¶
Pad or interpolate images and trimaps to multiple of given factor.
- restore_size(pred_alpha: torch.Tensor, data_sample: mmagic.structures.DataSample) torch.Tensor [source]¶
Restore the predicted alpha to the original shape.
The shape of the predicted alpha may not be the same as the shape of original input image. This function restores the shape of the predicted alpha.
- Parameters
pred_alpha (torch.Tensor) – A single predicted alpha of shape (1, H, W).
data_sample (DataSample) – Data sample containing original shape as meta data.
- Returns
The reshaped predicted alpha.
- Return type
torch.Tensor
- postprocess(batch_pred_alpha: torch.Tensor, data_samples: mmagic.structures.DataSample) List[mmagic.structures.DataSample] [source]¶
Post-process alpha predictions.
- This function contains the following steps:
Restore padding or interpolation
Mask alpha prediction with trimap
Clamp alpha prediction to 0-1
Convert alpha prediction to uint8
Pack alpha prediction into DataSample
Currently only batch_size 1 is actually supported.
- Parameters
batch_pred_alpha (torch.Tensor) – A batch of predicted alpha of shape (N, 1, H, W).
data_samples (List[DataSample]) – List of data samples.
- Returns
- A list of predictions.
Each data sample contains a pred_alpha, which is a torch.Tensor with dtype=uint8, device=cuda:0
- Return type
List[DataSample]
- forward(inputs: torch.Tensor, data_samples: DataSamples = None, mode: str = 'tensor') List[mmagic.structures.DataSample] [source]¶
General forward function.
- Parameters
inputs (torch.Tensor) – A batch of inputs. with image and trimap concatenated alone channel dimension.
data_samples (List[DataSample], optional) – A list of data samples, containing: - Ground-truth alpha / foreground / background to compute loss - other meta information
mode (str) –
mode should be one of
loss
,predict
andtensor
. Default: ‘tensor’.loss
: Called bytrain_step
and return lossdict
used for loggingpredict
: Called byval_step
andtest_step
and return list ofBaseDataElement
results used for computing metric.tensor
: Called by custom use to getTensor
type results.
- Returns
Sequence of predictions packed into DataElement
- Return type
List[DataElement]
- convert_to_datasample(predictions: List[mmagic.structures.DataSample], data_samples: mmagic.structures.DataSample) List[mmagic.structures.DataSample] [source]¶
Add predictions to data samples.
- Parameters
predictions (List[DataSample]) – The predictions of the model.
data_samples (DataSample) – The data samples loaded from dataloader.
- Returns
Modified data samples.
- Return type
List[DataSample]
- class mmagic.models.BaseTranslationModel(generator, discriminator, default_domain: str, reachable_domains: List[str], related_domains: List[str], data_preprocessor, discriminator_steps: int = 1, disc_init_steps: int = 0, real_img_key: str = 'real_img', loss_config: Optional[dict] = None)[source]¶
Bases:
mmengine.model.BaseModel
Base Translation Model.
Translation models can transfer images from one domain to another. Domain information like default_domain, reachable_domains are needed to initialize the class. And we also provide query functions like is_domain_reachable, get_other_domains.
You can get a specific generator based on the domain, and by specifying target_domain in the forward function, you can decide the domain of generated images. Considering the difference among different image translation models, we only provide the external interfaces mentioned above. When you implement image translation with a specific method, you can inherit both BaseTranslationModel and the method (e.g BaseGAN) and implement abstract methods.
- Parameters
default_domain (str) – Default output domain.
reachable_domains (list[str]) – Domains that can be generated by the model.
related_domains (list[str]) – Domains involved in training and testing. reachable_domains must be contained in related_domains. However, related_domains may contain source domains that are used to retrieve source images from data_batch but not in reachable_domains.
discriminator_steps (int) – The number of times the discriminator is completely updated before the generator is updated. Defaults to 1.
disc_init_steps (int) – The number of initial steps used only to train discriminators.
- init_weights()[source]¶
Initialize weights for the module dict.
- Parameters
pretrained (str, optional) – Path for pretrained weights. If given None, pretrained weights will not be loaded. Default: None.
- get_module(module)[source]¶
Get nn.ModuleDict to fit the MMDistributedDataParallel interface.
- Parameters
module (MMDistributedDataParallel | nn.ModuleDict) – The input module that needs processing.
- Returns
The ModuleDict of multiple networks.
- Return type
nn.ModuleDict
- forward(img, test_mode=False, **kwargs)[source]¶
Forward function.
- Parameters
img (tensor) – Input image tensor.
test_mode (bool) – Whether in test mode or not. Default: False.
kwargs (dict) – Other arguments.
- forward_train(img, target_domain, **kwargs)[source]¶
Forward function for training.
- Parameters
img (tensor) – Input image tensor.
target_domain (str) – Target domain of output image.
kwargs (dict) – Other arguments.
- Returns
Forward results.
- Return type
dict
- class mmagic.models.BasicInterpolator(generator: dict, pixel_loss: dict, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None, required_frames: int = 2, step_frames: int = 1, init_cfg: Optional[dict] = None, data_preprocessor: Optional[dict] = None)[source]¶
Bases:
mmagic.models.base_models.base_edit_model.BaseEditModel
Basic model for video interpolation.
It must contain a generator that takes frames as inputs and outputs an interpolated frame. It also has a pixel-wise loss for training.
- Parameters
generator (dict) – Config for the generator structure.
pixel_loss (dict) – Config for pixel-wise loss.
train_cfg (dict) – Config for training. Default: None.
test_cfg (dict) – Config for testing. Default: None.
required_frames (int) – Required frames in each process. Default: 2
step_frames (int) – Step size of video frame interpolation. Default: 1
init_cfg (dict, optional) – The weight initialized config for
BaseModule
.data_preprocessor (dict, optional) – The pre-process config of
BaseDataPreprocessor
.
- init_cfg¶
Initialization config dict.
- Type
dict, optional
- data_preprocessor¶
Used for pre-processing data sampled by dataloader to the format accepted by
forward()
.- Type
BaseDataPreprocessor
- split_frames(input_tensors: torch.Tensor) torch.Tensor [source]¶
split input tensors for inference.
- Parameters
input_tensors (Tensor) – Tensor of input frames with shape [1, t, c, h, w]
- Returns
Split tensor with shape [t-1, 2, c, h, w]
- Return type
Tensor
- static merge_frames(input_tensors: torch.Tensor, output_tensors: torch.Tensor) list [source]¶
merge input frames and output frames.
Interpolate a frame between the given two frames.
- Merged from
[[in1, in2], [in2, in3], [in3, in4], …] [[out1], [out2], [out3], …]
- to
[in1, out1, in2, out2, in3, out3, in4, …]
- Parameters
input_tensors (Tensor) – The input frames with shape [n, 2, c, h, w]
output_tensors (Tensor) – The output frames with shape [n, 1, c, h, w].
- Returns
The final frames.
- Return type
list[np.array]
- class mmagic.models.ExponentialMovingAverage(model: torch.nn.Module, momentum: float = 0.0002, interval: int = 1, device: Optional[torch.device] = None, update_buffers: bool = False)[source]¶
Bases:
mmengine.model.BaseAveragedModel
Implements the exponential moving average (EMA) of the model.
All parameters are updated by the formula as below:
\[Xema_{t+1} = (1 - momentum) * Xema_{t} + momentum * X_t\]- Parameters
model (nn.Module) – The model to be averaged.
momentum (float) – The momentum used for updating ema parameter. Defaults to 0.0002. Ema’s parameter are updated with the formula \(averaged\_param = (1-momentum) * averaged\_param + momentum * source\_param\).
interval (int) – Interval between two updates. Defaults to 1.
device (torch.device, optional) – If provided, the averaged model will be stored on the
device
. Defaults to None.update_buffers (bool) – if True, it will compute running averages for both the parameters and the buffers of the model. Defaults to False.
- avg_func(averaged_param: torch.Tensor, source_param: torch.Tensor, steps: int) None [source]¶
Compute the moving average of the parameters using exponential moving average.
- Parameters
averaged_param (Tensor) – The averaged parameters.
source_param (Tensor) – The source parameters.
steps (int) – The number of times the parameters have been updated.
- _load_from_state_dict(state_dict: dict, prefix: str, local_metadata: dict, strict: bool, missing_keys: list, unexpected_keys: list, error_msgs: List[str]) None [source]¶
Overrides
nn.Module._load_from_state_dict
to support loadingstate_dict
without wrap ema module withBaseAveragedModel
.In OpenMMLab 1.0, model will not wrap ema submodule with
BaseAveragedModel
, and the ema weight key in state_dict will miss module prefix. Therefore,BaseAveragedModel
need to automatically add themodule
prefix if the corresponding key instate_dict
misses it.- Parameters
state_dict (dict) – A dict containing parameters and persistent buffers.
prefix (str) – The prefix for parameters and buffers used in this module
local_metadata (dict) – a dict containing the metadata for this module.
strict (bool) – Whether to strictly enforce that the keys in
state_dict
withprefix
match the names of parameters and buffers in this modulemissing_keys (List[str]) – if
strict=True
, add missing keys to this listunexpected_keys (List[str]) – if
strict=True
, add unexpected keys to this listerror_msgs (List[str]) – error messages should be added to this list, and will be reported together in
load_state_dict()
.
- class mmagic.models.DataPreprocessor(mean: Union[Sequence[Union[float, int]], float, int] = 127.5, std: Union[Sequence[Union[float, int]], float, int] = 127.5, pad_size_divisor: int = 1, pad_value: Union[float, int] = 0, pad_mode: str = 'constant', non_image_keys: Optional[Tuple[str, List[str]]] = None, non_concentate_keys: Optional[Tuple[str, List[str]]] = None, output_channel_order: Optional[str] = None, data_keys: Union[List[str], str] = 'gt_img', input_view: Optional[tuple] = None, output_view: Optional[tuple] = None, stack_data_sample=True)[source]¶
Bases:
mmengine.model.ImgDataPreprocessor
Image pre-processor for generative models. This class provide normalization and bgr to rgb conversion for image tensor inputs. The input of this classes should be dict which keys are inputs and data_samples.
Besides to process tensor inputs, this class support dict as inputs. - If the value is Tensor and the corresponding key is not contained in
_NON_IMAGE_KEYS
, it will be processed as image tensor. - If the value is Tensor and the corresponding key belongs to_NON_IMAGE_KEYS
, it will not remains unchanged. - If value is string or integer, it will not remains unchanged.- Parameters
mean (Sequence[float or int], float or int, optional) – The pixel mean of image channels. Noted that normalization operation is performed after channel order conversion. If it is not specified, images will not be normalized. Defaults None.
std (Sequence[float or int], float or int, optional) – The pixel standard deviation of image channels. Noted that normalization operation is performed after channel order conversion. If it is not specified, images will not be normalized. Defaults None.
pad_size_divisor (int) – The size of padded image should be divisible by
pad_size_divisor
. Defaults to 1.pad_value (float or int) – The padded pixel value. Defaults to 0.
pad_mode (str) – Padding mode for
torch.nn.functional.pad
. Defaults to ‘constant’.non_image_keys (List[str] or str) – Keys for fields that not need to be processed (padding, channel conversion and normalization) as images. If not passed, the keys in
_NON_IMAGE_KEYS
will be used. This argument will only work when inputs is dict or list of dict. Defaults to None.non_concatenate_keys (List[str] or str) – Keys for fields that not need to be concatenated. If not passed, the keys in
_NON_CONCATENATE_KEYS
will be used. This argument will only work when inputs is dict or list of dict. Defaults to None.output_channel_order (str, optional) – The desired image channel order of output the data preprocessor. This is also the desired input channel order of model (and this most likely to be the output order of model). If not passed, no channel order conversion will be performed. Defaults to None.
data_keys (List[str] or str) – Keys to preprocess in data samples. Defaults to ‘gt_img’.
input_view (tuple, optional) – The view of input tensor. This argument maybe deleted in the future. Defaults to None.
output_view (tuple, optional) – The view of output tensor. This argument maybe deleted in the future. Defaults to None.
stack_data_sample (bool) – Whether stack a list of data samples to one data sample. Only support with input data samples are DataSamples. Defaults to True.
- _NON_IMAGE_KEYS = ['noise']¶
- _NON_CONCATENATE_KEYS = ['num_batches', 'mode', 'sample_kwargs', 'eq_cfg']¶
- cast_data(data: CastData) CastData [source]¶
Copying data to the target device.
- Parameters
data (dict) – Data returned by
DataLoader
.- Returns
Inputs and data sample at target device.
- Return type
CollatedResult
- _parse_channel_order(key: str, inputs: torch.Tensor, data_sample: Optional[mmagic.structures.DataSample] = None) str [source]¶
- _parse_batch_channel_order(key: str, inputs: Sequence, data_samples: Optional[Sequence[mmagic.structures.DataSample]]) str [source]¶
Parse channel order of inputs in batch.
- _update_metainfo(padding_info: torch.Tensor, channel_order_info: Optional[dict] = None, data_samples: Optional[mmagic.utils.typing.SampleList] = None) mmagic.utils.typing.SampleList [source]¶
Update padding_info and channel_order to metainfo of.
a batch of `data_samples`. For channel order, we consider same field among data samples share the same channel order. Therefore channel_order is passed as a dict, which key and value are field name and corresponding channel order. For padding info, we consider padding info is same among all field of a sample, but can vary between samples. Therefore, we pass padding_info as Tensor shape like (B, 1, 1).
- Parameters
padding_info (Tensor) – The padding info of each sample. Shape like (B, 1, 1).
channel_order (dict, Optional) – The channel order of target field. Key and value are field name and corresponding channel order respectively.
data_samples (List[DataSample], optional) – The data samples to be updated. If not passed, will initialize a list of empty data samples. Defaults to None.
- Returns
The updated data samples.
- Return type
List[DataSample]
- _do_conversion(inputs: torch.Tensor, inputs_order: str = 'BGR', target_order: Optional[str] = None) Tuple[torch.Tensor, str] [source]¶
Conduct channel order conversion for a batch of inputs, and return the converted inputs and order after conversion.
- inputs_order:
RGB / RGB: Convert to target order.
SINGLE: Do not change
- _preprocess_image_tensor(inputs: torch.Tensor, data_samples: Optional[mmagic.utils.typing.SampleList] = None, key: str = 'img') Tuple[torch.Tensor, mmagic.utils.typing.SampleList] [source]¶
Preprocess a batch of image tensor and update metainfo to corresponding data samples.
- Parameters
inputs (Tensor) – Image tensor with shape (C, H, W), (N, C, H, W) or (N, t, C, H, W) to preprocess.
data_samples (List[DataSample], optional) – The data samples of corresponding inputs. If not passed, a list of empty data samples will be initialized to save metainfo. Defaults to None.
key (str) – The key of image tensor in data samples. Defaults to ‘img’.
- Returns
- The preprocessed image tensor
and updated data samples.
- Return type
Tuple[Tensor, List[DataSample]]
- _preprocess_image_list(tensor_list: List[torch.Tensor], data_samples: Optional[mmagic.utils.typing.SampleList], key: str = 'img') Tuple[torch.Tensor, mmagic.utils.typing.SampleList] [source]¶
Preprocess a list of image tensor and update metainfo to corresponding data samples.
- Parameters
tensor_list (List[Tensor]) – Image tensor list to be preprocess.
data_samples (List[DataSample], optional) – The data samples of corresponding inputs. If not passed, a list of empty data samples will be initialized to save metainfo. Defaults to None.
key (str) – The key of tensor list in data samples. Defaults to ‘img’.
- Returns
- The preprocessed image tensor
and updated data samples.
- Return type
Tuple[Tensor, List[DataSample]]
- _preprocess_dict_inputs(batch_inputs: dict, data_samples: Optional[mmagic.utils.typing.SampleList] = None) Tuple[dict, mmagic.utils.typing.SampleList] [source]¶
Preprocess dict type inputs.
- Parameters
batch_inputs (dict) – Input dict.
data_samples (List[DataSample], optional) – The data samples of corresponding inputs. If not passed, a list of empty data samples will be initialized to save metainfo. Defaults to None.
- Returns
- The preprocessed dict and
updated data samples.
- Return type
Tuple[dict, List[DataSample]]
- _preprocess_data_sample(data_samples: mmagic.utils.typing.SampleList, training: bool) mmagic.structures.DataSample [source]¶
Preprocess data samples. When training is True, fields belong to
self.data_keys
will be converted toself.output_channel_order
and then normalized by self.mean and self.std. When training is False, fields belongs toself.data_keys
will be attempted to convert to ‘BGR’ without normalization. The corresponding metainfo related to normalization, channel order conversion will be updated to data sample as well.- Parameters
data_samples (List[DataSample]) – A list of data samples to preprocess.
training (bool) – Whether in training mode.
- Returns
The list of processed data samples.
- Return type
list
- forward(data: dict, training: bool = False) dict [source]¶
Performs normalization、padding and channel order conversion.
- Parameters
data (dict) – Input data to process.
training (bool) – Whether to in training mode. Default: False.
- Returns
Data in the same format as the model input.
- Return type
dict
- destruct(outputs: torch.Tensor, data_samples: Union[mmagic.utils.typing.SampleList, mmagic.structures.DataSample, None] = None, key: str = 'img') Union[list, torch.Tensor] [source]¶
Destruct padding, normalization and convert channel order to BGR if could. If data_samples is a list, outputs will be destructed as a batch of tensor. If data_samples is a DataSample, outputs will be destructed as a single tensor.
Before feed model outputs to visualizer and evaluator, users should call this function for model outputs and inputs.
Use cases:
>>> # destruct model outputs. >>> # model outputs share the same preprocess information with inputs >>> # ('img') therefore use 'img' as key >>> feats = self.forward_tensor(inputs, data_samples, **kwargs) >>> feats = self.data_preprocessor.destruct(feats, data_samples, 'img')
>>> # destruct model inputs for visualization >>> for idx, data_sample in enumerate(data_samples): >>> destructed_input = self.data_preprocessor.destruct( >>> inputs[idx], data_sample, key='img') >>> data_sample.set_data({'input': destructed_input})
- Parameters
outputs (Tensor) – Tensor to destruct.
data_samples (Union[SampleList, DataSample], optional) – Data samples (or data sample) corresponding to outputs. Defaults to None
key (str) – The key of field in data sample. Defaults to ‘img’.
- Returns
Destructed outputs.
- Return type
Union[list, Tensor]
- _destruct_norm_and_conversion(batch_tensor: torch.Tensor, data_samples: Union[mmagic.utils.typing.SampleList, mmagic.structures.DataSample, None], key: str) torch.Tensor [source]¶
De-norm and de-convert channel order. Noted that, we de-norm first, and then de-conversion, since mean and std used in normalization is based on channel order after conversion.
- Parameters
batch_tensor (Tensor) – Tensor to destruct.
data_samples (Union[SampleList, DataSample], optional) – Data samples (or data sample) corresponding to outputs.
key (str) – The key of field in data sample.
- Returns
Destructed tensor.
- Return type
Tensor
- _destruct_padding(batch_tensor: torch.Tensor, data_samples: Union[mmagic.utils.typing.SampleList, mmagic.structures.DataSample, None], same_padding: bool = True) Union[list, torch.Tensor] [source]¶
Destruct padding of the input tensor.
- Parameters
batch_tensor (Tensor) – Tensor to destruct.
data_samples (Union[SampleList, DataSample], optional) – Data samples (or data sample) corresponding to outputs. If
same_padding (bool) – Whether all samples will un-padded with the padding info of the first sample, and return a stacked un-padded tensor. Otherwise each sample will be unpadded with padding info saved in corresponding data samples, and return a list of un-padded tensor, since each un-padded tensor may have the different shape. Defaults to True.
- Returns
Destructed outputs.
- Return type
Union[list, Tensor]
- class mmagic.models.MattorPreprocessor(mean: MEAN_STD_TYPE = [123.675, 116.28, 103.53], std: MEAN_STD_TYPE = [58.395, 57.12, 57.375], output_channel_order: str = 'RGB', proc_trimap: str = 'rescale_to_zero_one', stack_data_sample=True)[source]¶
Bases:
mmagic.models.data_preprocessors.data_preprocessor.DataPreprocessor
DataPreprocessor for matting models.
See base class
DataPreprocessor
for detailed information.Workflow as follow :
Collate and move data to the target device.
Convert inputs from bgr to rgb if the shape of input is (3, H, W).
Normalize image with defined std and mean.
Stack inputs to batch_inputs.
- Parameters
mean (Sequence[float or int], float or int, optional) – The pixel mean of image channels. Noted that normalization operation is performed after channel order conversion. If it is not specified, images will not be normalized. Defaults None.
std (Sequence[float or int], float or int, optional) – The pixel standard deviation of image channels. Noted that normalization operation is performed after channel order conversion. If it is not specified, images will not be normalized. Defaults None.
proc_trimap (str) – Methods to process gt tensors. Default: ‘rescale_to_zero_one’. Available options are
rescale_to_zero_one
andas-is
.stack_data_sample (bool) – Whether stack a list of data samples to one data sample. Only support with input data samples are DataSamples. Defaults to True.
- _preprocess_data_sample(data_samples: mmagic.utils.typing.SampleList, training: bool) list [source]¶
Preprocess data samples. When training is True, fields belong to
self.data_keys
will be converted toself.output_channel_order
and divided by 255. When training is False, fields belongs toself.data_keys
will be attempted to convert to ‘BGR’ without normalization. The corresponding metainfo related to normalization, channel order conversion will be updated to data sample as well.- Parameters
data_samples (List[DataSample]) – A list of data samples to preprocess.
training (bool) – Whether in training mode.
- Returns
The list of processed data samples.
- Return type
list
- forward(data: Sequence[dict], training: bool = False) Tuple[torch.Tensor, list] [source]¶
Pre-process input images, trimaps, ground-truth as configured.
- Parameters
data (Sequence[dict]) – data sampled from dataloader.
training (bool) – Whether to enable training time augmentation. Default: False.
- Returns
Batched inputs and list of data samples.
- Return type
Tuple[torch.Tensor, list]