Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import os
|
2 |
import gc
|
3 |
import torch
|
4 |
-
import cv2
|
5 |
import gradio as gr
|
6 |
import numpy as np
|
7 |
import matplotlib.cm as cm
|
@@ -42,9 +42,11 @@ model_name = encoder2name[encoder]
|
|
42 |
|
43 |
# Initialize the model
|
44 |
video_depth_anything = VideoDepthAnything(**model_configs[encoder])
|
45 |
-
filepath = hf_hub_download(
|
46 |
-
|
47 |
-
|
|
|
|
|
48 |
video_depth_anything.load_state_dict(torch.load(filepath, map_location='cpu'))
|
49 |
video_depth_anything = video_depth_anything.to(DEVICE).eval()
|
50 |
|
@@ -52,8 +54,6 @@ title = "# Video Depth Anything"
|
|
52 |
description = """Official demo for **Video Depth Anything**.
|
53 |
Please refer to our [paper](https://arxiv.org/abs/2501.12375), [project page](https://videodepthanything.github.io/), and [github](https://github.com/DepthAnything/Video-Depth-Anything) for more details."""
|
54 |
|
55 |
-
|
56 |
-
@gr.processing_utils.threaded # alternativ kann spaces.GPU genutzt werden, falls verfügbar
|
57 |
def infer_video_depth(
|
58 |
input_video: str,
|
59 |
max_len: int = -1,
|
@@ -67,14 +67,14 @@ def infer_video_depth(
|
|
67 |
):
|
68 |
# Read input video frames
|
69 |
frames, target_fps = read_video_frames(input_video, max_len, target_fps, max_res)
|
70 |
-
# Infer
|
71 |
depths, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=input_size, device=DEVICE)
|
72 |
|
73 |
video_name = os.path.basename(input_video)
|
74 |
if not os.path.exists(output_dir):
|
75 |
os.makedirs(output_dir)
|
76 |
|
77 |
-
# Save the processed (RGB) video and the depth visualization
|
78 |
processed_video_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_src.mp4')
|
79 |
depth_vis_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_vis.mp4')
|
80 |
save_video(frames, processed_video_path, fps=fps)
|
@@ -82,7 +82,7 @@ def infer_video_depth(
|
|
82 |
|
83 |
stitched_video_path = ""
|
84 |
if stitch:
|
85 |
-
# Create a stitched video
|
86 |
d_min, d_max = depths.min(), depths.max()
|
87 |
stitched_frames = []
|
88 |
for i in range(min(len(frames), len(depths))):
|
@@ -90,18 +90,18 @@ def infer_video_depth(
|
|
90 |
depth_frame = depths[i]
|
91 |
# Normalize the depth frame to [0, 255]
|
92 |
depth_norm = ((depth_frame - d_min) / (d_max - d_min) * 255).astype(np.uint8)
|
93 |
-
#
|
94 |
if grayscale:
|
95 |
depth_vis = np.stack([depth_norm] * 3, axis=-1)
|
96 |
else:
|
97 |
cmap = cm.get_cmap("inferno")
|
98 |
-
# cmap liefert RGBA
|
99 |
depth_vis = (cmap(depth_norm / 255.0)[..., :3] * 255).astype(np.uint8)
|
100 |
-
# Apply Gaussian blur if requested
|
101 |
if blur > 0:
|
102 |
kernel_size = int(blur * 20) * 2 + 1 # ensures odd kernel size
|
103 |
depth_vis = cv2.GaussianBlur(depth_vis, (kernel_size, kernel_size), 0)
|
104 |
-
# Concatenate side-by-side
|
105 |
stitched = cv2.hconcat([rgb_frame, depth_vis])
|
106 |
stitched_frames.append(stitched)
|
107 |
stitched_frames = np.array(stitched_frames)
|
@@ -111,17 +111,15 @@ def infer_video_depth(
|
|
111 |
gc.collect()
|
112 |
torch.cuda.empty_cache()
|
113 |
|
114 |
-
# Return
|
115 |
-
# Falls stitch nicht aktiviert ist, wird ein leerer String zurückgegeben.
|
116 |
return [processed_video_path, depth_vis_path, stitched_video_path]
|
117 |
|
118 |
-
|
119 |
def construct_demo():
|
120 |
with gr.Blocks(analytics_enabled=False) as demo:
|
121 |
gr.Markdown(title)
|
122 |
gr.Markdown(description)
|
123 |
gr.Markdown("### If you find this work useful, please help ⭐ the [Github Repo](https://github.com/DepthAnything/Video-Depth-Anything). Thanks for your attention!")
|
124 |
-
|
125 |
with gr.Row(equal_height=True):
|
126 |
with gr.Column(scale=1):
|
127 |
input_video = gr.Video(label="Input Video", source="upload", type="filepath")
|
@@ -130,6 +128,7 @@ def construct_demo():
|
|
130 |
processed_video = gr.Video(label="Preprocessed Video", interactive=False, autoplay=True, loop=True, show_share_button=True, scale=5)
|
131 |
depth_vis_video = gr.Video(label="Generated Depth Video", interactive=False, autoplay=True, loop=True, show_share_button=True, scale=5)
|
132 |
stitched_video = gr.Video(label="Stitched RGBD Video", interactive=False, autoplay=True, loop=True, show_share_button=True, scale=5)
|
|
|
133 |
with gr.Row(equal_height=True):
|
134 |
with gr.Column(scale=1):
|
135 |
with gr.Accordion("Advanced Settings", open=False):
|
@@ -142,7 +141,7 @@ def construct_demo():
|
|
142 |
generate_btn = gr.Button("Generate")
|
143 |
with gr.Column(scale=2):
|
144 |
pass
|
145 |
-
|
146 |
gr.Examples(
|
147 |
examples=examples,
|
148 |
inputs=[input_video, max_len, target_fps, max_res, stitch_option, grayscale_option, blur_slider],
|
@@ -150,16 +149,16 @@ def construct_demo():
|
|
150 |
fn=infer_video_depth,
|
151 |
cache_examples="lazy",
|
152 |
)
|
153 |
-
|
154 |
generate_btn.click(
|
155 |
fn=infer_video_depth,
|
156 |
inputs=[input_video, max_len, target_fps, max_res, stitch_option, grayscale_option, blur_slider],
|
157 |
outputs=[processed_video, depth_vis_video, stitched_video],
|
158 |
)
|
159 |
-
|
160 |
return demo
|
161 |
|
162 |
if __name__ == "__main__":
|
163 |
demo = construct_demo()
|
164 |
-
demo.queue()
|
165 |
demo.launch(share=True)
|
|
|
1 |
import os
|
2 |
import gc
|
3 |
import torch
|
4 |
+
import cv2
|
5 |
import gradio as gr
|
6 |
import numpy as np
|
7 |
import matplotlib.cm as cm
|
|
|
42 |
|
43 |
# Initialize the model
|
44 |
video_depth_anything = VideoDepthAnything(**model_configs[encoder])
|
45 |
+
filepath = hf_hub_download(
|
46 |
+
repo_id=f"depth-anything/Video-Depth-Anything-{model_name}",
|
47 |
+
filename=f"video_depth_anything_{encoder}.pth",
|
48 |
+
repo_type="model"
|
49 |
+
)
|
50 |
video_depth_anything.load_state_dict(torch.load(filepath, map_location='cpu'))
|
51 |
video_depth_anything = video_depth_anything.to(DEVICE).eval()
|
52 |
|
|
|
54 |
description = """Official demo for **Video Depth Anything**.
|
55 |
Please refer to our [paper](https://arxiv.org/abs/2501.12375), [project page](https://videodepthanything.github.io/), and [github](https://github.com/DepthAnything/Video-Depth-Anything) for more details."""
|
56 |
|
|
|
|
|
57 |
def infer_video_depth(
|
58 |
input_video: str,
|
59 |
max_len: int = -1,
|
|
|
67 |
):
|
68 |
# Read input video frames
|
69 |
frames, target_fps = read_video_frames(input_video, max_len, target_fps, max_res)
|
70 |
+
# Infer depth maps using the model
|
71 |
depths, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=input_size, device=DEVICE)
|
72 |
|
73 |
video_name = os.path.basename(input_video)
|
74 |
if not os.path.exists(output_dir):
|
75 |
os.makedirs(output_dir)
|
76 |
|
77 |
+
# Save the processed (RGB) video and the depth visualization (using the default color mapping)
|
78 |
processed_video_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_src.mp4')
|
79 |
depth_vis_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_vis.mp4')
|
80 |
save_video(frames, processed_video_path, fps=fps)
|
|
|
82 |
|
83 |
stitched_video_path = ""
|
84 |
if stitch:
|
85 |
+
# Create a stitched video: left side is the RGB video, right side is the depth video
|
86 |
d_min, d_max = depths.min(), depths.max()
|
87 |
stitched_frames = []
|
88 |
for i in range(min(len(frames), len(depths))):
|
|
|
90 |
depth_frame = depths[i]
|
91 |
# Normalize the depth frame to [0, 255]
|
92 |
depth_norm = ((depth_frame - d_min) / (d_max - d_min) * 255).astype(np.uint8)
|
93 |
+
# Use grayscale or colored mapping for the depth channel
|
94 |
if grayscale:
|
95 |
depth_vis = np.stack([depth_norm] * 3, axis=-1)
|
96 |
else:
|
97 |
cmap = cm.get_cmap("inferno")
|
98 |
+
# cmap liefert RGBA – wir verwenden die ersten 3 Kanäle und skalieren auf 255
|
99 |
depth_vis = (cmap(depth_norm / 255.0)[..., :3] * 255).astype(np.uint8)
|
100 |
+
# Apply Gaussian blur if requested (blur factor > 0)
|
101 |
if blur > 0:
|
102 |
kernel_size = int(blur * 20) * 2 + 1 # ensures odd kernel size
|
103 |
depth_vis = cv2.GaussianBlur(depth_vis, (kernel_size, kernel_size), 0)
|
104 |
+
# Concatenate side-by-side: RGB frame on the left, processed depth on the right
|
105 |
stitched = cv2.hconcat([rgb_frame, depth_vis])
|
106 |
stitched_frames.append(stitched)
|
107 |
stitched_frames = np.array(stitched_frames)
|
|
|
111 |
gc.collect()
|
112 |
torch.cuda.empty_cache()
|
113 |
|
114 |
+
# Return processed RGB video, depth visualization, and (if created) stitched video.
|
|
|
115 |
return [processed_video_path, depth_vis_path, stitched_video_path]
|
116 |
|
|
|
117 |
def construct_demo():
|
118 |
with gr.Blocks(analytics_enabled=False) as demo:
|
119 |
gr.Markdown(title)
|
120 |
gr.Markdown(description)
|
121 |
gr.Markdown("### If you find this work useful, please help ⭐ the [Github Repo](https://github.com/DepthAnything/Video-Depth-Anything). Thanks for your attention!")
|
122 |
+
|
123 |
with gr.Row(equal_height=True):
|
124 |
with gr.Column(scale=1):
|
125 |
input_video = gr.Video(label="Input Video", source="upload", type="filepath")
|
|
|
128 |
processed_video = gr.Video(label="Preprocessed Video", interactive=False, autoplay=True, loop=True, show_share_button=True, scale=5)
|
129 |
depth_vis_video = gr.Video(label="Generated Depth Video", interactive=False, autoplay=True, loop=True, show_share_button=True, scale=5)
|
130 |
stitched_video = gr.Video(label="Stitched RGBD Video", interactive=False, autoplay=True, loop=True, show_share_button=True, scale=5)
|
131 |
+
|
132 |
with gr.Row(equal_height=True):
|
133 |
with gr.Column(scale=1):
|
134 |
with gr.Accordion("Advanced Settings", open=False):
|
|
|
141 |
generate_btn = gr.Button("Generate")
|
142 |
with gr.Column(scale=2):
|
143 |
pass
|
144 |
+
|
145 |
gr.Examples(
|
146 |
examples=examples,
|
147 |
inputs=[input_video, max_len, target_fps, max_res, stitch_option, grayscale_option, blur_slider],
|
|
|
149 |
fn=infer_video_depth,
|
150 |
cache_examples="lazy",
|
151 |
)
|
152 |
+
|
153 |
generate_btn.click(
|
154 |
fn=infer_video_depth,
|
155 |
inputs=[input_video, max_len, target_fps, max_res, stitch_option, grayscale_option, blur_slider],
|
156 |
outputs=[processed_video, depth_vis_video, stitched_video],
|
157 |
)
|
158 |
+
|
159 |
return demo
|
160 |
|
161 |
if __name__ == "__main__":
|
162 |
demo = construct_demo()
|
163 |
+
demo.queue() # Asynchrone Verarbeitung aktivieren
|
164 |
demo.launch(share=True)
|