mtwohey2 commited on
Commit
6916b74
·
verified ·
1 Parent(s): 45b851e

Update app.py

Browse files

Fix for infer_frames_depth

Files changed (1) hide show
  1. app.py +134 -151
app.py CHANGED
@@ -105,37 +105,65 @@ def process_frame(frame, max_res):
105
  frame = cv2.resize(frame, (new_w, new_h))
106
  return frame
107
 
108
- def frame_generator(video_path, max_len=-1, target_fps=-1, max_res=-1, skip_frames=0):
109
- """Generate frames from a video file one at a time."""
110
  cap = cv2.VideoCapture(video_path)
111
  if not cap.isOpened():
112
  raise ValueError(f"Could not open video file: {video_path}")
113
 
114
  original_fps = cap.get(cv2.CAP_PROP_FPS)
115
- frame_count = 0
 
 
 
 
 
 
116
 
117
- # Calculate frame skip if target_fps is specified
118
- if target_fps > 0 and target_fps < original_fps:
119
- skip = int(round(original_fps / target_fps)) - 1
 
 
 
 
 
120
  else:
121
- skip = skip_frames
 
122
 
123
  frame_idx = 0
124
- while True:
125
- ret, frame = cap.read()
126
- if not ret or (max_len > 0 and frame_count >= max_len):
127
- break
 
 
128
 
129
- # Process frame if we're not skipping it
130
- if frame_idx % (skip + 1) == 0:
131
- # Convert from BGR to RGB
132
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
133
- # Resize if necessary
134
- processed_frame = process_frame(frame, max_res)
135
- yield processed_frame
136
- frame_count += 1
137
-
138
- frame_idx += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  cap.release()
141
 
@@ -164,159 +192,110 @@ def infer_video_depth(
164
  fps = video_info['fps']
165
  frame_count = video_info['frame_count']
166
 
167
- # Set up VideoWriters
168
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
169
-
170
- # Setup for processing batches of frames
171
- batch_size = 8 # Process frames in small batches to balance efficiency and memory usage
172
- processed_frames = []
173
- depth_frames = []
174
- stitched_frames = []
175
-
176
- # Initialize min/max depth values for depth normalization
177
- d_min, d_max = float('inf'), float('-inf')
178
- depth_values = []
179
-
180
- # First pass: Process frames for depth inference and collect min/max depth values
181
  print(f"Processing video: {input_video}, {frame_count} frames at {fps} fps")
182
 
183
- # Process frames in batches for depth inference
184
- frame_gen = frame_generator(input_video, max_len, target_fps, max_res)
185
- batch_count = 0
186
 
187
- for i, frame in enumerate(frame_gen):
188
- if i % 10 == 0:
189
- print(f"Processing frame {i+1}/{frame_count}")
190
-
191
- processed_frames.append(frame)
192
-
193
- # When we have a full batch or reached the end, process it
194
- if len(processed_frames) == batch_size or i == frame_count - 1:
195
- # Process the batch for depth
196
- with torch.no_grad():
197
- batch_depths = video_depth_anything.infer_frames_depth(
198
- processed_frames,
199
- input_size=input_size,
200
- device=DEVICE
201
- )
202
-
203
- # Collect depth statistics and frames
204
- for depth in batch_depths:
205
- d_min = min(d_min, depth.min())
206
- d_max = max(d_max, depth.max())
207
- depth_values.append(depth)
208
-
209
- # Clear batch for next iteration
210
- processed_frames = []
211
- batch_count += 1
212
-
213
- # Free up memory
214
- torch.cuda.empty_cache()
215
- gc.collect()
216
 
217
- # Save the processed video
218
- height, width = depth_values[0].shape[:2] if depth_values else (0, 0)
219
- video_writer = cv2.VideoWriter(processed_video_path, fourcc, fps, (width, height))
220
-
221
- # Reprocess frames to save original and depth videos
222
- frame_gen = frame_generator(input_video, max_len, target_fps, max_res)
223
- depth_writer = cv2.VideoWriter(depth_vis_path, fourcc, fps, (width, height))
224
-
225
- for i, (frame, depth) in enumerate(zip(frame_gen, depth_values)):
226
- # Save original frame (convert RGB to BGR for OpenCV)
227
- video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
228
 
229
- # Normalize and visualize depth
230
- depth_norm = ((depth - d_min) / (d_max - d_min) * 255).astype(np.uint8)
231
- if grayscale:
232
- if convert_from_color:
233
- cmap = matplotlib.colormaps.get_cmap("inferno")
234
- depth_color = (cmap(depth_norm / 255.0)[..., :3] * 255).astype(np.uint8)
235
- depth_gray = cv2.cvtColor(depth_color, cv2.COLOR_RGB2GRAY)
236
- depth_vis = np.stack([depth_gray] * 3, axis=-1)
237
- else:
238
- depth_vis = np.stack([depth_norm] * 3, axis=-1)
239
- else:
240
- cmap = matplotlib.colormaps.get_cmap("inferno")
241
- depth_vis = (cmap(depth_norm / 255.0)[..., :3] * 255).astype(np.uint8)
242
 
243
- # Apply blur if requested
244
- if blur > 0:
245
- kernel_size = int(blur * 20) * 2 + 1 # Ensures an odd kernel size
246
- depth_vis = cv2.GaussianBlur(depth_vis, (kernel_size, kernel_size), 0)
247
 
248
- # Save depth visualization (convert RGB to BGR for OpenCV)
249
- depth_writer.write(cv2.cvtColor(depth_vis, cv2.COLOR_RGB2BGR))
 
 
 
 
 
250
 
251
- video_writer.release()
252
- depth_writer.release()
 
 
 
 
 
253
 
254
  # Process stitched video if requested
255
  stitched_video_path = None
256
  if stitch:
257
- # For stitching: read the original video in full resolution
258
- video_info_full = get_video_info(input_video, max_len, target_fps)
259
- original_frame_gen = frame_generator(input_video, max_len, target_fps, -1) # No resizing
260
-
261
- # Create a new writer for the stitched video
262
  base_name = os.path.splitext(video_name)[0]
263
  short_name = base_name[:20]
264
  stitched_video_path = os.path.join(output_dir, short_name + '_RGBD.mp4')
265
 
266
- # Get dimensions of the first frame to setup the video writer
267
- first_frame = next(frame_generator(input_video, 1, -1, -1))
268
- H_full, W_full = first_frame.shape[:2]
269
 
270
- # Set up the stitched video writer
271
- stitched_writer = cv2.VideoWriter(
272
- stitched_video_path,
273
- fourcc,
274
- fps,
275
- (W_full * 2, H_full) # Width is doubled for side-by-side
276
- )
277
-
278
- # Reset frame generator
279
- original_frame_gen = frame_generator(input_video, max_len, target_fps, -1)
280
-
281
- # Process each frame
282
- for i, (rgb_full, depth) in enumerate(zip(original_frame_gen, depth_values)):
283
- if i % 10 == 0:
284
- print(f"Stitching frame {i+1}/{frame_count}")
285
 
286
- # Normalize and visualize depth
287
- depth_norm = ((depth - d_min) / (d_max - d_min) * 255).astype(np.uint8)
288
-
289
- # Generate depth visualization
290
- if grayscale:
291
- if convert_from_color:
292
- cmap = matplotlib.colormaps.get_cmap("inferno")
293
- depth_color = (cmap(depth_norm / 255.0)[..., :3] * 255).astype(np.uint8)
294
- depth_gray = cv2.cvtColor(depth_color, cv2.COLOR_RGB2GRAY)
295
- depth_vis = np.stack([depth_gray] * 3, axis=-1)
 
 
 
 
 
 
 
 
 
 
 
 
296
  else:
297
- depth_vis = np.stack([depth_norm] * 3, axis=-1)
298
- else:
299
- cmap = matplotlib.colormaps.get_cmap("inferno")
300
- depth_vis = (cmap(depth_norm / 255.0)[..., :3] * 255).astype(np.uint8)
301
-
302
- # Apply blur if requested
303
- if blur > 0:
304
- kernel_size = int(blur * 20) * 2 + 1
305
- depth_vis = cv2.GaussianBlur(depth_vis, (kernel_size, kernel_size), 0)
306
-
307
- # Resize depth to match original frame
308
- H_full, W_full = rgb_full.shape[:2]
309
- depth_vis_resized = cv2.resize(depth_vis, (W_full, H_full))
 
 
 
310
 
311
- # Concatenate RGB and depth
312
- stitched = cv2.hconcat([rgb_full, depth_vis_resized])
313
 
314
- # Write to video (convert RGB to BGR for OpenCV)
315
- stitched_writer.write(cv2.cvtColor(stitched, cv2.COLOR_RGB2BGR))
316
 
317
- stitched_writer.release()
 
318
 
319
- # Merge audio from the input video into the stitched video
320
  temp_audio_path = stitched_video_path.replace('_RGBD.mp4', '_RGBD_audio.mp4')
321
  cmd = [
322
  "ffmpeg",
@@ -332,8 +311,12 @@ def infer_video_depth(
332
  ]
333
  subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
334
  os.replace(temp_audio_path, stitched_video_path)
 
 
 
335
 
336
  # Clean up
 
337
  gc.collect()
338
  torch.cuda.empty_cache()
339
 
 
105
  frame = cv2.resize(frame, (new_w, new_h))
106
  return frame
107
 
108
+ def read_video_frames_chunked(video_path, max_len=-1, target_fps=-1, max_res=-1, chunk_size=32):
109
+ """Read video frames in chunks to avoid loading the entire video into memory."""
110
  cap = cv2.VideoCapture(video_path)
111
  if not cap.isOpened():
112
  raise ValueError(f"Could not open video file: {video_path}")
113
 
114
  original_fps = cap.get(cv2.CAP_PROP_FPS)
115
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
116
+
117
+ # Determine actual number of frames to process
118
+ if max_len > 0 and max_len < total_frames:
119
+ frame_count = max_len
120
+ else:
121
+ frame_count = total_frames
122
 
123
+ # Use target_fps if specified
124
+ if target_fps > 0:
125
+ fps = target_fps
126
+ # Calculate frame skip if downsampling fps
127
+ if target_fps < original_fps:
128
+ skip = int(round(original_fps / target_fps)) - 1
129
+ else:
130
+ skip = 0
131
  else:
132
+ fps = original_fps
133
+ skip = 0
134
 
135
  frame_idx = 0
136
+ processed_count = 0
137
+
138
+ while processed_count < frame_count:
139
+ frames_chunk = []
140
+ # Read frames up to chunk size or remaining frames
141
+ chunk_limit = min(chunk_size, frame_count - processed_count)
142
 
143
+ while len(frames_chunk) < chunk_limit:
144
+ ret, frame = cap.read()
145
+ if not ret:
146
+ break
147
+
148
+ # Process frame if we're not skipping it
149
+ if frame_idx % (skip + 1) == 0:
150
+ # Convert from BGR to RGB
151
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
152
+ # Resize if necessary
153
+ frame = process_frame(frame, max_res)
154
+ frames_chunk.append(frame)
155
+ processed_count += 1
156
+
157
+ if processed_count >= frame_count:
158
+ break
159
+
160
+ frame_idx += 1
161
+
162
+ if frames_chunk:
163
+ yield frames_chunk, fps
164
+
165
+ if processed_count >= frame_count or len(frames_chunk) < chunk_limit:
166
+ break
167
 
168
  cap.release()
169
 
 
192
  fps = video_info['fps']
193
  frame_count = video_info['frame_count']
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  print(f"Processing video: {input_video}, {frame_count} frames at {fps} fps")
196
 
197
+ # Process the video in chunks to manage memory
198
+ chunk_size = 32 # Adjust based on available memory
 
199
 
200
+ # We'll collect depths as we go to calculate global min/max
201
+ all_depths = []
202
+ all_processed_frames = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
+ # First pass to collect frames and depths
205
+ frame_idx = 0
206
+ for frames_chunk, fps in read_video_frames_chunked(input_video, max_len, target_fps, max_res, chunk_size):
207
+ print(f"Processing chunk: frames {frame_idx+1}-{frame_idx+len(frames_chunk)}/{frame_count}")
 
 
 
 
 
 
 
208
 
209
+ # Process this chunk of frames
210
+ depths, _ = video_depth_anything.infer_video_depth(frames_chunk, fps, input_size=input_size, device=DEVICE)
211
+
212
+ # Store results (we'll need both for the output videos)
213
+ all_processed_frames.extend(frames_chunk)
214
+ all_depths.extend(depths)
 
 
 
 
 
 
 
215
 
216
+ frame_idx += len(frames_chunk)
 
 
 
217
 
218
+ # Free memory
219
+ gc.collect()
220
+ torch.cuda.empty_cache()
221
+
222
+ # Calculate global min/max for depth normalization
223
+ depths_array = np.array(all_depths)
224
+ d_min, d_max = depths_array.min(), depths_array.max()
225
 
226
+ # Save the preprocessed video and depth visualization
227
+ save_video(all_processed_frames, processed_video_path, fps=fps)
228
+ save_video(all_depths, depth_vis_path, fps=fps, is_depths=True)
229
+
230
+ # Free some memory before stitching
231
+ del all_processed_frames
232
+ gc.collect()
233
 
234
  # Process stitched video if requested
235
  stitched_video_path = None
236
  if stitch:
237
+ # Use only the first 20 characters of the base name for the output filename
 
 
 
 
238
  base_name = os.path.splitext(video_name)[0]
239
  short_name = base_name[:20]
240
  stitched_video_path = os.path.join(output_dir, short_name + '_RGBD.mp4')
241
 
242
+ # For stitching: read the original video in full resolution and stitch frames one by one
243
+ stitched_frames = []
 
244
 
245
+ # Process in chunks for memory efficiency
246
+ frame_idx = 0
247
+ for frames_chunk, _ in read_video_frames_chunked(input_video, max_len, target_fps, -1, chunk_size): # No max_res for original resolution
248
+ print(f"Stitching chunk: frames {frame_idx+1}-{frame_idx+len(frames_chunk)}/{frame_count}")
 
 
 
 
 
 
 
 
 
 
 
249
 
250
+ # Process each frame in the chunk
251
+ for i, rgb_full in enumerate(frames_chunk):
252
+ depth_idx = frame_idx + i
253
+ if depth_idx >= len(all_depths):
254
+ break
255
+
256
+ depth_frame = all_depths[depth_idx]
257
+
258
+ # Normalize the depth frame
259
+ depth_norm = ((depth_frame - d_min) / (d_max - d_min) * 255).astype(np.uint8)
260
+
261
+ # Generate depth visualization
262
+ if grayscale:
263
+ if convert_from_color:
264
+ # Convert from color to grayscale
265
+ cmap = matplotlib.colormaps.get_cmap("inferno")
266
+ depth_color = (cmap(depth_norm / 255.0)[..., :3] * 255).astype(np.uint8)
267
+ depth_gray = cv2.cvtColor(depth_color, cv2.COLOR_RGB2GRAY)
268
+ depth_vis = np.stack([depth_gray] * 3, axis=-1)
269
+ else:
270
+ # Directly use grayscale
271
+ depth_vis = np.stack([depth_norm] * 3, axis=-1)
272
  else:
273
+ # Use color visualization
274
+ cmap = matplotlib.colormaps.get_cmap("inferno")
275
+ depth_vis = (cmap(depth_norm / 255.0)[..., :3] * 255).astype(np.uint8)
276
+
277
+ # Apply blur if requested
278
+ if blur > 0:
279
+ kernel_size = int(blur * 20) * 2 + 1 # Ensures odd kernel size
280
+ depth_vis = cv2.GaussianBlur(depth_vis, (kernel_size, kernel_size), 0)
281
+
282
+ # Resize depth visualization to match original resolution
283
+ H_full, W_full = rgb_full.shape[:2]
284
+ depth_vis_resized = cv2.resize(depth_vis, (W_full, H_full))
285
+
286
+ # Concatenate RGB and depth
287
+ stitched = cv2.hconcat([rgb_full, depth_vis_resized])
288
+ stitched_frames.append(stitched)
289
 
290
+ frame_idx += len(frames_chunk)
 
291
 
292
+ # Free memory after processing each chunk
293
+ gc.collect()
294
 
295
+ # Save the stitched video
296
+ save_video(stitched_frames, stitched_video_path, fps=fps)
297
 
298
+ # Merge audio from the input video
299
  temp_audio_path = stitched_video_path.replace('_RGBD.mp4', '_RGBD_audio.mp4')
300
  cmd = [
301
  "ffmpeg",
 
311
  ]
312
  subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
313
  os.replace(temp_audio_path, stitched_video_path)
314
+
315
+ # Free memory
316
+ del stitched_frames
317
 
318
  # Clean up
319
+ del all_depths
320
  gc.collect()
321
  torch.cuda.empty_cache()
322