Krokodilpirat commited on
Commit
9c84c70
·
verified ·
1 Parent(s): 033fb37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -41
app.py CHANGED
@@ -5,26 +5,29 @@ import cv2
5
  import gradio as gr
6
  import numpy as np
7
  import matplotlib.cm as cm
 
8
 
9
  from video_depth_anything.video_depth import VideoDepthAnything
10
  from utils.dc_utils import read_video_frames, save_video
11
  from huggingface_hub import hf_hub_download
12
 
13
- # Examples for the Gradio Demo (the additional parameters: stitch, grayscale, blur are appended)
 
 
14
  examples = [
15
- ['assets/example_videos/davis_rollercoaster.mp4', -1, -1, 1280, False, False, 0],
16
- ['assets/example_videos/Tokyo-Walk_rgb.mp4', -1, -1, 1280, False, False, 0],
17
- ['assets/example_videos/4158877-uhd_3840_2160_30fps_rgb.mp4', -1, -1, 1280, False, False, 0],
18
- ['assets/example_videos/4511004-uhd_3840_2160_24fps_rgb.mp4', -1, -1, 1280, False, False, 0],
19
- ['assets/example_videos/1753029-hd_1920_1080_30fps.mp4', -1, -1, 1280, False, False, 0],
20
- ['assets/example_videos/davis_burnout.mp4', -1, -1, 1280, False, False, 0],
21
- ['assets/example_videos/example_5473765-l.mp4', -1, -1, 1280, False, False, 0],
22
- ['assets/example_videos/Istanbul-26920.mp4', -1, -1, 1280, False, False, 0],
23
- ['assets/example_videos/obj_1.mp4', -1, -1, 1280, False, False, 0],
24
- ['assets/example_videos/sheep_cut1.mp4', -1, -1, 1280, False, False, 0],
25
  ]
26
 
27
- # Determine the device: use GPU if available, else CPU.
28
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
29
 
30
  # Model configuration for different encoder variants.
@@ -39,7 +42,7 @@ encoder2name = {
39
  encoder = 'vitl'
40
  model_name = encoder2name[encoder]
41
 
42
- # Initialize the model
43
  video_depth_anything = VideoDepthAnything(**model_configs[encoder])
44
  filepath = hf_hub_download(
45
  repo_id=f"depth-anything/Video-Depth-Anything-{model_name}",
@@ -49,8 +52,8 @@ filepath = hf_hub_download(
49
  video_depth_anything.load_state_dict(torch.load(filepath, map_location='cpu'))
50
  video_depth_anything = video_depth_anything.to(DEVICE).eval()
51
 
52
- title = "# Video Depth Anything"
53
- description = """Official demo for **Video Depth Anything**.
54
  Please refer to our [paper](https://arxiv.org/abs/2501.12375), [project page](https://videodepthanything.github.io/), and [github](https://github.com/DepthAnything/Video-Depth-Anything) for more details."""
55
 
56
  def infer_video_depth(
@@ -58,23 +61,24 @@ def infer_video_depth(
58
  max_len: int = -1,
59
  target_fps: int = -1,
60
  max_res: int = 1280,
61
- stitch: bool = False,
62
- grayscale: bool = False,
63
  blur: float = 0.0,
64
- *, # The following parameters are keyword-only and cannot be overridden by UI input.
65
  output_dir: str = './outputs',
66
  input_size: int = 518,
 
67
  ):
68
- # Read input video frames with the given maximum resolution (max_res) for inference.
69
  frames, target_fps = read_video_frames(input_video, max_len, target_fps, max_res)
70
- # Perform depth inference using the model.
71
  depths, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=input_size, device=DEVICE)
72
 
73
  video_name = os.path.basename(input_video)
74
  if not os.path.exists(output_dir):
75
  os.makedirs(output_dir)
76
 
77
- # Save the preprocessed (RGB) video and the depth visualization (using the default color mapping)
78
  processed_video_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_src.mp4')
79
  depth_vis_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_vis.mp4')
80
  save_video(frames, processed_video_path, fps=fps)
@@ -82,40 +86,69 @@ def infer_video_depth(
82
 
83
  stitched_video_path = None
84
  if stitch:
85
- # For stitching: read the original video in full resolution (without downscaling)
86
  full_frames, _ = read_video_frames(input_video, max_len, target_fps, max_res=-1)
87
- # For each frame, create a visual depth image from the inferenced depth maps (which are in the downscaled resolution)
88
  d_min, d_max = depths.min(), depths.max()
89
  stitched_frames = []
90
  for i in range(min(len(full_frames), len(depths))):
91
- rgb_full = full_frames[i] # Full-resolution RGB frame
92
  depth_frame = depths[i]
93
- # Normalize the depth frame to the range [0, 255]
94
  depth_norm = ((depth_frame - d_min) / (d_max - d_min) * 255).astype(np.uint8)
95
- # Create either a grayscale image or apply the inferno colormap, depending on the setting.
96
  if grayscale:
97
- depth_vis = np.stack([depth_norm] * 3, axis=-1)
 
 
 
 
 
 
 
 
98
  else:
 
99
  cmap = cm.get_cmap("inferno")
100
  depth_vis = (cmap(depth_norm / 255.0)[..., :3] * 255).astype(np.uint8)
101
- # Apply Gaussian blur if requested (if blur factor > 0)
102
  if blur > 0:
103
- kernel_size = int(blur * 20) * 2 + 1 # ensures an odd kernel size
104
  depth_vis = cv2.GaussianBlur(depth_vis, (kernel_size, kernel_size), 0)
105
- # Resize the depth visual image to match the full-resolution RGB frame.
106
  H_full, W_full = rgb_full.shape[:2]
107
  depth_vis_resized = cv2.resize(depth_vis, (W_full, H_full))
108
- # Concatenate the full-resolution RGB frame (left) and the resized depth visual (right) side-by-side.
109
  stitched = cv2.hconcat([rgb_full, depth_vis_resized])
110
  stitched_frames.append(stitched)
111
  stitched_frames = np.array(stitched_frames)
112
- stitched_video_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_stitched.mp4')
 
 
 
113
  save_video(stitched_frames, stitched_video_path, fps=fps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  gc.collect()
116
  torch.cuda.empty_cache()
117
 
118
- # Return the processed RGB video, depth visualization, and (if created) the stitched video.
119
  return [processed_video_path, depth_vis_path, stitched_video_path]
120
 
121
  def construct_demo():
@@ -126,7 +159,7 @@ def construct_demo():
126
 
127
  with gr.Row(equal_height=True):
128
  with gr.Column(scale=1):
129
- # Use the Video component for file upload (without specifying 'source')
130
  input_video = gr.Video(label="Input Video")
131
  with gr.Column(scale=2):
132
  with gr.Row(equal_height=True):
@@ -137,19 +170,20 @@ def construct_demo():
137
  with gr.Row(equal_height=True):
138
  with gr.Column(scale=1):
139
  with gr.Accordion("Advanced Settings", open=False):
140
- max_len = gr.Slider(label="Max process length", minimum=-1, maximum=1000, value=500, step=1)
141
- target_fps = gr.Slider(label="Target FPS", minimum=-1, maximum=30, value=15, step=1)
142
  max_res = gr.Slider(label="Max side resolution", minimum=480, maximum=1920, value=1280, step=1)
143
- stitch_option = gr.Checkbox(label="Stitch RGB & Depth Videos", value=False)
144
- grayscale_option = gr.Checkbox(label="Output Depth as Grayscale", value=False)
145
  blur_slider = gr.Slider(minimum=0, maximum=1, step=0.01, label="Depth Blur Factor", value=0)
 
146
  generate_btn = gr.Button("Generate")
147
  with gr.Column(scale=2):
148
  pass
149
 
150
  gr.Examples(
151
  examples=examples,
152
- inputs=[input_video, max_len, target_fps, max_res, stitch_option, grayscale_option, blur_slider],
153
  outputs=[processed_video, depth_vis_video, stitched_video],
154
  fn=infer_video_depth,
155
  cache_examples=True,
@@ -158,7 +192,7 @@ def construct_demo():
158
 
159
  generate_btn.click(
160
  fn=infer_video_depth,
161
- inputs=[input_video, max_len, target_fps, max_res, stitch_option, grayscale_option, blur_slider],
162
  outputs=[processed_video, depth_vis_video, stitched_video],
163
  )
164
 
@@ -166,5 +200,5 @@ def construct_demo():
166
 
167
  if __name__ == "__main__":
168
  demo = construct_demo()
169
- demo.queue() # Enable asynchronous processing
170
  demo.launch(share=True)
 
5
  import gradio as gr
6
  import numpy as np
7
  import matplotlib.cm as cm
8
+ import subprocess
9
 
10
  from video_depth_anything.video_depth import VideoDepthAnything
11
  from utils.dc_utils import read_video_frames, save_video
12
  from huggingface_hub import hf_hub_download
13
 
14
+ # Examples for the Gradio Demo.
15
+ # Each example now contains 8 parameters:
16
+ # [video_path, max_len, target_fps, max_res, stitch, grayscale, blur, convert_from_color]
17
  examples = [
18
+ ['assets/example_videos/davis_rollercoaster.mp4', -1, -1, 1280, True, True, 0, True],
19
+ ['assets/example_videos/Tokyo-Walk_rgb.mp4', -1, -1, 1280, True, True, 0, True],
20
+ ['assets/example_videos/4158877-uhd_3840_2160_30fps_rgb.mp4', -1, -1, 1280, True, True, 0, True],
21
+ ['assets/example_videos/4511004-uhd_3840_2160_24fps_rgb.mp4', -1, -1, 1280, True, True, 0, True],
22
+ ['assets/example_videos/1753029-hd_1920_1080_30fps.mp4', -1, -1, 1280, True, True, 0, True],
23
+ ['assets/example_videos/davis_burnout.mp4', -1, -1, 1280, True, True, 0, True],
24
+ ['assets/example_videos/example_5473765-l.mp4', -1, -1, 1280, True, True, 0, True],
25
+ ['assets/example_videos/Istanbul-26920.mp4', -1, -1, 1280, True, True, 0, True],
26
+ ['assets/example_videos/obj_1.mp4', -1, -1, 1280, True, True, 0, True],
27
+ ['assets/example_videos/sheep_cut1.mp4', -1, -1, 1280, True, True, 0, True],
28
  ]
29
 
30
+ # Use GPU if available; otherwise, use CPU.
31
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
32
 
33
  # Model configuration for different encoder variants.
 
42
  encoder = 'vitl'
43
  model_name = encoder2name[encoder]
44
 
45
+ # Initialize the model.
46
  video_depth_anything = VideoDepthAnything(**model_configs[encoder])
47
  filepath = hf_hub_download(
48
  repo_id=f"depth-anything/Video-Depth-Anything-{model_name}",
 
52
  video_depth_anything.load_state_dict(torch.load(filepath, map_location='cpu'))
53
  video_depth_anything = video_depth_anything.to(DEVICE).eval()
54
 
55
+ title = "# Video Depth Anything + RGBD sbs output"
56
+ description = """Official demo for **Video Depth Anything** + RGBD sbs output for viewing with Looking Glass Factory displays.
57
  Please refer to our [paper](https://arxiv.org/abs/2501.12375), [project page](https://videodepthanything.github.io/), and [github](https://github.com/DepthAnything/Video-Depth-Anything) for more details."""
58
 
59
  def infer_video_depth(
 
61
  max_len: int = -1,
62
  target_fps: int = -1,
63
  max_res: int = 1280,
64
+ stitch: bool = True,
65
+ grayscale: bool = True,
66
  blur: float = 0.0,
67
+ *, # The following parameters are keyword-only (not overridden by UI input)
68
  output_dir: str = './outputs',
69
  input_size: int = 518,
70
+ convert_from_color: bool = True,
71
  ):
72
+ # 1. Read input video frames for inference (downscaled to max_res).
73
  frames, target_fps = read_video_frames(input_video, max_len, target_fps, max_res)
74
+ # 2. Perform depth inference using the model.
75
  depths, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=input_size, device=DEVICE)
76
 
77
  video_name = os.path.basename(input_video)
78
  if not os.path.exists(output_dir):
79
  os.makedirs(output_dir)
80
 
81
+ # Save the preprocessed (RGB) video and the generated depth visualization.
82
  processed_video_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_src.mp4')
83
  depth_vis_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_vis.mp4')
84
  save_video(frames, processed_video_path, fps=fps)
 
86
 
87
  stitched_video_path = None
88
  if stitch:
89
+ # For stitching: read the original video in full resolution (without downscaling).
90
  full_frames, _ = read_video_frames(input_video, max_len, target_fps, max_res=-1)
91
+ # For each frame, create a visual depth image from the inferenced depths.
92
  d_min, d_max = depths.min(), depths.max()
93
  stitched_frames = []
94
  for i in range(min(len(full_frames), len(depths))):
95
+ rgb_full = full_frames[i] # Full-resolution RGB frame.
96
  depth_frame = depths[i]
97
+ # Normalize the depth frame to [0, 255].
98
  depth_norm = ((depth_frame - d_min) / (d_max - d_min) * 255).astype(np.uint8)
99
+ # Generate depth visualization:
100
  if grayscale:
101
+ if convert_from_color:
102
+ # Generate a color depth image first, then convert it to grayscale.
103
+ cmap = cm.get_cmap("inferno")
104
+ depth_color = (cmap(depth_norm / 255.0)[..., :3] * 255).astype(np.uint8)
105
+ depth_gray = cv2.cvtColor(depth_color, cv2.COLOR_RGB2GRAY)
106
+ depth_vis = np.stack([depth_gray] * 3, axis=-1)
107
+ else:
108
+ # Directly generate a grayscale image from the normalized depth values.
109
+ depth_vis = np.stack([depth_norm] * 3, axis=-1)
110
  else:
111
+ # Generate a color depth image using the inferno colormap.
112
  cmap = cm.get_cmap("inferno")
113
  depth_vis = (cmap(depth_norm / 255.0)[..., :3] * 255).astype(np.uint8)
114
+ # Apply Gaussian blur if requested.
115
  if blur > 0:
116
+ kernel_size = int(blur * 20) * 2 + 1 # ensures an odd kernel size.
117
  depth_vis = cv2.GaussianBlur(depth_vis, (kernel_size, kernel_size), 0)
118
+ # Resize the depth visualization to match the full-resolution RGB frame.
119
  H_full, W_full = rgb_full.shape[:2]
120
  depth_vis_resized = cv2.resize(depth_vis, (W_full, H_full))
121
+ # Concatenate full-resolution RGB (left) and resized depth visualization (right).
122
  stitched = cv2.hconcat([rgb_full, depth_vis_resized])
123
  stitched_frames.append(stitched)
124
  stitched_frames = np.array(stitched_frames)
125
+ # Limit the video name to the first 20 characters and append '_RGBD.mp4'
126
+ base_name = os.path.splitext(video_name)[0]
127
+ short_name = base_name[:20]
128
+ stitched_video_path = os.path.join(output_dir, short_name + '_RGBD.mp4')
129
  save_video(stitched_frames, stitched_video_path, fps=fps)
130
+
131
+ # Merge audio from the input video into the stitched video using ffmpeg.
132
+ temp_audio_path = stitched_video_path.replace('_RGBD.mp4', '_RGBD_audio.mp4')
133
+ cmd = [
134
+ "ffmpeg",
135
+ "-y",
136
+ "-i", stitched_video_path,
137
+ "-i", input_video,
138
+ "-c:v", "copy",
139
+ "-c:a", "aac",
140
+ "-map", "0:v:0",
141
+ "-map", "1:a:0?",
142
+ "-shortest",
143
+ temp_audio_path
144
+ ]
145
+ subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
146
+ os.replace(temp_audio_path, stitched_video_path)
147
 
148
  gc.collect()
149
  torch.cuda.empty_cache()
150
 
151
+ # Return the preprocessed RGB video, depth visualization, and (if created) the stitched video.
152
  return [processed_video_path, depth_vis_path, stitched_video_path]
153
 
154
  def construct_demo():
 
159
 
160
  with gr.Row(equal_height=True):
161
  with gr.Column(scale=1):
162
+ # Video input component for file upload.
163
  input_video = gr.Video(label="Input Video")
164
  with gr.Column(scale=2):
165
  with gr.Row(equal_height=True):
 
170
  with gr.Row(equal_height=True):
171
  with gr.Column(scale=1):
172
  with gr.Accordion("Advanced Settings", open=False):
173
+ max_len = gr.Slider(label="Max process length", minimum=-1, maximum=1000, value=-1, step=1)
174
+ target_fps = gr.Slider(label="Target FPS", minimum=-1, maximum=30, value=-1, step=1)
175
  max_res = gr.Slider(label="Max side resolution", minimum=480, maximum=1920, value=1280, step=1)
176
+ stitch_option = gr.Checkbox(label="Stitch RGB & Depth Videos", value=True)
177
+ grayscale_option = gr.Checkbox(label="Output Depth as Grayscale", value=True)
178
  blur_slider = gr.Slider(minimum=0, maximum=1, step=0.01, label="Depth Blur Factor", value=0)
179
+ convert_from_color_option = gr.Checkbox(label="Convert Grayscale from Color", value=True)
180
  generate_btn = gr.Button("Generate")
181
  with gr.Column(scale=2):
182
  pass
183
 
184
  gr.Examples(
185
  examples=examples,
186
+ inputs=[input_video, max_len, target_fps, max_res, stitch_option, grayscale_option, blur_slider, convert_from_color_option],
187
  outputs=[processed_video, depth_vis_video, stitched_video],
188
  fn=infer_video_depth,
189
  cache_examples=True,
 
192
 
193
  generate_btn.click(
194
  fn=infer_video_depth,
195
+ inputs=[input_video, max_len, target_fps, max_res, stitch_option, grayscale_option, blur_slider, convert_from_color_option],
196
  outputs=[processed_video, depth_vis_video, stitched_video],
197
  )
198
 
 
200
 
201
  if __name__ == "__main__":
202
  demo = construct_demo()
203
+ demo.queue() # Enable asynchronous processing.
204
  demo.launch(share=True)