Jacobmadwed commited on
Commit
013f209
·
verified ·
1 Parent(s): 852fbd3

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +14 -4
handler.py CHANGED
@@ -9,7 +9,6 @@ import diffusers
9
  from diffusers.utils import load_image
10
  from diffusers.models import ControlNetModel
11
  from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
12
- from diffusers import StableDiffusionSafetyChecker
13
  from insightface.app import FaceAnalysis
14
  from style_template import styles
15
  from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline, draw_kps
@@ -21,6 +20,7 @@ from huggingface_hub import hf_hub_download
21
  import base64
22
  import io
23
  import json
 
24
 
25
  # global variable
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -109,7 +109,7 @@ class EndpointHandler:
109
  pretrained_model_name_or_path,
110
  controlnet=[self.controlnet_identitynet],
111
  torch_dtype=dtype,
112
- safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
113
  feature_extractor=None,
114
  ).to(device)
115
 
@@ -127,6 +127,10 @@ class EndpointHandler:
127
  self.pipe.image_proj_model.to("cuda")
128
  self.pipe.unet.to("cuda")
129
 
 
 
 
 
130
  def __call__(self, data):
131
 
132
  def convert_from_cv2_to_image(img: np.ndarray) -> Image:
@@ -172,6 +176,13 @@ class EndpointHandler:
172
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
173
  return p.replace("{prompt}", positive), n + " " + negative
174
 
 
 
 
 
 
 
 
175
  request = GenerateImageRequest(**data)
176
  inputs = request.inputs
177
  negative_prompt = request.negative_prompt
@@ -281,10 +292,9 @@ class EndpointHandler:
281
  )
282
 
283
  images = outputs.images
284
- nsfw_detected = outputs.nsfw_content_detected
285
 
286
  # Check for NSFW content
287
- if nsfw_detected[0]:
288
  return {"error": "Generated image contains NSFW content and was discarded."}
289
 
290
  # Convert the output image to base64
 
9
  from diffusers.utils import load_image
10
  from diffusers.models import ControlNetModel
11
  from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
 
12
  from insightface.app import FaceAnalysis
13
  from style_template import styles
14
  from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline, draw_kps
 
20
  import base64
21
  import io
22
  import json
23
+ from transformers import CLIPProcessor, CLIPModel
24
 
25
  # global variable
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
109
  pretrained_model_name_or_path,
110
  controlnet=[self.controlnet_identitynet],
111
  torch_dtype=dtype,
112
+ safety_checker=None, # We will use an external safety checker
113
  feature_extractor=None,
114
  ).to(device)
115
 
 
127
  self.pipe.image_proj_model.to("cuda")
128
  self.pipe.unet.to("cuda")
129
 
130
+ # Load safety checker
131
+ self.safety_checker = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
132
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
133
+
134
  def __call__(self, data):
135
 
136
  def convert_from_cv2_to_image(img: np.ndarray) -> Image:
 
176
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
177
  return p.replace("{prompt}", positive), n + " " + negative
178
 
179
+ def is_nsfw(image: Image) -> bool:
180
+ inputs = self.processor(images=image, return_tensors="pt")
181
+ outputs = self.safety_checker(**inputs)
182
+ logits_per_image = outputs.logits_per_image
183
+ probs = logits_per_image.softmax(dim=1) # We assume the probability for NSFW content is stored in the first position
184
+ return probs[0, 0] > 0.5 # This threshold may need to be adjusted
185
+
186
  request = GenerateImageRequest(**data)
187
  inputs = request.inputs
188
  negative_prompt = request.negative_prompt
 
292
  )
293
 
294
  images = outputs.images
 
295
 
296
  # Check for NSFW content
297
+ if is_nsfw(images[0]):
298
  return {"error": "Generated image contains NSFW content and was discarded."}
299
 
300
  # Convert the output image to base64