mmagic.models.editors.ttsr
¶
Package Contents¶
Classes¶
Learnable Texture Extractor. |
|
Search texture reference by transformer. |
|
TTSR model for Reference-based Image Super-Resolution. |
|
A discriminator for TTSR. |
|
TTSR network structure (main-net) for reference-based super-resolution. |
- class mmagic.models.editors.ttsr.LTE(requires_grad=True, pixel_range=1.0, load_pretrained_vgg=True, init_cfg=None)[source]¶
Bases:
mmengine.model.BaseModule
Learnable Texture Extractor.
Based on pretrained VGG19. Generate features in 3 levels.
- Parameters
requires_grad (bool) – Require grad or not. Default: True.
pixel_range (float) – Pixel range of feature. Default: 1.
load_pretrained_vgg (bool) – Load pretrained VGG from torchvision. Default: True. Train: must load pretrained VGG. Eval: needn’t load pretrained VGG, because we will load pretrained LTE.
init_cfg (dict, optional) – Initialization config dict.
- forward(x)[source]¶
Forward function.
- Parameters
x (Tensor) – Input tensor with shape (n, 3, h, w).
- Returns
- Forward results in 3 levels.
x_level3: Forward results in level 3 (n, 256, h/4, w/4). x_level2: Forward results in level 2 (n, 128, h/2, w/2). x_level1: Forward results in level 1 (n, 64, h, w).
- Return type
Tuple[Tensor]
- class mmagic.models.editors.ttsr.SearchTransformer(init_cfg: Union[dict, List[dict], None] = None)[source]¶
Bases:
mmengine.model.BaseModule
Search texture reference by transformer.
Include relevance embedding, hard-attention and soft-attention.
- gather(inputs, dim, index)[source]¶
Hard Attention. Gathers values along an axis specified by dim.
- Parameters
inputs (Tensor) – The source tensor. (N, C*k*k, H*W)
dim (int) – The axis along which to index.
index (Tensor) – The indices of elements to gather. (N, H*W)
- results:
outputs (Tensor): The result tensor. (N, C*k*k, H*W)
- forward(img_lq, ref_lq, refs)[source]¶
Texture transformer.
Q = LTE(img_lq) K = LTE(ref_lq) V = LTE(ref), from V_level_n to V_level_1
- Relevance embedding aims to embed the relevance between the LQ and
Ref image by estimating the similarity between Q and K.
- Hard-Attention: Only transfer features from the most relevant position
in V for each query.
- Soft-Attention: synthesize features from the transferred GT texture
features T and the LQ features F from the backbone.
- Parameters
extractor (All args are features come from) – These features contain 3 levels. When upscale_factor=4, the size ratio of these features is level3:level2:level1 = 1:2:4.
img_lq (Tensor) – Tensor of 4x bicubic-upsampled lq image. (N, C, H, W)
ref_lq (Tensor) – Tensor of ref_lq. ref_lq is obtained by applying bicubic down-sampling and up-sampling with factor 4x on ref. (N, C, H, W)
refs (Tuple[Tensor]) – Tuple of ref tensors. [(N, C, H, W), (N, C/2, 2H, 2W), …]
- Returns
- tuple contains:
soft_attention (Tensor): Soft-Attention tensor. (N, 1, H, W)
textures (Tuple[Tensor]): Transferred GT textures. [(N, C, H, W), (N, C/2, 2H, 2W), …]
- Return type
tuple
- class mmagic.models.editors.ttsr.TTSR(generator, extractor, transformer, pixel_loss, discriminator=None, perceptual_loss=None, transferal_perceptual_loss=None, gan_loss=None, train_cfg=None, test_cfg=None, init_cfg=None, data_preprocessor=None)[source]¶
Bases:
mmagic.models.editors.srgan.SRGAN
TTSR model for Reference-based Image Super-Resolution.
Paper: Learning Texture Transformer Network for Image Super-Resolution.
- Parameters
generator (dict) – Config for the generator.
extractor (dict) – Config for the extractor.
transformer (dict) – Config for the transformer.
pixel_loss (dict) – Config for the pixel loss.
discriminator (dict) – Config for the discriminator. Default: None.
perceptual_loss (dict) – Config for the perceptual loss. Default: None.
transferal_perceptual_loss (dict) – Config for the transferal perceptual loss. Default: None.
gan_loss (dict) – Config for the GAN loss. Default: None
train_cfg (dict) – Config for train. Default: None.
test_cfg (dict) – Config for testing. Default: None.
init_cfg (dict, optional) – The weight initialized config for
BaseModule
. Default: None.data_preprocessor (dict, optional) – The pre-process config of
BaseDataPreprocessor
. Default: None.
- forward_tensor(inputs, data_samples=None, training=False)[source]¶
Forward tensor. Returns result of simple forward.
- Parameters
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor
.training (bool) – Whether is training. Default: False.
- Returns
- results of forward inference and
forward train.
- Return type
(Tensor | Tuple[List[Tensor]])
- g_step(batch_outputs, batch_gt_data: mmagic.structures.DataSample)[source]¶
G step of GAN: Calculate losses of generator.
- Parameters
batch_outputs (Tensor) – Batch output of generator.
batch_gt_data (Tensor) – Batch GT data.
- Returns
Dict of losses.
- Return type
dict
- g_step_with_optim(batch_outputs: torch.Tensor, batch_gt_data: torch.Tensor, optim_wrapper: mmengine.optim.OptimWrapperDict)[source]¶
G step with optim of GAN: Calculate losses of generator and run optim.
- Parameters
batch_outputs (Tensor) – Batch output of generator.
batch_gt_data (Tensor) – Batch GT data.
optim_wrapper (OptimWrapperDict) – Optim wrapper dict.
- Returns
Dict of parsed losses.
- Return type
dict
- train_step(data: List[dict], optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor] [source]¶
Train step of GAN-based method.
- Parameters
data (List[dict]) – Data sampled from dataloader.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- Returns
A
dict
of tensor for logging.- Return type
Dict[str, torch.Tensor]
- class mmagic.models.editors.ttsr.TTSRDiscriminator(in_channels=3, in_size=160, init_cfg=None)[source]¶
Bases:
mmengine.model.BaseModule
A discriminator for TTSR.
- Parameters
in_channels (int) – Channel number of inputs. Default: 3.
in_size (int) – Size of input image. Default: 160.
init_cfg (dict, optional) – Initialization config dict.
- class mmagic.models.editors.ttsr.TTSRNet(in_channels, out_channels, mid_channels=64, texture_channels=64, num_blocks=(16, 16, 8, 4), res_scale=1.0, init_cfg=None)[source]¶
Bases:
mmengine.model.BaseModule
TTSR network structure (main-net) for reference-based super-resolution.
Paper: Learning Texture Transformer Network for Image Super-Resolution
Adapted from ‘https://github.com/researchmm/TTSR.git’ ‘https://github.com/researchmm/TTSR’ Copyright permission at ‘https://github.com/researchmm/TTSR/issues/38’.
- Parameters
in_channels (int) – Number of channels in the input image
out_channels (int) – Number of channels in the output image
mid_channels (int) – Channel number of intermediate features. Default: 64
texture_channels (int) – Number of texture channels. Default: 64.
num_blocks (tuple[int]) – Block numbers in the trunk network. Default: (16, 16, 8, 4)
res_scale (float) – Used to scale the residual in residual block. Default: 1.
init_cfg (dict, optional) – Initialization config dict.
- forward(x, soft_attention, textures)[source]¶
Forward function.
- Parameters
x (Tensor) – Input tensor with shape (n, c, h, w).
soft_attention (Tensor) – Soft-Attention tensor with shape (n, 1, h, w).
textures (Tuple[Tensor]) – Transferred HR texture tensors. [(N, C, H, W), (N, C/2, 2H, 2W), …]
- Returns
Forward results.
- Return type
Tensor