Jacobmadwed commited on
Commit
3b27112
·
verified ·
1 Parent(s): 5f710c2

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +23 -12
handler.py CHANGED
@@ -109,7 +109,7 @@ class EndpointHandler:
109
  pretrained_model_name_or_path,
110
  controlnet=[self.controlnet_identitynet],
111
  torch_dtype=dtype,
112
- safety_checker=None, # We will use an external safety checker
113
  feature_extractor=None,
114
  ).to(device)
115
 
@@ -127,9 +127,27 @@ class EndpointHandler:
127
  self.pipe.image_proj_model.to("cuda")
128
  self.pipe.unet.to("cuda")
129
 
130
- # Load safety checker
131
- self.safety_checker = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
132
- self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  def __call__(self, data):
135
 
@@ -176,13 +194,6 @@ class EndpointHandler:
176
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
177
  return p.replace("{prompt}", positive), n + " " + negative
178
 
179
- def is_nsfw(image: Image) -> bool:
180
- inputs = self.processor(images=image, return_tensors="pt")
181
- outputs = self.safety_checker(**inputs)
182
- logits_per_image = outputs.logits_per_image
183
- probs = logits_per_image.softmax(dim=1) # We assume the probability for NSFW content is stored in the first position
184
- return probs[0, 1] > 0.5 # This threshold may need to be adjusted
185
-
186
  request = GenerateImageRequest(**data)
187
  inputs = request.inputs
188
  negative_prompt = request.negative_prompt
@@ -294,7 +305,7 @@ class EndpointHandler:
294
  images = outputs.images
295
 
296
  # Check for NSFW content
297
- if is_nsfw(images[0]):
298
  return {"error": "Generated image contains NSFW content and was discarded."}
299
 
300
  # Convert the output image to base64
 
109
  pretrained_model_name_or_path,
110
  controlnet=[self.controlnet_identitynet],
111
  torch_dtype=dtype,
112
+ safety_checker=None,
113
  feature_extractor=None,
114
  ).to(device)
115
 
 
127
  self.pipe.image_proj_model.to("cuda")
128
  self.pipe.unet.to("cuda")
129
 
130
+ # Load CLIP model for safety checking
131
+ self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
132
+ self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
133
+
134
+ def is_nsfw(self, image: Image.Image) -> bool:
135
+ """
136
+ Check if an image contains NSFW content using CLIP model.
137
+
138
+ Args:
139
+ image (Image.Image): PIL image to check.
140
+
141
+ Returns:
142
+ bool: True if the image is NSFW, False otherwise.
143
+ """
144
+ inputs = self.clip_processor(text=["NSFW", "SFW"], images=image, return_tensors="pt", padding=True)
145
+ inputs = {k: v.to(device) for k, v in inputs.items()}
146
+ outputs = self.clip_model(**inputs)
147
+ logits_per_image = outputs.logits_per_image # this is the image-text similarity score
148
+ probs = logits_per_image.softmax(dim=1) # we take the softmax to get the probabilities
149
+ nsfw_prob = probs[0, 0].item() # probability of "NSFW" label
150
+ return nsfw_prob > 0.5
151
 
152
  def __call__(self, data):
153
 
 
194
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
195
  return p.replace("{prompt}", positive), n + " " + negative
196
 
 
 
 
 
 
 
 
197
  request = GenerateImageRequest(**data)
198
  inputs = request.inputs
199
  negative_prompt = request.negative_prompt
 
305
  images = outputs.images
306
 
307
  # Check for NSFW content
308
+ if self.is_nsfw(images[0]):
309
  return {"error": "Generated image contains NSFW content and was discarded."}
310
 
311
  # Convert the output image to base64