mmagic.models.utils.tome_utils 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Any, Callable, Dict, Tuple, Type
import torch
[文档]def add_tome_cfg_hook(model: torch.nn.Module):
"""Add a forward pre hook to get the image size. This hook can be removed
with remove_patch.
Source: https://github.com/dbolya/tomesd/blob/main/tomesd/patch.py#L158 # noqa
"""
def hook(module, args):
module._tome_info['size'] = (args[0].shape[2], args[0].shape[3])
return None
model._tome_info['hooks'].append(model.register_forward_pre_hook(hook))
[文档]def build_mmagic_wrapper_tomesd_block(block_class: Type[torch.nn.Module]
) -> Type[torch.nn.Module]:
"""Make a patched class for a DiffusersWrapper model in mmagic. This patch
applies ToMe to the forward function of the block.
Refer to: https://github.com/dbolya/tomesd/blob/main/tomesd/patch.py#L67 # noqa
Args:
block_class (torch.nn.Module): original class need tome speedup.
Returns:
ToMeBlock (torch.nn.Module): patched class based on the original class.
"""
class ToMeBlock(block_class):
# Save for unpatching later
_parent = block_class
def forward(
self,
hidden_states,
encoder_hidden_states=None,
timestep=None,
attention_mask=None,
cross_attention_kwargs=None,
class_labels=None,
):
# -> (1) ToMeBlock
m_a, m_c, m_m, u_a, u_c, u_m = build_merge(hidden_states,
self._tome_info)
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa,\
shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states,
timestep,
class_labels,
hidden_dtype=hidden_states.dtype)
else:
norm_hidden_states = self.norm1(hidden_states)
# -> (2) ToMe m_a
norm_hidden_states = m_a(norm_hidden_states)
# 1. Self-Attention
if cross_attention_kwargs is not None:
cross_attention_kwargs = cross_attention_kwargs
else:
cross_attention_kwargs = {}
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states
if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
# -> (3) ToMe u_a
hidden_states = u_a(attn_output) + hidden_states
if self.attn2 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep)
if self.use_ada_layer_norm else self.norm2(hidden_states))
# -> (4) ToMe m_c
norm_hidden_states = m_c(norm_hidden_states)
# 2. Cross-Attention
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
# -> (5) ToMe u_c
hidden_states = u_c(attn_output) + hidden_states
# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (
1 + scale_mlp[:, None]) + shift_mlp[:, None]
# -> (6) ToMe m_m
norm_hidden_states = m_m(norm_hidden_states)
ff_output = self.ff(norm_hidden_states)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
# -> (7) ToMe u_m
hidden_states = u_m(ff_output) + hidden_states
return hidden_states
return ToMeBlock
[文档]def build_mmagic_tomesd_block(block_class: Type[torch.nn.Module]
) -> Type[torch.nn.Module]:
"""Make a patched class for a mmagic StableDiffusion model. This patch
applies ToMe to the forward function of the block.
Refer to: https://github.com/dbolya/tomesd/blob/main/tomesd/patch.py#L67 # noqa
Args:
block_class (torch.nn.Module): original class need tome speedup.
Returns:
ToMeBlock (torch.nn.Module): patched class based on the original class.
"""
class ToMeBlock(block_class):
# Save for unpatching later
_parent = block_class
def forward(self, hidden_states, context=None, timestep=None):
# ->(1) ToMeBlock
m_a, m_c, m_m, u_a, u_c, u_m = build_merge(hidden_states,
self._tome_info)
# 1. Self-Attention
# ->(2) ToMe m_a
norm_hidden_states = (m_a(self.norm1(hidden_states)))
# ->(3) ToMe u_a
if self.only_cross_attention:
hidden_states = u_a(self.attn1(norm_hidden_states,
context)) + hidden_states
else:
hidden_states = u_a(
self.attn1(norm_hidden_states)) + hidden_states
# 2. Cross-Attention
# ->(4) ToMe m_c
norm_hidden_states = (m_c(self.norm2(hidden_states)))
# ->(5) ToMe u_c
hidden_states = u_c(
self.attn2(norm_hidden_states,
context=context)) + hidden_states
# 3. Feed-forward
# ->(6) ToMe m_m, u_m
hidden_states = u_m(self.ff(m_m(
self.norm3(hidden_states)))) + hidden_states
return hidden_states
return ToMeBlock
[文档]def isinstance_str(x: object, cls_name: str):
"""Checks whether `x` has any class *named* `cls_name` in its ancestry.
Doesn't require access to the class's implementation.
Source: https://github.com/dbolya/tomesd/blob/main/tomesd/utils.py#L3 # noqa
"""
for _cls in x.__class__.__mro__:
if _cls.__name__ == cls_name:
return True
return False
[文档]def do_nothing(x: torch.Tensor, mode: str = None):
"""Build identical mapping function.
Source: https://github.com/dbolya/tomesd/blob/main/tomesd/merge.py#L5 # noqa
"""
return x
[文档]def mps_gather_workaround(input, dim, index):
"""Gather function specific for `mps` backend (Metal Performance Shaders).
Source: https://github.com/dbolya/tomesd/blob/main/tomesd/merge.py#L9 # noqa
"""
if input.shape[-1] == 1:
return torch.gather(
input.unsqueeze(-1), dim - 1 if dim < 0 else dim,
index.unsqueeze(-1)).squeeze(-1)
else:
return torch.gather(input, dim, index)
[文档]def bipartite_soft_matching_random2d(metric: torch.Tensor,
w: int,
h: int,
sx: int,
sy: int,
r: int,
no_rand: bool = False
) -> Tuple[Callable, Callable]:
"""Partitions the tokens into src and dst and merges r tokens from src to
dst, dst tokens are partitioned by choosing one randomy in each (`sx`,
`sy`) region. More details refer to `Token Merging: Your ViT But Faster`,
paper link: <https://arxiv.org/abs/2210.09461>`_ # noqa.
Source: https://github.com/dbolya/tomesd/blob/main/tomesd/merge.py#20 # noqa
Args:
metric (torch.Tensor): metric with size (B, N, C) for similarity computation.
w (int): image width in tokens.
h (int): image height in tokens.
sx (int): stride in the x dimension for dst, must divide `w`.
sy (int): stride in the y dimension for dst, must divide `h`.
r (int): number of tokens to remove (by merging).
no_rand (bool): if true, disable randomness (use top left corner only).
Returns:
merge (Callable): token merging function.
unmerge (Callable): token unmerging function.
"""
B, N, _ = metric.shape
if r <= 0:
return do_nothing, do_nothing
if metric.device.type == 'mps':
gather = mps_gather_workaround
else:
gather = torch.gather
with torch.no_grad():
hsy, wsx = h // sy, w // sx
# For each sy by sx kernel, randomly assign one token to
# be dst and the rest src
if no_rand:
rand_idx = torch.zeros(
hsy, wsx, 1, device=metric.device, dtype=torch.int64)
else:
rand_idx = torch.randint(
sy * sx, size=(hsy, wsx, 1), device=metric.device)
# The image might not divide sx and sy, so we need to work
# on a view of the top left if the idx buffer instead
idx_buffer_view = torch.zeros(
hsy, wsx, sy * sx, device=metric.device, dtype=torch.int64)
idx_buffer_view.scatter_(
dim=2,
index=rand_idx,
src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(
1, 2).reshape(hsy * sy, wsx * sx)
# Image is not divisible by sx or sy so we need to move it
# into a new buffer
if (hsy * sy) < h or (wsx * sx) < w:
idx_buffer = torch.zeros(
h, w, device=metric.device, dtype=torch.int64)
idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view
else:
idx_buffer = idx_buffer_view
# We set dst tokens to be -1 and src to be 0, so an argsort
# gives us dst|src indices
rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1)
# We're finished with these
del idx_buffer, idx_buffer_view
# rand_idx is currently dst|src, so split them
num_dst = hsy * wsx
a_idx = rand_idx[:, num_dst:, :] # src
b_idx = rand_idx[:, :num_dst, :] # dst
def split(x):
C = x.shape[-1]
src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
return src, dst
# Cosine similarity between A and B
metric = metric / metric.norm(dim=-1, keepdim=True)
a, b = split(metric)
scores = a @ b.transpose(-1, -2)
# Can't reduce more than the # tokens in src
r = min(a.shape[1], r)
# Find the most similar greedily
node_max, node_idx = scores.max(dim=-1)
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
src_idx = edge_idx[..., :r, :] # Merged Tokens
dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
def merge(x: torch.Tensor, mode='mean') -> torch.Tensor:
src, dst = split(x)
n, t1, c = src.shape
unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
if not hasattr(torch.Tensor,
'scatter_reduce') or torch.__version__ < '1.12.1':
raise ImportError(
'Please upgrade torch >= 1.12.1 to enable \'scatter_reduce\'')
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
return torch.cat([unm, dst], dim=1)
def unmerge(x: torch.Tensor) -> torch.Tensor:
unm_len = unm_idx.shape[1]
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
_, _, c = unm.shape
src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c))
# Combine back to the original shape
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
out.scatter_(
dim=-2,
index=gather(
a_idx.expand(B, a_idx.shape[1], 1), dim=1,
index=unm_idx).expand(B, unm_len, c),
src=unm)
out.scatter_(
dim=-2,
index=gather(
a_idx.expand(B, a_idx.shape[1], 1), dim=1,
index=src_idx).expand(B, r, c),
src=src)
return out
return merge, unmerge
[文档]def build_merge(x: torch.Tensor, tome_info: Dict[str,
Any]) -> Tuple[Callable, ...]:
"""Build the merge and unmerge functions for a given setting from
`tome_info`.
Source: https://github.com/dbolya/tomesd/blob/main/tomesd/patch.py#L10 # noqa
"""
original_h, original_w = tome_info['size']
original_tokens = original_h * original_w
downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
args = tome_info['args']
if downsample <= args['max_downsample']:
w = int(math.ceil(original_w / downsample))
h = int(math.ceil(original_h / downsample))
r = int(x.shape[1] * args['ratio'])
# If the batch size is odd, then it's not possible for promted and
# unprompted images to be in the same batch, which causes artifacts
# with use_rand, so force it to be off.
use_rand = False if x.shape[0] % 2 == 1 else args['use_rand']
m, u = bipartite_soft_matching_random2d(x, w, h, args['sx'],
args['sy'], r, not use_rand)
else:
m, u = (do_nothing, do_nothing)
m_a, u_a = (m, u) if args['merge_attn'] else (do_nothing, do_nothing)
m_c, u_c = (m, u) if args['merge_crossattn'] else (do_nothing, do_nothing)
m_m, u_m = (m, u) if args['merge_mlp'] else (do_nothing, do_nothing)
return m_a, m_c, m_m, u_a, u_c, u_m