nhull commited on
Commit
e4627b7
·
verified ·
1 Parent(s): bdb3169

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -24
app.py CHANGED
@@ -9,7 +9,6 @@ from huggingface_hub import hf_hub_download
9
  import torch
10
  import pickle
11
  import numpy as np
12
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
13
 
14
  # Load models and tokenizers
15
  models = {
@@ -76,13 +75,12 @@ def predict_with_bert_multilingual(text):
76
  def predict_with_tinybert(text):
77
  tokenizer = models["TinyBERT"]["tokenizer"]
78
  model = models["TinyBERT"]["model"]
79
- encodings = tokenizer([text], padding=True, truncation=True, max_length=128, return_tensors="pt").to(device)
80
  with torch.no_grad():
81
  outputs = model(**encodings)
82
  logits = outputs.logits
83
  predictions = logits.argmax(axis=-1).cpu().numpy()
84
- return int(predictions[0])
85
-
86
 
87
  # Unified function for sentiment analysis and statistics
88
  def analyze_sentiment_and_statistics(text):
@@ -95,22 +93,28 @@ def analyze_sentiment_and_statistics(text):
95
 
96
  # Calculate statistics
97
  scores = list(results.values())
98
- min_score_model = min(results, key=results.get)
99
- max_score_model = max(results, key=results.get)
100
- average_score = np.mean(scores)
101
-
102
- statistics = {
103
- "Lowest Score": f"{results[min_score_model]} (Model: {min_score_model})",
104
- "Highest Score": f"{results[max_score_model]} (Model: {max_score_model})",
105
- "Average Score": f"{average_score:.2f}",
106
- }
 
 
 
 
 
 
107
  return results, statistics
108
 
109
  # Gradio Interface
110
  with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding: 20px; }") as demo:
111
  gr.Markdown("# Sentiment Analysis App")
112
  gr.Markdown(
113
- "This app predicts the sentiment of the input text on a scale from 1 to 5 using multiple models and provides detailed statistics."
114
  )
115
 
116
  with gr.Row():
@@ -150,7 +154,7 @@ with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding:
150
  with gr.Column():
151
  distilbert_output = gr.Textbox(label="Predicted Sentiment (DistilBERT)", interactive=False)
152
  log_reg_output = gr.Textbox(label="Predicted Sentiment (Logistic Regression)", interactive=False)
153
- bert_output = gr.Textbox(label="Predicted Sentiment (BERT Multilingual)", interactive=False),
154
  tinybert_output = gr.Textbox(label="Predicted Sentiment (TinyBERT)", interactive=False)
155
 
156
  with gr.Column():
@@ -159,13 +163,22 @@ with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding:
159
  # Button to analyze sentiment and show statistics
160
  def process_input_and_analyze(text_input):
161
  results, statistics = analyze_sentiment_and_statistics(text_input)
162
- return (
163
- f"{results['DistilBERT']}",
164
- f"{results['Logistic Regression']}",
165
- f"{results['BERT Multilingual (NLP Town)']}",
166
- f"{results['TinyBERT']}",
167
- f"Statistics:\n{statistics['Lowest Score']}\n{statistics['Highest Score']}\nAverage Score: {statistics['Average Score']}"
168
- )
 
 
 
 
 
 
 
 
 
169
 
170
  analyze_button.click(
171
  process_input_and_analyze,
@@ -173,7 +186,5 @@ with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding:
173
  outputs=[distilbert_output, log_reg_output, bert_output, tinybert_output, statistics_output]
174
  )
175
 
176
-
177
-
178
  # Launch the app
179
  demo.launch()
 
9
  import torch
10
  import pickle
11
  import numpy as np
 
12
 
13
  # Load models and tokenizers
14
  models = {
 
75
  def predict_with_tinybert(text):
76
  tokenizer = models["TinyBERT"]["tokenizer"]
77
  model = models["TinyBERT"]["model"]
78
+ encodings = tokenizer([text], padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
79
  with torch.no_grad():
80
  outputs = model(**encodings)
81
  logits = outputs.logits
82
  predictions = logits.argmax(axis=-1).cpu().numpy()
83
+ return int(predictions[0] + 1)
 
84
 
85
  # Unified function for sentiment analysis and statistics
86
  def analyze_sentiment_and_statistics(text):
 
93
 
94
  # Calculate statistics
95
  scores = list(results.values())
96
+ if all(score == scores[0] for score in scores): # Check if all predictions are the same
97
+ statistics = {
98
+ "Message": "All models predict the same score.",
99
+ "Average Score": f"{scores[0]:.2f}",
100
+ }
101
+ else:
102
+ min_score_model = min(results, key=results.get)
103
+ max_score_model = max(results, key=results.get)
104
+ average_score = np.mean(scores)
105
+
106
+ statistics = {
107
+ "Lowest Score": f"{results[min_score_model]} (Model: {min_score_model})",
108
+ "Highest Score": f"{results[max_score_model]} (Model: {max_score_model})",
109
+ "Average Score": f"{average_score:.2f}",
110
+ }
111
  return results, statistics
112
 
113
  # Gradio Interface
114
  with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding: 20px; }") as demo:
115
  gr.Markdown("# Sentiment Analysis App")
116
  gr.Markdown(
117
+ "This app predicts the sentiment of the input text on a scale from 1 to 5 using multiple models and provides basic statistics."
118
  )
119
 
120
  with gr.Row():
 
154
  with gr.Column():
155
  distilbert_output = gr.Textbox(label="Predicted Sentiment (DistilBERT)", interactive=False)
156
  log_reg_output = gr.Textbox(label="Predicted Sentiment (Logistic Regression)", interactive=False)
157
+ bert_output = gr.Textbox(label="Predicted Sentiment (BERT Multilingual)", interactive=False)
158
  tinybert_output = gr.Textbox(label="Predicted Sentiment (TinyBERT)", interactive=False)
159
 
160
  with gr.Column():
 
163
  # Button to analyze sentiment and show statistics
164
  def process_input_and_analyze(text_input):
165
  results, statistics = analyze_sentiment_and_statistics(text_input)
166
+ if "Message" in statistics: # All models predicted the same score
167
+ return (
168
+ f"{results['DistilBERT']}",
169
+ f"{results['Logistic Regression']}",
170
+ f"{results['BERT Multilingual (NLP Town)']}",
171
+ f"{results['TinyBERT']}",
172
+ f"Statistics:\n{statistics['Message']}\nAverage Score: {statistics['Average Score']}"
173
+ )
174
+ else: # Min and Max scores are present
175
+ return (
176
+ f"{results['DistilBERT']}",
177
+ f"{results['Logistic Regression']}",
178
+ f"{results['BERT Multilingual (NLP Town)']}",
179
+ f"{results['TinyBERT']}",
180
+ f"Statistics:\n{statistics['Lowest Score']}\n{statistics['Highest Score']}\nAverage Score: {statistics['Average Score']}"
181
+ )
182
 
183
  analyze_button.click(
184
  process_input_and_analyze,
 
186
  outputs=[distilbert_output, log_reg_output, bert_output, tinybert_output, statistics_output]
187
  )
188
 
 
 
189
  # Launch the app
190
  demo.launch()