mmagic.models.editors.ttsr.search_transformer
¶
Module Contents¶
Classes¶
Search texture reference by transformer. |
- class mmagic.models.editors.ttsr.search_transformer.SearchTransformer(init_cfg: Union[dict, List[dict], None] = None)[source]¶
Bases:
mmengine.model.BaseModule
Search texture reference by transformer.
Include relevance embedding, hard-attention and soft-attention.
- gather(inputs, dim, index)[source]¶
Hard Attention. Gathers values along an axis specified by dim.
- Parameters
inputs (Tensor) – The source tensor. (N, C*k*k, H*W)
dim (int) – The axis along which to index.
index (Tensor) – The indices of elements to gather. (N, H*W)
- results:
outputs (Tensor): The result tensor. (N, C*k*k, H*W)
- forward(img_lq, ref_lq, refs)[source]¶
Texture transformer.
Q = LTE(img_lq) K = LTE(ref_lq) V = LTE(ref), from V_level_n to V_level_1
- Relevance embedding aims to embed the relevance between the LQ and
Ref image by estimating the similarity between Q and K.
- Hard-Attention: Only transfer features from the most relevant position
in V for each query.
- Soft-Attention: synthesize features from the transferred GT texture
features T and the LQ features F from the backbone.
- Parameters
extractor (All args are features come from) – These features contain 3 levels. When upscale_factor=4, the size ratio of these features is level3:level2:level1 = 1:2:4.
img_lq (Tensor) – Tensor of 4x bicubic-upsampled lq image. (N, C, H, W)
ref_lq (Tensor) – Tensor of ref_lq. ref_lq is obtained by applying bicubic down-sampling and up-sampling with factor 4x on ref. (N, C, H, W)
refs (Tuple[Tensor]) – Tuple of ref tensors. [(N, C, H, W), (N, C/2, 2H, 2W), …]
- Returns
- tuple contains:
soft_attention (Tensor): Soft-Attention tensor. (N, 1, H, W)
textures (Tuple[Tensor]): Transferred GT textures. [(N, C, H, W), (N, C/2, 2H, 2W), …]
- Return type
tuple