Update handler.py
Browse files- handler.py +14 -4
handler.py
CHANGED
@@ -9,7 +9,6 @@ import diffusers
|
|
9 |
from diffusers.utils import load_image
|
10 |
from diffusers.models import ControlNetModel
|
11 |
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
12 |
-
from diffusers import StableDiffusionSafetyChecker
|
13 |
from insightface.app import FaceAnalysis
|
14 |
from style_template import styles
|
15 |
from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline, draw_kps
|
@@ -21,6 +20,7 @@ from huggingface_hub import hf_hub_download
|
|
21 |
import base64
|
22 |
import io
|
23 |
import json
|
|
|
24 |
|
25 |
# global variable
|
26 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -109,7 +109,7 @@ class EndpointHandler:
|
|
109 |
pretrained_model_name_or_path,
|
110 |
controlnet=[self.controlnet_identitynet],
|
111 |
torch_dtype=dtype,
|
112 |
-
safety_checker=
|
113 |
feature_extractor=None,
|
114 |
).to(device)
|
115 |
|
@@ -127,6 +127,10 @@ class EndpointHandler:
|
|
127 |
self.pipe.image_proj_model.to("cuda")
|
128 |
self.pipe.unet.to("cuda")
|
129 |
|
|
|
|
|
|
|
|
|
130 |
def __call__(self, data):
|
131 |
|
132 |
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
|
@@ -172,6 +176,13 @@ class EndpointHandler:
|
|
172 |
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
|
173 |
return p.replace("{prompt}", positive), n + " " + negative
|
174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
request = GenerateImageRequest(**data)
|
176 |
inputs = request.inputs
|
177 |
negative_prompt = request.negative_prompt
|
@@ -281,10 +292,9 @@ class EndpointHandler:
|
|
281 |
)
|
282 |
|
283 |
images = outputs.images
|
284 |
-
nsfw_detected = outputs.nsfw_content_detected
|
285 |
|
286 |
# Check for NSFW content
|
287 |
-
if
|
288 |
return {"error": "Generated image contains NSFW content and was discarded."}
|
289 |
|
290 |
# Convert the output image to base64
|
|
|
9 |
from diffusers.utils import load_image
|
10 |
from diffusers.models import ControlNetModel
|
11 |
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
|
|
12 |
from insightface.app import FaceAnalysis
|
13 |
from style_template import styles
|
14 |
from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline, draw_kps
|
|
|
20 |
import base64
|
21 |
import io
|
22 |
import json
|
23 |
+
from transformers import CLIPProcessor, CLIPModel
|
24 |
|
25 |
# global variable
|
26 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
109 |
pretrained_model_name_or_path,
|
110 |
controlnet=[self.controlnet_identitynet],
|
111 |
torch_dtype=dtype,
|
112 |
+
safety_checker=None, # We will use an external safety checker
|
113 |
feature_extractor=None,
|
114 |
).to(device)
|
115 |
|
|
|
127 |
self.pipe.image_proj_model.to("cuda")
|
128 |
self.pipe.unet.to("cuda")
|
129 |
|
130 |
+
# Load safety checker
|
131 |
+
self.safety_checker = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
132 |
+
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
133 |
+
|
134 |
def __call__(self, data):
|
135 |
|
136 |
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
|
|
|
176 |
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
|
177 |
return p.replace("{prompt}", positive), n + " " + negative
|
178 |
|
179 |
+
def is_nsfw(image: Image) -> bool:
|
180 |
+
inputs = self.processor(images=image, return_tensors="pt")
|
181 |
+
outputs = self.safety_checker(**inputs)
|
182 |
+
logits_per_image = outputs.logits_per_image
|
183 |
+
probs = logits_per_image.softmax(dim=1) # We assume the probability for NSFW content is stored in the first position
|
184 |
+
return probs[0, 0] > 0.5 # This threshold may need to be adjusted
|
185 |
+
|
186 |
request = GenerateImageRequest(**data)
|
187 |
inputs = request.inputs
|
188 |
negative_prompt = request.negative_prompt
|
|
|
292 |
)
|
293 |
|
294 |
images = outputs.images
|
|
|
295 |
|
296 |
# Check for NSFW content
|
297 |
+
if is_nsfw(images[0]):
|
298 |
return {"error": "Generated image contains NSFW content and was discarded."}
|
299 |
|
300 |
# Convert the output image to base64
|