mtwohey2 commited on
Commit
d542641
·
verified ·
1 Parent(s): 6916b74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -221
app.py CHANGED
@@ -5,16 +5,18 @@ import cv2
5
  import gradio as gr
6
  import numpy as np
7
  import matplotlib.cm as cm
8
- import matplotlib
9
  import subprocess
10
  import sys
11
  import spaces
12
 
13
  from video_depth_anything.video_depth import VideoDepthAnything
14
- from utils.dc_utils import save_video
15
  from huggingface_hub import hf_hub_download
16
 
17
  # Examples for the Gradio Demo.
 
 
18
  examples = [
19
  ['assets/example_videos/octopus_01.mp4', -1, -1, 1280, True, True, True, 0.3],
20
  ['assets/example_videos/chicken_01.mp4', -1, -1, 1280, True, True, True, 0.3],
@@ -60,114 +62,8 @@ title = "# Video Depth Anything + RGBD sbs output"
60
  description = """**Video Depth Anything** + RGBD sbs output for viewing with Looking Glass Factory displays.
61
  Please refer to our [paper](https://arxiv.org/abs/2501.12375), [project page](https://videodepthanything.github.io/), and [github](https://github.com/DepthAnything/Video-Depth-Anything) for more details."""
62
 
63
- def get_video_info(video_path, max_len=-1, target_fps=-1):
64
- """Extract video information without loading all frames into memory."""
65
- cap = cv2.VideoCapture(video_path)
66
- if not cap.isOpened():
67
- raise ValueError(f"Could not open video file: {video_path}")
68
-
69
- # Get video properties
70
- original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
71
- original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
72
- original_fps = cap.get(cv2.CAP_PROP_FPS)
73
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
74
-
75
- # Adjust based on max_len parameter
76
- if max_len > 0 and max_len < total_frames:
77
- frame_count = max_len
78
- else:
79
- frame_count = total_frames
80
-
81
- # Use target_fps if specified
82
- if target_fps > 0:
83
- fps = target_fps
84
- else:
85
- fps = original_fps
86
-
87
- cap.release()
88
-
89
- return {
90
- 'width': original_width,
91
- 'height': original_height,
92
- 'fps': fps,
93
- 'original_fps': original_fps,
94
- 'frame_count': frame_count,
95
- 'total_frames': total_frames
96
- }
97
-
98
- def process_frame(frame, max_res):
99
- """Process a single frame to the desired resolution."""
100
- if max_res > 0:
101
- h, w = frame.shape[:2]
102
- scale = min(max_res / w, max_res / h)
103
- if scale < 1:
104
- new_w, new_h = int(w * scale), int(h * scale)
105
- frame = cv2.resize(frame, (new_w, new_h))
106
- return frame
107
-
108
- def read_video_frames_chunked(video_path, max_len=-1, target_fps=-1, max_res=-1, chunk_size=32):
109
- """Read video frames in chunks to avoid loading the entire video into memory."""
110
- cap = cv2.VideoCapture(video_path)
111
- if not cap.isOpened():
112
- raise ValueError(f"Could not open video file: {video_path}")
113
-
114
- original_fps = cap.get(cv2.CAP_PROP_FPS)
115
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
116
-
117
- # Determine actual number of frames to process
118
- if max_len > 0 and max_len < total_frames:
119
- frame_count = max_len
120
- else:
121
- frame_count = total_frames
122
-
123
- # Use target_fps if specified
124
- if target_fps > 0:
125
- fps = target_fps
126
- # Calculate frame skip if downsampling fps
127
- if target_fps < original_fps:
128
- skip = int(round(original_fps / target_fps)) - 1
129
- else:
130
- skip = 0
131
- else:
132
- fps = original_fps
133
- skip = 0
134
-
135
- frame_idx = 0
136
- processed_count = 0
137
-
138
- while processed_count < frame_count:
139
- frames_chunk = []
140
- # Read frames up to chunk size or remaining frames
141
- chunk_limit = min(chunk_size, frame_count - processed_count)
142
-
143
- while len(frames_chunk) < chunk_limit:
144
- ret, frame = cap.read()
145
- if not ret:
146
- break
147
-
148
- # Process frame if we're not skipping it
149
- if frame_idx % (skip + 1) == 0:
150
- # Convert from BGR to RGB
151
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
152
- # Resize if necessary
153
- frame = process_frame(frame, max_res)
154
- frames_chunk.append(frame)
155
- processed_count += 1
156
-
157
- if processed_count >= frame_count:
158
- break
159
-
160
- frame_idx += 1
161
-
162
- if frames_chunk:
163
- yield frames_chunk, fps
164
-
165
- if processed_count >= frame_count or len(frames_chunk) < chunk_limit:
166
- break
167
-
168
- cap.release()
169
-
170
  @spaces.GPU(enable_queue=True)
 
171
  def infer_video_depth(
172
  input_video: str,
173
  max_len: int = -1,
@@ -180,122 +76,67 @@ def infer_video_depth(
180
  output_dir: str = './outputs',
181
  input_size: int = 518,
182
  ):
 
 
 
 
 
 
183
  if not os.path.exists(output_dir):
184
  os.makedirs(output_dir)
185
-
186
- video_name = os.path.basename(input_video)
187
  processed_video_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_src.mp4')
188
  depth_vis_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_vis.mp4')
189
-
190
- # Get video info first
191
- video_info = get_video_info(input_video, max_len, target_fps)
192
- fps = video_info['fps']
193
- frame_count = video_info['frame_count']
194
-
195
- print(f"Processing video: {input_video}, {frame_count} frames at {fps} fps")
196
-
197
- # Process the video in chunks to manage memory
198
- chunk_size = 32 # Adjust based on available memory
199
-
200
- # We'll collect depths as we go to calculate global min/max
201
- all_depths = []
202
- all_processed_frames = []
203
-
204
- # First pass to collect frames and depths
205
- frame_idx = 0
206
- for frames_chunk, fps in read_video_frames_chunked(input_video, max_len, target_fps, max_res, chunk_size):
207
- print(f"Processing chunk: frames {frame_idx+1}-{frame_idx+len(frames_chunk)}/{frame_count}")
208
-
209
- # Process this chunk of frames
210
- depths, _ = video_depth_anything.infer_video_depth(frames_chunk, fps, input_size=input_size, device=DEVICE)
211
-
212
- # Store results (we'll need both for the output videos)
213
- all_processed_frames.extend(frames_chunk)
214
- all_depths.extend(depths)
215
-
216
- frame_idx += len(frames_chunk)
217
-
218
- # Free memory
219
- gc.collect()
220
- torch.cuda.empty_cache()
221
-
222
- # Calculate global min/max for depth normalization
223
- depths_array = np.array(all_depths)
224
- d_min, d_max = depths_array.min(), depths_array.max()
225
-
226
- # Save the preprocessed video and depth visualization
227
- save_video(all_processed_frames, processed_video_path, fps=fps)
228
- save_video(all_depths, depth_vis_path, fps=fps, is_depths=True)
229
-
230
- # Free some memory before stitching
231
- del all_processed_frames
232
- gc.collect()
233
-
234
- # Process stitched video if requested
235
  stitched_video_path = None
236
  if stitch:
237
- # Use only the first 20 characters of the base name for the output filename
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  base_name = os.path.splitext(video_name)[0]
239
  short_name = base_name[:20]
240
  stitched_video_path = os.path.join(output_dir, short_name + '_RGBD.mp4')
241
-
242
- # For stitching: read the original video in full resolution and stitch frames one by one
243
- stitched_frames = []
244
-
245
- # Process in chunks for memory efficiency
246
- frame_idx = 0
247
- for frames_chunk, _ in read_video_frames_chunked(input_video, max_len, target_fps, -1, chunk_size): # No max_res for original resolution
248
- print(f"Stitching chunk: frames {frame_idx+1}-{frame_idx+len(frames_chunk)}/{frame_count}")
249
-
250
- # Process each frame in the chunk
251
- for i, rgb_full in enumerate(frames_chunk):
252
- depth_idx = frame_idx + i
253
- if depth_idx >= len(all_depths):
254
- break
255
-
256
- depth_frame = all_depths[depth_idx]
257
-
258
- # Normalize the depth frame
259
- depth_norm = ((depth_frame - d_min) / (d_max - d_min) * 255).astype(np.uint8)
260
-
261
- # Generate depth visualization
262
- if grayscale:
263
- if convert_from_color:
264
- # Convert from color to grayscale
265
- cmap = matplotlib.colormaps.get_cmap("inferno")
266
- depth_color = (cmap(depth_norm / 255.0)[..., :3] * 255).astype(np.uint8)
267
- depth_gray = cv2.cvtColor(depth_color, cv2.COLOR_RGB2GRAY)
268
- depth_vis = np.stack([depth_gray] * 3, axis=-1)
269
- else:
270
- # Directly use grayscale
271
- depth_vis = np.stack([depth_norm] * 3, axis=-1)
272
- else:
273
- # Use color visualization
274
- cmap = matplotlib.colormaps.get_cmap("inferno")
275
- depth_vis = (cmap(depth_norm / 255.0)[..., :3] * 255).astype(np.uint8)
276
-
277
- # Apply blur if requested
278
- if blur > 0:
279
- kernel_size = int(blur * 20) * 2 + 1 # Ensures odd kernel size
280
- depth_vis = cv2.GaussianBlur(depth_vis, (kernel_size, kernel_size), 0)
281
-
282
- # Resize depth visualization to match original resolution
283
- H_full, W_full = rgb_full.shape[:2]
284
- depth_vis_resized = cv2.resize(depth_vis, (W_full, H_full))
285
-
286
- # Concatenate RGB and depth
287
- stitched = cv2.hconcat([rgb_full, depth_vis_resized])
288
- stitched_frames.append(stitched)
289
-
290
- frame_idx += len(frames_chunk)
291
-
292
- # Free memory after processing each chunk
293
- gc.collect()
294
-
295
- # Save the stitched video
296
  save_video(stitched_frames, stitched_video_path, fps=fps)
297
 
298
- # Merge audio from the input video
299
  temp_audio_path = stitched_video_path.replace('_RGBD.mp4', '_RGBD_audio.mp4')
300
  cmd = [
301
  "ffmpeg",
@@ -311,15 +152,11 @@ def infer_video_depth(
311
  ]
312
  subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
313
  os.replace(temp_audio_path, stitched_video_path)
314
-
315
- # Free memory
316
- del stitched_frames
317
-
318
- # Clean up
319
- del all_depths
320
  gc.collect()
321
  torch.cuda.empty_cache()
322
-
 
323
  return [processed_video_path, depth_vis_path, stitched_video_path]
324
 
325
  def construct_demo():
@@ -371,4 +208,6 @@ def construct_demo():
371
 
372
  if __name__ == "__main__":
373
  demo = construct_demo()
 
 
374
  demo.queue(max_size=2).launch()
 
5
  import gradio as gr
6
  import numpy as np
7
  import matplotlib.cm as cm
8
+ import matplotlib # New import for the updated colormap API
9
  import subprocess
10
  import sys
11
  import spaces
12
 
13
  from video_depth_anything.video_depth import VideoDepthAnything
14
+ from utils.dc_utils import read_video_frames, save_video
15
  from huggingface_hub import hf_hub_download
16
 
17
  # Examples for the Gradio Demo.
18
+ # Each example now contains 8 parameters:
19
+ # [video_path, max_len, target_fps, max_res, stitch, grayscale, convert_from_color, blur]
20
  examples = [
21
  ['assets/example_videos/octopus_01.mp4', -1, -1, 1280, True, True, True, 0.3],
22
  ['assets/example_videos/chicken_01.mp4', -1, -1, 1280, True, True, True, 0.3],
 
62
  description = """**Video Depth Anything** + RGBD sbs output for viewing with Looking Glass Factory displays.
63
  Please refer to our [paper](https://arxiv.org/abs/2501.12375), [project page](https://videodepthanything.github.io/), and [github](https://github.com/DepthAnything/Video-Depth-Anything) for more details."""
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  @spaces.GPU(enable_queue=True)
66
+
67
  def infer_video_depth(
68
  input_video: str,
69
  max_len: int = -1,
 
76
  output_dir: str = './outputs',
77
  input_size: int = 518,
78
  ):
79
+ # 1. Read input video frames for inference (downscaled to max_res).
80
+ frames, target_fps = read_video_frames(input_video, max_len, target_fps, max_res)
81
+ # 2. Perform depth inference using the model.
82
+ depths, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=input_size, device=DEVICE)
83
+
84
+ video_name = os.path.basename(input_video)
85
  if not os.path.exists(output_dir):
86
  os.makedirs(output_dir)
87
+
88
+ # Save the preprocessed (RGB) video and the generated depth visualization.
89
  processed_video_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_src.mp4')
90
  depth_vis_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_vis.mp4')
91
+ save_video(frames, processed_video_path, fps=fps)
92
+ save_video(depths, depth_vis_path, fps=fps, is_depths=True)
93
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  stitched_video_path = None
95
  if stitch:
96
+ # For stitching: read the original video in full resolution (without downscaling).
97
+ full_frames, _ = read_video_frames(input_video, max_len, target_fps, max_res=-1)
98
+ # For each frame, create a visual depth image from the inferenced depths.
99
+ d_min, d_max = depths.min(), depths.max()
100
+ stitched_frames = []
101
+ for i in range(min(len(full_frames), len(depths))):
102
+ rgb_full = full_frames[i] # Full-resolution RGB frame.
103
+ depth_frame = depths[i]
104
+ # Normalize the depth frame to the range [0, 255].
105
+ depth_norm = ((depth_frame - d_min) / (d_max - d_min) * 255).astype(np.uint8)
106
+ # Generate depth visualization:
107
+ if grayscale:
108
+ if convert_from_color:
109
+ # First, generate a color depth image using the inferno colormap,
110
+ # then convert that color image to grayscale.
111
+ cmap = matplotlib.colormaps.get_cmap("inferno")
112
+ depth_color = (cmap(depth_norm / 255.0)[..., :3] * 255).astype(np.uint8)
113
+ depth_gray = cv2.cvtColor(depth_color, cv2.COLOR_RGB2GRAY)
114
+ depth_vis = np.stack([depth_gray] * 3, axis=-1)
115
+ else:
116
+ # Directly generate a grayscale image from the normalized depth values.
117
+ depth_vis = np.stack([depth_norm] * 3, axis=-1)
118
+ else:
119
+ # Generate a color depth image using the inferno colormap.
120
+ cmap = matplotlib.colormaps.get_cmap("inferno")
121
+ depth_vis = (cmap(depth_norm / 255.0)[..., :3] * 255).astype(np.uint8)
122
+ # Apply Gaussian blur if requested.
123
+ if blur > 0:
124
+ kernel_size = int(blur * 20) * 2 + 1 # Ensures an odd kernel size.
125
+ depth_vis = cv2.GaussianBlur(depth_vis, (kernel_size, kernel_size), 0)
126
+ # Resize the depth visualization to match the full-resolution RGB frame.
127
+ H_full, W_full = rgb_full.shape[:2]
128
+ depth_vis_resized = cv2.resize(depth_vis, (W_full, H_full))
129
+ # Concatenate the full-resolution RGB frame (left) and the resized depth visualization (right).
130
+ stitched = cv2.hconcat([rgb_full, depth_vis_resized])
131
+ stitched_frames.append(stitched)
132
+ stitched_frames = np.array(stitched_frames)
133
+ # Use only the first 20 characters of the base name for the output filename and append '_RGBD.mp4'
134
  base_name = os.path.splitext(video_name)[0]
135
  short_name = base_name[:20]
136
  stitched_video_path = os.path.join(output_dir, short_name + '_RGBD.mp4')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  save_video(stitched_frames, stitched_video_path, fps=fps)
138
 
139
+ # Merge audio from the input video into the stitched video using ffmpeg.
140
  temp_audio_path = stitched_video_path.replace('_RGBD.mp4', '_RGBD_audio.mp4')
141
  cmd = [
142
  "ffmpeg",
 
152
  ]
153
  subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
154
  os.replace(temp_audio_path, stitched_video_path)
155
+
 
 
 
 
 
156
  gc.collect()
157
  torch.cuda.empty_cache()
158
+
159
+ # Return the preprocessed RGB video, depth visualization, and (if created) the stitched video.
160
  return [processed_video_path, depth_vis_path, stitched_video_path]
161
 
162
  def construct_demo():
 
208
 
209
  if __name__ == "__main__":
210
  demo = construct_demo()
211
+ #demo.queue() # Enable asynchronous processing.
212
+ #demo.launch(share=True)
213
  demo.queue(max_size=2).launch()