Jacobmadwed commited on
Commit
79ae586
·
verified ·
1 Parent(s): 3fd28a6

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +31 -7
handler.py CHANGED
@@ -154,15 +154,26 @@ class EndpointHandler:
154
  image = image[np.newaxis, :, :, :].astype(np.float32) / 127.5 - 1.0 # Normalize to [-1, 1]
155
  return image
156
 
157
- def get_face_embedding(self, image):
158
  # Preprocess the image
159
  image = self.preprocess(image)
160
 
161
- # Run the ONNX model to get the face embedding
162
  input_name = self.ort_session.get_inputs()[0].name
163
- embedding = self.ort_session.run(None, {input_name: image})[0]
164
-
165
- return embedding
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  def __call__(self, data):
168
 
@@ -251,17 +262,30 @@ class EndpointHandler:
251
  height, width, _ = face_image_cv2.shape
252
 
253
  # Extract face features using the ONNX model
254
- face_emb = self.get_face_embedding(face_image_cv2)
 
 
 
255
 
 
 
 
256
  face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info["kps"])
257
  img_controlnet = face_image
 
258
  if pose_image:
259
  pose_image = resize_img(pose_image, max_side=1024)
260
  img_controlnet = pose_image
261
  pose_image_cv2 = convert_from_image_to_cv2(pose_image)
262
 
263
  # Extract face features from pose image using the ONNX model
264
- face_emb = self.get_face_embedding(pose_image_cv2)
 
 
 
 
 
 
265
  face_kps = draw_kps(pose_image, face_info["kps"])
266
 
267
  width, height = face_kps.size
 
154
  image = image[np.newaxis, :, :, :].astype(np.float32) / 127.5 - 1.0 # Normalize to [-1, 1]
155
  return image
156
 
157
+ def get_face_info(self, image):
158
  # Preprocess the image
159
  image = self.preprocess(image)
160
 
161
+ # Run the ONNX model to get the face detection results
162
  input_name = self.ort_session.get_inputs()[0].name
163
+ outputs = self.ort_session.run(None, {input_name: image})
164
+
165
+ # Process the output to extract face information
166
+ bboxes = outputs[0][0] # Adjust based on model output structure
167
+ face_info_list = []
168
+ for bbox in bboxes:
169
+ score = bbox[2]
170
+ if score > 0.5: # Confidence threshold
171
+ x1, y1, x2, y2 = bbox[3:7] * [320, 240, 320, 240] # Scale coordinates
172
+ face_info_list.append({
173
+ "bbox": [x1, y1, x2, y2],
174
+ "embedding": self.get_face_embedding(image[:, :, int(y1):int(y2), int(x1):int(x2)])
175
+ })
176
+ return face_info_list
177
 
178
  def __call__(self, data):
179
 
 
262
  height, width, _ = face_image_cv2.shape
263
 
264
  # Extract face features using the ONNX model
265
+ face_info_list = self.get_face_info(face_image_cv2)
266
+
267
+ if len(face_info_list) == 0:
268
+ return {"error": "No faces detected."}
269
 
270
+ # Use the largest face detected
271
+ face_info = max(face_info_list, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))
272
+ face_emb = face_info["embedding"]
273
  face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info["kps"])
274
  img_controlnet = face_image
275
+
276
  if pose_image:
277
  pose_image = resize_img(pose_image, max_side=1024)
278
  img_controlnet = pose_image
279
  pose_image_cv2 = convert_from_image_to_cv2(pose_image)
280
 
281
  # Extract face features from pose image using the ONNX model
282
+ face_info_list = self.get_face_info(pose_image_cv2)
283
+
284
+ if len(face_info_list) == 0:
285
+ return {"error": "No faces detected in pose image."}
286
+
287
+ face_info = max(face_info_list, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))
288
+ face_emb = face_info["embedding"]
289
  face_kps = draw_kps(pose_image, face_info["kps"])
290
 
291
  width, height = face_kps.size