Source code for mmagic.models.editors.animatediff.resnet_3d
# Copyright (c) OpenMMLab. All rights reserved.
# Adapted from https://github.com/huggingface/diffusers/blob/main/
# src/diffusers/models/resnet.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
[docs]class InflatedConv3d(nn.Conv2d):
"""An implementation of InflatedConv3d."""
[docs] def forward(self, x):
"""forward function."""
video_length = x.shape[2]
x = rearrange(x, 'b c f h w -> (b f) c h w')
x = super().forward(x)
x = rearrange(x, '(b f) c h w -> b c f h w', f=video_length)
return x
[docs]class InflatedGroupNorm(nn.GroupNorm):
[docs] def forward(self, x):
video_length = x.shape[2]
x = rearrange(x, 'b c f h w -> (b f) c h w')
x = super().forward(x)
x = rearrange(x, '(b f) c h w -> b c f h w', f=video_length)
return x
[docs]class Upsample3D(nn.Module):
"""An 3D upsampling layer with an optional convolution.
Args:
channels (int): channels in the inputs and outputs.
use_conv (bool): a bool determining if a convolution is applied.
use_conv_transpose (bool): whether to use conv transpose.
out_channels (int): output channels.
"""
def __init__(self,
channels,
use_conv=False,
use_conv_transpose=False,
out_channels=None,
name='conv'):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
if use_conv_transpose:
raise NotImplementedError
elif use_conv:
self.conv = InflatedConv3d(
self.channels, self.out_channels, 3, padding=1)
[docs] def forward(self, hidden_states, output_size=None):
"""forward with hidden states."""
assert hidden_states.shape[1] == self.channels
if self.use_conv_transpose:
raise NotImplementedError
# Cast to float32 to as 'upsample_nearest2d_out_frame'
# op does not support bfloat16
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.float32)
# upsample_nearest_nhwc fails with large batch sizes.
# see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
hidden_states = hidden_states.contiguous()
# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
if output_size is None:
hidden_states = F.interpolate(
hidden_states, scale_factor=[1.0, 2.0, 2.0], mode='nearest')
else:
hidden_states = F.interpolate(
hidden_states, size=output_size, mode='nearest')
# If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(dtype)
# if self.use_conv:
# if self.name == "conv":
# hidden_states = self.conv(hidden_states)
# else:
# hidden_states = self.Conv2d_0(hidden_states)
hidden_states = self.conv(hidden_states)
return hidden_states
[docs]class Downsample3D(nn.Module):
"""A 3D downsampling layer with an optional convolution.
Args:
channels (int): channels in the inputs and outputs.
use_conv (bool): a bool determining if a convolution is applied.
out_channels (int): output channels
padding (int): padding num
"""
def __init__(self,
channels,
use_conv=False,
out_channels=None,
padding=1,
name='conv'):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
self.name = name
if use_conv:
self.conv = InflatedConv3d(
self.channels,
self.out_channels,
3,
stride=stride,
padding=padding)
else:
raise NotImplementedError
[docs] def forward(self, hidden_states):
"""forward with hidden states."""
assert hidden_states.shape[1] == self.channels
if self.use_conv and self.padding == 0:
raise NotImplementedError
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states)
return hidden_states
[docs]class ResnetBlock3D(nn.Module):
"""3D resnet block support down sample and up sample.
Args:
in_channels (int): input channels.
out_channels (int): output channels.
conv_shortcut (bool): whether to use conv shortcut.
dropout (float): dropout rate.
temb_channels (int): time embedding channels.
groups (int): conv groups.
groups_out (int): conv out groups.
pre_norm (bool): whether to norm before conv. Todo: remove.
eps (float): eps for groupnorm.
non_linearity (str): non linearity type.
time_embedding_norm (str): time embedding norm type.
output_scale_factor (float): factor to scale input and output.
use_in_shortcut (bool): whether to use conv in shortcut.
"""
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity='swish',
time_embedding_norm='default',
output_scale_factor=1.0,
use_in_shortcut=None,
use_inflated_groupnorm=None,
):
super().__init__()
self.pre_norm = pre_norm
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
if use_inflated_groupnorm:
self.norm1 = InflatedGroupNorm(
num_groups=groups,
num_channels=in_channels,
eps=eps,
affine=True)
else:
self.norm1 = torch.nn.GroupNorm(
num_groups=groups,
num_channels=in_channels,
eps=eps,
affine=True)
self.conv1 = InflatedConv3d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
if self.time_embedding_norm == 'default':
time_emb_proj_out_channels = out_channels
elif self.time_embedding_norm == 'scale_shift':
time_emb_proj_out_channels = out_channels * 2
else:
raise ValueError(f'unknown time_embedding_norm : '
f'{self.time_embedding_norm} ')
self.time_emb_proj = torch.nn.Linear(temb_channels,
time_emb_proj_out_channels)
else:
self.time_emb_proj = None
if use_inflated_groupnorm:
self.norm2 = InflatedGroupNorm(
num_groups=groups_out,
num_channels=out_channels,
eps=eps,
affine=True)
else:
self.norm2 = torch.nn.GroupNorm(
num_groups=groups_out,
num_channels=out_channels,
eps=eps,
affine=True)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = InflatedConv3d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == 'swish':
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == 'mish':
self.nonlinearity = Mish()
elif non_linearity == 'silu':
self.nonlinearity = nn.SiLU()
self.use_in_shortcut = self.in_channels != self.out_channels \
if use_in_shortcut is None else use_in_shortcut
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = InflatedConv3d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0)
[docs] def forward(self, input_tensor, temb):
"""forward with hidden states and time embeddings."""
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None,
None, None]
if temb is not None and self.time_embedding_norm == 'default':
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == 'scale_shift':
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor +
hidden_states) / self.output_scale_factor
return output_tensor