Shortcuts

mmagic.models.base_models.base_edit_model

Module Contents

Classes

BaseEditModel

Base model for image and video editing.

class mmagic.models.base_models.base_edit_model.BaseEditModel(generator: dict, pixel_loss: dict, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None, init_cfg: Optional[dict] = None, data_preprocessor: Optional[dict] = None)[source]

Bases: mmengine.model.BaseModel

Base model for image and video editing.

It must contain a generator that takes frames as inputs and outputs an interpolated frame. It also has a pixel-wise loss for training.

Parameters
  • generator (dict) – Config for the generator structure.

  • pixel_loss (dict) – Config for pixel-wise loss.

  • train_cfg (dict) – Config for training. Default: None.

  • test_cfg (dict) – Config for testing. Default: None.

  • init_cfg (dict, optional) – The weight initialized config for BaseModule.

  • data_preprocessor (dict, optional) – The pre-process config of BaseDataPreprocessor.

init_cfg

Initialization config dict.

Type

dict, optional

data_preprocessor

Used for pre-processing data sampled by dataloader to the format accepted by forward(). Default: None.

Type

BaseDataPreprocessor

forward(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, mode: str = 'tensor', **kwargs) Union[torch.Tensor, List[mmagic.structures.DataSample], dict][source]

Returns losses or predictions of training, validation, testing, and simple inference process.

forward method of BaseModel is an abstract method, its subclasses must implement this method.

Accepts inputs and data_samples processed by data_preprocessor, and returns results according to mode arguments.

During non-distributed training, validation, and testing process, forward will be called by BaseModel.train_step, BaseModel.val_step and BaseModel.val_step directly.

During distributed data parallel training process, MMSeparateDistributedDataParallel.train_step will first call DistributedDataParallel.forward to enable automatic gradient synchronization, and then call forward to get training loss.

Parameters
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

  • mode (str) –

    mode should be one of loss, predict and tensor. Default: ‘tensor’.

    • loss: Called by train_step and return loss dict used for logging

    • predict: Called by val_step and test_step and return list of BaseDataElement results used for computing metric.

    • tensor: Called by custom use to get Tensor type results.

Returns

  • If mode == loss, return a dict of loss tensor used for backward and logging.

  • If mode == predict, return a list of BaseDataElement for computing metric and getting inference result.

  • If mode == tensor, return a tensor or tuple of tensor or dict or tensor for custom use.

Return type

ForwardResults

convert_to_datasample(predictions: mmagic.structures.DataSample, data_samples: mmagic.structures.DataSample, inputs: Optional[torch.Tensor]) List[mmagic.structures.DataSample][source]

Add predictions and destructed inputs (if passed) to data samples.

Parameters
  • predictions (DataSample) – The predictions of the model.

  • data_samples (DataSample) – The data samples loaded from dataloader.

  • inputs (Optional[torch.Tensor]) – The input of model. Defaults to None.

Returns

Modified data samples.

Return type

List[DataSample]

forward_tensor(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, **kwargs) torch.Tensor[source]

Forward tensor. Returns result of simple forward.

Parameters
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

Returns

result of simple forward.

Return type

Tensor

forward_inference(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, **kwargs) mmagic.structures.DataSample[source]

Forward inference. Returns predictions of validation, testing, and simple inference.

Parameters
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

Returns

predictions.

Return type

DataSample

forward_train(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, **kwargs) Dict[str, torch.Tensor][source]

Forward training. Returns dict of losses of training.

Parameters
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

Returns

Dict of losses.

Return type

dict

Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.