Jacobmadwed commited on
Commit
032b71d
·
verified ·
1 Parent(s): 087f163

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +173 -130
handler.py CHANGED
@@ -1,13 +1,15 @@
1
  import cv2
2
  import torch
3
  import numpy as np
 
4
  from PIL import Image
5
- from typing import Tuple, List, Optional, Dict, Any
6
  from pydantic import BaseModel
7
  import diffusers
8
  from diffusers.utils import load_image
9
  from diffusers.models import ControlNetModel
10
  from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
 
11
  from style_template import styles
12
  from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline, draw_kps
13
  from controlnet_aux import OpenposeDetector
@@ -17,10 +19,10 @@ import os
17
  from huggingface_hub import hf_hub_download
18
  import base64
19
  import io
 
20
  from transformers import CLIPProcessor, CLIPModel
21
- import onnxruntime as ort
22
 
23
- # Global variables
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
  dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
26
  STYLE_NAMES = list(styles.keys())
@@ -50,42 +52,58 @@ class GenerateImageRequest(BaseModel):
50
  pose_image_base64: Optional[str] = None
51
 
52
  class EndpointHandler:
53
- def __init__(self, model_dir=""):
54
  # Ensure the necessary files are downloaded
55
  controlnet_config = hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir=os.path.join(model_dir, "checkpoints"))
56
  controlnet_model = hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir=os.path.join(model_dir, "checkpoints"))
57
  face_adapter = hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir=os.path.join(model_dir, "checkpoints"))
58
 
59
- # Load the ONNX model
60
- onnx_model_path = os.path.join(model_dir, "models", "version-RFB-320.onnx")
61
- if not os.path.exists(onnx_model_path):
62
- print(f"Model path {onnx_model_path} does not exist. Please ensure the model is available.")
63
- self.ort_session = ort.InferenceSession(onnx_model_path)
 
 
64
 
65
- self.openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
 
66
 
67
  # Path to InstantID models
68
  controlnet_path = os.path.join(model_dir, "checkpoints", "ControlNetModel")
69
 
70
  # Load pipeline face ControlNetModel
71
- self.controlnet_identitynet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- # Load custom ControlNet models
74
- self.controlnet_pose = ControlNetModel.from_pretrained("thibaud/controlnet-openpose-sdxl-1.0", torch_dtype=dtype).to(device)
75
- self.controlnet_canny = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0", torch_dtype=dtype).to(device)
 
76
 
77
- # ControlNet map
78
  self.controlnet_map = {
79
- "pose": self.controlnet_pose,
80
- "canny": self.controlnet_canny
81
  }
82
 
83
  self.controlnet_map_fn = {
84
- "pose": self.openpose,
85
- "canny": self.get_canny_image
86
  }
87
 
88
- pretrained_model_name_or_path = "stablediffusionapi/protovision-xl-high-fidel"
89
 
90
  self.pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
91
  pretrained_model_name_or_path,
@@ -99,7 +117,7 @@ class EndpointHandler:
99
  self.pipe.scheduler.config
100
  )
101
 
102
- # Load and disable LCM
103
  self.pipe.load_lora_weights(lcm_lora_path)
104
  self.pipe.fuse_lora()
105
  self.pipe.disable_lora()
@@ -113,161 +131,186 @@ class EndpointHandler:
113
  self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
114
  self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
115
 
116
- def get_canny_image(self, image, t1=100, t2=200):
117
- image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
118
- edges = cv2.Canny(image, t1, t2)
119
- return Image.fromarray(edges, "L")
120
-
121
  def is_nsfw(self, image: Image.Image) -> bool:
 
 
 
 
 
 
 
 
 
122
  inputs = self.clip_processor(text=["NSFW", "SFW"], images=image, return_tensors="pt", padding=True)
123
  inputs = {k: v.to(device) for k, v in inputs.items()}
124
  outputs = self.clip_model(**inputs)
125
- logits_per_image = outputs.logits_per_image # image-text similarity score
126
- probs = logits_per_image.softmax(dim=1) # probabilities
127
  nsfw_prob = probs[0, 0].item() # probability of "NSFW" label
128
- return nsfw_prob > 0.9 # threshold for NSFW detection
129
-
130
- def preprocess(self, image):
131
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
132
- image = cv2.resize(image, (320, 240))
133
- image_mean = np.array([127, 127, 127])
134
- image = (image - image_mean) / 128
135
- image = np.transpose(image, [2, 0, 1])
136
- image = np.expand_dims(image, axis=0)
137
- image = image.astype(np.float32)
138
- return image
139
-
140
- def get_face_info(self, image):
141
- preprocessed_image = self.preprocess(image)
142
- input_name = self.ort_session.get_inputs()[0].name
143
- confidences, boxes = self.ort_session.run(None, {input_name: preprocessed_image})
144
-
145
- print(f"Confidences shape: {confidences.shape}, Boxes shape: {boxes.shape}")
146
-
147
- face_info_list = []
148
- for i in range(len(boxes)):
149
- box = boxes[i]
150
- conf = confidences[i]
151
- if conf[0] > 0.7: # Fixing the out-of-bounds issue
152
- x1, y1, x2, y2 = box[0] * 320, box[1] * 240, box[2] * 320, box[3] * 240
153
- face_info_list.append({"bbox": [x1, y1, x2, y2]})
154
- return face_info_list
155
-
156
- def __call__(self, data: Any) -> Dict[str, Any]:
157
- request = GenerateImageRequest(**data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
- if request.enable_LCM:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  self.pipe.enable_lora()
161
  self.pipe.scheduler = diffusers.LCMScheduler.from_config(self.pipe.scheduler.config)
162
- guidance_scale = min(max(request.guidance_scale, 0), 1)
163
  else:
164
  self.pipe.disable_lora()
165
  self.pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(self.pipe.scheduler.config)
166
 
167
- # Apply style
168
- inputs, negative_prompt = self.apply_style(request.style, request.inputs, request.negative_prompt)
 
 
 
 
 
169
 
170
- # Decode base64 images
171
- face_image = self.decode_base64_image(request.face_image_base64)
172
- pose_image = self.decode_base64_image(request.pose_image_base64) if request.pose_image_base64 else None
 
 
173
 
174
- face_image = self.resize_img(face_image, max_side=1024)
175
- face_image_cv2 = self.convert_from_image_to_cv2(face_image)
176
  height, width, _ = face_image_cv2.shape
177
 
178
  # Extract face features
179
- face_info_list = self.get_face_info(face_image_cv2)
180
- if len(face_info_list) == 0:
181
- return {"error": "No faces detected."}
182
-
183
- face_info = max(face_info_list, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))
184
- face_kps = draw_kps(self.convert_from_cv2_to_image(face_image_cv2), face_info["bbox"])
 
 
 
 
185
  img_controlnet = face_image
186
-
187
  if pose_image:
188
- pose_image = self.resize_img(pose_image, max_side=1024)
189
  img_controlnet = pose_image
190
- pose_image_cv2 = self.convert_from_image_to_cv2(pose_image)
191
- face_info_list = self.get_face_info(pose_image_cv2)
192
- if len(face_info_list) == 0:
193
- return {"error": "No faces detected in pose image."}
 
 
194
 
195
- face_info = max(face_info_list, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))
196
- face_kps = draw_kps(pose_image, face_info["bbox"])
197
  width, height = face_kps.size
198
 
199
- control_mask = np.zeros([height, width, 3], dtype=np.uint8)
200
- x1, y1, x2, y2 = map(int, face_info["bbox"])
 
201
  control_mask[y1:y2, x1:x2] = 255
202
- control_mask = Image.fromarray(control_mask)
203
 
204
- controlnet_scales = {"pose": request.pose_strength, "canny": request.canny_strength}
 
 
 
205
  self.pipe.controlnet = MultiControlNetModel(
206
- [self.controlnet_identitynet] + [self.controlnet_map[s] for s in request.controlnet_selection]
 
207
  )
208
- control_scales = [float(request.identitynet_strength_ratio)] + [controlnet_scales[s] for s in request.controlnet_selection]
209
- control_images = [face_kps] + [self.controlnet_map_fn[s](img_controlnet).resize((width, height)) for s in request.controlnet_selection]
 
 
 
 
 
 
 
210
 
211
- generator = torch.Generator(device=device).manual_seed(request.seed)
 
212
 
 
213
  outputs = self.pipe(
214
  prompt=inputs,
215
  negative_prompt=negative_prompt,
 
216
  image=control_images,
217
  control_mask=control_mask,
218
  controlnet_conditioning_scale=control_scales,
219
- num_inference_steps=request.num_steps,
220
- guidance_scale=request.guidance_scale,
221
  height=height,
222
  width=width,
223
  generator=generator,
224
- enhance_face_region=request.enhance_face_region,
225
  )
226
-
227
  images = outputs.images
228
 
 
229
  if self.is_nsfw(images[0]):
230
  return {"error": "Generated image contains NSFW content and was discarded."}
231
 
232
- # Convert the image to base64
233
  buffered = io.BytesIO()
234
  images[0].save(buffered, format="JPEG")
235
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
236
 
237
  return {"generated_image_base64": img_str}
238
-
239
- def decode_base64_image(self, image_string):
240
- base64_image = base64.b64decode(image_string)
241
- buffer = io.BytesIO(base64_image)
242
- return Image.open(buffer)
243
-
244
- def convert_from_cv2_to_image(self, img: np.ndarray) -> Image:
245
- return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
246
-
247
- def convert_from_image_to_cv2(self, img: Image) -> np.ndarray:
248
- return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
249
-
250
- def resize_img(self, input_image, max_side=1280, min_side=1024, size=None, pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64):
251
- w, h = input_image.size
252
- if size is not None:
253
- w_resize_new, h_resize_new = size
254
- else:
255
- ratio = min_side / min(h, w)
256
- w, h = round(ratio * w), round(ratio * h)
257
- ratio = max_side / max(h, w)
258
- input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode)
259
- w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
260
- h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
261
- input_image = input_image.resize([w_resize_new, h_resize_new], mode)
262
-
263
- if pad_to_max_side:
264
- res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
265
- offset_x = (max_side - w_resize_new) // 2
266
- offset_y = (max_side - h_resize_new) // 2
267
- res[offset_y: offset_y + h_resize_new, offset_x: offset_x + w_resize_new] = np.array(input_image)
268
- input_image = Image.fromarray(res)
269
- return input_image
270
-
271
- def apply_style(self, style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
272
- p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
273
- return p.replace("{prompt}", positive), n + " " + negative
 
1
  import cv2
2
  import torch
3
  import numpy as np
4
+ import PIL
5
  from PIL import Image
6
+ from typing import Tuple, List, Optional
7
  from pydantic import BaseModel
8
  import diffusers
9
  from diffusers.utils import load_image
10
  from diffusers.models import ControlNetModel
11
  from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
12
+ from insightface.app import FaceAnalysis
13
  from style_template import styles
14
  from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline, draw_kps
15
  from controlnet_aux import OpenposeDetector
 
19
  from huggingface_hub import hf_hub_download
20
  import base64
21
  import io
22
+ import json
23
  from transformers import CLIPProcessor, CLIPModel
 
24
 
25
+ # global variable
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
  dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
28
  STYLE_NAMES = list(styles.keys())
 
52
  pose_image_base64: Optional[str] = None
53
 
54
  class EndpointHandler:
55
+ def __init__(self, model_dir):
56
  # Ensure the necessary files are downloaded
57
  controlnet_config = hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir=os.path.join(model_dir, "checkpoints"))
58
  controlnet_model = hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir=os.path.join(model_dir, "checkpoints"))
59
  face_adapter = hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir=os.path.join(model_dir, "checkpoints"))
60
 
61
+ dir_path = os.path.join(model_dir, "models", "face_detection_yunet_2023mar_int8.onnx")
62
+ if not os.path.exists(dir_path):
63
+ print(f"Model path {dir_path} does not exist. Attempting to download.")
64
+ self.app = FaceAnalysis(name='antelopev2', root=model_dir, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
65
+ else:
66
+ print(f"Model path {dir_path} exists. Skipping download.")
67
+ self.app = FaceAnalysis(name='antelopev2', root=model_dir, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
68
 
69
+ self.app.prepare(ctx_id=0, det_size=(640, 640))
70
+ openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
71
 
72
  # Path to InstantID models
73
  controlnet_path = os.path.join(model_dir, "checkpoints", "ControlNetModel")
74
 
75
  # Load pipeline face ControlNetModel
76
+ self.controlnet_identitynet = ControlNetModel.from_pretrained(
77
+ controlnet_path, torch_dtype=dtype
78
+ )
79
+
80
+ # controlnet-pose
81
+ controlnet_pose_model = "thibaud/controlnet-openpose-sdxl-1.0"
82
+ controlnet_canny_model = "diffusers/controlnet-canny-sdxl-1.0"
83
+
84
+ controlnet_pose = ControlNetModel.from_pretrained(
85
+ controlnet_pose_model, torch_dtype=dtype
86
+ ).to(device)
87
+ controlnet_canny = ControlNetModel.from_pretrained(
88
+ controlnet_canny_model, torch_dtype=dtype
89
+ ).to(device)
90
 
91
+ def get_canny_image(image, t1=100, t2=200):
92
+ image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
93
+ edges = cv2.Canny(image, t1, t2)
94
+ return Image.fromarray(edges, "L")
95
 
 
96
  self.controlnet_map = {
97
+ "pose": controlnet_pose,
98
+ "canny": controlnet_canny
99
  }
100
 
101
  self.controlnet_map_fn = {
102
+ "pose": openpose,
103
+ "canny": get_canny_image
104
  }
105
 
106
+ pretrained_model_name_or_path = "wangqixun/YamerMIX_v8"
107
 
108
  self.pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
109
  pretrained_model_name_or_path,
 
117
  self.pipe.scheduler.config
118
  )
119
 
120
+ # load and disable LCM
121
  self.pipe.load_lora_weights(lcm_lora_path)
122
  self.pipe.fuse_lora()
123
  self.pipe.disable_lora()
 
131
  self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
132
  self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
133
 
 
 
 
 
 
134
  def is_nsfw(self, image: Image.Image) -> bool:
135
+ """
136
+ Check if an image contains NSFW content using CLIP model.
137
+
138
+ Args:
139
+ image (Image.Image): PIL image to check.
140
+
141
+ Returns:
142
+ bool: True if the image is NSFW, False otherwise.
143
+ """
144
  inputs = self.clip_processor(text=["NSFW", "SFW"], images=image, return_tensors="pt", padding=True)
145
  inputs = {k: v.to(device) for k, v in inputs.items()}
146
  outputs = self.clip_model(**inputs)
147
+ logits_per_image = outputs.logits_per_image # this is the image-text similarity score
148
+ probs = logits_per_image.softmax(dim=1) # we take the softmax to get the probabilities
149
  nsfw_prob = probs[0, 0].item() # probability of "NSFW" label
150
+ return nsfw_prob > 0.8 # Adjusted threshold for NSFW detection
151
+
152
+ def __call__(self, data):
153
+
154
+ def convert_from_cv2_to_image(img: np.ndarray) -> Image:
155
+ return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
156
+
157
+ def convert_from_image_to_cv2(img: Image) -> np.ndarray:
158
+ return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
159
+
160
+ def resize_img(
161
+ input_image,
162
+ max_side=1280,
163
+ min_side=1024,
164
+ size=None,
165
+ pad_to_max_side=False,
166
+ mode=PIL.Image.BILINEAR,
167
+ base_pixel_number=64,
168
+ ):
169
+ w, h = input_image.size
170
+ if size is not None:
171
+ w_resize_new, h_resize_new = size
172
+ else:
173
+ ratio = min_side / min(h, w)
174
+ w, h = round(ratio * w), round(ratio * h)
175
+ ratio = max_side / max(h, w)
176
+ input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode)
177
+ w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
178
+ h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
179
+ input_image = input_image.resize([w_resize_new, h_resize_new], mode)
180
+
181
+ if pad_to_max_side:
182
+ res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
183
+ offset_x = (max_side - w_resize_new) // 2
184
+ offset_y = (max_side - h_resize_new) // 2
185
+ res[
186
+ offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new
187
+ ] = np.array(input_image)
188
+ input_image = Image.fromarray(res)
189
+ return input_image
190
+
191
+ def apply_style(
192
+ style_name: str, positive: str, negative: str = ""
193
+ ) -> Tuple[str, str]:
194
+ p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
195
+ return p.replace("{prompt}", positive), n + " " + negative
196
 
197
+ request = GenerateImageRequest(**data)
198
+ inputs = request.inputs
199
+ negative_prompt = request.negative_prompt
200
+ style_name = request.style
201
+ identitynet_strength_ratio = request.identitynet_strength_ratio
202
+ adapter_strength_ratio = request.adapter_strength_ratio
203
+ pose_strength = request.pose_strength
204
+ canny_strength = request.canny_strength
205
+ num_steps = request.num_steps
206
+ guidance_scale = request.guidance_scale
207
+ controlnet_selection = request.controlnet_selection
208
+ seed = request.seed
209
+ enhance_face_region = request.enhance_face_region
210
+ enable_LCM = request.enable_LCM
211
+
212
+ if enable_LCM:
213
  self.pipe.enable_lora()
214
  self.pipe.scheduler = diffusers.LCMScheduler.from_config(self.pipe.scheduler.config)
215
+ guidance_scale = min(max(guidance_scale, 0), 1)
216
  else:
217
  self.pipe.disable_lora()
218
  self.pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(self.pipe.scheduler.config)
219
 
220
+ # apply the style template
221
+ inputs, negative_prompt = apply_style(style_name, inputs, negative_prompt)
222
+
223
+ # Decode base64 image
224
+ face_image_base64 = data.get("face_image_base64")
225
+ face_image_data = base64.b64decode(face_image_base64)
226
+ face_image = Image.open(io.BytesIO(face_image_data))
227
 
228
+ pose_image_base64 = data.get("pose_image_base64")
229
+ pose_image = None
230
+ if pose_image_base64:
231
+ pose_image_data = base64.b64decode(pose_image_base64)
232
+ pose_image = Image.open(io.BytesIO(pose_image_data))
233
 
234
+ face_image = resize_img(face_image, max_side=1024)
235
+ face_image_cv2 = convert_from_image_to_cv2(face_image)
236
  height, width, _ = face_image_cv2.shape
237
 
238
  # Extract face features
239
+ face_info = self.app.get(face_image_cv2)
240
+
241
+ face_info = sorted(
242
+ face_info,
243
+ key=lambda x: (x["bbox"][2] - x["bbox"][0]) * x["bbox"][3] - x["bbox"][1],
244
+ )[
245
+ -1
246
+ ] # only use the maximum face
247
+ face_emb = face_info["embedding"]
248
+ face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info["kps"])
249
  img_controlnet = face_image
 
250
  if pose_image:
251
+ pose_image = resize_img(pose_image, max_side=1024)
252
  img_controlnet = pose_image
253
+ pose_image_cv2 = convert_from_image_to_cv2(pose_image)
254
+
255
+ face_info = self.app.get(pose_image_cv2)
256
+
257
+ face_info = face_info[-1]
258
+ face_kps = draw_kps(pose_image, face_info["kps"])
259
 
 
 
260
  width, height = face_kps.size
261
 
262
+ control_mask = np.zeros([height, width, 3])
263
+ x1, y1, x2, y2 = face_info["bbox"]
264
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
265
  control_mask[y1:y2, x1:x2] = 255
266
+ control_mask = Image.fromarray(control_mask.astype(np.uint8))
267
 
268
+ controlnet_scales = {
269
+ "pose": pose_strength,
270
+ "canny": canny_strength
271
+ }
272
  self.pipe.controlnet = MultiControlNetModel(
273
+ [self.controlnet_identitynet]
274
+ + [self.controlnet_map[s] for s in controlnet_selection]
275
  )
276
+ control_scales = [float(identitynet_strength_ratio)] + [
277
+ controlnet_scales[s] for s in controlnet_selection
278
+ ]
279
+ control_images = [face_kps] + [
280
+ self.controlnet_map_fn[s](img_controlnet).resize((width, height))
281
+ for s in controlnet_selection
282
+ ]
283
+
284
+ generator = torch.Generator(device=device).manual_seed(seed)
285
 
286
+ print("Start inference...")
287
+ print(f"[Debug] Prompt: {inputs}, \n[Debug] Neg Prompt: {negative_prompt}")
288
 
289
+ self.pipe.set_ip_adapter_scale(adapter_strength_ratio)
290
  outputs = self.pipe(
291
  prompt=inputs,
292
  negative_prompt=negative_prompt,
293
+ image_embeds=face_emb,
294
  image=control_images,
295
  control_mask=control_mask,
296
  controlnet_conditioning_scale=control_scales,
297
+ num_inference_steps=num_steps,
298
+ guidance_scale=guidance_scale,
299
  height=height,
300
  width=width,
301
  generator=generator,
302
+ enhance_face_region=enhance_face_region
303
  )
304
+
305
  images = outputs.images
306
 
307
+ # Check for NSFW content
308
  if self.is_nsfw(images[0]):
309
  return {"error": "Generated image contains NSFW content and was discarded."}
310
 
311
+ # Convert the output image to base64
312
  buffered = io.BytesIO()
313
  images[0].save(buffered, format="JPEG")
314
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
315
 
316
  return {"generated_image_base64": img_str}