mmagic.models.editors.dic.dic
¶
Module Contents¶
Classes¶
DIC model for Face Super-Resolution. |
- class mmagic.models.editors.dic.dic.DIC(generator, pixel_loss, align_loss, discriminator=None, gan_loss=None, feature_loss=None, train_cfg=None, test_cfg=None, init_cfg=None, data_preprocessor=None)[source]¶
Bases:
mmagic.models.editors.srgan.SRGAN
DIC model for Face Super-Resolution.
- Paper: Deep Face Super-Resolution with Iterative Collaboration between
Attentive Recovery and Landmark Estimation.
- Parameters
generator (dict) – Config for the generator.
pixel_loss (dict) – Config for the pixel loss.
align_loss (dict) – Config for the align loss.
discriminator (dict) – Config for the discriminator. Default: None.
gan_loss (dict) – Config for the gan loss. Default: None.
feature_loss (dict) – Config for the feature loss. Default: None.
train_cfg (dict) – Config for train. Default: None.
test_cfg (dict) – Config for testing. Default: None.
init_cfg (dict, optional) – The weight initialized config for
BaseModule
. Default: None.data_preprocessor (dict, optional) – The pre-process config of
BaseDataPreprocessor
. Default: None.
- forward_tensor(inputs, data_samples=None, training=False)[source]¶
Forward tensor. Returns result of simple forward.
- Parameters
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor
.training (bool) – Whether is training. Default: False.
- Returns
- results of forward inference and
forward train.
- Return type
(Tensor | Tuple[List[Tensor]])
- g_step(batch_outputs, batch_gt_data)[source]¶
G step of GAN: Calculate losses of generator.
- Parameters
batch_outputs (Tensor) – Batch output of generator.
batch_gt_data (Tensor) – Batch GT data.
- Returns
Dict of losses.
- Return type
dict
- train_step(data: List[dict], optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor] [source]¶
Train step of GAN-based method.
- Parameters
data (List[dict]) – Data sampled from dataloader.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- Returns
A
dict
of tensor for logging.- Return type
Dict[str, torch.Tensor]