Thouph commited on
Commit
27757e7
1 Parent(s): 1cee72b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ random.seed(999)
4
+ import torch
5
+ from transformers import Qwen2ForSequenceClassification, AutoTokenizer
6
+ import gradio as gr
7
+ from datetime import datetime
8
+ torch.set_grad_enabled(False)
9
+
10
+ model = Qwen2ForSequenceClassification.from_pretrained("Thouph/prompt2tag-qwen2-0.5b-v0.1", num_labels = 9940).to("cuda")
11
+ model.eval()
12
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
13
+
14
+ with open("tags_9940.json", "r") as file:
15
+ allowed_tags = json.load(file)
16
+
17
+ allowed_tags = sorted(allowed_tags)
18
+
19
+ def create_tags(prompt, threshold):
20
+ inputs = tokenizer(
21
+ prompt,
22
+ padding="do_not_pad",
23
+ max_length=512,
24
+ truncation=True,
25
+ return_tensors="pt",
26
+ )
27
+
28
+ for k in inputs.keys():
29
+ inputs[k] = inputs[k].to("cuda")
30
+ # Generate
31
+ output = model(**inputs).logits
32
+ output = torch.nn.functional.sigmoid(output)
33
+ indices = torch.where(output > threshold)
34
+ values = output[indices]
35
+ indices = indices[1]
36
+ values = values.squeeze()
37
+
38
+ temp = []
39
+ tag_score = dict()
40
+ for i in range(indices.size(0)):
41
+ temp.append([allowed_tags[indices[i]], values[i].item()])
42
+ tag_score[allowed_tags[indices[i]]] = values[i].item()
43
+ temp = [t[0] for t in temp]
44
+ text_no_impl = " ".join(temp)
45
+ current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
46
+ print(f"{current_datetime}: finished.")
47
+ return text_no_impl, tag_score
48
+
49
+ demo = gr.Interface(
50
+ create_tags,
51
+ inputs=[
52
+ gr.TextArea(label="Prompt",),
53
+ gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.40, label="Threshold")
54
+ ],
55
+ outputs=[
56
+ gr.Textbox(label="Tag String"),
57
+ gr.Label(label="Tag Predictions", num_top_classes=200),
58
+ ],
59
+ allow_flagging="never",
60
+ )
61
+
62
+ demo.launch()