• Docs >
  • Module code >
  • mmagic.models.editors.stable_diffusion.stable_diffusion_inpaint
Shortcuts

Source code for mmagic.models.editors.stable_diffusion.stable_diffusion_inpaint

# Copyright (c) OpenMMLab. All rights reserved.
# Copyright 2023 The HuggingFace Team. All rights reserved.
from typing import Dict, List, Optional, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.logging import MMLogger
from mmengine.runner import set_random_seed
from PIL import Image
from tqdm.auto import tqdm

from mmagic.registry import MODELS
from mmagic.utils.typing import SampleList
from .stable_diffusion import StableDiffusion

[docs]logger = MMLogger.get_current_instance()
[docs]ModelType = Union[Dict, nn.Module]
@MODELS.register_module('sd-inpaint') @MODELS.register_module()
[docs]class StableDiffusionInpaint(StableDiffusion): def __init__(self, *args, **kwargs): """Initializes the current class using the same parameters as its parent, StableDiffusion. This constructor is primarily a pass-through to the parent class's constructor. All arguments and keyword arguments provided are directly passed to the parent class, StableDiffusion. """ super().__init__(*args, **kwargs) @torch.no_grad()
[docs] def infer(self, prompt: Union[str, List[str]], image: Union[torch.FloatTensor, Image.Image] = None, mask_image: Union[torch.FloatTensor, Image.Image] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[torch.Generator] = None, latents: Optional[torch.FloatTensor] = None, show_progress=True, seed=1, return_type='image'): """Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image (`Union[torch.FloatTensor, Image.Image]`): The image to inpaint. mask_image (`Union[torch.FloatTensor, Image.Image]`): The mask to apply to the image, i.e. regions to inpaint. height (`int`, *optional*, defaults to self.unet_sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet_sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance] (https://arxiv.org/abs/2207.12598). negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): A [torch generator] to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated by sampling using the supplied random `generator`. return_type (str): The return type of the inference results. Supported types are 'image', 'numpy', 'tensor'. If 'image' is passed, a list of PIL images will be returned. If 'numpy' is passed, a numpy array with shape [N, C, H, W] will be returned, and the value range will be same as decoder's output range. If 'tensor' is passed, the decoder's output will be returned. Defaults to 'image'. Returns: dict: A dict containing the generated images. """ assert return_type in ['image', 'tensor', 'numpy'] set_random_seed(seed=seed) # 0. Default height and width to unet height = height or self.unet_sample_size * self.vae_scale_factor width = width or self.unet_sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, height, width) # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self.device img_dtype = self.vae.module.dtype if hasattr(self.vae, 'module') \ else self.vae.dtype latent_dtype = next(self.unet.parameters()).dtype # here `guidance_scale` is defined analog to the # guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . # `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_embeddings = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt) # 4. Prepare timesteps self.test_scheduler.set_timesteps(num_inference_steps) timesteps = self.test_scheduler.timesteps # 5. Prepare mask and image mask, masked_image = prepare_mask_and_masked_image( image, mask_image, height, width) # 6. Prepare latent variables if hasattr(self.unet, 'module'): num_channels_latents = self.vae.module.latent_channels num_channels_unet = self.unet.module.in_channels else: num_channels_latents = self.vae.latent_channels num_channels_unet = self.unet.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, text_embeddings.dtype, device, generator, latents, ) # 7. Prepare masked image latents mask, masked_image_latents = self.prepare_mask_latents( mask, masked_image, batch_size * num_images_per_prompt, num_channels_latents, height, width, text_embeddings.dtype, device, generator, do_classifier_free_guidance, ) # 8. Check that sizes of mask, masked image and latents match if num_channels_unet == 9: # default case for runwayml/stable-diffusion-inpainting num_channels_mask = mask.shape[1] num_channels_masked_image = masked_image_latents.shape[1] total_channels = num_channels_latents + \ num_channels_masked_image + num_channels_mask if total_channels != self.unet.in_channels: raise ValueError( 'Incorrect configuration settings! The config of ' f'`pipeline.unet`: {self.unet.config} expects' f' {self.unet.in_channels} but received ' f'`num_channels_latents`: {num_channels_latents} +' f' `num_channels_mask`: {num_channels_mask} + ' '`num_channels_masked_image`: ' f'{num_channels_masked_image} = {total_channels}.' 'Please verify the config of `pipeline.unet` ' 'or your `mask_image` or `image` input.') elif num_channels_unet != 4: raise ValueError( f'The unet {self.unet.__class__} should have either 4 or 9 ' f'input channels, not {self.unet.config.in_channels}.') # 9. Prepare extra step kwargs. # TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 10. Denoising loop if show_progress: timesteps = tqdm(timesteps) for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat( [latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.test_scheduler.scale_model_input( latent_model_input, t) # concat latents with mask if num_channels_unet == 9: latent_model_input = torch.cat( [latent_model_input, mask, masked_image_latents], dim=1) latent_model_input = latent_model_input.to(latent_dtype) text_embeddings = text_embeddings.to(latent_dtype) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=text_embeddings)['sample'] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.test_scheduler.step( noise_pred, t, latents, **extra_step_kwargs)['prev_sample'] if num_channels_unet == 4: assert NotImplementedError # 8. Post-processing image = self.decode_latents(latents.to(img_dtype)) if return_type == 'image': image = self.output_to_pil(image) elif return_type == 'numpy': image = image.cpu().numpy() else: assert return_type == 'tensor', ( 'Only support \'image\', \'numpy\' and \'tensor\' for ' f'return_type, but receive {return_type}') return {'samples': image}
[docs] def prepare_mask_latents(self, mask, masked_image, batch_size, num_channels_latents, height, width, dtype, device, generator, do_classifier_free_guidance): """prepare latents for diffusion to run in latent space. Args: mask (torch.Tensor): The mask to apply to the image, i.e. regions to inpaint. image (torch.Tensor): The image to be masked. batch_size (int): batch size. num_channels_latents (int): latent channel nums. height (int): image height. width (int): image width. dtype (torch.dtype): float type. device (torch.device): torch device. generator (torch.Generator): generator for random functions, defaults to None. latents (torch.Tensor): Pre-generated noisy latents, defaults to None. do_classifier_free_guidance (bool): Whether to apply classifier-free guidance. Return: latents (torch.Tensor): prepared latents. """ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) mask = F.interpolate( mask, size=shape[2:]).to( device=device, dtype=dtype) masked_image = masked_image.to(device=device, dtype=dtype) masked_image_latents = self.vae.encode( masked_image).latent_dist.sample(generator) masked_image_latents = self.vae.config.scaling_factor * \ masked_image_latents # duplicate mask and masked_image_latents for each generation per # prompt, using mps friendly method if mask.shape[0] < batch_size: if not batch_size % mask.shape[0] == 0: raise ValueError( "The passed mask and the required batch size don't match." 'Masks are supposed to be duplicated to a total batch' f' size of {batch_size}, but {mask.shape[0]} masks were ' 'passed. Make sure the number of masks that you pass' ' is divisible by the total requested batch size.') mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) if masked_image_latents.shape[0] < batch_size: if not batch_size % masked_image_latents.shape[0] == 0: raise ValueError( "The passed images and the required batch size don't " 'match. Images are supposed to be duplicated to a total' f' batch size of {batch_size}, but ' f'{masked_image_latents.shape[0]} images were passed.' ' Make sure the number of images that you pass is' 'divisible by the total requested batch size.') masked_image_latents = masked_image_latents.repeat( batch_size // masked_image_latents.shape[0], 1, 1, 1) mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask masked_image_latents = ( torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents) # aligning device to prevent device errors when concatenating it with # the latent model input masked_image_latents = masked_image_latents.to( device=device, dtype=dtype) return mask, masked_image_latents
@torch.no_grad()
[docs] def val_step(self, data: dict) -> SampleList: """Performs a validation step on the provided data. This method is decorated with `torch.no_grad()` which indicates no gradients will be computed during the operations. This ensures efficient memory usage during testing. Args: data (dict): Dictionary containing input data for testing. Returns: SampleList: List of samples processed during the testing step. Raises: NotImplementedError: This method has not been implemented. """ raise NotImplementedError
@torch.no_grad()
[docs] def test_step(self, data: dict) -> SampleList: """Performs a testing step on the provided data. This method is decorated with `torch.no_grad()` which indicates no gradients will be computed during the operations. This ensures efficient memory usage during testing. Args: data (dict): Dictionary containing input data for testing. Returns: SampleList: List of samples processed during the testing step. Raises: NotImplementedError: This method has not been implemented. """ raise NotImplementedError
[docs] def train_step(self, data, optim_wrapper_dict): """Performs a training step on the provided data. Args: data: Input data for training. optim_wrapper_dict: Dictionary containing optimizer wrappers which may contain optimizers, schedulers, etc. required for the training step. Raises: NotImplementedError: This method has not been implemented. """ raise NotImplementedError
[docs]def prepare_mask_and_masked_image(image: torch.Tensor, mask: torch.Tensor, height: int = 512, width: int = 512, return_image: bool = False): """Prepare latents for diffusion to run in latent space. Args: image (torch.Tensor): The image to be masked. mask (torch.Tensor): The mask to apply to the image, i.e. regions to inpaint. height (int): Image height. width (int): Image width. return_image (bool): Whether to return the original image. Default to `False`. Returns: mask (torch.Tensor): A binary mask image. masked_image (torch.Tensor): An image that applied mask. """ if image is None: raise ValueError('`image` input cannot be undefined.') if mask is None: raise ValueError('`mask_image` input cannot be undefined.') if isinstance(image, torch.Tensor): if not isinstance(mask, torch.Tensor): raise TypeError('`image` is a torch.Tensor but `mask` (type: ' f'{type(mask)} is not') # Batch single image if image.ndim == 3: assert image.shape[ 0] == 3, 'Image outside a batch should be of shape (3, H, W)' image = image.unsqueeze(0) # Batch and add channel dim for single mask if mask.ndim == 2: mask = mask.unsqueeze(0).unsqueeze(0) # Batch single mask or add channel dim if mask.ndim == 3: # Single batched mask, no channel dim or single mask # not batched but channel dim if mask.shape[0] == 1: mask = mask.unsqueeze(0) # Batched masks no channel dim else: mask = mask.unsqueeze(1) assert (image.ndim == 4 and mask.ndim == 4), 'Image and Mask must have 4 dimensions' assert image.shape[-2:] == mask.shape[ -2:], 'Image and Mask must have the same spatial dimensions' assert image.shape[0] == mask.shape[ 0], 'Image and Mask must have the same batch size' # Check image is in [-1, 1] if image.min() < -1 or image.max() > 1: raise ValueError('Image should be in [-1, 1] range') # Check mask is in [0, 1] if mask.min() < 0 or mask.max() > 1: raise ValueError('Mask should be in [0, 1] range') # Binarize mask mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 # Image as float32 image = image.to(dtype=torch.float32) elif isinstance(mask, torch.Tensor): raise TypeError( f'`mask` is a torch.Tensor but `image` (type: {type(image)} is not' ) else: # preprocess image if isinstance(image, (Image.Image, np.ndarray)): image = [image] if isinstance(image, list) and isinstance(image[0], Image.Image): # resize all images w.r.t passed height an width image = [ i.resize((width, height), resample=Image.LANCZOS) for i in image ] image = [np.array(i.convert('RGB'))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): image = np.concatenate([i[None, :] for i in image], axis=0) image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 # preprocess mask if isinstance(mask, (Image.Image, np.ndarray)): mask = [mask] if isinstance(mask, list) and isinstance(mask[0], Image.Image): mask = [ i.resize((width, height), resample=Image.LANCZOS) for i in mask ] mask = np.concatenate( [np.array(m.convert('L'))[None, None, :] for m in mask], axis=0) mask = mask.astype(np.float32) / 255.0 elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): mask = np.concatenate([m[None, None, :] for m in mask], axis=0) mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) masked_image = image * (mask < 0.5) # n.b. ensure backwards compatibility as old function does not return image if return_image: return mask, masked_image, image return mask, masked_image
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.