Shortcuts

Source code for adm_ddim250_8xb32_imagenet_256x256

# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.config import read_base

with read_base():
    from .._base_.datasets.imagenet_64 import *
    from .._base_.gen_default_runtime import *

from mmagic.engine.hooks.visualization_hook import VisualizationHook
from mmagic.evaluation.metrics import FrechetInceptionDistance
from mmagic.models.data_preprocessors.data_preprocessor import DataPreprocessor
from mmagic.models.diffusion_schedulers.ddim_scheduler import EditDDIMScheduler
from mmagic.models.editors.ddpm.denoising_unet import (DenoisingUnet,
                                                       MultiHeadAttentionBlock)
from mmagic.models.editors.guided_diffusion.adm import AblatedDiffusionModel

[docs]model = dict( type=AblatedDiffusionModel, data_preprocessor=dict(type=DataPreprocessor), unet=dict( type=DenoisingUnet, image_size=256, in_channels=3, base_channels=256, resblocks_per_downsample=2, attention_res=(32, 16, 8), norm_cfg=dict(type='GN32', num_groups=32), dropout=0.1, num_classes=1000, use_fp16=False, resblock_updown=True, attention_cfg=dict( type=MultiHeadAttentionBlock, num_heads=4, num_head_channels=64, use_new_attention_order=False), use_scale_shift_norm=True), diffusion_scheduler=dict( type=EditDDIMScheduler, variance_type='learned_range', beta_schedule='linear'), rgb2bgr=True, use_fp16=False)
test_dataloader.update(dict(batch_size=32, num_workers=8))
[docs]train_cfg = dict(max_iters=100000)
[docs]metrics = [ dict( type=FrechetInceptionDistance, prefix='FID-Full-50k', fake_nums=50000, inception_style='StyleGAN', sample_model='orig', sample_kwargs=dict( num_inference_steps=250, show_progress=True, classifier_scale=1.))
]
[docs]val_evaluator = dict(metrics=metrics)
[docs]test_evaluator = dict(metrics=metrics)
# VIS_HOOK
[docs]custom_hooks = [dict(type=VisualizationHook, interval=5000, fixed_input=True)]