Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,6 +2,9 @@ import gradio as gr
|
|
2 |
import torch
|
3 |
import spaces
|
4 |
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
|
|
|
|
|
|
|
5 |
from huggingface_hub import hf_hub_download
|
6 |
from safetensors.torch import load_file
|
7 |
from PIL import Image
|
@@ -27,6 +30,11 @@ unet.load_state_dict(load_file(hf_hub_download(repo, opts["4 Steps"][0])))
|
|
27 |
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=dtype, variant="fp16").to(device, dtype)
|
28 |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
29 |
|
|
|
|
|
|
|
|
|
|
|
30 |
with open("filter.txt") as f:
|
31 |
filter_words = {word for word in f.read().split("\n") if word}
|
32 |
|
@@ -37,7 +45,7 @@ def generate(prompt, option, progress=gr.Progress()):
|
|
37 |
print(prompt, option)
|
38 |
ckpt, step = opts[option]
|
39 |
if any(word in prompt for word in filter_words):
|
40 |
-
gr.Warning("Safety checker triggered.")
|
41 |
print(f"Safety checker triggered on prompt: {prompt}")
|
42 |
return Image.new("RGB", (512, 512))
|
43 |
progress((0, step))
|
@@ -49,17 +57,18 @@ def generate(prompt, option, progress=gr.Progress()):
|
|
49 |
def inference_callback(p, i, t, kwargs):
|
50 |
progress((i+1, step))
|
51 |
return kwargs
|
52 |
-
results = pipe(prompt, num_inference_steps=step, guidance_scale=0, callback_on_step_end=inference_callback)
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
|
|
57 |
)
|
58 |
-
if
|
59 |
-
gr.Warning("Safety checker triggered.")
|
60 |
print(f"Safety checker triggered on prompt: {prompt}")
|
61 |
-
|
62 |
-
return results.images[0]
|
63 |
|
64 |
with gr.Blocks(css="style.css") as demo:
|
65 |
gr.HTML(
|
|
|
2 |
import torch
|
3 |
import spaces
|
4 |
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
|
5 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
6 |
+
from diffusers.image_processor import VaeImageProcessor
|
7 |
+
from transformers import CLIPImageProcessor
|
8 |
from huggingface_hub import hf_hub_download
|
9 |
from safetensors.torch import load_file
|
10 |
from PIL import Image
|
|
|
30 |
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=dtype, variant="fp16").to(device, dtype)
|
31 |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
32 |
|
33 |
+
# Safety checker.
|
34 |
+
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to(device, dtype)
|
35 |
+
feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
36 |
+
image_processor = VaeImageProcessor(vae_scale_factor=8)
|
37 |
+
|
38 |
with open("filter.txt") as f:
|
39 |
filter_words = {word for word in f.read().split("\n") if word}
|
40 |
|
|
|
45 |
print(prompt, option)
|
46 |
ckpt, step = opts[option]
|
47 |
if any(word in prompt for word in filter_words):
|
48 |
+
gr.Warning("Safety checker triggered. Image may contain violent or sexual content.")
|
49 |
print(f"Safety checker triggered on prompt: {prompt}")
|
50 |
return Image.new("RGB", (512, 512))
|
51 |
progress((0, step))
|
|
|
57 |
def inference_callback(p, i, t, kwargs):
|
58 |
progress((i+1, step))
|
59 |
return kwargs
|
60 |
+
results = pipe(prompt, num_inference_steps=step, guidance_scale=0, callback_on_step_end=inference_callback, output_type="pil")
|
61 |
+
|
62 |
+
# Safety check
|
63 |
+
feature_extractor_input = image_processor.postprocess(results.images, output_type="pil")
|
64 |
+
safety_checker_input = feature_extractor(feature_extractor_input, return_tensors="pt")
|
65 |
+
images, has_nsfw_concept = safety_checker(
|
66 |
+
images=results.images, clip_input=safety_checker_input.pixel_values.to(device, dtype)
|
67 |
)
|
68 |
+
if has_nsfw_concept[0]:
|
69 |
+
gr.Warning("Safety checker triggered. Image may contain violent or sexual content.")
|
70 |
print(f"Safety checker triggered on prompt: {prompt}")
|
71 |
+
return images[0]
|
|
|
72 |
|
73 |
with gr.Blocks(css="style.css") as demo:
|
74 |
gr.HTML(
|