mmagic.models.editors.srgan
¶
Package Contents¶
Classes¶
A modified VGG discriminator with input size 128 x 128. |
|
Modified SRResNet. |
|
SRGAN model for single image super-resolution. |
- class mmagic.models.editors.srgan.ModifiedVGG(in_channels, mid_channels)[源代码]¶
Bases:
mmengine.model.BaseModule
A modified VGG discriminator with input size 128 x 128.
It is used to train SRGAN and ESRGAN.
- 参数
in_channels (int) – Channel number of inputs. Default: 3.
mid_channels (int) – Channel number of base intermediate features. Default: 64.
- forward(x)¶
Forward function.
- 参数
x (Tensor) – Input tensor with shape (n, c, h, w).
- 返回
Forward results.
- 返回类型
Tensor
- class mmagic.models.editors.srgan.MSRResNet(in_channels, out_channels, mid_channels=64, num_blocks=16, upscale_factor=4)[源代码]¶
Bases:
mmengine.model.BaseModule
Modified SRResNet.
A compacted version modified from SRResNet in “Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network”.
It uses residual blocks without BN, similar to EDSR. Currently, it supports x2, x3 and x4 upsampling scale factor.
- 参数
in_channels (int) – Channel number of inputs.
out_channels (int) – Channel number of outputs.
mid_channels (int) – Channel number of intermediate features. Default: 64.
num_blocks (int) – Block number in the trunk network. Default: 16.
upscale_factor (int) – Upsampling factor. Support x2, x3 and x4. Default: 4.
- _supported_upscale_factors = [2, 3, 4]¶
- forward(x)¶
Forward function.
- 参数
x (Tensor) – Input tensor with shape (n, c, h, w).
- 返回
Forward results.
- 返回类型
Tensor
- init_weights()¶
Init weights for models.
- class mmagic.models.editors.srgan.SRGAN(generator, discriminator=None, gan_loss=None, pixel_loss=None, perceptual_loss=None, train_cfg=None, test_cfg=None, init_cfg=None, data_preprocessor=None)[源代码]¶
Bases:
mmagic.models.base_models.BaseEditModel
SRGAN model for single image super-resolution.
Ref: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network.
- 参数
generator (dict) – Config for the generator.
discriminator (dict) – Config for the discriminator. Default: None.
gan_loss (dict) – Config for the gan loss. Note that the loss weight in gan loss is only for the generator.
pixel_loss (dict) – Config for the pixel loss. Default: None.
perceptual_loss (dict) – Config for the perceptual loss. Default: None.
train_cfg (dict) – Config for training. Default: None.
test_cfg (dict) – Config for testing. Default: None.
init_cfg (dict, optional) – The weight initialized config for
BaseModule
. Default: None.data_preprocessor (dict, optional) – The pre-process config of
BaseDataPreprocessor
. Default: None.
- forward_train(inputs, data_samples=None, **kwargs)¶
Forward training. Losses of training is calculated in train_step.
- 参数
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor
.
- 返回
Result of
forward_tensor
withtraining=True
.- 返回类型
Tensor
- forward_tensor(inputs, data_samples=None, training=False)¶
Forward tensor. Returns result of simple forward.
- 参数
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor
.training (bool) – Whether is training. Default: False.
- 返回
result of simple forward.
- 返回类型
Tensor
- if_run_g()¶
Calculates whether need to run the generator step.
- if_run_d()¶
Calculates whether need to run the discriminator step.
- g_step(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor)¶
G step of GAN: Calculate losses of generator.
- 参数
batch_outputs (Tensor) – Batch output of generator.
batch_gt_data (Tensor) – Batch GT data.
- 返回
Dict of losses.
- 返回类型
dict
- d_step_real(batch_outputs, batch_gt_data: torch.Tensor)¶
Real part of D step.
- 参数
batch_outputs (Tensor) – Batch output of generator.
batch_gt_data (Tensor) – Batch GT data.
- 返回
Real part of gan_loss for discriminator.
- 返回类型
Tensor
- d_step_fake(batch_outputs: torch.Tensor, batch_gt_data)¶
Fake part of D step.
- 参数
batch_outputs (Tensor) – Batch output of generator.
batch_gt_data (Tensor) – Batch GT data.
- 返回
Fake part of gan_loss for discriminator.
- 返回类型
Tensor
- g_step_with_optim(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor, optim_wrapper: mmengine.optim.OptimWrapperDict)¶
G step with optim of GAN: Calculate losses of generator and run optim.
- 参数
batch_outputs (Tensor) – Batch output of generator.
batch_gt_data (Tensor) – Batch GT data.
optim_wrapper (OptimWrapperDict) – Optim wrapper dict.
- 返回
Dict of parsed losses.
- 返回类型
dict
- d_step_with_optim(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor, optim_wrapper: mmengine.optim.OptimWrapperDict)¶
D step with optim of GAN: Calculate losses of discriminator and run optim.
- 参数
batch_outputs (Tensor) – Batch output of generator.
batch_gt_data (Tensor) – Batch GT data.
optim_wrapper (OptimWrapperDict) – Optim wrapper dict.
- 返回
Dict of parsed losses.
- 返回类型
dict
- extract_gt_data(data_samples)¶
extract gt data from data samples.
- 参数
data_samples (list) – List of DataSample.
- 返回
Extract gt data.
- 返回类型
Tensor
- train_step(data: List[dict], optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor] ¶
Train step of GAN-based method.
- 参数
data (List[dict]) – Data sampled from dataloader.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- 返回
A
dict
of tensor for logging.- 返回类型
Dict[str, torch.Tensor]