Jacobmadwed commited on
Commit
bb03b0c
·
verified ·
1 Parent(s): b7bf383

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +101 -182
handler.py CHANGED
@@ -1,9 +1,8 @@
1
  import cv2
2
  import torch
3
  import numpy as np
4
- import PIL
5
  from PIL import Image
6
- from typing import Tuple, List, Optional
7
  from pydantic import BaseModel
8
  import diffusers
9
  from diffusers.utils import load_image
@@ -18,11 +17,10 @@ import os
18
  from huggingface_hub import hf_hub_download
19
  import base64
20
  import io
21
- import json
22
  from transformers import CLIPProcessor, CLIPModel
23
  import onnxruntime as ort
24
 
25
- # global variable
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
  dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
28
  STYLE_NAMES = list(styles.keys())
@@ -52,7 +50,7 @@ class GenerateImageRequest(BaseModel):
52
  pose_image_base64: Optional[str] = None
53
 
54
  class EndpointHandler:
55
- def __init__(self, model_dir):
56
  # Ensure the necessary files are downloaded
57
  controlnet_config = hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir=os.path.join(model_dir, "checkpoints"))
58
  controlnet_model = hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir=os.path.join(model_dir, "checkpoints"))
@@ -60,47 +58,34 @@ class EndpointHandler:
60
 
61
  # Load the ONNX model
62
  onnx_model_path = os.path.join(model_dir, "models", "version-RFB-320.onnx")
63
- if not os.path.exists(onnx_model_path):
64
  print(f"Model path {onnx_model_path} does not exist. Please ensure the model is available.")
65
  self.ort_session = ort.InferenceSession(onnx_model_path)
66
 
67
- openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
68
 
69
  # Path to InstantID models
70
  controlnet_path = os.path.join(model_dir, "checkpoints", "ControlNetModel")
71
 
72
  # Load pipeline face ControlNetModel
73
- self.controlnet_identitynet = ControlNetModel.from_pretrained(
74
- controlnet_path, torch_dtype=dtype
75
- )
76
-
77
- # controlnet-pose
78
- controlnet_pose_model = "thibaud/controlnet-openpose-sdxl-1.0"
79
- controlnet_canny_model = "diffusers/controlnet-canny-sdxl-1.0"
80
-
81
- controlnet_pose = ControlNetModel.from_pretrained(
82
- controlnet_pose_model, torch_dtype=dtype
83
- ).to(device)
84
- controlnet_canny = ControlNetModel.from_pretrained(
85
- controlnet_canny_model, torch_dtype=dtype
86
- ).to(device)
87
 
88
- def get_canny_image(image, t1=100, t2=200):
89
- image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
90
- edges = cv2.Canny(image, t1, t2)
91
- return Image.fromarray(edges, "L")
92
 
 
93
  self.controlnet_map = {
94
- "pose": controlnet_pose,
95
- "canny": controlnet_canny
96
  }
97
 
98
  self.controlnet_map_fn = {
99
- "pose": openpose,
100
- "canny": get_canny_image
101
  }
102
 
103
- pretrained_model_name_or_path = "wangqixun/YamerMIX_v8"
104
 
105
  self.pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
106
  pretrained_model_name_or_path,
@@ -114,7 +99,7 @@ class EndpointHandler:
114
  self.pipe.scheduler.config
115
  )
116
 
117
- # load and disable LCM
118
  self.pipe.load_lora_weights(lcm_lora_path)
119
  self.pipe.fuse_lora()
120
  self.pipe.disable_lora()
@@ -128,226 +113,160 @@ class EndpointHandler:
128
  self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
129
  self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
130
 
131
- def is_nsfw(self, image: Image.Image) -> bool:
132
- """
133
- Check if an image contains NSFW content using CLIP model.
134
-
135
- Args:
136
- image (Image.Image): PIL image to check.
137
 
138
- Returns:
139
- bool: True if the image is NSFW, False otherwise.
140
- """
141
  inputs = self.clip_processor(text=["NSFW", "SFW"], images=image, return_tensors="pt", padding=True)
142
  inputs = {k: v.to(device) for k, v in inputs.items()}
143
  outputs = self.clip_model(**inputs)
144
- logits_per_image = outputs.logits_per_image # this is the image-text similarity score
145
- probs = logits_per_image.softmax(dim=1) # we take the softmax to get the probabilities
146
  nsfw_prob = probs[0, 0].item() # probability of "NSFW" label
147
- return nsfw_prob > 0.8 # Adjusted threshold for NSFW detection
148
 
149
  def preprocess(self, image):
150
- # Preprocess the image for ONNX model
151
- image = cv2.resize(image, (320, 240)) # Adjust based on model input size
152
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
153
- image = np.transpose(image, (2, 0, 1))
154
- image = image[np.newaxis, :, :, :].astype(np.float32) / 127.5 - 1.0 # Normalize to [-1, 1]
 
 
 
 
155
  return image
156
 
157
  def get_face_info(self, image):
158
- # Preprocess the image
159
- image = self.preprocess(image)
160
-
161
- # Run the ONNX model to get the face detection results
162
  input_name = self.ort_session.get_inputs()[0].name
163
- outputs = self.ort_session.run(None, {input_name: image})
164
 
165
- # Process the output to extract face information
166
- bboxes = outputs[0][0] # Adjust based on model output structure
167
  face_info_list = []
168
- for bbox in bboxes:
169
- score = bbox[2]
170
- if score > 0.5: # Confidence threshold
171
- x1, y1, x2, y2 = bbox[3:7] * [320, 240, 320, 240] # Scale coordinates
172
- face_info_list.append({
173
- "bbox": [x1, y1, x2, y2],
174
- "embedding": self.get_face_embedding(image[:, :, int(y1):int(y2), int(x1):int(x2)])
175
- })
176
  return face_info_list
177
 
178
- def get_face_embedding(self, image):
179
- # Extract features for the face image region
180
- # Implement the logic to extract face embeddings
181
- # For now, returning a placeholder value
182
- return np.random.rand(512) # Replace with actual embedding extraction logic
183
-
184
- def __call__(self, data):
185
-
186
- def convert_from_cv2_to_image(img: np.ndarray) -> Image:
187
- return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
188
-
189
- def convert_from_image_to_cv2(img: Image) -> np.ndarray:
190
- return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
191
-
192
- def resize_img(
193
- input_image,
194
- max_side=1280,
195
- min_side=1024,
196
- size=None,
197
- pad_to_max_side=False,
198
- mode=PIL.Image.BILINEAR,
199
- base_pixel_number=64,
200
- ):
201
- w, h = input_image.size
202
- if size is not None:
203
- w_resize_new, h_resize_new = size
204
- else:
205
- ratio = min_side / min(h, w)
206
- w, h = round(ratio * w), round(ratio * h)
207
- ratio = max_side / max(h, w)
208
- input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode)
209
- w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
210
- h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
211
- input_image = input_image.resize([w_resize_new, h_resize_new], mode)
212
-
213
- if pad_to_max_side:
214
- res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
215
- offset_x = (max_side - w_resize_new) // 2
216
- offset_y = (max_side - h_resize_new) // 2
217
- res[
218
- offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new
219
- ] = np.array(input_image)
220
- input_image = Image.fromarray(res)
221
- return input_image
222
-
223
- def apply_style(
224
- style_name: str, positive: str, negative: str = ""
225
- ) -> Tuple[str, str]:
226
- p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
227
- return p.replace("{prompt}", positive), n + " " + negative
228
-
229
  request = GenerateImageRequest(**data)
230
- inputs = request.inputs
231
- negative_prompt = request.negative_prompt
232
- style_name = request.style
233
- identitynet_strength_ratio = request.identitynet_strength_ratio
234
- adapter_strength_ratio = request.adapter_strength_ratio
235
- pose_strength = request.pose_strength
236
- canny_strength = request.canny_strength
237
- num_steps = request.num_steps
238
- guidance_scale = request.guidance_scale
239
- controlnet_selection = request.controlnet_selection
240
- seed = request.seed
241
- enhance_face_region = request.enhance_face_region
242
- enable_LCM = request.enable_LCM
243
-
244
- if enable_LCM:
245
  self.pipe.enable_lora()
246
  self.pipe.scheduler = diffusers.LCMScheduler.from_config(self.pipe.scheduler.config)
247
- guidance_scale = min(max(guidance_scale, 0), 1)
248
  else:
249
  self.pipe.disable_lora()
250
  self.pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(self.pipe.scheduler.config)
251
 
252
- # apply the style template
253
- inputs, negative_prompt = apply_style(style_name, inputs, negative_prompt)
254
-
255
- # Decode base64 image
256
- face_image_base64 = data.get("face_image_base64")
257
- face_image_data = base64.b64decode(face_image_base64)
258
- face_image = Image.open(io.BytesIO(face_image_data))
259
 
260
- pose_image_base64 = data.get("pose_image_base64")
261
- pose_image = None
262
- if pose_image_base64:
263
- pose_image_data = base64.b64decode(pose_image_base64)
264
- pose_image = Image.open(io.BytesIO(pose_image_data))
265
 
266
- face_image = resize_img(face_image, max_side=1024)
267
- face_image_cv2 = convert_from_image_to_cv2(face_image)
268
  height, width, _ = face_image_cv2.shape
269
 
270
- # Extract face features using the ONNX model
271
  face_info_list = self.get_face_info(face_image_cv2)
272
-
273
  if len(face_info_list) == 0:
274
  return {"error": "No faces detected."}
275
 
276
- # Use the largest face detected
277
  face_info = max(face_info_list, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))
278
- face_emb = face_info["embedding"]
279
- face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info["bbox"])
280
  img_controlnet = face_image
281
 
282
  if pose_image:
283
- pose_image = resize_img(pose_image, max_side=1024)
284
  img_controlnet = pose_image
285
- pose_image_cv2 = convert_from_image_to_cv2(pose_image)
286
 
287
- # Extract face features from pose image using the ONNX model
288
  face_info_list = self.get_face_info(pose_image_cv2)
289
-
290
  if len(face_info_list) == 0:
291
  return {"error": "No faces detected in pose image."}
292
 
293
  face_info = max(face_info_list, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))
294
- face_emb = face_info["embedding"]
295
  face_kps = draw_kps(pose_image, face_info["bbox"])
296
-
297
  width, height = face_kps.size
298
 
299
- control_mask = np.zeros([height, width, 3], dtype=np.uint8) # Ensure dtype is uint8
300
- x1, y1, x2, y2 = face_info["bbox"]
301
- x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
302
  control_mask[y1:y2, x1:x2] = 255
303
  control_mask = Image.fromarray(control_mask)
304
 
305
- controlnet_scales = {
306
- "pose": pose_strength,
307
- "canny": canny_strength
308
- }
309
  self.pipe.controlnet = MultiControlNetModel(
310
- [self.controlnet_identitynet]
311
- + [self.controlnet_map[s] for s in controlnet_selection]
312
  )
313
- control_scales = [float(identitynet_strength_ratio)] + [
314
- controlnet_scales[s] for s in controlnet_selection
315
- ]
316
- control_images = [face_kps] + [
317
- self.controlnet_map_fn[s](img_controlnet).resize((width, height))
318
- for s in controlnet_selection
319
- ]
320
-
321
- generator = torch.Generator(device=device).manual_seed(seed)
322
 
323
- print("Start inference...")
324
- print(f"[Debug] Prompt: {inputs}, \n[Debug] Neg Prompt: {negative_prompt}")
325
 
326
- self.pipe.set_ip_adapter_scale(adapter_strength_ratio)
327
  outputs = self.pipe(
328
  prompt=inputs,
329
  negative_prompt=negative_prompt,
330
- image_embeds=face_emb,
331
  image=control_images,
332
  control_mask=control_mask,
333
  controlnet_conditioning_scale=control_scales,
334
- num_inference_steps=num_steps,
335
- guidance_scale=guidance_scale,
336
  height=height,
337
  width=width,
338
  generator=generator,
339
- enhance_face_region=enhance_face_region
340
  )
341
-
342
  images = outputs.images
343
 
344
- # Check for NSFW content
345
  if self.is_nsfw(images[0]):
346
  return {"error": "Generated image contains NSFW content and was discarded."}
347
 
348
- # Convert the output image to base64
349
  buffered = io.BytesIO()
350
  images[0].save(buffered, format="JPEG")
351
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
352
 
353
  return {"generated_image_base64": img_str}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import cv2
2
  import torch
3
  import numpy as np
 
4
  from PIL import Image
5
+ from typing import Tuple, List, Optional, Dict, Any
6
  from pydantic import BaseModel
7
  import diffusers
8
  from diffusers.utils import load_image
 
17
  from huggingface_hub import hf_hub_download
18
  import base64
19
  import io
 
20
  from transformers import CLIPProcessor, CLIPModel
21
  import onnxruntime as ort
22
 
23
+ # Global variables
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
  dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
26
  STYLE_NAMES = list(styles.keys())
 
50
  pose_image_base64: Optional[str] = None
51
 
52
  class EndpointHandler:
53
+ def __init__(self, model_dir=""):
54
  # Ensure the necessary files are downloaded
55
  controlnet_config = hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir=os.path.join(model_dir, "checkpoints"))
56
  controlnet_model = hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir=os.path.join(model_dir, "checkpoints"))
 
58
 
59
  # Load the ONNX model
60
  onnx_model_path = os.path.join(model_dir, "models", "version-RFB-320.onnx")
61
+ if not os.path.exists(onnx_model_path)):
62
  print(f"Model path {onnx_model_path} does not exist. Please ensure the model is available.")
63
  self.ort_session = ort.InferenceSession(onnx_model_path)
64
 
65
+ self.openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
66
 
67
  # Path to InstantID models
68
  controlnet_path = os.path.join(model_dir, "checkpoints", "ControlNetModel")
69
 
70
  # Load pipeline face ControlNetModel
71
+ self.controlnet_identitynet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ # Load custom ControlNet models
74
+ self.controlnet_pose = ControlNetModel.from_pretrained("thibaud/controlnet-openpose-sdxl-1.0", torch_dtype=dtype).to(device)
75
+ self.controlnet_canny = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0", torch_dtype=dtype).to(device)
 
76
 
77
+ # ControlNet map
78
  self.controlnet_map = {
79
+ "pose": self.controlnet_pose,
80
+ "canny": self.controlnet_canny
81
  }
82
 
83
  self.controlnet_map_fn = {
84
+ "pose": self.openpose,
85
+ "canny": self.get_canny_image
86
  }
87
 
88
+ pretrained_model_name_or_path = "stablediffusionapi/protovision-xl-high-fidel"
89
 
90
  self.pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
91
  pretrained_model_name_or_path,
 
99
  self.pipe.scheduler.config
100
  )
101
 
102
+ # Load and disable LCM
103
  self.pipe.load_lora_weights(lcm_lora_path)
104
  self.pipe.fuse_lora()
105
  self.pipe.disable_lora()
 
113
  self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
114
  self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
115
 
116
+ def get_canny_image(self, image, t1=100, t2=200):
117
+ image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
118
+ edges = cv2.Canny(image, t1, t2)
119
+ return Image.fromarray(edges, "L")
 
 
120
 
121
+ def is_nsfw(self, image: Image.Image) -> bool:
 
 
122
  inputs = self.clip_processor(text=["NSFW", "SFW"], images=image, return_tensors="pt", padding=True)
123
  inputs = {k: v.to(device) for k, v in inputs.items()}
124
  outputs = self.clip_model(**inputs)
125
+ logits_per_image = outputs.logits_per_image # image-text similarity score
126
+ probs = logits_per_image.softmax(dim=1) # probabilities
127
  nsfw_prob = probs[0, 0].item() # probability of "NSFW" label
128
+ return nsfw_prob > 0.9 # threshold for NSFW detection
129
 
130
  def preprocess(self, image):
 
 
131
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
132
+ image = cv2.resize(image, (320, 240))
133
+ image_mean = np.array([127, 127, 127])
134
+ image = (image - image_mean) / 128
135
+ image = np.transpose(image, [2, 0, 1])
136
+ image = np.expand_dims(image, axis=0)
137
+ image = image.astype(np.float32)
138
  return image
139
 
140
  def get_face_info(self, image):
141
+ preprocessed_image = self.preprocess(image)
 
 
 
142
  input_name = self.ort_session.get_inputs()[0].name
143
+ confidences, boxes = self.ort_session.run(None, {input_name: preprocessed_image})
144
 
 
 
145
  face_info_list = []
146
+ for i in range(boxes.shape[1]):
147
+ box = boxes[0, i, :]
148
+ conf = confidences[0, i, 1]
149
+ if conf > 0.7:
150
+ x1, y1, x2, y2 = box[0] * 320, box[1] * 240, box[2] * 320, box[3] * 240
151
+ face_info_list.append({"bbox": [x1, y1, x2, y2]})
 
 
152
  return face_info_list
153
 
154
+ def __call__(self, data: Any) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  request = GenerateImageRequest(**data)
156
+
157
+ if request.enable_LCM:
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  self.pipe.enable_lora()
159
  self.pipe.scheduler = diffusers.LCMScheduler.from_config(self.pipe.scheduler.config)
160
+ guidance_scale = min(max(request.guidance_scale, 0), 1)
161
  else:
162
  self.pipe.disable_lora()
163
  self.pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(self.pipe.scheduler.config)
164
 
165
+ # Apply style
166
+ inputs, negative_prompt = self.apply_style(request.style, request.inputs, request.negative_prompt)
 
 
 
 
 
167
 
168
+ # Decode base64 images
169
+ face_image = self.decode_base64_image(request.face_image_base64)
170
+ pose_image = self.decode_base64_image(request.pose_image_base64) if request.pose_image_base64 else None
 
 
171
 
172
+ face_image = self.resize_img(face_image, max_side=1024)
173
+ face_image_cv2 = self.convert_from_image_to_cv2(face_image)
174
  height, width, _ = face_image_cv2.shape
175
 
176
+ # Extract face features
177
  face_info_list = self.get_face_info(face_image_cv2)
 
178
  if len(face_info_list) == 0:
179
  return {"error": "No faces detected."}
180
 
 
181
  face_info = max(face_info_list, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))
182
+ face_kps = draw_kps(self.convert_from_cv2_to_image(face_image_cv2), face_info["bbox"])
 
183
  img_controlnet = face_image
184
 
185
  if pose_image:
186
+ pose_image = self.resize_img(pose_image, max_side=1024)
187
  img_controlnet = pose_image
188
+ pose_image_cv2 = self.convert_from_image_to_cv2(pose_image)
189
 
 
190
  face_info_list = self.get_face_info(pose_image_cv2)
 
191
  if len(face_info_list) == 0:
192
  return {"error": "No faces detected in pose image."}
193
 
194
  face_info = max(face_info_list, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))
 
195
  face_kps = draw_kps(pose_image, face_info["bbox"])
 
196
  width, height = face_kps.size
197
 
198
+ control_mask = np.zeros([height, width, 3], dtype=np.uint8)
199
+ x1, y1, x2, y2 = map(int, face_info["bbox"])
 
200
  control_mask[y1:y2, x1:x2] = 255
201
  control_mask = Image.fromarray(control_mask)
202
 
203
+ controlnet_scales = {"pose": request.pose_strength, "canny": request.canny_strength}
 
 
 
204
  self.pipe.controlnet = MultiControlNetModel(
205
+ [self.controlnet_identitynet] + [self.controlnet_map[s] for s in request.controlnet_selection]
 
206
  )
207
+ control_scales = [float(request.identitynet_strength_ratio)] + [controlnet_scales[s] for s in request.controlnet_selection]
208
+ control_images = [face_kps] + [self.controlnet_map_fn[s](img_controlnet).resize((width, height)) for s in request.controlnet_selection]
 
 
 
 
 
 
 
209
 
210
+ generator = torch.Generator(device=device).manual_seed(request.seed)
 
211
 
 
212
  outputs = self.pipe(
213
  prompt=inputs,
214
  negative_prompt=negative_prompt,
 
215
  image=control_images,
216
  control_mask=control_mask,
217
  controlnet_conditioning_scale=control_scales,
218
+ num_inference_steps=request.num_steps,
219
+ guidance_scale=request.guidance_scale,
220
  height=height,
221
  width=width,
222
  generator=generator,
223
+ enhance_face_region=request.enhance_face_region,
224
  )
225
+
226
  images = outputs.images
227
 
 
228
  if self.is_nsfw(images[0]):
229
  return {"error": "Generated image contains NSFW content and was discarded."}
230
 
231
+ # Convert the image to base64
232
  buffered = io.BytesIO()
233
  images[0].save(buffered, format="JPEG")
234
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
235
 
236
  return {"generated_image_base64": img_str}
237
+
238
+ def decode_base64_image(self, image_string):
239
+ base64_image = base64.b64decode(image_string)
240
+ buffer = io.BytesIO(base64_image)
241
+ return Image.open(buffer)
242
+
243
+ def convert_from_cv2_to_image(self, img: np.ndarray) -> Image:
244
+ return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
245
+
246
+ def convert_from_image_to_cv2(self, img: Image) -> np.ndarray:
247
+ return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
248
+
249
+ def resize_img(self, input_image, max_side=1280, min_side=1024, size=None, pad_to_max_side=False, mode=PIL.Image.BILINEAR, base_pixel_number=64):
250
+ w, h = input_image.size
251
+ if size is not None:
252
+ w_resize_new, h_resize_new = size
253
+ else:
254
+ ratio = min_side / min(h, w)
255
+ w, h = round(ratio * w), round(ratio * h)
256
+ ratio = max_side / max(h, w)
257
+ input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode)
258
+ w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
259
+ h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
260
+ input_image = input_image.resize([w_resize_new, h_resize_new], mode)
261
+
262
+ if pad_to_max_side:
263
+ res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
264
+ offset_x = (max_side - w_resize_new) // 2
265
+ offset_y = (max_side - h_resize_new) // 2
266
+ res[offset_y: offset_y + h_resize_new, offset_x: offset_x + w_resize_new] = np.array(input_image)
267
+ input_image = Image.fromarray(res)
268
+ return input_image
269
+
270
+ def apply_style(self, style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
271
+ p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
272
+ return p.replace("{prompt}", positive), n + " " + negative