Shortcuts

mmagic.models.editors.duf.duf 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn.functional as F
from mmengine.model import BaseModule


[文档]class DynamicUpsamplingFilter(BaseModule): """Dynamic upsampling filter used in DUF. Ref: https://github.com/yhjo09/VSR-DUF. It only supports input with 3 channels. And it applies the same filters to 3 channels. Args: filter_size (tuple): Filter size of generated filters. The shape is (kh, kw). Default: (5, 5). """ def __init__(self, filter_size=(5, 5)): super().__init__() if not isinstance(filter_size, tuple): raise TypeError('The type of filter_size must be tuple, ' f'but got type{filter_size}') if len(filter_size) != 2: raise ValueError('The length of filter size must be 2, ' f'but got {len(filter_size)}.') # generate a local expansion filter, similar to im2col self.filter_size = filter_size filter_prod = np.prod(filter_size) expansion_filter = torch.eye(int(filter_prod)).view( filter_prod, 1, *filter_size) # (kh*kw, 1, kh, kw) self.expansion_filter = expansion_filter.repeat( 3, 1, 1, 1) # repeat for all the 3 channels
[文档] def forward(self, x, filters): """Forward function for DynamicUpsamplingFilter. Args: x (Tensor): Input image with 3 channels. The shape is (n, 3, h, w). filters (Tensor): Generated dynamic filters. The shape is (n, filter_prod, upsampling_square, h, w). filter_prod: prod of filter kernel size, e.g., 1*5*5=25. upsampling_square: similar to pixel shuffle, upsampling_square = upsampling * upsampling e.g., for x 4 upsampling, upsampling_square= 4*4 = 16 Returns: Tensor: Filtered image with shape (n, 3*upsampling, h, w) """ n, filter_prod, upsampling_square, h, w = filters.size() kh, kw = self.filter_size expanded_input = F.conv2d( x, self.expansion_filter.to(x), padding=(kh // 2, kw // 2), groups=3) # (n, 3*filter_prod, h, w) expanded_input = expanded_input.view(n, 3, filter_prod, h, w).permute( 0, 3, 4, 1, 2) # (n, h, w, 3, filter_prod) filters = filters.permute( 0, 3, 4, 1, 2) # (n, h, w, filter_prod, upsampling_square] out = torch.matmul(expanded_input, filters) # (n, h, w, 3, upsampling_square) return out.permute(0, 3, 4, 1, 2).view(n, 3 * upsampling_square, h, w)
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.