Shortcuts

base_tof 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.dataset import DefaultSampler, InfiniteSampler
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
                            LoggerHook, ParamSchedulerHook)
from mmengine.optim import MultiStepLR, OptimWrapper
from mmengine.runner import IterBasedTrainLoop

from mmagic.datasets.transforms import LoadImageFromFile, PackInputs
from mmagic.engine.runner import MultiTestLoop, MultiValLoop
from mmagic.evaluation import MAE, PSNR, SSIM

[文档]_base_ = '../default_runtime.py'
[文档]train_pipeline = [ dict( type=LoadImageFromFile, key='img', channel_order='rgb', imdecode_backend='pillow'), dict( type=LoadImageFromFile, key='gt', channel_order='rgb', imdecode_backend='pillow'), dict(type=PackInputs)
]
[文档]demo_pipeline = [ dict( type=LoadImageFromFile, key='img', channel_order='rgb', imdecode_backend='pillow'), dict(type=PackInputs)
] # dataset settings
[文档]train_dataset_type = 'BasicFramesDataset'
[文档]val_dataset_type = 'BasicFramesDataset'
[文档]data_root = 'data/vimeo_triplet'
[文档]save_dir = './work_dirs'
[文档]train_dataloader = dict( num_workers=4, persistent_workers=False, sampler=dict(type=InfiniteSampler, shuffle=True), dataset=dict( type=train_dataset_type, ann_file='tri_trainlist.txt', metainfo=dict(dataset_type='vimeo90k', task_name='vfi'), data_root=data_root, data_prefix=dict(img='sequences', gt='sequences'), pipeline=train_pipeline, depth=2, load_frames_list=dict(img=['im1.png', 'im3.png'], gt=['im2.png'])))
[文档]val_dataloader = dict( num_workers=4, persistent_workers=False, drop_last=False, sampler=dict(type=DefaultSampler, shuffle=False), dataset=dict( type=val_dataset_type, ann_file='tri_testlist.txt', metainfo=dict(dataset_type='vimeo90k', task_name='vfi'), data_root=data_root, data_prefix=dict(img='sequences', gt='sequences'), pipeline=train_pipeline, depth=2, load_frames_list=dict(img=['im1.png', 'im3.png'], gt=['im2.png'])))
[文档]test_dataloader = val_dataloader
[文档]val_evaluator = [ dict(type=MAE), dict(type=PSNR), dict(type=SSIM),
]
[文档]test_evaluator = val_evaluator
# 5000 iters == 1 epoch
[文档]epoch_length = 5000
[文档]train_cfg = dict( type=IterBasedTrainLoop, max_iters=1_000_000, val_interval=epoch_length)
[文档]val_cfg = dict(type=MultiValLoop)
[文档]test_cfg = dict(type=MultiTestLoop)
# optimizer
[文档]optim_wrapper = dict( constructor='DefaultOptimWrapperConstructor', type=OptimWrapper, optimizer=dict( type='Adam', lr=5e-5, betas=(0.9, 0.99), weight_decay=1e-4,
), ) # learning policy
[文档]param_scheduler = dict( type=MultiStepLR, by_epoch=False, gamma=0.5, milestones=[200000, 400000, 600000, 800000])
[文档]default_hooks = dict( checkpoint=dict( type=CheckpointHook, interval=epoch_length, save_optimizer=True, by_epoch=False, out_dir=save_dir, ), timer=dict(type=IterTimerHook), logger=dict(type=LoggerHook, interval=100), param_scheduler=dict(type=ParamSchedulerHook), sampler_seed=dict(type=DistSamplerSeedHook),
)
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.