Shortcuts

mmagic.models.editors.ddpm.attention

Module Contents

Classes

Transformer2DModel

Transformer model for image-like data. Takes either discrete (classes of

BasicTransformerBlock

A basic Transformer block.

CrossAttention

A cross attention layer.

FeedForward

A feed-forward layer.

GEGLU

A variant of the gated linear unit activation function

ApproximateGELU

The approximate form of Gaussian Error Linear Unit (GELU)

class mmagic.models.editors.ddpm.attention.Transformer2DModel(num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, sample_size: Optional[int] = None, num_vector_embeds: Optional[int] = None, activation_fn: str = 'geglu', use_linear_projection: bool = False, only_cross_attention: bool = False)[source]

Bases: torch.nn.Module

Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual embeddings) inputs.

When input is continuous: First, project the input

(aka embedding) and reshape to b, t, d. Then apply standard

transformer action. Finally, reshape to image.

When input is discrete: First, input (classes of latent pixels)

is converted to embeddings and has positional

embeddings applied, see ImagePositionalEmbeddings. Then apply standard transformer action. Finally, predict classes of unnoised image.

Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.

Parameters
  • num_attention_heads (int, optional, defaults to 16) – The number of heads to use for multi-head attention.

  • attention_head_dim (int, optional, defaults to 88) – The number of channels in each head.

  • in_channels (int, optional) – Pass if the input is continuous. The number of channels in the input and output.

  • num_layers (int, optional, defaults to 1) – The number of layers of Transformer blocks to use.

  • dropout (float, optional, defaults to 0.1) – The dropout probability to use.

  • norm_num_groups (int) – Norm group num, defaults to 32.

  • cross_attention_dim (int, optional) – The number of context dimensions to use.

  • attention_bias (bool, optional) – Configure if the TransformerBlocks’ attention should contain a bias parameter.

  • sample_size (int, optional) – Pass if the input is discrete. The width of the latent images. Note that this is fixed at training time as it is used for learning a number of position embeddings. See ImagePositionalEmbeddings.

  • num_vector_embeds (int, optional) – Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked latent pixel.

  • activation_fn (str, optional, defaults to “geglu”) – Activation function to be used in feed-forward.

  • use_linear_projection (bool) – Whether to use linear projection, defaults to False.

  • only_cross_attention (bool) – whether only use cross attention, defaults to False.

_set_attention_slice(slice_size)[source]

set attention slice.

forward(hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True)[source]

forward function.

Parameters
  • discrete (hidden_states ( When) – of shape (batch size, num latent pixels). When continuous, torch.FloatTensor of shape ` (batch size, channel, height, width)`): Input hidden_states

  • torch.LongTensor – of shape (batch size, num latent pixels). When continuous, torch.FloatTensor of shape ` (batch size, channel, height, width)`): Input hidden_states

  • shape (encoder_hidden_states ( torch.LongTensor of) – (batch size, context dim), optional): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention.

  • timestep (torch.long, optional) – Optional timestep to be applied as an embedding in AdaLayerNorm’s. Used to indicate denoising step.

  • return_dict (bool, optional, defaults to True) – Whether or not to return a [models.unet_2d_condition.UNet2DConditionOutput] instead of a plain tuple.

Returns

Dict if return_dict is True, otherwise a tuple. When returning a tuple, the first element is the sample tensor.

class mmagic.models.editors.ddpm.attention.BasicTransformerBlock(dim: int, num_attention_heads: int, attention_head_dim: int, dropout=0.0, cross_attention_dim: Optional[int] = None, activation_fn: str = 'geglu', attention_bias: bool = False, only_cross_attention: bool = False)[source]

Bases: torch.nn.Module

A basic Transformer block.

Parameters
  • dim (int) – The number of channels in the input and output.

  • num_attention_heads (int) – The number of heads to use for multi-head attention.

  • attention_head_dim (int) – The number of channels in each head.

  • dropout (float, optional, defaults to 0.0) – The dropout probability to use.

  • cross_attention_dim (int, optional) – The size of the context vector for cross attention.

  • activation_fn (str, optional, defaults to “geglu”) – Activation function to be used in feed-forward.

  • attention_bias (bool, optional, defaults to False) – Configure if the attentions should contain a bias parameter.

  • only_cross_attention (bool, defaults to False) – whether to use cross attention only.

_set_attention_slice(slice_size)[source]

set attention slice.

forward(hidden_states, context=None, timestep=None)[source]

forward with hidden states, context and timestep.

class mmagic.models.editors.ddpm.attention.CrossAttention(query_dim: int, cross_attention_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: float = 0.0, bias=False)[source]

Bases: torch.nn.Module

A cross attention layer.

Parameters
  • query_dim (int) – The number of channels in the query.

  • cross_attention_dim (int, optional) – The number of channels in the context. If not given, defaults to query_dim.

  • heads (int, optional, defaults to 8) – The number of heads to use for multi-head attention.

  • dim_head (int, optional, defaults to 64) – The number of channels in each head.

  • dropout (float, optional, defaults to 0.0) – The dropout probability to use.

  • bias (bool, optional, defaults to False) – Set to True for the query, key, and value linear layers to contain a bias parameter.

reshape_heads_to_batch_dim(tensor)[source]

reshape heads num to batch dim.

reshape_batch_dim_to_heads(tensor)[source]

reshape batch dim to heads num.

forward(hidden_states, context=None, mask=None)[source]

forward with hidden states, context and mask.

_attention(query, key, value)[source]

attention calculation.

_sliced_attention(query, key, value, sequence_length, dim)[source]

sliced attention calculation.

class mmagic.models.editors.ddpm.attention.FeedForward(dim: int, dim_out: Optional[int] = None, mult: int = 4, dropout: float = 0.0, activation_fn: str = 'geglu')[source]

Bases: torch.nn.Module

A feed-forward layer.

Parameters
  • dim (int) – The number of channels in the input.

  • dim_out (int, optional) – The number of channels in the output. If not given, defaults to dim.

  • mult (int, optional, defaults to 4) – The multiplier to use for the hidden dimension.

  • dropout (float, optional, defaults to 0.0) – The dropout probability to use.

  • activation_fn (str, optional, defaults to “geglu”) – Activation function to be used in feed-forward.

forward(hidden_states)[source]

forward with hidden states.

class mmagic.models.editors.ddpm.attention.GEGLU(dim_in: int, dim_out: int)[source]

Bases: torch.nn.Module

A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.

Parameters
  • dim_in (int) – The number of channels in the input.

  • dim_out (int) – The number of channels in the output.

gelu(gate)[source]

gelu activation.

forward(hidden_states)[source]

forward with hidden states.

class mmagic.models.editors.ddpm.attention.ApproximateGELU(dim_in: int, dim_out: int)[source]

Bases: torch.nn.Module

The approximate form of Gaussian Error Linear Unit (GELU)

For more details, see section 2: https://arxiv.org/abs/1606.08415

forward(x)[source]

forward function.

Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.