cdnuts's picture
Update app.py
233aa95 verified
raw
history blame
No virus
1.56 kB
import json
import time
from PIL import Image
import torch
from torchvision.transforms import transforms
import gradio as gr
from huggingface_hub import snapshot_download
snapshot_download(repo_id="Thouph/eva-vit-1b-224-8043", local_dir = "./")
model = torch.load('model.pth', map_location=torch.device('cpu'))
model.eval()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[
0.48145466,
0.4578275,
0.40821073
], std=[
0.26862954,
0.26130258,
0.27577711
])
])
with open("tags_8040.json", "r") as file:
tags = json.load(file)
allowed_tags = sorted(tags)
allowed_tags.append("explicit")
allowed_tags.append("questionable")
allowed_tags.append("safe")
def create_tags(image):
img = image.convert('RGB')
tensor = transform(img).unsqueeze(0)
with torch.no_grad():
out = model(tensor)
probabilities = torch.nn.functional.sigmoid(out[0])
indices = torch.where(probabilities > 0.3)[0]
values = probabilities[indices]
temp = []
for i in range(indices.size(0)):
temp.append([allowed_tags[indices[i]], values[i].item()])
text = ""
for i in range(len(temp)):
text += temp[i][0] + (', ' if i < len(temp) - 1 else '')
text = text.replace(r"placeholder1, ", "")
text = text.replace(r"_", " ")
text = text.replace("(", "\\(").replace(")", "\\)")
return text
demo = gr.Interface(
fn=create_tags,
inputs=[gr.Image(type="pil")],
outputs=["text"],
)
demo.launch()