Thouph's picture
Create app.py
27757e7 verified
raw
history blame
1.78 kB
import json
import random
random.seed(999)
import torch
from transformers import Qwen2ForSequenceClassification, AutoTokenizer
import gradio as gr
from datetime import datetime
torch.set_grad_enabled(False)
model = Qwen2ForSequenceClassification.from_pretrained("Thouph/prompt2tag-qwen2-0.5b-v0.1", num_labels = 9940).to("cuda")
model.eval()
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
with open("tags_9940.json", "r") as file:
allowed_tags = json.load(file)
allowed_tags = sorted(allowed_tags)
def create_tags(prompt, threshold):
inputs = tokenizer(
prompt,
padding="do_not_pad",
max_length=512,
truncation=True,
return_tensors="pt",
)
for k in inputs.keys():
inputs[k] = inputs[k].to("cuda")
# Generate
output = model(**inputs).logits
output = torch.nn.functional.sigmoid(output)
indices = torch.where(output > threshold)
values = output[indices]
indices = indices[1]
values = values.squeeze()
temp = []
tag_score = dict()
for i in range(indices.size(0)):
temp.append([allowed_tags[indices[i]], values[i].item()])
tag_score[allowed_tags[indices[i]]] = values[i].item()
temp = [t[0] for t in temp]
text_no_impl = " ".join(temp)
current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
print(f"{current_datetime}: finished.")
return text_no_impl, tag_score
demo = gr.Interface(
create_tags,
inputs=[
gr.TextArea(label="Prompt",),
gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.40, label="Threshold")
],
outputs=[
gr.Textbox(label="Tag String"),
gr.Label(label="Tag Predictions", num_top_classes=200),
],
allow_flagging="never",
)
demo.launch()