Source code for base_glean
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
LoggerHook, ParamSchedulerHook)
from mmengine.model import MMSeparateDistributedDataParallel
from mmengine.optim import CosineAnnealingLR, OptimWrapper
from mmengine.runner import IterBasedTrainLoop
from mmagic.engine.runner import MultiTestLoop, MultiValLoop
from mmagic.evaluation import MAE, PSNR, SSIM
# DistributedDataParallel
[docs]model_wrapper_cfg = dict(
type=MMSeparateDistributedDataParallel, find_unused_parameters=True)
]
# optimizer
[docs]optim_wrapper = dict(
constructor='MultiOptimWrapperConstructor',
generator=dict(
type=OptimWrapper,
optimizer=dict(type='Adam', lr=1e-4, betas=(0.9, 0.99))),
discriminator=dict(
type=OptimWrapper,
optimizer=dict(type='Adam', lr=1e-4, betas=(0.9, 0.99))),
)
# learning policy
[docs]default_hooks = dict(
checkpoint=dict(
type=CheckpointHook,
interval=5000,
save_optimizer=True,
by_epoch=False,
out_dir=save_dir,
save_best=['MAE', 'PSNR', 'SSIM'],
rule=['less', 'greater', 'greater']),
timer=dict(type=IterTimerHook),
logger=dict(type=LoggerHook, interval=100),
param_scheduler=dict(type=ParamSchedulerHook),
sampler_seed=dict(type=DistSamplerSeedHook),
)