Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from transformers import CLIPFeatureExtractor | |
from safety_checker import StableDiffusionSafetyChecker | |
import torch | |
from PIL import Image | |
import gradio as gr | |
from pathlib import Path | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
safety_checker = StableDiffusionSafetyChecker.from_pretrained( | |
"CompVis/stable-diffusion-safety-checker" | |
).to(device) | |
feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32") | |
import gradio as gr | |
def image_classifier(files): | |
images = [Image.open(file).convert("RGB").resize((512, 512)) for file in files] | |
safety_checker_input = feature_extractor(images, return_tensors="pt").to(device) | |
has_nsfw_concepts = safety_checker( | |
images=[images], clip_input=safety_checker_input.pixel_values.to(torch.float16) | |
) | |
results = [ | |
{"has_nsfw": nsfw, "file": Path(file).name} | |
for (nsfw, file) in zip(has_nsfw_concepts, files) | |
] | |
return {"results": results} | |
demo = gr.Interface( | |
title="Stable Diffusion Safety Checker API", | |
fn=image_classifier, | |
inputs=gr.File(file_count="multiple", file_types=["image"]), | |
outputs="json", | |
api_name="classify", | |
) | |
demo.launch() |