mmagic.models.editors.disco_diffusion.guider 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import math
import lpips
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from mmengine.utils import digit_version
from resize_right import resize
from torchvision import __version__ as TORCHVISION_VERSION
from mmagic.models.losses import tv_loss
from mmagic.utils import try_import
from .secondary_model import alpha_sigma_to_t
[文档]normalize = T.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
[文档]def sinc(x):
"""
Sinc function.
If x equal to 0,
sinc(x) = 1
else:
sinc(x) = sin(x)/ x
Args:
x (torch.Tensor): Input Tensor
Returns:
torch.Tensor: Function output.
"""
return torch.where(x != 0,
torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
[文档]def lanczos(x, a):
"""Lanczos filter's reconstruction kernel L(x)."""
cond = torch.logical_and(-a < x, x < a)
out = torch.where(cond, sinc(x) * sinc(x / a), x.new_zeros([]))
return out / out.sum()
[文档]def ramp(ratio, width):
"""_summary_
Args:
ratio (_type_): _description_
width (_type_): _description_
Returns:
_type_: _description_
"""
n = math.ceil(width / ratio + 1)
out = torch.empty([n])
cur = 0
for i in range(out.shape[0]):
out[i] = cur
cur += ratio
return torch.cat([-out[1:].flip([0]), out])[1:-1]
[文档]def resample(input, size, align_corners=True):
"""Lanczos resampling image.
Args:
input (torch.Tensor): Input image tensor.
size (Tuple[int, int]): Output image size.
align_corners (bool): align_corners argument of F.interpolate.
Defaults to True.
Returns:
torch.Tensor: Resampling results.
"""
n, c, h, w = input.shape
dh, dw = size
input = input.reshape([n * c, 1, h, w])
if dh < h:
kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
pad_h = (kernel_h.shape[0] - 1) // 2
input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
input = F.conv2d(input, kernel_h[None, None, :, None])
if dw < w:
kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
pad_w = (kernel_w.shape[0] - 1) // 2
input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
input = F.conv2d(input, kernel_w[None, None, None, :])
input = input.reshape([n, c, h, w])
return F.interpolate(
input, size, mode='bicubic', align_corners=align_corners)
[文档]def range_loss(input):
"""range loss."""
return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])
[文档]def spherical_dist_loss(x, y):
"""spherical distance loss."""
x = F.normalize(x, dim=-1)
y = F.normalize(y, dim=-1)
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
[文档]class MakeCutouts(nn.Module):
"""Each iteration, the AI cuts the image into smaller pieces known as cuts.
, and compares each cut to the prompt to decide how to guide the next
diffusion step.
This classes will randomly cut patches and perform image augmentation to
these patches.
Args:
cut_size (int): Size of the patches.
cutn (int): Number of patches to cut.
"""
def __init__(self, cut_size, cutn):
super().__init__()
self.cut_size = cut_size
self.cutn = cutn
self.augs = T.Compose([
T.RandomHorizontalFlip(p=0.5),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
T.RandomAffine(degrees=15, translate=(0.1, 0.1)),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
T.RandomPerspective(distortion_scale=0.4, p=0.7),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
T.RandomGrayscale(p=0.15),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
])
[文档] def forward(self, input, skip_augs=False):
input = T.Pad(input.shape[2] // 4, fill=0)(input)
sideY, sideX = input.shape[2:4]
max_size = min(sideX, sideY)
cutouts = []
for ch in range(self.cutn):
if ch > self.cutn - self.cutn // 4:
cutout = input.clone()
else:
size = int(max_size * torch.zeros(1, ).normal_(
mean=.8, std=.3).clip(float(self.cut_size / max_size), 1.))
offsetx = torch.randint(0, abs(sideX - size + 1), ())
offsety = torch.randint(0, abs(sideY - size + 1), ())
cutout = input[:, :, offsety:offsety + size,
offsetx:offsetx + size]
if not skip_augs:
cutout = self.augs(cutout)
cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
del cutout
cutouts = torch.cat(cutouts, dim=0)
return cutouts
[文档]class MakeCutoutsDango(nn.Module):
"""Dango233(https://github.com/Dango233)'s version of MakeCutouts.
The improvement compared to ``MakeCutouts`` is that it use partial
greyscale augmentation to capture structure, and partial rotation
augmentation to capture whole frames.
Args:
cut_size (int): Size of the patches.
Overview (int): The total number of overview cuts.
In details,
Overview=1, Add whole frame;
Overview=2, Add grayscaled frame;
Overview=3, Add horizontal flip frame;
Overview=4, Add grayscaled horizontal flip frame;
Overview>4, Repeat add frame Overview times.
Defaults to 4.
InnerCrop (int): The total number of inner cuts.
Defaults to 0.
IC_Size_Pow (float): This sets the size of the border
used for inner cuts. High values have larger borders,
and therefore the cuts themselves will be smaller and
provide finer details. Defaults to 0.5.
IC_Grey_P (float): The portion of the inner cuts can be set to be
grayscale instead of color. This may help with improved
definition of shapes and edges, especially in the early
diffusion steps where the image structure is being defined.
Defaults to 0.2.
"""
def __init__(self,
cut_size,
Overview=4,
InnerCrop=0,
IC_Size_Pow=0.5,
IC_Grey_P=0.2):
super().__init__()
self.cut_size = cut_size
self.Overview = Overview
self.InnerCrop = InnerCrop
self.IC_Size_Pow = IC_Size_Pow
self.IC_Grey_P = IC_Grey_P
random_affine_args = dict(degrees=10, translate=(0.05, 0.05))
if digit_version(TORCHVISION_VERSION) >= digit_version('0.9.0'):
random_affine_args['interpolation'] = T.InterpolationMode.BILINEAR
else:
from PIL import Image
random_affine_args['resample'] = Image.NEAREST
self.augs = T.Compose([
T.RandomHorizontalFlip(p=0.5),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
T.RandomAffine(**random_affine_args),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
T.RandomGrayscale(p=0.1),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
T.ColorJitter(
brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
])
[文档] def forward(self, input, skip_augs=False):
"""Forward function."""
cutouts = []
gray = T.Grayscale(3)
sideY, sideX = input.shape[2:4]
max_size = min(sideX, sideY)
min_size = min(sideX, sideY, self.cut_size)
output_shape = [1, 3, self.cut_size, self.cut_size]
pad_input = F.pad(input,
((sideY - max_size) // 2, (sideY - max_size) // 2,
(sideX - max_size) // 2, (sideX - max_size) // 2))
cutout = resize(pad_input, out_shape=output_shape)
if self.Overview > 0:
if self.Overview <= 4:
if self.Overview >= 1:
cutouts.append(cutout)
if self.Overview >= 2:
cutouts.append(gray(cutout))
if self.Overview >= 3:
cutouts.append(TF.hflip(cutout))
if self.Overview == 4:
cutouts.append(gray(TF.hflip(cutout)))
else:
cutout = resize(pad_input, out_shape=output_shape)
for _ in range(self.Overview):
cutouts.append(cutout)
if self.InnerCrop > 0:
for i in range(self.InnerCrop):
size = int(
torch.rand([])**self.IC_Size_Pow * (max_size - min_size) +
min_size)
offsetx = torch.randint(0, sideX - size + 1, ())
offsety = torch.randint(0, sideY - size + 1, ())
cutout = input[:, :, offsety:offsety + size,
offsetx:offsetx + size]
if i <= int(self.IC_Grey_P * self.InnerCrop):
cutout = gray(cutout)
cutout = resize(cutout, out_shape=output_shape)
cutouts.append(cutout)
cutouts = torch.cat(cutouts)
if not skip_augs:
cutouts = self.augs(cutouts)
return cutouts
[文档]def parse_prompt(prompt):
"""Parse prompt, return text and text weight."""
if prompt.startswith('http://') or prompt.startswith('https://'):
vals = prompt.rsplit(':', 2)
vals = [vals[0] + ':' + vals[1], *vals[2:]]
else:
vals = prompt.rsplit(':', 1)
vals = vals + ['', '1'][len(vals):]
return vals[0], float(vals[1])
[文档]def split_prompts(prompts, max_frames=1):
"""Split prompts to a list of prompts."""
prompt_series = pd.Series([np.nan for a in range(max_frames)])
for i, prompt in prompts.items():
prompt_series[i] = prompt
# prompt_series = prompt_series.astype(str)
prompt_series = prompt_series.ffill().bfill()
return prompt_series
[文档]class ImageTextGuider(nn.Module):
"""Disco-Diffusion uses text and images to guide image generation. We will
use the clip models to extract text and image features as prompts, and then
during the iteration, the features of the image patches are computed, and
the similarity loss between the prompts features and the generated features
is computed. Other losses also include RGB Range loss, total variation
loss. Using these losses we can guide the image generation towards the
desired target.
Args:
clip_models (List[Dict]): List of clip model settings.
"""
def __init__(self, clip_models):
super().__init__()
assert clip is not None, (
"Cannot import 'clip'. Please install 'clip' via "
"\"pip install git+https://github.com/openai/CLIP.git\".")
self.clip_models = clip_models
self.lpips_model = lpips.LPIPS(net='vgg')
[文档] def frame_prompt_from_text(self, text_prompts, frame_num=0):
"""Get current frame prompt."""
prompts_series = split_prompts(text_prompts)
if prompts_series is not None and frame_num >= len(prompts_series):
frame_prompt = prompts_series[-1]
elif prompts_series is not None:
frame_prompt = prompts_series[frame_num]
else:
frame_prompt = []
return frame_prompt
[文档] def compute_prompt_stats(self,
text_prompts=[],
image_prompt=None,
fuzzy_prompt=False,
rand_mag=0.05):
"""Compute prompts statistics.
Args:
text_prompts (list): Text prompts. Defaults to [].
image_prompt (list): Image prompts. Defaults to None.
fuzzy_prompt (bool, optional): Controls whether to add multiple
noisy prompts to the prompt losses. If True, can increase
variability of image output. Defaults to False.
rand_mag (float, optional): Controls the magnitude of the
random noise added by fuzzy_prompt. Defaults to 0.05.
"""
model_stats = []
frame_prompt = self.frame_prompt_from_text(text_prompts)
for clip_model in self.clip_models:
model_stat = {
'clip_model': None,
'target_embeds': [],
'make_cutouts': None,
'weights': []
}
model_stat['clip_model'] = clip_model
for prompt in frame_prompt:
txt, weight = parse_prompt(prompt)
txt = clip_model.model.encode_text(
clip.tokenize(prompt).to(self.device)).float()
if fuzzy_prompt:
for i in range(25):
model_stat['target_embeds'].append(
(txt +
torch.randn(txt.shape).cuda() * rand_mag).clamp(
0, 1))
model_stat['weights'].append(weight)
else:
model_stat['target_embeds'].append(txt)
model_stat['weights'].append(weight)
model_stat['target_embeds'] = torch.cat(
model_stat['target_embeds'])
model_stat['weights'] = torch.tensor(
model_stat['weights'], device=self.device)
if model_stat['weights'].sum().abs() < 1e-3:
raise RuntimeError('The weights must not sum to 0.')
model_stat['weights'] /= model_stat['weights'].sum().abs()
model_stats.append(model_stat)
return model_stats
[文档] def cond_fn(self,
model,
diffusion_scheduler,
x,
t,
beta_prod_t,
model_stats,
secondary_model=None,
init_image=None,
clamp_grad=True,
clamp_max=0.05,
clip_guidance_scale=5000,
init_scale=1000,
tv_scale=0.,
sat_scale=0.,
range_scale=150,
cut_overview=[12] * 400 + [4] * 600,
cut_innercut=[4] * 400 + [12] * 600,
cut_ic_pow=[1] * 1000,
cut_icgray_p=[0.2] * 400 + [0] * 600,
cutn_batches=4):
"""Clip guidance function.
Args:
model (nn.Module): _description_
diffusion_scheduler (object): _description_
x (torch.Tensor): _description_
t (int): _description_
beta_prod_t (torch.Tensor): _description_
model_stats (List[torch.Tensor]): _description_
secondary_model (nn.Module): A smaller secondary diffusion model
trained by Katherine Crowson to remove noise from intermediate
timesteps to prepare them for CLIP.
Ref: https://twitter.com/rivershavewings/status/1462859669454536711 # noqa
Defaults to None.
init_image (torch.Tensor): Initial image for denoising.
Defaults to None.
clamp_grad (bool, optional): Whether clamp gradient. Defaults to True.
clamp_max (float, optional): Clamp max values. Defaults to 0.05.
clip_guidance_scale (int, optional): The scale of influence of
clip guidance on image generation. Defaults to 5000.
"""
with torch.enable_grad():
x_is_NaN = False
x = x.detach().requires_grad_()
n = x.shape[0]
if secondary_model is not None:
alpha = torch.tensor(
diffusion_scheduler.alphas_cumprod[t]**0.5,
dtype=torch.float32)
sigma = torch.tensor(
(1 - diffusion_scheduler.alphas_cumprod[t])**0.5,
dtype=torch.float32)
cosine_t = alpha_sigma_to_t(alpha, sigma).to(x.device)
model_output = secondary_model(
x, cosine_t[None].repeat([x.shape[0]]))
pred_original_sample = model_output['pred']
else:
model_output = model(x, t)['sample']
model_output, predicted_variance = torch.split(
model_output, x.shape[1], dim=1)
alpha_prod_t = 1 - beta_prod_t
pred_original_sample = (x - beta_prod_t**(0.5) *
model_output) / alpha_prod_t**(0.5)
# fac = diffusion_scheduler_output['beta_prod_t']** (0.5)
# x_in = diffusion_scheduler_output['original_sample'] * fac + x * (1 - fac) # noqa
fac = beta_prod_t**(0.5)
x_in = pred_original_sample * fac + x * (1 - fac)
x_in_grad = torch.zeros_like(x_in)
for model_stat in model_stats:
for i in range(cutn_batches):
t_int = int(t.item()) + 1
try:
input_resolution = model_stat[
'clip_model'].model.visual.input_resolution
except AttributeError:
input_resolution = 224
cuts = MakeCutoutsDango(
input_resolution,
Overview=cut_overview[1000 - t_int],
InnerCrop=cut_innercut[1000 - t_int],
IC_Size_Pow=cut_ic_pow[1000 - t_int],
IC_Grey_P=cut_icgray_p[1000 - t_int])
clip_in = normalize(cuts(x_in.add(1).div(2)))
image_embeds = model_stat['clip_model'].model.encode_image(
clip_in).float()
dists = spherical_dist_loss(
image_embeds.unsqueeze(1),
model_stat['target_embeds'].unsqueeze(0))
dists = dists.view([
cut_overview[1000 - t_int] +
cut_innercut[1000 - t_int], n, -1
])
losses = dists.mul(model_stat['weights']).sum(2).mean(0)
x_in_grad += torch.autograd.grad(
losses.sum() * clip_guidance_scale,
x_in)[0] / cutn_batches
tv_losses = tv_loss(x_in)
range_losses = range_loss(pred_original_sample)
sat_losses = torch.abs(x_in - x_in.clamp(min=-1, max=1)).mean()
loss = tv_losses.sum() * tv_scale + range_losses.sum(
) * range_scale + sat_losses.sum() * sat_scale
if init_image is not None and init_scale:
init_losses = self.lpips_model(x_in, init_image)
loss = loss + init_losses.sum() * init_scale
x_in_grad += torch.autograd.grad(loss, x_in)[0]
if not torch.isnan(x_in_grad).any():
grad = -torch.autograd.grad(x_in, x, x_in_grad)[0]
else:
x_is_NaN = True
grad = torch.zeros_like(x)
if clamp_grad and not x_is_NaN:
magnitude = grad.square().mean().sqrt()
return grad * magnitude.clamp(max=clamp_max) / magnitude
return grad
@property
[文档] def device(self):
"""Get current device of the model.
Returns:
torch.device: The current device of the model.
"""
return next(self.parameters()).device
[文档] def forward(self, x):
"""forward function."""
raise NotImplementedError('No forward function for disco guider')