NightRaven109 commited on
Commit
1e5c329
·
verified ·
1 Parent(s): adfc685

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -16
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import cv2
2
  import torch
3
  import numpy as np
@@ -11,6 +12,7 @@ model_configs = {
11
  'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}
12
  }
13
 
 
14
  def initialize_model():
15
  encoder = 'vitl'
16
  max_depth = 1
@@ -33,38 +35,49 @@ def initialize_model():
33
  my_state_dict[new_key] = state_dict[key]
34
 
35
  model.load_state_dict(my_state_dict)
36
- model.eval()
37
  return model
38
 
39
- # Initialize model globally
40
  MODEL = initialize_model()
41
 
 
42
  def process_image(input_image):
43
  """
44
  Process the input image and return depth maps
45
  """
46
- # Convert from RGB to BGR (since original code uses cv2.imread which loads in BGR)
47
  if input_image is None:
48
  return None, None
49
-
50
- input_image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB_BGR)
51
 
52
- # Get depth map
53
- depth = MODEL.infer_image(input_image)
 
54
 
55
- # Normalize depth for visualization (0-255)
56
- depth_normalized = ((depth - depth.min()) / (depth.max() - depth.min()) * 255).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- # Apply colormap for better visualization
59
- depth_colormap = cv2.applyColorMap(depth_normalized, cv2.COLORMAP_INFERNO)
60
- depth_colormap = cv2.cvtColor(depth_colormap, cv2.COLOR_BGR2RGB) # Convert back to RGB for Gradio
61
 
62
  return depth_normalized, depth_colormap
63
 
64
- # Create Gradio interface
65
  def gradio_interface(input_img):
66
- depth_raw, depth_colored = process_image(input_img)
67
- return [input_img, depth_raw, depth_colored]
 
 
 
 
68
 
69
  # Define interface
70
  iface = gr.Interface(
@@ -77,7 +90,7 @@ iface = gr.Interface(
77
  ],
78
  title="Depth Estimation",
79
  description="Upload an image to generate its depth map.",
80
- examples=["image.jpg"] # Add example images here
81
  )
82
 
83
  # Launch the app
 
1
+ import spaces # Import spaces first
2
  import cv2
3
  import torch
4
  import numpy as np
 
12
  'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}
13
  }
14
 
15
+ # Initialize model globally
16
  def initialize_model():
17
  encoder = 'vitl'
18
  max_depth = 1
 
35
  my_state_dict[new_key] = state_dict[key]
36
 
37
  model.load_state_dict(my_state_dict)
 
38
  return model
39
 
 
40
  MODEL = initialize_model()
41
 
42
+ @spaces.GPU
43
  def process_image(input_image):
44
  """
45
  Process the input image and return depth maps
46
  """
 
47
  if input_image is None:
48
  return None, None
 
 
49
 
50
+ # Move model to GPU for processing
51
+ MODEL.to('cuda')
52
+ MODEL.eval()
53
 
54
+ # Convert from RGB to BGR
55
+ input_image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
56
+
57
+ with torch.no_grad():
58
+ # Get depth map
59
+ depth = MODEL.infer_image(input_image)
60
+
61
+ # Normalize depth for visualization (0-255)
62
+ depth_normalized = ((depth - depth.min()) / (depth.max() - depth.min()) * 255).astype(np.uint8)
63
+
64
+ # Apply colormap for better visualization
65
+ depth_colormap = cv2.applyColorMap(depth_normalized, cv2.COLORMAP_INFERNO)
66
+ depth_colormap = cv2.cvtColor(depth_colormap, cv2.COLOR_BGR2RGB)
67
 
68
+ # Move model back to CPU after processing
69
+ MODEL.to('cpu')
 
70
 
71
  return depth_normalized, depth_colormap
72
 
73
+ @spaces.GPU
74
  def gradio_interface(input_img):
75
+ try:
76
+ depth_raw, depth_colored = process_image(input_img)
77
+ return [input_img, depth_raw, depth_colored]
78
+ except Exception as e:
79
+ print(f"Error processing image: {str(e)}")
80
+ return [input_img, None, None]
81
 
82
  # Define interface
83
  iface = gr.Interface(
 
90
  ],
91
  title="Depth Estimation",
92
  description="Upload an image to generate its depth map.",
93
+ examples=["image.jpg"]
94
  )
95
 
96
  # Launch the app