Krokodilpirat commited on
Commit
19217bc
·
verified ·
1 Parent(s): 83f9422

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -21
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import gc
3
  import torch
4
- import cv2 # Wird für die Bildverarbeitung (z.B. hconcat, GaussianBlur) benötigt
5
  import gradio as gr
6
  import numpy as np
7
  import matplotlib.cm as cm
@@ -42,9 +42,11 @@ model_name = encoder2name[encoder]
42
 
43
  # Initialize the model
44
  video_depth_anything = VideoDepthAnything(**model_configs[encoder])
45
- filepath = hf_hub_download(repo_id=f"depth-anything/Video-Depth-Anything-{model_name}",
46
- filename=f"video_depth_anything_{encoder}.pth",
47
- repo_type="model")
 
 
48
  video_depth_anything.load_state_dict(torch.load(filepath, map_location='cpu'))
49
  video_depth_anything = video_depth_anything.to(DEVICE).eval()
50
 
@@ -52,8 +54,6 @@ title = "# Video Depth Anything"
52
  description = """Official demo for **Video Depth Anything**.
53
  Please refer to our [paper](https://arxiv.org/abs/2501.12375), [project page](https://videodepthanything.github.io/), and [github](https://github.com/DepthAnything/Video-Depth-Anything) for more details."""
54
 
55
-
56
- @gr.processing_utils.threaded # alternativ kann spaces.GPU genutzt werden, falls verfügbar
57
  def infer_video_depth(
58
  input_video: str,
59
  max_len: int = -1,
@@ -67,14 +67,14 @@ def infer_video_depth(
67
  ):
68
  # Read input video frames
69
  frames, target_fps = read_video_frames(input_video, max_len, target_fps, max_res)
70
- # Infer depths using the model
71
  depths, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=input_size, device=DEVICE)
72
 
73
  video_name = os.path.basename(input_video)
74
  if not os.path.exists(output_dir):
75
  os.makedirs(output_dir)
76
 
77
- # Save the processed (RGB) video and the depth visualization
78
  processed_video_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_src.mp4')
79
  depth_vis_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_vis.mp4')
80
  save_video(frames, processed_video_path, fps=fps)
@@ -82,7 +82,7 @@ def infer_video_depth(
82
 
83
  stitched_video_path = ""
84
  if stitch:
85
- # Create a stitched video (side-by-side): left: processed RGB, right: depth
86
  d_min, d_max = depths.min(), depths.max()
87
  stitched_frames = []
88
  for i in range(min(len(frames), len(depths))):
@@ -90,18 +90,18 @@ def infer_video_depth(
90
  depth_frame = depths[i]
91
  # Normalize the depth frame to [0, 255]
92
  depth_norm = ((depth_frame - d_min) / (d_max - d_min) * 255).astype(np.uint8)
93
- # Choose grayscale or colored mapping
94
  if grayscale:
95
  depth_vis = np.stack([depth_norm] * 3, axis=-1)
96
  else:
97
  cmap = cm.get_cmap("inferno")
98
- # cmap liefert RGBA, hier verwenden wir nur die ersten drei Kanäle
99
  depth_vis = (cmap(depth_norm / 255.0)[..., :3] * 255).astype(np.uint8)
100
- # Apply Gaussian blur if requested
101
  if blur > 0:
102
  kernel_size = int(blur * 20) * 2 + 1 # ensures odd kernel size
103
  depth_vis = cv2.GaussianBlur(depth_vis, (kernel_size, kernel_size), 0)
104
- # Concatenate side-by-side
105
  stitched = cv2.hconcat([rgb_frame, depth_vis])
106
  stitched_frames.append(stitched)
107
  stitched_frames = np.array(stitched_frames)
@@ -111,17 +111,15 @@ def infer_video_depth(
111
  gc.collect()
112
  torch.cuda.empty_cache()
113
 
114
- # Return three outputs: processed RGB video, depth visualization, and (optionally) stitched video.
115
- # Falls stitch nicht aktiviert ist, wird ein leerer String zurückgegeben.
116
  return [processed_video_path, depth_vis_path, stitched_video_path]
117
 
118
-
119
  def construct_demo():
120
  with gr.Blocks(analytics_enabled=False) as demo:
121
  gr.Markdown(title)
122
  gr.Markdown(description)
123
  gr.Markdown("### If you find this work useful, please help ⭐ the [Github Repo](https://github.com/DepthAnything/Video-Depth-Anything). Thanks for your attention!")
124
-
125
  with gr.Row(equal_height=True):
126
  with gr.Column(scale=1):
127
  input_video = gr.Video(label="Input Video", source="upload", type="filepath")
@@ -130,6 +128,7 @@ def construct_demo():
130
  processed_video = gr.Video(label="Preprocessed Video", interactive=False, autoplay=True, loop=True, show_share_button=True, scale=5)
131
  depth_vis_video = gr.Video(label="Generated Depth Video", interactive=False, autoplay=True, loop=True, show_share_button=True, scale=5)
132
  stitched_video = gr.Video(label="Stitched RGBD Video", interactive=False, autoplay=True, loop=True, show_share_button=True, scale=5)
 
133
  with gr.Row(equal_height=True):
134
  with gr.Column(scale=1):
135
  with gr.Accordion("Advanced Settings", open=False):
@@ -142,7 +141,7 @@ def construct_demo():
142
  generate_btn = gr.Button("Generate")
143
  with gr.Column(scale=2):
144
  pass
145
-
146
  gr.Examples(
147
  examples=examples,
148
  inputs=[input_video, max_len, target_fps, max_res, stitch_option, grayscale_option, blur_slider],
@@ -150,16 +149,16 @@ def construct_demo():
150
  fn=infer_video_depth,
151
  cache_examples="lazy",
152
  )
153
-
154
  generate_btn.click(
155
  fn=infer_video_depth,
156
  inputs=[input_video, max_len, target_fps, max_res, stitch_option, grayscale_option, blur_slider],
157
  outputs=[processed_video, depth_vis_video, stitched_video],
158
  )
159
-
160
  return demo
161
 
162
  if __name__ == "__main__":
163
  demo = construct_demo()
164
- demo.queue()
165
  demo.launch(share=True)
 
1
  import os
2
  import gc
3
  import torch
4
+ import cv2
5
  import gradio as gr
6
  import numpy as np
7
  import matplotlib.cm as cm
 
42
 
43
  # Initialize the model
44
  video_depth_anything = VideoDepthAnything(**model_configs[encoder])
45
+ filepath = hf_hub_download(
46
+ repo_id=f"depth-anything/Video-Depth-Anything-{model_name}",
47
+ filename=f"video_depth_anything_{encoder}.pth",
48
+ repo_type="model"
49
+ )
50
  video_depth_anything.load_state_dict(torch.load(filepath, map_location='cpu'))
51
  video_depth_anything = video_depth_anything.to(DEVICE).eval()
52
 
 
54
  description = """Official demo for **Video Depth Anything**.
55
  Please refer to our [paper](https://arxiv.org/abs/2501.12375), [project page](https://videodepthanything.github.io/), and [github](https://github.com/DepthAnything/Video-Depth-Anything) for more details."""
56
 
 
 
57
  def infer_video_depth(
58
  input_video: str,
59
  max_len: int = -1,
 
67
  ):
68
  # Read input video frames
69
  frames, target_fps = read_video_frames(input_video, max_len, target_fps, max_res)
70
+ # Infer depth maps using the model
71
  depths, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=input_size, device=DEVICE)
72
 
73
  video_name = os.path.basename(input_video)
74
  if not os.path.exists(output_dir):
75
  os.makedirs(output_dir)
76
 
77
+ # Save the processed (RGB) video and the depth visualization (using the default color mapping)
78
  processed_video_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_src.mp4')
79
  depth_vis_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_vis.mp4')
80
  save_video(frames, processed_video_path, fps=fps)
 
82
 
83
  stitched_video_path = ""
84
  if stitch:
85
+ # Create a stitched video: left side is the RGB video, right side is the depth video
86
  d_min, d_max = depths.min(), depths.max()
87
  stitched_frames = []
88
  for i in range(min(len(frames), len(depths))):
 
90
  depth_frame = depths[i]
91
  # Normalize the depth frame to [0, 255]
92
  depth_norm = ((depth_frame - d_min) / (d_max - d_min) * 255).astype(np.uint8)
93
+ # Use grayscale or colored mapping for the depth channel
94
  if grayscale:
95
  depth_vis = np.stack([depth_norm] * 3, axis=-1)
96
  else:
97
  cmap = cm.get_cmap("inferno")
98
+ # cmap liefert RGBA wir verwenden die ersten 3 Kanäle und skalieren auf 255
99
  depth_vis = (cmap(depth_norm / 255.0)[..., :3] * 255).astype(np.uint8)
100
+ # Apply Gaussian blur if requested (blur factor > 0)
101
  if blur > 0:
102
  kernel_size = int(blur * 20) * 2 + 1 # ensures odd kernel size
103
  depth_vis = cv2.GaussianBlur(depth_vis, (kernel_size, kernel_size), 0)
104
+ # Concatenate side-by-side: RGB frame on the left, processed depth on the right
105
  stitched = cv2.hconcat([rgb_frame, depth_vis])
106
  stitched_frames.append(stitched)
107
  stitched_frames = np.array(stitched_frames)
 
111
  gc.collect()
112
  torch.cuda.empty_cache()
113
 
114
+ # Return processed RGB video, depth visualization, and (if created) stitched video.
 
115
  return [processed_video_path, depth_vis_path, stitched_video_path]
116
 
 
117
  def construct_demo():
118
  with gr.Blocks(analytics_enabled=False) as demo:
119
  gr.Markdown(title)
120
  gr.Markdown(description)
121
  gr.Markdown("### If you find this work useful, please help ⭐ the [Github Repo](https://github.com/DepthAnything/Video-Depth-Anything). Thanks for your attention!")
122
+
123
  with gr.Row(equal_height=True):
124
  with gr.Column(scale=1):
125
  input_video = gr.Video(label="Input Video", source="upload", type="filepath")
 
128
  processed_video = gr.Video(label="Preprocessed Video", interactive=False, autoplay=True, loop=True, show_share_button=True, scale=5)
129
  depth_vis_video = gr.Video(label="Generated Depth Video", interactive=False, autoplay=True, loop=True, show_share_button=True, scale=5)
130
  stitched_video = gr.Video(label="Stitched RGBD Video", interactive=False, autoplay=True, loop=True, show_share_button=True, scale=5)
131
+
132
  with gr.Row(equal_height=True):
133
  with gr.Column(scale=1):
134
  with gr.Accordion("Advanced Settings", open=False):
 
141
  generate_btn = gr.Button("Generate")
142
  with gr.Column(scale=2):
143
  pass
144
+
145
  gr.Examples(
146
  examples=examples,
147
  inputs=[input_video, max_len, target_fps, max_res, stitch_option, grayscale_option, blur_slider],
 
149
  fn=infer_video_depth,
150
  cache_examples="lazy",
151
  )
152
+
153
  generate_btn.click(
154
  fn=infer_video_depth,
155
  inputs=[input_video, max_len, target_fps, max_res, stitch_option, grayscale_option, blur_slider],
156
  outputs=[processed_video, depth_vis_video, stitched_video],
157
  )
158
+
159
  return demo
160
 
161
  if __name__ == "__main__":
162
  demo = construct_demo()
163
+ demo.queue() # Asynchrone Verarbeitung aktivieren
164
  demo.launch(share=True)