{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Arabic Dialect Classifier\n", "This notebook contains the training of the classifier model. The goal is to classify the dialects at the country level." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/mehdi/miniconda3/envs/adc/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import pickle\n", "\n", "from datasets import DatasetDict, Dataset\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.model_selection import RandomizedSearchCV\n", "from sklearn.preprocessing import LabelEncoder\n", "import torch\n", "from transformers import AutoModel, AutoTokenizer\n", "import xgboost as xgb\n", "\n", "from utils import evaluate_predictions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Exploring the Dataset" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "df_train = pd.read_csv(\"../data/DA_train_labeled.tsv\", sep=\"\\t\")\n", "df_test = pd.read_csv(\"../data/DA_dev_labeled.tsv\", sep=\"\\t\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | #1_tweetid | \n", "#2_tweet | \n", "#3_country_label | \n", "#4_province_label | \n", "
---|---|---|---|---|
0 | \n", "TRAIN_0 | \n", "حاجة حلوة اكيد | \n", "Egypt | \n", "eg_Faiyum | \n", "
1 | \n", "TRAIN_1 | \n", "عم بشتغلوا للشعب الاميركي اما نحن يكذبوا ويغشو... | \n", "Iraq | \n", "iq_Dihok | \n", "
2 | \n", "TRAIN_2 | \n", "ابشر طال عمرك | \n", "Saudi_Arabia | \n", "sa_Ha'il | \n", "
3 | \n", "TRAIN_3 | \n", "منطق 2017: أنا والغريب علي إبن عمي وأنا والغري... | \n", "Mauritania | \n", "mr_Nouakchott | \n", "
4 | \n", "TRAIN_4 | \n", "شهرين وتروح والباقي غير صيف ملينا | \n", "Algeria | \n", "dz_El-Oued | \n", "
\n", " | #1_tweetid | \n", "#2_tweet | \n", "#3_country_label | \n", "#4_province_label | \n", "
---|---|---|---|---|
0 | \n", "DEV_0 | \n", "قولنا اون لاين لا يا علي اون لاين لا | \n", "Egypt | \n", "eg_Alexandria | \n", "
1 | \n", "DEV_1 | \n", "ههههه بايخه ههههه URL … | \n", "Oman | \n", "om_Muscat | \n", "
2 | \n", "DEV_2 | \n", "ربنا يخليك يا دوك ولك المثل :D | \n", "Lebanon | \n", "lb_South-Lebanon | \n", "
3 | \n", "DEV_3 | \n", "#اوامر_ملكيه ياشباب اي واحد فيكم عنده شي يذكره... | \n", "Syria | \n", "sy_Damascus-City | \n", "
4 | \n", "DEV_4 | \n", "شد عالخط حتى هيا اكويسه | \n", "Libya | \n", "ly_Misrata | \n", "