Spaces:
Runtime error
Runtime error
import torch | |
from svd import StableVideoDiffusionPipeline | |
from diffusers import DDIMScheduler | |
from PIL import Image | |
import numpy as np | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class StableVideoDiffusion: | |
def __init__( | |
self, | |
device, | |
fp16=True, | |
t_range=[0.02, 0.98], | |
): | |
super().__init__() | |
self.guidance_type = [ | |
'sds', | |
'pixel reconstruction', | |
'latent reconstruction' | |
][1] | |
self.device = device | |
self.dtype = torch.float16 if fp16 else torch.float32 | |
# Create model | |
pipe = StableVideoDiffusionPipeline.from_pretrained( | |
"stabilityai/stable-video-diffusion-img2vid", torch_dtype=torch.float16, variant="fp16" | |
) | |
pipe.to(device) | |
self.pipe = pipe | |
self.num_train_timesteps = self.pipe.scheduler.config.num_train_timesteps if self.guidance_type == 'sds' else 25 | |
self.pipe.scheduler.set_timesteps(self.num_train_timesteps, device=device) # set sigma for euler discrete scheduling | |
self.min_step = int(self.num_train_timesteps * t_range[0]) | |
self.max_step = int(self.num_train_timesteps * t_range[1]) | |
self.alphas = self.pipe.scheduler.alphas_cumprod.to(self.device) # for convenience | |
self.embeddings = None | |
self.image = None | |
self.target_cache = None | |
def get_img_embeds(self, image): | |
self.image = Image.fromarray(np.uint8(image*255)) | |
def encode_image(self, image): | |
image = image * 2 -1 | |
latents = self.pipe._encode_vae_image(image, self.device, num_videos_per_prompt=1, do_classifier_free_guidance=False) | |
latents = self.pipe.vae.config.scaling_factor * latents | |
return latents | |
def refine(self, | |
pred_rgb, | |
steps=25, strength=0.8, | |
min_guidance_scale: float = 1.0, | |
max_guidance_scale: float = 3.0, | |
): | |
# strength = 0.8 | |
batch_size = pred_rgb.shape[0] | |
pred_rgb = pred_rgb.to(self.dtype) | |
# interp to 512x512 to be fed into vae. | |
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode="bilinear", align_corners=False) | |
# encode image into latents with vae, requires grad! | |
# latents = [] | |
# for i in range(batch_size): | |
# latent = self.encode_image(pred_rgb_512[i:i+1]) | |
# latents.append(latent) | |
# latents = torch.cat(latents, 0) | |
latents = self.encode_image(pred_rgb_512) | |
latents = latents.unsqueeze(0) | |
if strength == 0: | |
init_step = 0 | |
latents = torch.randn_like(latents) | |
else: | |
init_step = int(steps * strength) | |
latents = self.pipe.scheduler.add_noise(latents, torch.randn_like(latents), self.pipe.scheduler.timesteps[init_step:init_step+1]) | |
target = self.pipe( | |
image=self.image, | |
height=512, | |
width=512, | |
latents=latents, | |
denoise_beg=init_step, | |
denoise_end=steps, | |
output_type='frame', | |
num_frames=batch_size, | |
min_guidance_scale=min_guidance_scale, | |
max_guidance_scale=max_guidance_scale, | |
num_inference_steps=steps, | |
decode_chunk_size=1 | |
).frames[0] | |
target = (target + 1) * 0.5 | |
target = target.permute(1,0,2,3) | |
return target | |
# frames = self.pipe( | |
# image=self.image, | |
# height=512, | |
# width=512, | |
# latents=latents, | |
# denoise_beg=init_step, | |
# denoise_end=steps, | |
# num_frames=batch_size, | |
# min_guidance_scale=min_guidance_scale, | |
# max_guidance_scale=max_guidance_scale, | |
# num_inference_steps=steps, | |
# decode_chunk_size=1 | |
# ).frames[0] | |
# export_to_gif(frames, f"tmp.gif") | |
# raise | |
def train_step( | |
self, | |
pred_rgb, | |
step_ratio=None, | |
min_guidance_scale: float = 1.0, | |
max_guidance_scale: float = 3.0, | |
): | |
batch_size = pred_rgb.shape[0] | |
pred_rgb = pred_rgb.to(self.dtype) | |
# interp to 512x512 to be fed into vae. | |
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode="bilinear", align_corners=False) | |
# encode image into latents with vae, requires grad! | |
# latents = self.pipe._encode_image(pred_rgb_512, self.device, num_videos_per_prompt=1, do_classifier_free_guidance=True) | |
latents = self.encode_image(pred_rgb_512) | |
latents = latents.unsqueeze(0) | |
if step_ratio is not None: | |
# dreamtime-like | |
# t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio) | |
t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) | |
t = torch.full((1,), t, dtype=torch.long, device=self.device) | |
else: | |
t = torch.randint(self.min_step, self.max_step + 1, (1,), dtype=torch.long, device=self.device) | |
# print(t) | |
w = (1 - self.alphas[t]).view(1, 1, 1, 1) | |
if self.guidance_type == 'sds': | |
# predict the noise residual with unet, NO grad! | |
with torch.no_grad(): | |
t = self.num_train_timesteps - t.item() | |
# add noise | |
noise = torch.randn_like(latents) | |
latents_noisy = self.pipe.scheduler.add_noise(latents, noise, self.pipe.scheduler.timesteps[t:t+1]) # t=0 noise;t=999 clean | |
noise_pred = self.pipe( | |
image=self.image, | |
# image_embeddings=self.embeddings, | |
height=512, | |
width=512, | |
latents=latents_noisy, | |
output_type='noise', | |
denoise_beg=t, | |
denoise_end=t + 1, | |
min_guidance_scale=min_guidance_scale, | |
max_guidance_scale=max_guidance_scale, | |
num_frames=batch_size, | |
num_inference_steps=self.num_train_timesteps | |
).frames[0] | |
grad = w * (noise_pred - noise) | |
grad = torch.nan_to_num(grad) | |
target = (latents - grad).detach() | |
loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[1] | |
print(loss.item()) | |
return loss | |
elif self.guidance_type == 'pixel reconstruction': | |
# pixel space reconstruction | |
if self.target_cache is None: | |
with torch.no_grad(): | |
self.target_cache = self.pipe( | |
image=self.image, | |
height=512, | |
width=512, | |
output_type='frame', | |
num_frames=batch_size, | |
num_inference_steps=self.num_train_timesteps, | |
decode_chunk_size=1 | |
).frames[0] | |
self.target_cache = (self.target_cache + 1) * 0.5 | |
self.target_cache = self.target_cache.permute(1,0,2,3) | |
loss = 0.5 * F.mse_loss(pred_rgb_512.float(), self.target_cache.detach().float(), reduction='sum') / latents.shape[1] | |
print(loss.item()) | |
return loss | |
elif self.guidance_type == 'latent reconstruction': | |
# latent space reconstruction | |
if self.target_cache is None: | |
with torch.no_grad(): | |
self.target_cache = self.pipe( | |
image=self.image, | |
height=512, | |
width=512, | |
output_type='latent', | |
num_frames=batch_size, | |
num_inference_steps=self.num_train_timesteps, | |
).frames[0] | |
loss = 0.5 * F.mse_loss(latents.float(), self.target_cache.detach().float(), reduction='sum') / latents.shape[1] | |
print(loss.item()) | |
return loss |