Update handler.py
Browse files- handler.py +40 -40
handler.py
CHANGED
@@ -3,8 +3,7 @@ import torch
|
|
3 |
import numpy as np
|
4 |
import PIL
|
5 |
from PIL import Image
|
6 |
-
from typing import Tuple
|
7 |
-
from pydantic import BaseModel
|
8 |
import diffusers
|
9 |
from diffusers.utils import load_image
|
10 |
from diffusers.models import ControlNetModel
|
@@ -27,22 +26,6 @@ dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
|
|
27 |
STYLE_NAMES = list(styles.keys())
|
28 |
DEFAULT_STYLE_NAME = "Spring Festival"
|
29 |
|
30 |
-
class GenerateImageRequest(BaseModel):
|
31 |
-
prompt: str
|
32 |
-
negative_prompt: str
|
33 |
-
style: str
|
34 |
-
num_steps: int
|
35 |
-
identitynet_strength_ratio: float
|
36 |
-
adapter_strength_ratio: float
|
37 |
-
pose_strength: float
|
38 |
-
canny_strength: float
|
39 |
-
depth_strength: float
|
40 |
-
controlnet_selection: List[str]
|
41 |
-
guidance_scale: float
|
42 |
-
seed: int
|
43 |
-
enable_LCM: bool
|
44 |
-
enhance_face_region: bool
|
45 |
-
|
46 |
class EndpointHandler:
|
47 |
def __init__(self, model_dir):
|
48 |
# Ensure the necessary files are downloaded
|
@@ -163,33 +146,50 @@ class EndpointHandler:
|
|
163 |
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
|
164 |
return p.replace("{prompt}", positive), n + " " + negative
|
165 |
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
-
scheduler_class_name = "
|
184 |
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
# apply the style template
|
188 |
prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
|
189 |
|
190 |
-
face_image_path = data.get("face_image_path", "https://i.ibb.co/GQzm527/examples-musk-resize.jpg")
|
191 |
-
pose_image_path = data.get("pose_image_path", "https://i.ibb.co/TRCK4MS/examples-poses-pose2.jpg")
|
192 |
-
|
193 |
face_image = load_image(face_image_path)
|
194 |
face_image = resize_img(face_image, max_side=1024)
|
195 |
face_image_cv2 = convert_from_image_to_cv2(face_image)
|
@@ -246,6 +246,7 @@ class EndpointHandler:
|
|
246 |
|
247 |
print("Start inference...")
|
248 |
print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}")
|
|
|
249 |
|
250 |
self.pipe.set_ip_adapter_scale(adapter_strength_ratio)
|
251 |
images = self.pipe(
|
@@ -260,7 +261,6 @@ class EndpointHandler:
|
|
260 |
height=height,
|
261 |
width=width,
|
262 |
generator=generator,
|
263 |
-
enhance_face_region=enhance_face_region
|
264 |
).images
|
265 |
|
266 |
# Convert the output image to base64
|
|
|
3 |
import numpy as np
|
4 |
import PIL
|
5 |
from PIL import Image
|
6 |
+
from typing import Tuple
|
|
|
7 |
import diffusers
|
8 |
from diffusers.utils import load_image
|
9 |
from diffusers.models import ControlNetModel
|
|
|
26 |
STYLE_NAMES = list(styles.keys())
|
27 |
DEFAULT_STYLE_NAME = "Spring Festival"
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
class EndpointHandler:
|
30 |
def __init__(self, model_dir):
|
31 |
# Ensure the necessary files are downloaded
|
|
|
146 |
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
|
147 |
return p.replace("{prompt}", positive), n + " " + negative
|
148 |
|
149 |
+
face_image_path = data.pop("face_image_path", "https://i.ibb.co/GQzm527/examples-musk-resize.jpg")
|
150 |
+
pose_image_path = data.pop("pose_image_path", "https://i.ibb.co/TRCK4MS/examples-poses-pose2.jpg")
|
151 |
+
style_name = data.pop("style_name", DEFAULT_STYLE_NAME)
|
152 |
+
prompt = data.pop("inputs", "a man flying in the sky in Mars")
|
153 |
+
negative_prompt = data.pop("negative_prompt", "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green")
|
154 |
+
|
155 |
+
identitynet_strength_ratio = data.pop("identitynet_strength_ratio", 0.8)
|
156 |
+
adapter_strength_ratio = data.pop("adapter_strength_ratio", 0.8)
|
157 |
+
pose_strength = data.pop("pose_strength", 0.5)
|
158 |
+
canny_strength = data.pop("canny_strength", 0.3)
|
159 |
+
num_steps = data.pop("num_steps", 20)
|
160 |
+
guidance_scale = data.pop("guidance_scale", 5.0)
|
161 |
+
controlnet_selection = data.pop("controlnet_selection", ["pose", "canny"])
|
162 |
+
scheduler = data.pop("scheduler", "EulerDiscreteScheduler")
|
163 |
+
enable_fast_inference = data.pop("enable_fast_inference", False)
|
164 |
+
enhance_non_face_region = data.pop("enhance_non_face_region", False)
|
165 |
+
seed = data.pop("seed", 42)
|
166 |
+
|
167 |
+
# Ensure required fields are present
|
168 |
+
data.setdefault("prompt", prompt)
|
169 |
+
data.setdefault("style", style_name)
|
170 |
+
data.setdefault("num_steps", num_steps)
|
171 |
+
data.setdefault("enable_LCM", enable_fast_inference)
|
172 |
+
data.setdefault("enhance_face_region", enhance_non_face_region)
|
173 |
+
|
174 |
+
# Enable LCM if fast inference is enabled
|
175 |
+
if enable_fast_inference:
|
176 |
+
self.pipe.enable_lora()
|
177 |
+
else:
|
178 |
+
self.pipe.disable_lora()
|
179 |
|
180 |
+
scheduler_class_name = scheduler.split("-")[0]
|
181 |
|
182 |
+
add_kwargs = {}
|
183 |
+
if len(scheduler.split("-")) > 1:
|
184 |
+
add_kwargs["use_karras_sigmas"] = True
|
185 |
+
if len(scheduler.split("-")) > 2:
|
186 |
+
add_kwargs["algorithm_type"] = "sde-dpmsolver++"
|
187 |
+
scheduler = getattr(diffusers, scheduler_class_name)
|
188 |
+
self.pipe.scheduler = scheduler.from_config(self.pipe.scheduler.config, **add_kwargs)
|
189 |
|
190 |
# apply the style template
|
191 |
prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
|
192 |
|
|
|
|
|
|
|
193 |
face_image = load_image(face_image_path)
|
194 |
face_image = resize_img(face_image, max_side=1024)
|
195 |
face_image_cv2 = convert_from_image_to_cv2(face_image)
|
|
|
246 |
|
247 |
print("Start inference...")
|
248 |
print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}")
|
249 |
+
print(f"[Debug] Number of Inference Steps: {num_steps}")
|
250 |
|
251 |
self.pipe.set_ip_adapter_scale(adapter_strength_ratio)
|
252 |
images = self.pipe(
|
|
|
261 |
height=height,
|
262 |
width=width,
|
263 |
generator=generator,
|
|
|
264 |
).images
|
265 |
|
266 |
# Convert the output image to base64
|