Source code for mmagic.models.editors.ttsr.lte
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.model import BaseModule
from torchvision import models
from mmagic.models.archs import ImgNormalize
from mmagic.registry import MODELS
@MODELS.register_module()
[docs]class LTE(BaseModule):
"""Learnable Texture Extractor.
Based on pretrained VGG19. Generate features in 3 levels.
Args:
requires_grad (bool): Require grad or not. Default: True.
pixel_range (float): Pixel range of feature. Default: 1.
load_pretrained_vgg (bool): Load pretrained VGG from torchvision.
Default: True.
Train: must load pretrained VGG.
Eval: needn't load pretrained VGG, because we will load pretrained
LTE.
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self,
requires_grad=True,
pixel_range=1.,
load_pretrained_vgg=True,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
vgg_mean = (0.485, 0.456, 0.406)
vgg_std = (0.229 * pixel_range, 0.224 * pixel_range,
0.225 * pixel_range)
self.img_normalize = ImgNormalize(
pixel_range=pixel_range, img_mean=vgg_mean, img_std=vgg_std)
# use vgg19 weights to initialize
vgg_pretrained_features = models.vgg19(
pretrained=load_pretrained_vgg).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.slice1.parameters():
param.requires_grad = requires_grad
for param in self.slice2.parameters():
param.requires_grad = requires_grad
for param in self.slice3.parameters():
param.requires_grad = requires_grad
[docs] def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, 3, h, w).
Returns:
Tuple[Tensor]: Forward results in 3 levels.
x_level3: Forward results in level 3 (n, 256, h/4, w/4).
x_level2: Forward results in level 2 (n, 128, h/2, w/2).
x_level1: Forward results in level 1 (n, 64, h, w).
"""
x = self.img_normalize(x)
x_level1 = x = self.slice1(x)
x_level2 = x = self.slice2(x)
x_level3 = x = self.slice3(x)
return [x_level3, x_level2, x_level1]