mmagic.models.base_models
¶
Package Contents¶
Classes¶
Implements the exponential moving average (EMA) of the model. |
|
Implements the exponential moving average with ramping up momentum. |
|
Base class for Conditional GAM models. |
|
Base model for image and video editing. |
|
Base class for GAN models. |
|
Base class for trimap-based matting models. |
|
Base Translation Model. |
|
Basic model for video interpolation. |
|
Standard one-stage inpaintor with commonly used losses. |
|
Standard two-stage inpaintor with commonly used losses. A two-stage |
- class mmagic.models.base_models.ExponentialMovingAverage(model: torch.nn.Module, momentum: float = 0.0002, interval: int = 1, device: Optional[torch.device] = None, update_buffers: bool = False)[源代码]¶
Bases:
mmengine.model.BaseAveragedModel
Implements the exponential moving average (EMA) of the model.
All parameters are updated by the formula as below:
\[Xema_{t+1} = (1 - momentum) * Xema_{t} + momentum * X_t\]- 参数
model (nn.Module) – The model to be averaged.
momentum (float) – The momentum used for updating ema parameter. Defaults to 0.0002. Ema’s parameter are updated with the formula \(averaged\_param = (1-momentum) * averaged\_param + momentum * source\_param\).
interval (int) – Interval between two updates. Defaults to 1.
device (torch.device, optional) – If provided, the averaged model will be stored on the
device
. Defaults to None.update_buffers (bool) – if True, it will compute running averages for both the parameters and the buffers of the model. Defaults to False.
- avg_func(averaged_param: torch.Tensor, source_param: torch.Tensor, steps: int) None ¶
Compute the moving average of the parameters using exponential moving average.
- 参数
averaged_param (Tensor) – The averaged parameters.
source_param (Tensor) – The source parameters.
steps (int) – The number of times the parameters have been updated.
- _load_from_state_dict(state_dict: dict, prefix: str, local_metadata: dict, strict: bool, missing_keys: list, unexpected_keys: list, error_msgs: List[str]) None ¶
Overrides
nn.Module._load_from_state_dict
to support loadingstate_dict
without wrap ema module withBaseAveragedModel
.In OpenMMLab 1.0, model will not wrap ema submodule with
BaseAveragedModel
, and the ema weight key in state_dict will miss module prefix. Therefore,BaseAveragedModel
need to automatically add themodule
prefix if the corresponding key instate_dict
misses it.- 参数
state_dict (dict) – A dict containing parameters and persistent buffers.
prefix (str) – The prefix for parameters and buffers used in this module
local_metadata (dict) – a dict containing the metadata for this module.
strict (bool) – Whether to strictly enforce that the keys in
state_dict
withprefix
match the names of parameters and buffers in this modulemissing_keys (List[str]) – if
strict=True
, add missing keys to this listunexpected_keys (List[str]) – if
strict=True
, add unexpected keys to this listerror_msgs (List[str]) – error messages should be added to this list, and will be reported together in
load_state_dict()
.
- sync_buffers(model: torch.nn.Module) None ¶
Copy buffer from model to averaged model.
- 参数
model (nn.Module) – The model whose parameters will be averaged.
- sync_parameters(model: torch.nn.Module) None ¶
Copy buffer and parameters from model to averaged model.
- 参数
model (nn.Module) – The model whose parameters will be averaged.
- class mmagic.models.base_models.RampUpEMA(model: torch.nn.Module, interval: int = 1, ema_kimg: int = 10, ema_rampup: float = 0.05, batch_size: int = 32, eps: float = 1e-08, start_iter: int = 0, device: Optional[torch.device] = None, update_buffers: bool = False)[源代码]¶
Bases:
mmengine.model.BaseAveragedModel
Implements the exponential moving average with ramping up momentum.
Ref: https://github.com/NVlabs/stylegan3/blob/master/training/training_loop.py # noqa
- 参数
model (nn.Module) – The model to be averaged.
interval (int) – Interval between two updates. Defaults to 1.
ema_kimg (int, optional) – EMA kimgs. Defaults to 10.
ema_rampup (float, optional) – Ramp up rate. Defaults to 0.05.
batch_size (int, optional) – Global batch size. Defaults to 32.
eps (float, optional) – Ramp up epsilon. Defaults to 1e-8.
start_iter (int, optional) – EMA start iter. Defaults to 0.
device (torch.device, optional) – If provided, the averaged model will be stored on the
device
. Defaults to None.update_buffers (bool) – if True, it will compute running averages for both the parameters and the buffers of the model. Defaults to False.
- static rampup(steps, ema_kimg=10, ema_rampup=0.05, batch_size=4, eps=1e-08)¶
Ramp up ema momentum.
- 参数
steps –
ema_kimg (int, optional) – Half-life of the exponential moving average of generator weights. Defaults to 10.
ema_rampup (float, optional) – EMA ramp-up coefficient.If set to None, then rampup will be disabled. Defaults to 0.05.
batch_size (int, optional) – Total batch size for one training iteration. Defaults to 4.
eps (float, optional) – Epsiolon to avoid
batch_size
divided by zero. Defaults to 1e-8.
- 返回
Updated momentum.
- 返回类型
dict
- avg_func(averaged_param: torch.Tensor, source_param: torch.Tensor, steps: int) None ¶
Compute the moving average of the parameters using exponential moving average.
- 参数
averaged_param (Tensor) – The averaged parameters.
source_param (Tensor) – The source parameters.
steps (int) – The number of times the parameters have been updated.
- _load_from_state_dict(state_dict: dict, prefix: str, local_metadata: dict, strict: bool, missing_keys: list, unexpected_keys: list, error_msgs: List[str]) None ¶
Overrides
nn.Module._load_from_state_dict
to support loadingstate_dict
without wrap ema module withBaseAveragedModel
.In OpenMMLab 1.0, model will not wrap ema submodule with
BaseAveragedModel
, and the ema weight key in state_dict will miss module prefix. Therefore,BaseAveragedModel
need to automatically add themodule
prefix if the corresponding key instate_dict
misses it.- 参数
state_dict (dict) – A dict containing parameters and persistent buffers.
prefix (str) – The prefix for parameters and buffers used in this module
local_metadata (dict) – a dict containing the metadata for this module.
strict (bool) – Whether to strictly enforce that the keys in
state_dict
withprefix
match the names of parameters and buffers in this modulemissing_keys (List[str]) – if
strict=True
, add missing keys to this listunexpected_keys (List[str]) – if
strict=True
, add unexpected keys to this listerror_msgs (List[str]) – error messages should be added to this list, and will be reported together in
load_state_dict()
.
- sync_buffers(model: torch.nn.Module) None ¶
Copy buffer from model to averaged model.
- 参数
model (nn.Module) – The model whose parameters will be averaged.
- sync_parameters(model: torch.nn.Module) None ¶
Copy buffer and parameters from model to averaged model.
- 参数
model (nn.Module) – The model whose parameters will be averaged.
- class mmagic.models.base_models.BaseConditionalGAN(generator: ModelType, discriminator: Optional[ModelType] = None, data_preprocessor: Optional[Union[dict, mmengine.Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, noise_size: Optional[int] = None, num_classes: Optional[int] = None, ema_config: Optional[Dict] = None, loss_config: Optional[Dict] = None)[源代码]¶
Bases:
mmagic.models.base_models.base_gan.BaseGAN
Base class for Conditional GAM models.
- 参数
generator (ModelType) – The config or model of the generator.
discriminator (Optional[ModelType]) – The config or model of the discriminator. Defaults to None.
data_preprocessor (Optional[Union[dict, Config]]) – The pre-process config or
DataPreprocessor
.generator_steps (int) – The number of times the generator is completely updated before the discriminator is updated. Defaults to 1.
discriminator_steps (int) – The number of times the discriminator is completely updated before the generator is updated. Defaults to 1.
noise_size (Optional[int]) – Size of the input noise vector. Default to None.
num_classes (Optional[int]) – The number classes you would like to generate. Defaults to None.
ema_config (Optional[Dict]) – The config for generator’s exponential moving average setting. Defaults to None.
- label_fn(label: mmagic.utils.typing.LabelVar = None, num_batches: int = 1) torch.Tensor ¶
Sampling function for label. There are three scenarios in this function:
If label is a callable function, sample num_batches of labels with passed label.
If label is None, sample num_batches of labels in range of [0, self.num_classes-1] uniformly.
If label is a torch.Tensor, check the range of the tensor is in [0, self.num_classes-1]. If all values are in valid range, directly return label.
- 参数
label (Union[Tensor, Callable, List[int], None]) – You can directly give a batch of label through a
torch.Tensor
or offer a callable function to sample a batch of label data. Otherwise, theNone
indicates to use the default label sampler. Defaults to None.num_batches (int, optional) – The number of batches label want to sample. If label is a Tensor, this will be ignored. Defaults to 1.
- 返回
Sampled label tensor.
- 返回类型
Tensor
- data_sample_to_label(data_sample: mmagic.structures.DataSample) Optional[torch.Tensor] ¶
Get labels from input data_sample and pack to torch.Tensor. If no label is found in the passed data_sample, None would be returned.
- 参数
data_sample (DataSample) – Input data samples.
- 返回
Packed label tensor.
- 返回类型
Optional[torch.Tensor]
- static _get_valid_num_classes(num_classes: Optional[int], generator: ModelType, discriminator: Optional[ModelType]) int ¶
Try to get the value of num_classes from input, generator and discriminator and check the consistency of these values. If no conflict is found, return the num_classes.
- 参数
num_classes (Optional[int]) – num_classes passed to BaseConditionalGAN_refactor’s initialize function.
generator (ModelType) – The config or the model of generator.
discriminator (Optional[ModelType]) – The config or model of discriminator.
- 返回
The number of classes to be generated.
- 返回类型
int
- forward(inputs: mmagic.utils.typing.ForwardInputs, data_samples: Optional[list] = None, mode: Optional[str] = None) List[mmagic.structures.DataSample] ¶
Sample images with the given inputs. If forward mode is ‘ema’ or ‘orig’, the image generated by corresponding generator will be returned. If forward mode is ‘ema/orig’, images generated by original generator and EMA generator will both be returned in a dict.
- 参数
inputs (ForwardInputs) – Dict containing the necessary information (e.g. noise, num_batches, mode) to generate image.
data_samples (Optional[list]) – Data samples collated by
data_preprocessor
. Defaults to None.mode (Optional[str]) – mode is not used in
BaseConditionalGAN
. Defaults to None.
- 返回
Generated images or image dict.
- 返回类型
List[DataSample]
- train_generator(inputs: dict, data_samples: List[mmagic.structures.DataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor] ¶
Training function for discriminator. All GANs should implement this function by themselves.
- 参数
inputs (dict) – Inputs from dataloader.
data_samples (List[DataSample]) – Data samples from dataloader.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- 返回
A
dict
of tensor for logging.- 返回类型
Dict[str, Tensor]
- train_discriminator(inputs: dict, data_samples: List[mmagic.structures.DataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor] ¶
Training function for discriminator. All GANs should implement this function by themselves.
- 参数
inputs (dict) – Inputs from dataloader.
data_samples (List[DataSample]) – Data samples from dataloader.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- 返回
A
dict
of tensor for logging.- 返回类型
Dict[str, Tensor]
- class mmagic.models.base_models.BaseEditModel(generator: dict, pixel_loss: dict, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None, init_cfg: Optional[dict] = None, data_preprocessor: Optional[dict] = None)[源代码]¶
Bases:
mmengine.model.BaseModel
Base model for image and video editing.
It must contain a generator that takes frames as inputs and outputs an interpolated frame. It also has a pixel-wise loss for training.
- 参数
generator (dict) – Config for the generator structure.
pixel_loss (dict) – Config for pixel-wise loss.
train_cfg (dict) – Config for training. Default: None.
test_cfg (dict) – Config for testing. Default: None.
init_cfg (dict, optional) – The weight initialized config for
BaseModule
.data_preprocessor (dict, optional) – The pre-process config of
BaseDataPreprocessor
.
- init_cfg¶
Initialization config dict.
- Type
dict, optional
- data_preprocessor¶
Used for pre-processing data sampled by dataloader to the format accepted by
forward()
. Default: None.- Type
BaseDataPreprocessor
- forward(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, mode: str = 'tensor', **kwargs) Union[torch.Tensor, List[mmagic.structures.DataSample], dict] ¶
Returns losses or predictions of training, validation, testing, and simple inference process.
forward
method of BaseModel is an abstract method, its subclasses must implement this method.Accepts
inputs
anddata_samples
processed bydata_preprocessor
, and returns results according to mode arguments.During non-distributed training, validation, and testing process,
forward
will be called byBaseModel.train_step
,BaseModel.val_step
andBaseModel.val_step
directly.During distributed data parallel training process,
MMSeparateDistributedDataParallel.train_step
will first callDistributedDataParallel.forward
to enable automatic gradient synchronization, and then callforward
to get training loss.- 参数
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor
.mode (str) –
mode should be one of
loss
,predict
andtensor
. Default: ‘tensor’.loss
: Called bytrain_step
and return lossdict
used for loggingpredict
: Called byval_step
andtest_step
and return list ofBaseDataElement
results used for computing metric.tensor
: Called by custom use to getTensor
type results.
- 返回
If
mode == loss
, return adict
of loss tensor used for backward and logging.If
mode == predict
, return alist
ofBaseDataElement
for computing metric and getting inference result.If
mode == tensor
, return a tensor ortuple
of tensor ordict
or tensor for custom use.
- 返回类型
ForwardResults
- convert_to_datasample(predictions: mmagic.structures.DataSample, data_samples: mmagic.structures.DataSample, inputs: Optional[torch.Tensor]) List[mmagic.structures.DataSample] ¶
Add predictions and destructed inputs (if passed) to data samples.
- 参数
predictions (DataSample) – The predictions of the model.
data_samples (DataSample) – The data samples loaded from dataloader.
inputs (Optional[torch.Tensor]) – The input of model. Defaults to None.
- 返回
Modified data samples.
- 返回类型
List[DataSample]
- forward_tensor(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, **kwargs) torch.Tensor ¶
Forward tensor. Returns result of simple forward.
- 参数
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor
.
- 返回
result of simple forward.
- 返回类型
Tensor
- forward_inference(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, **kwargs) mmagic.structures.DataSample ¶
Forward inference. Returns predictions of validation, testing, and simple inference.
- 参数
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor
.
- 返回
predictions.
- 返回类型
- forward_train(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, **kwargs) Dict[str, torch.Tensor] ¶
Forward training. Returns dict of losses of training.
- 参数
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor
.
- 返回
Dict of losses.
- 返回类型
dict
- class mmagic.models.base_models.BaseGAN(generator: ModelType, discriminator: Optional[ModelType] = None, data_preprocessor: Optional[Union[dict, mmengine.Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, noise_size: Optional[int] = None, ema_config: Optional[Dict] = None, loss_config: Optional[Dict] = None)[源代码]¶
Bases:
mmengine.model.BaseModel
Base class for GAN models.
- 参数
generator (ModelType) – The config or model of the generator.
discriminator (Optional[ModelType]) – The config or model of the discriminator. Defaults to None.
data_preprocessor (Optional[Union[dict, Config]]) – The pre-process config or
DataPreprocessor
.generator_steps (int) – The number of times the generator is completely updated before the discriminator is updated. Defaults to 1.
discriminator_steps (int) – The number of times the discriminator is completely updated before the generator is updated. Defaults to 1.
ema_config (Optional[Dict]) – The config for generator’s exponential moving average setting. Defaults to None.
- property generator_steps: int¶
The number of times the generator is completely updated before the discriminator is updated.
- Type
int
- property discriminator_steps: int¶
The number of times the discriminator is completely updated before the generator is updated.
- Type
int
- property device: torch.device¶
Get current device of the model.
- 返回
The current device of the model.
- 返回类型
torch.device
- property with_ema_gen: bool¶
Whether the GAN adopts exponential moving average.
- 返回
- If True, means this GAN model is adopted to exponential
moving average and vice versa.
- 返回类型
bool
- static gather_log_vars(log_vars_list: List[Dict[str, torch.Tensor]]) Dict[str, torch.Tensor] ¶
Gather a list of log_vars. :param log_vars_list: List[Dict[str, Tensor]]
- 返回
Dict[str, Tensor]
- _init_loss(loss_config: Optional[Dict] = None) None ¶
Initialize customized loss modules.
If loss_config is a dict, we allow kinds of value for each field.
- loss_config is None: Users will implement all loss calculations
in their own function. Weights for each loss terms are hard coded.
- loss_config is dict of scalar or string: Users will implement all
loss calculations and use passed loss_config to control the weight or behavior of the loss calculation. Users will unpack and use each field in this dict by themselves.
loss_config = dict(gp_norm_mode=’HWC’, gp_loss_weight=10)
- loss_config is dict of dict: Each field in loss_config will
used to build a corresponding loss module. And use loss calculation function predefined by
BaseGAN
to calculate the loss.loss_config = dict()
示例
- loss_config = dict(
# BaseGAN pre-defined fields gan_loss=dict(type=’GANLoss’, gan_type=’wgan-logistic-ns’), disc_auxiliary_loss=dict(
type=’R1GradientPenalty’, loss_weight=10. / 2., interval=2, norm_mode=’HWC’, data_info=dict(
real_data=’real_imgs’, discriminator=’disc’)),
- gen_auxiliary_loss=dict(
type=’GeneratorPathRegularizer’, loss_weight=2, pl_batch_shrink=2, interval=g_reg_interval, data_info=dict(
generator=’gen’, num_batches=’batch_size’)),
# user-defined field for loss weights or loss calculation my_loss_2=dict(weight=2, norm_mode=’L1’), my_loss_3=2, my_loss_4_norm_type=’L2’)
- 参数
loss_config (Optional[Dict], optional) – Loss config used to build loss modules or define the loss weights. Defaults to None.
- noise_fn(noise: mmagic.utils.typing.NoiseVar = None, num_batches: int = 1)¶
Sampling function for noise. There are three scenarios in this function:
If noise is a callable function, sample num_batches of noise with passed noise.
If noise is None, sample num_batches of noise from gaussian distribution.
If noise is a torch.Tensor, directly return noise.
- 参数
noise (Union[Tensor, Callable, List[int], None]) – You can directly give a batch of label through a
torch.Tensor
or offer a callable function to sample a batch of label data. Otherwise, theNone
indicates to use the default noise sampler. Defaults to None.num_batches (int, optional) – The number of batches label want to sample. If label is a Tensor, this will be ignored. Defaults to 1.
- 返回
Sampled noise tensor.
- 返回类型
Tensor
- _init_ema_model(ema_config: dict)¶
Initialize a EMA model corresponding to the given ema_config. If ema_config is an empty dict or None, EMA model will not be initialized.
- 参数
ema_config (dict) – Config to initialize the EMA model.
- _get_valid_model(batch_inputs: mmagic.utils.typing.ForwardInputs) str ¶
Try to get the valid forward model from inputs.
If forward model is defined in batch_inputs, it will be used as forward model.
If forward model is not defined in batch_inputs, ‘ema’ will returned if :property:`with_ema_gen` is true. Otherwise, ‘orig’ will be returned.
- 参数
batch_inputs (ForwardInputs) – Inputs passed to
forward()
.- 返回
- Forward model to generate image. (‘orig’, ‘ema’ or
’ema/orig’).
- 返回类型
str
- forward(inputs: mmagic.utils.typing.ForwardInputs, data_samples: Optional[list] = None, mode: Optional[str] = None) mmagic.utils.typing.SampleList ¶
Sample images with the given inputs. If forward mode is ‘ema’ or ‘orig’, the image generated by corresponding generator will be returned. If forward mode is ‘ema/orig’, images generated by original generator and EMA generator will both be returned in a dict.
- 参数
batch_inputs (ForwardInputs) – Dict containing the necessary information (e.g. noise, num_batches, mode) to generate image.
data_samples (Optional[list]) – Data samples collated by
data_preprocessor
. Defaults to None.mode (Optional[str]) – mode is not used in
BaseGAN
. Defaults to None.
- 返回
A list of
DataSample
contain generated results.- 返回类型
SampleList
- val_step(data: dict) mmagic.utils.typing.SampleList ¶
Gets the generated image of given data.
Calls
self.data_preprocessor(data)
andself(inputs, data_sample, mode=None)
in order. Return the generated results which will be passed to evaluator.- 参数
data (dict) – Data sampled from metric specific sampler. More details in Metrics and Evaluator.
- 返回
Generated image or image dict.
- 返回类型
SampleList
- test_step(data: dict) mmagic.utils.typing.SampleList ¶
Gets the generated image of given data. Same as
val_step()
.- 参数
data (dict) – Data sampled from metric specific sampler. More details in Metrics and Evaluator.
- 返回
Generated image or image dict.
- 返回类型
List[DataSample]
- train_step(data: dict, optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor] ¶
Train GAN model. In the training of GAN models, generator and discriminator are updated alternatively. In MMagic’s design, self.train_step is called with data input. Therefore we always update discriminator, whose updating is relay on real data, and then determine if the generator needs to be updated based on the current number of iterations. More details about whether to update generator can be found in
should_gen_update()
.- 参数
data (dict) – Data sampled from dataloader.
optim_wrapper (OptimWrapperDict) – OptimWrapperDict instance contains OptimWrapper of generator and discriminator.
- 返回
A
dict
of tensor for logging.- 返回类型
Dict[str, torch.Tensor]
- _get_gen_loss(out_dict)¶
- _get_disc_loss(out_dict)¶
- train_generator(inputs: dict, data_samples: List[mmagic.structures.DataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor] ¶
Training function for discriminator. All GANs should implement this function by themselves.
- 参数
inputs (dict) – Inputs from dataloader.
data_samples (List[DataSample]) – Data samples from dataloader.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- 返回
A
dict
of tensor for logging.- 返回类型
Dict[str, Tensor]
- train_discriminator(inputs: dict, data_samples: List[mmagic.structures.DataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor] ¶
Training function for discriminator. All GANs should implement this function by themselves.
- 参数
inputs (dict) – Inputs from dataloader.
data_samples (List[DataSample]) – Data samples from dataloader.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- 返回
A
dict
of tensor for logging.- 返回类型
Dict[str, Tensor]
- class mmagic.models.base_models.BaseMattor(data_preprocessor: Union[dict, mmengine.config.Config], backbone: dict, init_cfg: Optional[dict] = None, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None)[源代码]¶
Bases:
mmengine.model.BaseModel
Base class for trimap-based matting models.
A matting model must contain a backbone which produces pred_alpha, a dense prediction with the same height and width of input image. In some cases (such as DIM), the model has a refiner which refines the prediction of the backbone.
Subclasses should overwrite the following functions:
_forward_train()
, to return a loss_forward_test()
, to return a prediction_forward()
, to return raw tensors
For test, this base class provides functions to resize inputs and post-process pred_alphas to get predictions
- 参数
backbone (dict) – Config of backbone.
data_preprocessor (dict) – Config of data_preprocessor. See
MattorPreprocessor
for details.init_cfg (dict, optional) – Initialization config dict.
train_cfg (dict) – Config of training. Customized by subclassesCustomized bu In
train_cfg
,train_backbone
should be specified. If the model has a refiner,train_refiner
should be specified.test_cfg (dict) – Config of testing. In
test_cfg
, If the model has a refiner,train_refiner
should be specified.
- resize_inputs(batch_inputs: torch.Tensor) torch.Tensor ¶
Pad or interpolate images and trimaps to multiple of given factor.
- restore_size(pred_alpha: torch.Tensor, data_sample: mmagic.structures.DataSample) torch.Tensor ¶
Restore the predicted alpha to the original shape.
The shape of the predicted alpha may not be the same as the shape of original input image. This function restores the shape of the predicted alpha.
- 参数
pred_alpha (torch.Tensor) – A single predicted alpha of shape (1, H, W).
data_sample (DataSample) – Data sample containing original shape as meta data.
- 返回
The reshaped predicted alpha.
- 返回类型
torch.Tensor
- postprocess(batch_pred_alpha: torch.Tensor, data_samples: mmagic.structures.DataSample) List[mmagic.structures.DataSample] ¶
Post-process alpha predictions.
- This function contains the following steps:
Restore padding or interpolation
Mask alpha prediction with trimap
Clamp alpha prediction to 0-1
Convert alpha prediction to uint8
Pack alpha prediction into DataSample
Currently only batch_size 1 is actually supported.
- 参数
batch_pred_alpha (torch.Tensor) – A batch of predicted alpha of shape (N, 1, H, W).
data_samples (List[DataSample]) – List of data samples.
- 返回
- A list of predictions.
Each data sample contains a pred_alpha, which is a torch.Tensor with dtype=uint8, device=cuda:0
- 返回类型
List[DataSample]
- forward(inputs: torch.Tensor, data_samples: DataSamples = None, mode: str = 'tensor') List[mmagic.structures.DataSample] ¶
General forward function.
- 参数
inputs (torch.Tensor) – A batch of inputs. with image and trimap concatenated alone channel dimension.
data_samples (List[DataSample], optional) – A list of data samples, containing: - Ground-truth alpha / foreground / background to compute loss - other meta information
mode (str) –
mode should be one of
loss
,predict
andtensor
. Default: ‘tensor’.loss
: Called bytrain_step
and return lossdict
used for loggingpredict
: Called byval_step
andtest_step
and return list ofBaseDataElement
results used for computing metric.tensor
: Called by custom use to getTensor
type results.
- 返回
Sequence of predictions packed into DataElement
- 返回类型
List[DataElement]
- convert_to_datasample(predictions: List[mmagic.structures.DataSample], data_samples: mmagic.structures.DataSample) List[mmagic.structures.DataSample] ¶
Add predictions to data samples.
- 参数
predictions (List[DataSample]) – The predictions of the model.
data_samples (DataSample) – The data samples loaded from dataloader.
- 返回
Modified data samples.
- 返回类型
List[DataSample]
- class mmagic.models.base_models.BaseTranslationModel(generator, discriminator, default_domain: str, reachable_domains: List[str], related_domains: List[str], data_preprocessor, discriminator_steps: int = 1, disc_init_steps: int = 0, real_img_key: str = 'real_img', loss_config: Optional[dict] = None)[源代码]¶
Bases:
mmengine.model.BaseModel
Base Translation Model.
Translation models can transfer images from one domain to another. Domain information like default_domain, reachable_domains are needed to initialize the class. And we also provide query functions like is_domain_reachable, get_other_domains.
You can get a specific generator based on the domain, and by specifying target_domain in the forward function, you can decide the domain of generated images. Considering the difference among different image translation models, we only provide the external interfaces mentioned above. When you implement image translation with a specific method, you can inherit both BaseTranslationModel and the method (e.g BaseGAN) and implement abstract methods.
- 参数
default_domain (str) – Default output domain.
reachable_domains (list[str]) – Domains that can be generated by the model.
related_domains (list[str]) – Domains involved in training and testing. reachable_domains must be contained in related_domains. However, related_domains may contain source domains that are used to retrieve source images from data_batch but not in reachable_domains.
discriminator_steps (int) – The number of times the discriminator is completely updated before the generator is updated. Defaults to 1.
disc_init_steps (int) – The number of initial steps used only to train discriminators.
- init_weights()¶
Initialize weights for the module dict.
- 参数
pretrained (str, optional) – Path for pretrained weights. If given None, pretrained weights will not be loaded. Default: None.
- get_module(module)¶
Get nn.ModuleDict to fit the MMDistributedDataParallel interface.
- 参数
module (MMDistributedDataParallel | nn.ModuleDict) – The input module that needs processing.
- 返回
The ModuleDict of multiple networks.
- 返回类型
nn.ModuleDict
- forward(img, test_mode=False, **kwargs)¶
Forward function.
- 参数
img (tensor) – Input image tensor.
test_mode (bool) – Whether in test mode or not. Default: False.
kwargs (dict) – Other arguments.
- forward_train(img, target_domain, **kwargs)¶
Forward function for training.
- 参数
img (tensor) – Input image tensor.
target_domain (str) – Target domain of output image.
kwargs (dict) – Other arguments.
- 返回
Forward results.
- 返回类型
dict
- forward_test(img, target_domain, **kwargs)¶
Forward function for testing.
- 参数
img (tensor) – Input image tensor.
target_domain (str) – Target domain of output image.
kwargs (dict) – Other arguments.
- 返回
Forward results.
- 返回类型
dict
- is_domain_reachable(domain)¶
Whether image of this domain can be generated.
- get_other_domains(domain)¶
get other domains.
- _get_target_generator(domain)¶
get target generator.
- _get_target_discriminator(domain)¶
get target discriminator.
- translation(image, target_domain=None, **kwargs)¶
Translation Image to target style.
- 参数
image (tensor) – Image tensor with a shape of (N, C, H, W).
target_domain (str, optional) – Target domain of output image. Default to None.
- 返回
Image tensor of target style.
- 返回类型
dict
- class mmagic.models.base_models.BasicInterpolator(generator: dict, pixel_loss: dict, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None, required_frames: int = 2, step_frames: int = 1, init_cfg: Optional[dict] = None, data_preprocessor: Optional[dict] = None)[源代码]¶
Bases:
mmagic.models.base_models.base_edit_model.BaseEditModel
Basic model for video interpolation.
It must contain a generator that takes frames as inputs and outputs an interpolated frame. It also has a pixel-wise loss for training.
- 参数
generator (dict) – Config for the generator structure.
pixel_loss (dict) – Config for pixel-wise loss.
train_cfg (dict) – Config for training. Default: None.
test_cfg (dict) – Config for testing. Default: None.
required_frames (int) – Required frames in each process. Default: 2
step_frames (int) – Step size of video frame interpolation. Default: 1
init_cfg (dict, optional) – The weight initialized config for
BaseModule
.data_preprocessor (dict, optional) – The pre-process config of
BaseDataPreprocessor
.
- init_cfg¶
Initialization config dict.
- Type
dict, optional
- data_preprocessor¶
Used for pre-processing data sampled by dataloader to the format accepted by
forward()
.- Type
BaseDataPreprocessor
- split_frames(input_tensors: torch.Tensor) torch.Tensor ¶
split input tensors for inference.
- 参数
input_tensors (Tensor) – Tensor of input frames with shape [1, t, c, h, w]
- 返回
Split tensor with shape [t-1, 2, c, h, w]
- 返回类型
Tensor
- static merge_frames(input_tensors: torch.Tensor, output_tensors: torch.Tensor) list ¶
merge input frames and output frames.
Interpolate a frame between the given two frames.
- Merged from
[[in1, in2], [in2, in3], [in3, in4], …] [[out1], [out2], [out3], …]
- to
[in1, out1, in2, out2, in3, out3, in4, …]
- 参数
input_tensors (Tensor) – The input frames with shape [n, 2, c, h, w]
output_tensors (Tensor) – The output frames with shape [n, 1, c, h, w].
- 返回
The final frames.
- 返回类型
list[np.array]
- class mmagic.models.base_models.OneStageInpaintor(data_preprocessor: Union[dict, mmengine.config.Config], encdec: dict, disc: Optional[dict] = None, loss_gan: Optional[dict] = None, loss_gp: Optional[dict] = None, loss_disc_shift: Optional[dict] = None, loss_composed_percep: Optional[dict] = None, loss_out_percep: bool = False, loss_l1_hole: Optional[dict] = None, loss_l1_valid: Optional[dict] = None, loss_tv: Optional[dict] = None, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None, init_cfg: Optional[dict] = None)[源代码]¶
Bases:
mmengine.model.BaseModel
Standard one-stage inpaintor with commonly used losses.
An inpaintor must contain an encoder-decoder style generator to inpaint masked regions. A discriminator will be adopted when adversarial training is needed.
In this class, we provide a common interface for inpaintors. For other inpaintors, only some funcs may be modified to fit the input style or training schedule.
- 参数
data_preprocessor (dict) – Config of data_preprocessor.
encdec (dict) – Config for encoder-decoder style generator.
disc (dict) – Config for discriminator.
loss_gan (dict) – Config for adversarial loss.
loss_gp (dict) – Config for gradient penalty loss.
loss_disc_shift (dict) – Config for discriminator shift loss.
loss_composed_percep (dict) – Config for perceptual and style loss with composed image as input.
loss_out_percep (dict) – Config for perceptual and style loss with direct output as input.
loss_l1_hole (dict) – Config for l1 loss in the hole.
loss_l1_valid (dict) – Config for l1 loss in the valid region.
loss_tv (dict) – Config for total variation loss.
train_cfg (dict) – Configs for training scheduler. disc_step must be contained for indicates the discriminator updating steps in each training step.
test_cfg (dict) – Configs for testing scheduler.
init_cfg (dict, optional) – Initialization config dict.
- forward(inputs: torch.Tensor, data_samples: Optional[mmagic.utils.SampleList], mode: str = 'tensor') FORWARD_RETURN_TYPE ¶
Forward function.
- 参数
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor
.mode (str) –
mode should be one of
loss
,predict
andtensor
. Default: ‘tensor’.loss
: Called bytrain_step
and return lossdict
used for loggingpredict
: Called byval_step
andtest_step
and return list ofBaseDataElement
results used for computing metric.tensor
: Called by custom use to getTensor
type results.
- 返回
If
mode == loss
, return adict
of loss tensor used for backward and logging.If
mode == predict
, return alist
ofBaseDataElement
for computing metric and getting inference result.If
mode == tensor
, return a tensor ortuple
of tensor ordict
or tensor for custom use.
- 返回类型
ForwardResults
- train_step(data: List[dict], optim_wrapper: mmengine.optim.OptimWrapperDict) dict ¶
Train step function.
In this function, the inpaintor will finish the train step following the pipeline:
get fake res/image
optimize discriminator (if have)
optimize generator
If self.train_cfg.disc_step > 1, the train step will contain multiple iterations for optimizing discriminator with different input data and only one iteration for optimizing generator after disc_step iterations for discriminator.
- 参数
data (List[dict]) – Batch of data as input.
optim_wrapper (dict[torch.optim.Optimizer]) – Dict with optimizers for generator and discriminator (if have).
- 返回
- Dict with loss, information for logger, the number of
samples and results for visualization.
- 返回类型
dict
- abstract forward_train(*args, **kwargs) None ¶
Forward function for training.
In this version, we do not use this interface.
- forward_train_d(data_batch: torch.Tensor, is_real: bool, is_disc: bool) dict ¶
Forward function in discriminator training step.
In this function, we compute the prediction for each data batch (real or fake). Meanwhile, the standard gan loss will be computed with several proposed losses for stable training.
- 参数
data_batch (torch.Tensor) – Batch of real data or fake data.
is_real (bool) – If True, the gan loss will regard this batch as real data. Otherwise, the gan loss will regard this batch as fake data.
is_disc (bool) – If True, this function is called in discriminator training step. Otherwise, this function is called in generator training step. This will help us to compute different types of adversarial loss, like LSGAN.
- 返回
Contains the loss items computed in this function.
- 返回类型
dict
- generator_loss(fake_res: torch.Tensor, fake_img: torch.Tensor, gt: torch.Tensor, mask: torch.Tensor, masked_img: torch.Tensor) Tuple[dict, dict] ¶
Forward function in generator training step.
In this function, we mainly compute the loss items for generator with the given (fake_res, fake_img). In general, the fake_res is the direct output of the generator and the fake_img is the composition of direct output and ground-truth image.
- 参数
fake_res (torch.Tensor) – Direct output of the generator.
fake_img (torch.Tensor) – Composition of fake_res and ground-truth image.
gt (torch.Tensor) – Ground-truth image.
mask (torch.Tensor) – Mask image.
masked_img (torch.Tensor) – Composition of mask image and ground-truth image.
- 返回
Dict contains the results computed within this function for visualization and dict contains the loss items computed in this function.
- 返回类型
tuple(dict)
- forward_tensor(inputs: torch.Tensor, data_samples: mmagic.utils.SampleList) Tuple[torch.Tensor, torch.Tensor] ¶
Forward function in tensor mode.
- 参数
inputs (torch.Tensor) – Input tensor.
data_samples (List[dict]) – List of data sample dict.
- 返回
- Direct output of the generator and composition of fake_res
and ground-truth image.
- 返回类型
tuple
- forward_test(inputs: torch.Tensor, data_samples: mmagic.utils.SampleList) mmagic.structures.DataSample ¶
Forward function for testing.
- 参数
inputs (torch.Tensor) – Input tensor.
data_samples (List[dict]) – List of data sample dict.
- 返回
- List of prediction saved in
DataSample.
- 返回类型
predictions (List[DataSample])
- convert_to_datasample(predictions: mmagic.structures.DataSample, data_samples: mmagic.structures.DataSample, inputs: Optional[torch.Tensor]) List[mmagic.structures.DataSample] ¶
Add predictions and destructed inputs (if passed) to data samples.
- 参数
predictions (DataSample) – The predictions of the model.
data_samples (DataSample) – The data samples loaded from dataloader.
inputs (Optional[torch.Tensor]) – The input of model. Defaults to None.
- 返回
Modified data samples.
- 返回类型
List[DataSample]
- forward_dummy(x: torch.Tensor) torch.Tensor ¶
Forward dummy function for getting flops.
- 参数
x (torch.Tensor) – Input tensor with shape of (n, c, h, w).
- 返回
Results tensor with shape of (n, 3, h, w).
- 返回类型
torch.Tensor
- class mmagic.models.base_models.TwoStageInpaintor(data_preprocessor: Union[dict, mmengine.config.Config], encdec: dict, disc: Optional[dict] = None, loss_gan: Optional[dict] = None, loss_gp: Optional[dict] = None, loss_disc_shift: Optional[dict] = None, loss_composed_percep: Optional[dict] = None, loss_out_percep: bool = False, loss_l1_hole: Optional[dict] = None, loss_l1_valid: Optional[dict] = None, loss_tv: Optional[dict] = None, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None, init_cfg: Optional[dict] = None, stage1_loss_type: Optional[Sequence[str]] = ('loss_l1_hole',), stage2_loss_type: Optional[Sequence[str]] = ('loss_l1_hole', 'loss_gan'), input_with_ones: bool = True, disc_input_with_mask: bool = False)[源代码]¶
Bases:
mmagic.models.base_models.one_stage.OneStageInpaintor
Standard two-stage inpaintor with commonly used losses. A two-stage inpaintor contains two encoder-decoder style generators to inpaint masked regions. Currently, we support these loss types in each of two stage inpaintors:
[‘loss_gan’, ‘loss_l1_hole’, ‘loss_l1_valid’, ‘loss_composed_percep’, ‘loss_out_percep’, ‘loss_tv’] The stage1_loss_type and stage2_loss_type should be chosen from these loss types.
- 参数
data_preprocessor (dict) – Config of data_preprocessor.
encdec (dict) – Config for encoder-decoder style generator.
disc (dict) – Config for discriminator.
loss_gan (dict) – Config for adversarial loss.
loss_gp (dict) – Config for gradient penalty loss.
loss_disc_shift (dict) – Config for discriminator shift loss.
loss_composed_percep (dict) – Config for perceptual and style loss with composed image as input.
loss_out_percep (dict) – Config for perceptual and style loss with direct output as input.
loss_l1_hole (dict) – Config for l1 loss in the hole.
loss_l1_valid (dict) – Config for l1 loss in the valid region.
loss_tv (dict) – Config for total variation loss.
train_cfg (dict) – Configs for training scheduler. disc_step must be contained for indicates the discriminator updating steps in each training step.
test_cfg (dict) – Configs for testing scheduler.
init_cfg (dict, optional) – Initialization config dict.
stage1_loss_type (tuple[str]) – Contains the loss names used in the first stage model. Default: (‘loss_l1_hole’).
stage2_loss_type (tuple[str]) – Contains the loss names used in the second stage model. Default: (‘loss_l1_hole’, ‘loss_gan’).
input_with_ones (bool) – Whether to concatenate an extra ones tensor in input. Default: True.
disc_input_with_mask (bool) – Whether to add mask as input in discriminator. Default: False.
- forward_tensor(inputs: torch.Tensor, data_samples: mmagic.utils.SampleList) Tuple[torch.Tensor, torch.Tensor] ¶
Forward function in tensor mode.
- 参数
inputs (torch.Tensor) – Input tensor.
data_samples (List[dict]) – List of data sample dict.
- 返回
Dict contains output results.
- 返回类型
dict
- two_stage_loss(stage1_data: dict, stage2_data: dict, gt: torch.Tensor, mask: torch.Tensor, masked_img: torch.Tensor) Tuple[dict, dict] ¶
Calculate two-stage loss.
- 参数
stage1_data (dict) – Contain stage1 results.
stage2_data (dict) – Contain stage2 results..
gt (torch.Tensor) – Ground-truth image.
mask (torch.Tensor) – Mask image.
masked_img (torch.Tensor) – Composition of mask image and ground-truth image.
- 返回
Dict contains the results computed within this function for visualization and dict contains the loss items computed in this function.
- 返回类型
tuple(dict)
- calculate_loss_with_type(loss_type: str, fake_res: torch.Tensor, fake_img: torch.Tensor, gt: torch.Tensor, mask: torch.Tensor, prefix: Optional[str] = 'stage1_') dict ¶
Calculate multiple types of losses.
- 参数
loss_type (str) – Type of the loss.
fake_res (torch.Tensor) – Direct results from model.
fake_img (torch.Tensor) – Composited results from model.
gt (torch.Tensor) – Ground-truth tensor.
mask (torch.Tensor) – Mask tensor.
prefix (str, optional) – Prefix for loss name. Defaults to ‘stage1_’. # noqa
- 返回
Contain loss value with its name.
- 返回类型
dict
- train_step(data: List[dict], optim_wrapper: mmengine.optim.OptimWrapperDict) dict ¶
Train step function.
In this function, the inpaintor will finish the train step following the pipeline: 1. get fake res/image 2. optimize discriminator (if have) 3. optimize generator
If self.train_cfg.disc_step > 1, the train step will contain multiple iterations for optimizing discriminator with different input data and only one iteration for optimizing generator after disc_step iterations for discriminator.
- 参数
data (List[dict]) – Batch of data as input.
optim_wrapper (dict[torch.optim.Optimizer]) – Dict with optimizers for generator and discriminator (if have).
- 返回
Dict with loss, information for logger, the number of samples and results for visualization.
- 返回类型
dict