Update handler.py
Browse files- handler.py +173 -130
handler.py
CHANGED
@@ -1,13 +1,15 @@
|
|
1 |
import cv2
|
2 |
import torch
|
3 |
import numpy as np
|
|
|
4 |
from PIL import Image
|
5 |
-
from typing import Tuple, List, Optional
|
6 |
from pydantic import BaseModel
|
7 |
import diffusers
|
8 |
from diffusers.utils import load_image
|
9 |
from diffusers.models import ControlNetModel
|
10 |
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
|
|
11 |
from style_template import styles
|
12 |
from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline, draw_kps
|
13 |
from controlnet_aux import OpenposeDetector
|
@@ -17,10 +19,10 @@ import os
|
|
17 |
from huggingface_hub import hf_hub_download
|
18 |
import base64
|
19 |
import io
|
|
|
20 |
from transformers import CLIPProcessor, CLIPModel
|
21 |
-
import onnxruntime as ort
|
22 |
|
23 |
-
#
|
24 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
25 |
dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
|
26 |
STYLE_NAMES = list(styles.keys())
|
@@ -50,42 +52,58 @@ class GenerateImageRequest(BaseModel):
|
|
50 |
pose_image_base64: Optional[str] = None
|
51 |
|
52 |
class EndpointHandler:
|
53 |
-
def __init__(self, model_dir
|
54 |
# Ensure the necessary files are downloaded
|
55 |
controlnet_config = hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir=os.path.join(model_dir, "checkpoints"))
|
56 |
controlnet_model = hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir=os.path.join(model_dir, "checkpoints"))
|
57 |
face_adapter = hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir=os.path.join(model_dir, "checkpoints"))
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
64 |
|
65 |
-
self.
|
|
|
66 |
|
67 |
# Path to InstantID models
|
68 |
controlnet_path = os.path.join(model_dir, "checkpoints", "ControlNetModel")
|
69 |
|
70 |
# Load pipeline face ControlNetModel
|
71 |
-
self.controlnet_identitynet = ControlNetModel.from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
76 |
|
77 |
-
# ControlNet map
|
78 |
self.controlnet_map = {
|
79 |
-
"pose":
|
80 |
-
"canny":
|
81 |
}
|
82 |
|
83 |
self.controlnet_map_fn = {
|
84 |
-
"pose":
|
85 |
-
"canny":
|
86 |
}
|
87 |
|
88 |
-
pretrained_model_name_or_path = "
|
89 |
|
90 |
self.pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
|
91 |
pretrained_model_name_or_path,
|
@@ -99,7 +117,7 @@ class EndpointHandler:
|
|
99 |
self.pipe.scheduler.config
|
100 |
)
|
101 |
|
102 |
-
#
|
103 |
self.pipe.load_lora_weights(lcm_lora_path)
|
104 |
self.pipe.fuse_lora()
|
105 |
self.pipe.disable_lora()
|
@@ -113,161 +131,186 @@ class EndpointHandler:
|
|
113 |
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
114 |
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
|
115 |
|
116 |
-
def get_canny_image(self, image, t1=100, t2=200):
|
117 |
-
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
118 |
-
edges = cv2.Canny(image, t1, t2)
|
119 |
-
return Image.fromarray(edges, "L")
|
120 |
-
|
121 |
def is_nsfw(self, image: Image.Image) -> bool:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
inputs = self.clip_processor(text=["NSFW", "SFW"], images=image, return_tensors="pt", padding=True)
|
123 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
124 |
outputs = self.clip_model(**inputs)
|
125 |
-
logits_per_image = outputs.logits_per_image # image-text similarity score
|
126 |
-
probs = logits_per_image.softmax(dim=1) # probabilities
|
127 |
nsfw_prob = probs[0, 0].item() # probability of "NSFW" label
|
128 |
-
return nsfw_prob > 0.
|
129 |
-
|
130 |
-
def
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
self.pipe.enable_lora()
|
161 |
self.pipe.scheduler = diffusers.LCMScheduler.from_config(self.pipe.scheduler.config)
|
162 |
-
guidance_scale = min(max(
|
163 |
else:
|
164 |
self.pipe.disable_lora()
|
165 |
self.pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(self.pipe.scheduler.config)
|
166 |
|
167 |
-
#
|
168 |
-
inputs, negative_prompt =
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
-
|
171 |
-
|
172 |
-
|
|
|
|
|
173 |
|
174 |
-
face_image =
|
175 |
-
face_image_cv2 =
|
176 |
height, width, _ = face_image_cv2.shape
|
177 |
|
178 |
# Extract face features
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
|
|
|
|
|
|
|
|
185 |
img_controlnet = face_image
|
186 |
-
|
187 |
if pose_image:
|
188 |
-
pose_image =
|
189 |
img_controlnet = pose_image
|
190 |
-
pose_image_cv2 =
|
191 |
-
|
192 |
-
|
193 |
-
|
|
|
|
|
194 |
|
195 |
-
face_info = max(face_info_list, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))
|
196 |
-
face_kps = draw_kps(pose_image, face_info["bbox"])
|
197 |
width, height = face_kps.size
|
198 |
|
199 |
-
control_mask = np.zeros([height, width, 3]
|
200 |
-
x1, y1, x2, y2 =
|
|
|
201 |
control_mask[y1:y2, x1:x2] = 255
|
202 |
-
control_mask = Image.fromarray(control_mask)
|
203 |
|
204 |
-
controlnet_scales = {
|
|
|
|
|
|
|
205 |
self.pipe.controlnet = MultiControlNetModel(
|
206 |
-
[self.controlnet_identitynet]
|
|
|
207 |
)
|
208 |
-
control_scales = [float(
|
209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
|
211 |
-
|
|
|
212 |
|
|
|
213 |
outputs = self.pipe(
|
214 |
prompt=inputs,
|
215 |
negative_prompt=negative_prompt,
|
|
|
216 |
image=control_images,
|
217 |
control_mask=control_mask,
|
218 |
controlnet_conditioning_scale=control_scales,
|
219 |
-
num_inference_steps=
|
220 |
-
guidance_scale=
|
221 |
height=height,
|
222 |
width=width,
|
223 |
generator=generator,
|
224 |
-
enhance_face_region=
|
225 |
)
|
226 |
-
|
227 |
images = outputs.images
|
228 |
|
|
|
229 |
if self.is_nsfw(images[0]):
|
230 |
return {"error": "Generated image contains NSFW content and was discarded."}
|
231 |
|
232 |
-
# Convert the image to base64
|
233 |
buffered = io.BytesIO()
|
234 |
images[0].save(buffered, format="JPEG")
|
235 |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
236 |
|
237 |
return {"generated_image_base64": img_str}
|
238 |
-
|
239 |
-
def decode_base64_image(self, image_string):
|
240 |
-
base64_image = base64.b64decode(image_string)
|
241 |
-
buffer = io.BytesIO(base64_image)
|
242 |
-
return Image.open(buffer)
|
243 |
-
|
244 |
-
def convert_from_cv2_to_image(self, img: np.ndarray) -> Image:
|
245 |
-
return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
246 |
-
|
247 |
-
def convert_from_image_to_cv2(self, img: Image) -> np.ndarray:
|
248 |
-
return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
249 |
-
|
250 |
-
def resize_img(self, input_image, max_side=1280, min_side=1024, size=None, pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64):
|
251 |
-
w, h = input_image.size
|
252 |
-
if size is not None:
|
253 |
-
w_resize_new, h_resize_new = size
|
254 |
-
else:
|
255 |
-
ratio = min_side / min(h, w)
|
256 |
-
w, h = round(ratio * w), round(ratio * h)
|
257 |
-
ratio = max_side / max(h, w)
|
258 |
-
input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode)
|
259 |
-
w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
|
260 |
-
h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
|
261 |
-
input_image = input_image.resize([w_resize_new, h_resize_new], mode)
|
262 |
-
|
263 |
-
if pad_to_max_side:
|
264 |
-
res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
|
265 |
-
offset_x = (max_side - w_resize_new) // 2
|
266 |
-
offset_y = (max_side - h_resize_new) // 2
|
267 |
-
res[offset_y: offset_y + h_resize_new, offset_x: offset_x + w_resize_new] = np.array(input_image)
|
268 |
-
input_image = Image.fromarray(res)
|
269 |
-
return input_image
|
270 |
-
|
271 |
-
def apply_style(self, style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
|
272 |
-
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
|
273 |
-
return p.replace("{prompt}", positive), n + " " + negative
|
|
|
1 |
import cv2
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
+
import PIL
|
5 |
from PIL import Image
|
6 |
+
from typing import Tuple, List, Optional
|
7 |
from pydantic import BaseModel
|
8 |
import diffusers
|
9 |
from diffusers.utils import load_image
|
10 |
from diffusers.models import ControlNetModel
|
11 |
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
12 |
+
from insightface.app import FaceAnalysis
|
13 |
from style_template import styles
|
14 |
from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline, draw_kps
|
15 |
from controlnet_aux import OpenposeDetector
|
|
|
19 |
from huggingface_hub import hf_hub_download
|
20 |
import base64
|
21 |
import io
|
22 |
+
import json
|
23 |
from transformers import CLIPProcessor, CLIPModel
|
|
|
24 |
|
25 |
+
# global variable
|
26 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
27 |
dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
|
28 |
STYLE_NAMES = list(styles.keys())
|
|
|
52 |
pose_image_base64: Optional[str] = None
|
53 |
|
54 |
class EndpointHandler:
|
55 |
+
def __init__(self, model_dir):
|
56 |
# Ensure the necessary files are downloaded
|
57 |
controlnet_config = hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir=os.path.join(model_dir, "checkpoints"))
|
58 |
controlnet_model = hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir=os.path.join(model_dir, "checkpoints"))
|
59 |
face_adapter = hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir=os.path.join(model_dir, "checkpoints"))
|
60 |
|
61 |
+
dir_path = os.path.join(model_dir, "models", "face_detection_yunet_2023mar_int8.onnx")
|
62 |
+
if not os.path.exists(dir_path):
|
63 |
+
print(f"Model path {dir_path} does not exist. Attempting to download.")
|
64 |
+
self.app = FaceAnalysis(name='antelopev2', root=model_dir, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
65 |
+
else:
|
66 |
+
print(f"Model path {dir_path} exists. Skipping download.")
|
67 |
+
self.app = FaceAnalysis(name='antelopev2', root=model_dir, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
68 |
|
69 |
+
self.app.prepare(ctx_id=0, det_size=(640, 640))
|
70 |
+
openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
|
71 |
|
72 |
# Path to InstantID models
|
73 |
controlnet_path = os.path.join(model_dir, "checkpoints", "ControlNetModel")
|
74 |
|
75 |
# Load pipeline face ControlNetModel
|
76 |
+
self.controlnet_identitynet = ControlNetModel.from_pretrained(
|
77 |
+
controlnet_path, torch_dtype=dtype
|
78 |
+
)
|
79 |
+
|
80 |
+
# controlnet-pose
|
81 |
+
controlnet_pose_model = "thibaud/controlnet-openpose-sdxl-1.0"
|
82 |
+
controlnet_canny_model = "diffusers/controlnet-canny-sdxl-1.0"
|
83 |
+
|
84 |
+
controlnet_pose = ControlNetModel.from_pretrained(
|
85 |
+
controlnet_pose_model, torch_dtype=dtype
|
86 |
+
).to(device)
|
87 |
+
controlnet_canny = ControlNetModel.from_pretrained(
|
88 |
+
controlnet_canny_model, torch_dtype=dtype
|
89 |
+
).to(device)
|
90 |
|
91 |
+
def get_canny_image(image, t1=100, t2=200):
|
92 |
+
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
93 |
+
edges = cv2.Canny(image, t1, t2)
|
94 |
+
return Image.fromarray(edges, "L")
|
95 |
|
|
|
96 |
self.controlnet_map = {
|
97 |
+
"pose": controlnet_pose,
|
98 |
+
"canny": controlnet_canny
|
99 |
}
|
100 |
|
101 |
self.controlnet_map_fn = {
|
102 |
+
"pose": openpose,
|
103 |
+
"canny": get_canny_image
|
104 |
}
|
105 |
|
106 |
+
pretrained_model_name_or_path = "wangqixun/YamerMIX_v8"
|
107 |
|
108 |
self.pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
|
109 |
pretrained_model_name_or_path,
|
|
|
117 |
self.pipe.scheduler.config
|
118 |
)
|
119 |
|
120 |
+
# load and disable LCM
|
121 |
self.pipe.load_lora_weights(lcm_lora_path)
|
122 |
self.pipe.fuse_lora()
|
123 |
self.pipe.disable_lora()
|
|
|
131 |
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
132 |
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
|
133 |
|
|
|
|
|
|
|
|
|
|
|
134 |
def is_nsfw(self, image: Image.Image) -> bool:
|
135 |
+
"""
|
136 |
+
Check if an image contains NSFW content using CLIP model.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
image (Image.Image): PIL image to check.
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
bool: True if the image is NSFW, False otherwise.
|
143 |
+
"""
|
144 |
inputs = self.clip_processor(text=["NSFW", "SFW"], images=image, return_tensors="pt", padding=True)
|
145 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
146 |
outputs = self.clip_model(**inputs)
|
147 |
+
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
148 |
+
probs = logits_per_image.softmax(dim=1) # we take the softmax to get the probabilities
|
149 |
nsfw_prob = probs[0, 0].item() # probability of "NSFW" label
|
150 |
+
return nsfw_prob > 0.8 # Adjusted threshold for NSFW detection
|
151 |
+
|
152 |
+
def __call__(self, data):
|
153 |
+
|
154 |
+
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
|
155 |
+
return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
156 |
+
|
157 |
+
def convert_from_image_to_cv2(img: Image) -> np.ndarray:
|
158 |
+
return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
159 |
+
|
160 |
+
def resize_img(
|
161 |
+
input_image,
|
162 |
+
max_side=1280,
|
163 |
+
min_side=1024,
|
164 |
+
size=None,
|
165 |
+
pad_to_max_side=False,
|
166 |
+
mode=PIL.Image.BILINEAR,
|
167 |
+
base_pixel_number=64,
|
168 |
+
):
|
169 |
+
w, h = input_image.size
|
170 |
+
if size is not None:
|
171 |
+
w_resize_new, h_resize_new = size
|
172 |
+
else:
|
173 |
+
ratio = min_side / min(h, w)
|
174 |
+
w, h = round(ratio * w), round(ratio * h)
|
175 |
+
ratio = max_side / max(h, w)
|
176 |
+
input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode)
|
177 |
+
w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
|
178 |
+
h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
|
179 |
+
input_image = input_image.resize([w_resize_new, h_resize_new], mode)
|
180 |
+
|
181 |
+
if pad_to_max_side:
|
182 |
+
res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
|
183 |
+
offset_x = (max_side - w_resize_new) // 2
|
184 |
+
offset_y = (max_side - h_resize_new) // 2
|
185 |
+
res[
|
186 |
+
offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new
|
187 |
+
] = np.array(input_image)
|
188 |
+
input_image = Image.fromarray(res)
|
189 |
+
return input_image
|
190 |
+
|
191 |
+
def apply_style(
|
192 |
+
style_name: str, positive: str, negative: str = ""
|
193 |
+
) -> Tuple[str, str]:
|
194 |
+
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
|
195 |
+
return p.replace("{prompt}", positive), n + " " + negative
|
196 |
|
197 |
+
request = GenerateImageRequest(**data)
|
198 |
+
inputs = request.inputs
|
199 |
+
negative_prompt = request.negative_prompt
|
200 |
+
style_name = request.style
|
201 |
+
identitynet_strength_ratio = request.identitynet_strength_ratio
|
202 |
+
adapter_strength_ratio = request.adapter_strength_ratio
|
203 |
+
pose_strength = request.pose_strength
|
204 |
+
canny_strength = request.canny_strength
|
205 |
+
num_steps = request.num_steps
|
206 |
+
guidance_scale = request.guidance_scale
|
207 |
+
controlnet_selection = request.controlnet_selection
|
208 |
+
seed = request.seed
|
209 |
+
enhance_face_region = request.enhance_face_region
|
210 |
+
enable_LCM = request.enable_LCM
|
211 |
+
|
212 |
+
if enable_LCM:
|
213 |
self.pipe.enable_lora()
|
214 |
self.pipe.scheduler = diffusers.LCMScheduler.from_config(self.pipe.scheduler.config)
|
215 |
+
guidance_scale = min(max(guidance_scale, 0), 1)
|
216 |
else:
|
217 |
self.pipe.disable_lora()
|
218 |
self.pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(self.pipe.scheduler.config)
|
219 |
|
220 |
+
# apply the style template
|
221 |
+
inputs, negative_prompt = apply_style(style_name, inputs, negative_prompt)
|
222 |
+
|
223 |
+
# Decode base64 image
|
224 |
+
face_image_base64 = data.get("face_image_base64")
|
225 |
+
face_image_data = base64.b64decode(face_image_base64)
|
226 |
+
face_image = Image.open(io.BytesIO(face_image_data))
|
227 |
|
228 |
+
pose_image_base64 = data.get("pose_image_base64")
|
229 |
+
pose_image = None
|
230 |
+
if pose_image_base64:
|
231 |
+
pose_image_data = base64.b64decode(pose_image_base64)
|
232 |
+
pose_image = Image.open(io.BytesIO(pose_image_data))
|
233 |
|
234 |
+
face_image = resize_img(face_image, max_side=1024)
|
235 |
+
face_image_cv2 = convert_from_image_to_cv2(face_image)
|
236 |
height, width, _ = face_image_cv2.shape
|
237 |
|
238 |
# Extract face features
|
239 |
+
face_info = self.app.get(face_image_cv2)
|
240 |
+
|
241 |
+
face_info = sorted(
|
242 |
+
face_info,
|
243 |
+
key=lambda x: (x["bbox"][2] - x["bbox"][0]) * x["bbox"][3] - x["bbox"][1],
|
244 |
+
)[
|
245 |
+
-1
|
246 |
+
] # only use the maximum face
|
247 |
+
face_emb = face_info["embedding"]
|
248 |
+
face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info["kps"])
|
249 |
img_controlnet = face_image
|
|
|
250 |
if pose_image:
|
251 |
+
pose_image = resize_img(pose_image, max_side=1024)
|
252 |
img_controlnet = pose_image
|
253 |
+
pose_image_cv2 = convert_from_image_to_cv2(pose_image)
|
254 |
+
|
255 |
+
face_info = self.app.get(pose_image_cv2)
|
256 |
+
|
257 |
+
face_info = face_info[-1]
|
258 |
+
face_kps = draw_kps(pose_image, face_info["kps"])
|
259 |
|
|
|
|
|
260 |
width, height = face_kps.size
|
261 |
|
262 |
+
control_mask = np.zeros([height, width, 3])
|
263 |
+
x1, y1, x2, y2 = face_info["bbox"]
|
264 |
+
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
265 |
control_mask[y1:y2, x1:x2] = 255
|
266 |
+
control_mask = Image.fromarray(control_mask.astype(np.uint8))
|
267 |
|
268 |
+
controlnet_scales = {
|
269 |
+
"pose": pose_strength,
|
270 |
+
"canny": canny_strength
|
271 |
+
}
|
272 |
self.pipe.controlnet = MultiControlNetModel(
|
273 |
+
[self.controlnet_identitynet]
|
274 |
+
+ [self.controlnet_map[s] for s in controlnet_selection]
|
275 |
)
|
276 |
+
control_scales = [float(identitynet_strength_ratio)] + [
|
277 |
+
controlnet_scales[s] for s in controlnet_selection
|
278 |
+
]
|
279 |
+
control_images = [face_kps] + [
|
280 |
+
self.controlnet_map_fn[s](img_controlnet).resize((width, height))
|
281 |
+
for s in controlnet_selection
|
282 |
+
]
|
283 |
+
|
284 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
285 |
|
286 |
+
print("Start inference...")
|
287 |
+
print(f"[Debug] Prompt: {inputs}, \n[Debug] Neg Prompt: {negative_prompt}")
|
288 |
|
289 |
+
self.pipe.set_ip_adapter_scale(adapter_strength_ratio)
|
290 |
outputs = self.pipe(
|
291 |
prompt=inputs,
|
292 |
negative_prompt=negative_prompt,
|
293 |
+
image_embeds=face_emb,
|
294 |
image=control_images,
|
295 |
control_mask=control_mask,
|
296 |
controlnet_conditioning_scale=control_scales,
|
297 |
+
num_inference_steps=num_steps,
|
298 |
+
guidance_scale=guidance_scale,
|
299 |
height=height,
|
300 |
width=width,
|
301 |
generator=generator,
|
302 |
+
enhance_face_region=enhance_face_region
|
303 |
)
|
304 |
+
|
305 |
images = outputs.images
|
306 |
|
307 |
+
# Check for NSFW content
|
308 |
if self.is_nsfw(images[0]):
|
309 |
return {"error": "Generated image contains NSFW content and was discarded."}
|
310 |
|
311 |
+
# Convert the output image to base64
|
312 |
buffered = io.BytesIO()
|
313 |
images[0].save(buffered, format="JPEG")
|
314 |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
315 |
|
316 |
return {"generated_image_base64": img_str}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|