Spaces:
Sleeping
Sleeping
File size: 2,681 Bytes
ec4a7b0 c1f16ee 7ecfa8c a025587 9a23b5c a025587 99757c1 747f8ea a025587 ec4a7b0 a025587 ec4a7b0 9d724cb a025587 ec4a7b0 99757c1 a025587 747f8ea 7ecfa8c a025587 2d9e152 9365c1c 2d9e152 c1f16ee a025587 d93cfa5 22914d3 d93cfa5 22914d3 d93cfa5 22914d3 d93cfa5 47e7f1f 6c4846e c1f16ee d93cfa5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import os
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from .utils import load_data
# Load model
model_name = "moussaKam/AraBART"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=21)
models_dir = os.path.join(os.path.dirname(__file__), '..', 'models')
model_file = os.path.join(models_dir, 'best_model_checkpoint.pth')
if os.path.exists(model_file):
with open(model_file, "rb") as f:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
checkpoint = torch.load(model_file, map_location=device)
model.load_state_dict(checkpoint)
else:
print(f"Error: {model_file} not found.")
# Load label encoder
encoder_file = os.path.join(models_dir, 'label_encoder.pkl')
label_encoder = load_data(encoder_file)
# Load html
html_dir = os.path.join(os.path.dirname(__file__), "templates")
index_html_path = os.path.join(html_dir, "index.html")
if os.path.exists(index_html_path):
with open(index_html_path, "r") as html_file:
index_html = html_file.read()
else:
print(f"Error: {index_html_path} not found.")
def classify_arabic_dialect(text):
tokenized_text = tokenizer(text, return_tensors="pt")
output = model(**tokenized_text)
probabilities = F.softmax(output.logits, dim=1)[0]
labels = label_encoder.inverse_transform(range(len(probabilities)))
predictions = {labels[i]: probabilities[i] for i in range(len(probabilities))}
return predictions
def main():
with gr.Blocks() as demo:
gr.HTML(index_html)
input_text = gr.Textbox(label="Your Arabic Text")
submit_btn = gr.Button("Submit")
predictions = gr.Label(num_top_classes=3)
submit_btn.click(
fn=classify_arabic_dialect,
inputs=input_text,
outputs=predictions)
gr.Markdown("## Text Examples")
gr.Examples(
examples=[
"الله يعطيك الصحة هاد الطاجين بنين تبارك لله",
"بصح راك فاهم لازم الزيت",
"حضرتك بروح زي كدا؟ على طول النهار ده",
],
inputs=input_text)
gr.HTML("""
<p style="text-align: center;font-size: large;">
Checkout the <a href="https://github.com/zaidmehdi/arabic-dialect-classifier">Github Repo</a>
</p>
""")
demo.launch(server_port=8080)
if __name__ == "__main__":
main() |