Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
import pytorch_lightning as pl | |
def extract_into_tensor(a, t, x_shape): | |
b, *_ = t.shape | |
out = a.gather(-1, t) | |
return out.reshape(b, *((1,) * (len(x_shape) - 1))) | |
class DDIMSolver: | |
def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50): | |
# DDIM sampling parameters | |
step_ratio = timesteps // ddim_timesteps | |
self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1 | |
self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps] | |
self.ddim_alpha_cumprods_prev = np.asarray( | |
[alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist() | |
) | |
# convert to torch tensors | |
self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long() | |
self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods) | |
self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev) | |
def to(self, device): | |
self.ddim_timesteps = self.ddim_timesteps.to(device) | |
self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device) | |
self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device) | |
return self | |
def ddim_step(self, pred_x0, pred_noise, timestep_index): | |
alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev.to(pred_x0.device), timestep_index, pred_x0.shape) | |
dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise | |
x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt | |
return x_prev |