Source code for base_deepfillv1
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.model import MMSeparateDistributedDataParallel
from mmengine.optim import OptimWrapper
from mmagic.models import DataPreprocessor
from mmagic.models.archs import MultiLayerDiscriminator
from mmagic.models.editors import (ContextualAttentionNeck, DeepFillDecoder,
DeepFillEncoder, DeepFillEncoderDecoder,
DeepFillRefiner, DeepFillv1Discriminators,
DeepFillv1Inpaintor, GLDilationNeck,
GLEncoderDecoder)
from mmagic.models.losses import (DiscShiftLoss, GANLoss, GradientPenaltyLoss,
L1Loss)
# DistributedDataParallel
[docs]model = dict(
type=DeepFillv1Inpaintor,
data_preprocessor=dict(
type=DataPreprocessor,
mean=[127.5],
std=[127.5],
),
encdec=dict(
type=DeepFillEncoderDecoder,
stage1=dict(
type=GLEncoderDecoder,
encoder=dict(type=DeepFillEncoder, padding_mode='reflect'),
decoder=dict(
type=DeepFillDecoder, in_channels=128, padding_mode='reflect'),
dilation_neck=dict(
type=GLDilationNeck,
in_channels=128,
act_cfg=dict(type='ELU'),
padding_mode='reflect')),
stage2=dict(
type=DeepFillRefiner,
encoder_attention=dict(
type=DeepFillEncoder,
encoder_type='stage2_attention',
padding_mode='reflect'),
encoder_conv=dict(
type=DeepFillEncoder,
encoder_type='stage2_conv',
padding_mode='reflect'),
dilation_neck=dict(
type=GLDilationNeck,
in_channels=128,
act_cfg=dict(type='ELU'),
padding_mode='reflect'),
contextual_attention=dict(
type=ContextualAttentionNeck,
in_channels=128,
padding_mode='reflect'),
decoder=dict(
type=DeepFillDecoder, in_channels=256,
padding_mode='reflect'))),
disc=dict(
type=DeepFillv1Discriminators,
global_disc_cfg=dict(
type=MultiLayerDiscriminator,
in_channels=3,
max_channels=256,
fc_in_channels=256 * 16 * 16,
fc_out_channels=1,
num_convs=4,
norm_cfg=None,
act_cfg=dict(type='ELU'),
out_act_cfg=dict(type='LeakyReLU', negative_slope=0.2)),
local_disc_cfg=dict(
type=MultiLayerDiscriminator,
in_channels=3,
max_channels=512,
fc_in_channels=512 * 8 * 8,
fc_out_channels=1,
num_convs=4,
norm_cfg=None,
act_cfg=dict(type='ELU'),
out_act_cfg=dict(type='LeakyReLU', negative_slope=0.2))),
stage1_loss_type=('loss_l1_hole', 'loss_l1_valid'),
stage2_loss_type=('loss_l1_hole', 'loss_l1_valid', 'loss_gan'),
loss_gan=dict(
type=GANLoss,
gan_type='wgan',
loss_weight=0.0001,
),
loss_l1_hole=dict(
type=L1Loss,
loss_weight=1.0,
),
loss_l1_valid=dict(
type=L1Loss,
loss_weight=1.0,
),
loss_gp=dict(type=GradientPenaltyLoss, loss_weight=10.),
loss_disc_shift=dict(type=DiscShiftLoss, loss_weight=0.001))
# optimizer
[docs]optim_wrapper = dict(
constructor='MultiOptimWrapperConstructor',
generator=dict(type=OptimWrapper, optimizer=dict(type='Adam', lr=0.0001)),
disc=dict(type=OptimWrapper, optimizer=dict(type='Adam', lr=0.0001)))
# learning policy
# Fixed