mmagic.models.editors.basicvsr
¶
Package Contents¶
Classes¶
BasicVSR model for video super-resolution. |
|
BasicVSR network structure for video super-resolution. |
- class mmagic.models.editors.basicvsr.BasicVSR(generator, pixel_loss, ensemble=None, train_cfg=None, test_cfg=None, init_cfg=None, data_preprocessor=None)[source]¶
Bases:
mmagic.models.BaseEditModel
BasicVSR model for video super-resolution.
Note that this model is used for IconVSR.
- Paper:
BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond, CVPR, 2021
- Parameters
generator (dict) – Config for the generator structure.
pixel_loss (dict) – Config for pixel-wise loss.
ensemble (dict) – Config for ensemble. Default: None.
train_cfg (dict) – Config for training. Default: None.
test_cfg (dict) – Config for testing. Default: None.
init_cfg (dict, optional) – The weight initialized config for
BaseModule
.data_preprocessor (dict, optional) – The pre-process config of
BaseDataPreprocessor
.
- check_if_mirror_extended(lrs)[source]¶
Check whether the input is a mirror-extended sequence.
If mirror-extended, the i-th (i=0, …, t-1) frame is equal to the (t-1-i)-th frame.
- Parameters
lrs (tensor) – Input LR images with shape (n, t, c, h, w)
- forward_train(inputs, data_samples=None, **kwargs)[source]¶
Forward training. Returns dict of losses of training.
- Parameters
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor
.
- Returns
Dict of losses.
- Return type
dict
- forward_inference(inputs, data_samples=None, **kwargs)[source]¶
Forward inference. Returns predictions of validation, testing.
- Parameters
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor
.
- Returns
predictions.
- Return type
- class mmagic.models.editors.basicvsr.BasicVSRNet(mid_channels=64, num_blocks=30, spynet_pretrained=None)[source]¶
Bases:
mmengine.model.BaseModule
BasicVSR network structure for video super-resolution.
Support only x4 upsampling.
- Paper:
BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond, CVPR, 2021
- Parameters
mid_channels (int) – Channel number of the intermediate features. Default: 64.
num_blocks (int) – Number of residual blocks in each propagation branch. Default: 30.
spynet_pretrained (str) – Pre-trained model path of SPyNet. Default: None.
- check_if_mirror_extended(lrs)[source]¶
Check whether the input is a mirror-extended sequence.
If mirror-extended, the i-th (i=0, …, t-1) frame is equal to the (t-1-i)-th frame.
- Parameters
lrs (tensor) – Input LR images with shape (n, t, c, h, w)
- compute_flow(lrs)[source]¶
Compute optical flow using SPyNet for feature warping.
Note that if the input is an mirror-extended sequence, ‘flows_forward’ is not needed, since it is equal to ‘flows_backward.flip(1)’.
- Parameters
lrs (tensor) – Input LR images with shape (n, t, c, h, w)
- Returns
- Optical flow. ‘flows_forward’ corresponds to the
flows used for forward-time propagation (current to previous). ‘flows_backward’ corresponds to the flows used for backward-time propagation (current to next).
- Return type
tuple(Tensor)