Update handler.py
Browse files- handler.py +40 -40
handler.py
CHANGED
@@ -3,7 +3,8 @@ import torch
|
|
3 |
import numpy as np
|
4 |
import PIL
|
5 |
from PIL import Image
|
6 |
-
from typing import Tuple
|
|
|
7 |
import diffusers
|
8 |
from diffusers.utils import load_image
|
9 |
from diffusers.models import ControlNetModel
|
@@ -26,6 +27,22 @@ dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
|
|
26 |
STYLE_NAMES = list(styles.keys())
|
27 |
DEFAULT_STYLE_NAME = "Spring Festival"
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
class EndpointHandler:
|
30 |
def __init__(self, model_dir):
|
31 |
# Ensure the necessary files are downloaded
|
@@ -146,50 +163,33 @@ class EndpointHandler:
|
|
146 |
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
|
147 |
return p.replace("{prompt}", positive), n + " " + negative
|
148 |
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
enable_fast_inference = data.pop("enable_fast_inference", False)
|
164 |
-
enhance_non_face_region = data.pop("enhance_non_face_region", False)
|
165 |
-
seed = data.pop("seed", 42)
|
166 |
-
|
167 |
-
# Ensure required fields are present
|
168 |
-
data.setdefault("prompt", prompt)
|
169 |
-
data.setdefault("style", style_name)
|
170 |
-
data.setdefault("num_steps", num_steps)
|
171 |
-
data.setdefault("enable_LCM", enable_fast_inference)
|
172 |
-
data.setdefault("enhance_face_region", enhance_non_face_region)
|
173 |
-
|
174 |
-
# Enable LCM if fast inference is enabled
|
175 |
-
if enable_fast_inference:
|
176 |
-
self.pipe.enable_lora()
|
177 |
-
else:
|
178 |
-
self.pipe.disable_lora()
|
179 |
|
180 |
-
|
181 |
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
if len(scheduler.split("-")) > 2:
|
186 |
-
add_kwargs["algorithm_type"] = "sde-dpmsolver++"
|
187 |
-
scheduler = getattr(diffusers, scheduler_class_name)
|
188 |
-
self.pipe.scheduler = scheduler.from_config(self.pipe.scheduler.config, **add_kwargs)
|
189 |
|
190 |
# apply the style template
|
191 |
prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
|
192 |
|
|
|
|
|
|
|
193 |
face_image = load_image(face_image_path)
|
194 |
face_image = resize_img(face_image, max_side=1024)
|
195 |
face_image_cv2 = convert_from_image_to_cv2(face_image)
|
@@ -246,7 +246,6 @@ class EndpointHandler:
|
|
246 |
|
247 |
print("Start inference...")
|
248 |
print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}")
|
249 |
-
print(f"[Debug] Number of Inference Steps: {num_steps}")
|
250 |
|
251 |
self.pipe.set_ip_adapter_scale(adapter_strength_ratio)
|
252 |
images = self.pipe(
|
@@ -261,6 +260,7 @@ class EndpointHandler:
|
|
261 |
height=height,
|
262 |
width=width,
|
263 |
generator=generator,
|
|
|
264 |
).images
|
265 |
|
266 |
# Convert the output image to base64
|
|
|
3 |
import numpy as np
|
4 |
import PIL
|
5 |
from PIL import Image
|
6 |
+
from typing import Tuple, List
|
7 |
+
from pydantic import BaseModel
|
8 |
import diffusers
|
9 |
from diffusers.utils import load_image
|
10 |
from diffusers.models import ControlNetModel
|
|
|
27 |
STYLE_NAMES = list(styles.keys())
|
28 |
DEFAULT_STYLE_NAME = "Spring Festival"
|
29 |
|
30 |
+
class GenerateImageRequest(BaseModel):
|
31 |
+
prompt: str
|
32 |
+
negative_prompt: str
|
33 |
+
style: str
|
34 |
+
num_steps: int
|
35 |
+
identitynet_strength_ratio: float
|
36 |
+
adapter_strength_ratio: float
|
37 |
+
pose_strength: float
|
38 |
+
canny_strength: float
|
39 |
+
depth_strength: float
|
40 |
+
controlnet_selection: List[str]
|
41 |
+
guidance_scale: float
|
42 |
+
seed: int
|
43 |
+
enable_LCM: bool
|
44 |
+
enhance_face_region: bool
|
45 |
+
|
46 |
class EndpointHandler:
|
47 |
def __init__(self, model_dir):
|
48 |
# Ensure the necessary files are downloaded
|
|
|
163 |
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
|
164 |
return p.replace("{prompt}", positive), n + " " + negative
|
165 |
|
166 |
+
request = GenerateImageRequest(**data)
|
167 |
+
prompt = request.prompt
|
168 |
+
negative_prompt = request.negative_prompt
|
169 |
+
style_name = request.style
|
170 |
+
identitynet_strength_ratio = request.identitynet_strength_ratio
|
171 |
+
adapter_strength_ratio = request.adapter_strength_ratio
|
172 |
+
pose_strength = request.pose_strength
|
173 |
+
canny_strength = request.canny_strength
|
174 |
+
num_steps = request.num_steps
|
175 |
+
guidance_scale = request.guidance_scale
|
176 |
+
controlnet_selection = request.controlnet_selection
|
177 |
+
seed = request.seed
|
178 |
+
enhance_face_region = request.enhance_face_region
|
179 |
+
enable_LCM = request.enable_LCM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
+
self.pipe.disable_lora() if not enable_LCM else self.pipe.enable_lora()
|
182 |
|
183 |
+
scheduler_class_name = "EulerDiscreteScheduler"
|
184 |
+
|
185 |
+
self.pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(self.pipe.scheduler.config)
|
|
|
|
|
|
|
|
|
186 |
|
187 |
# apply the style template
|
188 |
prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
|
189 |
|
190 |
+
face_image_path = data.get("face_image_path", "https://i.ibb.co/GQzm527/examples-musk-resize.jpg")
|
191 |
+
pose_image_path = data.get("pose_image_path", "https://i.ibb.co/TRCK4MS/examples-poses-pose2.jpg")
|
192 |
+
|
193 |
face_image = load_image(face_image_path)
|
194 |
face_image = resize_img(face_image, max_side=1024)
|
195 |
face_image_cv2 = convert_from_image_to_cv2(face_image)
|
|
|
246 |
|
247 |
print("Start inference...")
|
248 |
print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}")
|
|
|
249 |
|
250 |
self.pipe.set_ip_adapter_scale(adapter_strength_ratio)
|
251 |
images = self.pipe(
|
|
|
260 |
height=height,
|
261 |
width=width,
|
262 |
generator=generator,
|
263 |
+
enhance_face_region=enhance_face_region
|
264 |
).images
|
265 |
|
266 |
# Convert the output image to base64
|