Krokodilpirat commited on
Commit
033fb37
·
verified ·
1 Parent(s): 4423b71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -20
app.py CHANGED
@@ -10,8 +10,7 @@ from video_depth_anything.video_depth import VideoDepthAnything
10
  from utils.dc_utils import read_video_frames, save_video
11
  from huggingface_hub import hf_hub_download
12
 
13
- # Examples for the Gradio Demo
14
- # Hier wurden die zusätzlichen Parameter (stitch, grayscale, blur) mit Standardwerten ergänzt.
15
  examples = [
16
  ['assets/example_videos/davis_rollercoaster.mp4', -1, -1, 1280, False, False, 0],
17
  ['assets/example_videos/Tokyo-Walk_rgb.mp4', -1, -1, 1280, False, False, 0],
@@ -25,19 +24,18 @@ examples = [
25
  ['assets/example_videos/sheep_cut1.mp4', -1, -1, 1280, False, False, 0],
26
  ]
27
 
 
28
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
29
 
30
- # Model configuration
31
  model_configs = {
32
  'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
33
  'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
34
  }
35
-
36
  encoder2name = {
37
  'vits': 'Small',
38
  'vitl': 'Large',
39
  }
40
-
41
  encoder = 'vitl'
42
  model_name = encoder2name[encoder]
43
 
@@ -63,47 +61,52 @@ def infer_video_depth(
63
  stitch: bool = False,
64
  grayscale: bool = False,
65
  blur: float = 0.0,
66
- *, # Keyword-only parameters folgen ab hier:
67
  output_dir: str = './outputs',
68
  input_size: int = 518,
69
  ):
70
- # Read input video frames
71
  frames, target_fps = read_video_frames(input_video, max_len, target_fps, max_res)
72
- # Infer depth maps using the model
73
  depths, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=input_size, device=DEVICE)
74
 
75
  video_name = os.path.basename(input_video)
76
  if not os.path.exists(output_dir):
77
  os.makedirs(output_dir)
78
 
79
- # Save the processed (RGB) video and the depth visualization (using the default color mapping)
80
  processed_video_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_src.mp4')
81
  depth_vis_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_vis.mp4')
82
  save_video(frames, processed_video_path, fps=fps)
83
  save_video(depths, depth_vis_path, fps=fps, is_depths=True)
84
 
85
- stitched_video_path = ""
86
  if stitch:
87
- # Create a stitched video: left side is the RGB video, right side is the depth video
 
 
88
  d_min, d_max = depths.min(), depths.max()
89
  stitched_frames = []
90
- for i in range(min(len(frames), len(depths))):
91
- rgb_frame = frames[i]
92
  depth_frame = depths[i]
93
- # Normalize the depth frame to [0, 255]
94
  depth_norm = ((depth_frame - d_min) / (d_max - d_min) * 255).astype(np.uint8)
95
- # Use grayscale or colored mapping for the depth channel
96
  if grayscale:
97
  depth_vis = np.stack([depth_norm] * 3, axis=-1)
98
  else:
99
  cmap = cm.get_cmap("inferno")
100
  depth_vis = (cmap(depth_norm / 255.0)[..., :3] * 255).astype(np.uint8)
101
- # Apply Gaussian blur if requested (blur factor > 0)
102
  if blur > 0:
103
  kernel_size = int(blur * 20) * 2 + 1 # ensures an odd kernel size
104
  depth_vis = cv2.GaussianBlur(depth_vis, (kernel_size, kernel_size), 0)
105
- # Concatenate side-by-side: RGB frame on the left, processed depth on the right
106
- stitched = cv2.hconcat([rgb_frame, depth_vis])
 
 
 
107
  stitched_frames.append(stitched)
108
  stitched_frames = np.array(stitched_frames)
109
  stitched_video_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_stitched.mp4')
@@ -112,7 +115,7 @@ def infer_video_depth(
112
  gc.collect()
113
  torch.cuda.empty_cache()
114
 
115
- # Return processed RGB video, depth visualization, and (if created) stitched video.
116
  return [processed_video_path, depth_vis_path, stitched_video_path]
117
 
118
  def construct_demo():
@@ -123,7 +126,7 @@ def construct_demo():
123
 
124
  with gr.Row(equal_height=True):
125
  with gr.Column(scale=1):
126
- # Hier verwenden wir den Video-Component ohne den 'source'-Parameter.
127
  input_video = gr.Video(label="Input Video")
128
  with gr.Column(scale=2):
129
  with gr.Row(equal_height=True):
 
10
  from utils.dc_utils import read_video_frames, save_video
11
  from huggingface_hub import hf_hub_download
12
 
13
+ # Examples for the Gradio Demo (the additional parameters: stitch, grayscale, blur are appended)
 
14
  examples = [
15
  ['assets/example_videos/davis_rollercoaster.mp4', -1, -1, 1280, False, False, 0],
16
  ['assets/example_videos/Tokyo-Walk_rgb.mp4', -1, -1, 1280, False, False, 0],
 
24
  ['assets/example_videos/sheep_cut1.mp4', -1, -1, 1280, False, False, 0],
25
  ]
26
 
27
+ # Determine the device: use GPU if available, else CPU.
28
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
29
 
30
+ # Model configuration for different encoder variants.
31
  model_configs = {
32
  'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
33
  'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
34
  }
 
35
  encoder2name = {
36
  'vits': 'Small',
37
  'vitl': 'Large',
38
  }
 
39
  encoder = 'vitl'
40
  model_name = encoder2name[encoder]
41
 
 
61
  stitch: bool = False,
62
  grayscale: bool = False,
63
  blur: float = 0.0,
64
+ *, # The following parameters are keyword-only and cannot be overridden by UI input.
65
  output_dir: str = './outputs',
66
  input_size: int = 518,
67
  ):
68
+ # Read input video frames with the given maximum resolution (max_res) for inference.
69
  frames, target_fps = read_video_frames(input_video, max_len, target_fps, max_res)
70
+ # Perform depth inference using the model.
71
  depths, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=input_size, device=DEVICE)
72
 
73
  video_name = os.path.basename(input_video)
74
  if not os.path.exists(output_dir):
75
  os.makedirs(output_dir)
76
 
77
+ # Save the preprocessed (RGB) video and the depth visualization (using the default color mapping)
78
  processed_video_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_src.mp4')
79
  depth_vis_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_vis.mp4')
80
  save_video(frames, processed_video_path, fps=fps)
81
  save_video(depths, depth_vis_path, fps=fps, is_depths=True)
82
 
83
+ stitched_video_path = None
84
  if stitch:
85
+ # For stitching: read the original video in full resolution (without downscaling)
86
+ full_frames, _ = read_video_frames(input_video, max_len, target_fps, max_res=-1)
87
+ # For each frame, create a visual depth image from the inferenced depth maps (which are in the downscaled resolution)
88
  d_min, d_max = depths.min(), depths.max()
89
  stitched_frames = []
90
+ for i in range(min(len(full_frames), len(depths))):
91
+ rgb_full = full_frames[i] # Full-resolution RGB frame
92
  depth_frame = depths[i]
93
+ # Normalize the depth frame to the range [0, 255]
94
  depth_norm = ((depth_frame - d_min) / (d_max - d_min) * 255).astype(np.uint8)
95
+ # Create either a grayscale image or apply the inferno colormap, depending on the setting.
96
  if grayscale:
97
  depth_vis = np.stack([depth_norm] * 3, axis=-1)
98
  else:
99
  cmap = cm.get_cmap("inferno")
100
  depth_vis = (cmap(depth_norm / 255.0)[..., :3] * 255).astype(np.uint8)
101
+ # Apply Gaussian blur if requested (if blur factor > 0)
102
  if blur > 0:
103
  kernel_size = int(blur * 20) * 2 + 1 # ensures an odd kernel size
104
  depth_vis = cv2.GaussianBlur(depth_vis, (kernel_size, kernel_size), 0)
105
+ # Resize the depth visual image to match the full-resolution RGB frame.
106
+ H_full, W_full = rgb_full.shape[:2]
107
+ depth_vis_resized = cv2.resize(depth_vis, (W_full, H_full))
108
+ # Concatenate the full-resolution RGB frame (left) and the resized depth visual (right) side-by-side.
109
+ stitched = cv2.hconcat([rgb_full, depth_vis_resized])
110
  stitched_frames.append(stitched)
111
  stitched_frames = np.array(stitched_frames)
112
  stitched_video_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_stitched.mp4')
 
115
  gc.collect()
116
  torch.cuda.empty_cache()
117
 
118
+ # Return the processed RGB video, depth visualization, and (if created) the stitched video.
119
  return [processed_video_path, depth_vis_path, stitched_video_path]
120
 
121
  def construct_demo():
 
126
 
127
  with gr.Row(equal_height=True):
128
  with gr.Column(scale=1):
129
+ # Use the Video component for file upload (without specifying 'source')
130
  input_video = gr.Video(label="Input Video")
131
  with gr.Column(scale=2):
132
  with gr.Row(equal_height=True):