Update handler.py
Browse files- handler.py +23 -12
handler.py
CHANGED
@@ -109,7 +109,7 @@ class EndpointHandler:
|
|
109 |
pretrained_model_name_or_path,
|
110 |
controlnet=[self.controlnet_identitynet],
|
111 |
torch_dtype=dtype,
|
112 |
-
safety_checker=None,
|
113 |
feature_extractor=None,
|
114 |
).to(device)
|
115 |
|
@@ -127,9 +127,27 @@ class EndpointHandler:
|
|
127 |
self.pipe.image_proj_model.to("cuda")
|
128 |
self.pipe.unet.to("cuda")
|
129 |
|
130 |
-
# Load safety
|
131 |
-
self.
|
132 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
def __call__(self, data):
|
135 |
|
@@ -176,13 +194,6 @@ class EndpointHandler:
|
|
176 |
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
|
177 |
return p.replace("{prompt}", positive), n + " " + negative
|
178 |
|
179 |
-
def is_nsfw(image: Image) -> bool:
|
180 |
-
inputs = self.processor(images=image, return_tensors="pt")
|
181 |
-
outputs = self.safety_checker(**inputs)
|
182 |
-
logits_per_image = outputs.logits_per_image
|
183 |
-
probs = logits_per_image.softmax(dim=1) # We assume the probability for NSFW content is stored in the first position
|
184 |
-
return probs[0, 1] > 0.5 # This threshold may need to be adjusted
|
185 |
-
|
186 |
request = GenerateImageRequest(**data)
|
187 |
inputs = request.inputs
|
188 |
negative_prompt = request.negative_prompt
|
@@ -294,7 +305,7 @@ class EndpointHandler:
|
|
294 |
images = outputs.images
|
295 |
|
296 |
# Check for NSFW content
|
297 |
-
if is_nsfw(images[0]):
|
298 |
return {"error": "Generated image contains NSFW content and was discarded."}
|
299 |
|
300 |
# Convert the output image to base64
|
|
|
109 |
pretrained_model_name_or_path,
|
110 |
controlnet=[self.controlnet_identitynet],
|
111 |
torch_dtype=dtype,
|
112 |
+
safety_checker=None,
|
113 |
feature_extractor=None,
|
114 |
).to(device)
|
115 |
|
|
|
127 |
self.pipe.image_proj_model.to("cuda")
|
128 |
self.pipe.unet.to("cuda")
|
129 |
|
130 |
+
# Load CLIP model for safety checking
|
131 |
+
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
132 |
+
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
|
133 |
+
|
134 |
+
def is_nsfw(self, image: Image.Image) -> bool:
|
135 |
+
"""
|
136 |
+
Check if an image contains NSFW content using CLIP model.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
image (Image.Image): PIL image to check.
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
bool: True if the image is NSFW, False otherwise.
|
143 |
+
"""
|
144 |
+
inputs = self.clip_processor(text=["NSFW", "SFW"], images=image, return_tensors="pt", padding=True)
|
145 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
146 |
+
outputs = self.clip_model(**inputs)
|
147 |
+
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
148 |
+
probs = logits_per_image.softmax(dim=1) # we take the softmax to get the probabilities
|
149 |
+
nsfw_prob = probs[0, 0].item() # probability of "NSFW" label
|
150 |
+
return nsfw_prob > 0.5
|
151 |
|
152 |
def __call__(self, data):
|
153 |
|
|
|
194 |
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
|
195 |
return p.replace("{prompt}", positive), n + " " + negative
|
196 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
request = GenerateImageRequest(**data)
|
198 |
inputs = request.inputs
|
199 |
negative_prompt = request.negative_prompt
|
|
|
305 |
images = outputs.images
|
306 |
|
307 |
# Check for NSFW content
|
308 |
+
if self.is_nsfw(images[0]):
|
309 |
return {"error": "Generated image contains NSFW content and was discarded."}
|
310 |
|
311 |
# Convert the output image to base64
|