Source code for mmagic.models.editors.stable_diffusion.stable_diffusion_inpaint
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright 2023 The HuggingFace Team. All rights reserved.
from typing import Dict, List, Optional, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.logging import MMLogger
from mmengine.runner import set_random_seed
from PIL import Image
from tqdm.auto import tqdm
from mmagic.registry import MODELS
from mmagic.utils.typing import SampleList
from .stable_diffusion import StableDiffusion
@MODELS.register_module('sd-inpaint')
@MODELS.register_module()
[docs]class StableDiffusionInpaint(StableDiffusion):
def __init__(self, *args, **kwargs):
"""Initializes the current class using the same parameters as its
parent, StableDiffusion.
This constructor is primarily a pass-through to the parent class's
constructor. All arguments and keyword arguments provided are directly
passed to the parent class, StableDiffusion.
"""
super().__init__(*args, **kwargs)
@torch.no_grad()
[docs] def infer(self,
prompt: Union[str, List[str]],
image: Union[torch.FloatTensor, Image.Image] = None,
mask_image: Union[torch.FloatTensor, Image.Image] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
show_progress=True,
seed=1,
return_type='image'):
"""Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
image (`Union[torch.FloatTensor, Image.Image]`):
The image to inpaint.
mask_image (`Union[torch.FloatTensor, Image.Image]`):
The mask to apply to the image, i.e. regions to inpaint.
height (`int`, *optional*,
defaults to self.unet_sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*,
defaults to self.unet_sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
More denoising steps usually lead to a higher
quality image at the expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in
[Classifier-Free Diffusion Guidance]
(https://arxiv.org/abs/2207.12598).
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation.
Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper:
https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator] to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents,
sampled from a Gaussian distribution,
to be used as inputs for image generation.
Can be used to tweak the same generation
with different prompts.
If not provided, a latents tensor will be
generated by sampling using the supplied random `generator`.
return_type (str): The return type of the inference results.
Supported types are 'image', 'numpy', 'tensor'. If 'image'
is passed, a list of PIL images will be returned. If 'numpy'
is passed, a numpy array with shape [N, C, H, W] will be
returned, and the value range will be same as decoder's
output range. If 'tensor' is passed, the decoder's output
will be returned. Defaults to 'image'.
Returns:
dict: A dict containing the generated images.
"""
assert return_type in ['image', 'tensor', 'numpy']
set_random_seed(seed=seed)
# 0. Default height and width to unet
height = height or self.unet_sample_size * self.vae_scale_factor
width = width or self.unet_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self.device
img_dtype = self.vae.module.dtype if hasattr(self.vae, 'module') \
else self.vae.dtype
latent_dtype = next(self.unet.parameters()).dtype
# here `guidance_scale` is defined analog to the
# guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf .
# `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(prompt, device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt)
# 4. Prepare timesteps
self.test_scheduler.set_timesteps(num_inference_steps)
timesteps = self.test_scheduler.timesteps
# 5. Prepare mask and image
mask, masked_image = prepare_mask_and_masked_image(
image, mask_image, height, width)
# 6. Prepare latent variables
if hasattr(self.unet, 'module'):
num_channels_latents = self.vae.module.latent_channels
num_channels_unet = self.unet.module.in_channels
else:
num_channels_latents = self.vae.latent_channels
num_channels_unet = self.unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
text_embeddings.dtype,
device,
generator,
latents,
)
# 7. Prepare masked image latents
mask, masked_image_latents = self.prepare_mask_latents(
mask,
masked_image,
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
text_embeddings.dtype,
device,
generator,
do_classifier_free_guidance,
)
# 8. Check that sizes of mask, masked image and latents match
if num_channels_unet == 9:
# default case for runwayml/stable-diffusion-inpainting
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
total_channels = num_channels_latents + \
num_channels_masked_image + num_channels_mask
if total_channels != self.unet.in_channels:
raise ValueError(
'Incorrect configuration settings! The config of '
f'`pipeline.unet`: {self.unet.config} expects'
f' {self.unet.in_channels} but received '
f'`num_channels_latents`: {num_channels_latents} +'
f' `num_channels_mask`: {num_channels_mask} + '
'`num_channels_masked_image`: '
f'{num_channels_masked_image} = {total_channels}.'
'Please verify the config of `pipeline.unet` '
'or your `mask_image` or `image` input.')
elif num_channels_unet != 4:
raise ValueError(
f'The unet {self.unet.__class__} should have either 4 or 9 '
f'input channels, not {self.unet.config.in_channels}.')
# 9. Prepare extra step kwargs.
# TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 10. Denoising loop
if show_progress:
timesteps = tqdm(timesteps)
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat(
[latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.test_scheduler.scale_model_input(
latent_model_input, t)
# concat latents with mask
if num_channels_unet == 9:
latent_model_input = torch.cat(
[latent_model_input, mask, masked_image_latents], dim=1)
latent_model_input = latent_model_input.to(latent_dtype)
text_embeddings = text_embeddings.to(latent_dtype)
# predict the noise residual
noise_pred = self.unet(
latent_model_input, t,
encoder_hidden_states=text_embeddings)['sample']
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.test_scheduler.step(
noise_pred, t, latents, **extra_step_kwargs)['prev_sample']
if num_channels_unet == 4:
assert NotImplementedError
# 8. Post-processing
image = self.decode_latents(latents.to(img_dtype))
if return_type == 'image':
image = self.output_to_pil(image)
elif return_type == 'numpy':
image = image.cpu().numpy()
else:
assert return_type == 'tensor', (
'Only support \'image\', \'numpy\' and \'tensor\' for '
f'return_type, but receive {return_type}')
return {'samples': image}
[docs] def prepare_mask_latents(self, mask, masked_image, batch_size,
num_channels_latents, height, width, dtype,
device, generator, do_classifier_free_guidance):
"""prepare latents for diffusion to run in latent space.
Args:
mask (torch.Tensor): The mask to apply to the image, i.e. regions
to inpaint.
image (torch.Tensor): The image to be masked.
batch_size (int): batch size.
num_channels_latents (int): latent channel nums.
height (int): image height.
width (int): image width.
dtype (torch.dtype): float type.
device (torch.device): torch device.
generator (torch.Generator):
generator for random functions, defaults to None.
latents (torch.Tensor):
Pre-generated noisy latents, defaults to None.
do_classifier_free_guidance (bool): Whether to apply
classifier-free guidance.
Return:
latents (torch.Tensor): prepared latents.
"""
shape = (batch_size, num_channels_latents,
height // self.vae_scale_factor,
width // self.vae_scale_factor)
mask = F.interpolate(
mask, size=shape[2:]).to(
device=device, dtype=dtype)
masked_image = masked_image.to(device=device, dtype=dtype)
masked_image_latents = self.vae.encode(
masked_image).latent_dist.sample(generator)
masked_image_latents = self.vae.config.scaling_factor * \
masked_image_latents
# duplicate mask and masked_image_latents for each generation per
# prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match."
'Masks are supposed to be duplicated to a total batch'
f' size of {batch_size}, but {mask.shape[0]} masks were '
'passed. Make sure the number of masks that you pass'
' is divisible by the total requested batch size.')
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't "
'match. Images are supposed to be duplicated to a total'
f' batch size of {batch_size}, but '
f'{masked_image_latents.shape[0]} images were passed.'
' Make sure the number of images that you pass is'
'divisible by the total requested batch size.')
masked_image_latents = masked_image_latents.repeat(
batch_size // masked_image_latents.shape[0], 1, 1, 1)
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
masked_image_latents = (
torch.cat([masked_image_latents] * 2)
if do_classifier_free_guidance else masked_image_latents)
# aligning device to prevent device errors when concatenating it with
# the latent model input
masked_image_latents = masked_image_latents.to(
device=device, dtype=dtype)
return mask, masked_image_latents
@torch.no_grad()
[docs] def val_step(self, data: dict) -> SampleList:
"""Performs a validation step on the provided data.
This method is decorated with `torch.no_grad()` which indicates no
gradients will be computed during the operations. This ensures
efficient memory usage during testing.
Args:
data (dict): Dictionary containing input data for testing.
Returns:
SampleList: List of samples processed during the testing step.
Raises:
NotImplementedError: This method has not been implemented.
"""
raise NotImplementedError
@torch.no_grad()
[docs] def test_step(self, data: dict) -> SampleList:
"""Performs a testing step on the provided data.
This method is decorated with `torch.no_grad()` which indicates no
gradients will be computed during the operations. This ensures
efficient memory usage during testing.
Args:
data (dict): Dictionary containing input data for testing.
Returns:
SampleList: List of samples processed during the testing step.
Raises:
NotImplementedError: This method has not been implemented.
"""
raise NotImplementedError
[docs] def train_step(self, data, optim_wrapper_dict):
"""Performs a training step on the provided data.
Args:
data: Input data for training.
optim_wrapper_dict: Dictionary containing optimizer wrappers
which may contain optimizers, schedulers, etc. required
for the training step.
Raises:
NotImplementedError: This method has not been implemented.
"""
raise NotImplementedError
[docs]def prepare_mask_and_masked_image(image: torch.Tensor,
mask: torch.Tensor,
height: int = 512,
width: int = 512,
return_image: bool = False):
"""Prepare latents for diffusion to run in latent space.
Args:
image (torch.Tensor): The image to be masked.
mask (torch.Tensor): The mask to apply to the image, i.e. regions
to inpaint.
height (int): Image height.
width (int): Image width.
return_image (bool): Whether to return the original image.
Default to `False`.
Returns:
mask (torch.Tensor): A binary mask image.
masked_image (torch.Tensor): An image that applied mask.
"""
if image is None:
raise ValueError('`image` input cannot be undefined.')
if mask is None:
raise ValueError('`mask_image` input cannot be undefined.')
if isinstance(image, torch.Tensor):
if not isinstance(mask, torch.Tensor):
raise TypeError('`image` is a torch.Tensor but `mask` (type: '
f'{type(mask)} is not')
# Batch single image
if image.ndim == 3:
assert image.shape[
0] == 3, 'Image outside a batch should be of shape (3, H, W)'
image = image.unsqueeze(0)
# Batch and add channel dim for single mask
if mask.ndim == 2:
mask = mask.unsqueeze(0).unsqueeze(0)
# Batch single mask or add channel dim
if mask.ndim == 3:
# Single batched mask, no channel dim or single mask
# not batched but channel dim
if mask.shape[0] == 1:
mask = mask.unsqueeze(0)
# Batched masks no channel dim
else:
mask = mask.unsqueeze(1)
assert (image.ndim == 4
and mask.ndim == 4), 'Image and Mask must have 4 dimensions'
assert image.shape[-2:] == mask.shape[
-2:], 'Image and Mask must have the same spatial dimensions'
assert image.shape[0] == mask.shape[
0], 'Image and Mask must have the same batch size'
# Check image is in [-1, 1]
if image.min() < -1 or image.max() > 1:
raise ValueError('Image should be in [-1, 1] range')
# Check mask is in [0, 1]
if mask.min() < 0 or mask.max() > 1:
raise ValueError('Mask should be in [0, 1] range')
# Binarize mask
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
# Image as float32
image = image.to(dtype=torch.float32)
elif isinstance(mask, torch.Tensor):
raise TypeError(
f'`mask` is a torch.Tensor but `image` (type: {type(image)} is not'
)
else:
# preprocess image
if isinstance(image, (Image.Image, np.ndarray)):
image = [image]
if isinstance(image, list) and isinstance(image[0], Image.Image):
# resize all images w.r.t passed height an width
image = [
i.resize((width, height), resample=Image.LANCZOS)
for i in image
]
image = [np.array(i.convert('RGB'))[None, :] for i in image]
image = np.concatenate(image, axis=0)
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
image = np.concatenate([i[None, :] for i in image], axis=0)
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
# preprocess mask
if isinstance(mask, (Image.Image, np.ndarray)):
mask = [mask]
if isinstance(mask, list) and isinstance(mask[0], Image.Image):
mask = [
i.resize((width, height), resample=Image.LANCZOS) for i in mask
]
mask = np.concatenate(
[np.array(m.convert('L'))[None, None, :] for m in mask],
axis=0)
mask = mask.astype(np.float32) / 255.0
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5)
# n.b. ensure backwards compatibility as old function does not return image
if return_image:
return mask, masked_image, image
return mask, masked_image