Spaces:
Sleeping
Sleeping
import gradio as gr | |
import spaces | |
import numpy as np | |
import torch | |
torch.jit.script = lambda f: f | |
import cv2 | |
import os | |
import imageio | |
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel | |
from controlnet_aux import LineartDetector | |
from functools import partial | |
from PIL import Image | |
from torch.utils.data import DataLoader, Dataset | |
from torchvision.transforms import Compose, ToTensor, Normalize, Resize | |
from NaRCan_model import Homography, Siren | |
from util import get_mgrid, apply_homography, jacobian, VideoFitting, TestVideoFitting | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
def get_example(): | |
case = [ | |
[ | |
'examples/bear.mp4', | |
], | |
[ | |
'examples/boat.mp4', | |
], | |
[ | |
'examples/woman-drink.mp4', | |
], | |
[ | |
'examples/corgi.mp4', | |
], | |
[ | |
'examples/yacht.mp4', | |
], | |
[ | |
'examples/koolshooters.mp4', | |
], | |
[ | |
'examples/overlook-the-ocean.mp4', | |
], | |
[ | |
'examples/rotate.mp4', | |
], | |
[ | |
'examples/shark-ocean.mp4', | |
], | |
[ | |
'examples/surf.mp4', | |
], | |
[ | |
'examples/cactus.mp4', | |
], | |
[ | |
'examples/gold-fish.mp4', | |
] | |
] | |
return case | |
def set_default_prompt(video_name): | |
video_to_prompt = { | |
'bear.mp4': 'bear, Van Gogh Style', | |
'boat.mp4': 'a burning boat sails on lava', | |
'cactus.mp4': 'cactus, made of paper', | |
'corgi.mp4': 'a hellhound', | |
'gold-fish.mp4': 'Goldfish in the Milky Way', | |
'koolshooters.mp4': 'Avatar', | |
'overlook-the-ocean.mp4': 'ocean, pixel style', | |
'rotate.mp4': 'turbine engine', | |
'shark-ocean.mp4': 'A sleek shark, cartoon style', | |
'surf.mp4': 'Sailing, The background is a large white cloud, sketch style', | |
'woman-drink.mp4': 'a drinking zombie', | |
'yacht.mp4': 'yacht, cyberpunk style', | |
} | |
return video_to_prompt.get(video_name, '') | |
def update_prompt(input_video): | |
video_name = input_video.split('/')[-1] | |
return set_default_prompt(video_name) | |
# Map videos to corresponding images | |
video_to_image = { | |
'bear.mp4': ['canonical/bear.png', 'pth_file/bear', 'examples_frames/bear'], | |
'boat.mp4': ['canonical/boat.png', 'pth_file/boat', 'examples_frames/boat'], | |
'cactus.mp4': ['canonical/cactus.png', 'pth_file/cactus', 'examples_frames/cactus'], | |
'corgi.mp4': ['canonical/corgi.png', 'pth_file/corgi', 'examples_frames/corgi'], | |
'gold-fish.mp4': ['canonical/gold-fish.png', 'pth_file/gold-fish', 'examples_frames/gold-fish'], | |
'koolshooters.mp4': ['canonical/koolshooters.png', 'pth_file/koolshooters', 'examples_frames/koolshooters'], | |
'overlook-the-ocean.mp4': ['canonical/overlook-the-ocean.png', 'pth_file/overlook-the-ocean', 'examples_frames/overlook-the-ocean'], | |
'rotate.mp4': ['canonical/rotate.png', 'pth_file/rotate', 'examples_frames/rotate'], | |
'shark-ocean.mp4': ['canonical/shark-ocean.png', 'pth_file/shark-ocean', 'examples_frames/shark-ocean'], | |
'surf.mp4': ['canonical/surf.png', 'pth_file/surf', 'examples_frames/surf'], | |
'woman-drink.mp4': ['canonical/woman-drink.png', 'pth_file/woman-drink', 'examples_frames/woman-drink'], | |
'yacht.mp4': ['canonical/yacht.png', 'pth_file/yacht', 'examples_frames/yacht'], | |
} | |
def images_to_video(image_list, output_path, fps=10): | |
# Convert PIL Images to numpy arrays | |
frames = [np.array(img).astype(np.uint8) for img in image_list] | |
frames = frames[:20] | |
# Create video writer | |
writer = imageio.get_writer(output_path, fps=fps, codec='libx264') | |
for frame in frames: | |
writer.append_data(frame) | |
writer.close() | |
def NaRCan_make_video(edit_canonical, pth_path, frames_path): | |
# load NaRCan model | |
checkpoint_g_old = torch.load(os.path.join(pth_path, "homography_g.pth")) | |
checkpoint_g = torch.load(os.path.join(pth_path, "mlp_g.pth")) | |
g_old = Homography(hidden_features=256, hidden_layers=2).to(device) | |
g = Siren(in_features=3, out_features=2, hidden_features=256, | |
hidden_layers=5, outermost_linear=True).to(device) | |
g_old.load_state_dict(checkpoint_g_old) | |
g.load_state_dict(checkpoint_g) | |
g_old.eval() | |
g.eval() | |
transform = Compose([ | |
Resize(512), | |
ToTensor(), | |
Normalize(torch.Tensor([0.5, 0.5, 0.5]), torch.Tensor([0.5, 0.5, 0.5])) | |
]) | |
v = TestVideoFitting(frames_path, transform) | |
videoloader = DataLoader(v, batch_size=1, pin_memory=True, num_workers=0) | |
model_input, ground_truth = next(iter(videoloader)) | |
model_input, ground_truth = model_input[0].to(device), ground_truth[0].to(device) | |
myoutput = None | |
data_len = len(os.listdir(frames_path)) | |
with torch.no_grad(): | |
batch_size = (v.H * v.W) | |
for step in range(data_len): | |
start = (step * batch_size) % len(model_input) | |
end = min(start + batch_size, len(model_input)) | |
# get the deformation | |
xy, t = model_input[start:end, :-1], model_input[start:end, [-1]] | |
xyt = model_input[start:end] | |
h_old = apply_homography(xy, g_old(t)) | |
h = g(xyt) | |
xy_ = h_old + h | |
# use canonical to reconstruct | |
w, h = v.W, v.H | |
canonical_img = np.array(edit_canonical.convert('RGB')) | |
canonical_img = torch.from_numpy(canonical_img).float().to(device) | |
h_c, w_c = canonical_img.shape[:2] | |
grid_new = xy_.clone() | |
grid_new[..., 1] = xy_[..., 0] / 1.5 | |
grid_new[..., 0] = xy_[..., 1] / 2.0 | |
if len(canonical_img.shape) == 3: | |
canonical_img = canonical_img.unsqueeze(0) | |
results = torch.nn.functional.grid_sample( | |
canonical_img.permute(0, 3, 1, 2), | |
grid_new.unsqueeze(1).unsqueeze(0), | |
mode='bilinear', | |
padding_mode='border') | |
o = results.squeeze().permute(1,0) | |
if step == 0: | |
myoutput = o | |
else: | |
myoutput = torch.cat([myoutput, o]) | |
myoutput = myoutput.reshape(512, 512, data_len, 3).permute(2, 0, 1, 3).clone().detach().cpu().numpy().astype(np.float32) | |
# myoutput = np.clip(myoutput, -1, 1) * 0.5 + 0.5 | |
for i in range(len(myoutput)): | |
myoutput[i] = Image.fromarray(np.uint8(myoutput[i])).resize((512, 512)) #854, 480 | |
edit_video_path = f'NaRCan_fps_10.mp4' | |
images_to_video(myoutput, edit_video_path) | |
return edit_video_path | |
def edit_with_pnp(input_video, prompt, num_steps, guidance_scale, seed, n_prompt, control_type="Lineart"): | |
video_name = input_video.split('/')[-1] | |
if video_name in video_to_image: | |
image_path = video_to_image[video_name][0] | |
pth_path = video_to_image[video_name][1] | |
frames_path = video_to_image[video_name][2] | |
else: | |
return None | |
if control_type == "Lineart": | |
# Load the control net model for lineart | |
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_lineart", torch_dtype=torch.float16) | |
pipe = StableDiffusionControlNetPipeline.from_pretrained( | |
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 | |
) | |
pipe.to(device) | |
# lineart | |
processor = LineartDetector.from_pretrained("lllyasviel/Annotators") | |
processor_partial = partial(processor, coarse=False) | |
size_ = 768 | |
canonical_image = Image.open(image_path) | |
ori_size = canonical_image.size | |
image = processor_partial(canonical_image.resize((size_, size_)), detect_resolution=size_, image_resolution=size_) | |
image = image.resize(ori_size, resample=Image.BILINEAR) | |
image.save("control_map.png") | |
generator = torch.manual_seed(seed) if seed != -1 else None | |
output_images = pipe( | |
prompt=prompt, | |
image=image, | |
num_inference_steps=num_steps, | |
guidance_scale=guidance_scale, | |
negative_prompt=n_prompt, | |
generator=generator | |
).images | |
output_images[0].save("edited_canonical_image.png") | |
# output_images[0] = output_images[0].resize(ori_size, resample=Image.BILINEAR) | |
else: | |
# Load the control net model for canny | |
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16) | |
pipe = StableDiffusionControlNetPipeline.from_pretrained( | |
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 | |
) | |
pipe.to(device) | |
# canny | |
canonical_image = cv2.imread(image_path) | |
canonical_image = cv2.cvtColor(canonical_image, cv2.COLOR_BGR2RGB) | |
image = cv2.cvtColor(canonical_image, cv2.COLOR_RGB2GRAY) | |
image = cv2.Canny(image, 100, 200) | |
image = image[:, :, None] | |
image = np.concatenate([image, image, image], axis=2) | |
image = Image.fromarray(image) | |
image.save("control_map.png") | |
generator = torch.manual_seed(seed) if seed != -1 else None | |
output_images = pipe( | |
prompt=prompt, | |
image=image, | |
num_inference_steps=num_steps, | |
guidance_scale=guidance_scale, | |
negative_prompt=n_prompt, | |
generator=generator | |
).images | |
output_images[0].save("edited_canonical_image.png") | |
edit_video_path = NaRCan_make_video(output_images[0], pth_path, frames_path) | |
edit_image_path = [ | |
(image_path, "canonical image"), | |
("control_map.png", "control map"), | |
("edited_canonical_image.png", "edited canonical image") | |
] | |
# Here we return the first output image as the result | |
return edit_video_path, edit_image_path | |
######## | |
# demo # | |
######## | |
intro = """ | |
<div style="text-align:center"> | |
<h1 style="font-weight: 1400; text-align: center; margin-bottom: 7px;"> | |
NaRCan - <small>Natural Refined Canonical Image</small> | |
</h1> | |
<span>[<a target="_blank" href="https://koi953215.github.io/NaRCan_page/">Project page</a>], [<a target="_blank" href="https://huggingface.co/papers/2406.06523">Paper</a>]</span> | |
<div style="display:flex; justify-content: center;margin-top: 0.5em">Try selecting different control types (Canny or Lineart) in Advanced options!</div> | |
</div> | |
""" | |
with gr.Blocks(css="style.css") as demo: | |
gr.HTML(intro) | |
frames = gr.State() | |
inverted_latents = gr.State() | |
latents = gr.State() | |
zs = gr.State() | |
do_inversion = gr.State(value=True) | |
with gr.Row(): | |
input_video = gr.Video(label="Input Video", interactive=False, elem_id="input_video", value='examples/bear.mp4', height=365, width=365) | |
output_video = gr.Video(label="Edited Video", interactive=False, elem_id="output_video", height=365, width=365) | |
# input_video.style(height=365, width=365) | |
# output_video.style(height=365, width=365) | |
with gr.Row(): | |
prompt = gr.Textbox( | |
label="Describe your edited video", | |
max_lines=1, | |
value="bear, Van Gogh Style" | |
# placeholder="bear, Van Gogh Style" | |
) | |
with gr.Row(): | |
canonical_result = gr.Gallery(label="Edited Canonical Image", columns=3) | |
with gr.Row(): | |
run_button = gr.Button("Edit your video!", visible=True) | |
max_images = 12 | |
default_num_images = 3 | |
with gr.Accordion('Advanced options', open=False): | |
control_type = gr.Dropdown( | |
["Canny", "Lineart"], | |
label="Control Type", | |
info="Canny or Lineart", | |
value="Lineart" | |
) | |
num_steps = gr.Slider(label='Steps', | |
minimum=1, | |
maximum=100, | |
value=20, | |
step=1) | |
guidance_scale = gr.Slider(label='Guidance Scale', | |
minimum=0.1, | |
maximum=30.0, | |
value=9.0, | |
step=0.1) | |
seed = gr.Slider(label='Seed', | |
minimum=-1, | |
maximum=2147483647, | |
step=1, | |
randomize=True) | |
n_prompt = gr.Textbox( | |
label='Negative Prompt', | |
value="" | |
) | |
input_video.change( | |
fn = update_prompt, | |
inputs = [input_video], | |
outputs = [prompt], | |
queue = False) | |
run_button.click(fn = edit_with_pnp, | |
inputs = [input_video, | |
prompt, | |
num_steps, | |
guidance_scale, | |
seed, | |
n_prompt, | |
control_type, | |
], | |
outputs = [output_video, canonical_result] | |
) | |
gr.Examples( | |
examples=get_example(), | |
label='Examples', | |
inputs=[input_video], | |
outputs=[output_video], | |
examples_per_page=8 | |
) | |
demo.queue() | |
demo.launch() |