zaidmehdi commited on
Commit
747f8ea
1 Parent(s): 755fdb5

customizing layout + adding html

Browse files
Files changed (2) hide show
  1. src/main.py +27 -5
  2. src/templates/index.html +10 -0
src/main.py CHANGED
@@ -7,6 +7,7 @@ from transformers import AutoModel, AutoTokenizer
7
  from .utils import extract_hidden_state
8
 
9
 
 
10
  models_dir = os.path.join(os.path.dirname(__file__), '..', 'models')
11
  model_file = os.path.join(models_dir, 'logistic_regression.pkl')
12
 
@@ -16,21 +17,42 @@ if os.path.exists(model_file):
16
  else:
17
  print(f"Error: {model_file} not found.")
18
 
 
 
 
 
 
 
 
 
 
 
 
19
  model_name = "moussaKam/AraBART"
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
21
  language_model = AutoModel.from_pretrained(model_name)
22
 
 
23
  def classify_arabic_dialect(text):
24
  text_embeddings = extract_hidden_state(text, tokenizer, language_model)
25
  predicted_class = model.predict(text_embeddings)[0]
26
 
27
  return predicted_class
28
 
29
- demo = gr.Interface(
30
- fn=classify_arabic_dialect,
31
- inputs=["text"],
32
- outputs=["text"],
33
- )
 
 
 
 
 
 
 
 
 
34
 
35
 
36
  if __name__ == "__main__":
 
7
  from .utils import extract_hidden_state
8
 
9
 
10
+ # Load model
11
  models_dir = os.path.join(os.path.dirname(__file__), '..', 'models')
12
  model_file = os.path.join(models_dir, 'logistic_regression.pkl')
13
 
 
17
  else:
18
  print(f"Error: {model_file} not found.")
19
 
20
+ # Load html
21
+ html_dir = os.path.join(os.path.dirname(__file__), "templates")
22
+ index_html_path = os.path.join(html_dir, "index.html")
23
+
24
+ if os.path.exists(index_html_path):
25
+ with open(index_html_path, "r") as html_file:
26
+ index_html = html_file.read()
27
+ else:
28
+ print(f"Error: {index_html_path} not found.")
29
+
30
+ # Load pre-trained model
31
  model_name = "moussaKam/AraBART"
32
  tokenizer = AutoTokenizer.from_pretrained(model_name)
33
  language_model = AutoModel.from_pretrained(model_name)
34
 
35
+
36
  def classify_arabic_dialect(text):
37
  text_embeddings = extract_hidden_state(text, tokenizer, language_model)
38
  predicted_class = model.predict(text_embeddings)[0]
39
 
40
  return predicted_class
41
 
42
+
43
+ with gr.Blocks() as demo:
44
+ gr.HTML(index_html)
45
+ input_text = gr.Textbox(label="Your Arabic Text")
46
+ submit_btn = gr.Button("Submit")
47
+ with gr.Row():
48
+ first_country = gr.Textbox()
49
+ second_country = gr.Textbox()
50
+ third_country = gr.Textbox()
51
+ submit_btn.click(
52
+ fn=classify_arabic_dialect,
53
+ inputs=input_text,
54
+ outputs=[first_country, second_country, third_country])
55
+ gr.HTML("<p>Checkout the Github Repo:</p>")
56
 
57
 
58
  if __name__ == "__main__":
src/templates/index.html ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <title>Arabic Dialect Classifier</title>
5
+ </head>
6
+ <body>
7
+ <h1>Arabic Dialect Classifier</h1>
8
+ <p>Write some arabic text and get which dialect it is from</p>
9
+ </body>
10
+ </html>