Source code for eg3d_cvt-official-rgb_shapenet-128x128
# Copyright (c) OpenMMLab. All rights reserved.
# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa
# mmcv >= 2.0.1
# mmengine >= 0.8.0
from mmengine.config import read_base
from mmengine.dataset import DefaultSampler
from mmagic.datasets import BasicConditionalDataset
from mmagic.datasets.transforms import LoadImageFromFile, PackInputs
from mmagic.engine import VisualizationHook
from mmagic.evaluation.metrics import FrechetInceptionDistance
from mmagic.models.data_preprocessors import DataPreprocessor
from mmagic.models.editors.eg3d import EG3D, TriplaneGenerator, UniformCamera
with read_base():
from .._base_.gen_default_runtime import *
[docs]model = dict(
type=EG3D,
data_preprocessor=dict(type=DataPreprocessor),
generator=dict(
type=TriplaneGenerator,
out_size=128,
zero_cond_input=True,
cond_scale=0,
sr_in_size=64,
renderer_cfg=dict(
# Official implementation set ray_start, ray_end and box_warp as
# 0.1, 2.6 and 1.6 respectively, and FID is 7.2441
# ray_start=0.1,
# ray_end=2.6,
# box_warp=1.6,
ray_start=0.4,
ray_end=2.0,
box_warp=1.7,
depth_resolution=64,
depth_resolution_importance=64,
white_back=True,
),
rgb2bgr=True),
camera=dict(
type=UniformCamera,
horizontal_mean=3.141,
horizontal_std=3.141,
vertical_mean=3.141 / 2,
vertical_std=3.141 / 2,
focal=1.025390625,
up=[0, 0, 1],
radius=1.2),
)
train_cfg = train_dataloader = optim_wrapper = None
val_cfg = val_dataloader = val_evaluator = None
[docs]metrics = [
dict(
type=FrechetInceptionDistance,
prefix='FID-Full',
fake_nums=50000,
inception_pkl=inception_pkl,
need_cond_input=True,
sample_model='orig'),
dict(
type=FrechetInceptionDistance,
prefix='FID-Random-Camera',
fake_nums=50000,
inception_pkl=inception_pkl,
sample_model='orig'),
]
[docs]test_pipeline = [
dict(type=LoadImageFromFile, key='img', color_type='color'),
dict(type=PackInputs)
]
[docs]test_dataset = dict(
type=BasicConditionalDataset,
data_root='./data/eg3d/shapenet-car',
ann_file='shapenet.json',
pipeline=test_pipeline)
[docs]test_dataloader = dict(
# NOTE: `batch_size = 16` cost nearly **12GB** of GPU memory,
# modification this param by yourself corresponding to your own GPU.
batch_size=16,
persistent_workers=False,
drop_last=False,
sampler=dict(type=DefaultSampler, shuffle=False),
num_workers=9,
dataset=test_dataset)
[docs]custom_hooks = [
dict(
type=VisualizationHook,
interval=5000,
fixed_input=True,
# save_at_test=False,
vis_kwargs_list=dict(type='GAN', name='fake_img'))
]