Source code for base_gl
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.model import MMSeparateDistributedDataParallel
from mmengine.optim import OptimWrapper
from mmagic.models import DataPreprocessor
from mmagic.models.editors import (GLDecoder, GLDilationNeck, GLEncoder,
GLEncoderDecoder)
from mmagic.models.editors.global_local import GLDiscs, GLInpaintor
from mmagic.models.losses import GANLoss, L1Loss
# DistributedDataParallel
[docs]model = dict(
type=GLInpaintor,
data_preprocessor=dict(
type=DataPreprocessor,
mean=[127.5],
std=[127.5],
),
encdec=dict(
type=GLEncoderDecoder,
encoder=dict(type=GLEncoder, norm_cfg=dict(type='SyncBN')),
decoder=dict(type=GLDecoder, norm_cfg=dict(type='SyncBN')),
dilation_neck=dict(type=GLDilationNeck, norm_cfg=dict(type='SyncBN'))),
disc=dict(
type=GLDiscs,
global_disc_cfg=dict(
in_channels=3,
max_channels=512,
fc_in_channels=512 * 4 * 4,
fc_out_channels=1024,
num_convs=6,
norm_cfg=dict(type='SyncBN'),
),
local_disc_cfg=dict(
in_channels=3,
max_channels=512,
fc_in_channels=512 * 4 * 4,
fc_out_channels=1024,
num_convs=5,
norm_cfg=dict(type='SyncBN'),
),
),
loss_gan=dict(
type=GANLoss,
gan_type='vanilla',
loss_weight=0.001,
),
loss_l1_hole=dict(
type=L1Loss,
loss_weight=1.0,
))
# optimizer
[docs]optim_wrapper = dict(
constructor='MultiOptimWrapperConstructor',
generator=dict(type=OptimWrapper, optimizer=dict(type='Adam', lr=0.0004)),
disc=dict(type=OptimWrapper, optimizer=dict(type='Adam', lr=0.0004)))
# learning policy
# Fixed