selmee commited on
Commit
e512885
·
verified ·
1 Parent(s): 7ec279a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -8
app.py CHANGED
@@ -7,7 +7,7 @@ from typing import Union
7
  from pathlib import Path
8
  import os
9
 
10
- def predict_depth(image: Image.Image, auto_rotate: bool, remove_alpha: bool, model, transform):
11
  # Convert the PIL image to a temporary file path if needed
12
  image_path = "temp_image.jpg"
13
  image.save(image_path)
@@ -30,14 +30,20 @@ def predict_depth(image: Image.Image, auto_rotate: bool, remove_alpha: bool, mod
30
 
31
  focallength = prediction["focallength_px"].cpu().numpy()
32
 
33
- # Normalize and colorize depth map
34
- cmap = plt.get_cmap("turbo_r")
35
- color_depth = (cmap(inverse_depth_normalized)[..., :3] * 255).astype(np.uint8)
 
 
 
 
 
 
36
 
37
  # Clean up temporary image
38
  os.remove(image_path)
39
 
40
- return Image.fromarray(color_depth), focallength # Return depth map and f_px
41
 
42
  def main():
43
  # Load model and preprocessing transform
@@ -46,11 +52,12 @@ def main():
46
 
47
  # Set up Gradio interface
48
  iface = gr.Interface(
49
- fn=lambda image, auto_rotate, remove_alpha: predict_depth(image, auto_rotate, remove_alpha, model, transform),
50
  inputs=[
51
  gr.Image(type="pil", label="Upload Image"), # Use image browser for input
52
  gr.Checkbox(label="Auto Rotate", value=True), # Checkbox for auto_rotate
53
- gr.Checkbox(label="Remove Alpha", value=True) # Checkbox for remove_alpha
 
54
  ],
55
  outputs=[
56
  gr.Image(label="Depth Map"), # Use PIL image output
@@ -65,4 +72,4 @@ def main():
65
  iface.launch()
66
 
67
  if __name__ == "__main__":
68
- main()
 
7
  from pathlib import Path
8
  import os
9
 
10
+ def predict_depth(image: Image.Image, auto_rotate: bool, remove_alpha: bool, grayscale: bool, model, transform):
11
  # Convert the PIL image to a temporary file path if needed
12
  image_path = "temp_image.jpg"
13
  image.save(image_path)
 
30
 
31
  focallength = prediction["focallength_px"].cpu().numpy()
32
 
33
+ if grayscale:
34
+ # Normalize the inverse depth map to 0-255 and convert to grayscale
35
+ grayscale_depth = (inverse_depth_normalized * 255).astype(np.uint8)
36
+ depth_image = Image.fromarray(grayscale_depth, mode="L")
37
+ else:
38
+ # Normalize and colorize depth map
39
+ cmap = plt.get_cmap("turbo_r")
40
+ color_depth = (cmap(inverse_depth_normalized)[..., :3] * 255).astype(np.uint8)
41
+ depth_image = Image.fromarray(color_depth)
42
 
43
  # Clean up temporary image
44
  os.remove(image_path)
45
 
46
+ return depth_image, focallength # Return depth map and f_px
47
 
48
  def main():
49
  # Load model and preprocessing transform
 
52
 
53
  # Set up Gradio interface
54
  iface = gr.Interface(
55
+ fn=lambda image, auto_rotate, remove_alpha, grayscale: predict_depth(image, auto_rotate, remove_alpha, grayscale, model, transform),
56
  inputs=[
57
  gr.Image(type="pil", label="Upload Image"), # Use image browser for input
58
  gr.Checkbox(label="Auto Rotate", value=True), # Checkbox for auto_rotate
59
+ gr.Checkbox(label="Remove Alpha", value=True), # Checkbox for remove_alpha
60
+ gr.Checkbox(label="Grayscale Depth", value=False) # Checkbox for grayscale
61
  ],
62
  outputs=[
63
  gr.Image(label="Depth Map"), # Use PIL image output
 
72
  iface.launch()
73
 
74
  if __name__ == "__main__":
75
+ main()