Shortcuts

Source code for mmagic.models.editors.stable_diffusion_xl.stable_diffusion_xl

# Copyright (c) OpenMMLab. All rights reserved.
import inspect
from copy import deepcopy
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine import print_log
from mmengine.logging import MMLogger
from mmengine.model import BaseModel
from mmengine.optim import OptimWrapperDict
from mmengine.runner import set_random_seed
from PIL import Image
from tqdm.auto import tqdm

from mmagic.models.archs import TokenizerWrapper, set_lora
from mmagic.models.utils import build_module, set_tomesd, set_xformers
from mmagic.registry import DIFFUSION_SCHEDULERS, MODELS
from mmagic.structures import DataSample
from mmagic.utils.typing import SampleList

[docs]logger = MMLogger.get_current_instance()
[docs]ModelType = Union[Dict, nn.Module]
@MODELS.register_module('sdxl') @MODELS.register_module()
[docs]class StableDiffusionXL(BaseModel): """Class for Stable Diffusion XL. Refers to https://github.com/Stability- AI. /generative-models and https://github.com/huggingface/diffusers/blob/main/ src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py Args: unet (Union[dict, nn.Module]): The config or module for Unet model. text_encoder_one (Union[dict, nn.Module]): The config or module for text encoder. tokenizer_one (str): The **name** for CLIP tokenizer. text_encoder_two (Union[dict, nn.Module]): The config or module for text encoder. tokenizer_two (str): The **name** for CLIP tokenizer. vae (Union[dict, nn.Module]): The config or module for VAE model. schedule (Union[dict, nn.Module]): The config or module for diffusion scheduler. test_scheduler (Union[dict, nn.Module], optional): The config or module for diffusion scheduler in test stage (`self.infer`). If not passed, will use the same scheduler as `schedule`. Defaults to None. dtype (str, optional): The dtype for the model This argument will not work when dtype is defined for submodels. Defaults to None. enable_xformers (bool, optional): Whether to use xformers. Defaults to True. noise_offset_weight (bool, optional): The weight of noise offset introduced in https://www.crosslabs.org/blog/diffusion-with-offset-noise Defaults to 0. tomesd_cfg (dict, optional): The config for TOMESD. Please refers to https://github.com/dbolya/tomesd and https://github.com/open-mmlab/mmagic/blob/main/mmagic/models/utils/tome_utils.py for detail. # noqa Defaults to None. data_preprocessor (dict, optional): The pre-process config of :class:`BaseDataPreprocessor`. lora_config (dict, optional): The config for LoRA finetuning. Defaults to None. val_prompts (Union[str, List[str]], optional): The prompts for validation. Defaults to None. finetune_text_encoder (bool, optional): Whether to fine-tune text encoder. Defaults to False. force_zeros_for_empty_prompt (bool): Whether the negative prompt embeddings shall be forced to always be set to 0. Defaults to True. init_cfg (dict, optional): The weight initialized config for :class:`BaseModule`. """ def __init__(self, vae: ModelType, text_encoder_one: ModelType, tokenizer_one: str, text_encoder_two: ModelType, tokenizer_two: str, unet: ModelType, scheduler: ModelType, test_scheduler: Optional[ModelType] = None, dtype: Optional[str] = None, enable_xformers: bool = True, noise_offset_weight: float = 0, tomesd_cfg: Optional[dict] = None, data_preprocessor: Optional[ModelType] = dict( type='DataPreprocessor'), lora_config: Optional[dict] = None, val_prompts: Union[str, List[str]] = None, finetune_text_encoder: bool = False, force_zeros_for_empty_prompt: bool = True, init_cfg: Optional[dict] = None): # TODO: support `from_pretrained` for this class super().__init__(data_preprocessor, init_cfg) default_args = dict() if dtype is not None: default_args['dtype'] = dtype self.dtype = torch.float32 if dtype in ['float16', 'fp16', 'half']: self.dtype = torch.float16 elif dtype == 'bf16': self.dtype = torch.bfloat16 else: assert dtype in [ 'fp32', None ], ('dtype must be one of \'fp32\', \'fp16\', \'bf16\' or None.') self.vae = build_module(vae, MODELS, default_args=default_args) self.unet = build_module(unet, MODELS) # NOTE: initialize unet as fp32 self._unet_ori_dtype = next(self.unet.parameters()).dtype print_log(f'Set UNet dtype to \'{self._unet_ori_dtype}\'.', 'current') self.scheduler = build_module(scheduler, DIFFUSION_SCHEDULERS) if test_scheduler is None: self.test_scheduler = deepcopy(self.scheduler) else: self.test_scheduler = build_module(test_scheduler, DIFFUSION_SCHEDULERS) self.text_encoder_one = build_module(text_encoder_one, MODELS) if not isinstance(tokenizer_one, str): self.tokenizer_one = tokenizer_one else: # NOTE: here we assume tokenizer is an string self.tokenizer_one = TokenizerWrapper( tokenizer_one, subfolder='tokenizer') self.text_encoder_two = build_module(text_encoder_two, MODELS) if not isinstance(tokenizer_two, str): self.tokenizer_two = tokenizer_two else: # NOTE: here we assume tokenizer is an string self.tokenizer_two = TokenizerWrapper( tokenizer_two, subfolder='tokenizer_2') self.unet_sample_size = self.unet.sample_size self.vae_scale_factor = 2**(len(self.vae.block_out_channels) - 1) self.enable_noise_offset = noise_offset_weight > 0 self.noise_offset_weight = noise_offset_weight self.finetune_text_encoder = finetune_text_encoder self.val_prompts = val_prompts self.lora_config = deepcopy(lora_config) self.force_zeros_for_empty_prompt = force_zeros_for_empty_prompt self.prepare_model() self.set_lora() self.enable_xformers = enable_xformers self.set_xformers() self.tomesd_cfg = tomesd_cfg self.set_tomesd()
[docs] def prepare_model(self): """Prepare model for training. Move model to target dtype and disable gradient for some models. """ self.vae.requires_grad_(False) print_log('Set VAE untrainable.', 'current') self.vae.to(self.dtype) print_log(f'Move VAE to {self.dtype}.', 'current') if not self.finetune_text_encoder or self.lora_config: self.text_encoder_one.requires_grad_(False) self.text_encoder_two.requires_grad_(False) print_log('Set Text Encoder untrainable.', 'current') self.text_encoder_one.to(self.dtype) self.text_encoder_two.to(self.dtype) print_log(f'Move Text Encoder to {self.dtype}.', 'current') if self.lora_config: self.unet.requires_grad_(False) print_log('Set Unet untrainable.', 'current')
[docs] def set_lora(self): """Set LORA for model.""" if self.lora_config: set_lora(self.unet, self.lora_config)
[docs] def set_xformers(self, module: Optional[nn.Module] = None) -> nn.Module: """Set xformers for the model. Returns: nn.Module: The model with xformers. """ if self.enable_xformers: if module is None: set_xformers(self) else: set_xformers(module)
[docs] def set_tomesd(self) -> nn.Module: """Set ToMe for the stable diffusion model. Returns: nn.Module: The model with ToMe. """ if self.tomesd_cfg is not None: set_tomesd(self, **self.tomesd_cfg)
@property
[docs] def device(self): return next(self.parameters()).device
[docs] def train(self, mode: bool = True): """Set train/eval mode. Args: mode (bool, optional): Whether set train mode. Defaults to True. """ if mode: if next(self.unet.parameters()).dtype != self._unet_ori_dtype: print_log( f'Set UNet dtype to \'{self._unet_ori_dtype}\' ' 'in the train mode.', 'current') self.unet.to(self._unet_ori_dtype) else: self.unet.to(self.dtype) print_log(f'Set UNet dtype to \'{self.dtype}\' in the eval mode.', 'current') return super().train(mode)
@torch.no_grad()
[docs] def infer(self, prompt: Union[str, List[str]], prompt_2: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, denoising_end: Optional[float] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[torch.Generator] = None, latents: Optional[torch.FloatTensor] = None, show_progress: bool = True, seed: int = 1, original_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Optional[Tuple[int, int]] = None, negative_original_size: Optional[Tuple[int, int]] = None, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, return_type='image'): """Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. prompt2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_two` and `text_encoder_two`. If not defined, `prompt` is used in both text-encoders. Defaults to None. height (`int`, *optional*, defaults to self.unet_sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet_sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise as determined by the discrete timesteps selected by the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image Output**]( https://huggingface.co/docs/diffusers/api/pipelines/ stable_diffusion/stable_diffusion_xl#refining-the-image-output) guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance] (https://arxiv.org/abs/2207.12598). negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*)): The `negative_prompt` to be sent to the `tokenizer_two` and `text_encoder_two`. If not defined, `negative_prompt` is used in both text-encoders. Defaults to None. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): A [torch generator] to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated by sampling using the supplied random `generator`. show_progress (bool): Whether to show progress. Defaults to False. seed (int): Seed to be used. Defaults to 1. original_size (`Tuple[int]`, *optional*): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. If `original_size` is `(width, height)` if not specified. Defaults to None. crops_coords_top_left (`Tuple[int]`, *optional*): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Defaults to (0, 0). target_size (`Tuple[int]`, *optional*): For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will be `(width, height)`. Defaults to None. negative_original_size (`Tuple[int]`, *optional*): To negatively condition the generation process based on a specific image resolution. For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. Defaults to None. negative_crops_coords_top_left (`Tuple[int]`, *optional*): To negatively condition the generation process based on a specific crop coordinates. For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. Defaults to (0, 0). negative_target_size (`Tuple[int]`, *optional*): To negatively condition the generation process based on a target image resolution. It should be as same as the `target_size` for most cases. For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. Defaults to None. return_type (str): The return type of the inference results. Supported types are 'image', 'numpy', 'tensor'. If 'image' is passed, a list of PIL images will be returned. If 'numpy' is passed, a numpy array with shape [N, C, H, W] will be returned, and the value range will be same as decoder's output range. If 'tensor' is passed, the decoder's output will be returned. Defaults to 'image'. Returns: dict: A dict containing the generated images. """ assert return_type in ['image', 'tensor', 'numpy'] set_random_seed(seed=seed) # 0. Default height and width to unet height = height or self.unet_sample_size * self.vae_scale_factor width = width or self.unet_sample_size * self.vae_scale_factor original_size = original_size or (height, width) target_size = target_size or (height, width) # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, height, width) # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self.device img_dtype = self.vae.module.dtype if hasattr(self.vae, 'module') \ else self.vae.dtype latent_dtype = next(self.unet.parameters()).dtype # here `guidance_scale` is defined analog to the # guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . # `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self._encode_prompt(prompt, prompt_2, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, negative_prompt_2) # 4. Prepare timesteps self.test_scheduler.set_timesteps(num_inference_steps) timesteps = self.test_scheduler.timesteps # 5. Prepare latent variables if hasattr(self.unet, 'module'): num_channels_latents = self.unet.module.in_channels else: num_channels_latents = self.unet.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. # TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype) if (negative_original_size is not None) and (negative_target_size is not None): negative_add_time_ids = self._get_add_time_ids( negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, ) else: negative_add_time_ids = add_time_ids if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat( [negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat( batch_size * num_images_per_prompt, 1) # 9 Apply denoising_end if denoising_end is not None and isinstance( denoising_end, float) and denoising_end > 0 and denoising_end < 1: discrete_timestep_cutoff = int( round(self.scheduler.config.num_train_timesteps - (denoising_end * self.scheduler.config.num_train_timesteps))) num_inference_steps = len( list( filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] # 10. Denoising loop if show_progress: timesteps = tqdm(timesteps) for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat( [latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.test_scheduler.scale_model_input( latent_model_input, t) latent_model_input = latent_model_input.to(latent_dtype) prompt_embeds = prompt_embeds.to(latent_dtype) # predict the noise residual added_cond_kwargs = { 'text_embeds': add_text_embeds, 'time_ids': add_time_ids } noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.test_scheduler.step( noise_pred, t, latents, **extra_step_kwargs)['prev_sample'] # 8. Post-processing image = self.decode_latents(latents.to(img_dtype)) if return_type == 'image': image = self.output_to_pil(image) elif return_type == 'numpy': image = image.cpu().numpy() else: assert return_type == 'tensor', ( 'Only support \'image\', \'numpy\' and \'tensor\' for ' f'return_type, but receive {return_type}') return {'samples': image}
[docs] def _get_add_time_ids(self, original_size: Optional[Tuple[int, int]], crops_coords_top_left: Tuple[int, int], target_size: Optional[Tuple[int, int]], dtype): """Get `add_time_ids`. Args: original_size (`Tuple[int]`, *optional*): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. If `original_size` is `(width, height)` if not specified. Defaults to None. crops_coords_top_left (`Tuple[int]`, *optional*): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Defaults to (0, 0). target_size (`Tuple[int]`, *optional*): For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will be `(width, height)`. Defaults to None. dtype (str, optional): The dtype for the embeddings. Returns: add_time_ids (torch.Tensor): time ids for time embeddings layer. """ add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_two.config.projection_dim) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( 'Model expects an added time embedding vector of length ' f'{expected_add_embed_dim}, but a vector of ' f'{passed_add_embed_dim} was created. The model has an ' 'incorrect config. Please check ' '`unet.config.time_embedding_type` and ' '`text_encoder_2.config.projection_dim`.') add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids
[docs] def output_to_pil(self, image) -> List[Image.Image]: """Convert output tensor to PIL image. Output tensor will be de-normed to [0, 255] by `DataPreprocessor.destruct`. Due to no `data_samples` is passed, color order conversion will not be performed. Args: image (torch.Tensor): The output tensor of the decoder. Returns: List[Image.Image]: The list of processed PIL images. """ image = self.data_preprocessor.destruct(image) image = image.permute(0, 2, 3, 1).to(torch.uint8).cpu().numpy() image = [Image.fromarray(img) for img in image] return image
[docs] def _encode_prompt(self, prompt, prompt_2, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, negative_prompt_2): """Encodes the prompt into text encoder hidden states. Args: prompt (str or list(int)): prompt to be encoded. prompt_2 (str or list(int)): prompt to be encoded. Send to the `tokenizer_two` and `text_encoder_two`. If not defined, `prompt` is used in both text-encoders. device: (torch.device): torch device. num_images_per_prompt (int): number of images that should be generated per prompt. do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not. negative_prompt (str or List[str]): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (str or List[str]): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). Send to `tokenizer_two` and `text_encoder_two`. If not defined, `negative_prompt` is used in both text-encoders Returns: text_embeddings (torch.Tensor): text embeddings generated by clip text encoder. """ batch_size = len(prompt) if isinstance(prompt, list) else 1 prompt_2 = prompt_2 or prompt tokenizers = [self.tokenizer_one, self.tokenizer_two] text_encoders = [self.text_encoder_one, self.text_encoder_two] prompts = [prompt, prompt_2] prompt_embeds_list = [] for tokenizer, text_encoder, prompt in zip(tokenizers, text_encoders, prompts): text_inputs = tokenizer( prompt, padding='max_length', max_length=tokenizer.model_max_length, truncation=True, return_tensors='pt', ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer( prompt, padding='max_length', return_tensors='pt').input_ids if not torch.equal(text_input_ids, untruncated_ids): removed_text = tokenizer.batch_decode( untruncated_ids[:, tokenizer.model_max_length - 1:-1]) logger.warning( 'The following part of your input was truncated because ' ' CLIP can only handle sequences up to' f' {tokenizer.model_max_length} tokens: {removed_text}') text_encoder = text_encoder.module if hasattr( text_encoder, 'module') else text_encoder text_embeddings = text_encoder( text_input_ids.to(device), output_hidden_states=True, ) pooled_prompt_embeds = text_embeddings.pooler_output text_embeddings = text_embeddings.hidden_states[-2] prompt_embeds_list.append(text_embeddings) text_embeddings = torch.concat(prompt_embeds_list, dim=-1) # duplicate text embeddings for each generation per prompt, bs_embed, seq_len, _ = text_embeddings.shape text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) text_embeddings = text_embeddings.view( bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and self.force_zeros_for_empty_prompt: negative_prompt_embeds = torch.zeros_like(text_embeddings) negative_pooled_prompt_embeds = torch.zeros_like( pooled_prompt_embeds) elif do_classifier_free_guidance: negative_prompt = negative_prompt or '' negative_prompt_2 = negative_prompt_2 or negative_prompt uncond_tokens: List[str] if prompt is not None and type(prompt) is not type( negative_prompt): raise TypeError( '`negative_prompt` should be the same type to `prompt`, ' f'but got {type(negative_prompt)} != {type(prompt)}.') elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt, negative_prompt_2] elif batch_size != len(negative_prompt): raise ValueError( f'`negative_prompt`: {negative_prompt} has batch size ' f'{len(negative_prompt)}, but `prompt`: {prompt} has batch' f' size {batch_size}. Please make sure that passed ' '`negative_prompt` matches the batch size of `prompt`.') else: uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] for negative_prompt, tokenizer, text_encoder in zip( uncond_tokens, tokenizers, text_encoders): max_length = text_embeddings.shape[1] uncond_input = tokenizer( negative_prompt, padding='max_length', max_length=max_length, truncation=True, return_tensors='pt', ) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the # final text encoder negative_pooled_prompt_embeds = ( negative_prompt_embeds.pooler_output) negative_prompt_embeds = negative_prompt_embeds.hidden_states[ -2] negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat( negative_prompt_embeds_list, dim=-1) bs_embed, seq_len, _ = text_embeddings.shape # duplicate text embeddings for each generation per prompt, using mps # friendly method text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) text_embeddings = text_embeddings.view( bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt # ,using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat( 1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view( batch_size * num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat( 1, num_images_per_prompt).view(bs_embed * num_images_per_prompt, -1) if do_classifier_free_guidance: negative_pooled_prompt_embeds = ( negative_pooled_prompt_embeds.repeat( 1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1)) return (text_embeddings, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds)
[docs] def decode_latents(self, latents): """use vae to decode latents. Args: latents (torch.Tensor): latents to decode. Returns: image (torch.Tensor): image result. """ latents = 1 / 0.18215 * latents if hasattr(self.vae, 'module'): image = self.vae.module.decode(latents)['sample'] else: image = self.vae.decode(latents)['sample'] # we always cast to float32 as this does not cause # significant overhead and is compatible with bfloa16 return image.float()
[docs] def prepare_extra_step_kwargs(self, generator, eta): """prepare extra kwargs for the scheduler step. Args: generator (torch.Generator): generator for random functions. eta (float): eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 and should be between [0, 1] Return: extra_step_kwargs (dict): dict contains 'generator' and 'eta' """ accepts_eta = 'eta' in set( inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs['eta'] = eta # check if the scheduler accepts generator accepts_generator = 'generator' in set( inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs['generator'] = generator return extra_step_kwargs
[docs] def prepare_test_scheduler_extra_step_kwargs(self, generator, eta): """prepare extra kwargs for the scheduler step. Args: generator (torch.Generator): generator for random functions. eta (float): eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 and should be between [0, 1] Return: extra_step_kwargs (dict): dict contains 'generator' and 'eta' """ accepts_eta = 'eta' in set( inspect.signature(self.test_scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs['eta'] = eta # check if the scheduler accepts generator accepts_generator = 'generator' in set( inspect.signature(self.test_scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs['generator'] = generator return extra_step_kwargs
[docs] def check_inputs(self, prompt, height, width): """check whether inputs are in suitable format or not.""" if not isinstance(prompt, str) and not isinstance(prompt, list): raise ValueError(f'`prompt` has to be of ' f'type `str` or `list` but is {type(prompt)}') if height % 8 != 0 or width % 8 != 0: raise ValueError(f'`height` and `width` have to be divisible '
f'by 8 but are {height} and {width}.')
[docs] def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): """prepare latents for diffusion to run in latent space. Args: batch_size (int): batch size. num_channels_latents (int): latent channel nums. height (int): image height. width (int): image width. dtype (torch.dtype): float type. device (torch.device): torch device. generator (torch.Generator): generator for random functions, defaults to None. latents (torch.Tensor): Pre-generated noisy latents, defaults to None. Return: latents (torch.Tensor): prepared latents. """ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if latents is None: latents = torch.randn( shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f'Unexpected latents shape, ' f'got {latents.shape}, expected {shape}') latents = latents.to(device) # scale the initial noise by the standard # deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents
@torch.no_grad()
[docs] def val_step(self, data: dict) -> SampleList: """Gets the generated image of given data. Args: data (dict): Data sampled from metric specific sampler. More details in `Metrics` and `Evaluator`. Returns: SampleList: Generated image or image dict. """ if self.val_prompts is None: data = self.data_preprocessor(data) data_samples = data['data_samples'] prompt = data_samples.prompt else: prompt = self.val_prompts # construct a fake data_sample for destruct data_samples = DataSample.stack(data['data_samples'] * len(prompt)) output = self.infer(prompt, return_type='tensor') samples = output['samples'] samples = self.data_preprocessor.destruct(samples, data_samples) if self.val_prompts is None: gt_img = self.data_preprocessor.destruct(data['inputs'], data_samples) out_data_sample = DataSample( fake_img=samples, gt_img=gt_img, prompt=prompt) else: out_data_sample = DataSample(fake_img=samples, prompt=prompt) data_sample_list = out_data_sample.split() return data_sample_list
@torch.no_grad()
[docs] def test_step(self, data: dict) -> SampleList: """Gets the generated image of given data. Same as :meth:`val_step`. Args: data (dict): Data sampled from metric specific sampler. More details in `Metrics` and `Evaluator`. Returns: SampleList: A list of ``DataSample`` contain generated results. """ if self.val_prompts is None: data = self.data_preprocessor(data) data_samples = data['data_samples'] prompt = data_samples.prompt else: prompt = self.val_prompts # construct a fake data_sample for destruct data_samples = DataSample.stack(data['data_samples'] * len(prompt)) output = self.infer(prompt, return_type='tensor') samples = output['samples'] samples = self.data_preprocessor.destruct(samples, data_samples) if self.val_prompts is None: gt_img = self.data_preprocessor.destruct(data['inputs'], data_samples) out_data_sample = DataSample( fake_img=samples, gt_img=gt_img, prompt=prompt) else: out_data_sample = DataSample(fake_img=samples, prompt=prompt) data_sample_list = out_data_sample.split() return data_sample_list
[docs] def encode_prompt_train(self, text_one, text_two): """Encode prompt for training. Args: text_one (torch.tensor): Input ids from tokenizer_one. text_two (torch.tensor): Input ids from tokenizer_two. Returns: prompt_embeds (torch.tensor): Prompt embedings. pooled_prompt_embeds (torch.tensor): Pooled prompt embeddings. """ prompt_embeds_list = [] text_encoders = [self.text_encoder_one, self.text_encoder_two] texts = [text_one, text_two] for text_encoder, text in zip(text_encoders, texts): prompt_embeds = text_encoder( text, output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the # final text encoder pooled_prompt_embeds = prompt_embeds.pooler_output prompt_embeds = prompt_embeds.hidden_states[-2] bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) return prompt_embeds, pooled_prompt_embeds
[docs] def train_step(self, data: List[dict], optim_wrapper: OptimWrapperDict): """Train step function. Args: data (List[dict]): Batch of data as input. optim_wrapper (OptimWrapperDict): Dict with optimizers for generator and discriminator (if have). Returns: dict: Dict with loss, information for logger, the number of \ samples and results for visualization. """ data = self.data_preprocessor(data) inputs, data_samples = data['inputs'], data['data_samples'] vae = self.vae.module if hasattr(self.vae, 'module') else self.vae with optim_wrapper.optim_context(self.unet): image = inputs prompt = data_samples.prompt num_batches = image.shape[0] image = image.to(self.dtype) latents = vae.encode(image).latent_dist.sample() latents = latents * vae.config.scaling_factor noise = torch.randn_like(latents) if self.enable_noise_offset: noise = noise + self.noise_offset_weight * torch.randn( latents.shape[0], latents.shape[1], 1, 1, device=noise.device) timesteps = torch.randint( 0, self.scheduler.num_train_timesteps, (num_batches, ), device=self.device) timesteps = timesteps.long() noisy_latents = self.scheduler.add_noise(latents, noise, timesteps) input_ids_one = self.tokenizer_one( prompt, max_length=self.tokenizer_one.model_max_length, return_tensors='pt', padding='max_length', truncation=True)['input_ids'].to(self.device) input_ids_two = self.tokenizer_two( prompt, max_length=self.tokenizer_two.model_max_length, return_tensors='pt', padding='max_length', truncation=True)['input_ids'].to(self.device) (encoder_hidden_states, pooled_prompt_embeds) = self.encode_prompt_train( input_ids_one, input_ids_two) unet_added_conditions = { 'time_ids': data['time_ids'], 'text_embeds': pooled_prompt_embeds } if self.scheduler.config.prediction_type == 'epsilon': gt = noise elif self.scheduler.config.prediction_type == 'v_prediction': gt = self.scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError('Unknown prediction type ' f'{self.scheduler.config.prediction_type}') # NOTE: we train unet in fp32, convert to float manually model_output = self.unet( noisy_latents.float(), timesteps, encoder_hidden_states=encoder_hidden_states.float(), added_cond_kwargs=unet_added_conditions) model_pred = model_output['sample'] loss_dict = dict() # calculate loss in FP32 loss_mse = F.mse_loss(model_pred.float(), gt.float()) loss_dict['loss_mse'] = loss_mse parsed_loss, log_vars = self.parse_losses(loss_dict) optim_wrapper.update_params(parsed_loss) return log_vars
[docs] def forward(self, inputs: torch.Tensor, data_samples: Optional[list] = None, mode: str = 'tensor') -> Union[Dict[str, torch.Tensor], list]: """forward is not implemented now.""" raise NotImplementedError( 'Forward is not implemented now, please use infer.')
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.