mmagic.models.base_models.base_edit_model¶
Module Contents¶
Classes¶
Base model for image and video editing. |
- class mmagic.models.base_models.base_edit_model.BaseEditModel(generator: dict, pixel_loss: dict, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None, init_cfg: Optional[dict] = None, data_preprocessor: Optional[dict] = None)[source]¶
Bases:
mmengine.model.BaseModelBase model for image and video editing.
It must contain a generator that takes frames as inputs and outputs an interpolated frame. It also has a pixel-wise loss for training.
- Parameters
generator (dict) – Config for the generator structure.
pixel_loss (dict) – Config for pixel-wise loss.
train_cfg (dict) – Config for training. Default: None.
test_cfg (dict) – Config for testing. Default: None.
init_cfg (dict, optional) – The weight initialized config for
BaseModule.data_preprocessor (dict, optional) – The pre-process config of
BaseDataPreprocessor.
- init_cfg¶
Initialization config dict.
- Type
dict, optional
- data_preprocessor¶
Used for pre-processing data sampled by dataloader to the format accepted by
forward(). Default: None.- Type
BaseDataPreprocessor
- forward(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, mode: str = 'tensor', **kwargs) Union[torch.Tensor, List[mmagic.structures.DataSample], dict][source]¶
Returns losses or predictions of training, validation, testing, and simple inference process.
forwardmethod of BaseModel is an abstract method, its subclasses must implement this method.Accepts
inputsanddata_samplesprocessed bydata_preprocessor, and returns results according to mode arguments.During non-distributed training, validation, and testing process,
forwardwill be called byBaseModel.train_step,BaseModel.val_stepandBaseModel.val_stepdirectly.During distributed data parallel training process,
MMSeparateDistributedDataParallel.train_stepwill first callDistributedDataParallel.forwardto enable automatic gradient synchronization, and then callforwardto get training loss.- Parameters
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor.mode (str) –
mode should be one of
loss,predictandtensor. Default: ‘tensor’.loss: Called bytrain_stepand return lossdictused for loggingpredict: Called byval_stepandtest_stepand return list ofBaseDataElementresults used for computing metric.tensor: Called by custom use to getTensortype results.
- Returns
If
mode == loss, return adictof loss tensor used for backward and logging.If
mode == predict, return alistofBaseDataElementfor computing metric and getting inference result.If
mode == tensor, return a tensor ortupleof tensor ordictor tensor for custom use.
- Return type
ForwardResults
- convert_to_datasample(predictions: mmagic.structures.DataSample, data_samples: mmagic.structures.DataSample, inputs: Optional[torch.Tensor]) List[mmagic.structures.DataSample][source]¶
Add predictions and destructed inputs (if passed) to data samples.
- Parameters
predictions (DataSample) – The predictions of the model.
data_samples (DataSample) – The data samples loaded from dataloader.
inputs (Optional[torch.Tensor]) – The input of model. Defaults to None.
- Returns
Modified data samples.
- Return type
List[DataSample]
- forward_tensor(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, **kwargs) torch.Tensor[source]¶
Forward tensor. Returns result of simple forward.
- Parameters
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor.
- Returns
result of simple forward.
- Return type
Tensor
- forward_inference(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, **kwargs) mmagic.structures.DataSample[source]¶
Forward inference. Returns predictions of validation, testing, and simple inference.
- Parameters
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor.
- Returns
predictions.
- Return type
- forward_train(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, **kwargs) Dict[str, torch.Tensor][source]¶
Forward training. Returns dict of losses of training.
- Parameters
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor.
- Returns
Dict of losses.
- Return type
dict