Jacobmadwed commited on
Commit
4d248cf
·
verified ·
1 Parent(s): 9e431fc

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +40 -40
handler.py CHANGED
@@ -3,8 +3,7 @@ import torch
3
  import numpy as np
4
  import PIL
5
  from PIL import Image
6
- from typing import Tuple, List
7
- from pydantic import BaseModel
8
  import diffusers
9
  from diffusers.utils import load_image
10
  from diffusers.models import ControlNetModel
@@ -27,22 +26,6 @@ dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
27
  STYLE_NAMES = list(styles.keys())
28
  DEFAULT_STYLE_NAME = "Spring Festival"
29
 
30
- class GenerateImageRequest(BaseModel):
31
- prompt: str
32
- negative_prompt: str
33
- style: str
34
- num_steps: int
35
- identitynet_strength_ratio: float
36
- adapter_strength_ratio: float
37
- pose_strength: float
38
- canny_strength: float
39
- depth_strength: float
40
- controlnet_selection: List[str]
41
- guidance_scale: float
42
- seed: int
43
- enable_LCM: bool
44
- enhance_face_region: bool
45
-
46
  class EndpointHandler:
47
  def __init__(self, model_dir):
48
  # Ensure the necessary files are downloaded
@@ -163,33 +146,50 @@ class EndpointHandler:
163
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
164
  return p.replace("{prompt}", positive), n + " " + negative
165
 
166
- request = GenerateImageRequest(**data)
167
- prompt = request.prompt
168
- negative_prompt = request.negative_prompt
169
- style_name = request.style
170
- identitynet_strength_ratio = request.identitynet_strength_ratio
171
- adapter_strength_ratio = request.adapter_strength_ratio
172
- pose_strength = request.pose_strength
173
- canny_strength = request.canny_strength
174
- num_steps = request.num_steps
175
- guidance_scale = request.guidance_scale
176
- controlnet_selection = request.controlnet_selection
177
- seed = request.seed
178
- enhance_face_region = request.enhance_face_region
179
- enable_LCM = request.enable_LCM
180
-
181
- self.pipe.disable_lora() if not enable_LCM else self.pipe.enable_lora()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
- scheduler_class_name = "EulerDiscreteScheduler"
184
 
185
- self.pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(self.pipe.scheduler.config)
 
 
 
 
 
 
186
 
187
  # apply the style template
188
  prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
189
 
190
- face_image_path = data.get("face_image_path", "https://i.ibb.co/GQzm527/examples-musk-resize.jpg")
191
- pose_image_path = data.get("pose_image_path", "https://i.ibb.co/TRCK4MS/examples-poses-pose2.jpg")
192
-
193
  face_image = load_image(face_image_path)
194
  face_image = resize_img(face_image, max_side=1024)
195
  face_image_cv2 = convert_from_image_to_cv2(face_image)
@@ -246,6 +246,7 @@ class EndpointHandler:
246
 
247
  print("Start inference...")
248
  print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}")
 
249
 
250
  self.pipe.set_ip_adapter_scale(adapter_strength_ratio)
251
  images = self.pipe(
@@ -260,7 +261,6 @@ class EndpointHandler:
260
  height=height,
261
  width=width,
262
  generator=generator,
263
- enhance_face_region=enhance_face_region
264
  ).images
265
 
266
  # Convert the output image to base64
 
3
  import numpy as np
4
  import PIL
5
  from PIL import Image
6
+ from typing import Tuple
 
7
  import diffusers
8
  from diffusers.utils import load_image
9
  from diffusers.models import ControlNetModel
 
26
  STYLE_NAMES = list(styles.keys())
27
  DEFAULT_STYLE_NAME = "Spring Festival"
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  class EndpointHandler:
30
  def __init__(self, model_dir):
31
  # Ensure the necessary files are downloaded
 
146
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
147
  return p.replace("{prompt}", positive), n + " " + negative
148
 
149
+ face_image_path = data.pop("face_image_path", "https://i.ibb.co/GQzm527/examples-musk-resize.jpg")
150
+ pose_image_path = data.pop("pose_image_path", "https://i.ibb.co/TRCK4MS/examples-poses-pose2.jpg")
151
+ style_name = data.pop("style_name", DEFAULT_STYLE_NAME)
152
+ prompt = data.pop("inputs", "a man flying in the sky in Mars")
153
+ negative_prompt = data.pop("negative_prompt", "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green")
154
+
155
+ identitynet_strength_ratio = data.pop("identitynet_strength_ratio", 0.8)
156
+ adapter_strength_ratio = data.pop("adapter_strength_ratio", 0.8)
157
+ pose_strength = data.pop("pose_strength", 0.5)
158
+ canny_strength = data.pop("canny_strength", 0.3)
159
+ num_steps = data.pop("num_steps", 20)
160
+ guidance_scale = data.pop("guidance_scale", 5.0)
161
+ controlnet_selection = data.pop("controlnet_selection", ["pose", "canny"])
162
+ scheduler = data.pop("scheduler", "EulerDiscreteScheduler")
163
+ enable_fast_inference = data.pop("enable_fast_inference", False)
164
+ enhance_non_face_region = data.pop("enhance_non_face_region", False)
165
+ seed = data.pop("seed", 42)
166
+
167
+ # Ensure required fields are present
168
+ data.setdefault("prompt", prompt)
169
+ data.setdefault("style", style_name)
170
+ data.setdefault("num_steps", num_steps)
171
+ data.setdefault("enable_LCM", enable_fast_inference)
172
+ data.setdefault("enhance_face_region", enhance_non_face_region)
173
+
174
+ # Enable LCM if fast inference is enabled
175
+ if enable_fast_inference:
176
+ self.pipe.enable_lora()
177
+ else:
178
+ self.pipe.disable_lora()
179
 
180
+ scheduler_class_name = scheduler.split("-")[0]
181
 
182
+ add_kwargs = {}
183
+ if len(scheduler.split("-")) > 1:
184
+ add_kwargs["use_karras_sigmas"] = True
185
+ if len(scheduler.split("-")) > 2:
186
+ add_kwargs["algorithm_type"] = "sde-dpmsolver++"
187
+ scheduler = getattr(diffusers, scheduler_class_name)
188
+ self.pipe.scheduler = scheduler.from_config(self.pipe.scheduler.config, **add_kwargs)
189
 
190
  # apply the style template
191
  prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
192
 
 
 
 
193
  face_image = load_image(face_image_path)
194
  face_image = resize_img(face_image, max_side=1024)
195
  face_image_cv2 = convert_from_image_to_cv2(face_image)
 
246
 
247
  print("Start inference...")
248
  print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}")
249
+ print(f"[Debug] Number of Inference Steps: {num_steps}")
250
 
251
  self.pipe.set_ip_adapter_scale(adapter_strength_ratio)
252
  images = self.pipe(
 
261
  height=height,
262
  width=width,
263
  generator=generator,
 
264
  ).images
265
 
266
  # Convert the output image to base64