Shortcuts

Source code for mmagic.models.editors.fastcomposer.fastcomposer_util

# Copyright (c) OpenMMLab. All rights reserved.
import gc
import types
from collections import OrderedDict
from typing import Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torchvision.transforms as T
from torch import nn
from torch.nn import Linear
from transformers import (CLIPModel, CLIPPreTrainedModel, CLIPTextModel,
                          CLIPVisionConfig, CLIPVisionModel)
from transformers.modeling_outputs import BaseModelOutputWithPooling

from mmagic.utils import try_import

_expand_mask = try_import('transformers.models.clip.modeling_clip')
if _expand_mask is None:
[docs] _expand_mask = try_import( 'ransformers.models.clip.modeling_clip._prepare_4d_attention_mask')
[docs]class FastComposerModel(nn.Module): """FastComposerModel is based on the StableDiffusion Model and the Clip Model.""" def __init__(self, text_encoder, image_encoder, vae, unet, cfg): super().__init__() self.text_encoder = text_encoder self.image_encoder = image_encoder self.vae = vae self.unet = unet self.use_ema = False self.ema_param = None self.pretrained_model_name_or_path = cfg[ 'pretrained_model_name_or_path'] self.revision = cfg['revision'] self.non_ema_revision = cfg['non_ema_revision'] self.object_localization = cfg['object_localization'] self.object_localization_weight = cfg['object_localization_weight'] self.localization_layers = cfg['localization_layers'] self.mask_loss = cfg['mask_loss'] self.mask_loss_prob = cfg['mask_loss_prob'] embed_dim = text_encoder.config.hidden_size self.postfuse_module = FastComposerPostfuseModule(embed_dim) if self.object_localization: self.cross_attention_scores = {} self.unet = unet_store_cross_attention_scores( self.unet, self.cross_attention_scores, self.localization_layers) self.object_localization_loss_fn = BalancedL1Loss( cfg['object_localization_threshold'], cfg['object_localization_normalize'], )
[docs] def _clear_cross_attention_scores(self): """Delete cross attention scores.""" if hasattr(self, 'cross_attention_scores'): keys = list(self.cross_attention_scores.keys()) for k in keys: del self.cross_attention_scores[k] gc.collect()
@staticmethod
[docs] def from_pretrained(cfg, vae, unet): """Init FastComposerTextEncoder and FastComposerCLIPImageEncoder.""" text_encoder = FastComposerTextEncoder.from_pretrained( cfg['pretrained_model_name_or_path'], subfolder='text_encoder', revision=cfg['revision'], ) if not isinstance(cfg['image_encoder'], dict): image_encoder = FastComposerCLIPImageEncoder.from_pretrained( cfg['image_encoder']) else: vision_model = CLIPVisionModel( CLIPVisionConfig.from_dict(cfg['image_encoder'])) visual_projection = Linear( in_features=1024, out_features=768, bias=False) vision_processor = T.Normalize( (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711), ) image_encoder = FastComposerCLIPImageEncoder( vision_model, visual_projection, vision_processor, ) return FastComposerModel(text_encoder, image_encoder, vae, unet, cfg)
[docs] def forward(self, batch, noise_scheduler): """Forward function. Args: batch (torch.Tensor ): You can directly input a ``torch.Tensor``. noise_scheduler (torch.Tensor ): You can directly input a ``torch.Tensor``. Returns: Dict """ pixel_values = batch['pixel_values'] input_ids = batch['input_ids'] image_token_mask = batch['image_token_mask'] object_pixel_values = batch['object_pixel_values'] num_objects = batch['num_objects'] vae_dtype = self.vae.parameters().__next__().dtype vae_input = pixel_values.to(vae_dtype) latents = self.vae.encode(vae_input).latent_dist.sample() latents = latents * self.vae.config.scaling_factor # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint( 0, noise_scheduler.num_train_timesteps, (bsz, ), device=latents.device) timesteps = timesteps.long() # Add noise to the latents according to # the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # (bsz, max_num_objects, num_image_tokens, dim) object_embeds = self.image_encoder(object_pixel_values) encoder_hidden_states = self.text_encoder( input_ids, image_token_mask, object_embeds, num_objects)[0] # (bsz, seq_len, dim) encoder_hidden_states = self.postfuse_module( encoder_hidden_states, object_embeds, image_token_mask, num_objects, ) # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == 'epsilon': target = noise elif noise_scheduler.config.prediction_type == 'v_prediction': target = noise_scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError('Unknown prediction type ' f'{noise_scheduler.config.prediction_type}') pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample if self.mask_loss and torch.rand(1) < self.mask_loss_prob: object_segmaps = batch['object_segmaps'] mask = (object_segmaps.sum(dim=1) > 0).float() mask = F.interpolate( mask.unsqueeze(1), size=(pred.shape[-2], pred.shape[-1]), mode='bilinear', align_corners=False, ) pred = pred * mask target = target * mask denoise_loss = F.mse_loss( pred.float(), target.float(), reduction='mean') return_dict = {'denoise_loss': denoise_loss} if self.object_localization: object_segmaps = batch['object_segmaps'] image_token_idx = batch['image_token_idx'] image_token_idx_mask = batch['image_token_idx_mask'] localization_loss = get_object_localization_loss( self.cross_attention_scores, object_segmaps, image_token_idx, image_token_idx_mask, self.object_localization_loss_fn, ) return_dict['localization_loss'] = localization_loss loss = self.object_localization_weight * localization_loss loss += denoise_loss self._clear_cross_attention_scores() else: loss = denoise_loss return_dict['loss'] = loss return return_dict
[docs]class FastComposerTextEncoder(CLIPPreTrainedModel): """TextEncoder for FastComposerModel.""" @staticmethod
[docs] def from_pretrained(model_name_or_path, **kwargs): """Init textEncoder with Stable Diffusion Model name or path.""" model = CLIPTextModel.from_pretrained(model_name_or_path, **kwargs) text_model = model.text_model return FastComposerTextEncoder(text_model)
def __init__(self, text_model): super().__init__(text_model.config) self.config = text_model.config self.final_layer_norm = text_model.final_layer_norm self.embeddings = text_model.embeddings self.encoder = text_model.encoder self._build_causal_attention_mask = build_causal_attention_mask
[docs] def forward( self, input_ids, image_token_mask=None, object_embeds=None, num_objects=None, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: """Forward function. Args: input_ids (torch.Tensor ): You can directly input a ``torch.Tensor``. image_token_mask (torch.Tensor ): You can directly input a ``torch.Tensor``. object_embeds (torch.Tensor ): You can directly input a ``torch.Tensor``. num_objects (torch.Tensor ): You can directly input a ``torch.Tensor``. attention_mask (torch.Tensor ): You can directly input a ``torch.Tensor``. output_attentions (bool ): Default to None. output_hidden_states (bool ): Default to None. return_dict (bool ): Default to None. Returns: Union[Tuple, BaseModelOutputWithPooling] """ output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict) input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) hidden_states = self.embeddings(input_ids) bsz, seq_len = input_shape causal_attention_mask = self._build_causal_attention_mask( bsz, seq_len, hidden_states.dtype).to(hidden_states.device) # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _expand_mask(attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( inputs_embeds=hidden_states, attention_mask=attention_mask, causal_attention_mask=causal_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] last_hidden_state = self.final_layer_norm(last_hidden_state) # text_embeds.shape = [batch_size, sequence_length, transformer.width] # take features from the eot embedding # (eot_token is the highest number in each sequence) # casting to torch.int for onnx compatibility: # argmax doesn't support int64 inputs with opset 14 pooled_output = last_hidden_state[ torch.arange( last_hidden_state.shape[0], device=last_hidden_state.device), input_ids.to(dtype=torch.int, device=last_hidden_state.device ).argmax(dim=-1), ] if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions,
)
[docs]class FastComposerCLIPImageEncoder(CLIPPreTrainedModel): """CLIPImageEncoder for FastComposerModel.""" @staticmethod
[docs] def from_pretrained(global_model_name_or_path): """Init CLIPModel with Clip model name or path.""" model = CLIPModel.from_pretrained(global_model_name_or_path) vision_model = model.vision_model visual_projection = model.visual_projection vision_processor = T.Normalize( (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711), ) return FastComposerCLIPImageEncoder( vision_model, visual_projection, vision_processor,
) def __init__( self, vision_model, visual_projection, vision_processor, ): super().__init__(vision_model.config) self.vision_model = vision_model self.visual_projection = visual_projection self.vision_processor = vision_processor self.image_size = vision_model.config.image_size
[docs] def forward(self, object_pixel_values): """Forward function. Args: object_pixel_values (torch.Tensor ): You can directly input a ``torch.Tensor``. Returns: torch.Tensor : ``torch.tensor`` will be returned. """ b, num_objects, c, h, w = object_pixel_values.shape object_pixel_values = object_pixel_values.view(b * num_objects, c, h, w) if h != self.image_size or w != self.image_size: h, w = self.image_size, self.image_size object_pixel_values = F.interpolate( object_pixel_values, (h, w), mode='bilinear') object_pixel_values = self.vision_processor(object_pixel_values) object_embeds = self.vision_model(object_pixel_values)[1] object_embeds = self.visual_projection(object_embeds) object_embeds = object_embeds.view(b, num_objects, 1, -1) return object_embeds
[docs]def get_object_transforms(cfg): """Get Object transforms.""" if cfg['no_object_augmentation']: pre_augmentations = [] augmentations = [] else: pre_augmentations = [ ( 'zoomin', T.RandomApply([RandomZoomIn(min_zoom=1.0, max_zoom=2.0)], p=0.5), ), ] augmentations = [ ( 'rotate', T.RandomApply( [ T.RandomAffine( degrees=30, interpolation=T.InterpolationMode.BILINEAR) ], p=0.75, ), ), ('jitter', T.RandomApply([T.ColorJitter(0.5, 0.5, 0.5, 0.5)], p=0.5)), ('blur', T.RandomApply([T.GaussianBlur(5, sigma=(0.1, 2.0))], p=0.5)), ('gray', T.RandomGrayscale(p=0.1)), ('flip', T.RandomHorizontalFlip()), ('elastic', T.RandomApply([T.ElasticTransform()], p=0.5)), ] object_transforms = torch.nn.Sequential( OrderedDict([ *pre_augmentations, ('pad_to_square', PadToSquare(fill=0, padding_mode='constant')), ( 'resize', T.Resize( (cfg['object_resolution'], cfg['object_resolution']), interpolation=T.InterpolationMode.BILINEAR, ), ), *augmentations, ('convert_to_float', T.ConvertImageDtype(torch.float32)), ])) return object_transforms
[docs]class FastComposerPostfuseModule(nn.Module): """Postfuse Module for FastComposerModel.""" def __init__(self, embed_dim): super().__init__() self.mlp1 = MLP( embed_dim * 2, embed_dim, embed_dim, use_residual=False) self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True) self.layer_norm = nn.LayerNorm(embed_dim)
[docs] def fuse_fn(self, text_embeds, object_embeds): """Fuse function. Args: text_embeds (torch.Tensor ): You can directly input a ``torch.Tensor``. object_embeds (torch.Tensor ): You can directly input a ``torch.Tensor``. Returns: torch.Tensor : ``torch.tensor`` will be returned. """ text_object_embeds = torch.cat([text_embeds, object_embeds], dim=-1) text_object_embeds = self.mlp1(text_object_embeds) + text_embeds text_object_embeds = self.mlp2(text_object_embeds) text_object_embeds = self.layer_norm(text_object_embeds) return text_object_embeds
[docs] def forward( self, text_embeds, object_embeds, image_token_mask, num_objects, ) -> torch.Tensor: """Forward function. Args: text_embeds (torch.Tensor ): You can directly input a ``torch.Tensor``. object_embeds (torch.Tensor ): You can directly input a ``torch.Tensor``. image_token_mask (torch.Tensor ): You can directly input a ``torch.Tensor``. Returns: torch.Tensor : ``torch.tensor`` will be returned. """ text_object_embeds = fuse_object_embeddings(text_embeds, image_token_mask, object_embeds, num_objects, self.fuse_fn) return text_object_embeds
[docs]def unet_store_cross_attention_scores(unet, attention_scores, layers=5): """Unet store cross attention scores.""" from diffusers.models.attention_processor import (Attention, AttnProcessor, AttnProcessor2_0) UNET_LAYER_NAMES = [ 'down_blocks.0', 'down_blocks.1', 'down_blocks.2', 'mid_block', 'up_blocks.1', 'up_blocks.2', 'up_blocks.3', ] start_layer = (len(UNET_LAYER_NAMES) - layers) // 2 end_layer = start_layer + layers applicable_layers = UNET_LAYER_NAMES[start_layer:end_layer] def make_new_get_attention_scores_fn(name): """Wrapper Function of create attention scores for unet.""" def new_get_attention_scores(module, query, key, attention_mask=None): """Create attention scores for unet.""" attention_probs = module.old_get_attention_scores( query, key, attention_mask) attention_scores[name] = attention_probs return attention_probs return new_get_attention_scores for name, module in unet.named_modules(): if isinstance(module, Attention) and 'attn2' in name: if not any(layer in name for layer in applicable_layers): continue if isinstance(module.processor, AttnProcessor2_0): module.set_processor(AttnProcessor()) module.old_get_attention_scores = module.get_attention_scores module.get_attention_scores = types.MethodType( make_new_get_attention_scores_fn(name), module) return unet
[docs]class BalancedL1Loss(nn.Module): """BalancedL1Loss for object localization.""" def __init__(self, threshold=1.0, normalize=False): super().__init__() self.threshold = threshold self.normalize = normalize
[docs] def forward(self, object_token_attn_prob, object_segmaps): """Forward function. Args: object_token_attn_prob (torch.Tensor ): You can directly input a ``torch.Tensor``. object_segmaps (torch.Tensor ): You can directly input a ``torch.Tensor``. Returns: float : ``float`` will be returned. """ if self.normalize: object_token_attn_prob = object_token_attn_prob / ( object_token_attn_prob.max(dim=2, keepdim=True)[0] + 1e-5) background_segmaps = 1 - object_segmaps background_segmaps_sum = background_segmaps.sum(dim=2) + 1e-5 object_segmaps_sum = object_segmaps.sum(dim=2) + 1e-5 background_loss = (object_token_attn_prob * background_segmaps).sum( dim=2) / background_segmaps_sum object_loss = (object_token_attn_prob * object_segmaps).sum(dim=2) / object_segmaps_sum return background_loss - object_loss
[docs]def get_object_localization_loss( cross_attention_scores, object_segmaps, image_token_idx, image_token_idx_mask, loss_fn, ): """To obtain the average of the loss for each layer of object localization.""" num_layers = len(cross_attention_scores) loss = 0 for k, v in cross_attention_scores.items(): layer_loss = get_object_localization_loss_for_one_layer( v, object_segmaps, image_token_idx, image_token_idx_mask, loss_fn) loss += layer_loss return loss / num_layers
[docs]def get_object_localization_loss_for_one_layer( cross_attention_scores, object_segmaps, object_token_idx, object_token_idx_mask, loss_fn, ): """Get object localization loss for one layer.""" bxh, num_noise_latents, num_text_tokens = cross_attention_scores.shape b, max_num_objects, _, _ = object_segmaps.shape size = int(num_noise_latents**0.5) # Resize the object segmentation maps to # the size of the cross attention scores object_segmaps = F.interpolate( object_segmaps, size=(size, size), mode='bilinear') # (b, max_num_objects, size, size) object_segmaps = object_segmaps.view( b, max_num_objects, -1) # (b, max_num_objects, num_noise_latents) num_heads = bxh // b cross_attention_scores = cross_attention_scores.view( b, num_heads, num_noise_latents, num_text_tokens) # Gather object_token_attn_prob object_token_attn_prob = torch.gather( cross_attention_scores, dim=3, index=object_token_idx.view(b, 1, 1, max_num_objects).expand( b, num_heads, num_noise_latents, max_num_objects), ) # (b, num_heads, num_noise_latents, max_num_objects) object_segmaps = ( object_segmaps.permute(0, 2, 1).unsqueeze(1).expand(b, num_heads, num_noise_latents, max_num_objects)) loss = loss_fn(object_token_attn_prob, object_segmaps) loss = loss * object_token_idx_mask.view(b, 1, max_num_objects) object_token_cnt = object_token_idx_mask.sum(dim=1).view(b, 1) + 1e-5 loss = (loss.sum(dim=2) / object_token_cnt).mean() return loss
[docs]class RandomZoomIn(nn.Module): """RandomZoomIn for object transform.""" def __init__(self, min_zoom=1.0, max_zoom=1.5): super().__init__() self.min_zoom = min_zoom self.max_zoom = max_zoom
[docs] def forward(self, image: torch.Tensor): """Forward function. Args: image (torch.Tensor ): You can directly input a ``torch.Tensor``. Returns: torch.Tensor : ``torch.tensor`` will be returned. """ zoom = torch.rand(1) * (self.max_zoom - self.min_zoom) + self.min_zoom image = T.functional.resize( image, (int(zoom * image.shape[1]), int(zoom * image.shape[2])), interpolation=T.InterpolationMode.BILINEAR, ) # crop top square image = CropTopSquare()(image) return image
[docs]class PadToSquare(nn.Module): """If the height of the image is greater than the width, padding will be added on both sides of the image to make it a square.""" def __init__(self, fill=0, padding_mode='constant'): super().__init__() self.fill = fill self.padding_mode = padding_mode
[docs] def forward(self, image: torch.Tensor): """Forward function. Args: image (torch.Tensor ): You can directly input a ``torch.Tensor``. Returns: torch.Tensor : ``torch.tensor`` will be returned. """ _, h, w = image.shape if h == w: return image elif h > w: padding = (h - w) // 2 image = torch.nn.functional.pad( image, (padding, padding, 0, 0), self.padding_mode, self.fill, ) else: padding = (w - h) // 2 image = torch.nn.functional.pad( image, (0, 0, padding, padding), self.padding_mode, self.fill, ) return image
[docs]class CropTopSquare(nn.Module): """If the height of the image is greater than the width, the image will be cropped into a square starting from the top of the image.""" def __init__(self): super().__init__()
[docs] def forward(self, image: torch.Tensor): """Forward function. Args: image (torch.Tensor ): You can directly input a ``torch.Tensor``. Returns: torch.Tensor : ``torch.tensor`` will be returned. """ _, h, w = image.shape if h <= w: return image return image[:, :w, :]
[docs]class MLP(nn.Module): """Multilayer Perceptron.""" def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True): super().__init__() if use_residual: assert in_dim == out_dim self.layernorm = nn.LayerNorm(in_dim) self.fc1 = nn.Linear(in_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, out_dim) self.use_residual = use_residual self.act_fn = nn.GELU()
[docs] def forward(self, x): """Forward function. Args: x (torch.Tensor ): You can directly input a ``torch.Tensor``. Returns: torch.Tensor : ``torch.tensor`` will be returned. """ residual = x x = self.layernorm(x) x = self.fc1(x) x = self.act_fn(x) x = self.fc2(x) if self.use_residual: x = x + residual return x
[docs]def fuse_object_embeddings( inputs_embeds, image_token_mask, object_embeds, num_objects, fuse_fn=torch.add, ): """Fuse object embeddings.""" object_embeds = object_embeds.to(inputs_embeds.dtype) batch_size, max_num_objects = object_embeds.shape[:2] seq_length = inputs_embeds.shape[1] flat_object_embeds = object_embeds.view(-1, object_embeds.shape[-2], object_embeds.shape[-1]) valid_object_mask = ( torch.arange(max_num_objects, device=flat_object_embeds.device)[None, :] < num_objects[:, None]) valid_object_embeds = flat_object_embeds[valid_object_mask.flatten()] inputs_embeds = inputs_embeds.view(-1, inputs_embeds.shape[-1]) image_token_mask = image_token_mask.view(-1) valid_object_embeds = valid_object_embeds.view( -1, valid_object_embeds.shape[-1]) # slice out the image token embeddings image_token_embeds = inputs_embeds[image_token_mask] valid_object_embeds = fuse_fn(image_token_embeds, valid_object_embeds) inputs_embeds.masked_scatter_(image_token_mask[:, None], valid_object_embeds) inputs_embeds = inputs_embeds.view(batch_size, seq_length, -1) return inputs_embeds
[docs]def build_causal_attention_mask(bsz, seq_len, dtype, device=None): """The function originally belonged to CLIPTextTransformer, but it has been removed in versions of transformers after 4.25.1.""" # lazily create causal attention mask, # with full attention between the vision tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype, device=device) mask.fill_(torch.tensor(torch.finfo(dtype).min)) mask.triu_(1) # zero out the lower diagonal mask = mask.unsqueeze(1) # expand mask return mask