Jacobmadwed commited on
Commit
7f255da
·
verified ·
1 Parent(s): 4d248cf

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +40 -40
handler.py CHANGED
@@ -3,7 +3,8 @@ import torch
3
  import numpy as np
4
  import PIL
5
  from PIL import Image
6
- from typing import Tuple
 
7
  import diffusers
8
  from diffusers.utils import load_image
9
  from diffusers.models import ControlNetModel
@@ -26,6 +27,22 @@ dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
26
  STYLE_NAMES = list(styles.keys())
27
  DEFAULT_STYLE_NAME = "Spring Festival"
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  class EndpointHandler:
30
  def __init__(self, model_dir):
31
  # Ensure the necessary files are downloaded
@@ -146,50 +163,33 @@ class EndpointHandler:
146
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
147
  return p.replace("{prompt}", positive), n + " " + negative
148
 
149
- face_image_path = data.pop("face_image_path", "https://i.ibb.co/GQzm527/examples-musk-resize.jpg")
150
- pose_image_path = data.pop("pose_image_path", "https://i.ibb.co/TRCK4MS/examples-poses-pose2.jpg")
151
- style_name = data.pop("style_name", DEFAULT_STYLE_NAME)
152
- prompt = data.pop("inputs", "a man flying in the sky in Mars")
153
- negative_prompt = data.pop("negative_prompt", "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green")
154
-
155
- identitynet_strength_ratio = data.pop("identitynet_strength_ratio", 0.8)
156
- adapter_strength_ratio = data.pop("adapter_strength_ratio", 0.8)
157
- pose_strength = data.pop("pose_strength", 0.5)
158
- canny_strength = data.pop("canny_strength", 0.3)
159
- num_steps = data.pop("num_steps", 20)
160
- guidance_scale = data.pop("guidance_scale", 5.0)
161
- controlnet_selection = data.pop("controlnet_selection", ["pose", "canny"])
162
- scheduler = data.pop("scheduler", "EulerDiscreteScheduler")
163
- enable_fast_inference = data.pop("enable_fast_inference", False)
164
- enhance_non_face_region = data.pop("enhance_non_face_region", False)
165
- seed = data.pop("seed", 42)
166
-
167
- # Ensure required fields are present
168
- data.setdefault("prompt", prompt)
169
- data.setdefault("style", style_name)
170
- data.setdefault("num_steps", num_steps)
171
- data.setdefault("enable_LCM", enable_fast_inference)
172
- data.setdefault("enhance_face_region", enhance_non_face_region)
173
-
174
- # Enable LCM if fast inference is enabled
175
- if enable_fast_inference:
176
- self.pipe.enable_lora()
177
- else:
178
- self.pipe.disable_lora()
179
 
180
- scheduler_class_name = scheduler.split("-")[0]
181
 
182
- add_kwargs = {}
183
- if len(scheduler.split("-")) > 1:
184
- add_kwargs["use_karras_sigmas"] = True
185
- if len(scheduler.split("-")) > 2:
186
- add_kwargs["algorithm_type"] = "sde-dpmsolver++"
187
- scheduler = getattr(diffusers, scheduler_class_name)
188
- self.pipe.scheduler = scheduler.from_config(self.pipe.scheduler.config, **add_kwargs)
189
 
190
  # apply the style template
191
  prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
192
 
 
 
 
193
  face_image = load_image(face_image_path)
194
  face_image = resize_img(face_image, max_side=1024)
195
  face_image_cv2 = convert_from_image_to_cv2(face_image)
@@ -246,7 +246,6 @@ class EndpointHandler:
246
 
247
  print("Start inference...")
248
  print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}")
249
- print(f"[Debug] Number of Inference Steps: {num_steps}")
250
 
251
  self.pipe.set_ip_adapter_scale(adapter_strength_ratio)
252
  images = self.pipe(
@@ -261,6 +260,7 @@ class EndpointHandler:
261
  height=height,
262
  width=width,
263
  generator=generator,
 
264
  ).images
265
 
266
  # Convert the output image to base64
 
3
  import numpy as np
4
  import PIL
5
  from PIL import Image
6
+ from typing import Tuple, List
7
+ from pydantic import BaseModel
8
  import diffusers
9
  from diffusers.utils import load_image
10
  from diffusers.models import ControlNetModel
 
27
  STYLE_NAMES = list(styles.keys())
28
  DEFAULT_STYLE_NAME = "Spring Festival"
29
 
30
+ class GenerateImageRequest(BaseModel):
31
+ prompt: str
32
+ negative_prompt: str
33
+ style: str
34
+ num_steps: int
35
+ identitynet_strength_ratio: float
36
+ adapter_strength_ratio: float
37
+ pose_strength: float
38
+ canny_strength: float
39
+ depth_strength: float
40
+ controlnet_selection: List[str]
41
+ guidance_scale: float
42
+ seed: int
43
+ enable_LCM: bool
44
+ enhance_face_region: bool
45
+
46
  class EndpointHandler:
47
  def __init__(self, model_dir):
48
  # Ensure the necessary files are downloaded
 
163
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
164
  return p.replace("{prompt}", positive), n + " " + negative
165
 
166
+ request = GenerateImageRequest(**data)
167
+ prompt = request.prompt
168
+ negative_prompt = request.negative_prompt
169
+ style_name = request.style
170
+ identitynet_strength_ratio = request.identitynet_strength_ratio
171
+ adapter_strength_ratio = request.adapter_strength_ratio
172
+ pose_strength = request.pose_strength
173
+ canny_strength = request.canny_strength
174
+ num_steps = request.num_steps
175
+ guidance_scale = request.guidance_scale
176
+ controlnet_selection = request.controlnet_selection
177
+ seed = request.seed
178
+ enhance_face_region = request.enhance_face_region
179
+ enable_LCM = request.enable_LCM
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
+ self.pipe.disable_lora() if not enable_LCM else self.pipe.enable_lora()
182
 
183
+ scheduler_class_name = "EulerDiscreteScheduler"
184
+
185
+ self.pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(self.pipe.scheduler.config)
 
 
 
 
186
 
187
  # apply the style template
188
  prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
189
 
190
+ face_image_path = data.get("face_image_path", "https://i.ibb.co/GQzm527/examples-musk-resize.jpg")
191
+ pose_image_path = data.get("pose_image_path", "https://i.ibb.co/TRCK4MS/examples-poses-pose2.jpg")
192
+
193
  face_image = load_image(face_image_path)
194
  face_image = resize_img(face_image, max_side=1024)
195
  face_image_cv2 = convert_from_image_to_cv2(face_image)
 
246
 
247
  print("Start inference...")
248
  print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}")
 
249
 
250
  self.pipe.set_ip_adapter_scale(adapter_strength_ratio)
251
  images = self.pipe(
 
260
  height=height,
261
  width=width,
262
  generator=generator,
263
+ enhance_face_region=enhance_face_region
264
  ).images
265
 
266
  # Convert the output image to base64