Shortcuts

biggan_2xb25-500kiters_cifar10-32x32 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.config import read_base

from mmagic.engine import VisualizationHook
from mmagic.evaluation.metrics import FrechetInceptionDistance
from mmagic.models.data_preprocessors import DataPreprocessor
from mmagic.models.editors.biggan import (BigGAN, BigGANDiscriminator,
                                          BigGANGenerator)

# define model

with read_base():
    from .._base_.datasets.cifar10_noaug import *
    from .._base_.gen_default_runtime import *
[文档]ema_config = dict( type='ExponentialMovingAverage', interval=1, momentum=0.0001, start_iter=1000)
[文档]model = dict( type=BigGAN, num_classes=10, data_preprocessor=dict(type=DataPreprocessor, output_channel_order='BGR'), generator=dict( type=BigGANGenerator, output_scale=32, noise_size=128, num_classes=10, base_channels=64, with_shared_embedding=False, sn_eps=1e-8, sn_style='torch', split_noise=False, auto_sync_bn=False, init_cfg=dict(type='N02')), discriminator=dict( type=BigGANDiscriminator, input_scale=32, num_classes=10, base_channels=64, sn_eps=1e-8, sn_style='torch', with_spectral_norm=True, init_cfg=dict(type='N02')), generator_steps=1, discriminator_steps=4, ema_config=ema_config)
# define dataset
[文档]train_dataloader = dict(batch_size=25, num_workers=8)
[文档]val_dataloader = dict(batch_size=25, num_workers=8)
[文档]test_dataloader = dict(batch_size=25, num_workers=8)
# VIS_HOOK
[文档]custom_hooks = [ dict( type=VisualizationHook, interval=5000, fixed_input=True, # vis ema and orig at the same time vis_kwargs_list=dict( type='Noise', name='fake_img', sample_model='ema/orig', target_keys=['ema', 'orig'])),
]
[文档]optim_wrapper = dict( generator=dict(optimizer=dict(type='Adam', lr=0.0002, betas=(0.0, 0.999))), discriminator=dict( optimizer=dict(type='Adam', lr=0.0002, betas=(0.0, 0.999))))
[文档]train_cfg = dict(max_iters=500000)
[文档]metrics = [ dict( type=FrechetInceptionDistance, prefix='FID-Full-50k', fake_nums=50000, inception_style='StyleGAN', sample_model='ema'), dict( type='IS', prefix='IS-50k', fake_nums=50000, inception_style='StyleGAN', sample_model='ema')
] # save multi best checkpoints
[文档]default_hooks = dict( checkpoint=dict( save_best=['FID-Full-50k/fid', 'IS-50k/is'], rule=['less', 'greater']))
[文档]val_evaluator = dict(metrics=metrics)
[文档]test_evaluator = dict(metrics=metrics)
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.