Dref360 commited on
Commit
fbb0b68
·
1 Parent(s): afb0729

First commit

Browse files
Files changed (3) hide show
  1. .gitignore +162 -0
  2. README.md +6 -6
  3. app.py +170 -111
.gitignore ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Python template
2
+ # Byte-compiled / optimized / DLL files
3
+ __pycache__/
4
+ *.py[cod]
5
+ *$py.class
6
+ .idea/
7
+ # C extensions
8
+ *.so
9
+
10
+ # Distribution / packaging
11
+ .Python
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+ cover/
54
+
55
+ # Translations
56
+ *.mo
57
+ *.pot
58
+
59
+ # Django stuff:
60
+ *.log
61
+ local_settings.py
62
+ db.sqlite3
63
+ db.sqlite3-journal
64
+
65
+ # Flask stuff:
66
+ instance/
67
+ .webassets-cache
68
+
69
+ # Scrapy stuff:
70
+ .scrapy
71
+
72
+ # Sphinx documentation
73
+ docs/_build/
74
+
75
+ # PyBuilder
76
+ .pybuilder/
77
+ target/
78
+
79
+ # Jupyter Notebook
80
+ .ipynb_checkpoints
81
+
82
+ # IPython
83
+ profile_default/
84
+ ipython_config.py
85
+
86
+ # pyenv
87
+ # For a library or package, you might want to ignore these files since the code is
88
+ # intended to run in multiple environments; otherwise, check them in:
89
+ # .python-version
90
+
91
+ # pipenv
92
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
94
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
95
+ # install all needed dependencies.
96
+ #Pipfile.lock
97
+
98
+ # poetry
99
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
101
+ # commonly ignored for libraries.
102
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103
+ #poetry.lock
104
+
105
+ # pdm
106
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107
+ #pdm.lock
108
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109
+ # in version control.
110
+ # https://pdm.fming.dev/#use-with-ide
111
+ .pdm.toml
112
+
113
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
114
+ __pypackages__/
115
+
116
+ # Celery stuff
117
+ celerybeat-schedule
118
+ celerybeat.pid
119
+
120
+ # SageMath parsed files
121
+ *.sage.py
122
+
123
+ # Environments
124
+ .env
125
+ .venv
126
+ env/
127
+ venv/
128
+ ENV/
129
+ env.bak/
130
+ venv.bak/
131
+
132
+ # Spyder project settings
133
+ .spyderproject
134
+ .spyproject
135
+
136
+ # Rope project settings
137
+ .ropeproject
138
+
139
+ # mkdocs documentation
140
+ /site
141
+
142
+ # mypy
143
+ .mypy_cache/
144
+ .dmypy.json
145
+ dmypy.json
146
+
147
+ # Pyre type checker
148
+ .pyre/
149
+
150
+ # pytype static type analyzer
151
+ .pytype/
152
+
153
+ # Cython debug symbols
154
+ cython_debug/
155
+
156
+ # PyCharm
157
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
158
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
159
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
160
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
161
+ #.idea/
162
+
README.md CHANGED
@@ -1,14 +1,14 @@
1
  ---
2
- title: Vit Pose Playground
3
- emoji: 🏆
4
- colorFrom: green
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 5.5.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
- short_description: Small Space to test ViTPose
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Human Interaction Demo
3
+ emoji: 📊
4
+ colorFrom: gray
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 5.6.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ short_description: Uses pose estimation to determine what are you touching.
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,50 +1,52 @@
 
1
  import gradio as gr
2
- import torch
3
  import numpy as np
4
- import cv2
5
- from PIL import Image
6
  import supervision as sv
 
 
7
  from transformers import (
8
  RTDetrForObjectDetection,
9
  RTDetrImageProcessor,
10
- VitPoseConfig,
11
  VitPoseForPoseEstimation,
12
  VitPoseImageProcessor,
 
13
  )
14
 
15
-
16
- KEYPOINT_LABEL_MAP = {
17
- 0: "Nose",
18
- 1: "L_Eye",
19
- 2: "R_Eye",
20
- 3: "L_Ear",
21
- 4: "R_Ear",
22
- 5: "L_Shoulder",
23
- 6: "R_Shoulder",
24
- 7: "L_Elbow",
25
- 8: "R_Elbow",
26
- 9: "L_Wrist",
27
- 10: "R_Wrist",
28
- 11: "L_Hip",
29
- 12: "R_Hip",
30
- 13: "L_Knee",
31
- 14: "R_Knee",
32
- 15: "L_Ankle",
33
- 16: "R_Ankle",
34
- }
35
-
36
-
37
- class KeypointDetector:
38
  def __init__(self):
39
  self.person_detector = None
40
  self.person_processor = None
41
  self.pose_model = None
42
  self.pose_processor = None
 
 
 
43
  self.load_models()
44
 
45
  def load_models(self):
46
  """Load all required models"""
47
- # Object detection model
48
  self.person_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
49
  self.person_detector = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
50
 
@@ -52,21 +54,35 @@ class KeypointDetector:
52
  self.pose_processor = VitPoseImageProcessor.from_pretrained("nielsr/vitpose-base-simple")
53
  self.pose_model = VitPoseForPoseEstimation.from_pretrained("nielsr/vitpose-base-simple")
54
 
55
- @staticmethod
56
- def pascal_voc_to_coco(bboxes: np.ndarray) -> np.ndarray:
57
- """Convert Pascal VOC format to COCO format"""
58
- bboxes = bboxes.copy() # Create a copy to avoid modifying the input
59
- bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0]
60
- bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1]
61
- return bboxes
62
-
63
- @staticmethod
64
- def coco_to_xyxy(bboxes: np.ndarray) -> np.ndarray:
65
- """Convert COCO format (x,y,w,h) to xyxy format (x1,y1,x2,y2)"""
66
- bboxes = bboxes.copy()
67
- bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2]
68
- bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3]
69
- return bboxes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  def detect_persons(self, image: Image.Image):
72
  """Detect persons in the image"""
@@ -80,70 +96,105 @@ class KeypointDetector:
80
  threshold=0.3
81
  )
82
 
83
- # Get boxes and scores for human class (index 0 in COCO dataset)
84
  boxes = results[0]["boxes"][results[0]["labels"] == 0]
85
  scores = results[0]["scores"][results[0]["labels"] == 0]
86
  return boxes.cpu().numpy(), scores.cpu().numpy()
87
 
88
  def detect_keypoints(self, image: Image.Image):
89
  """Detect keypoints in the image"""
90
- # Detect persons first
91
  boxes, scores = self.detect_persons(image)
92
- boxes_coco = [self.pascal_voc_to_coco(boxes)]
93
 
94
- # Detect pose keypoints
95
- pixel_values = self.pose_processor(image, boxes=boxes_coco, return_tensors="pt").pixel_values
96
  with torch.no_grad():
97
  outputs = self.pose_model(pixel_values)
98
 
99
- pose_results = self.pose_processor.post_process_pose_estimation(outputs, boxes=boxes_coco)[0]
100
  return pose_results, boxes, scores
101
 
102
- def visualize_detections(self, image: Image.Image, pose_results, boxes, scores):
103
- """Visualize both bounding boxes and keypoints on the image"""
104
- # Convert image to numpy array if needed
105
- image_array = np.array(image)
 
106
 
107
- # Setup detections for bounding boxes
108
- detections = sv.Detections(
109
- xyxy=self.coco_to_xyxy(boxes),
110
- confidence=scores,
111
- class_id=np.array([0]*len(scores))
112
- )
 
 
113
 
114
- # Create box annotator
115
- box_annotator = sv.BoxAnnotator(
116
- color=sv.ColorPalette.DEFAULT,
117
- thickness=2
118
- )
119
 
120
- # Create edge annotator for keypoints
121
- edge_annotator = sv.EdgeAnnotator(
122
- color=sv.Color.GREEN,
123
- thickness=3
124
- )
 
 
 
125
 
126
- # Convert keypoints to supervision format
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  key_points = sv.KeyPoints(
128
  xy=torch.cat([pose_result['keypoints'].unsqueeze(0) for pose_result in pose_results]).cpu().numpy()
129
  )
 
130
 
131
- # Annotate image with boxes first
132
- annotated_frame = box_annotator.annotate(
133
- scene=image_array.copy(),
134
- detections=detections
135
- )
136
 
137
- # Then add keypoints
138
- annotated_frame = edge_annotator.annotate(
139
- scene=annotated_frame,
140
- key_points=key_points
141
- )
142
 
143
- return Image.fromarray(annotated_frame)
 
 
 
 
 
144
 
145
  def process_image(self, input_image):
146
- """Process image and return visualization"""
147
  if input_image is None:
148
  return None, ""
149
 
@@ -153,69 +204,77 @@ class KeypointDetector:
153
  else:
154
  image = input_image
155
 
156
- # Detect keypoints and boxes
157
- pose_results, boxes, scores = self.detect_keypoints(image)
 
 
158
 
159
  # Visualize results
160
- result_image = self.visualize_detections(image, pose_results, boxes, scores)
161
 
162
- # Create detection information text
163
  info_text = []
 
 
 
 
 
 
 
 
 
164
 
165
- # Box information
166
- for i, (box, score) in enumerate(zip(boxes, scores)):
167
- info_text.append(f"\nPerson {i + 1} (confidence: {score:.2f})")
168
- info_text.append(f"Bounding Box: x1={box[0]:.1f}, y1={box[1]:.1f}, x2={box[2]:.1f}, y2={box[3]:.1f}")
169
 
170
- # Add keypoint information for this person
171
- pose_result = pose_results[i]
172
- for j, keypoint in enumerate(pose_result["keypoints"]):
173
- x, y, confidence = keypoint
174
- info_text.append(f"Keypoint {KEYPOINT_LABEL_MAP[j]}: x={x:.1f}, y={y:.1f}, confidence={confidence:.2f}")
175
 
176
- return result_image, "\n".join(info_text)
177
 
178
 
179
  def create_gradio_interface():
180
  """Create Gradio interface"""
181
- detector = KeypointDetector()
182
 
183
  with gr.Blocks() as interface:
184
- gr.Markdown("# Human Detection and Keypoint Estimation using VitPose")
185
- gr.Markdown("Upload an image to detect people and their keypoints. The model will:")
186
- gr.Markdown("1. Detect people in the image (shown as bounding boxes)")
187
- gr.Markdown("2. Identify keypoints for each detected person (shown as connected green lines)")
188
- gr.Markdown("Huge shoutout to @NielsRogge and @SangbumChoi for this work!")
189
 
190
  with gr.Row():
191
  with gr.Column():
192
  input_image = gr.Image(label="Input Image")
193
- process_button = gr.Button("Detect People & Keypoints")
194
 
195
  with gr.Column():
196
  output_image = gr.Image(label="Detection Results")
197
- detection_info = gr.Textbox(
198
- label="Detection Information",
199
  lines=10,
200
- placeholder="Detection details will appear here..."
201
  )
 
 
202
 
203
  process_button.click(
204
  fn=detector.process_image,
205
  inputs=input_image,
206
- outputs=[output_image, detection_info]
207
  )
208
 
209
  gr.Examples(
210
  examples=[
211
- "http://images.cocodataset.org/val2017/000000000139.jpg"
 
212
  ],
213
  inputs=input_image
214
  )
215
 
216
  return interface
217
 
218
-
219
  if __name__ == "__main__":
220
- interface = create_gradio_interface()
221
- interface.launch()
 
1
+ import cv2
2
  import gradio as gr
 
3
  import numpy as np
 
 
4
  import supervision as sv
5
+ import torch
6
+ from PIL import Image
7
  from transformers import (
8
  RTDetrForObjectDetection,
9
  RTDetrImageProcessor,
 
10
  VitPoseForPoseEstimation,
11
  VitPoseImageProcessor,
12
+ pipeline,
13
  )
14
 
15
+ KEYPOINT_LABEL_MAP = {
16
+ 0: "Nose",
17
+ 1: "L_Eye",
18
+ 2: "R_Eye",
19
+ 3: "L_Ear",
20
+ 4: "R_Ear",
21
+ 5: "L_Shoulder",
22
+ 6: "R_Shoulder",
23
+ 7: "L_Elbow",
24
+ 8: "R_Elbow",
25
+ 9: "L_Wrist",
26
+ 10: "R_Wrist",
27
+ 11: "L_Hip",
28
+ 12: "R_Hip",
29
+ 13: "L_Knee",
30
+ 14: "R_Knee",
31
+ 15: "L_Ankle",
32
+ 16: "R_Ankle",
33
+ }
34
+
35
+
36
+ class InteractionDetector:
 
37
  def __init__(self):
38
  self.person_detector = None
39
  self.person_processor = None
40
  self.pose_model = None
41
  self.pose_processor = None
42
+ self.depth_model = None
43
+ self.segmentation_model = None
44
+ self.interaction_threshold = 2
45
  self.load_models()
46
 
47
  def load_models(self):
48
  """Load all required models"""
49
+ # Person detection model
50
  self.person_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
51
  self.person_detector = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
52
 
 
54
  self.pose_processor = VitPoseImageProcessor.from_pretrained("nielsr/vitpose-base-simple")
55
  self.pose_model = VitPoseForPoseEstimation.from_pretrained("nielsr/vitpose-base-simple")
56
 
57
+ # Depth estimation model
58
+ self.depth_model = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")
59
+
60
+ # Semantic segmentation model
61
+ self.segmentation_model = pipeline("image-segmentation", model="facebook/maskformer-swin-base-ade")
62
+ self.segmentation_id2label = self.segmentation_model.model.config.id2label
63
+ self.segmentation_label2id = {v: k for k, v in self.segmentation_model.model.config.id2label.items()}
64
+
65
+ def get_nearest_pixel_class(self, joint, depth_map, segmentation_map):
66
+ """
67
+ Find the nearest pixel of a specific class to a given joint coordinate
68
+ Args:
69
+ joint: (x, y) coordinates of the joint
70
+ depth_map: Depth map
71
+ segmentation_map: Semantic segmentation results
72
+ Returns:
73
+ tuple: class_name of nearest pixel, distance to that pixel
74
+ """
75
+ PERSON_ID = 12
76
+ grid_x, grid_y = np.meshgrid(np.arange(depth_map.shape[0]), np.arange(depth_map.shape[1]))
77
+ dist_x = np.abs(grid_x.T - joint[1])
78
+ dist_y = np.abs(grid_y.T - joint[0])
79
+ dist_coord = dist_x + dist_y
80
+
81
+
82
+ depth_dist = np.abs(depth_map - depth_map[joint[1], joint[0]])
83
+ depth_dist[(segmentation_map == PERSON_ID) | (dist_coord > 50)] = 255
84
+ min_dist = np.unravel_index(np.argmin(depth_dist), depth_dist.shape)
85
+ return segmentation_map[min_dist], depth_dist[min_dist]
86
 
87
  def detect_persons(self, image: Image.Image):
88
  """Detect persons in the image"""
 
96
  threshold=0.3
97
  )
98
 
 
99
  boxes = results[0]["boxes"][results[0]["labels"] == 0]
100
  scores = results[0]["scores"][results[0]["labels"] == 0]
101
  return boxes.cpu().numpy(), scores.cpu().numpy()
102
 
103
  def detect_keypoints(self, image: Image.Image):
104
  """Detect keypoints in the image"""
 
105
  boxes, scores = self.detect_persons(image)
 
106
 
107
+ pixel_values = self.pose_processor(image, boxes=[boxes], return_tensors="pt").pixel_values
 
108
  with torch.no_grad():
109
  outputs = self.pose_model(pixel_values)
110
 
111
+ pose_results = self.pose_processor.post_process_pose_estimation(outputs, boxes=[boxes])[0]
112
  return pose_results, boxes, scores
113
 
114
+ def estimate_depth(self, image: Image.Image):
115
+ """Estimate depth for the image"""
116
+ with torch.no_grad():
117
+ depth_map = np.array(self.depth_model(image)['depth'])
118
+ return depth_map
119
 
120
+ def segment_image(self, image: Image.Image):
121
+ """Perform semantic segmentation on the image"""
122
+ with torch.no_grad():
123
+ segmentation_map = self.segmentation_model(image)
124
+ result = np.zeros(np.array(image).shape[:2], dtype=np.uint8)
125
+ print("Found", [l['label'] for l in segmentation_map])
126
+ for cls_item in sorted(segmentation_map, key=lambda l: np.sum(l['mask']), reverse=True):
127
+ result[np.array(cls_item['mask']) > 0] = self.segmentation_label2id[cls_item['label']]
128
 
129
+ return result
 
 
 
 
130
 
131
+ def detect_wall_interaction(self, image: Image.Image):
132
+ """Detect if hands are touching walls"""
133
+ # Get all necessary information
134
+ pose_results, boxes, scores = self.detect_keypoints(image)
135
+ depth_map = self.estimate_depth(image)
136
+ segmentation_map = self.segment_image(image)
137
+
138
+ interactions = []
139
 
140
+ for person_idx, pose_result in enumerate(pose_results):
141
+ # Get hand keypoints
142
+ right_hand = pose_result["keypoints"][10].numpy().astype(int)
143
+ left_hand = pose_result["keypoints"][9].numpy().astype(int)
144
+
145
+ # Find nearest anything pixels
146
+ right_cls, r_distance = self.get_nearest_pixel_class(right_hand[:2], depth_map, segmentation_map)
147
+ left_cls, l_distance = self.get_nearest_pixel_class(left_hand[:2], depth_map, segmentation_map)
148
+
149
+
150
+ # Check for interactions
151
+ right_touching = r_distance < self.interaction_threshold
152
+ left_touching = l_distance < self.interaction_threshold
153
+
154
+ interactions.append({
155
+ "person_id": person_idx,
156
+ "right_hand_touching_object": self.segmentation_id2label[right_cls],
157
+ "left_hand_touching_object": self.segmentation_id2label[left_cls],
158
+ "right_hand_touching": right_touching,
159
+ "left_hand_touching": left_touching,
160
+ "right_hand_distance": r_distance,
161
+ "left_hand_distance": l_distance
162
+ })
163
+
164
+ return interactions, pose_results, segmentation_map, depth_map
165
+
166
+ def visualize_results(self, image: Image.Image, interactions, pose_results):
167
+ """Visualize detection results"""
168
+ # Create base visualization from original image
169
+ vis_image = np.array(image).copy()
170
+
171
+ # Add pose keypoints
172
+ edge_annotator = sv.EdgeAnnotator(color=sv.Color.GREEN, thickness=2)
173
  key_points = sv.KeyPoints(
174
  xy=torch.cat([pose_result['keypoints'].unsqueeze(0) for pose_result in pose_results]).cpu().numpy()
175
  )
176
+ vis_image = edge_annotator.annotate(scene=vis_image, key_points=key_points)
177
 
178
+ # Add interaction indicators
179
+ for interaction in interactions:
180
+ person_id = interaction["person_id"]
181
+ pose_result = pose_results[person_id]
 
182
 
183
+ # Draw indicators for touching hands
184
+ if interaction["right_hand_touching"]:
185
+ cv2.circle(vis_image,
186
+ tuple(map(int, pose_result["keypoints"][10][:2])),
187
+ 10, (0, 0, 255), -1)
188
 
189
+ if interaction["left_hand_touching"]:
190
+ cv2.circle(vis_image,
191
+ tuple(map(int, pose_result["keypoints"][9][:2])),
192
+ 10, (0, 0, 255), -1)
193
+
194
+ return Image.fromarray(vis_image)
195
 
196
  def process_image(self, input_image):
197
+ """Process image and return visualization with interaction detection"""
198
  if input_image is None:
199
  return None, ""
200
 
 
204
  else:
205
  image = input_image
206
 
207
+ image = image.resize((1280, 720))
208
+
209
+ # Detect interactions
210
+ interactions, pose_results, segmentation_map, depth_map = self.detect_wall_interaction(image)
211
 
212
  # Visualize results
213
+ result_image = self.visualize_results(image, interactions, pose_results)
214
 
215
+ # Create interaction information text
216
  info_text = []
217
+ for interaction in interactions:
218
+ info_text.append(f"\nPerson {interaction['person_id'] + 1}:")
219
+ if interaction["right_hand_touching"]:
220
+ info_text.append(f"Right hand is touching {interaction['right_hand_touching_object']}")
221
+ if interaction["left_hand_touching"]:
222
+ info_text.append(f"Left hand is touching {interaction['left_hand_touching_object']}")
223
+ info_text.append(f"Right hand distance to wall: {interaction['right_hand_distance']:.2f}")
224
+ info_text.append(f"Left hand distance to wall: {interaction['left_hand_distance']:.2f}")
225
+
226
 
227
+ # Add color to segmentation
228
+ mask = np.zeros((*segmentation_map.shape, 3), dtype=np.uint8)
229
+ colors = np.random.randint(0, 255, size=(100, 3))
 
230
 
231
+ for cl_id in np.unique(segmentation_map):
232
+ mask_array = np.array(segmentation_map == cl_id)
233
+ color = colors[cl_id % len(colors)]
234
+ mask[mask_array] = color
 
235
 
236
+ return result_image, mask, depth_map, "\n".join(info_text)
237
 
238
 
239
  def create_gradio_interface():
240
  """Create Gradio interface"""
241
+ detector = InteractionDetector()
242
 
243
  with gr.Blocks() as interface:
244
+ gr.Markdown("# Object Interaction Detection")
245
+ gr.Markdown("Upload an image to detect when people are touching objects.")
 
 
 
246
 
247
  with gr.Row():
248
  with gr.Column():
249
  input_image = gr.Image(label="Input Image")
250
+ process_button = gr.Button("Detect Interactions")
251
 
252
  with gr.Column():
253
  output_image = gr.Image(label="Detection Results")
254
+ interaction_info = gr.Textbox(
255
+ label="Interaction Information",
256
  lines=10,
257
+ placeholder="Interaction details will appear here..."
258
  )
259
+ segmentation_im = gr.Image(label="Segmentaiton Results")
260
+ depth_im = gr.Image(label="Depth Results")
261
 
262
  process_button.click(
263
  fn=detector.process_image,
264
  inputs=input_image,
265
+ outputs=[output_image, segmentation_im, depth_im, interaction_info]
266
  )
267
 
268
  gr.Examples(
269
  examples=[
270
+ "https://img.freepik.com/premium-photo/happy-black-man-opening-door-gesturing-okay-approving-new-home_116547-23954.jpg?w=1800",
271
+ "https://static3.bigstockphoto.com/6/7/2/large1500/276757975.jpg"
272
  ],
273
  inputs=input_image
274
  )
275
 
276
  return interface
277
 
278
+ interface = create_gradio_interface()
279
  if __name__ == "__main__":
280
+ interface.launch(debug=True)