Source code for stylegan2_c2_8xb4_lsun_car_384x512
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.config import read_base
from mmengine.dataset import DefaultSampler, InfiniteSampler
from torch.optim import Adam
from mmagic.datasets.transforms import (Flip, LoadImageFromFile, NumpyPad,
PackInputs)
from mmagic.engine import VisualizationHook
from mmagic.evaluation import (FrechetInceptionDistance, PerceptualPathLength,
PrecisionAndRecall)
from mmagic.models import BaseGAN
with read_base():
from .._base_.gen_default_runtime import * # noqa: F403,F405
from .._base_.models.base_styleganv2 import * # noqa: F403,F405
# reg params
model.update(
generator=dict(out_size=512),
discriminator=dict(in_size=512),
ema_config=dict(
type=ExponentialMovingAverage,
interval=1,
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))),
loss_config=dict(
r1_loss_weight=10. / 2. * d_reg_interval,
r1_interval=d_reg_interval,
norm_mode='HWC',
g_reg_interval=g_reg_interval,
g_reg_weight=2. * g_reg_interval,
pl_batch_shrink=2))
train_cfg.update(max_iters=1800002)
optim_wrapper.update(
generator=dict(
optimizer=dict(
type=Adam, lr=0.002 * g_reg_ratio, betas=(0, 0.99**g_reg_ratio))),
discriminator=dict(
optimizer=dict(
type=Adam, lr=0.002 * d_reg_ratio, betas=(0, 0.99**d_reg_ratio))))
# DATA
[docs]train_pipeline = [
dict(type=LoadImageFromFile, key='gt'),
dict(
type=NumpyPad,
keys='img',
padding=((64, 64), (0, 0), (0, 0)),
),
dict(type=Flip, keys=['gt'], direction='horizontal'),
dict(type=PackInputs)
]
# `batch_size` and `data_root` need to be set.
[docs]train_dataloader = dict(
batch_size=4,
num_workers=8,
persistent_workers=True,
sampler=dict(type=InfiniteSampler, shuffle=True),
dataset=dict(
type=dataset_type, data_root=data_root, pipeline=train_pipeline))
[docs]val_dataloader = dict(
batch_size=4,
num_workers=8,
dataset=dict(
type=dataset_type,
data_root=data_root, # set by user
pipeline=val_pipeline),
sampler=dict(type=DefaultSampler, shuffle=False),
persistent_workers=True)
[docs]test_dataloader = dict(
batch_size=4,
num_workers=8,
dataset=dict(
type=dataset_type,
data_root=data_root, # set by user
pipeline=val_pipeline),
sampler=dict(type=DefaultSampler, shuffle=False),
persistent_workers=True)
# VIS_HOOK
[docs]custom_hooks = [
dict(
type=VisualizationHook,
interval=5000,
fixed_input=True,
vis_kwargs_list=dict(type=BaseGAN, name='fake_img'))
]
# METRICS
[docs]metrics = [
dict(
type=FrechetInceptionDistance,
prefix='FID-50k',
fake_nums=50000,
real_nums=50000,
inception_style='StyleGAN',
sample_model='ema'),
dict(type=PrecisionAndRecall, fake_nums=50000, prefix='PR-50K'),
dict(type=PerceptualPathLength, fake_nums=50000, prefix='ppl-w')
]
# NOTE: config for save multi best checkpoints
# default_hooks.update(
# checkpoint=dict(
# save_best=['FID-Full-50k/fid', 'IS-50k/is'],
# rule=['less', 'greater']))
default_hooks.update(checkpoint=dict(save_best='FID-50k/fid'))
val_evaluator.update(metrics=metrics)
test_evaluator.update(metrics=metrics)