cdnuts's picture
Rename run.py to app.py
2fffe3f verified
raw
history blame
1.62 kB
import argparse
import json
import time
from PIL import Image
import torch
from torchvision.transforms import transforms
import gradio as gr
parser = argparse.ArgumentParser(description="Image Classification")
parser.add_argument("-i", "--image_path", required=True, help="Path to the image file")
args = parser.parse_args()
model = torch.load('model.pth', map_location=torch.device('cpu'))
model.eval()
transform = transforms.Compose([
transforms.Resize((448, 448)),
transforms.ToTensor(),
transforms.Normalize(mean=[
0.48145466,
0.4578275,
0.40821073
], std=[
0.26862954,
0.26130258,
0.27577711
])
])
with open("tags_8041.json", "r") as file:
tags = json.load(file)
allowed_tags = sorted(tags)
allowed_tags.insert(0, "placeholder0")
allowed_tags.append("placeholder1")
allowed_tags.append("explicit")
allowed_tags.append("questionable")
allowed_tags.append("safe")
def create_tags(image):
img = image.convert('RGB')
tensor = transform(img).unsqueeze(0)
with torch.no_grad():
out = model(tensor)
probabilities = torch.nn.functional.sigmoid(out[0])
indices = torch.where(probabilities > 0.3)[0]
values = probabilities[indices]
temp = []
for i in range(indices.size(0)):
temp.append([allowed_tags[indices[i]], values[i].item()])
temp = sorted(temp, key=lambda x: x[1], reverse=True)
text = ""
for i in range(len(temp)):
text += temp[i][0] + (' ,' if i < len(temp) - 1 else '')
return text
demo = gr.Interface(
fn=create_tags,
inputs=["image"],
outputs=["text"],
)
demo.launch()