mmagic.models.editors.deblurganv2¶
Package Contents¶
Classes¶
Base class for all algorithmic models. |
|
Defines the discriminator for DeblurGanv2 with the specified arguments.. |
|
Defines the generator for DeblurGanv2 with the specified arguments.. |
- class mmagic.models.editors.deblurganv2.DeblurGanV2(generator: ModelType, discriminator: Optional[ModelType] = None, pixel_loss: Optional[Union[dict, str]] = None, disc_loss: Optional[Union[dict, str]] = None, adv_lambda: float = 0.001, warmup_num: int = 3, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None, init_cfg: Optional[dict] = None, data_preprocessor: Optional[dict] = None)[源代码]¶
Bases:
mmengine.model.BaseModelBase class for all algorithmic models.
BaseModel implements the basic functions of the algorithmic model, such as weights initialize, batch inputs preprocess(see more information in
BaseDataPreprocessor), parse losses, and update model parameters.Subclasses inherit from BaseModel only need to implement the forward method, which implements the logic to calculate loss and predictions, then can be trained in the runner.
实际案例
>>> @MODELS.register_module() >>> class ToyModel(BaseModel): >>> >>> def __init__(self): >>> super().__init__() >>> self.backbone = nn.Sequential() >>> self.backbone.add_module('conv1', nn.Conv2d(3, 6, 5)) >>> self.backbone.add_module('pool', nn.MaxPool2d(2, 2)) >>> self.backbone.add_module('conv2', nn.Conv2d(6, 16, 5)) >>> self.backbone.add_module('fc1', nn.Linear(16 * 5 * 5, 120)) >>> self.backbone.add_module('fc2', nn.Linear(120, 84)) >>> self.backbone.add_module('fc3', nn.Linear(84, 10)) >>> >>> self.criterion = nn.CrossEntropyLoss() >>> >>> def forward(self, batch_inputs, data_samples, mode='tensor'): >>> data_samples = torch.stack(data_samples) >>> if mode == 'tensor': >>> return self.backbone(batch_inputs) >>> elif mode == 'predict': >>> feats = self.backbone(batch_inputs) >>> predictions = torch.argmax(feats, 1) >>> return predictions >>> elif mode == 'loss': >>> feats = self.backbone(batch_inputs) >>> loss = self.criterion(feats, data_samples) >>> return dict(loss=loss)
- 参数
data_preprocessor (dict, optional) – The pre-process config of
BaseDataPreprocessor.init_cfg (dict, optional) – The weight initialized config for
BaseModule.
- data_preprocessor¶
Used for pre-processing data sampled by dataloader to the format accepted by
forward().- Type
BaseDataPreprocessor
- init_cfg¶
Initialization config dict.
- Type
dict, optional
- forward(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, mode: str = 'tensor', **kwargs) Union[torch.Tensor, List[mmagic.structures.DataSample], dict]¶
Returns losses or predictions of training, validation, testing, and simple inference process.
forwardmethod of BaseModel is an abstract method, its subclasses must implement this method.Accepts
inputsanddata_samplesprocessed bydata_preprocessor, and returns results according to mode arguments.During non-distributed training, validation, and testing process,
forwardwill be called byBaseModel.train_step,BaseModel.val_stepandBaseModel.val_stepdirectly.During distributed data parallel training process,
MMSeparateDistributedDataParallel.train_stepwill first callDistributedDataParallel.forwardto enable automatic gradient synchronization, and then callforwardto get training loss.- 参数
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor.mode (str) –
mode should be one of
loss,predictandtensor. Default: ‘tensor’.loss: Called bytrain_stepand return lossdictused for loggingpredict: Called byval_stepandtest_stepand return list ofBaseDataElementresults used for computing metric.tensor: Called by custom use to getTensortype results.
- 返回
If
mode == loss, return adictof loss tensor used for backward and logging.If
mode == val, return alistofBaseDataElementfor computing metric and getting inference result.If
mode == predict, return alistofBaseDataElementfor computing metric and getting inference result.If
mode == tensor, return a tensor ortupleof tensor ordictor tensor for custom use.
- 返回类型
ForwardResults
- convert_to_datasample(predictions: mmagic.structures.DataSample, data_samples: mmagic.structures.DataSample, inputs: Optional[torch.Tensor]) List[mmagic.structures.DataSample]¶
Add predictions and destructed inputs (if passed) to data samples.
- 参数
predictions (DataSample) – The predictions of the model.
data_samples (DataSample) – The data samples loaded from dataloader.
inputs (Optional[torch.Tensor]) – The input of model. Defaults to None.
- 返回
Modified data samples.
- 返回类型
List[DataSample]
- forward_tensor(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, **kwargs) torch.Tensor¶
Forward tensor. Returns result of simple forward.
- 参数
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor.
- 返回
result of simple forward.
- 返回类型
Tensor
- forward_inference(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, **kwargs) List[mmagic.structures.DataSample]¶
Forward inference. Returns predictions of validation, testing, and simple inference.
- 参数
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor.
- 返回
predictions.
- 返回类型
List[EditDataSample]
- forward_train(inputs, data_samples=None, **kwargs)¶
Forward training. Losses of training is calculated in train_step.
- 参数
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor.
- 返回
Result of
forward_tensorwithtraining=True.- 返回类型
Tensor
- val_step(data: Union[tuple, dict, list]) list¶
Gets the predictions of given data.
Calls
self.data_preprocessor(data, False)andself(inputs, data_sample, mode='predict')in order. Return the predictions which will be passed to evaluator.- 参数
data (dict or tuple or list) – Data sampled from dataset.
- 返回
The predictions of given data.
- 返回类型
list
- test_step(data: Union[dict, tuple, list]) list¶
BaseModelimplementstest_stepthe same asval_step.- 参数
data (dict or tuple or list) – Data sampled from dataset.
- 返回
The predictions of given data.
- 返回类型
list
- _run_forward(data: Union[dict, tuple, list], mode: str) Union[Dict[str, torch.Tensor], list]¶
Unpacks data for
forward()- 参数
data (dict or tuple or list) – Data sampled from dataset.
mode (str) – Mode of forward.
- 返回
Results of training or testing mode.
- 返回类型
dict or list
- train_step(data: List[dict], optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor]¶
Train step of GAN-based method.
- 参数
data (List[dict]) – Data sampled from dataloader.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- 返回
A
dictof tensor for logging.- 返回类型
Dict[str, torch.Tensor]
- g_step(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor)¶
G step of DobuleGAN: Calculate losses of generator.
- 参数
batch_outputs (Tensor) – Batch output of generator.
batch_gt_data (Tensor) – Batch GT data.
- 返回
Dict of losses.
- 返回类型
dict
- d_step(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor)¶
D step of DobuleGAN: Calculate losses of generator.
- 参数
batch_outputs (Tensor) – Batch output of generator.
batch_gt_data (Tensor) – Batch GT data.
- 返回
Dict of losses.
- 返回类型
dict
- g_step_with_optim(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor, optim_wrapper: mmengine.optim.OptimWrapperDict)¶
G step with optim of GAN: Calculate losses of generator and run optim.
- 参数
batch_outputs (Tensor) – Batch output of generator.
batch_gt_data (Tensor) – Batch GT data.
optim_wrapper (OptimWrapperDict) – Optim wrapper dict.
- 返回
Dict of parsed losses.
- 返回类型
dict
- d_step_with_optim(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor, optim_wrapper: mmengine.optim.OptimWrapperDict)¶
D step with optim of GAN: Calculate losses of discriminator and run optim.
- 参数
batch_outputs (Tensor) – Batch output of generator.
batch_gt_data (Tensor) – Batch GT data.
optim_wrapper (OptimWrapperDict) – Optim wrapper dict.
- 返回
Dict of parsed losses.
- 返回类型
dict
- extract_gt_data(data_samples)¶
extract gt data from data samples.
- 参数
data_samples (list) – List of DataSample.
- 返回
Extract gt data.
- 返回类型
Tensor