nhull commited on
Commit
0257e1e
·
verified ·
1 Parent(s): 572334d

Add GRU model (fingers crossed it works)

Browse files
Files changed (1) hide show
  1. app.py +36 -2
app.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import gradio as gr
2
  from transformers import (
3
  DistilBertTokenizerFast,
@@ -9,8 +12,11 @@ from huggingface_hub import hf_hub_download
9
  import torch
10
  import pickle
11
  import numpy as np
 
 
 
12
 
13
- # Load models and tokenizers
14
  models = {
15
  "DistilBERT": {
16
  "tokenizer": DistilBertTokenizerFast.from_pretrained("nhull/distilbert-sentiment-model"),
@@ -50,7 +56,30 @@ for model_data in models.values():
50
  if "model" in model_data:
51
  model_data["model"].to(device)
52
 
53
- # Functions for prediction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def predict_with_distilbert(text):
55
  tokenizer = models["DistilBERT"]["tokenizer"]
56
  model = models["DistilBERT"]["model"]
@@ -99,6 +128,7 @@ def predict_with_roberta_ordek899(text):
99
  # Unified function for sentiment analysis and statistics
100
  def analyze_sentiment_and_statistics(text):
101
  results = {
 
102
  "DistilBERT": predict_with_distilbert(text),
103
  "Logistic Regression": predict_with_logistic_regression(text),
104
  "BERT Multilingual (NLP Town)": predict_with_bert_multilingual(text),
@@ -169,6 +199,7 @@ with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding:
169
 
170
  with gr.Row():
171
  with gr.Column():
 
172
  distilbert_output = gr.Textbox(label="Predicted Sentiment (DistilBERT)", interactive=False)
173
  log_reg_output = gr.Textbox(label="Predicted Sentiment (Logistic Regression)", interactive=False)
174
  bert_output = gr.Textbox(label="Predicted Sentiment (BERT Multilingual)", interactive=False)
@@ -183,6 +214,7 @@ with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding:
183
  results, statistics = analyze_sentiment_and_statistics(text_input)
184
  if "Message" in statistics:
185
  return (
 
186
  f"{results['DistilBERT']}",
187
  f"{results['Logistic Regression']}",
188
  f"{results['BERT Multilingual (NLP Town)']}",
@@ -192,6 +224,7 @@ with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding:
192
  )
193
  else:
194
  return (
 
195
  f"{results['DistilBERT']}",
196
  f"{results['Logistic Regression']}",
197
  f"{results['BERT Multilingual (NLP Town)']}",
@@ -204,6 +237,7 @@ with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding:
204
  process_input_and_analyze,
205
  inputs=[text_input],
206
  outputs=[
 
207
  distilbert_output,
208
  log_reg_output,
209
  bert_output,
 
1
+ import os
2
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Disable GPU and enforce CPU execution
3
+
4
  import gradio as gr
5
  from transformers import (
6
  DistilBertTokenizerFast,
 
12
  import torch
13
  import pickle
14
  import numpy as np
15
+ from tensorflow.keras.models import load_model
16
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
17
+ import re
18
 
19
+ # Load pre-trained models and tokenizers
20
  models = {
21
  "DistilBERT": {
22
  "tokenizer": DistilBertTokenizerFast.from_pretrained("nhull/distilbert-sentiment-model"),
 
56
  if "model" in model_data:
57
  model_data["model"].to(device)
58
 
59
+ # Load GRU model and tokenizer
60
+ gru_repo_id = "arjahojnik/GRU-sentiment-model"
61
+ gru_model_path = hf_hub_download(repo_id=gru_repo_id, filename="best_GRU_tuning_model.h5")
62
+ gru_model = load_model(gru_model_path)
63
+ gru_tokenizer_path = hf_hub_download(repo_id=gru_repo_id, filename="my_tokenizer.pkl")
64
+ with open(gru_tokenizer_path, "rb") as f:
65
+ gru_tokenizer = pickle.load(f)
66
+
67
+ # Preprocessing function for GRU
68
+ def preprocess_text(text):
69
+ text = text.lower()
70
+ text = re.sub(r"[^a-zA-Z\s]", "", text).strip()
71
+ return text
72
+
73
+ # GRU prediction function
74
+ def predict_with_gru(text):
75
+ cleaned = preprocess_text(text)
76
+ seq = gru_tokenizer.texts_to_sequences([cleaned])
77
+ padded_seq = pad_sequences(seq, maxlen=200) # Ensure maxlen matches the GRU training
78
+ probs = gru_model.predict(padded_seq)
79
+ predicted_class = np.argmax(probs, axis=1)[0]
80
+ return int(predicted_class + 1)
81
+
82
+ # Functions for other model predictions
83
  def predict_with_distilbert(text):
84
  tokenizer = models["DistilBERT"]["tokenizer"]
85
  model = models["DistilBERT"]["model"]
 
128
  # Unified function for sentiment analysis and statistics
129
  def analyze_sentiment_and_statistics(text):
130
  results = {
131
+ "GRU Model": predict_with_gru(text),
132
  "DistilBERT": predict_with_distilbert(text),
133
  "Logistic Regression": predict_with_logistic_regression(text),
134
  "BERT Multilingual (NLP Town)": predict_with_bert_multilingual(text),
 
199
 
200
  with gr.Row():
201
  with gr.Column():
202
+ gru_output = gr.Textbox(label="Predicted Sentiment (GRU Model)", interactive=False)
203
  distilbert_output = gr.Textbox(label="Predicted Sentiment (DistilBERT)", interactive=False)
204
  log_reg_output = gr.Textbox(label="Predicted Sentiment (Logistic Regression)", interactive=False)
205
  bert_output = gr.Textbox(label="Predicted Sentiment (BERT Multilingual)", interactive=False)
 
214
  results, statistics = analyze_sentiment_and_statistics(text_input)
215
  if "Message" in statistics:
216
  return (
217
+ f"{results['GRU Model']}",
218
  f"{results['DistilBERT']}",
219
  f"{results['Logistic Regression']}",
220
  f"{results['BERT Multilingual (NLP Town)']}",
 
224
  )
225
  else:
226
  return (
227
+ f"{results['GRU Model']}",
228
  f"{results['DistilBERT']}",
229
  f"{results['Logistic Regression']}",
230
  f"{results['BERT Multilingual (NLP Town)']}",
 
237
  process_input_and_analyze,
238
  inputs=[text_input],
239
  outputs=[
240
+ gru_output,
241
  distilbert_output,
242
  log_reg_output,
243
  bert_output,