mmagic.models.editors.deblurganv2
¶
Package Contents¶
Classes¶
Base class for all algorithmic models. |
|
Defines the discriminator for DeblurGanv2 with the specified arguments.. |
|
Defines the generator for DeblurGanv2 with the specified arguments.. |
- class mmagic.models.editors.deblurganv2.DeblurGanV2(generator: ModelType, discriminator: Optional[ModelType] = None, pixel_loss: Optional[Union[dict, str]] = None, disc_loss: Optional[Union[dict, str]] = None, adv_lambda: float = 0.001, warmup_num: int = 3, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None, init_cfg: Optional[dict] = None, data_preprocessor: Optional[dict] = None)[源代码]¶
Bases:
mmengine.model.BaseModel
Base class for all algorithmic models.
BaseModel implements the basic functions of the algorithmic model, such as weights initialize, batch inputs preprocess(see more information in
BaseDataPreprocessor
), parse losses, and update model parameters.Subclasses inherit from BaseModel only need to implement the forward method, which implements the logic to calculate loss and predictions, then can be trained in the runner.
实际案例
>>> @MODELS.register_module() >>> class ToyModel(BaseModel): >>> >>> def __init__(self): >>> super().__init__() >>> self.backbone = nn.Sequential() >>> self.backbone.add_module('conv1', nn.Conv2d(3, 6, 5)) >>> self.backbone.add_module('pool', nn.MaxPool2d(2, 2)) >>> self.backbone.add_module('conv2', nn.Conv2d(6, 16, 5)) >>> self.backbone.add_module('fc1', nn.Linear(16 * 5 * 5, 120)) >>> self.backbone.add_module('fc2', nn.Linear(120, 84)) >>> self.backbone.add_module('fc3', nn.Linear(84, 10)) >>> >>> self.criterion = nn.CrossEntropyLoss() >>> >>> def forward(self, batch_inputs, data_samples, mode='tensor'): >>> data_samples = torch.stack(data_samples) >>> if mode == 'tensor': >>> return self.backbone(batch_inputs) >>> elif mode == 'predict': >>> feats = self.backbone(batch_inputs) >>> predictions = torch.argmax(feats, 1) >>> return predictions >>> elif mode == 'loss': >>> feats = self.backbone(batch_inputs) >>> loss = self.criterion(feats, data_samples) >>> return dict(loss=loss)
- 参数
data_preprocessor (dict, optional) – The pre-process config of
BaseDataPreprocessor
.init_cfg (dict, optional) – The weight initialized config for
BaseModule
.
- data_preprocessor¶
Used for pre-processing data sampled by dataloader to the format accepted by
forward()
.- Type
BaseDataPreprocessor
- init_cfg¶
Initialization config dict.
- Type
dict, optional
- forward(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, mode: str = 'tensor', **kwargs) Union[torch.Tensor, List[mmagic.structures.DataSample], dict] ¶
Returns losses or predictions of training, validation, testing, and simple inference process.
forward
method of BaseModel is an abstract method, its subclasses must implement this method.Accepts
inputs
anddata_samples
processed bydata_preprocessor
, and returns results according to mode arguments.During non-distributed training, validation, and testing process,
forward
will be called byBaseModel.train_step
,BaseModel.val_step
andBaseModel.val_step
directly.During distributed data parallel training process,
MMSeparateDistributedDataParallel.train_step
will first callDistributedDataParallel.forward
to enable automatic gradient synchronization, and then callforward
to get training loss.- 参数
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor
.mode (str) –
mode should be one of
loss
,predict
andtensor
. Default: ‘tensor’.loss
: Called bytrain_step
and return lossdict
used for loggingpredict
: Called byval_step
andtest_step
and return list ofBaseDataElement
results used for computing metric.tensor
: Called by custom use to getTensor
type results.
- 返回
If
mode == loss
, return adict
of loss tensor used for backward and logging.If
mode == val
, return alist
ofBaseDataElement
for computing metric and getting inference result.If
mode == predict
, return alist
ofBaseDataElement
for computing metric and getting inference result.If
mode == tensor
, return a tensor ortuple
of tensor ordict
or tensor for custom use.
- 返回类型
ForwardResults
- convert_to_datasample(predictions: mmagic.structures.DataSample, data_samples: mmagic.structures.DataSample, inputs: Optional[torch.Tensor]) List[mmagic.structures.DataSample] ¶
Add predictions and destructed inputs (if passed) to data samples.
- 参数
predictions (DataSample) – The predictions of the model.
data_samples (DataSample) – The data samples loaded from dataloader.
inputs (Optional[torch.Tensor]) – The input of model. Defaults to None.
- 返回
Modified data samples.
- 返回类型
List[DataSample]
- forward_tensor(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, **kwargs) torch.Tensor ¶
Forward tensor. Returns result of simple forward.
- 参数
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor
.
- 返回
result of simple forward.
- 返回类型
Tensor
- forward_inference(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, **kwargs) List[mmagic.structures.DataSample] ¶
Forward inference. Returns predictions of validation, testing, and simple inference.
- 参数
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor
.
- 返回
predictions.
- 返回类型
List[EditDataSample]
- forward_train(inputs, data_samples=None, **kwargs)¶
Forward training. Losses of training is calculated in train_step.
- 参数
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor
.
- 返回
Result of
forward_tensor
withtraining=True
.- 返回类型
Tensor
- val_step(data: Union[tuple, dict, list]) list ¶
Gets the predictions of given data.
Calls
self.data_preprocessor(data, False)
andself(inputs, data_sample, mode='predict')
in order. Return the predictions which will be passed to evaluator.- 参数
data (dict or tuple or list) – Data sampled from dataset.
- 返回
The predictions of given data.
- 返回类型
list
- test_step(data: Union[dict, tuple, list]) list ¶
BaseModel
implementstest_step
the same asval_step
.- 参数
data (dict or tuple or list) – Data sampled from dataset.
- 返回
The predictions of given data.
- 返回类型
list
- _run_forward(data: Union[dict, tuple, list], mode: str) Union[Dict[str, torch.Tensor], list] ¶
Unpacks data for
forward()
- 参数
data (dict or tuple or list) – Data sampled from dataset.
mode (str) – Mode of forward.
- 返回
Results of training or testing mode.
- 返回类型
dict or list
- train_step(data: List[dict], optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor] ¶
Train step of GAN-based method.
- 参数
data (List[dict]) – Data sampled from dataloader.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- 返回
A
dict
of tensor for logging.- 返回类型
Dict[str, torch.Tensor]
- g_step(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor)¶
G step of DobuleGAN: Calculate losses of generator.
- 参数
batch_outputs (Tensor) – Batch output of generator.
batch_gt_data (Tensor) – Batch GT data.
- 返回
Dict of losses.
- 返回类型
dict
- d_step(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor)¶
D step of DobuleGAN: Calculate losses of generator.
- 参数
batch_outputs (Tensor) – Batch output of generator.
batch_gt_data (Tensor) – Batch GT data.
- 返回
Dict of losses.
- 返回类型
dict
- g_step_with_optim(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor, optim_wrapper: mmengine.optim.OptimWrapperDict)¶
G step with optim of GAN: Calculate losses of generator and run optim.
- 参数
batch_outputs (Tensor) – Batch output of generator.
batch_gt_data (Tensor) – Batch GT data.
optim_wrapper (OptimWrapperDict) – Optim wrapper dict.
- 返回
Dict of parsed losses.
- 返回类型
dict
- d_step_with_optim(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor, optim_wrapper: mmengine.optim.OptimWrapperDict)¶
D step with optim of GAN: Calculate losses of discriminator and run optim.
- 参数
batch_outputs (Tensor) – Batch output of generator.
batch_gt_data (Tensor) – Batch GT data.
optim_wrapper (OptimWrapperDict) – Optim wrapper dict.
- 返回
Dict of parsed losses.
- 返回类型
dict
- extract_gt_data(data_samples)¶
extract gt data from data samples.
- 参数
data_samples (list) – List of DataSample.
- 返回
Extract gt data.
- 返回类型
Tensor