Shortcuts

mmagic.models.editors.ddpm.unet_blocks

Module Contents

Classes

UNetMidBlock2DCrossAttn

unet mid block built by cross attention.

CrossAttnDownBlock2D

Down block built by cross attention.

DownBlock2D

Down block built by resnet.

CrossAttnUpBlock2D

Up block built by cross attention.

UpBlock2D

Up block built by resnet.

Functions

get_down_block(down_block_type, num_layers, ...[, ...])

get unet down path block.

get_up_block(up_block_type, num_layers, in_channels, ...)

get unet up path block.

mmagic.models.editors.ddpm.unet_blocks.get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample, resnet_act_fn, attn_num_head_channels, resnet_eps=1e-05, resnet_groups=32, cross_attention_dim=1280, downsample_padding=1, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False)[source]

get unet down path block.

mmagic.models.editors.ddpm.unet_blocks.get_up_block(up_block_type, num_layers, in_channels, out_channels, prev_output_channel, temb_channels, add_upsample, resnet_act_fn, attn_num_head_channels, resnet_eps=1e-05, resnet_groups=32, cross_attention_dim=1280, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False)[source]

get unet up path block.

class mmagic.models.editors.ddpm.unet_blocks.UNetMidBlock2DCrossAttn(in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-05, resnet_time_scale_shift: str = 'default', resnet_act_fn: str = 'swish', resnet_groups: int = 32, resnet_pre_norm: bool = True, attn_num_head_channels=1, attention_type='default', output_scale_factor=1.0, cross_attention_dim=1280, dual_cross_attention=False, use_linear_projection=False)[source]

Bases: torch.nn.Module

unet mid block built by cross attention.

set_attention_slice(slice_size)[source]

set attention slice.

forward(hidden_states, temb=None, encoder_hidden_states=None)[source]

forward with hidden states.

class mmagic.models.editors.ddpm.unet_blocks.CrossAttnDownBlock2D(in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-05, resnet_time_scale_shift: str = 'default', resnet_act_fn: str = 'swish', resnet_groups: int = 32, resnet_pre_norm: bool = True, attn_num_head_channels=1, cross_attention_dim=1280, attention_type='default', output_scale_factor=1.0, downsample_padding=1, add_downsample=True, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False)[source]

Bases: torch.nn.Module

Down block built by cross attention.

set_attention_slice(slice_size)[source]

set attention slice.

forward(hidden_states, temb=None, encoder_hidden_states=None)[source]

forward with hidden states.

class mmagic.models.editors.ddpm.unet_blocks.DownBlock2D(in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-05, resnet_time_scale_shift: str = 'default', resnet_act_fn: str = 'swish', resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_downsample=True, downsample_padding=1)[source]

Bases: torch.nn.Module

Down block built by resnet.

forward(hidden_states, temb=None)[source]

forward with hidden states.

class mmagic.models.editors.ddpm.unet_blocks.CrossAttnUpBlock2D(in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-05, resnet_time_scale_shift: str = 'default', resnet_act_fn: str = 'swish', resnet_groups: int = 32, resnet_pre_norm: bool = True, attn_num_head_channels=1, cross_attention_dim=1280, attention_type='default', output_scale_factor=1.0, add_upsample=True, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False)[source]

Bases: torch.nn.Module

Up block built by cross attention.

set_attention_slice(slice_size)[source]

set attention slice.

forward(hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None)[source]

forward with hidden states and res hidden states.

class mmagic.models.editors.ddpm.unet_blocks.UpBlock2D(in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-05, resnet_time_scale_shift: str = 'default', resnet_act_fn: str = 'swish', resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True)[source]

Bases: torch.nn.Module

Up block built by resnet.

forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None)[source]

forward with hidden states and res hidden states.

Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.