zaidmehdi commited on
Commit
1964ece
1 Parent(s): 1c28db7

print test accuracu

Browse files
Files changed (2) hide show
  1. src/model_training.py +4 -2
  2. src/utils.py +21 -1
src/model_training.py CHANGED
@@ -4,7 +4,7 @@ from torch.utils.data import DataLoader
4
  from tqdm.auto import tqdm
5
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
 
7
- from utils import get_dataset, serialize_data, plot_training_history
8
 
9
 
10
  def train_model(model, optimizer, train_loader, val_loader, num_epochs=100, patience=10):
@@ -77,7 +77,7 @@ def train_model(model, optimizer, train_loader, val_loader, num_epochs=100, pati
77
  best_valid_loss = valid_loss
78
  epochs_no_improve = 0
79
  best_model = model.state_dict()
80
- torch.save(best_model, "best_model_checkpoint.pth")
81
  else:
82
  epochs_no_improve += 1
83
  if epochs_no_improve == patience:
@@ -121,6 +121,8 @@ def main():
121
  model, history = train_model(model, optimizer, train_loader, val_loader, num_epochs=num_epochs)
122
  plot_training_history(history)
123
 
 
 
124
 
125
  if __name__ == "__main__":
126
  main()
 
4
  from tqdm.auto import tqdm
5
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
 
7
+ from utils import get_dataset, serialize_data, plot_training_history, get_model_accuracy
8
 
9
 
10
  def train_model(model, optimizer, train_loader, val_loader, num_epochs=100, patience=10):
 
77
  best_valid_loss = valid_loss
78
  epochs_no_improve = 0
79
  best_model = model.state_dict()
80
+ torch.save(best_model, "../models/best_model_checkpoint.pth")
81
  else:
82
  epochs_no_improve += 1
83
  if epochs_no_improve == patience:
 
121
  model, history = train_model(model, optimizer, train_loader, val_loader, num_epochs=num_epochs)
122
  plot_training_history(history)
123
 
124
+ test_accuracy = get_model_accuracy(model, test_loader)
125
+ print("The accuracy of the model on the test set is:", test_accuracy)
126
 
127
  if __name__ == "__main__":
128
  main()
src/utils.py CHANGED
@@ -1,9 +1,11 @@
1
  import pickle
2
 
 
3
  import matplotlib.pyplot as plt
4
  import numpy as np
5
  import pandas as pd
6
  import seaborn as sns
 
7
  from datasets import DatasetDict, Dataset
8
  from sklearn.metrics import confusion_matrix
9
  from sklearn.model_selection import train_test_split
@@ -93,4 +95,22 @@ def plot_training_history(history):
93
  plt.legend()
94
 
95
  plt.tight_layout()
96
- plt.savefig('../docs/images/training_history.png')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pickle
2
 
3
+ import evaluate
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
  import pandas as pd
7
  import seaborn as sns
8
+ import torch
9
  from datasets import DatasetDict, Dataset
10
  from sklearn.metrics import confusion_matrix
11
  from sklearn.model_selection import train_test_split
 
95
  plt.legend()
96
 
97
  plt.tight_layout()
98
+ plt.savefig('../docs/images/training_history.png')
99
+
100
+
101
+ def get_model_accuracy(model, test_loader):
102
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
103
+ model.to(device)
104
+
105
+ metric = evaluate.load("accuracy")
106
+ model.eval()
107
+ for batch in test_loader:
108
+ batch = {k: v.to(device) for k, v in batch.items()}
109
+ with torch.no_grad():
110
+ outputs = model(**batch)
111
+
112
+ logits = outputs.logits
113
+ predictions = torch.argmax(logits, dim=-1)
114
+ metric.add_batch(predictions=predictions, references=batch["labels"])
115
+
116
+ return metric.compute()