zaidmehdi commited on
Commit
3698678
1 Parent(s): f9a4b3a

refactoring notebook code

Browse files
Files changed (1) hide show
  1. src/model_training.py +120 -0
src/model_training.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import DatasetDict, Dataset
2
+ import numpy as np
3
+ import pandas as pd
4
+ import torch
5
+ from sklearn.linear_model import LogisticRegression
6
+ from sklearn.metrics import accuracy_score, f1_score
7
+ from transformers import AutoModel, AutoTokenizer
8
+
9
+ from .utils import serialize_data, load_data
10
+
11
+
12
+ class PreProcessor:
13
+ def __init__(self, model_name, train_path:str, test_path:str, output_path:str):
14
+ self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
15
+ self.model = AutoModel.from_pretrained(model_name).to(self.device)
16
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+ self.df_train = pd.read_csv(train_path, sep="\t")
18
+ self.df_test = pd.read_csv(test_path, sep="\t")
19
+ self.output_path = output_path
20
+
21
+ def _get_datasetdict_object(self):
22
+ mapper = {"#2_tweet": "tweet", "#3_country_label": "label"}
23
+ columns_to_keep = ["tweet", "label"]
24
+
25
+ df_train = self.df_train.rename(columns=mapper)[columns_to_keep]
26
+ df_test = self.df_test.rename(columns=mapper)[columns_to_keep]
27
+
28
+ train_dataset = Dataset.from_pandas(df_train)
29
+ test_dataset = Dataset.from_pandas(df_test)
30
+ data = DatasetDict({'train': train_dataset, 'test': test_dataset})
31
+
32
+ return data
33
+
34
+ def _tokenize(self, batch):
35
+ return self.tokenizer(batch["tweet"], padding=True)
36
+
37
+ def _encode_data(self, data):
38
+ data_encoded = data.map(self._tokenize, batched=True, batch_size=None)
39
+ return data_encoded
40
+
41
+ def _extract_hidden_states(self, batch):
42
+ inputs = {k:v.to(self.device) for k,v in batch.items()
43
+ if k in self.tokenizer.model_input_names}
44
+ with torch.no_grad():
45
+ last_hidden_state = self.model(**inputs).last_hidden_state
46
+
47
+ return {"hidden_state": last_hidden_state[:,0].cpu().numpy()}
48
+
49
+ def _get_features(self, data_encoded):
50
+ data_encoded.set_format("torch", columns=["input_ids", "attention_mask", "label"])
51
+ data_hidden = data_encoded.map(self._extract_hidden_states, batched=True, batch_size=50)
52
+ return data_hidden
53
+
54
+ def preprocess_data(self):
55
+ data = self._get_datasetdict_object()
56
+ data_encoded = self._encode_data(data)
57
+ data_hidden = self._get_features(data_encoded)
58
+ serialize_data(data_hidden, output_path=self.output_path)
59
+
60
+
61
+ class Model():
62
+ def __init__(self, data_input_path:str, model_name:str):
63
+ self.model_name = model_name
64
+ self.model = None
65
+ self.data = load_data(input_path=data_input_path)
66
+ self.X_train = np.array(self.data["train"]["hidden_state"])
67
+ self.X_test = np.array(self.data["test"]["hidden_state"])
68
+ self.y_train = np.array(self.data["train"]["label"])
69
+ self.y_test = np.array(self.data["test"]["label"])
70
+
71
+ def _train_logistic_regression(X_train, y_train):
72
+ lr_model = LogisticRegression(multi_class='multinomial',
73
+ class_weight="balanced",
74
+ max_iter=1000,
75
+ random_state=2024)
76
+ lr_model.fit(X_train, y_train)
77
+ return lr_model
78
+
79
+ def train_model(self, output_path):
80
+ if self.model_name != "lr":
81
+ raise ValueError(f"Model name {self.model_name} does not exist. Please try 'lr'!")
82
+
83
+ lr_model = self._train_logistic_regression(self.X_train, self.y_train)
84
+ self.model = lr_model
85
+ serialize_data(lr_model, output_path)
86
+
87
+ def _get_metrics(self, y_true, y_preds):
88
+ accuracy = accuracy_score(y_true, y_preds)
89
+ f1_macro = f1_score(y_true, y_preds, average="macro")
90
+ f1_weighted = f1_score(y_true, y_preds, average="weighted")
91
+ print(f"Accuracy: {accuracy}")
92
+ print(f"F1 macro average: {f1_macro}")
93
+ print(f"F1 weighted average: {f1_weighted}")
94
+
95
+ def evaluate_predictions(self):
96
+ train_preds = self.model.predict(self.X_train)
97
+ test_preds = self.model.predict(self.X_test)
98
+
99
+ print(self.model_name)
100
+ print("\nTrain set:")
101
+ self._get_metrics(self.y_train, train_preds)
102
+ print("-"*50)
103
+ print("Test set:")
104
+ self._get_metrics(self.y_test, test_preds)
105
+
106
+
107
+ def main():
108
+ file_path = "../data/data_hidden.pkl"
109
+ preprocessor = PreProcessor(model_name="moussaKam/AraBART",
110
+ train_path="../data/DA_train_labeled.tsv",
111
+ test_path="../data/DA_dev_labeled.tsv",
112
+ output_path=file_path)
113
+ preprocessor.preprocess_data()
114
+ model = Model(data_input_path=file_path, model_name="lr")
115
+ model.train_model("../models/logistic_regression.pkl")
116
+ model.evaluate_predictions()
117
+
118
+ if __name__ == "__main__":
119
+ main()
120
+