Source code for mmagic.models.editors.ttsr.search_transformer
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F
from mmengine.model import BaseModule
from mmagic.registry import MODELS
@MODELS.register_module()
[docs]class SearchTransformer(BaseModule):
"""Search texture reference by transformer.
Include relevance embedding, hard-attention and soft-attention.
"""
[docs] def gather(self, inputs, dim, index):
"""Hard Attention. Gathers values along an axis specified by dim.
Args:
inputs (Tensor): The source tensor. (N, C*k*k, H*W)
dim (int): The axis along which to index.
index (Tensor): The indices of elements to gather. (N, H*W)
results:
outputs (Tensor): The result tensor. (N, C*k*k, H*W)
"""
views = [inputs.size(0)
] + [1 if i != dim else -1 for i in range(1, inputs.ndim)]
expansion = [
-1 if i in (0, dim) else d for i, d in enumerate(inputs.size())
]
index = index.view(views).expand(expansion)
outputs = torch.gather(inputs, dim, index)
return outputs
[docs] def forward(self, img_lq, ref_lq, refs):
"""Texture transformer.
Q = LTE(img_lq)
K = LTE(ref_lq)
V = LTE(ref), from V_level_n to V_level_1
Relevance embedding aims to embed the relevance between the LQ and
Ref image by estimating the similarity between Q and K.
Hard-Attention: Only transfer features from the most relevant position
in V for each query.
Soft-Attention: synthesize features from the transferred GT texture
features T and the LQ features F from the backbone.
Args:
All args are features come from extractor (such as LTE).
These features contain 3 levels.
When upscale_factor=4, the size ratio of these features is
level3:level2:level1 = 1:2:4.
img_lq (Tensor): Tensor of 4x bicubic-upsampled lq image.
(N, C, H, W)
ref_lq (Tensor): Tensor of ref_lq. ref_lq is obtained
by applying bicubic down-sampling and up-sampling with factor
4x on ref. (N, C, H, W)
refs (Tuple[Tensor]): Tuple of ref tensors.
[(N, C, H, W), (N, C/2, 2H, 2W), ...]
Returns:
tuple: tuple contains:
soft_attention (Tensor): Soft-Attention tensor. (N, 1, H, W) \n
textures (Tuple[Tensor]): Transferred GT textures.
[(N, C, H, W), (N, C/2, 2H, 2W), ...]
"""
levels = len(refs)
# query
query = F.unfold(img_lq, kernel_size=(3, 3), padding=1)
# key
key = F.unfold(ref_lq, kernel_size=(3, 3), padding=1)
key_t = key.permute(0, 2, 1)
# values
values = [
F.unfold(
refs[i],
kernel_size=3 * pow(2, i),
padding=pow(2, i),
stride=pow(2, i)) for i in range(levels)
]
key_t = F.normalize(key_t, dim=2) # [N, H*W, C*k*k]
query = F.normalize(query, dim=1) # [N, C*k*k, H*W]
# Relevance embedding
rel_embedding = torch.bmm(key_t, query) # [N, H*W, H*W]
max_val, max_index = torch.max(rel_embedding, dim=1) # [N, H*W]
# hard-attention
textures = [self.gather(value, 2, max_index) for value in values]
# to tensor
h, w = img_lq.size()[-2:]
textures = [
F.fold(
textures[i],
output_size=(h * pow(2, i), w * pow(2, i)),
kernel_size=3 * pow(2, i),
padding=pow(2, i),
stride=pow(2, i)) / 9. for i in range(levels)
]
soft_attention = max_val.view(max_val.size(0), 1, h, w)
return soft_attention, textures