File size: 13,617 Bytes
e6544a3
 
 
032b71d
e6544a3
032b71d
7f255da
e6544a3
 
 
 
adbe0d1
e6544a3
 
 
8dd27f8
 
38521ba
 
5fd224a
 
032b71d
013f209
e6544a3
032b71d
e6544a3
 
 
 
38521ba
64e6528
 
 
 
 
7f255da
97df0ce
7f255da
 
 
 
 
 
 
 
 
 
 
 
 
c26f322
 
7f255da
38521ba
032b71d
38521ba
 
 
 
 
8dd27f8
 
 
 
 
adbe0d1
8dd27f8
 
032b71d
38521ba
 
 
 
 
032b71d
 
 
 
 
 
 
 
 
 
 
 
 
 
38521ba
032b71d
 
 
 
496a461
38521ba
032b71d
 
38521ba
 
 
032b71d
 
38521ba
 
032b71d
38521ba
 
 
 
 
3b27112
38521ba
 
 
 
 
e6544a3
 
032b71d
64e6528
 
c3a8672
734bbd3
d477ee0
 
 
 
496a461
3b27112
 
 
 
bb03b0c
032b71d
 
 
 
 
 
 
 
 
3b27112
 
 
032b71d
 
3b27112
032b71d
 
0b8c57d
 
 
 
 
 
 
 
 
 
adbe0d1
0b8c57d
 
 
 
 
 
 
 
032b71d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb03b0c
032b71d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64e6528
 
032b71d
64e6528
 
 
0968d03
032b71d
 
 
 
 
 
 
c26f322
032b71d
 
016f972
032b71d
 
7f255da
032b71d
 
38521ba
0968d03
0b8c57d
 
 
 
 
 
 
 
adbe0d1
8dd27f8
adbe0d1
 
 
 
 
 
 
c26f322
 
032b71d
38521ba
032b71d
0b8c57d
 
 
 
 
79ae586
38521ba
0968d03
032b71d
0b8c57d
032b71d
0968d03
032b71d
0968d03
032b71d
 
 
 
38521ba
032b71d
 
0968d03
032b71d
 
 
 
 
 
 
 
 
38521ba
032b71d
 
38521ba
032b71d
852fbd3
97df0ce
38521ba
adbe0d1
38521ba
 
 
032b71d
 
38521ba
 
 
032b71d
852fbd3
032b71d
852fbd3
 
032b71d
3b27112
852fbd3
38521ba
032b71d
38521ba
 
 
 
f73bdec
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
import cv2
import torch
import numpy as np
import PIL
from PIL import Image
from typing import Tuple, List, Optional
from pydantic import BaseModel
import diffusers
from diffusers.utils import load_image
from diffusers.models import ControlNetModel
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from insightface.app import FaceAnalysis
from style_template import styles
from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline, draw_kps
from controlnet_aux import OpenposeDetector
import torch.nn.functional as F
from torchvision.transforms import Compose
import os
from huggingface_hub import hf_hub_download
import base64
import io
import json
from transformers import CLIPProcessor, CLIPModel

# global variable
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
STYLE_NAMES = list(styles.keys())
DEFAULT_STYLE_NAME = "Spring Festival"

# Download LCM-LoRA model if not already downloaded
lcm_lora_path = "./checkpoints/pytorch_lora_weights.safetensors"
if not os.path.exists(lcm_lora_path):
    hf_hub_download(repo_id="latent-consistency/lcm-lora-sdxl", filename="pytorch_lora_weights.safetensors", local_dir="./checkpoints")

class GenerateImageRequest(BaseModel):
    inputs: str
    negative_prompt: str
    style: str
    num_steps: int
    identitynet_strength_ratio: float
    adapter_strength_ratio: float
    pose_strength: float
    canny_strength: float
    depth_strength: float
    controlnet_selection: List[str]
    guidance_scale: float
    seed: int
    enable_LCM: bool
    enhance_face_region: bool
    face_image_base64: str
    pose_image_base64: Optional[str] = None

class EndpointHandler:
    def __init__(self, model_dir):
        # Ensure the necessary files are downloaded
        controlnet_config = hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir=os.path.join(model_dir, "checkpoints"))
        controlnet_model = hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir=os.path.join(model_dir, "checkpoints"))
        face_adapter = hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir=os.path.join(model_dir, "checkpoints"))

        dir_path = os.path.join(model_dir, "models", "face_detection_yunet_2023mar_int8.onnx")
        if not os.path.exists(dir_path):
            raise RuntimeError(f"Model path {dir_path} does not exist.")
        else:
            self.face_net = cv2.dnn.readNet(dir_path)

        self.app = FaceAnalysis(name='model', root=model_dir, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
        self.app.prepare(ctx_id=0, det_size=(640, 640))
        openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")

        # Path to InstantID models
        controlnet_path = os.path.join(model_dir, "checkpoints", "ControlNetModel")

        # Load pipeline face ControlNetModel
        self.controlnet_identitynet = ControlNetModel.from_pretrained(
            controlnet_path, torch_dtype=dtype
        )

        # controlnet-pose
        controlnet_pose_model = "thibaud/controlnet-openpose-sdxl-1.0"
        controlnet_canny_model = "diffusers/controlnet-canny-sdxl-1.0"

        controlnet_pose = ControlNetModel.from_pretrained(
            controlnet_pose_model, torch_dtype=dtype
        ).to(device)
        controlnet_canny = ControlNetModel.from_pretrained(
            controlnet_canny_model, torch_dtype=dtype
        ).to(device)

        def get_canny_image(image, t1=100, t2=200):
            image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
            edges = cv2.Canny(image, t1, t2)
            return Image.fromarray(edges, "L")

        self.controlnet_map = {
            "pose": controlnet_pose,
            "canny": controlnet_canny
        }

        self.controlnet_map_fn = {
            "pose": openpose,
            "canny": get_canny_image
        }

        pretrained_model_name_or_path = "wangqixun/YamerMIX_v8"

        self.pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
            pretrained_model_name_or_path,
            controlnet=[self.controlnet_identitynet],
            torch_dtype=dtype,
            safety_checker=None,
            feature_extractor=None,
        ).to(device)

        self.pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(
            self.pipe.scheduler.config
        )

        # load and disable LCM
        self.pipe.load_lora_weights(lcm_lora_path)
        self.pipe.fuse_lora()
        self.pipe.disable_lora()

        self.pipe.cuda()
        self.pipe.load_ip_adapter_instantid(face_adapter)
        self.pipe.image_proj_model.to("cuda")
        self.pipe.unet.to("cuda")

        # Load CLIP model for safety checking
        self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)

    def is_nsfw(self, image: Image.Image) -> bool:
        """
        Check if an image contains NSFW content using CLIP model.

        Args:
            image (Image.Image): PIL image to check.

        Returns:
            bool: True if the image is NSFW, False otherwise.
        """
        inputs = self.clip_processor(text=["NSFW", "SFW"], images=image, return_tensors="pt", padding=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        outputs = self.clip_model(**inputs)
        logits_per_image = outputs.logits_per_image  # this is the image-text similarity score
        probs = logits_per_image.softmax(dim=1)  # we take the softmax to get the probabilities
        nsfw_prob = probs[0, 0].item()  # probability of "NSFW" label
        return nsfw_prob > 0.8  # Adjusted threshold for NSFW detection

    def detect_faces(self, image: np.ndarray):
        """
        Detect faces using Yunet model.
        """
        blob = cv2.dnn.blobFromImage(image, scalefactor=1.0, size=(320, 320), mean=(104.0, 177.0, 123.0))
        self.face_net.setInput(blob)
        detections = self.face_net.forward()

        h, w = image.shape[:2]
        faces = []
        for i in range(detections.shape[2]):  # Ensure we access the third dimension correctly
            confidence = detections[0, 0, i, 2]
            if confidence > 0.5:  # confidence threshold
                box = detections[0, 0, i, 3:7] * np.array([w, h, w, h])
                (x, y, x1, y1) = box.astype("int")
                face = image[y:y1, x:x1]
                faces.append((x, y, x1, y1, face))
        return faces

    def __call__(self, data):

        def convert_from_cv2_to_image(img: np.ndarray) -> Image:
            return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))

        def convert_from_image_to_cv2(img: Image) -> np.ndarray:
            return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)

        def resize_img(
            input_image,
            max_side=1280,
            min_side=1024,
            size=None,
            pad_to_max_side=False,
            mode=PIL.Image.BILINEAR,
            base_pixel_number=64,
        ):
            w, h = input_image.size
            if size is not None:
                w_resize_new, h_resize_new = size
            else:
                ratio = min_side / min(h, w)
                w, h = round(ratio * w), round(ratio * h)
                ratio = max_side / max(h, w)
                input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode)
                w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
                h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
            input_image = input_image.resize([w_resize_new, h_resize_new], mode)

            if pad_to_max_side:
                res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
                offset_x = (max_side - w_resize_new) // 2
                offset_y = (max_side - h_resize_new) // 2
                res[
                    offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new
                ] = np.array(input_image)
                input_image = Image.fromarray(res)
            return input_image

        def apply_style(
            style_name: str, positive: str, negative: str = ""
        ) -> Tuple[str, str]:
            p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
            return p.replace("{prompt}", positive), n + " " + negative

        request = GenerateImageRequest(**data)
        inputs = request.inputs
        negative_prompt = request.negative_prompt
        style_name = request.style
        identitynet_strength_ratio = request.identitynet_strength_ratio
        adapter_strength_ratio = request.adapter_strength_ratio
        pose_strength = request.pose_strength
        canny_strength = request.canny_strength
        num_steps = request.num_steps
        guidance_scale = request.guidance_scale
        controlnet_selection = request.controlnet_selection
        seed = request.seed
        enhance_face_region = request.enhance_face_region
        enable_LCM = request.enable_LCM

        if enable_LCM:
            self.pipe.enable_lora()
            self.pipe.scheduler = diffusers.LCMScheduler.from_config(self.pipe.scheduler.config)
            guidance_scale = min(max(guidance_scale, 0), 1)
        else:
            self.pipe.disable_lora()
            self.pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(self.pipe.scheduler.config)

        # apply the style template
        inputs, negative_prompt = apply_style(style_name, inputs, negative_prompt)

        # Decode base64 image
        face_image_base64 = data.get("face_image_base64")
        face_image_data = base64.b64decode(face_image_base64)
        face_image = Image.open(io.BytesIO(face_image_data))

        pose_image_base64 = data.get("pose_image_base64")
        pose_image = None
        if pose_image_base64:
            pose_image_data = base64.b64decode(pose_image_base64)
            pose_image = Image.open(io.BytesIO(pose_image_data))

        face_image = resize_img(face_image, max_side=1024)
        face_image_cv2 = convert_from_image_to_cv2(face_image)
        height, width, _ = face_image_cv2.shape

        # Detect faces using Yunet model
        faces = self.detect_faces(face_image_cv2)
        if not faces:
            return {"error": "No faces detected."}
        
        x, y, x1, y1, face_region = faces[0]  # Only using the first detected face for simplicity
        face_kps = draw_kps(face_image, np.array([[x, y], [x1, y1]]))  # Placeholder keypoints

        # Analyze the face using InsightFace
        face_info = self.app.get(face_image_cv2)

        if not face_info:
            return {"error": "Face analysis failed."}

        face_info = face_info[0]  # Assume we are interested in the first face detected
        face_emb = face_info["embedding"]

        img_controlnet = face_image
        if pose_image:
            pose_image = resize_img(pose_image, max_side=1024)
            img_controlnet = pose_image

            pose_image_cv2 = convert_from_image_to_cv2(pose_image)
            faces = self.detect_faces(pose_image_cv2)
            if faces:
                x, y, x1, y1, _ = faces[0]
                face_kps = draw_kps(pose_image, np.array([[x, y], [x1, y1]]))

            width, height = face_kps.size

        control_mask = np.zeros([height, width, 3])
        x1, y1, x2, y2 = x, y, x1, y1
        x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
        control_mask[y1:y2, x1:x2] = 255
        control_mask = Image.fromarray(control_mask.astype(np.uint8))

        controlnet_scales = {
            "pose": pose_strength,
            "canny": canny_strength
        }
        self.pipe.controlnet = MultiControlNetModel(
            [self.controlnet_identitynet]
            + [self.controlnet_map[s] for s in controlnet_selection]
        )
        control_scales = [float(identitynet_strength_ratio)] + [
            controlnet_scales[s] for s in controlnet_selection
        ]
        control_images = [face_kps] + [
            self.controlnet_map_fn[s](img_controlnet).resize((width, height))
            for s in controlnet_selection
        ]

        generator = torch.Generator(device=device).manual_seed(seed)

        print("Start inference...")
        print(f"[Debug] Prompt: {inputs}, \n[Debug] Neg Prompt: {negative_prompt}")

        self.pipe.set_ip_adapter_scale(adapter_strength_ratio)
        outputs = self.pipe(
            prompt=inputs,
            negative_prompt=negative_prompt,
            image_embeds=face_emb,
            image=control_images,
            control_mask=control_mask,
            controlnet_conditioning_scale=control_scales,
            num_inference_steps=num_steps,
            guidance_scale=guidance_scale,
            height=height,
            width=width,
            generator=generator,
            enhance_face_region=enhance_face_region
        )
       
        images = outputs.images

        # Check for NSFW content
        if self.is_nsfw(images[0]):
            return {"error": "Generated image contains NSFW content and was discarded."}

        # Convert the output image to base64
        buffered = io.BytesIO()
        images[0].save(buffered, format="JPEG")
        img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")

        return {"generated_image_base64": img_str}