Update handler.py
Browse files- handler.py +101 -182
handler.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1 |
import cv2
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
-
import PIL
|
5 |
from PIL import Image
|
6 |
-
from typing import Tuple, List, Optional
|
7 |
from pydantic import BaseModel
|
8 |
import diffusers
|
9 |
from diffusers.utils import load_image
|
@@ -18,11 +17,10 @@ import os
|
|
18 |
from huggingface_hub import hf_hub_download
|
19 |
import base64
|
20 |
import io
|
21 |
-
import json
|
22 |
from transformers import CLIPProcessor, CLIPModel
|
23 |
import onnxruntime as ort
|
24 |
|
25 |
-
#
|
26 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
27 |
dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
|
28 |
STYLE_NAMES = list(styles.keys())
|
@@ -52,7 +50,7 @@ class GenerateImageRequest(BaseModel):
|
|
52 |
pose_image_base64: Optional[str] = None
|
53 |
|
54 |
class EndpointHandler:
|
55 |
-
def __init__(self, model_dir):
|
56 |
# Ensure the necessary files are downloaded
|
57 |
controlnet_config = hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir=os.path.join(model_dir, "checkpoints"))
|
58 |
controlnet_model = hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir=os.path.join(model_dir, "checkpoints"))
|
@@ -60,47 +58,34 @@ class EndpointHandler:
|
|
60 |
|
61 |
# Load the ONNX model
|
62 |
onnx_model_path = os.path.join(model_dir, "models", "version-RFB-320.onnx")
|
63 |
-
if not os.path.exists(onnx_model_path):
|
64 |
print(f"Model path {onnx_model_path} does not exist. Please ensure the model is available.")
|
65 |
self.ort_session = ort.InferenceSession(onnx_model_path)
|
66 |
|
67 |
-
openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
|
68 |
|
69 |
# Path to InstantID models
|
70 |
controlnet_path = os.path.join(model_dir, "checkpoints", "ControlNetModel")
|
71 |
|
72 |
# Load pipeline face ControlNetModel
|
73 |
-
self.controlnet_identitynet = ControlNetModel.from_pretrained(
|
74 |
-
controlnet_path, torch_dtype=dtype
|
75 |
-
)
|
76 |
-
|
77 |
-
# controlnet-pose
|
78 |
-
controlnet_pose_model = "thibaud/controlnet-openpose-sdxl-1.0"
|
79 |
-
controlnet_canny_model = "diffusers/controlnet-canny-sdxl-1.0"
|
80 |
-
|
81 |
-
controlnet_pose = ControlNetModel.from_pretrained(
|
82 |
-
controlnet_pose_model, torch_dtype=dtype
|
83 |
-
).to(device)
|
84 |
-
controlnet_canny = ControlNetModel.from_pretrained(
|
85 |
-
controlnet_canny_model, torch_dtype=dtype
|
86 |
-
).to(device)
|
87 |
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
return Image.fromarray(edges, "L")
|
92 |
|
|
|
93 |
self.controlnet_map = {
|
94 |
-
"pose": controlnet_pose,
|
95 |
-
"canny": controlnet_canny
|
96 |
}
|
97 |
|
98 |
self.controlnet_map_fn = {
|
99 |
-
"pose": openpose,
|
100 |
-
"canny": get_canny_image
|
101 |
}
|
102 |
|
103 |
-
pretrained_model_name_or_path = "
|
104 |
|
105 |
self.pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
|
106 |
pretrained_model_name_or_path,
|
@@ -114,7 +99,7 @@ class EndpointHandler:
|
|
114 |
self.pipe.scheduler.config
|
115 |
)
|
116 |
|
117 |
-
#
|
118 |
self.pipe.load_lora_weights(lcm_lora_path)
|
119 |
self.pipe.fuse_lora()
|
120 |
self.pipe.disable_lora()
|
@@ -128,226 +113,160 @@ class EndpointHandler:
|
|
128 |
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
129 |
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
|
130 |
|
131 |
-
def
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
Args:
|
136 |
-
image (Image.Image): PIL image to check.
|
137 |
|
138 |
-
|
139 |
-
bool: True if the image is NSFW, False otherwise.
|
140 |
-
"""
|
141 |
inputs = self.clip_processor(text=["NSFW", "SFW"], images=image, return_tensors="pt", padding=True)
|
142 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
143 |
outputs = self.clip_model(**inputs)
|
144 |
-
logits_per_image = outputs.logits_per_image #
|
145 |
-
probs = logits_per_image.softmax(dim=1) #
|
146 |
nsfw_prob = probs[0, 0].item() # probability of "NSFW" label
|
147 |
-
return nsfw_prob > 0.
|
148 |
|
149 |
def preprocess(self, image):
|
150 |
-
# Preprocess the image for ONNX model
|
151 |
-
image = cv2.resize(image, (320, 240)) # Adjust based on model input size
|
152 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
153 |
-
image =
|
154 |
-
|
|
|
|
|
|
|
|
|
155 |
return image
|
156 |
|
157 |
def get_face_info(self, image):
|
158 |
-
|
159 |
-
image = self.preprocess(image)
|
160 |
-
|
161 |
-
# Run the ONNX model to get the face detection results
|
162 |
input_name = self.ort_session.get_inputs()[0].name
|
163 |
-
|
164 |
|
165 |
-
# Process the output to extract face information
|
166 |
-
bboxes = outputs[0][0] # Adjust based on model output structure
|
167 |
face_info_list = []
|
168 |
-
for
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
"embedding": self.get_face_embedding(image[:, :, int(y1):int(y2), int(x1):int(x2)])
|
175 |
-
})
|
176 |
return face_info_list
|
177 |
|
178 |
-
def
|
179 |
-
# Extract features for the face image region
|
180 |
-
# Implement the logic to extract face embeddings
|
181 |
-
# For now, returning a placeholder value
|
182 |
-
return np.random.rand(512) # Replace with actual embedding extraction logic
|
183 |
-
|
184 |
-
def __call__(self, data):
|
185 |
-
|
186 |
-
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
|
187 |
-
return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
188 |
-
|
189 |
-
def convert_from_image_to_cv2(img: Image) -> np.ndarray:
|
190 |
-
return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
191 |
-
|
192 |
-
def resize_img(
|
193 |
-
input_image,
|
194 |
-
max_side=1280,
|
195 |
-
min_side=1024,
|
196 |
-
size=None,
|
197 |
-
pad_to_max_side=False,
|
198 |
-
mode=PIL.Image.BILINEAR,
|
199 |
-
base_pixel_number=64,
|
200 |
-
):
|
201 |
-
w, h = input_image.size
|
202 |
-
if size is not None:
|
203 |
-
w_resize_new, h_resize_new = size
|
204 |
-
else:
|
205 |
-
ratio = min_side / min(h, w)
|
206 |
-
w, h = round(ratio * w), round(ratio * h)
|
207 |
-
ratio = max_side / max(h, w)
|
208 |
-
input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode)
|
209 |
-
w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
|
210 |
-
h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
|
211 |
-
input_image = input_image.resize([w_resize_new, h_resize_new], mode)
|
212 |
-
|
213 |
-
if pad_to_max_side:
|
214 |
-
res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
|
215 |
-
offset_x = (max_side - w_resize_new) // 2
|
216 |
-
offset_y = (max_side - h_resize_new) // 2
|
217 |
-
res[
|
218 |
-
offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new
|
219 |
-
] = np.array(input_image)
|
220 |
-
input_image = Image.fromarray(res)
|
221 |
-
return input_image
|
222 |
-
|
223 |
-
def apply_style(
|
224 |
-
style_name: str, positive: str, negative: str = ""
|
225 |
-
) -> Tuple[str, str]:
|
226 |
-
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
|
227 |
-
return p.replace("{prompt}", positive), n + " " + negative
|
228 |
-
|
229 |
request = GenerateImageRequest(**data)
|
230 |
-
|
231 |
-
|
232 |
-
style_name = request.style
|
233 |
-
identitynet_strength_ratio = request.identitynet_strength_ratio
|
234 |
-
adapter_strength_ratio = request.adapter_strength_ratio
|
235 |
-
pose_strength = request.pose_strength
|
236 |
-
canny_strength = request.canny_strength
|
237 |
-
num_steps = request.num_steps
|
238 |
-
guidance_scale = request.guidance_scale
|
239 |
-
controlnet_selection = request.controlnet_selection
|
240 |
-
seed = request.seed
|
241 |
-
enhance_face_region = request.enhance_face_region
|
242 |
-
enable_LCM = request.enable_LCM
|
243 |
-
|
244 |
-
if enable_LCM:
|
245 |
self.pipe.enable_lora()
|
246 |
self.pipe.scheduler = diffusers.LCMScheduler.from_config(self.pipe.scheduler.config)
|
247 |
-
guidance_scale = min(max(guidance_scale, 0), 1)
|
248 |
else:
|
249 |
self.pipe.disable_lora()
|
250 |
self.pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(self.pipe.scheduler.config)
|
251 |
|
252 |
-
#
|
253 |
-
inputs, negative_prompt = apply_style(
|
254 |
-
|
255 |
-
# Decode base64 image
|
256 |
-
face_image_base64 = data.get("face_image_base64")
|
257 |
-
face_image_data = base64.b64decode(face_image_base64)
|
258 |
-
face_image = Image.open(io.BytesIO(face_image_data))
|
259 |
|
260 |
-
|
261 |
-
|
262 |
-
if pose_image_base64
|
263 |
-
pose_image_data = base64.b64decode(pose_image_base64)
|
264 |
-
pose_image = Image.open(io.BytesIO(pose_image_data))
|
265 |
|
266 |
-
face_image = resize_img(face_image, max_side=1024)
|
267 |
-
face_image_cv2 = convert_from_image_to_cv2(face_image)
|
268 |
height, width, _ = face_image_cv2.shape
|
269 |
|
270 |
-
# Extract face features
|
271 |
face_info_list = self.get_face_info(face_image_cv2)
|
272 |
-
|
273 |
if len(face_info_list) == 0:
|
274 |
return {"error": "No faces detected."}
|
275 |
|
276 |
-
# Use the largest face detected
|
277 |
face_info = max(face_info_list, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))
|
278 |
-
|
279 |
-
face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info["bbox"])
|
280 |
img_controlnet = face_image
|
281 |
|
282 |
if pose_image:
|
283 |
-
pose_image = resize_img(pose_image, max_side=1024)
|
284 |
img_controlnet = pose_image
|
285 |
-
pose_image_cv2 = convert_from_image_to_cv2(pose_image)
|
286 |
|
287 |
-
# Extract face features from pose image using the ONNX model
|
288 |
face_info_list = self.get_face_info(pose_image_cv2)
|
289 |
-
|
290 |
if len(face_info_list) == 0:
|
291 |
return {"error": "No faces detected in pose image."}
|
292 |
|
293 |
face_info = max(face_info_list, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))
|
294 |
-
face_emb = face_info["embedding"]
|
295 |
face_kps = draw_kps(pose_image, face_info["bbox"])
|
296 |
-
|
297 |
width, height = face_kps.size
|
298 |
|
299 |
-
control_mask = np.zeros([height, width, 3], dtype=np.uint8)
|
300 |
-
x1, y1, x2, y2 = face_info["bbox"]
|
301 |
-
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
302 |
control_mask[y1:y2, x1:x2] = 255
|
303 |
control_mask = Image.fromarray(control_mask)
|
304 |
|
305 |
-
controlnet_scales = {
|
306 |
-
"pose": pose_strength,
|
307 |
-
"canny": canny_strength
|
308 |
-
}
|
309 |
self.pipe.controlnet = MultiControlNetModel(
|
310 |
-
[self.controlnet_identitynet]
|
311 |
-
+ [self.controlnet_map[s] for s in controlnet_selection]
|
312 |
)
|
313 |
-
control_scales = [float(identitynet_strength_ratio)] + [
|
314 |
-
|
315 |
-
]
|
316 |
-
control_images = [face_kps] + [
|
317 |
-
self.controlnet_map_fn[s](img_controlnet).resize((width, height))
|
318 |
-
for s in controlnet_selection
|
319 |
-
]
|
320 |
-
|
321 |
-
generator = torch.Generator(device=device).manual_seed(seed)
|
322 |
|
323 |
-
|
324 |
-
print(f"[Debug] Prompt: {inputs}, \n[Debug] Neg Prompt: {negative_prompt}")
|
325 |
|
326 |
-
self.pipe.set_ip_adapter_scale(adapter_strength_ratio)
|
327 |
outputs = self.pipe(
|
328 |
prompt=inputs,
|
329 |
negative_prompt=negative_prompt,
|
330 |
-
image_embeds=face_emb,
|
331 |
image=control_images,
|
332 |
control_mask=control_mask,
|
333 |
controlnet_conditioning_scale=control_scales,
|
334 |
-
num_inference_steps=num_steps,
|
335 |
-
guidance_scale=guidance_scale,
|
336 |
height=height,
|
337 |
width=width,
|
338 |
generator=generator,
|
339 |
-
enhance_face_region=enhance_face_region
|
340 |
)
|
341 |
-
|
342 |
images = outputs.images
|
343 |
|
344 |
-
# Check for NSFW content
|
345 |
if self.is_nsfw(images[0]):
|
346 |
return {"error": "Generated image contains NSFW content and was discarded."}
|
347 |
|
348 |
-
# Convert the
|
349 |
buffered = io.BytesIO()
|
350 |
images[0].save(buffered, format="JPEG")
|
351 |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
352 |
|
353 |
return {"generated_image_base64": img_str}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import cv2
|
2 |
import torch
|
3 |
import numpy as np
|
|
|
4 |
from PIL import Image
|
5 |
+
from typing import Tuple, List, Optional, Dict, Any
|
6 |
from pydantic import BaseModel
|
7 |
import diffusers
|
8 |
from diffusers.utils import load_image
|
|
|
17 |
from huggingface_hub import hf_hub_download
|
18 |
import base64
|
19 |
import io
|
|
|
20 |
from transformers import CLIPProcessor, CLIPModel
|
21 |
import onnxruntime as ort
|
22 |
|
23 |
+
# Global variables
|
24 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
25 |
dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
|
26 |
STYLE_NAMES = list(styles.keys())
|
|
|
50 |
pose_image_base64: Optional[str] = None
|
51 |
|
52 |
class EndpointHandler:
|
53 |
+
def __init__(self, model_dir=""):
|
54 |
# Ensure the necessary files are downloaded
|
55 |
controlnet_config = hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir=os.path.join(model_dir, "checkpoints"))
|
56 |
controlnet_model = hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir=os.path.join(model_dir, "checkpoints"))
|
|
|
58 |
|
59 |
# Load the ONNX model
|
60 |
onnx_model_path = os.path.join(model_dir, "models", "version-RFB-320.onnx")
|
61 |
+
if not os.path.exists(onnx_model_path)):
|
62 |
print(f"Model path {onnx_model_path} does not exist. Please ensure the model is available.")
|
63 |
self.ort_session = ort.InferenceSession(onnx_model_path)
|
64 |
|
65 |
+
self.openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
|
66 |
|
67 |
# Path to InstantID models
|
68 |
controlnet_path = os.path.join(model_dir, "checkpoints", "ControlNetModel")
|
69 |
|
70 |
# Load pipeline face ControlNetModel
|
71 |
+
self.controlnet_identitynet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
+
# Load custom ControlNet models
|
74 |
+
self.controlnet_pose = ControlNetModel.from_pretrained("thibaud/controlnet-openpose-sdxl-1.0", torch_dtype=dtype).to(device)
|
75 |
+
self.controlnet_canny = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0", torch_dtype=dtype).to(device)
|
|
|
76 |
|
77 |
+
# ControlNet map
|
78 |
self.controlnet_map = {
|
79 |
+
"pose": self.controlnet_pose,
|
80 |
+
"canny": self.controlnet_canny
|
81 |
}
|
82 |
|
83 |
self.controlnet_map_fn = {
|
84 |
+
"pose": self.openpose,
|
85 |
+
"canny": self.get_canny_image
|
86 |
}
|
87 |
|
88 |
+
pretrained_model_name_or_path = "stablediffusionapi/protovision-xl-high-fidel"
|
89 |
|
90 |
self.pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
|
91 |
pretrained_model_name_or_path,
|
|
|
99 |
self.pipe.scheduler.config
|
100 |
)
|
101 |
|
102 |
+
# Load and disable LCM
|
103 |
self.pipe.load_lora_weights(lcm_lora_path)
|
104 |
self.pipe.fuse_lora()
|
105 |
self.pipe.disable_lora()
|
|
|
113 |
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
114 |
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
|
115 |
|
116 |
+
def get_canny_image(self, image, t1=100, t2=200):
|
117 |
+
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
118 |
+
edges = cv2.Canny(image, t1, t2)
|
119 |
+
return Image.fromarray(edges, "L")
|
|
|
|
|
120 |
|
121 |
+
def is_nsfw(self, image: Image.Image) -> bool:
|
|
|
|
|
122 |
inputs = self.clip_processor(text=["NSFW", "SFW"], images=image, return_tensors="pt", padding=True)
|
123 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
124 |
outputs = self.clip_model(**inputs)
|
125 |
+
logits_per_image = outputs.logits_per_image # image-text similarity score
|
126 |
+
probs = logits_per_image.softmax(dim=1) # probabilities
|
127 |
nsfw_prob = probs[0, 0].item() # probability of "NSFW" label
|
128 |
+
return nsfw_prob > 0.9 # threshold for NSFW detection
|
129 |
|
130 |
def preprocess(self, image):
|
|
|
|
|
131 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
132 |
+
image = cv2.resize(image, (320, 240))
|
133 |
+
image_mean = np.array([127, 127, 127])
|
134 |
+
image = (image - image_mean) / 128
|
135 |
+
image = np.transpose(image, [2, 0, 1])
|
136 |
+
image = np.expand_dims(image, axis=0)
|
137 |
+
image = image.astype(np.float32)
|
138 |
return image
|
139 |
|
140 |
def get_face_info(self, image):
|
141 |
+
preprocessed_image = self.preprocess(image)
|
|
|
|
|
|
|
142 |
input_name = self.ort_session.get_inputs()[0].name
|
143 |
+
confidences, boxes = self.ort_session.run(None, {input_name: preprocessed_image})
|
144 |
|
|
|
|
|
145 |
face_info_list = []
|
146 |
+
for i in range(boxes.shape[1]):
|
147 |
+
box = boxes[0, i, :]
|
148 |
+
conf = confidences[0, i, 1]
|
149 |
+
if conf > 0.7:
|
150 |
+
x1, y1, x2, y2 = box[0] * 320, box[1] * 240, box[2] * 320, box[3] * 240
|
151 |
+
face_info_list.append({"bbox": [x1, y1, x2, y2]})
|
|
|
|
|
152 |
return face_info_list
|
153 |
|
154 |
+
def __call__(self, data: Any) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
request = GenerateImageRequest(**data)
|
156 |
+
|
157 |
+
if request.enable_LCM:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
self.pipe.enable_lora()
|
159 |
self.pipe.scheduler = diffusers.LCMScheduler.from_config(self.pipe.scheduler.config)
|
160 |
+
guidance_scale = min(max(request.guidance_scale, 0), 1)
|
161 |
else:
|
162 |
self.pipe.disable_lora()
|
163 |
self.pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(self.pipe.scheduler.config)
|
164 |
|
165 |
+
# Apply style
|
166 |
+
inputs, negative_prompt = self.apply_style(request.style, request.inputs, request.negative_prompt)
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
+
# Decode base64 images
|
169 |
+
face_image = self.decode_base64_image(request.face_image_base64)
|
170 |
+
pose_image = self.decode_base64_image(request.pose_image_base64) if request.pose_image_base64 else None
|
|
|
|
|
171 |
|
172 |
+
face_image = self.resize_img(face_image, max_side=1024)
|
173 |
+
face_image_cv2 = self.convert_from_image_to_cv2(face_image)
|
174 |
height, width, _ = face_image_cv2.shape
|
175 |
|
176 |
+
# Extract face features
|
177 |
face_info_list = self.get_face_info(face_image_cv2)
|
|
|
178 |
if len(face_info_list) == 0:
|
179 |
return {"error": "No faces detected."}
|
180 |
|
|
|
181 |
face_info = max(face_info_list, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))
|
182 |
+
face_kps = draw_kps(self.convert_from_cv2_to_image(face_image_cv2), face_info["bbox"])
|
|
|
183 |
img_controlnet = face_image
|
184 |
|
185 |
if pose_image:
|
186 |
+
pose_image = self.resize_img(pose_image, max_side=1024)
|
187 |
img_controlnet = pose_image
|
188 |
+
pose_image_cv2 = self.convert_from_image_to_cv2(pose_image)
|
189 |
|
|
|
190 |
face_info_list = self.get_face_info(pose_image_cv2)
|
|
|
191 |
if len(face_info_list) == 0:
|
192 |
return {"error": "No faces detected in pose image."}
|
193 |
|
194 |
face_info = max(face_info_list, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))
|
|
|
195 |
face_kps = draw_kps(pose_image, face_info["bbox"])
|
|
|
196 |
width, height = face_kps.size
|
197 |
|
198 |
+
control_mask = np.zeros([height, width, 3], dtype=np.uint8)
|
199 |
+
x1, y1, x2, y2 = map(int, face_info["bbox"])
|
|
|
200 |
control_mask[y1:y2, x1:x2] = 255
|
201 |
control_mask = Image.fromarray(control_mask)
|
202 |
|
203 |
+
controlnet_scales = {"pose": request.pose_strength, "canny": request.canny_strength}
|
|
|
|
|
|
|
204 |
self.pipe.controlnet = MultiControlNetModel(
|
205 |
+
[self.controlnet_identitynet] + [self.controlnet_map[s] for s in request.controlnet_selection]
|
|
|
206 |
)
|
207 |
+
control_scales = [float(request.identitynet_strength_ratio)] + [controlnet_scales[s] for s in request.controlnet_selection]
|
208 |
+
control_images = [face_kps] + [self.controlnet_map_fn[s](img_controlnet).resize((width, height)) for s in request.controlnet_selection]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
+
generator = torch.Generator(device=device).manual_seed(request.seed)
|
|
|
211 |
|
|
|
212 |
outputs = self.pipe(
|
213 |
prompt=inputs,
|
214 |
negative_prompt=negative_prompt,
|
|
|
215 |
image=control_images,
|
216 |
control_mask=control_mask,
|
217 |
controlnet_conditioning_scale=control_scales,
|
218 |
+
num_inference_steps=request.num_steps,
|
219 |
+
guidance_scale=request.guidance_scale,
|
220 |
height=height,
|
221 |
width=width,
|
222 |
generator=generator,
|
223 |
+
enhance_face_region=request.enhance_face_region,
|
224 |
)
|
225 |
+
|
226 |
images = outputs.images
|
227 |
|
|
|
228 |
if self.is_nsfw(images[0]):
|
229 |
return {"error": "Generated image contains NSFW content and was discarded."}
|
230 |
|
231 |
+
# Convert the image to base64
|
232 |
buffered = io.BytesIO()
|
233 |
images[0].save(buffered, format="JPEG")
|
234 |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
235 |
|
236 |
return {"generated_image_base64": img_str}
|
237 |
+
|
238 |
+
def decode_base64_image(self, image_string):
|
239 |
+
base64_image = base64.b64decode(image_string)
|
240 |
+
buffer = io.BytesIO(base64_image)
|
241 |
+
return Image.open(buffer)
|
242 |
+
|
243 |
+
def convert_from_cv2_to_image(self, img: np.ndarray) -> Image:
|
244 |
+
return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
245 |
+
|
246 |
+
def convert_from_image_to_cv2(self, img: Image) -> np.ndarray:
|
247 |
+
return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
248 |
+
|
249 |
+
def resize_img(self, input_image, max_side=1280, min_side=1024, size=None, pad_to_max_side=False, mode=PIL.Image.BILINEAR, base_pixel_number=64):
|
250 |
+
w, h = input_image.size
|
251 |
+
if size is not None:
|
252 |
+
w_resize_new, h_resize_new = size
|
253 |
+
else:
|
254 |
+
ratio = min_side / min(h, w)
|
255 |
+
w, h = round(ratio * w), round(ratio * h)
|
256 |
+
ratio = max_side / max(h, w)
|
257 |
+
input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode)
|
258 |
+
w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
|
259 |
+
h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
|
260 |
+
input_image = input_image.resize([w_resize_new, h_resize_new], mode)
|
261 |
+
|
262 |
+
if pad_to_max_side:
|
263 |
+
res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
|
264 |
+
offset_x = (max_side - w_resize_new) // 2
|
265 |
+
offset_y = (max_side - h_resize_new) // 2
|
266 |
+
res[offset_y: offset_y + h_resize_new, offset_x: offset_x + w_resize_new] = np.array(input_image)
|
267 |
+
input_image = Image.fromarray(res)
|
268 |
+
return input_image
|
269 |
+
|
270 |
+
def apply_style(self, style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
|
271 |
+
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
|
272 |
+
return p.replace("{prompt}", positive), n + " " + negative
|