Source code for mmagic.models.archs.upsample
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from .sr_backbone import default_init_weights
[docs]class PixelShufflePack(nn.Module):
"""Pixel Shuffle upsample layer.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
scale_factor (int): Upsample ratio.
upsample_kernel (int): Kernel size of Conv layer to expand channels.
Returns:
Upsampled feature map.
"""
def __init__(self, in_channels: int, out_channels: int, scale_factor: int,
upsample_kernel: int):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.scale_factor = scale_factor
self.upsample_kernel = upsample_kernel
self.upsample_conv = nn.Conv2d(
self.in_channels,
self.out_channels * scale_factor * scale_factor,
self.upsample_kernel,
padding=(self.upsample_kernel - 1) // 2)
self.init_weights()
[docs] def init_weights(self) -> None:
"""Initialize weights for PixelShufflePack."""
default_init_weights(self, 1)
[docs] def forward(self, x: Tensor) -> Tensor:
"""Forward function for PixelShufflePack.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
x = self.upsample_conv(x)
x = F.pixel_shuffle(x, self.scale_factor)
return x