Shane922 commited on
Commit
bddb8a1
·
1 Parent(s): dd6503b

update video writer

Browse files
app.py CHANGED
@@ -13,14 +13,14 @@
13
  # limitations under the License.
14
  import spaces
15
  import gradio as gr
16
-
17
 
18
  import numpy as np
19
  import os
20
  import torch
21
 
22
  from video_depth_anything.video_depth import VideoDepthAnything
23
- from utils.dc_utils import read_video_frames, vis_sequence_depth, save_video
24
 
25
  from huggingface_hub import hf_hub_download
26
 
@@ -73,9 +73,8 @@ def infer_video_depth(
73
  input_size: int = 518,
74
  ):
75
  frames, target_fps = read_video_frames(input_video, max_len, target_fps, max_res)
76
- depth_list, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=input_size, device=DEVICE)
77
- depth_list = np.stack(depth_list, axis=0)
78
- vis = vis_sequence_depth(depth_list)
79
  video_name = os.path.basename(input_video)
80
  if not os.path.exists(output_dir):
81
  os.makedirs(output_dir)
@@ -83,7 +82,10 @@ def infer_video_depth(
83
  processed_video_path = os.path.join(output_dir, os.path.splitext(video_name)[0]+'_src.mp4')
84
  depth_vis_path = os.path.join(output_dir, os.path.splitext(video_name)[0]+'_vis.mp4')
85
  save_video(frames, processed_video_path, fps=fps)
86
- save_video(vis, depth_vis_path, fps=fps)
 
 
 
87
 
88
  return [processed_video_path, depth_vis_path]
89
 
 
13
  # limitations under the License.
14
  import spaces
15
  import gradio as gr
16
+ import gc
17
 
18
  import numpy as np
19
  import os
20
  import torch
21
 
22
  from video_depth_anything.video_depth import VideoDepthAnything
23
+ from utils.dc_utils import read_video_frames, save_video
24
 
25
  from huggingface_hub import hf_hub_download
26
 
 
73
  input_size: int = 518,
74
  ):
75
  frames, target_fps = read_video_frames(input_video, max_len, target_fps, max_res)
76
+ depths, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=input_size, device=DEVICE)
77
+
 
78
  video_name = os.path.basename(input_video)
79
  if not os.path.exists(output_dir):
80
  os.makedirs(output_dir)
 
82
  processed_video_path = os.path.join(output_dir, os.path.splitext(video_name)[0]+'_src.mp4')
83
  depth_vis_path = os.path.join(output_dir, os.path.splitext(video_name)[0]+'_vis.mp4')
84
  save_video(frames, processed_video_path, fps=fps)
85
+ save_video(depths, depth_vis_path, fps=fps, is_depths=True)
86
+
87
+ gc.collect()
88
+ torch.cuda.empty_cache()
89
 
90
  return [processed_video_path, depth_vis_path]
91
 
requirements.txt CHANGED
@@ -7,7 +7,8 @@ opencv-python
7
  matplotlib
8
  huggingface_hub
9
  pillow
10
- mediapy
 
11
  decord
12
  xformers
13
  einops
 
7
  matplotlib
8
  huggingface_hub
9
  pillow
10
+ imageio
11
+ imageio-ffmpeg
12
  decord
13
  xformers
14
  einops
utils/dc_utils.py CHANGED
@@ -3,82 +3,84 @@
3
  #
4
  # This file may have been modified by ByteDance Ltd. and/or its affiliates on [date of modification]
5
  # Original file is released under [ MIT License license], with the full license text available at [https://github.com/Tencent/DepthCrafter?tab=License-1-ov-file].
6
- from typing import Union, List
7
- import tempfile
8
  import numpy as np
9
- import PIL.Image
10
  import matplotlib.cm as cm
11
- import mediapy
12
- import torch
13
- from decord import VideoReader, cpu
 
 
 
 
14
 
 
 
15
 
16
- def read_video_frames(video_path, process_length, target_fps=-1, max_res=-1, dataset="open"):
 
 
 
 
 
 
 
 
 
17
 
18
- vid = VideoReader(video_path, ctx=cpu(0))
19
- print("==> original video shape: ", (len(vid), *vid.get_batch([0]).shape[1:]))
20
- original_height, original_width = vid.get_batch([0]).shape[1:3]
21
- height = original_height
22
- width = original_width
23
- if max_res > 0 and max(height, width) > max_res:
24
- scale = max_res / max(original_height, original_width)
25
- height = round(original_height * scale)
26
- width = round(original_width * scale)
27
 
28
- vid = VideoReader(video_path, ctx=cpu(0), width=width, height=height)
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- fps = vid.get_avg_fps() if target_fps == -1 else target_fps
31
- stride = round(vid.get_avg_fps() / fps)
32
- stride = max(stride, 1)
33
- frames_idx = list(range(0, len(vid), stride))
34
- if process_length != -1 and process_length < len(frames_idx):
35
- frames_idx = frames_idx[:process_length]
36
- print(f"==> final processing shape: {len(frames_idx), *vid.get_batch([0]).shape[1:]}")
37
- frames = vid.get_batch(frames_idx).asnumpy()
38
 
39
- return frames, fps
40
-
41
-
42
- def save_video(
43
- video_frames: Union[List[np.ndarray], List[PIL.Image.Image]],
44
- output_video_path: str = None,
45
- fps: int = 10,
46
- crf: int = 18,
47
- ) -> str:
48
- if output_video_path is None:
49
- output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
50
 
51
- if isinstance(video_frames[0], np.ndarray):
52
- video_frames = [frame.astype(np.uint8) for frame in video_frames]
53
 
54
- elif isinstance(video_frames[0], PIL.Image.Image):
55
- video_frames = [np.array(frame) for frame in video_frames]
56
- mediapy.write_video(output_video_path, video_frames, fps=fps, crf=crf)
57
- return output_video_path
 
 
 
 
 
 
 
 
 
 
58
 
 
59
 
60
- class ColorMapper:
61
- # a color mapper to map depth values to a certain colormap
62
- def __init__(self, colormap: str = "inferno"):
63
- self.colormap = torch.tensor(cm.get_cmap(colormap).colors)
64
-
65
- def apply(self, image: torch.Tensor, v_min=None, v_max=None):
66
- # assert len(image.shape) == 2
67
- if v_min is None:
68
- v_min = image.min()
69
- if v_max is None:
70
- v_max = image.max()
71
- image = (image - v_min) / (v_max - v_min)
72
- image = (image * 255).long()
73
- image = self.colormap[image] * 255
74
- return image
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- def vis_sequence_depth(depths: np.ndarray, v_min=None, v_max=None):
78
- visualizer = ColorMapper()
79
- if v_min is None:
80
- v_min = depths.min()
81
- if v_max is None:
82
- v_max = depths.max()
83
- res = visualizer.apply(torch.tensor(depths), v_min=v_min, v_max=v_max).numpy()
84
- return res
 
3
  #
4
  # This file may have been modified by ByteDance Ltd. and/or its affiliates on [date of modification]
5
  # Original file is released under [ MIT License license], with the full license text available at [https://github.com/Tencent/DepthCrafter?tab=License-1-ov-file].
 
 
6
  import numpy as np
 
7
  import matplotlib.cm as cm
8
+ import imageio
9
+ try:
10
+ from decord import VideoReader, cpu
11
+ DECORD_AVAILABLE = True
12
+ except:
13
+ import cv2
14
+ DECORD_AVAILABLE = False
15
 
16
+ def ensure_even(value):
17
+ return value if value % 2 == 0 else value + 1
18
 
19
+ def read_video_frames(video_path, process_length, target_fps=-1, max_res=-1):
20
+ if DECORD_AVAILABLE:
21
+ vid = VideoReader(video_path, ctx=cpu(0))
22
+ original_height, original_width = vid.get_batch([0]).shape[1:3]
23
+ height = original_height
24
+ width = original_width
25
+ if max_res > 0 and max(height, width) > max_res:
26
+ scale = max_res / max(original_height, original_width)
27
+ height = ensure_even(round(original_height * scale))
28
+ width = ensure_even(round(original_width * scale))
29
 
30
+ vid = VideoReader(video_path, ctx=cpu(0), width=width, height=height)
 
 
 
 
 
 
 
 
31
 
32
+ fps = vid.get_avg_fps() if target_fps == -1 else target_fps
33
+ stride = round(vid.get_avg_fps() / fps)
34
+ stride = max(stride, 1)
35
+ frames_idx = list(range(0, len(vid), stride))
36
+ if process_length != -1 and process_length < len(frames_idx):
37
+ frames_idx = frames_idx[:process_length]
38
+ frames = vid.get_batch(frames_idx).asnumpy()
39
+ else:
40
+ cap = cv2.VideoCapture(video_path)
41
+ original_fps = cap.get(cv2.CAP_PROP_FPS)
42
+ original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
43
+ original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
44
 
45
+ if max_res > 0 and max(original_height, original_width) > max_res:
46
+ scale = max_res / max(original_height, original_width)
47
+ height = round(original_height * scale)
48
+ width = round(original_width * scale)
 
 
 
 
49
 
50
+ fps = original_fps if target_fps < 0 else target_fps
 
 
 
 
 
 
 
 
 
 
51
 
52
+ stride = max(round(original_fps / fps), 1)
 
53
 
54
+ frames = []
55
+ frame_count = 0
56
+ while cap.isOpened():
57
+ ret, frame = cap.read()
58
+ if not ret or (process_length > 0 and frame_count >= process_length):
59
+ break
60
+ if frame_count % stride == 0:
61
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
62
+ if max_res > 0 and max(original_height, original_width) > max_res:
63
+ frame = cv2.resize(frame, (width, height)) # Resize frame
64
+ frames.append(frame)
65
+ frame_count += 1
66
+ cap.release()
67
+ frames = np.stack(frames, axis=0)
68
 
69
+ return frames, fps
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ def save_video(frames, output_video_path, fps=10, is_depths=False):
73
+ writer = imageio.get_writer(output_video_path, fps=fps, macro_block_size=1, codec='libx264', ffmpeg_params=['-crf', '18'])
74
+ if is_depths:
75
+ colormap = np.array(cm.get_cmap("inferno").colors)
76
+ d_min, d_max = frames.min(), frames.max()
77
+ for i in range(frames.shape[0]):
78
+ depth = frames[i]
79
+ depth_norm = ((depth - d_min) / (d_max - d_min) * 255).astype(np.uint8)
80
+ depth_vis = (colormap[depth_norm] * 255).astype(np.uint8)
81
+ writer.append_data(depth_vis)
82
+ else:
83
+ for i in range(frames.shape[0]):
84
+ writer.append_data(frames[i])
85
 
86
+ writer.close()
 
 
 
 
 
 
 
video_depth_anything/video_depth.py CHANGED
@@ -152,5 +152,5 @@ class VideoDepthAnything(nn.Module):
152
 
153
  depth_list = depth_list_aligned
154
 
155
- return depth_list[:org_video_len], target_fps
156
 
 
152
 
153
  depth_list = depth_list_aligned
154
 
155
+ return np.stack(depth_list[:org_video_len], axis=0), target_fps
156