Shortcuts

Source code for mmagic.engine.optimizers.singan_optimizer_constructor

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch.nn as nn
from mmengine.optim import DefaultOptimWrapperConstructor, OptimWrapperDict

from mmagic.registry import OPTIM_WRAPPER_CONSTRUCTORS


@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
[docs]class SinGANOptimWrapperConstructor: """OptimizerConstructor for SinGAN models. Set optimizers for each submodule of SinGAN. All submodule must be contained in a :class:`torch.nn.ModuleList` named 'blocks'. And we access each submodule by `MODEL.blocks[SCALE]`, where `MODEL` is generator or discriminator, and the scale is the index of the resolution scale. More detail about the resolution scale and naming rule please refers to :class:`~mmagic.models.editors.singan.SinGANMultiScaleGenerator` and :class:`~mmagic.models.editors.singan.SinGANMultiScaleDiscriminator`. Example: >>> # build SinGAN model >>> model = dict( >>> type='SinGAN', >>> data_preprocessor=dict( >>> type='GANDataPreprocessor', >>> non_image_keys=['input_sample']), >>> generator=dict( >>> type='SinGANMultiScaleGenerator', >>> in_channels=3, >>> out_channels=3, >>> num_scales=2), >>> discriminator=dict( >>> type='SinGANMultiScaleDiscriminator', >>> in_channels=3, >>> num_scales=3)) >>> singan = MODELS.build(model) >>> # build constructor >>> optim_wrapper = dict( >>> generator=dict(optimizer=dict(type='Adam', lr=0.0005, >>> betas=(0.5, 0.999))), >>> discriminator=dict( >>> optimizer=dict(type='Adam', lr=0.0005, >>> betas=(0.5, 0.999)))) >>> optim_wrapper_dict_builder = SinGANOptimWrapperConstructor( >>> optim_wrapper) >>> # build optim wrapper dict >>> optim_wrapper_dict = optim_wrapper_dict_builder(singan) Args: optim_wrapper_cfg (dict): Config of the optimizer wrapper. paramwise_cfg (Optional[dict]): Parameter-wise options. """ def __init__(self, optim_wrapper_cfg: dict, paramwise_cfg: Optional[dict] = None): if not isinstance(optim_wrapper_cfg, dict): raise TypeError('optimizer_cfg should be a dict', f'but got {type(optim_wrapper_cfg)}') assert paramwise_cfg is None, ( 'paramwise_cfg should be set in each optimizer separately') self.optim_cfg = optim_wrapper_cfg self.constructors = {} for key, cfg in self.optim_cfg.items(): cfg_ = cfg.copy() paramwise_cfg_ = cfg_.pop('paramwise_cfg', None) self.constructors[key] = DefaultOptimWrapperConstructor( cfg_, paramwise_cfg_)
[docs] def __call__(self, module: nn.Module) -> OptimWrapperDict: """Build optimizer and return a optimizerwrapperdict.""" optimizers = {} if hasattr(module, 'module'): module = module.module num_scales = module.num_scales for key, constructor in self.constructors.items(): for idx in range(num_scales + 1): submodule = module._modules[key] if hasattr(submodule, 'module'): submodule = submodule.module optimizers[f'{key}_{idx}'] = constructor(submodule.blocks[idx]) optimizers = OptimWrapperDict(**optimizers) return optimizers
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.