nhull commited on
Commit
35f0d17
·
verified ·
1 Parent(s): 0257e1e

Add all other models, reorganize sections, add description

Browse files
Files changed (1) hide show
  1. app.py +171 -67
app.py CHANGED
@@ -16,7 +16,60 @@ from tensorflow.keras.models import load_model
16
  from tensorflow.keras.preprocessing.sequence import pad_sequences
17
  import re
18
 
19
- # Load pre-trained models and tokenizers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  models = {
21
  "DistilBERT": {
22
  "tokenizer": DistilBertTokenizerFast.from_pretrained("nhull/distilbert-sentiment-model"),
@@ -37,49 +90,23 @@ models = {
37
  }
38
  }
39
 
40
- # Load logistic regression model and vectorizer
41
  logistic_regression_repo = "nhull/logistic-regression-model"
42
-
43
- # Download and load logistic regression model
44
  log_reg_model_path = hf_hub_download(repo_id=logistic_regression_repo, filename="logistic_regression_model.pkl")
45
  with open(log_reg_model_path, "rb") as model_file:
46
  log_reg_model = pickle.load(model_file)
47
 
48
- # Download and load TF-IDF vectorizer
49
  vectorizer_path = hf_hub_download(repo_id=logistic_regression_repo, filename="tfidf_vectorizer.pkl")
50
  with open(vectorizer_path, "rb") as vectorizer_file:
51
  vectorizer = pickle.load(vectorizer_file)
52
 
53
- # Move HuggingFace models to device (if GPU is available)
54
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
  for model_data in models.values():
56
  if "model" in model_data:
57
  model_data["model"].to(device)
58
 
59
- # Load GRU model and tokenizer
60
- gru_repo_id = "arjahojnik/GRU-sentiment-model"
61
- gru_model_path = hf_hub_download(repo_id=gru_repo_id, filename="best_GRU_tuning_model.h5")
62
- gru_model = load_model(gru_model_path)
63
- gru_tokenizer_path = hf_hub_download(repo_id=gru_repo_id, filename="my_tokenizer.pkl")
64
- with open(gru_tokenizer_path, "rb") as f:
65
- gru_tokenizer = pickle.load(f)
66
-
67
- # Preprocessing function for GRU
68
- def preprocess_text(text):
69
- text = text.lower()
70
- text = re.sub(r"[^a-zA-Z\s]", "", text).strip()
71
- return text
72
-
73
- # GRU prediction function
74
- def predict_with_gru(text):
75
- cleaned = preprocess_text(text)
76
- seq = gru_tokenizer.texts_to_sequences([cleaned])
77
- padded_seq = pad_sequences(seq, maxlen=200) # Ensure maxlen matches the GRU training
78
- probs = gru_model.predict(padded_seq)
79
- predicted_class = np.argmax(probs, axis=1)[0]
80
- return int(predicted_class + 1)
81
-
82
- # Functions for other model predictions
83
  def predict_with_distilbert(text):
84
  tokenizer = models["DistilBERT"]["tokenizer"]
85
  model = models["DistilBERT"]["model"]
@@ -125,18 +152,18 @@ def predict_with_roberta_ordek899(text):
125
  predictions = logits.argmax(axis=-1).cpu().numpy()
126
  return int(predictions[0] + 1)
127
 
128
- # Unified function for sentiment analysis and statistics
129
  def analyze_sentiment_and_statistics(text):
130
  results = {
 
131
  "GRU Model": predict_with_gru(text),
 
 
132
  "DistilBERT": predict_with_distilbert(text),
133
- "Logistic Regression": predict_with_logistic_regression(text),
134
  "BERT Multilingual (NLP Town)": predict_with_bert_multilingual(text),
135
  "TinyBERT": predict_with_tinybert(text),
136
  "RoBERTa": predict_with_roberta_ordek899(text),
137
  }
138
-
139
- # Calculate statistics
140
  scores = list(results.values())
141
  min_score = min(scores)
142
  max_score = max(scores)
@@ -158,12 +185,64 @@ def analyze_sentiment_and_statistics(text):
158
  return results, statistics
159
 
160
  # Gradio Interface
161
- with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding: 20px; }") as demo:
162
- gr.Markdown("# Sentiment Analysis App")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  gr.Markdown(
164
- "This app predicts the sentiment of the input text on a scale from 1 to 5 using multiple models and provides basic statistics."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  )
166
-
167
  with gr.Row():
168
  with gr.Column():
169
  text_input = gr.Textbox(
@@ -184,7 +263,6 @@ with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding:
184
  interactive=True
185
  )
186
 
187
- # Sync dropdown with text input
188
  def update_textbox(selected_sample):
189
  return selected_sample
190
 
@@ -193,43 +271,68 @@ with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding:
193
  inputs=[sample_dropdown],
194
  outputs=[text_input]
195
  )
 
196
 
 
197
  with gr.Column():
198
- analyze_button = gr.Button("Analyze Sentiment")
 
199
 
200
- with gr.Row():
201
  with gr.Column():
202
- gru_output = gr.Textbox(label="Predicted Sentiment (GRU Model)", interactive=False)
203
- distilbert_output = gr.Textbox(label="Predicted Sentiment (DistilBERT)", interactive=False)
204
- log_reg_output = gr.Textbox(label="Predicted Sentiment (Logistic Regression)", interactive=False)
205
- bert_output = gr.Textbox(label="Predicted Sentiment (BERT Multilingual)", interactive=False)
206
- tinybert_output = gr.Textbox(label="Predicted Sentiment (TinyBERT)", interactive=False)
207
- roberta_ordek_output = gr.Textbox(label="Predicted Sentiment (RoBERTa)", interactive=False)
208
-
209
  with gr.Column():
210
- statistics_output = gr.Textbox(label="Statistics (Lowest, Highest, Average)", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
- # Button to analyze sentiment and show statistics
213
  def process_input_and_analyze(text_input):
214
  results, statistics = analyze_sentiment_and_statistics(text_input)
215
  if "Message" in statistics:
216
  return (
217
- f"{results['GRU Model']}",
218
- f"{results['DistilBERT']}",
219
- f"{results['Logistic Regression']}",
220
- f"{results['BERT Multilingual (NLP Town)']}",
221
- f"{results['TinyBERT']}",
222
- f"{results['RoBERTa']}",
 
 
223
  f"Statistics:\n{statistics['Message']}\nAverage Score: {statistics['Average Score']}"
224
  )
225
  else:
226
  return (
227
- f"{results['GRU Model']}",
228
- f"{results['DistilBERT']}",
229
- f"{results['Logistic Regression']}",
230
- f"{results['BERT Multilingual (NLP Town)']}",
231
- f"{results['TinyBERT']}",
232
- f"{results['RoBERTa']}",
 
 
233
  f"Statistics:\n{statistics['Lowest Score']}\n{statistics['Highest Score']}\nAverage Score: {statistics['Average Score']}"
234
  )
235
 
@@ -237,15 +340,16 @@ with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding:
237
  process_input_and_analyze,
238
  inputs=[text_input],
239
  outputs=[
240
- gru_output,
241
- distilbert_output,
242
  log_reg_output,
 
 
 
 
243
  bert_output,
244
  tinybert_output,
245
- roberta_ordek_output,
246
- statistics_output
247
  ]
248
  )
249
 
250
- # Launch the app
251
  demo.launch()
 
16
  from tensorflow.keras.preprocessing.sequence import pad_sequences
17
  import re
18
 
19
+ # Load GRU, LSTM, and BiLSTM models and tokenizers
20
+ gru_repo_id = "arjahojnik/GRU-sentiment-model"
21
+ gru_model_path = hf_hub_download(repo_id=gru_repo_id, filename="best_GRU_tuning_model.h5")
22
+ gru_model = load_model(gru_model_path)
23
+ gru_tokenizer_path = hf_hub_download(repo_id=gru_repo_id, filename="my_tokenizer.pkl")
24
+ with open(gru_tokenizer_path, "rb") as f:
25
+ gru_tokenizer = pickle.load(f)
26
+
27
+ lstm_repo_id = "arjahojnik/LSTM-sentiment-model"
28
+ lstm_model_path = hf_hub_download(repo_id=lstm_repo_id, filename="LSTM_model.h5")
29
+ lstm_model = load_model(lstm_model_path)
30
+ lstm_tokenizer_path = hf_hub_download(repo_id=lstm_repo_id, filename="my_tokenizer.pkl")
31
+ with open(lstm_tokenizer_path, "rb") as f:
32
+ lstm_tokenizer = pickle.load(f)
33
+
34
+ bilstm_repo_id = "arjahojnik/BiLSTM-sentiment-model"
35
+ bilstm_model_path = hf_hub_download(repo_id=bilstm_repo_id, filename="BiLSTM_model.h5")
36
+ bilstm_model = load_model(bilstm_model_path)
37
+ bilstm_tokenizer_path = hf_hub_download(repo_id=bilstm_repo_id, filename="my_tokenizer.pkl")
38
+ with open(bilstm_tokenizer_path, "rb") as f:
39
+ bilstm_tokenizer = pickle.load(f)
40
+
41
+ # Preprocessing function for text
42
+ def preprocess_text(text):
43
+ text = text.lower()
44
+ text = re.sub(r"[^a-zA-Z\s]", "", text).strip()
45
+ return text
46
+
47
+ # Prediction functions for GRU, LSTM, and BiLSTM
48
+ def predict_with_gru(text):
49
+ cleaned = preprocess_text(text)
50
+ seq = gru_tokenizer.texts_to_sequences([cleaned])
51
+ padded_seq = pad_sequences(seq, maxlen=200)
52
+ probs = gru_model.predict(padded_seq)
53
+ predicted_class = np.argmax(probs, axis=1)[0]
54
+ return int(predicted_class + 1)
55
+
56
+ def predict_with_lstm(text):
57
+ cleaned = preprocess_text(text)
58
+ seq = lstm_tokenizer.texts_to_sequences([cleaned])
59
+ padded_seq = pad_sequences(seq, maxlen=200)
60
+ probs = lstm_model.predict(padded_seq)
61
+ predicted_class = np.argmax(probs, axis=1)[0]
62
+ return int(predicted_class + 1)
63
+
64
+ def predict_with_bilstm(text):
65
+ cleaned = preprocess_text(text)
66
+ seq = bilstm_tokenizer.texts_to_sequences([cleaned])
67
+ padded_seq = pad_sequences(seq, maxlen=200)
68
+ probs = bilstm_model.predict(padded_seq)
69
+ predicted_class = np.argmax(probs, axis=1)[0]
70
+ return int(predicted_class + 1)
71
+
72
+ # Load other models
73
  models = {
74
  "DistilBERT": {
75
  "tokenizer": DistilBertTokenizerFast.from_pretrained("nhull/distilbert-sentiment-model"),
 
90
  }
91
  }
92
 
93
+ # Logistic regression model and TF-IDF vectorizer
94
  logistic_regression_repo = "nhull/logistic-regression-model"
 
 
95
  log_reg_model_path = hf_hub_download(repo_id=logistic_regression_repo, filename="logistic_regression_model.pkl")
96
  with open(log_reg_model_path, "rb") as model_file:
97
  log_reg_model = pickle.load(model_file)
98
 
 
99
  vectorizer_path = hf_hub_download(repo_id=logistic_regression_repo, filename="tfidf_vectorizer.pkl")
100
  with open(vectorizer_path, "rb") as vectorizer_file:
101
  vectorizer = pickle.load(vectorizer_file)
102
 
103
+ # Move HuggingFace models to device
104
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
105
  for model_data in models.values():
106
  if "model" in model_data:
107
  model_data["model"].to(device)
108
 
109
+ # Prediction functions for other models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  def predict_with_distilbert(text):
111
  tokenizer = models["DistilBERT"]["tokenizer"]
112
  model = models["DistilBERT"]["model"]
 
152
  predictions = logits.argmax(axis=-1).cpu().numpy()
153
  return int(predictions[0] + 1)
154
 
155
+ # Unified function for analysis
156
  def analyze_sentiment_and_statistics(text):
157
  results = {
158
+ "Logistic Regression": predict_with_logistic_regression(text),
159
  "GRU Model": predict_with_gru(text),
160
+ "LSTM Model": predict_with_lstm(text),
161
+ "BiLSTM Model": predict_with_bilstm(text),
162
  "DistilBERT": predict_with_distilbert(text),
 
163
  "BERT Multilingual (NLP Town)": predict_with_bert_multilingual(text),
164
  "TinyBERT": predict_with_tinybert(text),
165
  "RoBERTa": predict_with_roberta_ordek899(text),
166
  }
 
 
167
  scores = list(results.values())
168
  min_score = min(scores)
169
  max_score = max(scores)
 
185
  return results, statistics
186
 
187
  # Gradio Interface
188
+ with gr.Blocks(
189
+ css="""
190
+ .gradio-container {
191
+ max-width: 900px;
192
+ margin: auto;
193
+ padding: 20px;
194
+ background-color: #1e1e1e; /* Dark background for contrast */
195
+ color: white; /* White text throughout */
196
+ }
197
+ h1 {
198
+ text-align: center;
199
+ font-size: 2.5rem;
200
+ color: white; /* White text for title */
201
+ }
202
+ footer {
203
+ text-align: center;
204
+ margin-top: 20px;
205
+ font-size: 14px;
206
+ color: white; /* White text for footer */
207
+ }
208
+ .gr-button {
209
+ background-color: #4a4a4a; /* Dark gray button background */
210
+ color: white; /* White button text */
211
+ border-radius: 8px; /* Rounded buttons */
212
+ padding: 10px 20px;
213
+ font-weight: bold;
214
+ transition: background-color 0.3s ease;
215
+ }
216
+ .gr-button:hover {
217
+ background-color: #6a6a6a; /* Slightly lighter gray on hover */
218
+ }
219
+ .gr-textbox, .gr-dropdown, .gr-output {
220
+ border: 1px solid #4a4a4a; /* Subtle gray border */
221
+ border-radius: 8px; /* Rounded edges */
222
+ background-color: #2e2e2e; /* Darker gray input background */
223
+ color: white; /* White text for inputs/outputs */
224
+ }
225
+ """
226
+ ) as demo:
227
+ gr.Markdown("# Sentiment Analysis Demo")
228
  gr.Markdown(
229
+ """
230
+ This demo analyzes the sentiment of text inputs (e.g., hotel or restaurant reviews) on a scale from 1 to 5 using various machine learning, deep learning, and transformer-based models.
231
+
232
+ - **Machine Learning**: Logistic Regression with TF-IDF.
233
+ - **Deep Learning**: GRU, LSTM, and BiLSTM models.
234
+ - **Transformers**: DistilBERT, TinyBERT, BERT Multilingual, and RoBERTa.
235
+
236
+ ### Features:
237
+ - Compare predictions across different models.
238
+ - See which model predicts the highest and lowest scores.
239
+ - Get the average sentiment score across all models.
240
+ - Easily test with your own input or select from suggested reviews.
241
+
242
+ Use this app to explore how different models interpret sentiment and compare their outputs!
243
+ """
244
  )
245
+
246
  with gr.Row():
247
  with gr.Column():
248
  text_input = gr.Textbox(
 
263
  interactive=True
264
  )
265
 
 
266
  def update_textbox(selected_sample):
267
  return selected_sample
268
 
 
271
  inputs=[sample_dropdown],
272
  outputs=[text_input]
273
  )
274
+ analyze_button = gr.Button("Analyze Sentiment")
275
 
276
+ with gr.Row():
277
  with gr.Column():
278
+ gr.Markdown("### Machine Learning")
279
+ log_reg_output = gr.Textbox(label="Logistic Regression", interactive=False)
280
 
 
281
  with gr.Column():
282
+ gr.Markdown("### Deep Learning")
283
+ gru_output = gr.Textbox(label="GRU Model", interactive=False)
284
+ lstm_output = gr.Textbox(label="LSTM Model", interactive=False)
285
+ bilstm_output = gr.Textbox(label="BiLSTM Model", interactive=False)
286
+
 
 
287
  with gr.Column():
288
+ gr.Markdown("### Transformers")
289
+ distilbert_output = gr.Textbox(label="DistilBERT", interactive=False)
290
+ bert_output = gr.Textbox(label="BERT Multilingual", interactive=False)
291
+ tinybert_output = gr.Textbox(label="TinyBERT", interactive=False)
292
+ roberta_output = gr.Textbox(label="RoBERTa", interactive=False)
293
+
294
+ with gr.Row():
295
+ with gr.Column():
296
+ gr.Markdown("### Statistics")
297
+ stats_output = gr.Textbox(label="Statistics", interactive=False)
298
+
299
+ # Add footer
300
+ gr.Markdown(
301
+ """
302
+ <footer>
303
+ This demo was built as a part of the NLP course at the University of Zagreb.
304
+ Check out our GitHub repository:
305
+ <a href="https://github.com/FFZG-NLP-2024/TripAdvisor-Sentiment/" target="_blank" style="color: white; text-decoration: underline;">TripAdvisor Sentiment Analysis</a>
306
+ Explore our HuggingFace collection:
307
+ <a href="https://huggingface.co/collections/nhull/nlp-zg-6794604b85fd4216e6470d38" target="_blank" style="color: white; text-decoration: underline;">NLP Zagreb HuggingFace Collection</a>
308
+ </footer>
309
+ """
310
+ )
311
 
 
312
  def process_input_and_analyze(text_input):
313
  results, statistics = analyze_sentiment_and_statistics(text_input)
314
  if "Message" in statistics:
315
  return (
316
+ results["Logistic Regression"],
317
+ results["GRU Model"],
318
+ results["LSTM Model"],
319
+ results["BiLSTM Model"],
320
+ results["DistilBERT"],
321
+ results["BERT Multilingual (NLP Town)"],
322
+ results["TinyBERT"],
323
+ results["RoBERTa"],
324
  f"Statistics:\n{statistics['Message']}\nAverage Score: {statistics['Average Score']}"
325
  )
326
  else:
327
  return (
328
+ results["Logistic Regression"],
329
+ results["GRU Model"],
330
+ results["LSTM Model"],
331
+ results["BiLSTM Model"],
332
+ results["DistilBERT"],
333
+ results["BERT Multilingual (NLP Town)"],
334
+ results["TinyBERT"],
335
+ results["RoBERTa"],
336
  f"Statistics:\n{statistics['Lowest Score']}\n{statistics['Highest Score']}\nAverage Score: {statistics['Average Score']}"
337
  )
338
 
 
340
  process_input_and_analyze,
341
  inputs=[text_input],
342
  outputs=[
 
 
343
  log_reg_output,
344
+ gru_output,
345
+ lstm_output,
346
+ bilstm_output,
347
+ distilbert_output,
348
  bert_output,
349
  tinybert_output,
350
+ roberta_output,
351
+ stats_output
352
  ]
353
  )
354
 
 
355
  demo.launch()