Shortcuts

Source code for mmagic.models.editors.vico.vico_utils

# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import Transformer2DModel
from diffusers.models.attention import Attention, BasicTransformerBlock
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
from diffusers.utils import BaseOutput, is_torch_version


[docs]class ViCoCrossAttnProcessor: """Processor for implementing attention for the ViCo method."""
[docs] def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): """ Args: attn (Attention): Attention module. hidden_states (torch.Tensor): Input hidden states. encoder_hidden_states (torch.Tensor): Encoder hidden states. attention_mask (torch.Tensor): Attention mask. Returns: torch.Tensor: Output hidden states. """ batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask( attention_mask, sequence_length, batch_size=batch_size) query = attn.to_q(hidden_states) encoder_hidden_states = encoder_hidden_states \ if encoder_hidden_states is not None else hidden_states key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) # new bookkeeping to save the attn probs attn.attn_probs = attention_probs hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states
[docs]def replace_cross_attention(unet): """Replace Cross Attention processor in UNet.""" for name, module in unet.named_modules(): name: str if name.endswith('attn2'): module.set_processor(ViCoCrossAttnProcessor())
@dataclass
[docs]class ViCoTransformer2DModelOutput(BaseOutput): """Output for ViCoTransformer2DModel."""
[docs] sample: torch.FloatTensor
[docs] loss_reg: torch.FloatTensor
[docs]def otsu(mask_in): """Apply otsu for mask. Args: mask_in (torch.Tensor): Input mask. """ # normalize mask_norm = (mask_in - mask_in.min(-1, keepdim=True)[0]) / \ (mask_in.max(-1, keepdim=True)[0] - mask_in.min(-1, keepdim=True)[0]) bs = mask_in.shape[0] h = mask_in.shape[1] mask = [] for i in range(bs): threshold_t = 0. max_g = 0. for t in range(10): mask_i = mask_norm[i] low = mask_i[mask_i < t * 0.1] high = mask_i[mask_i >= t * 0.1] low_num = low.shape[0] / h high_num = high.shape[0] / h low_mean = low.mean() high_mean = high.mean() g = low_num * high_num * ((low_mean - high_mean)**2) if g > max_g: max_g = g threshold_t = t * 0.1 mask_i[mask_i < threshold_t] = 0 mask_i[mask_i > threshold_t] = 1 mask.append(mask_i) mask_out = torch.stack(mask, dim=0) return mask_out
[docs]class ViCoTransformer2D(nn.Module): """New ViCo-Transformer2D to replace the original Transformer2D model.""" def __init__(self, org_transformer2d: Transformer2DModel, have_image_cross) -> None: """ Args: org_transformer2d (Transformer2DModel): Original Transformer2DModel. have_image_cross (bool): Flag indicating if the model has image_cross_attention modules. """ super().__init__() self.transformer_blocks = org_transformer2d.transformer_blocks self.is_input_continuous = org_transformer2d.is_input_continuous self.norm = org_transformer2d.norm self.use_linear_projection = org_transformer2d.use_linear_projection self.proj_in = org_transformer2d.proj_in self.proj_out = org_transformer2d.proj_out self.is_input_vectorized = org_transformer2d.is_input_vectorized self.is_input_patches = org_transformer2d.is_input_patches num_attention_heads = org_transformer2d.num_attention_heads attention_head_dim = org_transformer2d.attention_head_dim inner_dim = num_attention_heads * attention_head_dim self.have_image_cross = have_image_cross if self.have_image_cross: image_cross_attention = BasicTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, cross_attention_dim=inner_dim) self.image_cross_attention = image_cross_attention.to( org_transformer2d.device, dtype=org_transformer2d.dtype)
[docs] def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, timestep: Optional[torch.LongTensor] = None, placeholder_position: list = None, class_labels: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ): if attention_mask is not None and attention_mask.ndim == 2: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = ( 1 - attention_mask.to(hidden_states.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) if encoder_attention_mask is not None and (encoder_attention_mask.ndim == 2): encoder_attention_mask = ( 1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 1. Input if self.is_input_continuous: batch, _, height, width = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) if not self.use_linear_projection: hidden_states = self.proj_in(hidden_states) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( batch, height * width, inner_dim) else: inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( batch, height * width, inner_dim) hidden_states = self.proj_in(hidden_states) elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) elif self.is_input_patches: hidden_states = self.pos_embed(hidden_states) # 2. Blocks for block in self.transformer_blocks: hidden_states = block( hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, ) attention_probs = block.attn2.attn_probs[batch // 2:batch, ...] loss_reg = None if self.have_image_cross: # 2.5. image cross attention ph_idx, eot_idx = placeholder_position[ 0], placeholder_position[1] attn = attention_probs.transpose(1, 2) attn_ph = attn[ph_idx].squeeze(1) # bs, n_patch attn_eot = attn[eot_idx].squeeze(1).detach() # ######################## # attention reg if self.image_cross_attention.training: loss_reg = F.mse_loss( attn_ph / attn_ph.max(-1, keepdim=True)[0], attn_eot / attn_eot.max(-1, keepdim=True)[0]) # ######################## mask = attn_ph.detach() mask = otsu(mask) if mask.dim() == 2: mask = mask.unsqueeze(1) hidden_states, image_reference = hidden_states[:batch // 2], \ hidden_states[batch // 2:] hidden_states = self.image_cross_attention( hidden_states, attention_mask=attention_mask, encoder_hidden_states=image_reference, encoder_attention_mask=mask, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, ) hidden_states = torch.cat([hidden_states, image_reference], dim=0) # 3. Output if self.is_input_continuous: if not self.use_linear_projection: hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute( 0, 3, 1, 2).contiguous() hidden_states = self.proj_out(hidden_states) else: hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute( 0, 3, 1, 2).contiguous() output = hidden_states + residual elif self.is_input_vectorized: hidden_states = self.norm_out(hidden_states) logits = self.out(hidden_states) # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) logits = logits.permute(0, 2, 1) # log(p(x_0)) output = F.log_softmax(logits.double(), dim=1).float() elif self.is_input_patches: # TODO: cleanup! conditioning = self.transformer_blocks[0].norm1.emb( timestep, class_labels, hidden_dtype=hidden_states.dtype) shift, scale = self.proj_out_1(F.silu(conditioning)).chunk( 2, dim=1) hidden_states = self.norm_out(hidden_states) * ( 1 + scale[:, None]) + shift[:, None] hidden_states = self.proj_out_2(hidden_states) # unpatchify height = width = int(hidden_states.shape[1]**0.5) hidden_states = hidden_states.reshape( shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)) hidden_states = torch.einsum('nhwpqc->nchpwq', hidden_states) output = hidden_states.reshape( shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)) if not return_dict: return (output, loss_reg) return ViCoTransformer2DModelOutput(sample=output, loss_reg=loss_reg)
[docs]def replace_transformer2d(module: nn.Module, have_image_cross: Dict[str, List[bool]]): """Replace the the Transformer2DModel in UNet. Args: module (nn.Module): Parent module of Transformer2D. have_image_cross (List): List of flag indicating which transformer2D modules have image_cross_attention modules. """ down_transformer2d_modules = [(k.rsplit('.', 1), v) for k, v in module.named_modules() if isinstance(v, Transformer2DModel)] for i, ((parent, k), v) in enumerate(down_transformer2d_modules): parent = module.get_submodule(parent) setattr(parent, k, ViCoTransformer2D(v, have_image_cross[i]))
[docs]class ViCoBlockWrapper(nn.Module): """Wrapper for ViCo blocks."""
[docs] def apply_to(self, org_module): self.org_module = org_module self.org_module.forward = self.forward
[docs]class ViCoCrossAttnDownBlock2D(ViCoBlockWrapper):
[docs] def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, placeholder_position: torch.Tensor = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): """ Args: hidden_states (torch.FloatTensor): Hidden states. temb (Optional[torch.FloatTensor]): Time embedding. encoder_hidden_states (Optional[torch.FloatTensor]): Encoder hidden states. placeholder_position (torch.Tensor): Placeholder position. attention_mask (Optional[torch.FloatTensor]): Attention mask. cross_attention_kwargs (Optional[Dict[str, Any]]): Cross attention keyword arguments. encoder_attention_mask (Optional[torch.FloatTensor]): Encoder attention mask. Returns: torch.FloatTensor: Output hidden states. Tuple[torch.FloatTensor]: Output hidden states of each block. torch.FloatTensor: Attention regularization loss. """ output_states = () loss_reg_all = 0.0 for resnet, attn in zip(self.org_module.resnets, self.org_module.attentions): attn: ViCoTransformer2D if self.org_module.training and ( self.org_module.gradient_checkpointing): def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = { 'use_reentrant': False } if is_torch_version('>=', '1.11.0') else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states, loss_reg = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, None, # timestep placeholder_position, None, # class_labels cross_attention_kwargs, attention_mask, encoder_attention_mask, **ckpt_kwargs, ) else: hidden_states = resnet(hidden_states, temb) hidden_states, loss_reg = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, placeholder_position=placeholder_position, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, ) output_states = output_states + (hidden_states, ) if loss_reg is not None: loss_reg_all += loss_reg if self.org_module.downsamplers is not None: for downsampler in self.org_module.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states, ) return hidden_states, output_states, loss_reg_all
[docs]class ViCoUNetMidBlock2DCrossAttn(ViCoBlockWrapper):
[docs] def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, placeholder_position: torch.Tensor = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: """ Args: hidden_states (torch.FloatTensor): Hidden states. temb (Optional[torch.FloatTensor]): Time embedding. encoder_hidden_states (Optional[torch.FloatTensor]): Encoder hidden states. placeholder_position (torch.Tensor): Placeholder position. attention_mask (Optional[torch.FloatTensor]): Attention mask. cross_attention_kwargs (Optional[Dict[str, Any]]): Cross attention keyword arguments. encoder_attention_mask (Optional[torch.FloatTensor]): Encoder attention mask. Returns: torch.FloatTensor: Output hidden states. torch.FloatTensor: Attention regularization loss. """ loss_reg_all = 0.0 hidden_states = self.org_module.resnets[0](hidden_states, temb) for attn, resnet in zip(self.org_module.attentions, self.org_module.resnets[1:]): hidden_states, loss_reg = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, placeholder_position=placeholder_position, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, ) hidden_states = resnet(hidden_states, temb) if loss_reg is not None: loss_reg_all += loss_reg return hidden_states, loss_reg_all
[docs]class ViCoCrossAttnUpBlock2D(ViCoBlockWrapper):
[docs] def forward( self, hidden_states: torch.FloatTensor, res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, placeholder_position: torch.Tensor = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): """Performs the forward pass through the ViCoCrossAttnUpBlock2D module. Args: hidden_states (torch.FloatTensor): Input hidden states. res_hidden_states_tuple (Tuple[torch.FloatTensor, ...]): Tuple of residual hidden states. temb (Optional[torch.FloatTensor], optional): Temporal embeddings. Defaults to None. encoder_hidden_states (Optional[torch.FloatTensor], optional): Encoder hidden states. Defaults to None. placeholder_position (torch.Tensor, optional): Placeholder positions. Defaults to None. cross_attention_kwargs (Optional[Dict[str, Any]], optional): Keyword arguments for cross-attention. Defaults to None. upsample_size (Optional[int], optional): Upsample size. attention_mask (Optional[torch.FloatTensor], optional): Attention mask. encoder_attention_mask (Optional[torch.FloatTensor], optional): Encoder attention mask. Returns: Tuple[torch.FloatTensor, torch.FloatTensor]: A tuple containing the output hidden states and the total regularization loss. """ loss_reg_all = 0.0 for resnet, attn in zip(self.org_module.resnets, self.org_module.attentions): attn: ViCoTransformer2D # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.org_module.training and ( self.org_module.gradient_checkpointing): def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = { 'use_reentrant': False } if is_torch_version('>=', '1.11.0') else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states, loss_reg = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, None, # timestep placeholder_position, None, # class_labels cross_attention_kwargs, attention_mask, encoder_attention_mask, **ckpt_kwargs, ) else: hidden_states = resnet(hidden_states, temb) hidden_states, loss_reg = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, placeholder_position=placeholder_position, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, ) if loss_reg is not None: loss_reg_all += loss_reg if self.org_module.upsamplers is not None: for upsampler in self.org_module.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states, loss_reg_all
[docs]class ViCoUNet2DConditionOutput(BaseOutput): """Output for ViCoUNet2DConditionModel."""
[docs] sample: torch.FloatTensor
[docs] loss_reg: torch.FloatTensor
[docs]class ViCoUNet2DConditionModel(ViCoBlockWrapper): """UNet2DConditionModel for ViCo Method."""
[docs] def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, placeholder_position: torch.Tensor, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: """Performs the forward pass through the ViCoBlock2D module. Args: sample (torch.FloatTensor): Input sample. timestep (Union[torch.Tensor, float, int]): Timestep value. encoder_hidden_states (torch.Tensor): Encoder hidden states. placeholder_position (torch.Tensor): Placeholder positions. class_labels (Optional[torch.Tensor], optional): Class labels. Defaults to None. timestep_cond (Optional[torch.Tensor], optional): Timestep condition. Defaults to None. attention_mask (Optional[torch.Tensor], optional): Attention mask. Defaults to None. cross_attention_kwargs (Optional[Dict[str, Any]], optional): Keyword arguments for cross-attention. Defaults to None. added_cond_kwargs (Optional[Dict[str, torch.Tensor]], optional): Additional condition arguments. Defaults to None. down_block_additional_residuals (Optional[Tuple[torch.Tensor]], optional): Additional residuals for down-blocks. Defaults to None. mid_block_additional_residual (Optional[torch.Tensor], optional): Additional residual for mid-block. Defaults to None. encoder_attention_mask (Optional[torch.Tensor], optional): Encoder attention mask. Defaults to None. return_dict (bool, optional): Whether to return a dictionary or a tuple. Returns: Union[UNet2DConditionOutput, Tuple]: The output of the forward pass, which can be either a UNet2DConditionOutput object or a tuple of tensors. """ default_overall_up_factor = 2**self.org_module.num_upsamplers forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): forward_upsample_size = True if attention_mask is not None: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) if encoder_attention_mask is not None: encoder_attention_mask = ( 1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 0. center input if necessary if self.org_module.config.center_input_sample: sample = 2 * sample - 1.0 # 1. time timesteps = timestep if not torch.is_tensor(timesteps): is_mps = sample.device.type == 'mps' if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) t_emb = self.org_module.time_proj(timesteps) t_emb = t_emb.to(dtype=sample.dtype) emb = self.org_module.time_embedding(t_emb, timestep_cond) if self.org_module.class_embedding is not None: if class_labels is None: raise ValueError('class_labels should be provided \ when num_class_embeds > 0') if self.org_module.config.class_embed_type == 'timestep': class_labels = self.org_module.time_proj(class_labels) class_labels = class_labels.to(dtype=sample.dtype) class_emb = self.org_module.class_embedding(class_labels).to( dtype=sample.dtype) if self.org_module.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) else: emb = emb + class_emb if self.org_module.config.addition_embed_type == 'text': aug_emb = self.org_module.add_embedding(encoder_hidden_states) emb = emb + aug_emb elif self.org_module.config.addition_embed_type == 'text_image': # Kadinsky 2.1 - style if 'image_embeds' not in added_cond_kwargs: raise ValueError( f"{self.org_module.__class__} has the config param \ `addition_embed_type` set to 'text_image' which \ requires the keyword argument `image_embeds` \ to be passed in `added_cond_kwargs`") image_embs = added_cond_kwargs.get('image_embeds') text_embs = added_cond_kwargs.get('text_embeds', encoder_hidden_states) aug_emb = self.org_module.add_embedding(text_embs, image_embs) emb = emb + aug_emb if self.org_module.time_embed_act is not None: emb = self.org_module.time_embed_act(emb) if self.org_module.encoder_hid_proj is not None and ( self.org_module.config.encoder_hid_dim_type == 'text_proj'): encoder_hidden_states = self.org_module.encoder_hid_proj( encoder_hidden_states) elif self.org_module.encoder_hid_proj is not None and ( self.org_module.config.encoder_hid_dim_type == 'text_image_proj'): # Kadinsky 2.1 - style if 'image_embeds' not in added_cond_kwargs: raise ValueError( f"{self.org_module.__class__} has the config param \ `encoder_hid_dim_type` set to 'text_image_proj' which \ requires the keyword argument `image_embeds` to be \ passed in `added_conditions`") image_embeds = added_cond_kwargs.get('image_embeds') encoder_hidden_states = self.org_module.encoder_hid_proj( encoder_hidden_states, image_embeds) # 2. pre-process sample = self.org_module.conv_in(sample) loss_reg_all = 0.0 if self.training else None # 3. down down_block_res_samples = (sample, ) for downsample_block in self.org_module.down_blocks: if hasattr(downsample_block, 'has_cross_attention' ) and downsample_block.has_cross_attention: sample, res_samples, loss_reg = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, placeholder_position=placeholder_position, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) if self.training: loss_reg_all += loss_reg else: sample, res_samples = downsample_block( hidden_states=sample, temb=emb) down_block_res_samples += res_samples if down_block_additional_residuals is not None: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals): down_block_res_sample = down_block_res_sample + \ down_block_additional_residual new_down_block_res_samples = new_down_block_res_samples + ( down_block_res_sample, ) down_block_res_samples = new_down_block_res_samples # 4. mid if self.org_module.mid_block is not None: sample, loss_reg = self.org_module.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) if self.training: loss_reg_all += loss_reg if mid_block_additional_residual is not None: sample = sample + mid_block_additional_residual # 5. up for i, upsample_block in enumerate(self.org_module.up_blocks): is_final_block = i == len(self.org_module.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets):] down_block_res_samples = down_block_res_samples[:-len( upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, 'has_cross_attention' ) and upsample_block.has_cross_attention: sample, loss_reg = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, placeholder_position=placeholder_position, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, ) if self.training: loss_reg_all += loss_reg else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size) # 6. post-process if self.org_module.conv_norm_out: sample = self.org_module.conv_norm_out(sample) sample = self.org_module.conv_act(sample) sample = self.org_module.conv_out(sample) if not return_dict: return (sample, loss_reg_all) return ViCoUNet2DConditionOutput(sample=sample, loss_reg=loss_reg_all)
[docs]def set_vico_modules(unet, image_cross_layers): """Set all modules for ViCo method after the UNet initialized normally. Args: unet (nn.Module): UNet model. image_cross_layers (List): List of flag indicating which transformer2D modules have image_cross_attention modules. """ # replace transformer2d blocks replace_transformer2d(unet, image_cross_layers) # replace cross attention layer replace_cross_attention(unet) # replace forward for _, layer in unet.named_modules(): if layer.__class__.__name__ == 'UNet2DConditionModel': vico_unet = ViCoUNet2DConditionModel() vico_unet.apply_to(unet) elif layer.__class__.__name__ == 'CrossAttnDownBlock2D': vico_down_block = ViCoCrossAttnDownBlock2D() vico_down_block.apply_to(layer) elif layer.__class__.__name__ == 'UNetMidBlock2DCrossAttn': vico_mid_block = ViCoUNetMidBlock2DCrossAttn() vico_mid_block.apply_to(layer) elif layer.__class__.__name__ == 'CrossAttnUpBlock2D': vico_up_block = ViCoCrossAttnUpBlock2D() vico_up_block.apply_to(layer)