Shortcuts

mmagic.models.editors.fastcomposer.fastcomposer_util

Module Contents

Classes

FastComposerModel

FastComposerModel is based on the StableDiffusion Model and the Clip

FastComposerTextEncoder

TextEncoder for FastComposerModel.

FastComposerCLIPImageEncoder

CLIPImageEncoder for FastComposerModel.

FastComposerPostfuseModule

Postfuse Module for FastComposerModel.

BalancedL1Loss

BalancedL1Loss for object localization.

RandomZoomIn

RandomZoomIn for object transform.

PadToSquare

If the height of the image is greater than the width, padding will be

CropTopSquare

If the height of the image is greater than the width, the image will be

MLP

Multilayer Perceptron.

Functions

get_object_transforms(cfg)

Get Object transforms.

unet_store_cross_attention_scores(unet, attention_scores)

Unet store cross attention scores.

get_object_localization_loss(cross_attention_scores, ...)

To obtain the average of the loss for each layer of object

get_object_localization_loss_for_one_layer(...)

Get object localization loss for one layer.

fuse_object_embeddings(inputs_embeds, ...[, fuse_fn])

Fuse object embeddings.

build_causal_attention_mask(bsz, seq_len, dtype[, device])

The function originally belonged to CLIPTextTransformer, but it has been

Attributes

_expand_mask

_expand_mask

mmagic.models.editors.fastcomposer.fastcomposer_util._expand_mask[source]
mmagic.models.editors.fastcomposer.fastcomposer_util._expand_mask[source]
class mmagic.models.editors.fastcomposer.fastcomposer_util.FastComposerModel(text_encoder, image_encoder, vae, unet, cfg)[source]

Bases: torch.nn.Module

FastComposerModel is based on the StableDiffusion Model and the Clip Model.

_clear_cross_attention_scores()[source]

Delete cross attention scores.

static from_pretrained(cfg, vae, unet)[source]

Init FastComposerTextEncoder and FastComposerCLIPImageEncoder.

forward(batch, noise_scheduler)[source]

Forward function.

Parameters
  • batch (torch.Tensor) – You can directly input a torch.Tensor.

  • noise_scheduler (torch.Tensor) – You can directly input a torch.Tensor.

Returns

Dict

class mmagic.models.editors.fastcomposer.fastcomposer_util.FastComposerTextEncoder(text_model)[source]

Bases: transformers.CLIPPreTrainedModel

TextEncoder for FastComposerModel.

static from_pretrained(model_name_or_path, **kwargs)[source]

Init textEncoder with Stable Diffusion Model name or path.

forward(input_ids, image_token_mask=None, object_embeds=None, num_objects=None, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None) Union[Tuple, transformers.modeling_outputs.BaseModelOutputWithPooling][source]

Forward function.

Parameters
  • input_ids (torch.Tensor) – You can directly input a torch.Tensor.

  • image_token_mask (torch.Tensor) – You can directly input a torch.Tensor.

  • object_embeds (torch.Tensor) – You can directly input a torch.Tensor.

  • num_objects (torch.Tensor) – You can directly input a torch.Tensor.

  • attention_mask (torch.Tensor) – You can directly input a torch.Tensor.

  • output_attentions (bool) – Default to None.

  • output_hidden_states (bool) – Default to None.

  • return_dict (bool) – Default to None.

Returns

Union[Tuple, BaseModelOutputWithPooling]

class mmagic.models.editors.fastcomposer.fastcomposer_util.FastComposerCLIPImageEncoder(vision_model, visual_projection, vision_processor)[source]

Bases: transformers.CLIPPreTrainedModel

CLIPImageEncoder for FastComposerModel.

static from_pretrained(global_model_name_or_path)[source]

Init CLIPModel with Clip model name or path.

forward(object_pixel_values)[source]

Forward function.

Parameters

object_pixel_values (torch.Tensor) – You can directly input a torch.Tensor.

Returns

torch.tensor will be returned.

Return type

torch.Tensor

mmagic.models.editors.fastcomposer.fastcomposer_util.get_object_transforms(cfg)[source]

Get Object transforms.

class mmagic.models.editors.fastcomposer.fastcomposer_util.FastComposerPostfuseModule(embed_dim)[source]

Bases: torch.nn.Module

Postfuse Module for FastComposerModel.

fuse_fn(text_embeds, object_embeds)[source]

Fuse function.

Parameters
  • text_embeds (torch.Tensor) – You can directly input a torch.Tensor.

  • object_embeds (torch.Tensor) – You can directly input a torch.Tensor.

Returns

torch.tensor will be returned.

Return type

torch.Tensor

forward(text_embeds, object_embeds, image_token_mask, num_objects) torch.Tensor[source]

Forward function.

Parameters
  • text_embeds (torch.Tensor) – You can directly input a torch.Tensor.

  • object_embeds (torch.Tensor) – You can directly input a torch.Tensor.

  • image_token_mask (torch.Tensor) – You can directly input a torch.Tensor.

Returns

torch.tensor will be returned.

Return type

torch.Tensor

mmagic.models.editors.fastcomposer.fastcomposer_util.unet_store_cross_attention_scores(unet, attention_scores, layers=5)[source]

Unet store cross attention scores.

class mmagic.models.editors.fastcomposer.fastcomposer_util.BalancedL1Loss(threshold=1.0, normalize=False)[source]

Bases: torch.nn.Module

BalancedL1Loss for object localization.

forward(object_token_attn_prob, object_segmaps)[source]

Forward function.

Parameters
  • object_token_attn_prob (torch.Tensor) – You can directly input a torch.Tensor.

  • object_segmaps (torch.Tensor) – You can directly input a torch.Tensor.

Returns

float will be returned.

Return type

float

mmagic.models.editors.fastcomposer.fastcomposer_util.get_object_localization_loss(cross_attention_scores, object_segmaps, image_token_idx, image_token_idx_mask, loss_fn)[source]

To obtain the average of the loss for each layer of object localization.

mmagic.models.editors.fastcomposer.fastcomposer_util.get_object_localization_loss_for_one_layer(cross_attention_scores, object_segmaps, object_token_idx, object_token_idx_mask, loss_fn)[source]

Get object localization loss for one layer.

class mmagic.models.editors.fastcomposer.fastcomposer_util.RandomZoomIn(min_zoom=1.0, max_zoom=1.5)[source]

Bases: torch.nn.Module

RandomZoomIn for object transform.

forward(image: torch.Tensor)[source]

Forward function.

Parameters

image (torch.Tensor) – You can directly input a torch.Tensor.

Returns

torch.tensor will be returned.

Return type

torch.Tensor

class mmagic.models.editors.fastcomposer.fastcomposer_util.PadToSquare(fill=0, padding_mode='constant')[source]

Bases: torch.nn.Module

If the height of the image is greater than the width, padding will be added on both sides of the image to make it a square.

forward(image: torch.Tensor)[source]

Forward function.

Parameters

image (torch.Tensor) – You can directly input a torch.Tensor.

Returns

torch.tensor will be returned.

Return type

torch.Tensor

class mmagic.models.editors.fastcomposer.fastcomposer_util.CropTopSquare[source]

Bases: torch.nn.Module

If the height of the image is greater than the width, the image will be cropped into a square starting from the top of the image.

forward(image: torch.Tensor)[source]

Forward function.

Parameters

image (torch.Tensor) – You can directly input a torch.Tensor.

Returns

torch.tensor will be returned.

Return type

torch.Tensor

class mmagic.models.editors.fastcomposer.fastcomposer_util.MLP(in_dim, out_dim, hidden_dim, use_residual=True)[source]

Bases: torch.nn.Module

Multilayer Perceptron.

forward(x)[source]

Forward function.

Parameters

x (torch.Tensor) – You can directly input a torch.Tensor.

Returns

torch.tensor will be returned.

Return type

torch.Tensor

mmagic.models.editors.fastcomposer.fastcomposer_util.fuse_object_embeddings(inputs_embeds, image_token_mask, object_embeds, num_objects, fuse_fn=torch.add)[source]

Fuse object embeddings.

mmagic.models.editors.fastcomposer.fastcomposer_util.build_causal_attention_mask(bsz, seq_len, dtype, device=None)[source]

The function originally belonged to CLIPTextTransformer, but it has been removed in versions of transformers after 4.25.1.

Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.