mmagic.models.editors.vico.vico_utils
¶
Module Contents¶
Classes¶
Processor for implementing attention for the ViCo method. |
|
Output for ViCoTransformer2DModel. |
|
New ViCo-Transformer2D to replace the original Transformer2D model. |
|
Wrapper for ViCo blocks. |
|
Wrapper for ViCo blocks. |
|
Wrapper for ViCo blocks. |
|
Wrapper for ViCo blocks. |
|
Output for ViCoUNet2DConditionModel. |
|
UNet2DConditionModel for ViCo Method. |
Functions¶
|
Replace Cross Attention processor in UNet. |
|
Apply otsu for mask. |
|
Replace the the Transformer2DModel in UNet. |
|
Set all modules for ViCo method after the UNet initialized normally. |
- class mmagic.models.editors.vico.vico_utils.ViCoCrossAttnProcessor[source]¶
Processor for implementing attention for the ViCo method.
- __call__(attn: diffusers.models.attention.Attention, hidden_states, encoder_hidden_states=None, attention_mask=None)[source]¶
- Parameters
attn (Attention) – Attention module.
hidden_states (torch.Tensor) – Input hidden states.
encoder_hidden_states (torch.Tensor) – Encoder hidden states.
attention_mask (torch.Tensor) – Attention mask.
- Returns
Output hidden states.
- Return type
torch.Tensor
- mmagic.models.editors.vico.vico_utils.replace_cross_attention(unet)[source]¶
Replace Cross Attention processor in UNet.
- class mmagic.models.editors.vico.vico_utils.ViCoTransformer2DModelOutput[source]¶
Bases:
diffusers.utils.BaseOutput
Output for ViCoTransformer2DModel.
- mmagic.models.editors.vico.vico_utils.otsu(mask_in)[source]¶
Apply otsu for mask.
- Parameters
mask_in (torch.Tensor) – Input mask.
- class mmagic.models.editors.vico.vico_utils.ViCoTransformer2D(org_transformer2d: diffusers.Transformer2DModel, have_image_cross)[source]¶
Bases:
torch.nn.Module
New ViCo-Transformer2D to replace the original Transformer2D model.
- forward(hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, timestep: Optional[torch.LongTensor] = None, placeholder_position: list = None, class_labels: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True)[source]¶
- mmagic.models.editors.vico.vico_utils.replace_transformer2d(module: torch.nn.Module, have_image_cross: Dict[str, List[bool]])[source]¶
Replace the the Transformer2DModel in UNet.
- Parameters
module (nn.Module) – Parent module of Transformer2D.
have_image_cross (List) – List of flag indicating which
modules. (transformer2D modules have image_cross_attention) –
- class mmagic.models.editors.vico.vico_utils.ViCoBlockWrapper[source]¶
Bases:
torch.nn.Module
Wrapper for ViCo blocks.
- class mmagic.models.editors.vico.vico_utils.ViCoCrossAttnDownBlock2D[source]¶
Bases:
ViCoBlockWrapper
Wrapper for ViCo blocks.
- forward(hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, placeholder_position: torch.Tensor = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None)[source]¶
- Parameters
hidden_states (torch.FloatTensor) – Hidden states.
temb (Optional[torch.FloatTensor]) – Time embedding.
encoder_hidden_states (Optional[torch.FloatTensor]) – Encoder hidden states.
placeholder_position (torch.Tensor) – Placeholder position.
attention_mask (Optional[torch.FloatTensor]) – Attention mask.
cross_attention_kwargs (Optional[Dict[str, Any]]) – Cross attention keyword arguments.
encoder_attention_mask (Optional[torch.FloatTensor]) – Encoder attention mask.
- Returns
Output hidden states. Tuple[torch.FloatTensor]: Output hidden states of each block. torch.FloatTensor: Attention regularization loss.
- Return type
torch.FloatTensor
- class mmagic.models.editors.vico.vico_utils.ViCoUNetMidBlock2DCrossAttn[source]¶
Bases:
ViCoBlockWrapper
Wrapper for ViCo blocks.
- forward(hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, placeholder_position: torch.Tensor = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None) torch.FloatTensor [source]¶
- Parameters
hidden_states (torch.FloatTensor) – Hidden states.
temb (Optional[torch.FloatTensor]) – Time embedding.
encoder_hidden_states (Optional[torch.FloatTensor]) – Encoder hidden states.
placeholder_position (torch.Tensor) – Placeholder position.
attention_mask (Optional[torch.FloatTensor]) – Attention mask.
cross_attention_kwargs (Optional[Dict[str, Any]]) – Cross attention keyword arguments.
encoder_attention_mask (Optional[torch.FloatTensor]) – Encoder attention mask.
- Returns
Output hidden states. torch.FloatTensor: Attention regularization loss.
- Return type
torch.FloatTensor
- class mmagic.models.editors.vico.vico_utils.ViCoCrossAttnUpBlock2D[source]¶
Bases:
ViCoBlockWrapper
Wrapper for ViCo blocks.
- forward(hidden_states: torch.FloatTensor, res_hidden_states_tuple: Tuple[torch.FloatTensor, Ellipsis], temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, placeholder_position: torch.Tensor = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None)[source]¶
Performs the forward pass through the ViCoCrossAttnUpBlock2D module.
- Parameters
hidden_states (torch.FloatTensor) – Input hidden states.
res_hidden_states_tuple (Tuple[torch.FloatTensor, ...]) – Tuple of residual hidden states.
temb (Optional[torch.FloatTensor], optional) – Temporal embeddings. Defaults to None.
encoder_hidden_states (Optional[torch.FloatTensor], optional) – Encoder hidden states. Defaults to None.
placeholder_position (torch.Tensor, optional) – Placeholder positions. Defaults to None.
cross_attention_kwargs (Optional[Dict[str, Any]], optional) – Keyword arguments for cross-attention. Defaults to None.
upsample_size (Optional[int], optional) – Upsample size.
attention_mask (Optional[torch.FloatTensor], optional) – Attention mask.
encoder_attention_mask (Optional[torch.FloatTensor], optional) – Encoder attention mask.
- Returns
A tuple containing the output hidden states and the total regularization loss.
- Return type
Tuple[torch.FloatTensor, torch.FloatTensor]
- class mmagic.models.editors.vico.vico_utils.ViCoUNet2DConditionOutput[source]¶
Bases:
diffusers.utils.BaseOutput
Output for ViCoUNet2DConditionModel.
- class mmagic.models.editors.vico.vico_utils.ViCoUNet2DConditionModel[source]¶
Bases:
ViCoBlockWrapper
UNet2DConditionModel for ViCo Method.
- forward(sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, placeholder_position: torch.Tensor, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True) Union[diffusers.models.unet_2d_condition.UNet2DConditionOutput, Tuple] [source]¶
Performs the forward pass through the ViCoBlock2D module.
- Parameters
sample (torch.FloatTensor) – Input sample.
timestep (Union[torch.Tensor, float, int]) – Timestep value.
encoder_hidden_states (torch.Tensor) – Encoder hidden states.
placeholder_position (torch.Tensor) – Placeholder positions.
class_labels (Optional[torch.Tensor], optional) – Class labels. Defaults to None.
timestep_cond (Optional[torch.Tensor], optional) – Timestep condition. Defaults to None.
attention_mask (Optional[torch.Tensor], optional) – Attention mask. Defaults to None.
cross_attention_kwargs (Optional[Dict[str, Any]], optional) – Keyword arguments for cross-attention. Defaults to None.
added_cond_kwargs (Optional[Dict[str, torch.Tensor]], optional) – Additional condition arguments. Defaults to None.
down_block_additional_residuals – (Optional[Tuple[torch.Tensor]], optional): Additional residuals for down-blocks. Defaults to None.
mid_block_additional_residual (Optional[torch.Tensor], optional) – Additional residual for mid-block. Defaults to None.
encoder_attention_mask (Optional[torch.Tensor], optional) – Encoder attention mask. Defaults to None.
return_dict (bool, optional) – Whether to return a dictionary or a tuple.
- Returns
The output of the forward pass, which can be either a UNet2DConditionOutput object or a tuple of tensors.
- Return type
Union[UNet2DConditionOutput, Tuple]
- mmagic.models.editors.vico.vico_utils.set_vico_modules(unet, image_cross_layers)[source]¶
Set all modules for ViCo method after the UNet initialized normally.
- Parameters
unet (nn.Module) – UNet model.
image_cross_layers (List) – List of flag indicating which
modules. (transformer2D modules have image_cross_attention) –