Shortcuts

mmagic.engine.optimizers.multi_optimizer_constructor

Module Contents

Classes

MultiOptimWrapperConstructor

OptimizerConstructor for GAN models. This class construct optimizer for

Functions

get_params_by_names(→ Tuple[list, list])

Support two kinds of name matching:

class mmagic.engine.optimizers.multi_optimizer_constructor.MultiOptimWrapperConstructor(optim_wrapper_cfg: dict, paramwise_cfg=None)[source]

OptimizerConstructor for GAN models. This class construct optimizer for the submodules of the model separately, and return a mmengine.optim.OptimWrapperDict or mmengine.optim.OptimWrapper.

Example 1: Build multi optimizers (e.g., GANs):
>>> # build GAN model
>>> model = dict(
>>>     type='GANModel',
>>>     num_classes=10,
>>>     generator=dict(type='Generator'),
>>>     discriminator=dict(type='Discriminator'))
>>> gan_model = MODELS.build(model)
>>> # build constructor
>>> optim_wrapper = dict(
>>>     generator=dict(
>>>         type='OptimWrapper',
>>>         accumulative_counts=1,
>>>         optimizer=dict(type='Adam', lr=0.0002,
>>>                        betas=(0.5, 0.999))),
>>>     discriminator=dict(
>>>         type='OptimWrapper',
>>>         accumulative_counts=1,
>>>         optimizer=dict(type='Adam', lr=0.0002,
>>>                            betas=(0.5, 0.999))))
>>> optim_dict_builder = MultiOptimWrapperConstructor(optim_wrapper)
>>> # build optim wrapper dict
>>> optim_wrapper_dict = optim_dict_builder(gan_model)
Example 2: Build multi optimizers for specific submodules:
>>> # build model
>>> class GAN(nn.Module):
>>>     def __init__(self) -> None:
>>>         super().__init__()
>>>         self.generator = nn.Conv2d(3, 3, 1)
>>>         self.discriminator = nn.Conv2d(3, 3, 1)
>>> class TextEncoder(nn.Module):
>>>     def __init__(self):
>>>         super().__init__()
>>>         self.embedding = nn.Embedding(100, 100)
>>> class ToyModel(nn.Module):
>>>     def __init__(self) -> None:
>>>         super().__init__()
>>>         self.m1 = GAN()
>>>         self.m2 = nn.Conv2d(3, 3, 1)
>>>         self.m3 = nn.Linear(2, 2)
>>>         self.text_encoder = TextEncoder()
>>> model = ToyModel()
>>> # build constructor
>>> optim_wrapper = {
>>>     '.*embedding': {
>>>         'type': 'OptimWrapper',
>>>         'optimizer': {
>>>             'type': 'Adam',
>>>             'lr': 1e-4,
>>>             'betas': (0.9, 0.99)
>>>         }
>>>     },
>>>     'm1.generator': {
>>>         'type': 'OptimWrapper',
>>>         'optimizer': {
>>>             'type': 'Adam',
>>>             'lr': 1e-5,
>>>             'betas': (0.9, 0.99)
>>>         }
>>>     },
>>>     'm2': {
>>>         'type': 'OptimWrapper',
>>>         'optimizer': {
>>>             'type': 'Adam',
>>>             'lr': 1e-5,
>>>         }
>>>     }
>>> }
>>> optim_dict_builder = MultiOptimWrapperConstructor(optim_wrapper)
>>> # build optim wrapper dict
>>> optim_wrapper_dict = optim_dict_builder(model)
Example 3: Build a single optimizer for multi modules (e.g., DreamBooth):
>>> # build StableDiffusion model
>>> model = dict(
>>>     type='StableDiffusion',
>>>     unet=dict(type='unet'),
>>>     vae=dict(type='vae'),
        text_encoder=dict(type='text_encoder'))
>>> diffusion_model = MODELS.build(model)
>>> # build constructor
>>> optim_wrapper = dict(
>>>     modules=['unet', 'text_encoder']
>>>     optimizer=dict(type='Adam', lr=0.0002),
>>>     accumulative_counts=1)
>>> optim_dict_builder = MultiOptimWrapperConstructor(optim_wrapper)
>>> # build optim wrapper dict
>>> optim_wrapper_dict = optim_dict_builder(diffusion_model)
Parameters
  • optim_wrapper_cfg_dict (dict) – Config of the optimizer wrapper.

  • paramwise_cfg (dict) – Config of parameter-wise settings. Default: None.

__call__(module: torch.nn.Module) Union[mmengine.optim.OptimWrapperDict, mmengine.optim.OptimWrapper][source]

Build optimizer and return a optimizer_wrapper_dict.

mmagic.engine.optimizers.multi_optimizer_constructor.get_params_by_names(module: torch.nn.Module, names: Union[str, list]) Tuple[list, list][source]
Support two kinds of name matching:
  1. matching name from first-level submodule.

  2. matching name by re.fullmatch.

Parameters
  • module (nn.Module) – The module to get parameters.

  • names (Union[str, list]) – The name or a list of names of the submodule parameters.

Returns

A list of parameters and corresponding name for logging.

Return type

Tuple[list]

Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.