Spaces:
Runtime error
Runtime error
import torch | |
import numpy as np | |
class DDPMSampler: | |
def __init__(self, generator: torch.Generator, num_training_steps=1000, beta_start: float = 0.00085, beta_end: float = 0.0120): | |
# Params "beta_start" and "beta_end" taken from: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/configs/stable-diffusion/v1-inference.yaml#L5C8-L5C8 | |
# For the naming conventions, refer to the DDPM paper (https://arxiv.org/pdf/2006.11239.pdf) | |
self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_training_steps, dtype=torch.float32) ** 2 #beta | |
self.alphas = 1.0 - self.betas | |
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # alpha bar | |
self.one = torch.tensor(1.0) | |
self.generator = generator | |
self.num_train_timesteps = num_training_steps | |
self.timesteps = torch.from_numpy(np.arange(0, num_training_steps)[::-1].copy()) ##[999, 998, ...0] | |
def set_inference_timesteps(self, num_inference_steps=50): | |
# num_inference_steps = 50 | |
# step ratio = num_training_steps // inference_steps = 20 | |
self.num_inference_steps = num_inference_steps | |
step_ratio = self.num_train_timesteps // self.num_inference_steps # 1000/50 = 20 | |
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) #[980, 960, ..0] | |
self.timesteps = torch.from_numpy(timesteps) | |
def _get_previous_timestep(self, timestep: int) -> int: | |
prev_t = timestep - self.num_train_timesteps // self.num_inference_steps #eg: t = 960, t-1 = 960-20 = 940 | |
return prev_t | |
def _get_variance(self, timestep: int) -> torch.Tensor: | |
prev_t = self._get_previous_timestep(timestep) #t-1 | |
alpha_prod_t = self.alphas_cumprod[timestep] #alpha bar t | |
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one #alpha bar t minus 1 | |
current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev #beta t | |
# For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) | |
# and sample from it to get previous sample | |
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample | |
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t #variance# | |
# we always take the log of variance, so clamp it to ensure it's not 0 | |
variance = torch.clamp(variance, min=1e-20) | |
return variance | |
def set_strength(self, strength=1): | |
""" | |
Set how much noise to add to the input image. | |
More noise (strength ~ 1) means that the output will be further from the input image. | |
Less noise (strength ~ 0) means that the output will be closer to the input image. | |
""" | |
# more strength -> start step is approximately 0 that is model starts from pure noise and generates the image from it, strength = 1, start step = 50 - (50 * 1) = 0 | |
# less strenght -> start step is skipped till 50 so model has the less noisified image a time step 50, model reconstructs the image from the less noisified image, strength = 0, start_step = 50 | |
# start_step is the number of noise levels to skip | |
#eg inf_steps = 50, strength = 1, start step = 50 - (50 * 1) = 0, strength = 0, start_step = 50 | |
start_step = self.num_inference_steps - int(self.num_inference_steps * strength) | |
self.timesteps = self.timesteps[start_step:] #skip time_steps, if start_step = 50 8# | |
self.start_step = start_step #50, in this case | |
def step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor): | |
t = timestep #t | |
prev_t = self._get_previous_timestep(t) #t-1 | |
# 1. compute alphas, betas | |
alpha_prod_t = self.alphas_cumprod[t] #alpha_bar_t | |
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one #alpha_bar_t-1 | |
beta_prod_t = 1 - alpha_prod_t #beta_bar_t | |
beta_prod_t_prev = 1 - alpha_prod_t_prev #beta_bar_t-1 | |
current_alpha_t = alpha_prod_t / alpha_prod_t_prev #alpha_t | |
current_beta_t = 1 - current_alpha_t #beta_t | |
# 2. compute predicted original sample from predicted noise also called | |
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf | |
pred_original_sample = (latents - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) #x_0 - gaussian noise | |
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t | |
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf | |
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t #coeff_x_0 | |
current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t #coff_x_t | |
# 5. Compute predicted previous sample µ_t | |
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf | |
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents # | |
# 6. Add noise | |
variance = 0 | |
if t > 0: | |
device = model_output.device | |
noise = torch.randn(model_output.shape, generator=self.generator, device=device, dtype=model_output.dtype) | |
# Compute the variance as per formula (7) from https://arxiv.org/pdf/2006.11239.pdf | |
variance = (self._get_variance(t) ** 0.5) * noise | |
# sample from N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1) | |
# the variable "variance" is already multiplied by the noise N(0, 1) | |
pred_prev_sample = pred_prev_sample + variance #predicted xt-1 | |
return pred_prev_sample | |
def add_noise( | |
self, | |
original_samples: torch.FloatTensor, | |
timesteps: torch.IntTensor, | |
) -> torch.FloatTensor: | |
#forward noisification | |
#q(xt | x_not) = N(xt; sqrt(alpha_cumprod); (1 - alpha_cumprod)I) | |
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) #alpha_bar | |
timesteps = timesteps.to(original_samples.device) | |
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 #sqrt(alpha_bar_t) | |
sqrt_alpha_prod = sqrt_alpha_prod.flatten() #flatten | |
while len(sqrt_alpha_prod.shape) < len(original_samples.shape): #for boardcasting | |
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) | |
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 #sqrt(1 - alpha_bar_t) | |
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() | |
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): | |
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) | |
# Sample from q(x_t | x_0) as in equation (4) of https://arxiv.org/pdf/2006.11239.pdf | |
# Because N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1) | |
# here mu = sqrt_alpha_prod * original_samples and sigma = sqrt_one_minus_alpha_prod | |
noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype) #noise | |
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise #noisy samples | |
return noisy_samples |