Update handler.py
Browse files- handler.py +20 -18
handler.py
CHANGED
@@ -3,7 +3,7 @@ import torch
|
|
3 |
import numpy as np
|
4 |
import PIL
|
5 |
from PIL import Image
|
6 |
-
from typing import Tuple, List
|
7 |
from pydantic import BaseModel
|
8 |
import diffusers
|
9 |
from diffusers.utils import load_image
|
@@ -47,6 +47,8 @@ class GenerateImageRequest(BaseModel):
|
|
47 |
seed: int
|
48 |
enable_LCM: bool
|
49 |
enhance_face_region: bool
|
|
|
|
|
50 |
|
51 |
class EndpointHandler:
|
52 |
def __init__(self, model_dir):
|
@@ -195,38 +197,38 @@ class EndpointHandler:
|
|
195 |
# apply the style template
|
196 |
inputs, negative_prompt = apply_style(style_name, inputs, negative_prompt)
|
197 |
|
198 |
-
|
199 |
-
|
|
|
|
|
|
|
|
|
200 |
|
201 |
-
face_image = load_image(face_image_path)
|
202 |
face_image = resize_img(face_image, max_side=1024)
|
203 |
face_image_cv2 = convert_from_image_to_cv2(face_image)
|
204 |
height, width, _ = face_image_cv2.shape
|
205 |
|
206 |
-
#
|
207 |
-
|
208 |
-
|
209 |
-
face_info = sorted(
|
210 |
-
face_info,
|
211 |
-
key=lambda x: (x["bbox"][2] - x["bbox"][0]) * x["bbox"][3] - x["bbox"][1],
|
212 |
-
)[
|
213 |
-
-1
|
214 |
-
] # only use the maximum face
|
215 |
-
face_emb = face_info["embedding"]
|
216 |
-
face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info["kps"])
|
217 |
-
img_controlnet = face_image
|
218 |
-
if pose_image_path is not None:
|
219 |
pose_image = load_image(pose_image_path)
|
220 |
pose_image = resize_img(pose_image, max_side=1024)
|
221 |
img_controlnet = pose_image
|
222 |
pose_image_cv2 = convert_from_image_to_cv2(pose_image)
|
223 |
|
224 |
face_info = self.app.get(pose_image_cv2)
|
225 |
-
|
226 |
face_info = face_info[-1]
|
227 |
face_kps = draw_kps(pose_image, face_info["kps"])
|
228 |
|
229 |
width, height = face_kps.size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
control_mask = np.zeros([height, width, 3])
|
232 |
x1, y1, x2, y2 = face_info["bbox"]
|
|
|
3 |
import numpy as np
|
4 |
import PIL
|
5 |
from PIL import Image
|
6 |
+
from typing import Tuple, List, Optional
|
7 |
from pydantic import BaseModel
|
8 |
import diffusers
|
9 |
from diffusers.utils import load_image
|
|
|
47 |
seed: int
|
48 |
enable_LCM: bool
|
49 |
enhance_face_region: bool
|
50 |
+
face_image_path: Optional[str] = None
|
51 |
+
pose_image_path: Optional[str] = None
|
52 |
|
53 |
class EndpointHandler:
|
54 |
def __init__(self, model_dir):
|
|
|
197 |
# apply the style template
|
198 |
inputs, negative_prompt = apply_style(style_name, inputs, negative_prompt)
|
199 |
|
200 |
+
# Load face image
|
201 |
+
face_image_path = data.get("face_image_path", "raw/input.jpg")
|
202 |
+
if os.path.exists(face_image_path):
|
203 |
+
face_image = load_image(face_image_path)
|
204 |
+
else:
|
205 |
+
raise FileNotFoundError(f"Face image not found at path: {face_image_path}")
|
206 |
|
|
|
207 |
face_image = resize_img(face_image, max_side=1024)
|
208 |
face_image_cv2 = convert_from_image_to_cv2(face_image)
|
209 |
height, width, _ = face_image_cv2.shape
|
210 |
|
211 |
+
# Load pose image if provided
|
212 |
+
pose_image_path = data.get("pose_image_path")
|
213 |
+
if pose_image_path and os.path.exists(pose_image_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
pose_image = load_image(pose_image_path)
|
215 |
pose_image = resize_img(pose_image, max_side=1024)
|
216 |
img_controlnet = pose_image
|
217 |
pose_image_cv2 = convert_from_image_to_cv2(pose_image)
|
218 |
|
219 |
face_info = self.app.get(pose_image_cv2)
|
|
|
220 |
face_info = face_info[-1]
|
221 |
face_kps = draw_kps(pose_image, face_info["kps"])
|
222 |
|
223 |
width, height = face_kps.size
|
224 |
+
else:
|
225 |
+
img_controlnet = face_image
|
226 |
+
|
227 |
+
# Extract face features
|
228 |
+
face_info = self.app.get(face_image_cv2)
|
229 |
+
face_info = sorted(face_info, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * x["bbox"][3] - x["bbox"][1])[-1] # only use the maximum face
|
230 |
+
face_emb = face_info["embedding"]
|
231 |
+
face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info["kps"])
|
232 |
|
233 |
control_mask = np.zeros([height, width, 3])
|
234 |
x1, y1, x2, y2 = face_info["bbox"]
|