Shortcuts

mmagic.models.editors.ttsr.ttsr

Module Contents

Classes

TTSR

TTSR model for Reference-based Image Super-Resolution.

class mmagic.models.editors.ttsr.ttsr.TTSR(generator, extractor, transformer, pixel_loss, discriminator=None, perceptual_loss=None, transferal_perceptual_loss=None, gan_loss=None, train_cfg=None, test_cfg=None, init_cfg=None, data_preprocessor=None)[source]

Bases: mmagic.models.editors.srgan.SRGAN

TTSR model for Reference-based Image Super-Resolution.

Paper: Learning Texture Transformer Network for Image Super-Resolution.

Parameters
  • generator (dict) – Config for the generator.

  • extractor (dict) – Config for the extractor.

  • transformer (dict) – Config for the transformer.

  • pixel_loss (dict) – Config for the pixel loss.

  • discriminator (dict) – Config for the discriminator. Default: None.

  • perceptual_loss (dict) – Config for the perceptual loss. Default: None.

  • transferal_perceptual_loss (dict) – Config for the transferal perceptual loss. Default: None.

  • gan_loss (dict) – Config for the GAN loss. Default: None

  • train_cfg (dict) – Config for train. Default: None.

  • test_cfg (dict) – Config for testing. Default: None.

  • init_cfg (dict, optional) – The weight initialized config for BaseModule. Default: None.

  • data_preprocessor (dict, optional) – The pre-process config of BaseDataPreprocessor. Default: None.

forward_tensor(inputs, data_samples=None, training=False)[source]

Forward tensor. Returns result of simple forward.

Parameters
  • inputs (torch.Tensor) – batch input tensor collated by data_preprocessor.

  • data_samples (List[BaseDataElement], optional) – data samples collated by data_preprocessor.

  • training (bool) – Whether is training. Default: False.

Returns

results of forward inference and

forward train.

Return type

(Tensor | Tuple[List[Tensor]])

if_run_g()[source]

Calculates whether need to run the generator step.

if_run_d()[source]

Calculates whether need to run the discriminator step.

g_step(batch_outputs, batch_gt_data: mmagic.structures.DataSample)[source]

G step of GAN: Calculate losses of generator.

Parameters
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tensor) – Batch GT data.

Returns

Dict of losses.

Return type

dict

g_step_with_optim(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor, optim_wrapper: mmengine.optim.OptimWrapperDict)[source]

G step with optim of GAN: Calculate losses of generator and run optim.

Parameters
  • batch_outputs (Tensor) – Batch output of generator.

  • batch_gt_data (Tensor) – Batch GT data.

  • optim_wrapper (OptimWrapperDict) – Optim wrapper dict.

Returns

Dict of parsed losses.

Return type

dict

train_step(data: List[dict], optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor][source]

Train step of GAN-based method.

Parameters
  • data (List[dict]) – Data sampled from dataloader.

  • optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.

Returns

A dict of tensor for logging.

Return type

Dict[str, torch.Tensor]

Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.