Krokodilpirat commited on
Commit
28da247
·
verified ·
1 Parent(s): 891bd26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -86
app.py CHANGED
@@ -1,29 +1,16 @@
1
- # Copyright (2025) Bytedance Ltd. and/or its affiliates
2
-
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
-
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import spaces
15
- import gradio as gr
16
- import gc
17
-
18
- import numpy as np
19
  import os
 
20
  import torch
 
 
 
 
21
 
22
  from video_depth_anything.video_depth import VideoDepthAnything
23
  from utils.dc_utils import read_video_frames, save_video
24
-
25
  from huggingface_hub import hf_hub_download
26
 
 
27
  examples = [
28
  ['assets/example_videos/davis_rollercoaster.mp4', -1, -1, 1280],
29
  ['assets/example_videos/Tokyo-Walk_rgb.mp4', -1, -1, 1280],
@@ -39,6 +26,7 @@ examples = [
39
 
40
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
41
 
 
42
  model_configs = {
43
  'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
44
  'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
@@ -49,21 +37,23 @@ encoder2name = {
49
  'vitl': 'Large',
50
  }
51
 
52
- encoder='vitl'
53
  model_name = encoder2name[encoder]
54
 
 
55
  video_depth_anything = VideoDepthAnything(**model_configs[encoder])
56
- filepath = hf_hub_download(repo_id=f"depth-anything/Video-Depth-Anything-{model_name}", filename=f"video_depth_anything_{encoder}.pth", repo_type="model")
 
 
57
  video_depth_anything.load_state_dict(torch.load(filepath, map_location='cpu'))
58
  video_depth_anything = video_depth_anything.to(DEVICE).eval()
59
 
60
-
61
  title = "# Video Depth Anything"
62
  description = """Official demo for **Video Depth Anything**.
63
  Please refer to our [paper](https://arxiv.org/abs/2501.12375), [project page](https://videodepthanything.github.io/), and [github](https://github.com/DepthAnything/Video-Depth-Anything) for more details."""
64
 
65
 
66
- @spaces.GPU(duration=240)
67
  def infer_video_depth(
68
  input_video: str,
69
  max_len: int = -1,
@@ -71,106 +61,100 @@ def infer_video_depth(
71
  max_res: int = 1280,
72
  output_dir: str = './outputs',
73
  input_size: int = 518,
 
 
 
74
  ):
 
75
  frames, target_fps = read_video_frames(input_video, max_len, target_fps, max_res)
 
76
  depths, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=input_size, device=DEVICE)
77
 
78
  video_name = os.path.basename(input_video)
79
  if not os.path.exists(output_dir):
80
  os.makedirs(output_dir)
81
 
82
- processed_video_path = os.path.join(output_dir, os.path.splitext(video_name)[0]+'_src.mp4')
83
- depth_vis_path = os.path.join(output_dir, os.path.splitext(video_name)[0]+'_vis.mp4')
 
84
  save_video(frames, processed_video_path, fps=fps)
85
  save_video(depths, depth_vis_path, fps=fps, is_depths=True)
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  gc.collect()
88
  torch.cuda.empty_cache()
89
 
90
- return [processed_video_path, depth_vis_path]
 
 
91
 
92
 
93
  def construct_demo():
94
  with gr.Blocks(analytics_enabled=False) as demo:
95
  gr.Markdown(title)
96
  gr.Markdown(description)
97
- gr.Markdown("### If you find this work useful, please help ⭐ the [\[Github Repo\]](https://github.com/DepthAnything/Video-Depth-Anything). Thanks for your attention!")
98
 
99
  with gr.Row(equal_height=True):
100
  with gr.Column(scale=1):
101
- input_video = gr.Video(label="Input Video")
102
-
103
- # with gr.Tab(label="Output"):
104
  with gr.Column(scale=2):
105
  with gr.Row(equal_height=True):
106
- processed_video = gr.Video(
107
- label="Preprocessed video",
108
- interactive=False,
109
- autoplay=True,
110
- loop=True,
111
- show_share_button=True,
112
- scale=5,
113
- )
114
- depth_vis_video = gr.Video(
115
- label="Generated Depth Video",
116
- interactive=False,
117
- autoplay=True,
118
- loop=True,
119
- show_share_button=True,
120
- scale=5,
121
- )
122
-
123
  with gr.Row(equal_height=True):
124
  with gr.Column(scale=1):
125
- with gr.Row(equal_height=False):
126
- with gr.Accordion("Advanced Settings", open=False):
127
- max_len = gr.Slider(
128
- label="max process length",
129
- minimum=-1,
130
- maximum=1000,
131
- value=500,
132
- step=1,
133
- )
134
- target_fps = gr.Slider(
135
- label="target FPS",
136
- minimum=-1,
137
- maximum=30,
138
- value=15,
139
- step=1,
140
- )
141
- max_res = gr.Slider(
142
- label="max side resolution",
143
- minimum=480,
144
- maximum=1920,
145
- value=1280,
146
- step=1,
147
- )
148
- generate_btn = gr.Button("Generate")
149
  with gr.Column(scale=2):
150
  pass
151
 
152
  gr.Examples(
153
  examples=examples,
154
- inputs=[
155
- input_video,
156
- max_len,
157
- target_fps,
158
- max_res
159
- ],
160
- outputs=[processed_video, depth_vis_video],
161
  fn=infer_video_depth,
162
  cache_examples="lazy",
163
  )
164
 
165
  generate_btn.click(
166
  fn=infer_video_depth,
167
- inputs=[
168
- input_video,
169
- max_len,
170
- target_fps,
171
- max_res
172
- ],
173
- outputs=[processed_video, depth_vis_video],
174
  )
175
 
176
  return demo
@@ -178,4 +162,4 @@ def construct_demo():
178
  if __name__ == "__main__":
179
  demo = construct_demo()
180
  demo.queue()
181
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import gc
3
  import torch
4
+ import cv2 # Wird für die Bildverarbeitung (z.B. hconcat, GaussianBlur) benötigt
5
+ import gradio as gr
6
+ import numpy as np
7
+ import matplotlib.cm as cm
8
 
9
  from video_depth_anything.video_depth import VideoDepthAnything
10
  from utils.dc_utils import read_video_frames, save_video
 
11
  from huggingface_hub import hf_hub_download
12
 
13
+ # Examples for the Gradio Demo
14
  examples = [
15
  ['assets/example_videos/davis_rollercoaster.mp4', -1, -1, 1280],
16
  ['assets/example_videos/Tokyo-Walk_rgb.mp4', -1, -1, 1280],
 
26
 
27
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
28
 
29
+ # Model configuration
30
  model_configs = {
31
  'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
32
  'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
 
37
  'vitl': 'Large',
38
  }
39
 
40
+ encoder = 'vitl'
41
  model_name = encoder2name[encoder]
42
 
43
+ # Initialize the model
44
  video_depth_anything = VideoDepthAnything(**model_configs[encoder])
45
+ filepath = hf_hub_download(repo_id=f"depth-anything/Video-Depth-Anything-{model_name}",
46
+ filename=f"video_depth_anything_{encoder}.pth",
47
+ repo_type="model")
48
  video_depth_anything.load_state_dict(torch.load(filepath, map_location='cpu'))
49
  video_depth_anything = video_depth_anything.to(DEVICE).eval()
50
 
 
51
  title = "# Video Depth Anything"
52
  description = """Official demo for **Video Depth Anything**.
53
  Please refer to our [paper](https://arxiv.org/abs/2501.12375), [project page](https://videodepthanything.github.io/), and [github](https://github.com/DepthAnything/Video-Depth-Anything) for more details."""
54
 
55
 
56
+ @gradio.processing_utils.threaded # alternativ kann spaces.GPU genutzt werden, falls verfügbar
57
  def infer_video_depth(
58
  input_video: str,
59
  max_len: int = -1,
 
61
  max_res: int = 1280,
62
  output_dir: str = './outputs',
63
  input_size: int = 518,
64
+ stitch: bool = False,
65
+ grayscale: bool = False,
66
+ blur: float = 0.0,
67
  ):
68
+ # Read input video frames
69
  frames, target_fps = read_video_frames(input_video, max_len, target_fps, max_res)
70
+ # Infer depths using the model
71
  depths, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=input_size, device=DEVICE)
72
 
73
  video_name = os.path.basename(input_video)
74
  if not os.path.exists(output_dir):
75
  os.makedirs(output_dir)
76
 
77
+ # Save the processed (RGB) video and the depth visualization
78
+ processed_video_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_src.mp4')
79
+ depth_vis_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_vis.mp4')
80
  save_video(frames, processed_video_path, fps=fps)
81
  save_video(depths, depth_vis_path, fps=fps, is_depths=True)
82
 
83
+ stitched_video_path = ""
84
+ if stitch:
85
+ # Create a stitched video (side-by-side): left: processed RGB, right: depth
86
+ d_min, d_max = depths.min(), depths.max()
87
+ stitched_frames = []
88
+ for i in range(min(len(frames), len(depths))):
89
+ rgb_frame = frames[i]
90
+ depth_frame = depths[i]
91
+ # Normalize the depth frame to [0, 255]
92
+ depth_norm = ((depth_frame - d_min) / (d_max - d_min) * 255).astype(np.uint8)
93
+ # Choose grayscale or colored mapping
94
+ if grayscale:
95
+ depth_vis = np.stack([depth_norm] * 3, axis=-1)
96
+ else:
97
+ cmap = cm.get_cmap("inferno")
98
+ # cmap liefert RGBA, hier verwenden wir nur die ersten drei Kanäle
99
+ depth_vis = (cmap(depth_norm / 255.0)[..., :3] * 255).astype(np.uint8)
100
+ # Apply Gaussian blur if requested
101
+ if blur > 0:
102
+ kernel_size = int(blur * 20) * 2 + 1 # ensures odd kernel size
103
+ depth_vis = cv2.GaussianBlur(depth_vis, (kernel_size, kernel_size), 0)
104
+ # Concatenate side-by-side
105
+ stitched = cv2.hconcat([rgb_frame, depth_vis])
106
+ stitched_frames.append(stitched)
107
+ stitched_frames = np.array(stitched_frames)
108
+ stitched_video_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_stitched.mp4')
109
+ save_video(stitched_frames, stitched_video_path, fps=fps)
110
+
111
  gc.collect()
112
  torch.cuda.empty_cache()
113
 
114
+ # Return three outputs: processed RGB video, depth visualization, and (optionally) stitched video.
115
+ # Falls stitch nicht aktiviert ist, wird ein leerer String zurückgegeben.
116
+ return [processed_video_path, depth_vis_path, stitched_video_path]
117
 
118
 
119
  def construct_demo():
120
  with gr.Blocks(analytics_enabled=False) as demo:
121
  gr.Markdown(title)
122
  gr.Markdown(description)
123
+ gr.Markdown("### If you find this work useful, please help ⭐ the [Github Repo](https://github.com/DepthAnything/Video-Depth-Anything). Thanks for your attention!")
124
 
125
  with gr.Row(equal_height=True):
126
  with gr.Column(scale=1):
127
+ input_video = gr.Video(label="Input Video", source="upload", type="filepath")
 
 
128
  with gr.Column(scale=2):
129
  with gr.Row(equal_height=True):
130
+ processed_video = gr.Video(label="Preprocessed Video", interactive=False, autoplay=True, loop=True, show_share_button=True, scale=5)
131
+ depth_vis_video = gr.Video(label="Generated Depth Video", interactive=False, autoplay=True, loop=True, show_share_button=True, scale=5)
132
+ stitched_video = gr.Video(label="Stitched RGBD Video", interactive=False, autoplay=True, loop=True, show_share_button=True, scale=5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  with gr.Row(equal_height=True):
134
  with gr.Column(scale=1):
135
+ with gr.Accordion("Advanced Settings", open=False):
136
+ max_len = gr.Slider(label="Max process length", minimum=-1, maximum=1000, value=500, step=1)
137
+ target_fps = gr.Slider(label="Target FPS", minimum=-1, maximum=30, value=15, step=1)
138
+ max_res = gr.Slider(label="Max side resolution", minimum=480, maximum=1920, value=1280, step=1)
139
+ stitch_option = gr.Checkbox(label="Stitch RGB & Depth Videos", value=False)
140
+ grayscale_option = gr.Checkbox(label="Output Depth as Grayscale", value=False)
141
+ blur_slider = gr.Slider(minimum=0, maximum=1, step=0.01, label="Depth Blur Factor", value=0)
142
+ generate_btn = gr.Button("Generate")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  with gr.Column(scale=2):
144
  pass
145
 
146
  gr.Examples(
147
  examples=examples,
148
+ inputs=[input_video, max_len, target_fps, max_res, stitch_option, grayscale_option, blur_slider],
149
+ outputs=[processed_video, depth_vis_video, stitched_video],
 
 
 
 
 
150
  fn=infer_video_depth,
151
  cache_examples="lazy",
152
  )
153
 
154
  generate_btn.click(
155
  fn=infer_video_depth,
156
+ inputs=[input_video, max_len, target_fps, max_res, stitch_option, grayscale_option, blur_slider],
157
+ outputs=[processed_video, depth_vis_video, stitched_video],
 
 
 
 
 
158
  )
159
 
160
  return demo
 
162
  if __name__ == "__main__":
163
  demo = construct_demo()
164
  demo.queue()
165
+ demo.launch(share=True)