mmagic.models.editors.deblurganv2
¶
Package Contents¶
Classes¶
Base class for all algorithmic models. |
|
Defines the discriminator for DeblurGanv2 with the specified arguments.. |
|
Defines the generator for DeblurGanv2 with the specified arguments.. |
- class mmagic.models.editors.deblurganv2.DeblurGanV2(generator: ModelType, discriminator: Optional[ModelType] = None, pixel_loss: Optional[Union[dict, str]] = None, disc_loss: Optional[Union[dict, str]] = None, adv_lambda: float = 0.001, warmup_num: int = 3, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None, init_cfg: Optional[dict] = None, data_preprocessor: Optional[dict] = None)[source]¶
Bases:
mmengine.model.BaseModel
Base class for all algorithmic models.
BaseModel implements the basic functions of the algorithmic model, such as weights initialize, batch inputs preprocess(see more information in
BaseDataPreprocessor
), parse losses, and update model parameters.Subclasses inherit from BaseModel only need to implement the forward method, which implements the logic to calculate loss and predictions, then can be trained in the runner.
Examples
>>> @MODELS.register_module() >>> class ToyModel(BaseModel): >>> >>> def __init__(self): >>> super().__init__() >>> self.backbone = nn.Sequential() >>> self.backbone.add_module('conv1', nn.Conv2d(3, 6, 5)) >>> self.backbone.add_module('pool', nn.MaxPool2d(2, 2)) >>> self.backbone.add_module('conv2', nn.Conv2d(6, 16, 5)) >>> self.backbone.add_module('fc1', nn.Linear(16 * 5 * 5, 120)) >>> self.backbone.add_module('fc2', nn.Linear(120, 84)) >>> self.backbone.add_module('fc3', nn.Linear(84, 10)) >>> >>> self.criterion = nn.CrossEntropyLoss() >>> >>> def forward(self, batch_inputs, data_samples, mode='tensor'): >>> data_samples = torch.stack(data_samples) >>> if mode == 'tensor': >>> return self.backbone(batch_inputs) >>> elif mode == 'predict': >>> feats = self.backbone(batch_inputs) >>> predictions = torch.argmax(feats, 1) >>> return predictions >>> elif mode == 'loss': >>> feats = self.backbone(batch_inputs) >>> loss = self.criterion(feats, data_samples) >>> return dict(loss=loss)
- Parameters
data_preprocessor (dict, optional) – The pre-process config of
BaseDataPreprocessor
.init_cfg (dict, optional) – The weight initialized config for
BaseModule
.
- data_preprocessor¶
Used for pre-processing data sampled by dataloader to the format accepted by
forward()
.- Type
BaseDataPreprocessor
- init_cfg¶
Initialization config dict.
- Type
dict, optional
- forward(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, mode: str = 'tensor', **kwargs) Union[torch.Tensor, List[mmagic.structures.DataSample], dict] [source]¶
Returns losses or predictions of training, validation, testing, and simple inference process.
forward
method of BaseModel is an abstract method, its subclasses must implement this method.Accepts
inputs
anddata_samples
processed bydata_preprocessor
, and returns results according to mode arguments.During non-distributed training, validation, and testing process,
forward
will be called byBaseModel.train_step
,BaseModel.val_step
andBaseModel.val_step
directly.During distributed data parallel training process,
MMSeparateDistributedDataParallel.train_step
will first callDistributedDataParallel.forward
to enable automatic gradient synchronization, and then callforward
to get training loss.- Parameters
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor
.mode (str) –
mode should be one of
loss
,predict
andtensor
. Default: ‘tensor’.loss
: Called bytrain_step
and return lossdict
used for loggingpredict
: Called byval_step
andtest_step
and return list ofBaseDataElement
results used for computing metric.tensor
: Called by custom use to getTensor
type results.
- Returns
If
mode == loss
, return adict
of loss tensor used for backward and logging.If
mode == val
, return alist
ofBaseDataElement
for computing metric and getting inference result.If
mode == predict
, return alist
ofBaseDataElement
for computing metric and getting inference result.If
mode == tensor
, return a tensor ortuple
of tensor ordict
or tensor for custom use.
- Return type
ForwardResults
- convert_to_datasample(predictions: mmagic.structures.DataSample, data_samples: mmagic.structures.DataSample, inputs: Optional[torch.Tensor]) List[mmagic.structures.DataSample] [source]¶
Add predictions and destructed inputs (if passed) to data samples.
- Parameters
predictions (DataSample) – The predictions of the model.
data_samples (DataSample) – The data samples loaded from dataloader.
inputs (Optional[torch.Tensor]) – The input of model. Defaults to None.
- Returns
Modified data samples.
- Return type
List[DataSample]
- forward_tensor(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, **kwargs) torch.Tensor [source]¶
Forward tensor. Returns result of simple forward.
- Parameters
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor
.
- Returns
result of simple forward.
- Return type
Tensor
- forward_inference(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, **kwargs) List[mmagic.structures.DataSample] [source]¶
Forward inference. Returns predictions of validation, testing, and simple inference.
- Parameters
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor
.
- Returns
predictions.
- Return type
List[EditDataSample]
- forward_train(inputs, data_samples=None, **kwargs)[source]¶
Forward training. Losses of training is calculated in train_step.
- Parameters
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor
.
- Returns
Result of
forward_tensor
withtraining=True
.- Return type
Tensor
- val_step(data: Union[tuple, dict, list]) list [source]¶
Gets the predictions of given data.
Calls
self.data_preprocessor(data, False)
andself(inputs, data_sample, mode='predict')
in order. Return the predictions which will be passed to evaluator.- Parameters
data (dict or tuple or list) – Data sampled from dataset.
- Returns
The predictions of given data.
- Return type
list
- test_step(data: Union[dict, tuple, list]) list [source]¶
BaseModel
implementstest_step
the same asval_step
.- Parameters
data (dict or tuple or list) – Data sampled from dataset.
- Returns
The predictions of given data.
- Return type
list
- _run_forward(data: Union[dict, tuple, list], mode: str) Union[Dict[str, torch.Tensor], list] [source]¶
Unpacks data for
forward()
- Parameters
data (dict or tuple or list) – Data sampled from dataset.
mode (str) – Mode of forward.
- Returns
Results of training or testing mode.
- Return type
dict or list
- train_step(data: List[dict], optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor] [source]¶
Train step of GAN-based method.
- Parameters
data (List[dict]) – Data sampled from dataloader.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- Returns
A
dict
of tensor for logging.- Return type
Dict[str, torch.Tensor]
- g_step(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor)[source]¶
G step of DobuleGAN: Calculate losses of generator.
- Parameters
batch_outputs (Tensor) – Batch output of generator.
batch_gt_data (Tensor) – Batch GT data.
- Returns
Dict of losses.
- Return type
dict
- d_step(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor)[source]¶
D step of DobuleGAN: Calculate losses of generator.
- Parameters
batch_outputs (Tensor) – Batch output of generator.
batch_gt_data (Tensor) – Batch GT data.
- Returns
Dict of losses.
- Return type
dict
- g_step_with_optim(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor, optim_wrapper: mmengine.optim.OptimWrapperDict)[source]¶
G step with optim of GAN: Calculate losses of generator and run optim.
- Parameters
batch_outputs (Tensor) – Batch output of generator.
batch_gt_data (Tensor) – Batch GT data.
optim_wrapper (OptimWrapperDict) – Optim wrapper dict.
- Returns
Dict of parsed losses.
- Return type
dict
- d_step_with_optim(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor, optim_wrapper: mmengine.optim.OptimWrapperDict)[source]¶
D step with optim of GAN: Calculate losses of discriminator and run optim.
- Parameters
batch_outputs (Tensor) – Batch output of generator.
batch_gt_data (Tensor) – Batch GT data.
optim_wrapper (OptimWrapperDict) – Optim wrapper dict.
- Returns
Dict of parsed losses.
- Return type
dict