Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -9,7 +9,6 @@ from huggingface_hub import hf_hub_download
|
|
9 |
import torch
|
10 |
import pickle
|
11 |
import numpy as np
|
12 |
-
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
13 |
|
14 |
# Load models and tokenizers
|
15 |
models = {
|
@@ -76,13 +75,12 @@ def predict_with_bert_multilingual(text):
|
|
76 |
def predict_with_tinybert(text):
|
77 |
tokenizer = models["TinyBERT"]["tokenizer"]
|
78 |
model = models["TinyBERT"]["model"]
|
79 |
-
encodings = tokenizer([text], padding=True, truncation=True, max_length=
|
80 |
with torch.no_grad():
|
81 |
outputs = model(**encodings)
|
82 |
logits = outputs.logits
|
83 |
predictions = logits.argmax(axis=-1).cpu().numpy()
|
84 |
-
return int(predictions[0])
|
85 |
-
|
86 |
|
87 |
# Unified function for sentiment analysis and statistics
|
88 |
def analyze_sentiment_and_statistics(text):
|
@@ -95,22 +93,28 @@ def analyze_sentiment_and_statistics(text):
|
|
95 |
|
96 |
# Calculate statistics
|
97 |
scores = list(results.values())
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
return results, statistics
|
108 |
|
109 |
# Gradio Interface
|
110 |
with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding: 20px; }") as demo:
|
111 |
gr.Markdown("# Sentiment Analysis App")
|
112 |
gr.Markdown(
|
113 |
-
"This app predicts the sentiment of the input text on a scale from 1 to 5 using multiple models and provides
|
114 |
)
|
115 |
|
116 |
with gr.Row():
|
@@ -150,7 +154,7 @@ with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding:
|
|
150 |
with gr.Column():
|
151 |
distilbert_output = gr.Textbox(label="Predicted Sentiment (DistilBERT)", interactive=False)
|
152 |
log_reg_output = gr.Textbox(label="Predicted Sentiment (Logistic Regression)", interactive=False)
|
153 |
-
bert_output = gr.Textbox(label="Predicted Sentiment (BERT Multilingual)", interactive=False)
|
154 |
tinybert_output = gr.Textbox(label="Predicted Sentiment (TinyBERT)", interactive=False)
|
155 |
|
156 |
with gr.Column():
|
@@ -159,13 +163,22 @@ with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding:
|
|
159 |
# Button to analyze sentiment and show statistics
|
160 |
def process_input_and_analyze(text_input):
|
161 |
results, statistics = analyze_sentiment_and_statistics(text_input)
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
analyze_button.click(
|
171 |
process_input_and_analyze,
|
@@ -173,7 +186,5 @@ with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding:
|
|
173 |
outputs=[distilbert_output, log_reg_output, bert_output, tinybert_output, statistics_output]
|
174 |
)
|
175 |
|
176 |
-
|
177 |
-
|
178 |
# Launch the app
|
179 |
demo.launch()
|
|
|
9 |
import torch
|
10 |
import pickle
|
11 |
import numpy as np
|
|
|
12 |
|
13 |
# Load models and tokenizers
|
14 |
models = {
|
|
|
75 |
def predict_with_tinybert(text):
|
76 |
tokenizer = models["TinyBERT"]["tokenizer"]
|
77 |
model = models["TinyBERT"]["model"]
|
78 |
+
encodings = tokenizer([text], padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
|
79 |
with torch.no_grad():
|
80 |
outputs = model(**encodings)
|
81 |
logits = outputs.logits
|
82 |
predictions = logits.argmax(axis=-1).cpu().numpy()
|
83 |
+
return int(predictions[0] + 1)
|
|
|
84 |
|
85 |
# Unified function for sentiment analysis and statistics
|
86 |
def analyze_sentiment_and_statistics(text):
|
|
|
93 |
|
94 |
# Calculate statistics
|
95 |
scores = list(results.values())
|
96 |
+
if all(score == scores[0] for score in scores): # Check if all predictions are the same
|
97 |
+
statistics = {
|
98 |
+
"Message": "All models predict the same score.",
|
99 |
+
"Average Score": f"{scores[0]:.2f}",
|
100 |
+
}
|
101 |
+
else:
|
102 |
+
min_score_model = min(results, key=results.get)
|
103 |
+
max_score_model = max(results, key=results.get)
|
104 |
+
average_score = np.mean(scores)
|
105 |
+
|
106 |
+
statistics = {
|
107 |
+
"Lowest Score": f"{results[min_score_model]} (Model: {min_score_model})",
|
108 |
+
"Highest Score": f"{results[max_score_model]} (Model: {max_score_model})",
|
109 |
+
"Average Score": f"{average_score:.2f}",
|
110 |
+
}
|
111 |
return results, statistics
|
112 |
|
113 |
# Gradio Interface
|
114 |
with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding: 20px; }") as demo:
|
115 |
gr.Markdown("# Sentiment Analysis App")
|
116 |
gr.Markdown(
|
117 |
+
"This app predicts the sentiment of the input text on a scale from 1 to 5 using multiple models and provides basic statistics."
|
118 |
)
|
119 |
|
120 |
with gr.Row():
|
|
|
154 |
with gr.Column():
|
155 |
distilbert_output = gr.Textbox(label="Predicted Sentiment (DistilBERT)", interactive=False)
|
156 |
log_reg_output = gr.Textbox(label="Predicted Sentiment (Logistic Regression)", interactive=False)
|
157 |
+
bert_output = gr.Textbox(label="Predicted Sentiment (BERT Multilingual)", interactive=False)
|
158 |
tinybert_output = gr.Textbox(label="Predicted Sentiment (TinyBERT)", interactive=False)
|
159 |
|
160 |
with gr.Column():
|
|
|
163 |
# Button to analyze sentiment and show statistics
|
164 |
def process_input_and_analyze(text_input):
|
165 |
results, statistics = analyze_sentiment_and_statistics(text_input)
|
166 |
+
if "Message" in statistics: # All models predicted the same score
|
167 |
+
return (
|
168 |
+
f"{results['DistilBERT']}",
|
169 |
+
f"{results['Logistic Regression']}",
|
170 |
+
f"{results['BERT Multilingual (NLP Town)']}",
|
171 |
+
f"{results['TinyBERT']}",
|
172 |
+
f"Statistics:\n{statistics['Message']}\nAverage Score: {statistics['Average Score']}"
|
173 |
+
)
|
174 |
+
else: # Min and Max scores are present
|
175 |
+
return (
|
176 |
+
f"{results['DistilBERT']}",
|
177 |
+
f"{results['Logistic Regression']}",
|
178 |
+
f"{results['BERT Multilingual (NLP Town)']}",
|
179 |
+
f"{results['TinyBERT']}",
|
180 |
+
f"Statistics:\n{statistics['Lowest Score']}\n{statistics['Highest Score']}\nAverage Score: {statistics['Average Score']}"
|
181 |
+
)
|
182 |
|
183 |
analyze_button.click(
|
184 |
process_input_and_analyze,
|
|
|
186 |
outputs=[distilbert_output, log_reg_output, bert_output, tinybert_output, statistics_output]
|
187 |
)
|
188 |
|
|
|
|
|
189 |
# Launch the app
|
190 |
demo.launch()
|