Shortcuts

mmagic.models.editors.deblurganv2

Package Contents

Classes

DeblurGanV2

Base class for all algorithmic models.

DeblurGanV2Discriminator

Defines the discriminator for DeblurGanv2 with the specified arguments..

DeblurGanV2Generator

Defines the generator for DeblurGanv2 with the specified arguments..

class mmagic.models.editors.deblurganv2.DeblurGanV2(generator: ModelType, discriminator: Optional[ModelType] = None, pixel_loss: Optional[Union[dict, str]] = None, disc_loss: Optional[Union[dict, str]] = None, adv_lambda: float = 0.001, warmup_num: int = 3, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None, init_cfg: Optional[dict] = None, data_preprocessor: Optional[dict] = None)[源代码]

Bases: mmengine.model.BaseModel

Base class for all algorithmic models.

BaseModel implements the basic functions of the algorithmic model, such as weights initialize, batch inputs preprocess(see more information in BaseDataPreprocessor), parse losses, and update model parameters.

Subclasses inherit from BaseModel only need to implement the forward method, which implements the logic to calculate loss and predictions, then can be trained in the runner.

实际案例

>>> @MODELS.register_module()
>>> class ToyModel(BaseModel):
>>>
>>>     def __init__(self):
>>>         super().__init__()
>>>         self.backbone = nn.Sequential()
>>>         self.backbone.add_module('conv1', nn.Conv2d(3, 6, 5))
>>>         self.backbone.add_module('pool', nn.MaxPool2d(2, 2))
>>>         self.backbone.add_module('conv2', nn.Conv2d(6, 16, 5))
>>>         self.backbone.add_module('fc1', nn.Linear(16 * 5 * 5, 120))
>>>         self.backbone.add_module('fc2', nn.Linear(120, 84))
>>>         self.backbone.add_module('fc3', nn.Linear(84, 10))
>>>
>>>         self.criterion = nn.CrossEntropyLoss()
>>>
>>>     def forward(self, batch_inputs, data_samples, mode='tensor'):
>>>         data_samples = torch.stack(data_samples)
>>>         if mode == 'tensor':
>>>             return self.backbone(batch_inputs)
>>>         elif mode == 'predict':
>>>             feats = self.backbone(batch_inputs)
>>>             predictions = torch.argmax(feats, 1)
>>>             return predictions
>>>         elif mode == 'loss':
>>>             feats = self.backbone(batch_inputs)
>>>             loss = self.criterion(feats, data_samples)
>>>             return dict(loss=loss)
参数
  • data_preprocessor (dict, optional) – The pre-process config of BaseDataPreprocessor.

  • init_cfg (dict, optional) – The weight initialized config for BaseModule.

data_preprocessor

Used for pre-processing data sampled by dataloader to the format accepted by forward().

Type

BaseDataPreprocessor

init_cfg

Initialization config dict.

Type

dict, optional

forward(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, mode: str = 'tensor', **kwargs) Union[torch.Tensor, List[mmagic.structures.DataSample], dict]

Returns losses or predictions of training, validation, testing, and simple inference process.

forward method of BaseModel is an abstract method, its subclasses must implement this method.

Accepts inputs and data_samples processed by data_preprocessor, and returns results according to mode arguments.

During non-distributed training, validation, and testing process, forward will be called by BaseModel.train_step, BaseModel.val_step and BaseModel.val_step directly.

During distributed data parallel training process, MMSeparateDistributedDataParallel.train_step will first call DistributedDataParallel.forward to enable automatic gradient synchronization, and then call forward to get training loss.

参数
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

  • mode (str) –

    mode should be one of loss, predict and tensor. Default: ‘tensor’.

    • loss: Called by train_step and return loss dict used for logging

    • predict: Called by val_step and test_step and return list of BaseDataElement results used for computing metric.

    • tensor: Called by custom use to get Tensor type results.

返回

  • If mode == loss, return a dict of loss tensor used for backward and logging.

  • If mode == val, return a list of BaseDataElement for computing metric and getting inference result.

  • If mode == predict, return a list of BaseDataElement for computing metric and getting inference result.

  • If mode == tensor, return a tensor or tuple of tensor or dict or tensor for custom use.

返回类型

ForwardResults

convert_to_datasample(predictions: mmagic.structures.DataSample, data_samples: mmagic.structures.DataSample, inputs: Optional[torch.Tensor]) List[mmagic.structures.DataSample]

Add predictions and destructed inputs (if passed) to data samples.

参数
  • predictions (DataSample) – The predictions of the model.

  • data_samples (DataSample) – The data samples loaded from dataloader.

  • inputs (Optional[torch.Tensor]) – The input of model. Defaults to None.

返回

Modified data samples.

返回类型

List[DataSample]

forward_tensor(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, **kwargs) torch.Tensor

Forward tensor. Returns result of simple forward.

参数
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

返回

result of simple forward.

返回类型

Tensor

forward_inference(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, **kwargs) List[mmagic.structures.DataSample]

Forward inference. Returns predictions of validation, testing, and simple inference.

参数
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

返回

predictions.

返回类型

List[EditDataSample]

forward_train(inputs, data_samples=None, **kwargs)

Forward training. Losses of training is calculated in train_step.

参数
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

返回

Result of forward_tensor with training=True.

返回类型

Tensor

val_step(data: Union[tuple, dict, list]) list

Gets the predictions of given data.

Calls self.data_preprocessor(data, False) and self(inputs, data_sample, mode='predict') in order. Return the predictions which will be passed to evaluator.

参数

data (dict or tuple or list) – Data sampled from dataset.

返回

The predictions of given data.

返回类型

list

test_step(data: Union[dict, tuple, list]) list

BaseModel implements test_step the same as val_step.

参数

data (dict or tuple or list) – Data sampled from dataset.

返回

The predictions of given data.

返回类型

list

_run_forward(data: Union[dict, tuple, list], mode: str) Union[Dict[str, torch.Tensor], list]

Unpacks data for forward()

参数
  • data (dict or tuple or list) – Data sampled from dataset.

  • mode (str) – Mode of forward.

返回

Results of training or testing mode.

返回类型

dict or list

train_step(data: List[dict], optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor]

Train step of GAN-based method.

参数
  • data (List[dict]) – Data sampled from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

返回

A dict of tensor for logging.

返回类型

Dict[str, torch.Tensor]

g_step(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor)

G step of DobuleGAN: Calculate losses of generator.

参数
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tensor) – Batch GT data.

返回

Dict of losses.

返回类型

dict

d_step(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor)

D step of DobuleGAN: Calculate losses of generator.

参数
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tensor) – Batch GT data.

返回

Dict of losses.

返回类型

dict

g_step_with_optim(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor, optim_wrapper: mmengine.optim.OptimWrapperDict)

G step with optim of GAN: Calculate losses of generator and run optim.

参数
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tensor) – Batch GT data.

  • optim_wrapper (OptimWrapperDict) – Optim wrapper dict.

返回

Dict of parsed losses.

返回类型

dict

d_step_with_optim(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor, optim_wrapper: mmengine.optim.OptimWrapperDict)

D step with optim of GAN: Calculate losses of discriminator and run optim.

参数
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tensor) – Batch GT data.

  • optim_wrapper (OptimWrapperDict) – Optim wrapper dict.

返回

Dict of parsed losses.

返回类型

dict

extract_gt_data(data_samples)

extract gt data from data samples.

参数

data_samples (list) – List of DataSample.

返回

Extract gt data.

返回类型

Tensor

class mmagic.models.editors.deblurganv2.DeblurGanV2Discriminator[源代码]

Defines the discriminator for DeblurGanv2 with the specified arguments..

参数

model (Str) – Type of the discriminator model

class mmagic.models.editors.deblurganv2.DeblurGanV2Generator[源代码]

Defines the generator for DeblurGanv2 with the specified arguments..

参数

model (Str) – Type of the generator model

Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.