# Copyright (c) OpenMMLab. All rights reserved.
from logging import WARNING
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine import print_log
from mmengine.model import is_model_wrapper
from mmengine.optim import OptimWrapperDict
from PIL import Image
from torch import Tensor
from tqdm import tqdm

from mmagic.models.archs import AttentionInjection
from mmagic.models.utils import build_module
from mmagic.registry import MODELS
from mmagic.structures import DataSample
from mmagic.utils.typing import SampleList
from ..stable_diffusion import StableDiffusion
from .controlnet_utils import change_base_model

[docs]ModelType = Union[Dict, nn.Module]
[docs]class ControlStableDiffusion(StableDiffusion): """Implementation of `ControlNet with Stable Diffusion. <>`_ (ControlNet). Args: vae (Union[dict, nn.Module]): The config or module for VAE model. text_encoder (Union[dict, nn.Module]): The config or module for text encoder. tokenizer (str): The **name** for CLIP tokenizer. unet (Union[dict, nn.Module]): The config or module for Unet model. controlnet (Union[dict, nn.Module]): The config or module for ControlNet. schedule (Union[dict, nn.Module]): The config or module for diffusion scheduler. test_scheduler (Union[dict, nn.Module], optional): The config or module for diffusion scheduler in test stage (`self.infer`). If not passed, will use the same scheduler as `schedule`. Defaults to None. dtype (str, optional): The dtype for the model. Defaults to 'fp16'. enable_xformers (bool, optional): Whether to use xformers. Defaults to True. noise_offset_weight (bool, optional): The weight of noise offset introduced in # noqa Defaults to 0. data_preprocessor (dict, optional): The pre-process config of :class:`BaseDataPreprocessor`. Defaults to dict(type='DataPreprocessor'). init_cfg (dict, optional): The weight initialized config for :class:`BaseModule`. Defaults to None/ """ def __init__(self, vae: ModelType, text_encoder: ModelType, tokenizer: str, unet: ModelType, controlnet: ModelType, scheduler: ModelType, test_scheduler: Optional[ModelType] = None, dtype: str = 'fp32', enable_xformers: bool = True, noise_offset_weight: float = 0, tomesd_cfg: Optional[dict] = None, data_preprocessor=dict(type='DataPreprocessor'), init_cfg: Optional[dict] = None, attention_injection=False): super().__init__(vae, text_encoder, tokenizer, unet, scheduler, test_scheduler, dtype, enable_xformers, noise_offset_weight, tomesd_cfg, data_preprocessor, init_cfg) default_args = dict() if dtype is not None: default_args['dtype'] = dtype # NOTE: initialize controlnet as fp32 self.controlnet = build_module(controlnet, MODELS) self._controlnet_ori_dtype = next(self.controlnet.parameters()).dtype print_log( 'Set ControlNetModel dtype to ' f'\'{self._controlnet_ori_dtype}\'.', 'current') self.set_xformers(self.controlnet) self.vae.requires_grad_(False) self.text_encoder.requires_grad_(False) self.unet.requires_grad_(False) if attention_injection: self.unet = AttentionInjection(self.unet)
[docs] def init_weights(self): """Initialize the weights. Noted that this function will only be called at train. If you want to inference with a different unet model, you can call this function manually or use `mmagic.models.editors.controlnet.controlnet_utils.change_base_model` to convert the weight of ControlNet manually. Example: >>> 1. init controlnet from unet >>> init_cfg = dict(type='init_from_unet') >>> 2. switch controlnet weight from unet >>> # base model is not defined, use `runwayml/stable-diffusion-v1-5` >>> # as default >>> init_cfg = dict(type='convert_from_unet') >>> # base model is defined >>> init_cfg = dict( >>> type='convert_from_unet', >>> base_model=dict( >>> type='UNet2DConditionModel', >>> from_pretrained='REPO_ID', >>> subfolder='unet')) """ if self.init_cfg is not None: init_type = self.init_cfg.get('type', None) else: init_type = None if init_type == 'init_from_unet': # fetch module if is_model_wrapper(self.controlnet): controlnet = self.controlnet.module else: controlnet = self.controlnet if is_model_wrapper(self.unet): unet = self.unet.module else: unet = self.unet if controlnet._from_pretrained is not None: print_log( 'ControlNet has initialized from pretrained ' f'weight \'{controlnet._from_pretrained}\'.' ' Re-initialize ControlNet from Unet.', 'current', WARNING) # copy weight log_template = 'Initialize weight ControlNet from Unet: {}' for n, p in unet.named_parameters(): if n in controlnet.state_dict(): print_log(log_template.format(n), 'current') controlnet.state_dict()[n].copy_( # check zero_conv zero_conv_blocks = controlnet.controlnet_down_blocks for n, p in zero_conv_blocks.named_parameters(): if not (p == 0).all(): print_log(f'{n} in ControlNet is not initialized with ' 'zero. Set to zero manually.') elif init_type == 'convert_from_unet': # fetch module if is_model_wrapper(self.controlnet): controlnet = self.controlnet.module else: controlnet = self.controlnet if is_model_wrapper(self.unet): unet = self.unet.module else: unet = self.unet # use sd-v15 as base model by default base_model_default_cfg = dict( type='UNet2DConditionModel', from_pretrained='runwayml/stable-diffusion-v1-5', subfolder='unet') base_model_cfg = self.init_cfg.get('base_model', base_model_default_cfg) base_model = change_base_model(controlnet, unet, base_model) else: assert init_type is None, ( 'Only support \'init_from_unet\', \'convert_from_unet\' or '
f'None. But receive {init_type}.')
[docs] def train_step(self, data: dict, optim_wrapper: OptimWrapperDict) -> Dict[str, Tensor]: """Train step for ControlNet model. Args: data (dict): Data sampled from dataloader. optim_wrapper (OptimWrapperDict): OptimWrapperDict instance contains OptimWrapper of generator and discriminator. Returns: Dict[str, torch.Tensor]: A ``dict`` of tensor for logging. """ data = self.data_preprocessor(data) inputs, data_samples = data['inputs'], data['data_samples'] optimizer = optim_wrapper['controlnet'] with optimizer.optim_context(self.controlnet): target = inputs['target'] control = (inputs['source'] + 1) / 2 # [-1, 1] -> [0, 1] prompt = data_samples.prompt num_batches = target.shape[0] target = latents = self.vae.encode(target).latent_dist.sample() latents = latents * self.vae.config.scaling_factor noise = torch.randn_like(latents) timesteps = torch.randint( 0, self.scheduler.num_train_timesteps, (num_batches, ), device=self.device) timesteps = timesteps.long() noisy_latents = self.scheduler.add_noise(latents, noise, timesteps) input_ids = self.tokenizer( prompt, max_length=self.tokenizer.model_max_length, return_tensors='pt', padding='max_length', truncation=True)['input_ids'].to(self.device) encoder_hidden_states = self.text_encoder(input_ids)[0] if self.scheduler.config.prediction_type == 'epsilon': gt = noise elif self.scheduler.config.prediction_type == 'v_prediction': gt = self.scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError('Unknown prediction type ' f'{self.scheduler.config.prediction_type}') # forward control # NOTE: we train controlnet in fp32, convert to float manually down_block_res_samples, mid_block_res_sample = self.controlnet( noisy_latents.float(), timesteps, encoder_hidden_states=encoder_hidden_states.float(), controlnet_cond=control.float(), return_dict=False, ) # Predict the noise residual and compute loss # NOTE: we train unet in fp32, convert to float manually model_output = self.unet( noisy_latents.float(), timesteps, encoder_hidden_states=encoder_hidden_states.float(), down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample) model_pred = model_output['sample'] loss = F.mse_loss(model_pred.float(), gt.float(), reduction='mean') optimizer.update_params(loss) return dict(loss=loss)
[docs] def val_step(self, data: dict) -> SampleList: """Gets the generated image of given data. Calls ``self.data_preprocessor`` and ``self.infer`` in order. Return the generated results which will be passed to evaluator or visualizer. Args: data (dict or tuple or list): Data sampled from dataset. Returns: SampleList: Generated image or image dict. """ data = self.data_preprocessor(data) prompt = data['data_samples'].prompt control = data['inputs']['source'] output = self.infer( prompt, control=((control + 1) / 2), return_type='tensor') samples = output['samples'] samples = self.data_preprocessor.destruct( samples, data['data_samples'], key='target') control = self.data_preprocessor.destruct( control, data['data_samples'], key='source') data_sample = DataSample( fake_img=samples, control=control, prompt=data['data_samples'].prompt) data_sample_list = data_sample.split() return data_sample_list
[docs] def test_step(self, data: dict) -> SampleList: """Gets the generated image of given data. Calls ``self.data_preprocessor`` and ``self.infer`` in order. Return the generated results which will be passed to evaluator or visualizer. Args: data (dict or tuple or list): Data sampled from dataset. Returns: SampleList: Generated image or image dict. """ data = self.data_preprocessor(data) prompt = data['data_samples'].prompt control = data['inputs']['source'] output = self.infer( prompt, control=((control + 1) / 2), return_type='tensor') samples = output['samples'] samples = self.data_preprocessor.destruct( samples, data['data_samples'], key='target') control = self.data_preprocessor.destruct( control, data['data_samples'], key='source') data_sample = DataSample( fake_img=samples, control=control, prompt=data['data_samples'].prompt) data_sample_list = data_sample.split() return data_sample_list
# NOTE: maybe we should do this in a controlnet preprocessor @staticmethod
[docs] def prepare_control(image: Tuple[Image.Image, List[Image.Image], Tensor, List[Tensor]], width: int, height: int, batch_size: int, num_images_per_prompt: int, device: str, dtype: str) -> Tensor: """A helper function to prepare single control images. Args: image (Tuple[Image.Image, List[Image.Image], Tensor, List[Tensor]]): # noqa The input image for control. batch_size (int): The number of the prompt. The control will be repeated for `batch_size` times. num_images_per_prompt (int): The number images generate for one prompt. device (str): The device of the control. dtype (str): The dtype of the control. Returns: Tensor: The control in torch.tensor. """ if not isinstance(image, torch.Tensor): if isinstance(image, Image.Image): image = [image] if isinstance(image[0], Image.Image): image = [ img.resize((width, height), resample=Image.LANCZOS) for img in image ] image = [np.array(img)[None, :] for img in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image =, dim=0) image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size * num_images_per_prompt else: assert image_batch_size == batch_size, ( 'The number of Control condition must be 1 or equal to the ' 'number of prompt.') # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image =, dtype=dtype) return image
[docs] def train(self, mode: bool = True): """Set train/eval mode. Args: mode (bool, optional): Whether set train mode. Defaults to True. """ if mode: if next(self.controlnet.parameters() ).dtype != self._controlnet_ori_dtype: print_log( 'Set ControlNetModel dtype to ' f'\'{self._controlnet_ori_dtype}\' in the train mode.', 'current') else: print_log( f'Set ControlNetModel dtype to \'{self.dtype}\' ' 'in the eval mode.', 'current') return super().train(mode)
[docs] def infer(self, prompt: Union[str, List[str]], height: Optional[int] = None, width: Optional[int] = None, control: Optional[Union[str, np.ndarray, torch.Tensor]] = None, controlnet_conditioning_scale: float = 1.0, num_inference_steps: int = 20, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[torch.Generator] = None, latents: Optional[torch.FloatTensor] = None, return_type='image', show_progress=True): """Function invoked when calling the pipeline for generation. Args: prompt (str or List[str]): The prompt or prompts to guide the image generation. height (int, Optional): The height in pixels of the generated image. If not passed, the height will be `self.unet_sample_size * self.vae_scale_factor` Defaults to None. width (int, Optional): The width in pixels of the generated image. If not passed, the width will be `self.unet_sample_size * self.vae_scale_factor` Defaults to None. num_inference_steps (int): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. Defaults to 50. guidance_scale (float): Guidance scale as defined in Classifier- Free Diffusion Guidance ( Defaults to 7.5 negative_prompt (str or List[str], optional): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than 1). Defaults to None. num_images_per_prompt (int): The number of images to generate per prompt. Defaults to 1. eta (float): Corresponds to parameter eta (η) in the DDIM paper: Only applies to DDIMScheduler, will be ignored for others. Defaults to 0.0. generator (torch.Generator, optional): A torch generator to make generation deterministic. Defaults to None. latents (torch.FloatTensor, optional): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated by sampling using the supplied random `generator`. Defaults to None. return_type (str): The return type of the inference results. Supported types are 'image', 'numpy', 'tensor'. If 'image' is passed, a list of PIL images will be returned. If 'numpy' is passed, a numpy array with shape [N, C, H, W] will be returned, and the value range will be same as decoder's output range. If 'tensor' is passed, the decoder's output will be returned. Defaults to 'image'. Returns: dict: A dict containing the generated images and Control image. """ assert return_type in ['image', 'tensor', 'numpy'] # 0. Default height and width to unet height = height or self.unet_sample_size * self.vae_scale_factor width = width or self.unet_sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, height, width) # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self.device # here `guidance_scale` is defined analog to the # guidance weight `w` of equation (2) # of the Imagen paper: . # `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 img_dtype = self.vae.module.dtype if hasattr(self.vae, 'module') \ else self.vae.dtype if is_model_wrapper(self.controlnet): control_dtype = self.controlnet.module.dtype else: control_dtype = self.controlnet.dtype controls = self.prepare_control( control, width, height, batch_size, num_images_per_prompt, device, dtype=control_dtype) if do_classifier_free_guidance: controls =[controls] * 2) # 3. Encode input prompt text_embeddings = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt) # 4. Prepare timesteps # self.scheduler.set_timesteps(num_inference_steps, device=device) self.test_scheduler.set_timesteps(num_inference_steps) timesteps = self.test_scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, text_embeddings.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop if show_progress: timesteps = tqdm(timesteps) for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = [latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.test_scheduler.scale_model_input( latent_model_input, t) latent_model_input = text_embeddings = down_block_res_samples, mid_block_res_sample = self.controlnet( latent_model_input, t, encoder_hidden_states=text_embeddings, controlnet_cond=controls, return_dict=False, ) down_block_res_samples = [ down_block_res_sample * controlnet_conditioning_scale for down_block_res_sample in down_block_res_samples ] mid_block_res_sample *= controlnet_conditioning_scale # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=text_embeddings, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, )['sample'] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.test_scheduler.step( noise_pred, t, latents, **extra_step_kwargs)['prev_sample'] # 8. Post-processing image = self.decode_latents( if do_classifier_free_guidance: controls = torch.split(controls, controls.shape[0] // 2, dim=0)[0] if return_type == 'image': image = self.output_to_pil(image) controls = self.output_to_pil(controls * 2 - 1) elif return_type == 'numpy': image = image.cpu().numpy() controls = controls.cpu().numpy() else: assert return_type == 'tensor', ( 'Only support \'image\', \'numpy\' and \'tensor\' for ' f'return_type, but receive {return_type}') return {'samples': image, 'controls': controls}
[docs] def forward(self, *args, **kwargs): """forward is not implemented now.""" raise NotImplementedError( 'Forward is not implemented now, please use infer.')
[docs]class ControlStableDiffusionImg2Img(ControlStableDiffusion):
[docs] def _default_height_width(self, height, width, image): if isinstance(image, list): image = image[0] if height is None: if isinstance(image, Image.Image): height = image.height elif isinstance(image, torch.Tensor): height = image.shape[3] height = (height // 8) * 8 # round down to nearest multiple of 8 if width is None: if isinstance(image, Image.Image): width = image.width elif isinstance(image, torch.Tensor): width = image.shape[2] width = (width // 8) * 8 # round down to nearest multiple of 8 return height, width
[docs] def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min( int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.test_scheduler.timesteps[t_start:] return timesteps, num_inference_steps - t_start
[docs] def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, noise=None): if not isinstance(image, (torch.Tensor, Image.Image, list)): raise ValueError( f'`image` has to be of type `torch.Tensor`, ' f' `PIL.Image.Image` or list but is {type(image)}') image =, dtype=dtype) batch_size = batch_size * num_images_per_prompt if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f'You have passed a list of generators of ' f' length {len(generator)}, but requested an effective batch' f' size of {batch_size}. Make sure the batch size ' f' matches the length of the generators.') if isinstance(generator, list): init_latents = [ self.vae.encode(image[i:i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] init_latents =, dim=0) else: init_latents = self.vae.encode(image).latent_dist.sample(generator) init_latents = self.vae.config.scaling_factor * init_latents vae_encode_latents = init_latents if batch_size > init_latents.shape[0] and \ batch_size % init_latents.shape[0] == 0: raise ValueError( f'Cannot duplicate `image` of batch size' f' {init_latents.shape[0]} to {batch_size} text prompts.') else: init_latents =[init_latents], dim=0) shape = init_latents.shape if noise is None: noise = torch.randn( shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) return init_latents, vae_encode_latents
[docs] def prepare_latent_image(self, image, dtype): if isinstance(image, torch.Tensor): # Batch single image if image.ndim == 3: image = image.unsqueeze(0) image = else: # preprocess image if isinstance(image, (Image.Image, np.ndarray)): image = [image] if isinstance(image, list) and isinstance(image[0], Image.Image): image = [np.array(i.convert('RGB'))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): image = np.concatenate([i[None, :] for i in image], axis=0) image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=dtype) / 127.5 - 1.0 return image
[docs] def infer( self, prompt: Union[str, List[str]], latent_image: Union[torch.FloatTensor, Image.Image, List[torch.FloatTensor], List[Image.Image]] = None, latent_mask: torch.FloatTensor = None, strength: float = 1.0, height: Optional[int] = None, width: Optional[int] = None, control: Optional[Union[str, np.ndarray, torch.Tensor]] = None, controlnet_conditioning_scale: float = 1.0, num_inference_steps: int = 20, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[torch.Generator] = None, latents: Optional[torch.FloatTensor] = None, return_type='image', show_progress=True, reference_img: Union[torch.FloatTensor, Image.Image, List[torch.FloatTensor], List[Image.Image]] = None, ): """Function invoked when calling the pipeline for generation. Args: prompt (str or List[str]): The prompt or prompts to guide the image generation. height (int, Optional): The height in pixels of the generated image. If not passed, the height will be `self.unet_sample_size * self.vae_scale_factor` Defaults to None. width (int, Optional): The width in pixels of the generated image. If not passed, the width will be `self.unet_sample_size * self.vae_scale_factor` Defaults to None. num_inference_steps (int): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. Defaults to 50. guidance_scale (float): Guidance scale as defined in Classifier- Free Diffusion Guidance ( Defaults to 7.5 negative_prompt (str or List[str], optional): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than 1). Defaults to None. num_images_per_prompt (int): The number of images to generate per prompt. Defaults to 1. eta (float): Corresponds to parameter eta (η) in the DDIM paper: Only applies to DDIMScheduler, will be ignored for others. Defaults to 0.0. generator (torch.Generator, optional): A torch generator to make generation deterministic. Defaults to None. latents (torch.FloatTensor, optional): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated by sampling using the supplied random `generator`. Defaults to None. return_type (str): The return type of the inference results. Supported types are 'image', 'numpy', 'tensor'. If 'image' is passed, a list of PIL images will be returned. If 'numpy' is passed, a numpy array with shape [N, C, H, W] will be returned, and the value range will be same as decoder's output range. If 'tensor' is passed, the decoder's output will be returned. Defaults to 'image'. Returns: dict: A dict containing the generated images and Control image. """ assert return_type in ['image', 'tensor', 'numpy'] # 0. Default height and width to unet # height = height or self.unet_sample_size * self.vae_scale_factor # width = width or self.unet_sample_size * self.vae_scale_factor height, width = self._default_height_width(height, width, control) # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, height, width) # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self.device # here `guidance_scale` is defined analog to the # guidance weight `w` of equation (2) # of the Imagen paper: . # `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 img_dtype = self.vae.module.dtype if hasattr( self.vae, 'module') else self.vae.dtype if is_model_wrapper(self.controlnet): control_dtype = self.controlnet.module.dtype else: control_dtype = self.controlnet.dtype controls = self.prepare_control( control, width, height, batch_size, num_images_per_prompt, device, dtype=control_dtype) if do_classifier_free_guidance: controls =[controls] * 2) latent_image = self.prepare_latent_image(latent_image, self.controlnet.dtype) if reference_img is not None: reference_img = self.prepare_latent_image(reference_img, self.controlnet.dtype) # 3. Encode input prompt text_embeddings = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt) text_embeddings = # 4. Prepare timesteps # self.scheduler.set_timesteps(num_inference_steps, device=device) self.test_scheduler.set_timesteps(num_inference_steps) timesteps = self.test_scheduler.timesteps timesteps, num_inference_steps = self.get_timesteps( num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 5. Prepare latent variables latents, vae_encode_latents = self.prepare_latents( latent_image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator, noise=latents) if reference_img is not None: _, ref_img_vae_latents = self.prepare_latents( reference_img, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator, noise=latents) # 6. Prepare extra step kwargs. extra_step_kwargs = self.prepare_test_scheduler_extra_step_kwargs( generator, eta) # 7. Denoising loop if show_progress: timesteps = tqdm(timesteps) for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = [latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.test_scheduler.scale_model_input( latent_model_input, t) latent_model_input = if reference_img is not None: ref_img_vae_latents_t = self.scheduler.add_noise( ref_img_vae_latents, torch.randn_like(ref_img_vae_latents), t) ref_img_vae_latents_model_input = [ref_img_vae_latents_t] * 2) if \ do_classifier_free_guidance else ref_img_vae_latents_t ref_img_vae_latents_model_input = \ self.test_scheduler.scale_model_input( ref_img_vae_latents_model_input, t) ref_img_vae_latents_model_input = \ down_block_res_samples, mid_block_res_sample = self.controlnet( latent_model_input, t, encoder_hidden_states=text_embeddings, controlnet_cond=controls, return_dict=False, ) down_block_res_samples = [ down_block_res_sample * controlnet_conditioning_scale for down_block_res_sample in down_block_res_samples ] mid_block_res_sample *= controlnet_conditioning_scale # predict the noise residual if reference_img is not None: noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=text_embeddings, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, ref_x=ref_img_vae_latents_model_input)['sample'] else: noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=text_embeddings, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, )['sample'] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.test_scheduler.step( noise_pred, t, latents, **extra_step_kwargs)['prev_sample'] if latent_mask is not None: latents = latents * latent_mask + \ vae_encode_latents * (1.0 - latent_mask) # 8. Post-processing image = self.decode_latents( if do_classifier_free_guidance: controls = torch.split(controls, controls.shape[0] // 2, dim=0)[0] if return_type == 'image': image = self.output_to_pil(image) controls = self.output_to_pil(controls * 2 - 1) elif return_type == 'numpy': image = image.cpu().numpy() controls = controls.cpu().numpy() else: assert return_type == 'tensor', ( 'Only support \'image\', \'numpy\' and \'tensor\' for ' f'return_type, but receive {return_type}') return {'samples': image, 'controls': controls}
