mmagic.models.editors.edvr
¶
Package Contents¶
Classes¶
EDVR model for video super-resolution. |
|
EDVR network structure for video super-resolution. |
- class mmagic.models.editors.edvr.EDVR(generator, pixel_loss, train_cfg=None, test_cfg=None, init_cfg=None, data_preprocessor=None)[源代码]¶
Bases:
mmagic.models.BaseEditModel
EDVR model for video super-resolution.
EDVR: Video Restoration with Enhanced Deformable Convolutional Networks.
- 参数
generator (dict) – Config for the generator structure.
pixel_loss (dict) – Config for pixel-wise loss.
train_cfg (dict) – Config for training. Default: None.
test_cfg (dict) – Config for testing. Default: None.
init_cfg (dict, optional) – The weight initialized config for
BaseModule
.data_preprocessor (dict, optional) – The pre-process config of
BaseDataPreprocessor
.
- forward_train(inputs, data_samples=None)¶
Forward training. Returns dict of losses of training.
- 参数
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor
.
- 返回
Dict of losses.
- 返回类型
dict
- class mmagic.models.editors.edvr.EDVRNet(in_channels, out_channels, mid_channels=64, num_frames=5, deform_groups=8, num_blocks_extraction=5, num_blocks_reconstruction=10, center_frame_idx=2, with_tsa=True, init_cfg=None)[源代码]¶
Bases:
mmengine.model.BaseModule
EDVR network structure for video super-resolution.
Now only support X4 upsampling factor. Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks.
- 参数
in_channels (int) – Channel number of inputs.
out_channels (int) – Channel number of outputs.
mid_channels (int) – Channel number of intermediate features. Default: 64.
num_frames (int) – Number of input frames. Default: 5.
deform_groups (int) – Deformable groups. Defaults: 8.
num_blocks_extraction (int) – Number of blocks for feature extraction. Default: 5.
num_blocks_reconstruction (int) – Number of blocks for reconstruction. Default: 10.
center_frame_idx (int) – The index of center frame. Frame counting from 0. Default: 2.
with_tsa (bool) – Whether to use TSA module. Default: True.
init_cfg (dict, optional) – Initialization config dict. Default: None.
- forward(x)¶
Forward function for EDVRNet.
- 参数
x (Tensor) – Input tensor with shape (n, t, c, h, w).
- 返回
SR center frame with shape (n, c, h, w).
- 返回类型
Tensor
- init_weights()¶
Init weights for models.