Shortcuts

Source code for mmagic.models.editors.ddpm.denoising_unet

# Copyright (c) OpenMMLab. All rights reserved.
import math
from copy import deepcopy
from functools import partial
from typing import Tuple

import mmengine
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.bricks import build_norm_layer
from mmcv.cnn.bricks.conv_module import ConvModule
from mmengine.logging import MMLogger
from mmengine.model import BaseModule, constant_init
from mmengine.runner import load_checkpoint
from mmengine.utils.dl_utils import TORCH_VERSION
from mmengine.utils.version_utils import digit_version

from mmagic.registry import MODELS
from .embeddings import TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block

[docs]logger = MMLogger.get_current_instance()
[docs]class EmbedSequential(nn.Sequential): """A sequential module that passes timestep embeddings to the children that support it as an extra input. Modified from https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/unet.py#L35 """
[docs] def forward(self, x, y, encoder_out=None): for layer in self: if isinstance(layer, DenoisingResBlock): x = layer(x, y) elif isinstance( layer, MultiHeadAttentionBlock) and encoder_out is not None: x = layer(x, encoder_out) else: x = layer(x) return x
@MODELS.register_module('GN32')
[docs]class GroupNorm32(nn.GroupNorm): def __init__(self, num_channels, num_groups=32, **kwargs): super().__init__(num_groups, num_channels, **kwargs)
[docs] def forward(self, x): return super().forward(x.float()).type(x.dtype)
[docs]def convert_module_to_f16(layer): """Convert primitive modules to float16.""" if isinstance(layer, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): layer.weight.data = layer.weight.data.half() if layer.bias is not None: layer.bias.data = layer.bias.data.half()
[docs]def convert_module_to_f32(layer): """Convert primitive modules to float32, undoing convert_module_to_f16().""" if isinstance(layer, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): layer.weight.data = layer.weight.data.float() if layer.bias is not None: layer.bias.data = layer.bias.data.float()
@MODELS.register_module()
[docs]class SiLU(BaseModule): r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise. The SiLU function is also known as the swish function. Args: input (bool, optional): Use inplace operation or not. Defaults to `False`. """ def __init__(self, inplace=False): super().__init__() if digit_version(TORCH_VERSION) <= digit_version('1.6.0') and inplace: mmengine.print_log( 'Inplace version of \'SiLU\' is not supported for ' f'torch < 1.6.0, found \'{torch.version}\'.') self.inplace = inplace
[docs] def forward(self, x): """Forward function for SiLU. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Tensor after activation. """ if digit_version(TORCH_VERSION) <= digit_version('1.6.0'): return x * torch.sigmoid(x) return F.silu(x, inplace=self.inplace)
@MODELS.register_module()
[docs]class MultiHeadAttention(BaseModule): """An attention block allows spatial position to attend to each other. Originally ported from here, but adapted to the N-d case. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. # noqa Args: in_channels (int): Channels of the input feature map. num_heads (int, optional): Number of heads in the attention. norm_cfg (dict, optional): Config for normalization layer. Default to ``dict(type='GN', num_groups=32)`` """ def __init__(self, in_channels, num_heads=1, norm_cfg=dict(type='GN', num_groups=32)): super().__init__() self.num_heads = num_heads _, self.norm = build_norm_layer(norm_cfg, in_channels) self.qkv = nn.Conv1d(in_channels, in_channels * 3, 1) self.proj = nn.Conv1d(in_channels, in_channels, 1) self.init_weights() @staticmethod
[docs] def QKVAttention(qkv): channel = qkv.shape[1] // 3 q, k, v = torch.chunk(qkv, 3, dim=1) scale = 1 / np.sqrt(np.sqrt(channel)) weight = torch.einsum('bct,bcs->bts', q * scale, k * scale) weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) weight = torch.einsum('bts,bcs->bct', weight, v) return weight
[docs] def forward(self, x): """Forward function for multi head attention. Args: x (torch.Tensor): Input feature map. Returns: torch.Tensor: Feature map after attention. """ b, c, *spatial = x.shape x = x.reshape(b, c, -1) qkv = self.qkv(self.norm(x)) qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2]) h = self.QKVAttention(qkv) h = h.reshape(b, -1, h.shape[-1]) h = self.proj(h) return (h + x).reshape(b, c, *spatial)
[docs] def init_weights(self): constant_init(self.proj, 0)
@MODELS.register_module()
[docs]class MultiHeadAttentionBlock(BaseModule): """An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted to the N-d case. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. """ def __init__(self, in_channels, num_heads=1, num_head_channels=-1, use_new_attention_order=False, norm_cfg=dict(type='GN32', num_groups=32), encoder_channels=None): super().__init__() self.in_channels = in_channels if num_head_channels == -1: self.num_heads = num_heads else: assert (in_channels % num_head_channels == 0), ( f'q,k,v channels {in_channels} is not divisible by ' 'num_head_channels {num_head_channels}') self.num_heads = in_channels // num_head_channels _, self.norm = build_norm_layer(norm_cfg, in_channels) self.qkv = nn.Conv1d(in_channels, in_channels * 3, 1) if use_new_attention_order: # split qkv before split heads self.attention = QKVAttention(self.num_heads) else: # split heads before split qkv self.attention = QKVAttentionLegacy(self.num_heads) self.proj_out = nn.Conv1d(in_channels, in_channels, 1) if encoder_channels is not None: self.encoder_kv = nn.Conv1d(encoder_channels, in_channels * 2, 1)
[docs] def forward(self, x, encoder_out=None): b, c, *spatial = x.shape x = x.reshape(b, c, -1) qkv = self.qkv(self.norm(x)) if encoder_out is not None: encoder_out = self.encoder_kv(encoder_out) h = self.attention(qkv, encoder_out) else: h = self.attention(qkv) h = self.proj_out(h) return (x + h).reshape(b, c, *spatial)
@MODELS.register_module()
[docs]class QKVAttentionLegacy(BaseModule): """A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping """ def __init__(self, n_heads): super().__init__() self.n_heads = n_heads
[docs] def forward(self, qkv, encoder_kv=None): """Apply QKV attention. :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after attention. """ bs, width, length = qkv.shape assert width % (3 * self.n_heads) == 0 ch = width // (3 * self.n_heads) q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split( ch, dim=1) if encoder_kv is not None: assert encoder_kv.shape[1] == self.n_heads * ch * 2 ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split( ch, dim=1) k = torch.cat([ek, k], dim=-1) v = torch.cat([ev, v], dim=-1) scale = 1 / math.sqrt(math.sqrt(ch)) weight = torch.einsum( 'bct,bcs->bts', q * scale, k * scale) # More stable with f16 than dividing afterwards weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) a = torch.einsum('bts,bcs->bct', weight, v) return a.reshape(bs, -1, length)
@MODELS.register_module()
[docs]class QKVAttention(BaseModule): """A module which performs QKV attention and splits in a different order.""" def __init__(self, n_heads): super().__init__() self.n_heads = n_heads
[docs] def forward(self, qkv): """Apply QKV attention. :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after attention. """ bs, width, length = qkv.shape assert width % (3 * self.n_heads) == 0 ch = width // (3 * self.n_heads) q, k, v = qkv.chunk(3, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) weight = torch.einsum( 'bct,bcs->bts', (q * scale).view(bs * self.n_heads, ch, length), (k * scale).view(bs * self.n_heads, ch, length), ) # More stable with f16 than dividing afterwards weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) a = torch.einsum('bts,bcs->bct', weight, v.reshape(bs * self.n_heads, ch, length)) return a.reshape(bs, -1, length)
@MODELS.register_module()
[docs]class TimeEmbedding(BaseModule): """Time embedding layer, reference to Two level embedding. First embedding time by an embedding function, then feed to neural networks. Args: in_channels (int): The channel number of the input feature map. embedding_channels (int): The channel number of the output embedding. embedding_mode (str, optional): Embedding mode for the time embedding. Defaults to 'sin'. embedding_cfg (dict, optional): Config for time embedding. Defaults to None. act_cfg (dict, optional): Config for activation layer. Defaults to ``dict(type='SiLU', inplace=False)``. """ def __init__(self, in_channels, embedding_channels, embedding_mode='sin', embedding_cfg=None, act_cfg=dict(type='SiLU', inplace=False)): super().__init__() self.blocks = nn.Sequential( nn.Linear(in_channels, embedding_channels), MODELS.build(act_cfg), nn.Linear(embedding_channels, embedding_channels)) # add `dim` to embedding config embedding_cfg_ = dict(dim=in_channels) if embedding_cfg is not None: embedding_cfg_.update(embedding_cfg) if embedding_mode.upper() == 'SIN': self.embedding_fn = partial(self.sinusodial_embedding, **embedding_cfg_) else: raise ValueError('Only support `SIN` for time embedding, ' f'but receive {embedding_mode}.') @staticmethod
[docs] def sinusodial_embedding(timesteps, dim, max_period=10000): """Create sinusoidal timestep embeddings. Args: timesteps (torch.Tensor): Timestep to embedding. 1-D tensor shape as ``[bz, ]``, one per batch element. dim (int): The dimension of the embedding. max_period (int, optional): Controls the minimum frequency of the embeddings. Defaults to ``10000``. Returns: torch.Tensor: Embedding results shape as `[bz, dim]`. """ half = dim // 2 freqs = torch.exp( -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=timesteps.device) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat( [embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding
[docs] def forward(self, t): """Forward function for time embedding layer. Args: t (torch.Tensor): Input timesteps. Returns: torch.Tensor: Timesteps embedding. """ return self.blocks(self.embedding_fn(t))
@MODELS.register_module()
[docs]class DenoisingResBlock(BaseModule): """Resblock for the denoising network. If `in_channels` not equals to `out_channels`, a learnable shortcut with conv layers will be added. Args: in_channels (int): Number of channels of the input feature map. embedding_channels (int): Number of channels of the input embedding. use_scale_shift_norm (bool): Whether use scale-shift-norm in `NormWithEmbedding` layer. dropout (float): Probability of the dropout layers. out_channels (int, optional): Number of output channels of the ResBlock. If not defined, the output channels will equal to the `in_channels`. Defaults to `None`. norm_cfg (dict, optional): The config for the normalization layers. Defaults too ``dict(type='GN', num_groups=32)``. act_cfg (dict, optional): The config for the activation layers. Defaults to ``dict(type='SiLU', inplace=False)``. shortcut_kernel_size (int, optional): The kernel size for the shortcut conv. Defaults to ``1``. """ def __init__(self, in_channels, embedding_channels, use_scale_shift_norm, dropout, out_channels=None, norm_cfg=dict(type='GN', num_groups=32), act_cfg=dict(type='SiLU', inplace=False), shortcut_kernel_size=1, up=False, down=False): super().__init__() out_channels = in_channels if out_channels is None else out_channels _norm_cfg = deepcopy(norm_cfg) _, norm_1 = build_norm_layer(_norm_cfg, in_channels) conv_1 = [ norm_1, MODELS.build(act_cfg), nn.Conv2d(in_channels, out_channels, 3, padding=1) ] self.conv_1 = nn.Sequential(*conv_1) norm_with_embedding_cfg = dict( in_channels=out_channels, embedding_channels=embedding_channels, use_scale_shift=use_scale_shift_norm, norm_cfg=_norm_cfg) self.norm_with_embedding = MODELS.build( dict(type='NormWithEmbedding'), default_args=norm_with_embedding_cfg) conv_2 = [ MODELS.build(act_cfg), nn.Dropout(dropout), nn.Conv2d(out_channels, out_channels, 3, padding=1) ] self.conv_2 = nn.Sequential(*conv_2) assert shortcut_kernel_size in [ 1, 3 ], ('Only support `1` and `3` for `shortcut_kernel_size`, but ' f'receive {shortcut_kernel_size}.') self.learnable_shortcut = out_channels != in_channels if self.learnable_shortcut: shortcut_padding = 1 if shortcut_kernel_size == 3 else 0 self.shortcut = nn.Conv2d( in_channels, out_channels, shortcut_kernel_size, padding=shortcut_padding) self.updown = up or down if up: self.h_upd = DenoisingUpsample(in_channels, False) self.x_upd = DenoisingUpsample(in_channels, False) elif down: self.h_upd = DenoisingDownsample(in_channels, False) self.x_upd = DenoisingDownsample(in_channels, False) else: self.h_upd = self.x_upd = nn.Identity() self.init_weights()
[docs] def forward_shortcut(self, x): if self.learnable_shortcut: return self.shortcut(x) return x
[docs] def forward(self, x, y): """Forward function. Args: x (torch.Tensor): Input feature map tensor. y (torch.Tensor): Shared time embedding or shared label embedding. Returns: torch.Tensor : Output feature map tensor. """ if self.updown: in_rest, in_conv = self.conv_1[:-1], self.conv_1[-1] h = in_rest(x) h = self.h_upd(h) x = self.x_upd(x) h = in_conv(h) else: h = self.conv_1(x) shortcut = self.forward_shortcut(x) h = self.norm_with_embedding(h, y) h = self.conv_2(h) return h + shortcut
[docs] def init_weights(self): # apply zero init to last conv layer constant_init(self.conv_2[-1], 0)
@MODELS.register_module()
[docs]class NormWithEmbedding(BaseModule): """Nornalization with embedding layer. If `use_scale_shift == True`, embedding results will be chunked and used to re-shift and re-scale normalization results. Otherwise, embedding results will directly add to input of normalization layer. Args: in_channels (int): Number of channels of the input feature map. embedding_channels (int) Number of channels of the input embedding. norm_cfg (dict, optional): Config for the normalization operation. Defaults to `dict(type='GN', num_groups=32)`. act_cfg (dict, optional): Config for the activation layer. Defaults to `dict(type='SiLU', inplace=False)`. use_scale_shift (bool): If True, the output of Embedding layer will be split to 'scale' and 'shift' and map the output of normalization layer to ``out * (1 + scale) + shift``. Otherwise, the output of Embedding layer will be added with the input before normalization operation. Defaults to True. """ def __init__(self, in_channels, embedding_channels, norm_cfg=dict(type='GN', num_groups=32), act_cfg=dict(type='SiLU', inplace=False), use_scale_shift=True): super().__init__() self.use_scale_shift = use_scale_shift _, self.norm = build_norm_layer(norm_cfg, in_channels) embedding_output = in_channels * 2 if use_scale_shift else in_channels self.embedding_layer = nn.Sequential( MODELS.build(act_cfg), nn.Linear(embedding_channels, embedding_output))
[docs] def forward(self, x, y): """Forward function. Args: x (torch.Tensor): Input feature map tensor. y (torch.Tensor): Shared time embedding or shared label embedding. Returns: torch.Tensor : Output feature map tensor. """ embedding = self.embedding_layer(y).type(x.dtype) embedding = embedding[:, :, None, None] if self.use_scale_shift: scale, shift = torch.chunk(embedding, 2, dim=1) x = self.norm(x) x = x * (1 + scale) + shift else: x = self.norm(x + embedding) return x
@MODELS.register_module()
[docs]class DenoisingDownsample(BaseModule): """Downsampling operation used in the denoising network. Support average pooling and convolution for downsample operation. Args: in_channels (int): Number of channels of the input feature map to be downsampled. with_conv (bool, optional): Whether use convolution operation for downsampling. Defaults to `True`. """ def __init__(self, in_channels, with_conv=True): super().__init__() if with_conv: self.downsample = nn.Conv2d(in_channels, in_channels, 3, 2, 1) else: self.downsample = nn.AvgPool2d(kernel_size=2, stride=2)
[docs] def forward(self, x): """Forward function for downsampling operation. Args: x (torch.Tensor): Feature map to downsample. Returns: torch.Tensor: Feature map after downsampling. """ return self.downsample(x)
@MODELS.register_module()
[docs]class DenoisingUpsample(BaseModule): """Upsampling operation used in the denoising network. Allows users to apply an additional convolution layer after the nearest interpolation operation. Args: in_channels (int): Number of channels of the input feature map to be downsampled. with_conv (bool, optional): Whether apply an additional convolution layer after upsampling. Defaults to `True`. """ def __init__(self, in_channels, with_conv=True): super().__init__() self.with_conv = with_conv if with_conv: self.conv = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
[docs] def forward(self, x): """Forward function for upsampling operation. Args: x (torch.Tensor): Feature map to upsample. Returns: torch.Tensor: Feature map after upsampling. """ x = F.interpolate(x, scale_factor=2, mode='nearest') if self.with_conv: x = self.conv(x) return x
[docs]def build_down_block_resattn(resblocks_per_downsample, resblock_cfg, in_channels_, out_channels_, attention_scale, attention_cfg, in_channels_list, level, channel_factor_list, embedding_channels, use_scale_shift_norm, dropout, norm_cfg, resblock_updown, downsample_cfg, scale): """build unet down path blocks with resnet and attention.""" in_blocks = nn.ModuleList() for _ in range(resblocks_per_downsample): layers = [ MODELS.build( resblock_cfg, default_args={ 'in_channels': in_channels_, 'out_channels': out_channels_ }) ] in_channels_ = out_channels_ if scale in attention_scale: layers.append( MODELS.build( attention_cfg, default_args={'in_channels': in_channels_})) in_channels_list.append(in_channels_) in_blocks.append(EmbedSequential(*layers)) if level != len(channel_factor_list) - 1: in_blocks.append( EmbedSequential( DenoisingResBlock( out_channels_, embedding_channels, use_scale_shift_norm, dropout, norm_cfg=norm_cfg, out_channels=out_channels_, down=True) if resblock_updown else MODELS.build( downsample_cfg, default_args={'in_channels': in_channels_}))) in_channels_list.append(in_channels_) scale *= 2 return in_blocks, scale
[docs]def build_mid_blocks_resattn(resblock_cfg, attention_cfg, in_channels_): """build unet mid blocks with resnet and attention.""" return EmbedSequential( MODELS.build(resblock_cfg, default_args={'in_channels': in_channels_}), MODELS.build( attention_cfg, default_args={'in_channels': in_channels_}), MODELS.build(resblock_cfg, default_args={'in_channels': in_channels_}),
)
[docs]def build_up_blocks_resattn( resblocks_per_downsample, resblock_cfg, in_channels_, in_channels_list, base_channels, factor, scale, attention_scale, attention_cfg, channel_factor_list, level, embedding_channels, use_scale_shift_norm, dropout, norm_cfg, resblock_updown, upsample_cfg, ): """build up path blocks with resnet and attention.""" out_blocks = nn.ModuleList() for idx in range(resblocks_per_downsample + 1): layers = [ MODELS.build( resblock_cfg, default_args={ 'in_channels': in_channels_ + in_channels_list.pop(), 'out_channels': int(base_channels * factor) }) ] in_channels_ = int(base_channels * factor) if scale in attention_scale: layers.append( MODELS.build( attention_cfg, default_args={'in_channels': in_channels_})) if (level != len(channel_factor_list) - 1 and idx == resblocks_per_downsample): out_channels_ = in_channels_ layers.append( DenoisingResBlock( in_channels_, embedding_channels, use_scale_shift_norm, dropout, norm_cfg=norm_cfg, out_channels=out_channels_, up=True) if resblock_updown else MODELS. build( upsample_cfg, default_args={'in_channels': in_channels_})) scale //= 2 out_blocks.append(EmbedSequential(*layers)) return out_blocks, in_channels_, scale
@MODELS.register_module()
[docs]class DenoisingUnet(BaseModule): """Denoising Unet. This network receives a diffused image ``x_t`` and current timestep ``t``, and returns a ``output_dict`` corresponding to the passed ``output_cfg``. ``output_cfg`` defines the number of channels and the meaning of the output. ``output_cfg`` mainly contains keys of ``mean`` and ``var``, denoting how the network outputs mean and variance required for the denoising process. For ``mean``: 1. ``dict(mean='EPS')``: Model will predict noise added in the diffusion process, and the ``output_dict`` will contain a key named ``eps_t_pred``. 2. ``dict(mean='START_X')``: Model will direct predict the mean of the original image `x_0`, and the ``output_dict`` will contain a key named ``x_0_pred``. 3. ``dict(mean='X_TM1_PRED')``: Model will predict the mean of diffused image at `t-1` timestep, and the ``output_dict`` will contain a key named ``x_tm1_pred``. For ``var``: 1. ``dict(var='FIXED_SMALL')`` or ``dict(var='FIXED_LARGE')``: Variance in the denoising process is regarded as a fixed value. Therefore only 'mean' will be predicted, and the output channels will equal to the input image (e.g., three channels for RGB image.) 2. ``dict(var='LEARNED')``: Model will predict `log_variance` in the denoising process, and the ``output_dict`` will contain a key named ``log_var``. 3. ``dict(var='LEARNED_RANGE')``: Model will predict an interpolation factor and the `log_variance` will be calculated as `factor * upper_bound + (1-factor) * lower_bound`. The ``output_dict`` will contain a key named ``factor``. If ``var`` is not ``FIXED_SMALL`` or ``FIXED_LARGE``, the number of output channels will be the double of input channels, where the first half part contains predicted mean values and the other part is the predicted variance values. Otherwise, the number of output channels equals to the input channels, only containing the predicted mean values. Args: image_size (int | list[int]): The size of image to denoise. in_channels (int, optional): The input channels of the input image. Defaults as ``3``. out_channels (int, optional): The output channels of the output prediction. Defaults as ``None`` for automaticaaly assigned by ``var_mode``. base_channels (int, optional): The basic channel number of the generator. The other layers contain channels based on this number. Defaults to ``128``. resblocks_per_downsample (int, optional): Number of ResBlock used between two downsample operations. The number of ResBlock between upsample operations will be the same value to keep symmetry. Defaults to 3. num_timesteps (int, optional): The total timestep of the denoising process and the diffusion process. Defaults to ``1000``. use_rescale_timesteps (bool, optional): Whether rescale the input timesteps in range of [0, 1000]. Defaults to ``True``. dropout (float, optional): The probability of dropout operation of each ResBlock. Pass ``0`` to do not use dropout. Defaults as 0. embedding_channels (int, optional): The output channels of time embedding layer and label embedding layer. If not passed (or passed ``-1``), output channels of the embedding layers will set as four times of ``base_channels``. Defaults to ``-1``. num_classes (int, optional): The number of conditional classes. If set to 0, this model will be degraded to an unconditional model. Defaults to 0. channels_cfg (list | dict[list], optional): Config for input channels of the intermediate blocks. If list is passed, each element of the list indicates the scale factor for the input channels of the current block with regard to the ``base_channels``. For block ``i``, the input and output channels should be ``channels_cfg[i] * base_channels`` and ``channels_cfg[i+1] * base_channels`` If dict is provided, the key of the dict should be the output scale and corresponding value should be a list to define channels. Default: Please refer to ``_default_channels_cfg``. output_cfg (dict, optional): Config for output variables. Defaults to ``dict(mean='eps', var='learned_range')``. norm_cfg (dict, optional): The config for normalization layers. Defaults to ``dict(type='GN', num_groups=32)``. act_cfg (dict, optional): The config for activation layers. Defaults to ``dict(type='SiLU', inplace=False)``. shortcut_kernel_size (int, optional): The kernel size for shortcut conv in ResBlocks. The value of this argument will overwrite the default value of `resblock_cfg`. Defaults to `3`. use_scale_shift_norm (bool, optional): Whether perform scale and shift after normalization operation. Defaults to True. num_heads (int, optional): The number of attention heads. Defaults to 4. time_embedding_mode (str, optional): Embedding method of ``time_embedding``. Defaults to 'sin'. time_embedding_cfg (dict, optional): Config for ``time_embedding``. Defaults to None. resblock_cfg (dict, optional): Config for ResBlock. Defaults to ``dict(type='DenoisingResBlock')``. attention_cfg (dict, optional): Config for attention operation. Defaults to ``dict(type='MultiHeadAttention')``. upsample_conv (bool, optional): Whether use conv in upsample block. Defaults to ``True``. downsample_conv (bool, optional): Whether use conv operation in downsample block. Defaults to ``True``. upsample_cfg (dict, optional): Config for upsample blocks. Defaults to ``dict(type='DenoisingDownsample')``. downsample_cfg (dict, optional): Config for downsample blocks. Defaults to ``dict(type='DenoisingUpsample')``. attention_res (int | list[int], optional): Resolution of feature maps to apply attention operation. Defaults to ``[16, 8]``. pretrained (str | dict, optional): Path for the pretrained model or dict containing information for pretrained models whose necessary key is 'ckpt_path'. Besides, you can also provide 'prefix' to load the generator part from the whole state dict. Defaults to None. """
[docs] _default_channels_cfg = { 512: [0.5, 1, 1, 2, 2, 4, 4], 256: [1, 1, 2, 2, 4, 4], 128: [1, 1, 2, 3, 4], 64: [1, 2, 3, 4], 32: [1, 2, 2, 2]
} def __init__(self, image_size, in_channels=3, out_channels=None, base_channels=128, resblocks_per_downsample=3, num_timesteps=1000, use_rescale_timesteps=False, dropout=0, embedding_channels=-1, num_classes=0, use_fp16=False, channels_cfg=None, output_cfg=dict(mean='eps', var='learned_range'), norm_cfg=dict(type='GN', num_groups=32), act_cfg=dict(type='SiLU', inplace=False), shortcut_kernel_size=1, use_scale_shift_norm=False, resblock_updown=False, num_heads=4, time_embedding_mode='sin', time_embedding_cfg=None, resblock_cfg=dict(type='DenoisingResBlock'), attention_cfg=dict(type='MultiHeadAttention'), encoder_channels=None, downsample_conv=True, upsample_conv=True, downsample_cfg=dict(type='DenoisingDownsample'), upsample_cfg=dict(type='DenoisingUpsample'), attention_res=[16, 8], pretrained=None, unet_type='', down_block_types: Tuple[str] = (), up_block_types: Tuple[str] = (), cross_attention_dim=768, layers_per_block: int = 2): super().__init__() self.unet_type = unet_type self.num_classes = num_classes self.num_timesteps = num_timesteps self.base_channels = base_channels self.encoder_channels = encoder_channels self.use_rescale_timesteps = use_rescale_timesteps self.dtype = torch.float16 if use_fp16 else torch.float32 self.output_cfg = deepcopy(output_cfg) self.mean_mode = self.output_cfg.get('mean', 'eps') self.var_mode = self.output_cfg.get('var', 'learned_range') self.in_channels = in_channels # double output_channels to output mean and var at same time if out_channels is None: out_channels = in_channels if 'FIXED' in self.var_mode.upper() \ else 2 * in_channels self.out_channels = out_channels # check type of image_size if not isinstance(image_size, int) and not isinstance( image_size, list): raise TypeError( 'Only support `int` and `list[int]` for `image_size`.') if isinstance(image_size, list): assert len( image_size) == 2, 'The length of `image_size` should be 2.' assert image_size[0] == image_size[ 1], 'Width and height of the image should be same.' image_size = image_size[0] self.image_size = image_size channels_cfg = deepcopy(self._default_channels_cfg) \ if channels_cfg is None else deepcopy(channels_cfg) if isinstance(channels_cfg, dict): if image_size not in channels_cfg: raise KeyError(f'`image_size={image_size} is not found in ' '`channels_cfg`, only support configs for ' f'{[chn for chn in channels_cfg.keys()]}') self.channel_factor_list = channels_cfg[image_size] elif isinstance(channels_cfg, list): self.channel_factor_list = channels_cfg else: raise ValueError('Only support list or dict for `channels_cfg`, ' f'receive {type(channels_cfg)}') embedding_channels = base_channels * 4 \ if embedding_channels == -1 else embedding_channels # init the channel scale factor scale = 1 ch = int(base_channels * self.channel_factor_list[0]) self.in_channels_list = [ch] if self.unet_type == 'stable': # time self.time_proj = Timesteps(ch) self.time_embedding = TimestepEmbedding(base_channels, embedding_channels) self.conv_in = nn.Conv2d( in_channels, ch, kernel_size=3, padding=(1, 1)) else: self.time_embedding = TimeEmbedding( base_channels, embedding_channels=embedding_channels, embedding_mode=time_embedding_mode, embedding_cfg=time_embedding_cfg, act_cfg=act_cfg) self.in_blocks = nn.ModuleList( [EmbedSequential(nn.Conv2d(in_channels, ch, 3, 1, padding=1))]) if self.num_classes != 0: self.label_embedding = nn.Embedding(self.num_classes, embedding_channels) self.resblock_cfg = deepcopy(resblock_cfg) self.resblock_cfg.setdefault('dropout', dropout) self.resblock_cfg.setdefault('norm_cfg', norm_cfg) self.resblock_cfg.setdefault('act_cfg', act_cfg) self.resblock_cfg.setdefault('embedding_channels', embedding_channels) self.resblock_cfg.setdefault('use_scale_shift_norm', use_scale_shift_norm) self.resblock_cfg.setdefault('shortcut_kernel_size', shortcut_kernel_size) # get scales of ResBlock to apply attention attention_scale = [image_size // int(res) for res in attention_res] self.attention_cfg = deepcopy(attention_cfg) self.attention_cfg.setdefault('num_heads', num_heads) self.attention_cfg.setdefault('norm_cfg', norm_cfg) self.downsample_cfg = deepcopy(downsample_cfg) self.downsample_cfg.setdefault('with_conv', downsample_conv) self.upsample_cfg = deepcopy(upsample_cfg) self.upsample_cfg.setdefault('with_conv', upsample_conv) self.down_blocks = nn.ModuleList([]) self.mid_block = None self.up_blocks = nn.ModuleList([]) attention_head_dim = (num_heads, ) * len(down_block_types) # construct the encoder part of Unet for level, factor in enumerate(self.channel_factor_list): in_channels_ = ch if level == 0 \ else int(base_channels * self.channel_factor_list[level - 1]) out_channels_ = int(base_channels * factor) if self.unet_type == 'stable': is_final_block = level == len(self.channel_factor_list) - 1 down_block_type = down_block_types[level] down_block = get_down_block( down_block_type, num_layers=layers_per_block, in_channels=in_channels_, out_channels=out_channels_, temb_channels=embedding_channels, cross_attention_dim=cross_attention_dim, add_downsample=not is_final_block, resnet_act_fn=act_cfg['type'], resnet_groups=norm_cfg['num_groups'], attn_num_head_channels=attention_head_dim[level], ) self.down_blocks.append(down_block) else: in_blocks, scale = build_down_block_resattn( resblocks_per_downsample=resblocks_per_downsample, resblock_cfg=self.resblock_cfg, in_channels_=in_channels_, out_channels_=out_channels_, attention_scale=attention_scale, attention_cfg=self.attention_cfg, in_channels_list=self.in_channels_list, level=level, channel_factor_list=self.channel_factor_list, embedding_channels=embedding_channels, use_scale_shift_norm=use_scale_shift_norm, dropout=dropout, norm_cfg=norm_cfg, resblock_updown=resblock_updown, downsample_cfg=self.downsample_cfg, scale=scale) self.in_blocks.extend(in_blocks) # construct the bottom part of Unet block_out_channels = [ times * base_channels for times in self.channel_factor_list ] in_channels_ = self.in_channels_list[-1] if self.unet_type == 'stable': self.mid_block = UNetMidBlock2DCrossAttn( in_channels=block_out_channels[-1], temb_channels=embedding_channels, cross_attention_dim=cross_attention_dim, resnet_act_fn=act_cfg['type'], resnet_time_scale_shift='default', attn_num_head_channels=attention_head_dim[-1], resnet_groups=norm_cfg['num_groups'], ) else: self.mid_blocks = build_mid_blocks_resattn(self.resblock_cfg, self.attention_cfg, in_channels_) # stable up parameters self.num_upsamplers = 0 reversed_block_out_channels = list(reversed(block_out_channels)) reversed_attention_head_dim = list(reversed(attention_head_dim)) output_channel = reversed_block_out_channels[0] # construct the decoder part of Unet in_channels_list = deepcopy(self.in_channels_list) if self.unet_type != 'stable': self.out_blocks = nn.ModuleList() for level, factor in enumerate(self.channel_factor_list[::-1]): if self.unet_type == 'stable': is_final_block = level == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[level] input_channel = reversed_block_out_channels[min( level + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block_type = up_block_types[level] up_block = get_up_block( up_block_type, num_layers=layers_per_block + 1, in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=embedding_channels, cross_attention_dim=cross_attention_dim, add_upsample=add_upsample, resnet_act_fn=act_cfg['type'], resnet_groups=norm_cfg['num_groups'], attn_num_head_channels=reversed_attention_head_dim[level], ) self.up_blocks.append(up_block) prev_output_channel = output_channel else: out_blocks, in_channels_, scale = build_up_blocks_resattn( resblocks_per_downsample, self.resblock_cfg, in_channels_, in_channels_list, base_channels, factor, scale, attention_scale, self.attention_cfg, self.channel_factor_list, level, embedding_channels, use_scale_shift_norm, dropout, norm_cfg, resblock_updown, self.upsample_cfg, ) self.out_blocks.extend(out_blocks) if self.unet_type == 'stable': # out self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_cfg['num_groups']) if digit_version(TORCH_VERSION) > digit_version('1.6.0'): self.conv_act = nn.SiLU() else: mmengine.print_log('\'SiLU\' is not supported for ' f'torch < 1.6.0, found \'{torch.version}\'.' 'Use ReLu instead but result maybe wrong') self.conv_act = nn.ReLU() self.conv_out = nn.Conv2d( block_out_channels[0], self.out_channels, kernel_size=3, padding=1) else: self.out = ConvModule( in_channels=in_channels_, out_channels=out_channels, kernel_size=3, padding=1, act_cfg=act_cfg, norm_cfg=norm_cfg, bias=True, order=('norm', 'act', 'conv')) if self.unet_type == 'stable': self.sample_size = image_size // 8 # NOTE: hard code here self.init_weights(pretrained)
[docs] def forward(self, x_t, t, encoder_hidden_states=None, label=None, return_noise=False): """Forward function. Args: x_t (torch.Tensor): Diffused image at timestep `t` to denoise. t (torch.Tensor): Current timestep. label (torch.Tensor | callable | None): You can directly give a batch of label through a ``torch.Tensor`` or offer a callable function to sample a batch of label data. Otherwise, the ``None`` indicates to use the default label sampler. return_noise (bool, optional): If True, inputted ``x_t`` and ``t`` will be returned in a dict with output desired by ``output_cfg``. Defaults to False. Returns: torch.Tensor | dict: If not ``return_noise`` """ # By default samples have to be AT least a multiple of t # he overall upsampling factor. # The overall upsampling factor is equal # to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size # can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not # a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in x_t.shape[-2:]): logger.info( 'Forward upsample size to force interpolation output size.') forward_upsample_size = True if not torch.is_tensor(t): t = torch.tensor([t], dtype=torch.long, device=x_t.device) elif torch.is_tensor(t) and len(t.shape) == 0: t = t[None].to(x_t.device) if self.unet_type == 'stable': # broadcast to batch dimension in a way that's # compatible with ONNX/Core ML t = t.expand(x_t.shape[0]) t_emb = self.time_proj(t) # t does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. # so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) embedding = self.time_embedding(t_emb) else: embedding = self.time_embedding(t) if label is not None: assert hasattr(self, 'label_embedding') embedding = self.label_embedding(label) + embedding if self.unet_type == 'stable': # 2. pre-process x_t = self.conv_in(x_t) # 3. down down_block_res_samples = (x_t, ) for downsample_block in self.down_blocks: if hasattr(downsample_block, 'attentions' ) and downsample_block.attentions is not None: x_t, res_samples = downsample_block( hidden_states=x_t, temb=embedding, encoder_hidden_states=encoder_hidden_states, ) else: x_t, res_samples = downsample_block( hidden_states=x_t, temb=embedding) down_block_res_samples += res_samples # 4. mid x_t = self.mid_block( x_t, embedding, encoder_hidden_states=encoder_hidden_states) # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block. resnets):] down_block_res_samples = down_block_res_samples[:-len( upsample_block.resnets)] # if we have not reached the final block # and need to forward the upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, 'attentions' ) and upsample_block.attentions is not None: x_t = upsample_block( hidden_states=x_t, temb=embedding, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, upsample_size=upsample_size, ) else: x_t = upsample_block( hidden_states=x_t, temb=embedding, res_hidden_states_tuple=res_samples, upsample_size=upsample_size) # 6. post-process x_t = self.conv_norm_out(x_t) x_t = self.conv_act(x_t) x_t = self.conv_out(x_t) outputs = x_t else: h, hs = x_t, [] h = h.type(self.dtype) # forward downsample blocks for block in self.in_blocks: h = block(h, embedding) hs.append(h) # forward middle blocks h = self.mid_blocks(h, embedding) # forward upsample blocks for block in self.out_blocks: h = block(torch.cat([h, hs.pop()], dim=1), embedding) h = h.type(x_t.dtype) outputs = self.out(h) return {'sample': outputs}
[docs] def init_weights(self, pretrained=None): """Init weights for models. We just use the initialization method proposed in the original paper. Args: pretrained (str, optional): Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults to None. """ if isinstance(pretrained, str): logger = MMLogger.get_current_instance() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: # As Improved-DDPM, we apply zero-initialization to # second conv block in ResBlock (keywords: conv_2) # the output layer of the Unet (keywords: 'out' but # not 'out_blocks') # projection layer in Attention layer (keywords: proj) for n, m in self.named_modules(): if isinstance(m, nn.Conv2d) and ('conv_2' in n or ('out' in n and 'out_blocks' not in n)): constant_init(m, 0) if isinstance(m, nn.Conv1d) and 'proj' in n: constant_init(m, 0) else: raise TypeError('pretrained must be a str or None but'
f' got {type(pretrained)} instead.')
[docs] def convert_to_fp16(self): """Convert the precision of the model to float16.""" self.in_blocks.apply(convert_module_to_f16) self.mid_blocks.apply(convert_module_to_f16) self.out_blocks.apply(convert_module_to_f16)
[docs] def convert_to_fp32(self): """Convert the precision of the model to float32.""" self.in_blocks.apply(convert_module_to_f32) self.mid_blocks.apply(convert_module_to_f32) self.out_blocks.apply(convert_module_to_f32)
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.