Spaces:
Runtime error
Runtime error
import os | |
import imageio | |
import numpy as np | |
from typing import Literal, Union, List, Dict, Tuple | |
import torch | |
import torchvision | |
import cv2 | |
from PIL import Image | |
from tqdm import tqdm | |
from einops import rearrange | |
import webp | |
import subprocess | |
from .. import logger | |
def save_videos_to_images(videos: np.array, path: str, image_type="png") -> None: | |
"""save video batch to images into image_type | |
Args: | |
videos (np.array): [h w c] | |
path (str): image directory path | |
""" | |
os.makedirs(path, exist_ok=True) | |
for i, video in enumerate(videos): | |
imageio.imsave(os.path.join(path, f"{i:04d}.{image_type}"), video) | |
def save_videos_grid( | |
videos: torch.Tensor, | |
path: str, | |
rescale=False, | |
n_rows=4, # 一行多少个视频 | |
fps=8, | |
save_type="webp", | |
) -> None: | |
videos = rearrange(videos, "b c t h w -> t b c h w") | |
outputs = [] | |
for x in videos: | |
x = torchvision.utils.make_grid(x, nrow=n_rows) | |
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) | |
if rescale: | |
x = (x + 1.0) / 2.0 # -1,1 -> 0,1 | |
if x.dtype != torch.uint8: | |
x = (x * 255).numpy().astype(np.uint8) | |
if save_type == "webp": | |
outputs.append(Image.fromarray(x)) | |
else: | |
outputs.append(x) | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
if "gif" in path or save_type == "gif": | |
params = { | |
"duration": int(1000 * 1.0 / fps), | |
"loop": 0, | |
} | |
elif save_type == "mp4": | |
params = { | |
"quality": 9, | |
"fps": fps, | |
"pixelformat": "yuv420p", | |
} | |
else: | |
params = { | |
"quality": 9, | |
"fps": fps, | |
} | |
if save_type == "webp": | |
webp.save_images(outputs, path, fps=fps, lossless=True) | |
else: | |
imageio.mimsave(path, outputs, **params) | |
def make_grid_with_opencv( | |
batch: Union[torch.Tensor, np.ndarray], | |
nrows: int, | |
texts: List[str] = None, | |
rescale: bool = False, | |
font_size: float = 0.05, | |
font_thickness: int = 1, | |
font_color: Tuple[int] = (255, 0, 0), | |
tensor_order: str = "b c h w", | |
write_info: bool = False, | |
) -> np.ndarray: | |
"""read tensor batch and make a grid with opencv | |
Args: | |
batch (Union[torch.Tensor, np.ndarray]): 4 dim tensor, like b c h w | |
nrows (int): how many rows in the grid | |
texts (List[str], optional): text to write in video . Defaults to None. | |
rescale (bool, optional): whether rescale [0,1] from [-1, 1]. Defaults to False. | |
font_size (float, optional): font size. Defaults to 0.05. | |
font_thickness (int, optional): font_thickness . Defaults to 1. | |
font_color (Tuple[int], optional): text color. Defaults to (255, 0, 0). | |
tensor_order (str, optional): batch channel order. Defaults to "b c h w". | |
write_info (bool, optional): whether write text into video. Defaults to True. | |
Returns: | |
np.ndarray: h w c | |
""" | |
if isinstance(batch, torch.Tensor): | |
batch = batch.cpu().numpy() | |
# batch: (B, C, H, W) | |
batch = rearrange(batch, f"{tensor_order} -> b h w c") | |
b, h, w, c = batch.shape | |
ncols = int(np.ceil(b / nrows)) | |
grid = np.zeros((h * nrows, w * ncols, c), dtype=np.uint8) | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
for i, x in enumerate(batch): | |
i_row, i_col = i // ncols, i % ncols | |
if rescale: | |
x = (x + 1.0) / 2.0 # -1,1 -> 0,1 | |
x = (x * 255).astype(np.uint8) | |
# 没有这行会报错 | |
# ref: https://stackoverflow.com/questions/72327137/opencv4-5-5-error-5bad-argument-in-function-puttext | |
x = x.copy() | |
if texts is not None and write_info: | |
x = cv2.putText( | |
x, | |
texts[i], | |
(5, 20), | |
font, | |
fontScale=font_size, | |
color=font_color, | |
thickness=font_thickness, | |
) | |
grid[i_row * h : (i_row + 1) * h, i_col * w : (i_col + 1) * w, :] = x | |
return grid | |
def save_videos_grid_with_opencv( | |
videos: Union[torch.Tensor, np.ndarray], | |
path: str, | |
n_cols: int, | |
texts: List[str] = None, | |
rescale: bool = False, | |
fps: int = 8, | |
font_size: int = 0.6, | |
font_thickness: int = 1, | |
font_color: Tuple[int] = (255, 0, 0), | |
tensor_order: str = "b c t h w", | |
batch_dim: int = 0, | |
split_size_or_sections: int = None, # split batch to avoid large video | |
write_info: bool = False, | |
save_filetype: Literal["gif", "mp4", "webp"] = "mp4", | |
save_images: bool = False, | |
) -> None: | |
"""存储tensor视频为gif、mp4等 | |
Args: | |
videos (Union[torch.Tensor, np.ndarray]): 五维视频tensor, 如 b c t h w,值范围[0-1] | |
path (str): 视频存储路径,后缀会影响存储方式 | |
n_cols (int): 由于b可能特别大,所以会分成几列 | |
texts (List[str], optional): b长度,会写在每个b视频左上角. Defaults to None. | |
rescale (bool, optional): 输入是[-1,1]时,应该为True. Defaults to False. | |
fps (int, optional): 存储视频的fps. Defaults to 8. | |
font_size (int, optional): text对应的字体大小. Defaults to 0.6. | |
font_thickness (int, optional): 字体宽度. Defaults to 1. | |
font_color (Tuple[int], optional): 字体颜色. Defaults to (255, 0, 0). | |
tensor_order (str, optional): 输入tensor的顺序,如果不是 `b c t h w`,会被转换成 b c t h w,. Defaults to "b c t h w". | |
batch_dim (int, optional): 有时候b特别大,这时候一个视频就太大了,就可以分成几个视频存储. Defaults to 0. | |
split_size_or_sections (int, optional): 不为None时,与batch_dim配套,一个存储视频最多支持几个子视频。会按照n_cols截断向上取整数. Defaults to None. | |
write_info (bool, False): 是否也些提示信息在视频上 | |
""" | |
if split_size_or_sections is not None: | |
split_size_or_sections = int(np.ceil(split_size_or_sections / n_cols)) * n_cols | |
if isinstance(videos, np.ndarray): | |
videos = torch.from_numpy(videos) | |
# 比np.array_split更适合 | |
videos_split = torch.split(videos, split_size_or_sections, dim=batch_dim) | |
videos_split = [videos.cpu().numpy() for videos in videos_split] | |
else: | |
videos_split = [videos] | |
n_videos_split = len(videos_split) | |
dirname, basename = os.path.dirname(path), os.path.basename(path) | |
filename, ext = os.path.splitext(basename) | |
os.makedirs(dirname, exist_ok=True) | |
for i_video, videos in enumerate(videos_split): | |
videos = rearrange(videos, f"{tensor_order} -> t b c h w") | |
outputs = [] | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
batch_size = videos.shape[1] | |
n_rows = int(np.ceil(batch_size / n_cols)) | |
for t, x in enumerate(videos): | |
x = make_grid_with_opencv( | |
x, | |
n_rows, | |
texts, | |
rescale, | |
font_size, | |
font_thickness, | |
font_color, | |
write_info=write_info, | |
) | |
h, w, c = x.shape | |
x = x.copy() | |
if write_info: | |
x = cv2.putText( | |
x, | |
str(t), | |
(5, h - 20), | |
font, | |
fontScale=2, | |
color=font_color, | |
thickness=font_thickness, | |
) | |
outputs.append(x) | |
logger.debug(f"outputs[0].shape: {outputs[0].shape}") | |
# TODO: 有待更新实现方式 | |
if i_video == 0 and n_videos_split == 1: | |
pass | |
else: | |
path = os.path.join(dirname, "{}_{}{}".format(filename, i_video, ext)) | |
if save_filetype == "gif": | |
params = { | |
"duration": int(1000 * 1.0 / fps), | |
"loop": 0, | |
} | |
imageio.mimsave(path, outputs, **params) | |
elif save_filetype == "mp4": | |
params = { | |
"quality": 9, | |
"fps": fps, | |
} | |
imageio.mimsave(path, outputs, **params) | |
elif save_filetype == "webp": | |
outputs = [Image.fromarray(x_tmp) for x_tmp in outputs] | |
webp.save_images(outputs, path, fps=fps, lossless=True) | |
else: | |
raise ValueError(f"Unsupported file type: {save_filetype}") | |
if save_images: | |
images_path = os.path.join(dirname, filename) | |
os.makedirs(images_path, exist_ok=True) | |
save_videos_to_images(outputs, images_path) | |
def export_to_video(videos: torch.Tensor, output_video_path: str, fps=8): | |
tmp_path = output_video_path.replace(".mp4", "_tmp.mp4") | |
videos = rearrange(videos, "b c t h w -> b t h w c") | |
videos = videos.squeeze() | |
videos = (videos * 255).cpu().detach().numpy().astype(np.uint8) # tensor -> numpy | |
fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
h, w, _ = videos[0].shape | |
video_writer = cv2.VideoWriter( | |
tmp_path, fourcc, fps=fps, frameSize=(w, h), isColor=True | |
) | |
for i in range(len(videos)): | |
img = cv2.cvtColor(videos[i], cv2.COLOR_RGB2BGR) | |
video_writer.write(img) | |
video_writer.release() # 要释放video writer,否则无法播放 | |
cmd = f"ffmpeg -y -i {tmp_path} -c:v libx264 -c:a aac -strict -2 {output_video_path} -loglevel quiet" | |
subprocess.run(cmd, shell=True) | |
os.remove(tmp_path) | |
# DDIM Inversion | |
def init_prompt(prompt, pipeline): | |
uncond_input = pipeline.tokenizer( | |
[""], | |
padding="max_length", | |
max_length=pipeline.tokenizer.model_max_length, | |
return_tensors="pt", | |
) | |
uncond_embeddings = pipeline.text_encoder( | |
uncond_input.input_ids.to(pipeline.device) | |
)[0] | |
text_input = pipeline.tokenizer( | |
[prompt], | |
padding="max_length", | |
max_length=pipeline.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] | |
context = torch.cat([uncond_embeddings, text_embeddings]) | |
return context | |
def next_step( | |
model_output: Union[torch.FloatTensor, np.ndarray], | |
timestep: int, | |
sample: Union[torch.FloatTensor, np.ndarray], | |
ddim_scheduler, | |
): | |
timestep, next_timestep = ( | |
min( | |
timestep | |
- ddim_scheduler.config.num_train_timesteps | |
// ddim_scheduler.num_inference_steps, | |
999, | |
), | |
timestep, | |
) | |
alpha_prod_t = ( | |
ddim_scheduler.alphas_cumprod[timestep] | |
if timestep >= 0 | |
else ddim_scheduler.final_alpha_cumprod | |
) | |
alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] | |
beta_prod_t = 1 - alpha_prod_t | |
next_original_sample = ( | |
sample - beta_prod_t**0.5 * model_output | |
) / alpha_prod_t**0.5 | |
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output | |
next_sample = ( | |
alpha_prod_t_next**0.5 * next_original_sample + next_sample_direction | |
) | |
return next_sample | |
def get_noise_pred_single(latents, t, context, unet): | |
noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] | |
return noise_pred | |
def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): | |
context = init_prompt(prompt, pipeline) | |
uncond_embeddings, cond_embeddings = context.chunk(2) | |
all_latent = [latent] | |
latent = latent.clone().detach() | |
for i in tqdm(range(num_inv_steps)): | |
t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] | |
noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) | |
latent = next_step(noise_pred, t, latent, ddim_scheduler) | |
all_latent.append(latent) | |
return all_latent | |
def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): | |
ddim_latents = ddim_loop( | |
pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt | |
) | |
return ddim_latents | |
def fn_recursive_search( | |
name: str, | |
module: torch.nn.Module, | |
target: str, | |
print_method=print, | |
print_name: str = "data", | |
): | |
if hasattr(module, target): | |
print_method( | |
[ | |
name + "." + target + "." + print_name, | |
getattr(getattr(module, target), print_name)[0].cpu().detach().numpy(), | |
] | |
) | |
parent_name = name | |
for name, child in module.named_children(): | |
fn_recursive_search( | |
parent_name + "." + name, child, target, print_method, print_name | |
) | |
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): | |
""" | |
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and | |
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 | |
""" | |
std_text = noise_pred_text.std( | |
dim=list(range(1, noise_pred_text.ndim)), keepdim=True | |
) | |
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) | |
# rescale the results from guidance (fixes overexposure) | |
noise_pred_rescaled = noise_cfg * (std_text / std_cfg) | |
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images | |
noise_cfg = ( | |
guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg | |
) | |
return noise_cfg | |