Shortcuts

mmagic.models.diffusion_schedulers.ddim_scheduler 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Union

import numpy as np
import torch

from mmagic.models.utils.diffusion_utils import betas_for_alpha_bar
from mmagic.registry import DIFFUSION_SCHEDULERS


@DIFFUSION_SCHEDULERS.register_module()
[文档]class EditDDIMScheduler: """```EditDDIMScheduler``` support the diffusion and reverse process formulated in https://arxiv.org/abs/2010.02502. The code is heavily influenced by https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py. # noqa The difference is that we ensemble gradient-guided sampling in step function. Args: num_train_timesteps (int, optional): _description_. Defaults to 1000. beta_start (float, optional): _description_. Defaults to 0.0001. beta_end (float, optional): _description_. Defaults to 0.02. beta_schedule (str, optional): _description_. Defaults to "linear". variance_type (str, optional): _description_. Defaults to 'learned_range'. timestep_values (_type_, optional): _description_. Defaults to None. clip_sample (bool, optional): _description_. Defaults to True. set_alpha_to_one (bool, optional): _description_. Defaults to True. """ def __init__( self, num_train_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule='linear', variance_type='learned_range', timestep_values=None, clip_sample=True, set_alpha_to_one=True, ): self.num_train_timesteps = num_train_timesteps self.beta_start = beta_start self.beta_end = beta_end self.beta_schedule = beta_schedule self.variance_type = variance_type self.timestep_values = timestep_values self.clip_sample = clip_sample self.set_alpha_to_one = set_alpha_to_one if beta_schedule == 'linear': self.betas = np.linspace( beta_start, beta_end, num_train_timesteps, dtype=np.float32) elif beta_schedule == 'scaled_linear': # this schedule is very specific to the latent diffusion model. self.betas = np.linspace( beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32)**2 elif beta_schedule == 'squaredcos_cap_v2': # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError( f'{beta_schedule} does is not implemented for {self.__class__}' ) self.alphas = 1.0 - self.betas self.alphas_cumprod = np.cumprod(self.alphas, axis=0) # At every step in ddim, we are looking into the # previous alphas_cumprod. For the final step, # there is no previous alphas_cumprod because we are already # at 0 `set_alpha_to_one` decides whether we set this parameter # simply to one or whether we use the final alpha of the # "non-previous" one. self.final_alpha_cumprod = np.array( 1.0) if set_alpha_to_one else self.alphas_cumprod[0] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.num_inference_steps = None self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
[文档] def set_timesteps(self, num_inference_steps, offset=0): """set time steps.""" self.num_inference_steps = num_inference_steps self.timesteps = np.arange( 0, self.num_train_timesteps, self.num_train_timesteps // self.num_inference_steps)[::-1].copy() self.timesteps += offset
[文档] def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ return sample
[文档] def _get_variance(self, timestep, prev_timestep): """get variance.""" alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[ prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) return variance
[文档] def step( self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray], cond_fn=None, cond_kwargs={}, eta: float = 0.0, use_clipped_model_output: bool = False, generator=None, ): """step forward.""" output = {} if self.num_inference_steps is None: raise ValueError("Number of inference steps is 'None', '\ 'you need to run 'set_timesteps' '\ 'after creating the scheduler") pred = None if isinstance(model_output, dict): pred = model_output['pred'] model_output = model_output['eps'] elif model_output.shape[1] == sample.shape[ 1] * 2 and self.variance_type in ['learned', 'learned_range']: model_output, _ = torch.split(model_output, sample.shape[1], dim=1) else: if not model_output.shape == sample.shape: raise TypeError # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # noqa # Ideally, read DDIM paper in-detail understanding # Notation (<variable name> -> <name in paper> # - pred_noise_t -> e_theta(x_t, t) # - pred_original_sample -> f_theta(x_t, t) or x_0 # - std_dev_t -> sigma_t # - eta -> η # - pred_sample_direction -> "direction pointingc to x_t" # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) prev_timestep = ( timestep - self.num_train_timesteps // self.num_inference_steps) # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[ prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # noqa pred_original_sample = (sample - ( (beta_prod_t)**(0.5)) * model_output) / alpha_prod_t**(0.5) if pred is not None: pred_original_sample = pred gradient = 0. if cond_fn is not None: if cond_fn.__name__ == 'classifier_grad': y = cond_kwargs['y'] classifier = cond_kwargs['classifier'] classifier_scale = cond_kwargs['classifier_scale'] gradient = cond_fn( classifier, sample, timestep, y=y, classifier_scale=classifier_scale) else: gradient = cond_fn( cond_kwargs.pop('unet'), self, sample, timestep, beta_prod_t, cond_kwargs.pop('model_stats'), **cond_kwargs) model_output = model_output - (beta_prod_t**0.5) * gradient pred_original_sample = ( sample - (beta_prod_t**(0.5)) * model_output) / alpha_prod_t**(0.5) # 4. Clip "predicted x_0" if self.clip_sample: pred_original_sample = torch.clamp(pred_original_sample, -1, 1) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = self._get_variance(timestep, prev_timestep) std_dev_t = eta * variance**(0.5) output.update(dict(sigma=std_dev_t)) if use_clipped_model_output: # the model_output is always # re-derived from the clipped x_0 in Glide model_output = (sample - (alpha_prod_t**(0.5)) * pred_original_sample) / beta_prod_t**(0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # noqa pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2)**(0.5) * model_output # 7. compute x_t without "random noise" of # formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_mean = alpha_prod_t_prev**( 0.5) * pred_original_sample + pred_sample_direction output.update(dict(mean=prev_mean, prev_sample=prev_mean)) if eta > 0: device = model_output.device if torch.is_tensor( model_output) else 'cpu' noise = torch.randn( model_output.shape, generator=generator).to(device) variance = std_dev_t * noise if not torch.is_tensor(model_output): variance = variance.numpy() prev_sample = prev_mean + variance output.update({'prev_sample': prev_sample}) # NOTE: this x0 is twice computed output.update({ 'original_sample': pred_original_sample, 'beta_prod_t': beta_prod_t }) return output
[文档] def add_noise(self, original_samples, noise, timesteps): """add noise.""" sqrt_alpha_prod = self.alphas_cumprod[timesteps]**0.5 sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps])**0.5 noisy_samples = ( sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise) return noisy_samples
[文档] def __len__(self): return self.num_train_timesteps
Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.