Shortcuts

Source code for mmagic.engine.optimizers.multi_optimizer_constructor

# Copyright (c) OpenMMLab. All rights reserved.
import re
from typing import Tuple, Union

import torch.nn as nn
from mmengine import print_log
from mmengine.optim import (DefaultOptimWrapperConstructor, OptimWrapper,
                            OptimWrapperDict)

from mmagic.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
                             OPTIMIZERS)


@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
[docs]class MultiOptimWrapperConstructor: """OptimizerConstructor for GAN models. This class construct optimizer for the submodules of the model separately, and return a :class:`mmengine.optim.OptimWrapperDict` or :class:`mmengine.optim.OptimWrapper`. Example 1: Build multi optimizers (e.g., GANs): >>> # build GAN model >>> model = dict( >>> type='GANModel', >>> num_classes=10, >>> generator=dict(type='Generator'), >>> discriminator=dict(type='Discriminator')) >>> gan_model = MODELS.build(model) >>> # build constructor >>> optim_wrapper = dict( >>> generator=dict( >>> type='OptimWrapper', >>> accumulative_counts=1, >>> optimizer=dict(type='Adam', lr=0.0002, >>> betas=(0.5, 0.999))), >>> discriminator=dict( >>> type='OptimWrapper', >>> accumulative_counts=1, >>> optimizer=dict(type='Adam', lr=0.0002, >>> betas=(0.5, 0.999)))) >>> optim_dict_builder = MultiOptimWrapperConstructor(optim_wrapper) >>> # build optim wrapper dict >>> optim_wrapper_dict = optim_dict_builder(gan_model) Example 2: Build multi optimizers for specific submodules: >>> # build model >>> class GAN(nn.Module): >>> def __init__(self) -> None: >>> super().__init__() >>> self.generator = nn.Conv2d(3, 3, 1) >>> self.discriminator = nn.Conv2d(3, 3, 1) >>> class TextEncoder(nn.Module): >>> def __init__(self): >>> super().__init__() >>> self.embedding = nn.Embedding(100, 100) >>> class ToyModel(nn.Module): >>> def __init__(self) -> None: >>> super().__init__() >>> self.m1 = GAN() >>> self.m2 = nn.Conv2d(3, 3, 1) >>> self.m3 = nn.Linear(2, 2) >>> self.text_encoder = TextEncoder() >>> model = ToyModel() >>> # build constructor >>> optim_wrapper = { >>> '.*embedding': { >>> 'type': 'OptimWrapper', >>> 'optimizer': { >>> 'type': 'Adam', >>> 'lr': 1e-4, >>> 'betas': (0.9, 0.99) >>> } >>> }, >>> 'm1.generator': { >>> 'type': 'OptimWrapper', >>> 'optimizer': { >>> 'type': 'Adam', >>> 'lr': 1e-5, >>> 'betas': (0.9, 0.99) >>> } >>> }, >>> 'm2': { >>> 'type': 'OptimWrapper', >>> 'optimizer': { >>> 'type': 'Adam', >>> 'lr': 1e-5, >>> } >>> } >>> } >>> optim_dict_builder = MultiOptimWrapperConstructor(optim_wrapper) >>> # build optim wrapper dict >>> optim_wrapper_dict = optim_dict_builder(model) Example 3: Build a single optimizer for multi modules (e.g., DreamBooth): >>> # build StableDiffusion model >>> model = dict( >>> type='StableDiffusion', >>> unet=dict(type='unet'), >>> vae=dict(type='vae'), text_encoder=dict(type='text_encoder')) >>> diffusion_model = MODELS.build(model) >>> # build constructor >>> optim_wrapper = dict( >>> modules=['unet', 'text_encoder'] >>> optimizer=dict(type='Adam', lr=0.0002), >>> accumulative_counts=1) >>> optim_dict_builder = MultiOptimWrapperConstructor(optim_wrapper) >>> # build optim wrapper dict >>> optim_wrapper_dict = optim_dict_builder(diffusion_model) Args: optim_wrapper_cfg_dict (dict): Config of the optimizer wrapper. paramwise_cfg (dict): Config of parameter-wise settings. Default: None. """ def __init__(self, optim_wrapper_cfg: dict, paramwise_cfg=None): if not isinstance(optim_wrapper_cfg, dict): raise TypeError('optimizer_cfg should be a dict', f'but got {type(optim_wrapper_cfg)}') assert paramwise_cfg is None, ( 'paramwise_cfg should be set in each optimizer separately') self.optim_cfg = optim_wrapper_cfg if 'modules' in optim_wrapper_cfg: # single optimizer with multi param groups cfg_ = optim_wrapper_cfg.copy() self.modules = cfg_.pop('modules') paramwise_cfg_ = cfg_.pop('paramwise_cfg', None) self.constructors = DefaultOptimWrapperConstructor( cfg_, paramwise_cfg_) else: self.constructors = {} self.modules = {} for key, cfg in self.optim_cfg.items(): cfg_ = cfg.copy() if 'modules' in cfg_: self.modules[key] = cfg_.pop('modules') paramwise_cfg_ = cfg_.pop('paramwise_cfg', None) self.constructors[key] = DefaultOptimWrapperConstructor( cfg_, paramwise_cfg_)
[docs] def __call__(self, module: nn.Module) -> Union[OptimWrapperDict, OptimWrapper]: """Build optimizer and return a optimizer_wrapper_dict.""" optimizers = {} if hasattr(module, 'module'): module = module.module if isinstance(self.constructors, dict): for key, constructor in self.constructors.items(): module_names = self.modules[key] if self.modules else key if (isinstance(module_names, str) and module_names in module._modules): optimizers[key] = constructor( module._modules[module_names]) optim_wrapper_cfg = constructor.optimizer_cfg print_log( f'Add to optimizer \'{key}\' ' f'({optim_wrapper_cfg}): \'{key}\'.', 'current') else: assert not constructor.paramwise_cfg, ( 'Do not support paramwise_cfg for multi module ' 'optimizer.') params, found_names = get_params_by_names( module, module_names) # build optimizer optimizer_cfg = constructor.optimizer_cfg.copy() optimizer_cfg['params'] = params optimizer = OPTIMIZERS.build(optimizer_cfg) # build optimizer wrapper optim_wrapper_cfg = constructor.optim_wrapper_cfg.copy() optim_wrapper_cfg.setdefault('type', 'OptimWrapper') optim_wrapper = OPTIM_WRAPPERS.build( optim_wrapper_cfg, default_args=dict(optimizer=optimizer)) for name in found_names: print_log( f'Add to optimizer \'{key}\' ' f'({constructor.optimizer_cfg}): \'{name}\'.', 'current') optimizers[key] = optim_wrapper return OptimWrapperDict(**optimizers) else: params, found_names = get_params_by_names(module, self.modules) constructor = self.constructors assert not constructor.paramwise_cfg, ( 'Do not support paramwise_cfg for multi parameters') optimizer_cfg = constructor.optimizer_cfg.copy() optimizer_cfg['params'] = params optimizer = OPTIMIZERS.build(optimizer_cfg) for name in found_names: print_log( f'Add to optimizer ({constructor.optimizer_cfg}): ' f'\'{name}\'.', 'current') # build optimizer wrapper optim_wrapper_cfg = constructor.optim_wrapper_cfg.copy() optim_wrapper_cfg.setdefault('type', 'OptimWrapper') optim_wrapper = OPTIM_WRAPPERS.build( optim_wrapper_cfg, default_args=dict(optimizer=optimizer)) return optim_wrapper
[docs]def get_params_by_names(module: nn.Module, names: Union[str, list]) -> Tuple[list, list]: """Support two kinds of name matching: 1. matching name from **first-level** submodule. 2. matching name by `re.fullmatch`. Args: module (nn.Module): The module to get parameters. names (Union[str, list]): The name or a list of names of the submodule parameters. Returns: Tuple[list]: A list of parameters and corresponding name for logging. """ if not isinstance(names, list): names = [names] params = [] found_names = [] for name in names: if name in module._modules: params.extend(module._modules[name].parameters()) found_names.append(name) else: for n, m in module.named_modules(): if re.fullmatch(name, n): params.extend(m.parameters()) found_names.append(n) return params, found_names
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.