Spaces:
Sleeping
Sleeping
refactoring notebook code
Browse files- src/model_training.py +120 -0
src/model_training.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import DatasetDict, Dataset
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
import torch
|
5 |
+
from sklearn.linear_model import LogisticRegression
|
6 |
+
from sklearn.metrics import accuracy_score, f1_score
|
7 |
+
from transformers import AutoModel, AutoTokenizer
|
8 |
+
|
9 |
+
from .utils import serialize_data, load_data
|
10 |
+
|
11 |
+
|
12 |
+
class PreProcessor:
|
13 |
+
def __init__(self, model_name, train_path:str, test_path:str, output_path:str):
|
14 |
+
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
15 |
+
self.model = AutoModel.from_pretrained(model_name).to(self.device)
|
16 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
17 |
+
self.df_train = pd.read_csv(train_path, sep="\t")
|
18 |
+
self.df_test = pd.read_csv(test_path, sep="\t")
|
19 |
+
self.output_path = output_path
|
20 |
+
|
21 |
+
def _get_datasetdict_object(self):
|
22 |
+
mapper = {"#2_tweet": "tweet", "#3_country_label": "label"}
|
23 |
+
columns_to_keep = ["tweet", "label"]
|
24 |
+
|
25 |
+
df_train = self.df_train.rename(columns=mapper)[columns_to_keep]
|
26 |
+
df_test = self.df_test.rename(columns=mapper)[columns_to_keep]
|
27 |
+
|
28 |
+
train_dataset = Dataset.from_pandas(df_train)
|
29 |
+
test_dataset = Dataset.from_pandas(df_test)
|
30 |
+
data = DatasetDict({'train': train_dataset, 'test': test_dataset})
|
31 |
+
|
32 |
+
return data
|
33 |
+
|
34 |
+
def _tokenize(self, batch):
|
35 |
+
return self.tokenizer(batch["tweet"], padding=True)
|
36 |
+
|
37 |
+
def _encode_data(self, data):
|
38 |
+
data_encoded = data.map(self._tokenize, batched=True, batch_size=None)
|
39 |
+
return data_encoded
|
40 |
+
|
41 |
+
def _extract_hidden_states(self, batch):
|
42 |
+
inputs = {k:v.to(self.device) for k,v in batch.items()
|
43 |
+
if k in self.tokenizer.model_input_names}
|
44 |
+
with torch.no_grad():
|
45 |
+
last_hidden_state = self.model(**inputs).last_hidden_state
|
46 |
+
|
47 |
+
return {"hidden_state": last_hidden_state[:,0].cpu().numpy()}
|
48 |
+
|
49 |
+
def _get_features(self, data_encoded):
|
50 |
+
data_encoded.set_format("torch", columns=["input_ids", "attention_mask", "label"])
|
51 |
+
data_hidden = data_encoded.map(self._extract_hidden_states, batched=True, batch_size=50)
|
52 |
+
return data_hidden
|
53 |
+
|
54 |
+
def preprocess_data(self):
|
55 |
+
data = self._get_datasetdict_object()
|
56 |
+
data_encoded = self._encode_data(data)
|
57 |
+
data_hidden = self._get_features(data_encoded)
|
58 |
+
serialize_data(data_hidden, output_path=self.output_path)
|
59 |
+
|
60 |
+
|
61 |
+
class Model():
|
62 |
+
def __init__(self, data_input_path:str, model_name:str):
|
63 |
+
self.model_name = model_name
|
64 |
+
self.model = None
|
65 |
+
self.data = load_data(input_path=data_input_path)
|
66 |
+
self.X_train = np.array(self.data["train"]["hidden_state"])
|
67 |
+
self.X_test = np.array(self.data["test"]["hidden_state"])
|
68 |
+
self.y_train = np.array(self.data["train"]["label"])
|
69 |
+
self.y_test = np.array(self.data["test"]["label"])
|
70 |
+
|
71 |
+
def _train_logistic_regression(X_train, y_train):
|
72 |
+
lr_model = LogisticRegression(multi_class='multinomial',
|
73 |
+
class_weight="balanced",
|
74 |
+
max_iter=1000,
|
75 |
+
random_state=2024)
|
76 |
+
lr_model.fit(X_train, y_train)
|
77 |
+
return lr_model
|
78 |
+
|
79 |
+
def train_model(self, output_path):
|
80 |
+
if self.model_name != "lr":
|
81 |
+
raise ValueError(f"Model name {self.model_name} does not exist. Please try 'lr'!")
|
82 |
+
|
83 |
+
lr_model = self._train_logistic_regression(self.X_train, self.y_train)
|
84 |
+
self.model = lr_model
|
85 |
+
serialize_data(lr_model, output_path)
|
86 |
+
|
87 |
+
def _get_metrics(self, y_true, y_preds):
|
88 |
+
accuracy = accuracy_score(y_true, y_preds)
|
89 |
+
f1_macro = f1_score(y_true, y_preds, average="macro")
|
90 |
+
f1_weighted = f1_score(y_true, y_preds, average="weighted")
|
91 |
+
print(f"Accuracy: {accuracy}")
|
92 |
+
print(f"F1 macro average: {f1_macro}")
|
93 |
+
print(f"F1 weighted average: {f1_weighted}")
|
94 |
+
|
95 |
+
def evaluate_predictions(self):
|
96 |
+
train_preds = self.model.predict(self.X_train)
|
97 |
+
test_preds = self.model.predict(self.X_test)
|
98 |
+
|
99 |
+
print(self.model_name)
|
100 |
+
print("\nTrain set:")
|
101 |
+
self._get_metrics(self.y_train, train_preds)
|
102 |
+
print("-"*50)
|
103 |
+
print("Test set:")
|
104 |
+
self._get_metrics(self.y_test, test_preds)
|
105 |
+
|
106 |
+
|
107 |
+
def main():
|
108 |
+
file_path = "../data/data_hidden.pkl"
|
109 |
+
preprocessor = PreProcessor(model_name="moussaKam/AraBART",
|
110 |
+
train_path="../data/DA_train_labeled.tsv",
|
111 |
+
test_path="../data/DA_dev_labeled.tsv",
|
112 |
+
output_path=file_path)
|
113 |
+
preprocessor.preprocess_data()
|
114 |
+
model = Model(data_input_path=file_path, model_name="lr")
|
115 |
+
model.train_model("../models/logistic_regression.pkl")
|
116 |
+
model.evaluate_predictions()
|
117 |
+
|
118 |
+
if __name__ == "__main__":
|
119 |
+
main()
|
120 |
+
|