Shortcuts

Source code for mmagic.models.editors.liif.liif

# Copyright (c) OpenMMLab. All rights reserved.
import math

import torch

from mmagic.models.base_models import BaseEditModel
from mmagic.registry import MODELS
from mmagic.structures import DataSample


@MODELS.register_module()
[docs]class LIIF(BaseEditModel): """LIIF model for single image super-resolution. Paper: Learning Continuous Image Representation with Local Implicit Image Function Args: generator (dict): Config for the generator. pixel_loss (dict): Config for the pixel loss. pretrained (str): Path for pretrained model. Default: None. data_preprocessor (dict, optional): The pre-process config of :class:`BaseDataPreprocessor`. """
[docs] def forward_tensor(self, inputs, data_samples=None, **kwargs): """Forward tensor. Returns result of simple forward. Args: inputs (torch.Tensor): batch input tensor collated by :attr:`data_preprocessor`. data_samples (List[BaseDataElement], optional): data samples collated by :attr:`data_preprocessor`. Returns: Tensor: result of simple forward. """ coord = torch.stack(data_samples.metainfo['coord']).to(inputs) cell = torch.stack(data_samples.metainfo['cell']).to(inputs) feats = self.generator(inputs, coord, cell, **kwargs) return feats
[docs] def forward_inference(self, inputs, data_samples=None, **kwargs): """Forward inference. Returns predictions of validation, testing, and simple inference. Args: inputs (torch.Tensor): batch input tensor collated by :attr:`data_preprocessor`. data_samples (BaseDataElement, optional): data samples collated by :attr:`data_preprocessor`. Returns: List[DataSample]: predictions. """ # NOTE: feats: shape [bz, N, 3] feats = self.forward_tensor(inputs, data_samples, test_mode=True) # reshape for eval, [bz, N, 3] -> [bz, 3, H, W] ih, iw = inputs.shape[-2:] # metainfo in stacked data sample is a list, fetch by indexing coord_count = data_samples.metainfo['coord'][0].shape[0] s = math.sqrt(coord_count / (ih * iw)) shape = [len(data_samples), round(ih * s), round(iw * s), 3] feats = feats.view(shape).permute(0, 3, 1, 2).contiguous() feats = self.data_preprocessor.destruct(feats, data_samples) predictions = DataSample(pred_img=feats.cpu()) return predictions
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.