File size: 2,681 Bytes
ec4a7b0
c1f16ee
7ecfa8c
a025587
 
 
9a23b5c
a025587
99757c1
 
747f8ea
a025587
 
 
ec4a7b0
a025587
 
ec4a7b0
 
9d724cb
 
a025587
ec4a7b0
 
99757c1
a025587
 
 
 
747f8ea
 
 
 
 
 
 
 
 
 
 
7ecfa8c
a025587
 
 
 
2d9e152
9365c1c
2d9e152
c1f16ee
a025587
d93cfa5
 
 
 
 
 
 
 
 
 
 
 
 
22914d3
d93cfa5
22914d3
d93cfa5
 
 
22914d3
d93cfa5
 
 
 
 
47e7f1f
6c4846e
c1f16ee
 
 
d93cfa5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os

import gradio as gr
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification

from .utils import load_data


# Load model
model_name = "moussaKam/AraBART"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=21)

models_dir = os.path.join(os.path.dirname(__file__), '..', 'models')
model_file = os.path.join(models_dir, 'best_model_checkpoint.pth')
if os.path.exists(model_file):
    with open(model_file, "rb") as f:
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        checkpoint = torch.load(model_file, map_location=device)
        model.load_state_dict(checkpoint)
else:
    print(f"Error: {model_file} not found.")

# Load label encoder
encoder_file = os.path.join(models_dir, 'label_encoder.pkl')
label_encoder = load_data(encoder_file)

# Load html
html_dir = os.path.join(os.path.dirname(__file__), "templates")
index_html_path = os.path.join(html_dir, "index.html")

if os.path.exists(index_html_path):
    with open(index_html_path, "r") as html_file:
        index_html = html_file.read()
else:
    print(f"Error: {index_html_path} not found.")


def classify_arabic_dialect(text):
    tokenized_text = tokenizer(text, return_tensors="pt")
    output = model(**tokenized_text)
    probabilities = F.softmax(output.logits, dim=1)[0]
    labels = label_encoder.inverse_transform(range(len(probabilities)))
    predictions = {labels[i]: probabilities[i] for i in range(len(probabilities))}

    return predictions


def main():
    with gr.Blocks() as demo:
        gr.HTML(index_html)

        input_text = gr.Textbox(label="Your Arabic Text")
        submit_btn = gr.Button("Submit")
        predictions = gr.Label(num_top_classes=3)
        submit_btn.click(
            fn=classify_arabic_dialect, 
            inputs=input_text, 
            outputs=predictions)
        
        gr.Markdown("## Text Examples")
        gr.Examples(
            examples=[
                "الله يعطيك الصحة هاد الطاجين بنين تبارك لله",
                "بصح راك فاهم لازم الزيت",
                "حضرتك بروح زي كدا؟ على طول النهار ده",
            ],
            inputs=input_text)
        gr.HTML("""
                <p style="text-align: center;font-size: large;">
                Checkout the <a href="https://github.com/zaidmehdi/arabic-dialect-classifier">Github Repo</a>
                </p>
                """)
    
    demo.launch(server_port=8080)


if __name__ == "__main__":
    main()