dreamgaussian4d / guidance /svd_utils.py
jiaweir
init
21c4e64
import torch
from svd import StableVideoDiffusionPipeline
from diffusers import DDIMScheduler
from PIL import Image
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
class StableVideoDiffusion:
def __init__(
self,
device,
fp16=True,
t_range=[0.02, 0.98],
):
super().__init__()
self.guidance_type = [
'sds',
'pixel reconstruction',
'latent reconstruction'
][1]
self.device = device
self.dtype = torch.float16 if fp16 else torch.float32
# Create model
pipe = StableVideoDiffusionPipeline.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid", torch_dtype=torch.float16, variant="fp16"
)
pipe.to(device)
self.pipe = pipe
self.num_train_timesteps = self.pipe.scheduler.config.num_train_timesteps if self.guidance_type == 'sds' else 25
self.pipe.scheduler.set_timesteps(self.num_train_timesteps, device=device) # set sigma for euler discrete scheduling
self.min_step = int(self.num_train_timesteps * t_range[0])
self.max_step = int(self.num_train_timesteps * t_range[1])
self.alphas = self.pipe.scheduler.alphas_cumprod.to(self.device) # for convenience
self.embeddings = None
self.image = None
self.target_cache = None
@torch.no_grad()
def get_img_embeds(self, image):
self.image = Image.fromarray(np.uint8(image*255))
def encode_image(self, image):
image = image * 2 -1
latents = self.pipe._encode_vae_image(image, self.device, num_videos_per_prompt=1, do_classifier_free_guidance=False)
latents = self.pipe.vae.config.scaling_factor * latents
return latents
def refine(self,
pred_rgb,
steps=25, strength=0.8,
min_guidance_scale: float = 1.0,
max_guidance_scale: float = 3.0,
):
# strength = 0.8
batch_size = pred_rgb.shape[0]
pred_rgb = pred_rgb.to(self.dtype)
# interp to 512x512 to be fed into vae.
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode="bilinear", align_corners=False)
# encode image into latents with vae, requires grad!
# latents = []
# for i in range(batch_size):
# latent = self.encode_image(pred_rgb_512[i:i+1])
# latents.append(latent)
# latents = torch.cat(latents, 0)
latents = self.encode_image(pred_rgb_512)
latents = latents.unsqueeze(0)
if strength == 0:
init_step = 0
latents = torch.randn_like(latents)
else:
init_step = int(steps * strength)
latents = self.pipe.scheduler.add_noise(latents, torch.randn_like(latents), self.pipe.scheduler.timesteps[init_step:init_step+1])
target = self.pipe(
image=self.image,
height=512,
width=512,
latents=latents,
denoise_beg=init_step,
denoise_end=steps,
output_type='frame',
num_frames=batch_size,
min_guidance_scale=min_guidance_scale,
max_guidance_scale=max_guidance_scale,
num_inference_steps=steps,
decode_chunk_size=1
).frames[0]
target = (target + 1) * 0.5
target = target.permute(1,0,2,3)
return target
# frames = self.pipe(
# image=self.image,
# height=512,
# width=512,
# latents=latents,
# denoise_beg=init_step,
# denoise_end=steps,
# num_frames=batch_size,
# min_guidance_scale=min_guidance_scale,
# max_guidance_scale=max_guidance_scale,
# num_inference_steps=steps,
# decode_chunk_size=1
# ).frames[0]
# export_to_gif(frames, f"tmp.gif")
# raise
def train_step(
self,
pred_rgb,
step_ratio=None,
min_guidance_scale: float = 1.0,
max_guidance_scale: float = 3.0,
):
batch_size = pred_rgb.shape[0]
pred_rgb = pred_rgb.to(self.dtype)
# interp to 512x512 to be fed into vae.
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode="bilinear", align_corners=False)
# encode image into latents with vae, requires grad!
# latents = self.pipe._encode_image(pred_rgb_512, self.device, num_videos_per_prompt=1, do_classifier_free_guidance=True)
latents = self.encode_image(pred_rgb_512)
latents = latents.unsqueeze(0)
if step_ratio is not None:
# dreamtime-like
# t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)
t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)
t = torch.full((1,), t, dtype=torch.long, device=self.device)
else:
t = torch.randint(self.min_step, self.max_step + 1, (1,), dtype=torch.long, device=self.device)
# print(t)
w = (1 - self.alphas[t]).view(1, 1, 1, 1)
if self.guidance_type == 'sds':
# predict the noise residual with unet, NO grad!
with torch.no_grad():
t = self.num_train_timesteps - t.item()
# add noise
noise = torch.randn_like(latents)
latents_noisy = self.pipe.scheduler.add_noise(latents, noise, self.pipe.scheduler.timesteps[t:t+1]) # t=0 noise;t=999 clean
noise_pred = self.pipe(
image=self.image,
# image_embeddings=self.embeddings,
height=512,
width=512,
latents=latents_noisy,
output_type='noise',
denoise_beg=t,
denoise_end=t + 1,
min_guidance_scale=min_guidance_scale,
max_guidance_scale=max_guidance_scale,
num_frames=batch_size,
num_inference_steps=self.num_train_timesteps
).frames[0]
grad = w * (noise_pred - noise)
grad = torch.nan_to_num(grad)
target = (latents - grad).detach()
loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[1]
print(loss.item())
return loss
elif self.guidance_type == 'pixel reconstruction':
# pixel space reconstruction
if self.target_cache is None:
with torch.no_grad():
self.target_cache = self.pipe(
image=self.image,
height=512,
width=512,
output_type='frame',
num_frames=batch_size,
num_inference_steps=self.num_train_timesteps,
decode_chunk_size=1
).frames[0]
self.target_cache = (self.target_cache + 1) * 0.5
self.target_cache = self.target_cache.permute(1,0,2,3)
loss = 0.5 * F.mse_loss(pred_rgb_512.float(), self.target_cache.detach().float(), reduction='sum') / latents.shape[1]
print(loss.item())
return loss
elif self.guidance_type == 'latent reconstruction':
# latent space reconstruction
if self.target_cache is None:
with torch.no_grad():
self.target_cache = self.pipe(
image=self.image,
height=512,
width=512,
output_type='latent',
num_frames=batch_size,
num_inference_steps=self.num_train_timesteps,
).frames[0]
loss = 0.5 * F.mse_loss(latents.float(), self.target_cache.detach().float(), reduction='sum') / latents.shape[1]
print(loss.item())
return loss