Shortcuts

mmagic.models.editors.vico.vico_utils

Module Contents

Classes

ViCoCrossAttnProcessor

Processor for implementing attention for the ViCo method.

ViCoTransformer2DModelOutput

Output for ViCoTransformer2DModel.

ViCoTransformer2D

New ViCo-Transformer2D to replace the original Transformer2D model.

ViCoBlockWrapper

Wrapper for ViCo blocks.

ViCoCrossAttnDownBlock2D

Wrapper for ViCo blocks.

ViCoUNetMidBlock2DCrossAttn

Wrapper for ViCo blocks.

ViCoCrossAttnUpBlock2D

Wrapper for ViCo blocks.

ViCoUNet2DConditionOutput

Output for ViCoUNet2DConditionModel.

ViCoUNet2DConditionModel

UNet2DConditionModel for ViCo Method.

Functions

replace_cross_attention(unet)

Replace Cross Attention processor in UNet.

otsu(mask_in)

Apply otsu for mask.

replace_transformer2d(module, have_image_cross)

Replace the the Transformer2DModel in UNet.

set_vico_modules(unet, image_cross_layers)

Set all modules for ViCo method after the UNet initialized normally.

class mmagic.models.editors.vico.vico_utils.ViCoCrossAttnProcessor[源代码]

Processor for implementing attention for the ViCo method.

__call__(attn: diffusers.models.attention.Attention, hidden_states, encoder_hidden_states=None, attention_mask=None)[源代码]
参数
  • attn (Attention) – Attention module.

  • hidden_states (torch.Tensor) – Input hidden states.

  • encoder_hidden_states (torch.Tensor) – Encoder hidden states.

  • attention_mask (torch.Tensor) – Attention mask.

返回

Output hidden states.

返回类型

torch.Tensor

mmagic.models.editors.vico.vico_utils.replace_cross_attention(unet)[源代码]

Replace Cross Attention processor in UNet.

class mmagic.models.editors.vico.vico_utils.ViCoTransformer2DModelOutput[源代码]

Bases: diffusers.utils.BaseOutput

Output for ViCoTransformer2DModel.

sample: torch.FloatTensor[源代码]
loss_reg: torch.FloatTensor[源代码]
mmagic.models.editors.vico.vico_utils.otsu(mask_in)[源代码]

Apply otsu for mask.

参数

mask_in (torch.Tensor) – Input mask.

class mmagic.models.editors.vico.vico_utils.ViCoTransformer2D(org_transformer2d: diffusers.Transformer2DModel, have_image_cross)[源代码]

Bases: torch.nn.Module

New ViCo-Transformer2D to replace the original Transformer2D model.

forward(hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, timestep: Optional[torch.LongTensor] = None, placeholder_position: list = None, class_labels: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True)[源代码]
mmagic.models.editors.vico.vico_utils.replace_transformer2d(module: torch.nn.Module, have_image_cross: Dict[str, List[bool]])[源代码]

Replace the the Transformer2DModel in UNet.

参数
  • module (nn.Module) – Parent module of Transformer2D.

  • have_image_cross (List) – List of flag indicating which

  • modules. (transformer2D modules have image_cross_attention) –

class mmagic.models.editors.vico.vico_utils.ViCoBlockWrapper[源代码]

Bases: torch.nn.Module

Wrapper for ViCo blocks.

apply_to(org_module)[源代码]
class mmagic.models.editors.vico.vico_utils.ViCoCrossAttnDownBlock2D[源代码]

Bases: ViCoBlockWrapper

Wrapper for ViCo blocks.

forward(hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, placeholder_position: torch.Tensor = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None)[源代码]
参数
  • hidden_states (torch.FloatTensor) – Hidden states.

  • temb (Optional[torch.FloatTensor]) – Time embedding.

  • encoder_hidden_states (Optional[torch.FloatTensor]) – Encoder hidden states.

  • placeholder_position (torch.Tensor) – Placeholder position.

  • attention_mask (Optional[torch.FloatTensor]) – Attention mask.

  • cross_attention_kwargs (Optional[Dict[str, Any]]) – Cross attention keyword arguments.

  • encoder_attention_mask (Optional[torch.FloatTensor]) – Encoder attention mask.

返回

Output hidden states. Tuple[torch.FloatTensor]: Output hidden states of each block. torch.FloatTensor: Attention regularization loss.

返回类型

torch.FloatTensor

class mmagic.models.editors.vico.vico_utils.ViCoUNetMidBlock2DCrossAttn[源代码]

Bases: ViCoBlockWrapper

Wrapper for ViCo blocks.

forward(hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, placeholder_position: torch.Tensor = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None) torch.FloatTensor[源代码]
参数
  • hidden_states (torch.FloatTensor) – Hidden states.

  • temb (Optional[torch.FloatTensor]) – Time embedding.

  • encoder_hidden_states (Optional[torch.FloatTensor]) – Encoder hidden states.

  • placeholder_position (torch.Tensor) – Placeholder position.

  • attention_mask (Optional[torch.FloatTensor]) – Attention mask.

  • cross_attention_kwargs (Optional[Dict[str, Any]]) – Cross attention keyword arguments.

  • encoder_attention_mask (Optional[torch.FloatTensor]) – Encoder attention mask.

返回

Output hidden states. torch.FloatTensor: Attention regularization loss.

返回类型

torch.FloatTensor

class mmagic.models.editors.vico.vico_utils.ViCoCrossAttnUpBlock2D[源代码]

Bases: ViCoBlockWrapper

Wrapper for ViCo blocks.

forward(hidden_states: torch.FloatTensor, res_hidden_states_tuple: Tuple[torch.FloatTensor, Ellipsis], temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, placeholder_position: torch.Tensor = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None)[源代码]

Performs the forward pass through the ViCoCrossAttnUpBlock2D module.

参数
  • hidden_states (torch.FloatTensor) – Input hidden states.

  • res_hidden_states_tuple (Tuple[torch.FloatTensor, ...]) – Tuple of residual hidden states.

  • temb (Optional[torch.FloatTensor], optional) – Temporal embeddings. Defaults to None.

  • encoder_hidden_states (Optional[torch.FloatTensor], optional) – Encoder hidden states. Defaults to None.

  • placeholder_position (torch.Tensor, optional) – Placeholder positions. Defaults to None.

  • cross_attention_kwargs (Optional[Dict[str, Any]], optional) – Keyword arguments for cross-attention. Defaults to None.

  • upsample_size (Optional[int], optional) – Upsample size.

  • attention_mask (Optional[torch.FloatTensor], optional) – Attention mask.

  • encoder_attention_mask (Optional[torch.FloatTensor], optional) – Encoder attention mask.

返回

A tuple containing the output hidden states and the total regularization loss.

返回类型

Tuple[torch.FloatTensor, torch.FloatTensor]

class mmagic.models.editors.vico.vico_utils.ViCoUNet2DConditionOutput[源代码]

Bases: diffusers.utils.BaseOutput

Output for ViCoUNet2DConditionModel.

sample: torch.FloatTensor[源代码]
loss_reg: torch.FloatTensor[源代码]
class mmagic.models.editors.vico.vico_utils.ViCoUNet2DConditionModel[源代码]

Bases: ViCoBlockWrapper

UNet2DConditionModel for ViCo Method.

forward(sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, placeholder_position: torch.Tensor, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True) Union[diffusers.models.unet_2d_condition.UNet2DConditionOutput, Tuple][源代码]

Performs the forward pass through the ViCoBlock2D module.

参数
  • sample (torch.FloatTensor) – Input sample.

  • timestep (Union[torch.Tensor, float, int]) – Timestep value.

  • encoder_hidden_states (torch.Tensor) – Encoder hidden states.

  • placeholder_position (torch.Tensor) – Placeholder positions.

  • class_labels (Optional[torch.Tensor], optional) – Class labels. Defaults to None.

  • timestep_cond (Optional[torch.Tensor], optional) – Timestep condition. Defaults to None.

  • attention_mask (Optional[torch.Tensor], optional) – Attention mask. Defaults to None.

  • cross_attention_kwargs (Optional[Dict[str, Any]], optional) – Keyword arguments for cross-attention. Defaults to None.

  • added_cond_kwargs (Optional[Dict[str, torch.Tensor]], optional) – Additional condition arguments. Defaults to None.

  • down_block_additional_residuals – (Optional[Tuple[torch.Tensor]], optional): Additional residuals for down-blocks. Defaults to None.

  • mid_block_additional_residual (Optional[torch.Tensor], optional) – Additional residual for mid-block. Defaults to None.

  • encoder_attention_mask (Optional[torch.Tensor], optional) – Encoder attention mask. Defaults to None.

  • return_dict (bool, optional) – Whether to return a dictionary or a tuple.

返回

The output of the forward pass, which can be either a UNet2DConditionOutput object or a tuple of tensors.

返回类型

Union[UNet2DConditionOutput, Tuple]

mmagic.models.editors.vico.vico_utils.set_vico_modules(unet, image_cross_layers)[源代码]

Set all modules for ViCo method after the UNet initialized normally.

参数
  • unet (nn.Module) – UNet model.

  • image_cross_layers (List) – List of flag indicating which

  • modules. (transformer2D modules have image_cross_attention) –

Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.