haodongli commited on
Commit
a47128f
·
1 Parent(s): 4943e69

update for video depth

Browse files
Files changed (6) hide show
  1. app.py +100 -39
  2. files/videos/00.mp4 +3 -0
  3. files/videos/01.mp4 +3 -0
  4. infer.py +69 -79
  5. pipeline.py +1 -1
  6. utils/image_utils.py +5 -2
app.py CHANGED
@@ -31,18 +31,19 @@ def infer(path_input, seed):
31
  return [path_input, g_save_path], [path_input, d_save_path]
32
 
33
  def infer_video(path_input, seed):
34
- frames_g, frames_d = lotus_video(path_input, 'depth', seed, device)
35
  if not os.path.exists("files/output"):
36
  os.makedirs("files/output")
37
  name_base, _ = os.path.splitext(os.path.basename(path_input))
38
  g_save_path = os.path.join("files/output", f"{name_base}_g.mp4")
39
  d_save_path = os.path.join("files/output", f"{name_base}_d.mp4")
40
- imageio.mimsave(g_save_path, frames_g)
41
- imageio.mimsave(d_save_path, frames_d)
42
  return [g_save_path, d_save_path]
43
 
44
  def run_demo_server():
45
  infer_gpu = spaces.GPU(functools.partial(infer))
 
46
  gradio_theme = gr.themes.Default()
47
 
48
  with gr.Blocks(
@@ -113,49 +114,96 @@ def run_demo_server():
113
  """
114
  )
115
  with gr.Tabs(elem_classes=["tabs"]):
116
- with gr.Row():
117
- with gr.Column():
118
- image_input = gr.Image(
119
- label="Input Image",
120
- type="filepath",
121
- )
122
- seed = gr.Number(
123
- label="Seed (only for Generative mode)",
124
- minimum=0,
125
- maximum=999999999,
126
- )
127
- with gr.Row():
128
- image_submit_btn = gr.Button(
129
- value="Predict Depth!", variant="primary"
130
  )
131
- image_reset_btn = gr.Button(value="Reset")
132
- with gr.Column():
133
- image_output_g = ImageSlider(
134
- label="Output (Generative)",
135
- type="filepath",
136
- interactive=False,
137
- elem_classes="slider",
138
- position=0.25,
139
- )
140
- with gr.Row():
141
- image_output_d = ImageSlider(
142
- label="Output (Discriminative)",
143
  type="filepath",
144
  interactive=False,
145
  elem_classes="slider",
146
  position=0.25,
147
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- gr.Examples(
150
- fn=infer_gpu,
151
- examples=sorted([
152
- [os.path.join("files", "images", name), 0]
153
- for name in os.listdir(os.path.join("files", "images"))
154
- ]),
155
- inputs=[image_input, seed],
156
- outputs=[image_output_g, image_output_d],
157
- cache_examples=False,
158
- )
159
 
160
  ### Image
161
  image_submit_btn.click(
@@ -175,6 +223,19 @@ def run_demo_server():
175
  queue=False,
176
  )
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  ### Server launch
179
  demo.queue(
180
  api_open=False,
 
31
  return [path_input, g_save_path], [path_input, d_save_path]
32
 
33
  def infer_video(path_input, seed):
34
+ frames_g, frames_d, fps = lotus_video(path_input, 'depth', seed, device)
35
  if not os.path.exists("files/output"):
36
  os.makedirs("files/output")
37
  name_base, _ = os.path.splitext(os.path.basename(path_input))
38
  g_save_path = os.path.join("files/output", f"{name_base}_g.mp4")
39
  d_save_path = os.path.join("files/output", f"{name_base}_d.mp4")
40
+ imageio.mimsave(g_save_path, frames_g, fps=fps)
41
+ imageio.mimsave(d_save_path, frames_d, fps=fps)
42
  return [g_save_path, d_save_path]
43
 
44
  def run_demo_server():
45
  infer_gpu = spaces.GPU(functools.partial(infer))
46
+ infer_video_gpu = spaces.GPU(functools.partial(infer_video))
47
  gradio_theme = gr.themes.Default()
48
 
49
  with gr.Blocks(
 
114
  """
115
  )
116
  with gr.Tabs(elem_classes=["tabs"]):
117
+ with gr.Tab("IMAGE"):
118
+ with gr.Row():
119
+ with gr.Column():
120
+ image_input = gr.Image(
121
+ label="Input Image",
122
+ type="filepath",
123
+ )
124
+ seed = gr.Number(
125
+ label="Seed (only for Generative mode)",
126
+ minimum=0,
127
+ maximum=999999999,
 
 
 
128
  )
129
+ with gr.Row():
130
+ image_submit_btn = gr.Button(
131
+ value="Predict Depth!", variant="primary"
132
+ )
133
+ image_reset_btn = gr.Button(value="Reset")
134
+ with gr.Column():
135
+ image_output_g = ImageSlider(
136
+ label="Output (Generative)",
 
 
 
 
137
  type="filepath",
138
  interactive=False,
139
  elem_classes="slider",
140
  position=0.25,
141
  )
142
+ with gr.Row():
143
+ image_output_d = ImageSlider(
144
+ label="Output (Discriminative)",
145
+ type="filepath",
146
+ interactive=False,
147
+ elem_classes="slider",
148
+ position=0.25,
149
+ )
150
+
151
+ gr.Examples(
152
+ fn=infer_gpu,
153
+ examples=sorted([
154
+ [os.path.join("files", "images", name), 0]
155
+ for name in os.listdir(os.path.join("files", "images"))
156
+ ]),
157
+ inputs=[image_input, seed],
158
+ outputs=[image_output_g, image_output_d],
159
+ cache_examples=False,
160
+ )
161
+
162
+ with gr.Tab("VIDEO"):
163
+ with gr.Row():
164
+ with gr.Column():
165
+ input_video = gr.Video(
166
+ label="Input Video",
167
+ autoplay=True,
168
+ loop=True,
169
+ )
170
+ seed = gr.Number(
171
+ label="Seed (only for Generative mode)",
172
+ minimum=0,
173
+ maximum=999999999,
174
+ )
175
+ with gr.Row():
176
+ video_submit_btn = gr.Button(
177
+ value="Predict Depth!", variant="primary"
178
+ )
179
+ video_reset_btn = gr.Button(value="Reset")
180
+ with gr.Column():
181
+ video_output_g = gr.Video(
182
+ label="Output (Generative)",
183
+ interactive=False,
184
+ autoplay=True,
185
+ loop=True,
186
+ show_share_button=True,
187
+ )
188
+ with gr.Row():
189
+ video_output_d = gr.Video(
190
+ label="Output (Discriminative)",
191
+ interactive=False,
192
+ autoplay=True,
193
+ loop=True,
194
+ show_share_button=True,
195
+ )
196
 
197
+ gr.Examples(
198
+ fn=infer_video_gpu,
199
+ examples=sorted([
200
+ [os.path.join("files", "videos", name), 0]
201
+ for name in os.listdir(os.path.join("files", "videos"))
202
+ ]),
203
+ inputs=[input_video, seed],
204
+ outputs=[video_output_g, video_output_d],
205
+ cache_examples=False,
206
+ )
207
 
208
  ### Image
209
  image_submit_btn.click(
 
223
  queue=False,
224
  )
225
 
226
+ ### Video
227
+ video_submit_btn.click(
228
+ fn=infer_video_gpu,
229
+ inputs=[input_video, seed],
230
+ outputs=[video_output_g, video_output_d],
231
+ queue=True,
232
+ )
233
+ video_reset_btn.click(
234
+ fn=lambda: (None, None, None),
235
+ inputs=[],
236
+ outputs=[video_output_g, video_output_d],
237
+ )
238
+
239
  ### Server launch
240
  demo.queue(
241
  api_open=False,
files/videos/00.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ddb5e80168634ef46cdd5bb45178573a34001f147c7e96eb6220c09bfc0c4649
3
+ size 3774878
files/videos/01.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a532ba2738716dbb244e0d7172cf681879218cbbdad09980404fa08ef6b9ecc
3
+ size 3095352
infer.py CHANGED
@@ -19,7 +19,7 @@ import cv2
19
 
20
  check_min_version('0.28.0.dev0')
21
 
22
- def infer_pipe(pipe, image_input, task_name, seed, device):
23
  if seed is None:
24
  generator = None
25
  else:
@@ -31,7 +31,8 @@ def infer_pipe(pipe, image_input, task_name, seed, device):
31
  autocast_ctx = torch.autocast(pipe.device.type)
32
  with autocast_ctx:
33
 
34
- test_image = Image.open(image_input).convert('RGB')
 
35
  test_image = np.array(test_image).astype(np.float16)
36
  test_image = torch.tensor(test_image).permute(2,0,1).unsqueeze(0)
37
  test_image = test_image / 127.5 - 1.0
@@ -55,17 +56,57 @@ def infer_pipe(pipe, image_input, task_name, seed, device):
55
  # Post-process the prediction
56
  if task_name == 'depth':
57
  output_npy = pred.mean(axis=-1)
58
- output_color = colorize_depth_map(output_npy)
59
  else:
60
  output_npy = pred
61
  output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
62
 
63
  return output_color
64
 
65
- def lotus_video(input_video, task_name, seed, device):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  if task_name == 'depth':
67
- model_g = 'jingheya/lotus-depth-g-v1-0'
68
- model_d = 'jingheya/lotus-depth-d-v1-0'
69
  else:
70
  model_g = 'jingheya/lotus-normal-g-v1-0'
71
  model_d = 'jingheya/lotus-normal-d-v1-0'
@@ -84,9 +125,17 @@ def lotus_video(input_video, task_name, seed, device):
84
  pipe_g.set_progress_bar_config(disable=True)
85
  pipe_d.set_progress_bar_config(disable=True)
86
  logging.info(f"Successfully loading pipeline from {model_g} and {model_d}.")
 
 
 
 
87
 
88
  # load the video and split it into frames
89
  cap = cv2.VideoCapture(input_video)
 
 
 
 
90
  frames = []
91
  while True:
92
  ret, frame = cap.read()
@@ -94,91 +143,32 @@ def lotus_video(input_video, task_name, seed, device):
94
  break
95
  frames.append(frame)
96
  cap.release()
97
- logging.info(f"There are {len(frames)} frames in the video.")
98
 
 
99
  if seed is None:
100
  generator = None
101
  else:
102
  generator = torch.Generator(device=device).manual_seed(seed)
103
-
104
- task_emb = torch.tensor([1, 0]).float().unsqueeze(0).repeat(1, 1).to(device)
105
- task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1).repeat(1, 1)
 
106
 
107
  output_g = []
108
  output_d = []
109
  for frame in frames:
110
- if torch.backends.mps.is_available():
111
- autocast_ctx = nullcontext()
112
- else:
113
- autocast_ctx = torch.autocast(pipe_g.device.type)
114
- with autocast_ctx:
115
- test_image = frame
116
- test_image = np.array(test_image).astype(np.float16)
117
- test_image = torch.tensor(test_image).permute(2,0,1).unsqueeze(0)
118
- test_image = test_image / 127.5 - 1.0
119
- test_image = test_image.to(device)
120
 
121
- # Run
122
- pred_g = pipe_g(
123
- rgb_in=test_image,
124
- prompt='',
125
- num_inference_steps=1,
126
- generator=generator,
127
- # guidance_scale=0,
128
- output_type='np',
129
- timesteps=[999],
130
- task_emb=task_emb,
131
- ).images[0]
132
- pred_d = pipe_d(
133
- rgb_in=test_image,
134
- prompt='',
135
- num_inference_steps=1,
136
- generator=generator,
137
- # guidance_scale=0,
138
- output_type='np',
139
- timesteps=[999],
140
- task_emb=task_emb,
141
- ).images[0]
142
-
143
- # Post-process the prediction
144
- if task_name == 'depth':
145
- output_npy_g = pred_g.mean(axis=-1)
146
- output_color_g = colorize_depth_map(output_npy_g)
147
- output_npy_d = pred_d.mean(axis=-1)
148
- output_color_d = colorize_depth_map(output_npy_d)
149
- else:
150
- output_npy_g = pred_g
151
- output_color_g = Image.fromarray((output_npy_g * 255).astype(np.uint8))
152
- output_npy_d = pred_d
153
- output_color_d = Image.fromarray((output_npy_d * 255).astype(np.uint8))
154
-
155
- output_g.append(output_color_g)
156
- output_d.append(output_color_d)
157
-
158
- return output_g, output_d
159
 
160
  def lotus(image_input, task_name, seed, device):
161
- if task_name == 'depth':
162
- model_g = 'jingheya/lotus-depth-g-v1-0'
163
- model_d = 'jingheya/lotus-depth-d-v1-1'
164
- else:
165
- model_g = 'jingheya/lotus-normal-g-v1-0'
166
- model_d = 'jingheya/lotus-normal-d-v1-0'
167
-
168
- dtype = torch.float16
169
- pipe_g = LotusGPipeline.from_pretrained(
170
- model_g,
171
- torch_dtype=dtype,
172
- )
173
- pipe_d = LotusDPipeline.from_pretrained(
174
- model_d,
175
- torch_dtype=dtype,
176
- )
177
- pipe_g.to(device)
178
- pipe_d.to(device)
179
- pipe_g.set_progress_bar_config(disable=True)
180
- pipe_d.set_progress_bar_config(disable=True)
181
- logging.info(f"Successfully loading pipeline from {model_g} and {model_d}.")
182
  output_g = infer_pipe(pipe_g, image_input, task_name, seed, device)
183
  output_d = infer_pipe(pipe_d, image_input, task_name, seed, device)
184
  return output_g, output_d
 
19
 
20
  check_min_version('0.28.0.dev0')
21
 
22
+ def infer_pipe(pipe, test_image, task_name, seed, device, video_depth=False):
23
  if seed is None:
24
  generator = None
25
  else:
 
31
  autocast_ctx = torch.autocast(pipe.device.type)
32
  with autocast_ctx:
33
 
34
+ if video_depth == False:
35
+ test_image = Image.open(test_image).convert('RGB')
36
  test_image = np.array(test_image).astype(np.float16)
37
  test_image = torch.tensor(test_image).permute(2,0,1).unsqueeze(0)
38
  test_image = test_image / 127.5 - 1.0
 
56
  # Post-process the prediction
57
  if task_name == 'depth':
58
  output_npy = pred.mean(axis=-1)
59
+ output_color = colorize_depth_map(output_npy, reverse_color=True)
60
  else:
61
  output_npy = pred
62
  output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
63
 
64
  return output_color
65
 
66
+ def infer_pipe_video(pipe, test_image, task_name, generator, device, latents=None):
67
+ if torch.backends.mps.is_available():
68
+ autocast_ctx = nullcontext()
69
+ else:
70
+ autocast_ctx = torch.autocast(pipe.device.type)
71
+ with autocast_ctx:
72
+ test_image = np.array(test_image).astype(np.float16)
73
+ test_image = torch.tensor(test_image).permute(2,0,1).unsqueeze(0)
74
+ test_image = test_image / 127.5 - 1.0
75
+ test_image = test_image.to(device)
76
+
77
+ task_emb = torch.tensor([1, 0]).float().unsqueeze(0).repeat(1, 1).to(device)
78
+ task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1).repeat(1, 1)
79
+
80
+ # Run
81
+ output = pipe(
82
+ rgb_in=test_image,
83
+ prompt='',
84
+ num_inference_steps=1,
85
+ generator=generator,
86
+ latents=latents,
87
+ # guidance_scale=0,
88
+ output_type='np',
89
+ timesteps=[999],
90
+ task_emb=task_emb,
91
+ return_dict=False
92
+ )
93
+ pred = output[0][0]
94
+ last_frame_latent = output[2]
95
+
96
+ # Post-process the prediction
97
+ if task_name == 'depth':
98
+ output_npy = pred.mean(axis=-1)
99
+ output_color = colorize_depth_map(output_npy, reverse_color=True)
100
+ else:
101
+ output_npy = pred
102
+ output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
103
+
104
+ return output_color, last_frame_latent
105
+
106
+ def load_pipe(task_name, device):
107
  if task_name == 'depth':
108
+ model_g = 'jingheya/lotus-depth-g-v2-0-disparity'
109
+ model_d = 'jingheya/lotus-depth-d-v2-0-disparity'
110
  else:
111
  model_g = 'jingheya/lotus-normal-g-v1-0'
112
  model_d = 'jingheya/lotus-normal-d-v1-0'
 
125
  pipe_g.set_progress_bar_config(disable=True)
126
  pipe_d.set_progress_bar_config(disable=True)
127
  logging.info(f"Successfully loading pipeline from {model_g} and {model_d}.")
128
+ return pipe_g, pipe_d
129
+
130
+ def lotus_video(input_video, task_name, seed, device):
131
+ pipe_g, pipe_d = load_pipe(task_name, device)
132
 
133
  # load the video and split it into frames
134
  cap = cv2.VideoCapture(input_video)
135
+ fps = cap.get(cv2.CAP_PROP_FPS)
136
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
137
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
138
+
139
  frames = []
140
  while True:
141
  ret, frame = cap.read()
 
143
  break
144
  frames.append(frame)
145
  cap.release()
 
146
 
147
+ # generate latents_common for lotus-g
148
  if seed is None:
149
  generator = None
150
  else:
151
  generator = torch.Generator(device=device).manual_seed(seed)
152
+ last_frame_latent = None
153
+ latent_common = torch.randn(
154
+ (1, 4, height // pipe_g.vae_scale_factor, width // pipe_g.vae_scale_factor), generator=generator, dtype=pipe_g.dtype, device=device
155
+ )
156
 
157
  output_g = []
158
  output_d = []
159
  for frame in frames:
160
+ latents = latent_common
161
+ if last_frame_latent is not None:
162
+ latents = 0.9 * latents + 0.1 * last_frame_latent
163
+ output_frame_g, last_frame_latent = infer_pipe_video(pipe_g, frame, task_name, seed, device, latents)
164
+ output_frame_d = infer_pipe(pipe_d, frame, task_name, seed, device, video_depth=True)
165
+ output_g.append(output_frame_g)
166
+ output_d.append(output_frame_d)
 
 
 
167
 
168
+ return output_g, output_d, fps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  def lotus(image_input, task_name, seed, device):
171
+ pipe_g, pipe_d = load_pipe(task_name, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  output_g = infer_pipe(pipe_g, image_input, task_name, seed, device)
173
  output_d = infer_pipe(pipe_d, image_input, task_name, seed, device)
174
  return output_g, output_d
pipeline.py CHANGED
@@ -1279,6 +1279,6 @@ class LotusGPipeline(DirectDiffusionPipeline):
1279
  self.maybe_free_model_hooks()
1280
 
1281
  if not return_dict:
1282
- return (image, has_nsfw_concept)
1283
 
1284
  return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
 
1279
  self.maybe_free_model_hooks()
1280
 
1281
  if not return_dict:
1282
+ return (image, has_nsfw_concept, latents)
1283
 
1284
  return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
utils/image_utils.py CHANGED
@@ -44,12 +44,15 @@ def concatenate_images(*image_lists):
44
  return new_image
45
 
46
 
47
- def colorize_depth_map(depth, mask=None):
48
  cm = matplotlib.colormaps["Spectral"]
49
  # normalize
50
  depth = ((depth - depth.min()) / (depth.max() - depth.min()))
51
  # colorize
52
- img_colored_np = cm(depth, bytes=False)[:, :, 0:3] # (h,w,3)
 
 
 
53
  depth_colored = (img_colored_np * 255).astype(np.uint8)
54
  if mask is not None:
55
  masked_image = np.zeros_like(depth_colored)
 
44
  return new_image
45
 
46
 
47
+ def colorize_depth_map(depth, mask=None, reverse_color=False):
48
  cm = matplotlib.colormaps["Spectral"]
49
  # normalize
50
  depth = ((depth - depth.min()) / (depth.max() - depth.min()))
51
  # colorize
52
+ if reverse_color:
53
+ img_colored_np = cm(1 - depth, bytes=False)[:, :, 0:3] # Invert the depth values before applying colormap
54
+ else:
55
+ img_colored_np = cm(depth, bytes=False)[:, :, 0:3] # (h,w,3)
56
  depth_colored = (img_colored_np * 255).astype(np.uint8)
57
  if mask is not None:
58
  masked_image = np.zeros_like(depth_colored)