katsuchi commited on
Commit
ad23810
·
verified ·
1 Parent(s): ff54c9e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from peft import PeftModel
4
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
+ import json
6
+
7
+ # Load model and tokenizer
8
+ base_model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=6)
9
+ model = PeftModel.from_pretrained(base_model, "katsuchi/bert-dair-ai-emotion")
10
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
11
+
12
+ def predict_emotion(text):
13
+ # Tokenize input
14
+ tokens = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
15
+
16
+ # Get model prediction
17
+ with torch.no_grad():
18
+ outputs = model(tokens['input_ids'])
19
+ probs = torch.softmax(outputs.logits, dim=-1)
20
+
21
+ # Convert probabilities to percentages
22
+ percentages = (probs * 100).squeeze().tolist()
23
+
24
+ # Create emotion-percentage mapping
25
+ emotions = ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']
26
+ emotion_probs = {
27
+ emotion: f"{percentage:.1f}%"
28
+ for emotion, percentage in zip(emotions, percentages)
29
+ }
30
+
31
+ # Sort by probability in descending order
32
+ sorted_emotions = dict(
33
+ sorted(emotion_probs.items(),
34
+ key=lambda x: float(x[1].rstrip('%')),
35
+ reverse=True)
36
+ )
37
+
38
+ # Format output
39
+ return json.dumps(sorted_emotions, indent=2)
40
+
41
+ # Create Gradio interface
42
+ iface = gr.Interface(
43
+ fn=predict_emotion,
44
+ inputs=gr.Textbox(
45
+ lines=3,
46
+ placeholder="Enter text here..."
47
+ ),
48
+ outputs=gr.JSON(),
49
+ title="Emotion Classifier",
50
+ description="Predict emotions in text with confidence percentages",
51
+ examples=[
52
+ ["I am so happy to see you!"],
53
+ ["I'm really disappointed with the results."],
54
+ ["That's absolutely terrifying!"],
55
+ ["I love spending time with my family."]
56
+ ]
57
+ )
58
+
59
+ if __name__ == "__main__":
60
+ iface.launch()