Shortcuts

Source code for mmagic.models.utils.model_utils

# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import Any, Dict, List, Optional, Union

import torch
import torch.nn as nn
from mmengine import print_log
from mmengine.model.weight_init import (constant_init, kaiming_init,
                                        normal_init, update_init_info,
                                        xavier_init)
from mmengine.registry import Registry
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from torch import Tensor
from torch.nn import init

from mmagic.structures import DataSample
from mmagic.utils.typing import ForwardInputs
from .tome_utils import (add_tome_cfg_hook, build_mmagic_tomesd_block,
                         build_mmagic_wrapper_tomesd_block, isinstance_str)


[docs]def default_init_weights(module, scale=1): """Initialize network weights. Args: modules (nn.Module): Modules to be initialized. scale (float): Scale initialized weights, especially for residual blocks. Default: 1. """ for m in module.modules(): if isinstance(m, nn.Conv2d): kaiming_init(m, a=0, mode='fan_in', bias=0) m.weight.data *= scale elif isinstance(m, nn.Linear): kaiming_init(m, a=0, mode='fan_in', bias=0) m.weight.data *= scale elif isinstance(m, _BatchNorm): constant_init(m.weight, val=1, bias=0)
[docs]def make_layer(block, num_blocks, **kwarg): """Make layers by stacking the same blocks. Args: block (nn.module): nn.module class for basic block. num_blocks (int): number of blocks. Returns: nn.Sequential: Stacked blocks in nn.Sequential. """ layers = [] for _ in range(num_blocks): layers.append(block(**kwarg)) return nn.Sequential(*layers)
[docs]def get_module_device(module): """Get the device of a module. Args: module (nn.Module): A module contains the parameters. Returns: torch.device: The device of the module. """ try: next(module.parameters()) except StopIteration: raise ValueError('The input module should contain parameters.') if next(module.parameters()).is_cuda: return next(module.parameters()).get_device() else: return torch.device('cpu')
[docs]def set_requires_grad(nets, requires_grad=False): """Set requires_grad for all the networks. Args: nets (nn.Module | list[nn.Module]): A list of networks or a single network. requires_grad (bool): Whether the networks require gradients or not """ if not isinstance(nets, list): nets = [nets] for net in nets: if net is not None: for param in net.parameters(): param.requires_grad = requires_grad
[docs]def generation_init_weights(module, init_type='normal', init_gain=0.02): """Default initialization of network weights for image generation. By default, we use normal init, but xavier and kaiming might work better for some applications. Args: module (nn.Module): Module to be initialized. init_type (str): The name of an initialization method: normal | xavier | kaiming | orthogonal. Default: 'normal'. init_gain (float): Scaling factor for normal, xavier and orthogonal. Default: 0.02. """ def init_func(m): """Initialization function. Args: m (nn.Module): Module to be initialized. """ classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': normal_init(m, 0.0, init_gain) elif init_type == 'xavier': xavier_init(m, gain=init_gain, distribution='normal') elif init_type == 'kaiming': kaiming_init( m, a=0, mode='fan_in', nonlinearity='leaky_relu', distribution='normal') elif init_type == 'orthogonal': init.orthogonal_(m.weight, gain=init_gain) init.constant_(m.bias.data, 0.0) else: raise NotImplementedError( f"Initialization method '{init_type}' is not implemented") init_info = (f'Initialize {m.__class__.__name__} by \'init_type\' ' f'{init_type}.') elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; # only normal distribution applies. normal_init(m, 1.0, init_gain) init_info = (f'{m.__class__.__name__} is BatchNorm2d, initialize ' 'by Norm initialization with mean=1, ' f'std={init_gain}') if hasattr(m, '_params_init_info'): update_init_info(module, init_info) module.apply(init_func)
[docs]def get_valid_noise_size(noise_size: Optional[int], generator: Union[Dict, nn.Module]) -> Optional[int]: """Get the value of `noise_size` from input, `generator` and check the consistency of these values. If no conflict is found, return that value. Args: noise_size (Optional[int]): `noise_size` passed to `BaseGAN_refactor`'s initialize function. generator (ModelType): The config or the model of generator. Returns: int | None: The noise size feed to generator. """ if isinstance(generator, dict): model_noise_size = generator.get('noise_size', None) else: model_noise_size = getattr(generator, 'noise_size', None) # get noise_size if noise_size is not None and model_noise_size is not None: assert noise_size == model_noise_size, ( 'Input \'noise_size\' is inconsistent with ' f'\'generator.noise_size\'. Receive \'{noise_size}\' and ' f'\'{model_noise_size}\'.') else: noise_size = noise_size or model_noise_size return noise_size
[docs]def get_valid_num_batches(batch_inputs: Optional[ForwardInputs] = None, data_samples: List[DataSample] = None) -> int: """Try get the valid batch size from inputs. - If some values in `batch_inputs` are `Tensor` and 'num_batches' is in `batch_inputs`, we check whether the value of 'num_batches' and the the length of first dimension of all tensors are same. If the values are not same, `AssertionError` will be raised. If all values are the same, return the value. - If no values in `batch_inputs` is `Tensor`, 'num_batches' must be contained in `batch_inputs`. And this value will be returned. - If some values are `Tensor` and 'num_batches' is not contained in `batch_inputs`, we check whether all tensor have the same length on the first dimension. If the length are not same, `AssertionError` will be raised. If all length are the same, return the length as batch size. - If batch_inputs is a `Tensor`, directly return the length of the first dimension as batch size. Args: batch_inputs (ForwardInputs): Inputs passed to :meth:`forward`. Returns: int: The batch size of samples to generate. """ # attempt to infer num_batches from batch_inputs if batch_inputs is not None: if isinstance(batch_inputs, Tensor): return batch_inputs.shape[0] # get num_batches from batch_inputs num_batches_dict = { k: v.shape[0] for k, v in batch_inputs.items() if isinstance(v, Tensor) } if 'num_batches' in batch_inputs: num_batches_dict['num_batches'] = batch_inputs['num_batches'] if num_batches_dict: num_batches_inputs = list(num_batches_dict.values())[0] # ensure all num_batches are same assert all([ bz == num_batches_inputs for bz in num_batches_dict.values() ]), ('\'num_batches\' is inconsistency among the preprocessed ' f'input. \'num_batches\' parsed results: {num_batches_dict}') else: num_batches_inputs = None else: num_batches_inputs = None # attempt to infer num_batches from data_samples if data_samples is not None: num_batches_samples = len(data_samples) else: num_batches_samples = None if not (num_batches_inputs or num_batches_samples): print_log( 'Cannot get \'num_batches\' from both \'inputs\' and ' '\'data_samples\', automatically set \'num_batches\' as 1. ' 'This may leads to potential error.', 'current', logging.WARNING) return 1 elif num_batches_inputs and num_batches_samples: assert num_batches_inputs == num_batches_samples, ( '\'num_batches\' inferred from \'inputs\' and \'data_samples\' ' f'are different, ({num_batches_inputs} vs. {num_batches_samples}).' ' Please check your input carefully.') return num_batches_inputs or num_batches_samples
[docs]def build_module(module: Union[dict, nn.Module], builder: Registry, *args, **kwargs) -> Any: """Build module from config or return the module itself. Args: module (Union[dict, nn.Module]): The module to build. builder (Registry): The registry to build module. *args, **kwargs: Arguments passed to build function. Returns: Any: The built module. """ if isinstance(module, dict): return builder.build(module, *args, **kwargs) elif isinstance(module, nn.Module): return module else: raise TypeError( f'Only support dict and nn.Module, but got {type(module)}.')
[docs]def xformers_is_enable(verbose: bool = False) -> bool: """Check whether xformers is installed. Args: verbose (bool): Whether to print the log. Returns: bool: Whether xformers is installed. """ from mmagic.utils import try_import xformers = try_import('xformers') if xformers is None and verbose: print_log('Do not support Xformers.', 'current') return xformers is not None
[docs]def set_xformers(module: nn.Module, prefix: str = '') -> nn.Module: """Set xformers' efficient Attention for attention modules. Args: module (nn.Module): The module to set xformers. prefix (str): The prefix of the module name. Returns: nn.Module: The module with xformers' efficient Attention. """ if not xformers_is_enable(): print_log('Do not support Xformers. Please install Xformers first. ' 'The program will run without Xformers.') return for n, m in module.named_children(): if hasattr(m, 'set_use_memory_efficient_attention_xformers'): # set xformers for Diffusers' Cross Attention m.set_use_memory_efficient_attention_xformers(True) module_name = f'{prefix}.{n}' if prefix else n print_log( 'Enable Xformers for HuggingFace Diffusers\' ' f'module \'{module_name}\'.', 'current') else: set_xformers(m, prefix=n) return module
[docs]def set_tomesd(model: torch.nn.Module, ratio: float = 0.5, max_downsample: int = 1, sx: int = 2, sy: int = 2, use_rand: bool = True, merge_attn: bool = True, merge_crossattn: bool = False, merge_mlp: bool = False): """Patches a stable diffusion model with ToMe. Apply this to the highest level stable diffusion object. Refer to: https://github.com/dbolya/tomesd/blob/main/tomesd/patch.py#L173 # noqa Args: model (torch.nn.Module): A top level Stable Diffusion module to patch in place. ratio (float): The ratio of tokens to merge. I.e., 0.4 would reduce the total number of tokens by 40%.The maximum value for this is 1-(1/(`sx` * `sy`)). By default, the max ratio is 0.75 (usually <= 0.5 is recommended). Higher values result in more speed-up, but with more visual quality loss. max_downsample (int): Apply ToMe to layers with at most this amount of downsampling. E.g., 1 only applies to layers with no downsampling, while 8 applies to all layers. Should be chosen from [1, 2, 4, or 8]. 1 and 2 are recommended. sx, sy (int, int): The stride for computing dst sets. A higher stride means you can merge more tokens, default setting of (2, 2) works well in most cases. `sx` and `sy` do not need to divide image size. use_rand (bool): Whether or not to allow random perturbations when computing dst sets. By default: True, but if you're having weird artifacts you can try turning this off. merge_attn (bool): Whether or not to merge tokens for attention (recommended). merge_crossattn (bool): Whether or not to merge tokens for cross attention (not recommended). merge_mlp (bool): Whether or not to merge tokens for the mlp layers (particular not recommended). Returns: model (torch.nn.Module): Model patched by ToMe. """ # Make sure the module is not currently patched remove_tomesd(model) is_mmagic = isinstance_str(model, 'StableDiffusion') or isinstance_str( model, 'BaseModel') if is_mmagic: # Supports "StableDiffusion.unet" and "unet" diffusion_model = model.unet if hasattr(model, 'unet') else model if isinstance_str(diffusion_model, 'DenoisingUnet'): is_wrapper = False else: is_wrapper = True else: if not hasattr(model, 'model') or not hasattr(model.model, 'diffusion_model'): # Provided model not supported print('Expected a Stable Diffusion / Latent Diffusion model.') raise RuntimeError('Provided model was not supported.') diffusion_model = model.model.diffusion_model # TODO: can support more diffusion models, like Stability AI is_wrapper = None diffusion_model._tome_info = { 'size': None, 'hooks': [], 'args': { 'ratio': ratio, 'max_downsample': max_downsample, 'sx': sx, 'sy': sy, 'use_rand': use_rand, 'merge_attn': merge_attn, 'merge_crossattn': merge_crossattn, 'merge_mlp': merge_mlp } } add_tome_cfg_hook(diffusion_model) for _, module in diffusion_model.named_modules(): if isinstance_str(module, 'BasicTransformerBlock'): # TODO: can support more stable diffusion based models if is_mmagic: if is_wrapper is None: raise NotImplementedError( 'Specific ToMe block not implemented') elif not is_wrapper: make_tome_block_fn = build_mmagic_tomesd_block elif is_wrapper: make_tome_block_fn = build_mmagic_wrapper_tomesd_block else: raise TypeError( 'Currently `tome` only support *stable-diffusion* model!') module.__class__ = make_tome_block_fn(module.__class__) module._tome_info = diffusion_model._tome_info return model
[docs]def remove_tomesd(model: torch.nn.Module): """Removes a patch from a ToMe Diffusion module if it was already patched. Refer to: https://github.com/dbolya/tomesd/blob/main/tomesd/patch.py#L251 # noqa """ # For mmagic Stable Diffusion models model = model.unet if hasattr(model, 'unet') else model for _, module in model.named_modules(): if hasattr(module, '_tome_info'): for hook in module._tome_info['hooks']: hook.remove() module._tome_info['hooks'].clear() if module.__class__.__name__ == 'ToMeBlock': module.__class__ = module._parent return model