mtwohey2 commited on
Commit
45b851e
·
verified ·
1 Parent(s): 8265d02

Update app.py

Browse files

Claude recommendation for reducing memory usage.

Files changed (1) hide show
  1. app.py +226 -48
app.py CHANGED
@@ -5,18 +5,16 @@ import cv2
5
  import gradio as gr
6
  import numpy as np
7
  import matplotlib.cm as cm
8
- import matplotlib # New import for the updated colormap API
9
  import subprocess
10
  import sys
11
  import spaces
12
 
13
  from video_depth_anything.video_depth import VideoDepthAnything
14
- from utils.dc_utils import read_video_frames, save_video
15
  from huggingface_hub import hf_hub_download
16
 
17
  # Examples for the Gradio Demo.
18
- # Each example now contains 8 parameters:
19
- # [video_path, max_len, target_fps, max_res, stitch, grayscale, convert_from_color, blur]
20
  examples = [
21
  ['assets/example_videos/octopus_01.mp4', -1, -1, 1280, True, True, True, 0.3],
22
  ['assets/example_videos/chicken_01.mp4', -1, -1, 1280, True, True, True, 0.3],
@@ -62,8 +60,86 @@ title = "# Video Depth Anything + RGBD sbs output"
62
  description = """**Video Depth Anything** + RGBD sbs output for viewing with Looking Glass Factory displays.
63
  Please refer to our [paper](https://arxiv.org/abs/2501.12375), [project page](https://videodepthanything.github.io/), and [github](https://github.com/DepthAnything/Video-Depth-Anything) for more details."""
64
 
65
- @spaces.GPU(enable_queue=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def infer_video_depth(
68
  input_video: str,
69
  max_len: int = -1,
@@ -76,67 +152,171 @@ def infer_video_depth(
76
  output_dir: str = './outputs',
77
  input_size: int = 518,
78
  ):
79
- # 1. Read input video frames for inference (downscaled to max_res).
80
- frames, target_fps = read_video_frames(input_video, max_len, target_fps, max_res)
81
- # 2. Perform depth inference using the model.
82
- depths, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=input_size, device=DEVICE)
83
-
84
- video_name = os.path.basename(input_video)
85
  if not os.path.exists(output_dir):
86
  os.makedirs(output_dir)
87
-
88
- # Save the preprocessed (RGB) video and the generated depth visualization.
89
  processed_video_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_src.mp4')
90
  depth_vis_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_vis.mp4')
91
- save_video(frames, processed_video_path, fps=fps)
92
- save_video(depths, depth_vis_path, fps=fps, is_depths=True)
93
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  stitched_video_path = None
95
  if stitch:
96
- # For stitching: read the original video in full resolution (without downscaling).
97
- full_frames, _ = read_video_frames(input_video, max_len, target_fps, max_res=-1)
98
- # For each frame, create a visual depth image from the inferenced depths.
99
- d_min, d_max = depths.min(), depths.max()
100
- stitched_frames = []
101
- for i in range(min(len(full_frames), len(depths))):
102
- rgb_full = full_frames[i] # Full-resolution RGB frame.
103
- depth_frame = depths[i]
104
- # Normalize the depth frame to the range [0, 255].
105
- depth_norm = ((depth_frame - d_min) / (d_max - d_min) * 255).astype(np.uint8)
106
- # Generate depth visualization:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  if grayscale:
108
  if convert_from_color:
109
- # First, generate a color depth image using the inferno colormap,
110
- # then convert that color image to grayscale.
111
  cmap = matplotlib.colormaps.get_cmap("inferno")
112
  depth_color = (cmap(depth_norm / 255.0)[..., :3] * 255).astype(np.uint8)
113
  depth_gray = cv2.cvtColor(depth_color, cv2.COLOR_RGB2GRAY)
114
  depth_vis = np.stack([depth_gray] * 3, axis=-1)
115
  else:
116
- # Directly generate a grayscale image from the normalized depth values.
117
  depth_vis = np.stack([depth_norm] * 3, axis=-1)
118
  else:
119
- # Generate a color depth image using the inferno colormap.
120
  cmap = matplotlib.colormaps.get_cmap("inferno")
121
  depth_vis = (cmap(depth_norm / 255.0)[..., :3] * 255).astype(np.uint8)
122
- # Apply Gaussian blur if requested.
 
123
  if blur > 0:
124
- kernel_size = int(blur * 20) * 2 + 1 # Ensures an odd kernel size.
125
  depth_vis = cv2.GaussianBlur(depth_vis, (kernel_size, kernel_size), 0)
126
- # Resize the depth visualization to match the full-resolution RGB frame.
 
127
  H_full, W_full = rgb_full.shape[:2]
128
  depth_vis_resized = cv2.resize(depth_vis, (W_full, H_full))
129
- # Concatenate the full-resolution RGB frame (left) and the resized depth visualization (right).
 
130
  stitched = cv2.hconcat([rgb_full, depth_vis_resized])
131
- stitched_frames.append(stitched)
132
- stitched_frames = np.array(stitched_frames)
133
- # Use only the first 20 characters of the base name for the output filename and append '_RGBD.mp4'
134
- base_name = os.path.splitext(video_name)[0]
135
- short_name = base_name[:20]
136
- stitched_video_path = os.path.join(output_dir, short_name + '_RGBD.mp4')
137
- save_video(stitched_frames, stitched_video_path, fps=fps)
138
 
139
- # Merge audio from the input video into the stitched video using ffmpeg.
 
 
140
  temp_audio_path = stitched_video_path.replace('_RGBD.mp4', '_RGBD_audio.mp4')
141
  cmd = [
142
  "ffmpeg",
@@ -152,11 +332,11 @@ def infer_video_depth(
152
  ]
153
  subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
154
  os.replace(temp_audio_path, stitched_video_path)
155
-
 
156
  gc.collect()
157
  torch.cuda.empty_cache()
158
-
159
- # Return the preprocessed RGB video, depth visualization, and (if created) the stitched video.
160
  return [processed_video_path, depth_vis_path, stitched_video_path]
161
 
162
  def construct_demo():
@@ -208,6 +388,4 @@ def construct_demo():
208
 
209
  if __name__ == "__main__":
210
  demo = construct_demo()
211
- #demo.queue() # Enable asynchronous processing.
212
- #demo.launch(share=True)
213
  demo.queue(max_size=2).launch()
 
5
  import gradio as gr
6
  import numpy as np
7
  import matplotlib.cm as cm
8
+ import matplotlib
9
  import subprocess
10
  import sys
11
  import spaces
12
 
13
  from video_depth_anything.video_depth import VideoDepthAnything
14
+ from utils.dc_utils import save_video
15
  from huggingface_hub import hf_hub_download
16
 
17
  # Examples for the Gradio Demo.
 
 
18
  examples = [
19
  ['assets/example_videos/octopus_01.mp4', -1, -1, 1280, True, True, True, 0.3],
20
  ['assets/example_videos/chicken_01.mp4', -1, -1, 1280, True, True, True, 0.3],
 
60
  description = """**Video Depth Anything** + RGBD sbs output for viewing with Looking Glass Factory displays.
61
  Please refer to our [paper](https://arxiv.org/abs/2501.12375), [project page](https://videodepthanything.github.io/), and [github](https://github.com/DepthAnything/Video-Depth-Anything) for more details."""
62
 
63
+ def get_video_info(video_path, max_len=-1, target_fps=-1):
64
+ """Extract video information without loading all frames into memory."""
65
+ cap = cv2.VideoCapture(video_path)
66
+ if not cap.isOpened():
67
+ raise ValueError(f"Could not open video file: {video_path}")
68
+
69
+ # Get video properties
70
+ original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
71
+ original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
72
+ original_fps = cap.get(cv2.CAP_PROP_FPS)
73
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
74
+
75
+ # Adjust based on max_len parameter
76
+ if max_len > 0 and max_len < total_frames:
77
+ frame_count = max_len
78
+ else:
79
+ frame_count = total_frames
80
+
81
+ # Use target_fps if specified
82
+ if target_fps > 0:
83
+ fps = target_fps
84
+ else:
85
+ fps = original_fps
86
+
87
+ cap.release()
88
+
89
+ return {
90
+ 'width': original_width,
91
+ 'height': original_height,
92
+ 'fps': fps,
93
+ 'original_fps': original_fps,
94
+ 'frame_count': frame_count,
95
+ 'total_frames': total_frames
96
+ }
97
+
98
+ def process_frame(frame, max_res):
99
+ """Process a single frame to the desired resolution."""
100
+ if max_res > 0:
101
+ h, w = frame.shape[:2]
102
+ scale = min(max_res / w, max_res / h)
103
+ if scale < 1:
104
+ new_w, new_h = int(w * scale), int(h * scale)
105
+ frame = cv2.resize(frame, (new_w, new_h))
106
+ return frame
107
 
108
+ def frame_generator(video_path, max_len=-1, target_fps=-1, max_res=-1, skip_frames=0):
109
+ """Generate frames from a video file one at a time."""
110
+ cap = cv2.VideoCapture(video_path)
111
+ if not cap.isOpened():
112
+ raise ValueError(f"Could not open video file: {video_path}")
113
+
114
+ original_fps = cap.get(cv2.CAP_PROP_FPS)
115
+ frame_count = 0
116
+
117
+ # Calculate frame skip if target_fps is specified
118
+ if target_fps > 0 and target_fps < original_fps:
119
+ skip = int(round(original_fps / target_fps)) - 1
120
+ else:
121
+ skip = skip_frames
122
+
123
+ frame_idx = 0
124
+ while True:
125
+ ret, frame = cap.read()
126
+ if not ret or (max_len > 0 and frame_count >= max_len):
127
+ break
128
+
129
+ # Process frame if we're not skipping it
130
+ if frame_idx % (skip + 1) == 0:
131
+ # Convert from BGR to RGB
132
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
133
+ # Resize if necessary
134
+ processed_frame = process_frame(frame, max_res)
135
+ yield processed_frame
136
+ frame_count += 1
137
+
138
+ frame_idx += 1
139
+
140
+ cap.release()
141
+
142
+ @spaces.GPU(enable_queue=True)
143
  def infer_video_depth(
144
  input_video: str,
145
  max_len: int = -1,
 
152
  output_dir: str = './outputs',
153
  input_size: int = 518,
154
  ):
 
 
 
 
 
 
155
  if not os.path.exists(output_dir):
156
  os.makedirs(output_dir)
157
+
158
+ video_name = os.path.basename(input_video)
159
  processed_video_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_src.mp4')
160
  depth_vis_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_vis.mp4')
161
+
162
+ # Get video info first
163
+ video_info = get_video_info(input_video, max_len, target_fps)
164
+ fps = video_info['fps']
165
+ frame_count = video_info['frame_count']
166
+
167
+ # Set up VideoWriters
168
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
169
+
170
+ # Setup for processing batches of frames
171
+ batch_size = 8 # Process frames in small batches to balance efficiency and memory usage
172
+ processed_frames = []
173
+ depth_frames = []
174
+ stitched_frames = []
175
+
176
+ # Initialize min/max depth values for depth normalization
177
+ d_min, d_max = float('inf'), float('-inf')
178
+ depth_values = []
179
+
180
+ # First pass: Process frames for depth inference and collect min/max depth values
181
+ print(f"Processing video: {input_video}, {frame_count} frames at {fps} fps")
182
+
183
+ # Process frames in batches for depth inference
184
+ frame_gen = frame_generator(input_video, max_len, target_fps, max_res)
185
+ batch_count = 0
186
+
187
+ for i, frame in enumerate(frame_gen):
188
+ if i % 10 == 0:
189
+ print(f"Processing frame {i+1}/{frame_count}")
190
+
191
+ processed_frames.append(frame)
192
+
193
+ # When we have a full batch or reached the end, process it
194
+ if len(processed_frames) == batch_size or i == frame_count - 1:
195
+ # Process the batch for depth
196
+ with torch.no_grad():
197
+ batch_depths = video_depth_anything.infer_frames_depth(
198
+ processed_frames,
199
+ input_size=input_size,
200
+ device=DEVICE
201
+ )
202
+
203
+ # Collect depth statistics and frames
204
+ for depth in batch_depths:
205
+ d_min = min(d_min, depth.min())
206
+ d_max = max(d_max, depth.max())
207
+ depth_values.append(depth)
208
+
209
+ # Clear batch for next iteration
210
+ processed_frames = []
211
+ batch_count += 1
212
+
213
+ # Free up memory
214
+ torch.cuda.empty_cache()
215
+ gc.collect()
216
+
217
+ # Save the processed video
218
+ height, width = depth_values[0].shape[:2] if depth_values else (0, 0)
219
+ video_writer = cv2.VideoWriter(processed_video_path, fourcc, fps, (width, height))
220
+
221
+ # Reprocess frames to save original and depth videos
222
+ frame_gen = frame_generator(input_video, max_len, target_fps, max_res)
223
+ depth_writer = cv2.VideoWriter(depth_vis_path, fourcc, fps, (width, height))
224
+
225
+ for i, (frame, depth) in enumerate(zip(frame_gen, depth_values)):
226
+ # Save original frame (convert RGB to BGR for OpenCV)
227
+ video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
228
+
229
+ # Normalize and visualize depth
230
+ depth_norm = ((depth - d_min) / (d_max - d_min) * 255).astype(np.uint8)
231
+ if grayscale:
232
+ if convert_from_color:
233
+ cmap = matplotlib.colormaps.get_cmap("inferno")
234
+ depth_color = (cmap(depth_norm / 255.0)[..., :3] * 255).astype(np.uint8)
235
+ depth_gray = cv2.cvtColor(depth_color, cv2.COLOR_RGB2GRAY)
236
+ depth_vis = np.stack([depth_gray] * 3, axis=-1)
237
+ else:
238
+ depth_vis = np.stack([depth_norm] * 3, axis=-1)
239
+ else:
240
+ cmap = matplotlib.colormaps.get_cmap("inferno")
241
+ depth_vis = (cmap(depth_norm / 255.0)[..., :3] * 255).astype(np.uint8)
242
+
243
+ # Apply blur if requested
244
+ if blur > 0:
245
+ kernel_size = int(blur * 20) * 2 + 1 # Ensures an odd kernel size
246
+ depth_vis = cv2.GaussianBlur(depth_vis, (kernel_size, kernel_size), 0)
247
+
248
+ # Save depth visualization (convert RGB to BGR for OpenCV)
249
+ depth_writer.write(cv2.cvtColor(depth_vis, cv2.COLOR_RGB2BGR))
250
+
251
+ video_writer.release()
252
+ depth_writer.release()
253
+
254
+ # Process stitched video if requested
255
  stitched_video_path = None
256
  if stitch:
257
+ # For stitching: read the original video in full resolution
258
+ video_info_full = get_video_info(input_video, max_len, target_fps)
259
+ original_frame_gen = frame_generator(input_video, max_len, target_fps, -1) # No resizing
260
+
261
+ # Create a new writer for the stitched video
262
+ base_name = os.path.splitext(video_name)[0]
263
+ short_name = base_name[:20]
264
+ stitched_video_path = os.path.join(output_dir, short_name + '_RGBD.mp4')
265
+
266
+ # Get dimensions of the first frame to setup the video writer
267
+ first_frame = next(frame_generator(input_video, 1, -1, -1))
268
+ H_full, W_full = first_frame.shape[:2]
269
+
270
+ # Set up the stitched video writer
271
+ stitched_writer = cv2.VideoWriter(
272
+ stitched_video_path,
273
+ fourcc,
274
+ fps,
275
+ (W_full * 2, H_full) # Width is doubled for side-by-side
276
+ )
277
+
278
+ # Reset frame generator
279
+ original_frame_gen = frame_generator(input_video, max_len, target_fps, -1)
280
+
281
+ # Process each frame
282
+ for i, (rgb_full, depth) in enumerate(zip(original_frame_gen, depth_values)):
283
+ if i % 10 == 0:
284
+ print(f"Stitching frame {i+1}/{frame_count}")
285
+
286
+ # Normalize and visualize depth
287
+ depth_norm = ((depth - d_min) / (d_max - d_min) * 255).astype(np.uint8)
288
+
289
+ # Generate depth visualization
290
  if grayscale:
291
  if convert_from_color:
 
 
292
  cmap = matplotlib.colormaps.get_cmap("inferno")
293
  depth_color = (cmap(depth_norm / 255.0)[..., :3] * 255).astype(np.uint8)
294
  depth_gray = cv2.cvtColor(depth_color, cv2.COLOR_RGB2GRAY)
295
  depth_vis = np.stack([depth_gray] * 3, axis=-1)
296
  else:
 
297
  depth_vis = np.stack([depth_norm] * 3, axis=-1)
298
  else:
 
299
  cmap = matplotlib.colormaps.get_cmap("inferno")
300
  depth_vis = (cmap(depth_norm / 255.0)[..., :3] * 255).astype(np.uint8)
301
+
302
+ # Apply blur if requested
303
  if blur > 0:
304
+ kernel_size = int(blur * 20) * 2 + 1
305
  depth_vis = cv2.GaussianBlur(depth_vis, (kernel_size, kernel_size), 0)
306
+
307
+ # Resize depth to match original frame
308
  H_full, W_full = rgb_full.shape[:2]
309
  depth_vis_resized = cv2.resize(depth_vis, (W_full, H_full))
310
+
311
+ # Concatenate RGB and depth
312
  stitched = cv2.hconcat([rgb_full, depth_vis_resized])
313
+
314
+ # Write to video (convert RGB to BGR for OpenCV)
315
+ stitched_writer.write(cv2.cvtColor(stitched, cv2.COLOR_RGB2BGR))
 
 
 
 
316
 
317
+ stitched_writer.release()
318
+
319
+ # Merge audio from the input video into the stitched video
320
  temp_audio_path = stitched_video_path.replace('_RGBD.mp4', '_RGBD_audio.mp4')
321
  cmd = [
322
  "ffmpeg",
 
332
  ]
333
  subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
334
  os.replace(temp_audio_path, stitched_video_path)
335
+
336
+ # Clean up
337
  gc.collect()
338
  torch.cuda.empty_cache()
339
+
 
340
  return [processed_video_path, depth_vis_path, stitched_video_path]
341
 
342
  def construct_demo():
 
388
 
389
  if __name__ == "__main__":
390
  demo = construct_demo()
 
 
391
  demo.queue(max_size=2).launch()