Spaces:
Sleeping
Sleeping
print test accuracu
Browse files- src/model_training.py +4 -2
- src/utils.py +21 -1
src/model_training.py
CHANGED
@@ -4,7 +4,7 @@ from torch.utils.data import DataLoader
|
|
4 |
from tqdm.auto import tqdm
|
5 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
6 |
|
7 |
-
from utils import get_dataset, serialize_data, plot_training_history
|
8 |
|
9 |
|
10 |
def train_model(model, optimizer, train_loader, val_loader, num_epochs=100, patience=10):
|
@@ -77,7 +77,7 @@ def train_model(model, optimizer, train_loader, val_loader, num_epochs=100, pati
|
|
77 |
best_valid_loss = valid_loss
|
78 |
epochs_no_improve = 0
|
79 |
best_model = model.state_dict()
|
80 |
-
torch.save(best_model, "best_model_checkpoint.pth")
|
81 |
else:
|
82 |
epochs_no_improve += 1
|
83 |
if epochs_no_improve == patience:
|
@@ -121,6 +121,8 @@ def main():
|
|
121 |
model, history = train_model(model, optimizer, train_loader, val_loader, num_epochs=num_epochs)
|
122 |
plot_training_history(history)
|
123 |
|
|
|
|
|
124 |
|
125 |
if __name__ == "__main__":
|
126 |
main()
|
|
|
4 |
from tqdm.auto import tqdm
|
5 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
6 |
|
7 |
+
from utils import get_dataset, serialize_data, plot_training_history, get_model_accuracy
|
8 |
|
9 |
|
10 |
def train_model(model, optimizer, train_loader, val_loader, num_epochs=100, patience=10):
|
|
|
77 |
best_valid_loss = valid_loss
|
78 |
epochs_no_improve = 0
|
79 |
best_model = model.state_dict()
|
80 |
+
torch.save(best_model, "../models/best_model_checkpoint.pth")
|
81 |
else:
|
82 |
epochs_no_improve += 1
|
83 |
if epochs_no_improve == patience:
|
|
|
121 |
model, history = train_model(model, optimizer, train_loader, val_loader, num_epochs=num_epochs)
|
122 |
plot_training_history(history)
|
123 |
|
124 |
+
test_accuracy = get_model_accuracy(model, test_loader)
|
125 |
+
print("The accuracy of the model on the test set is:", test_accuracy)
|
126 |
|
127 |
if __name__ == "__main__":
|
128 |
main()
|
src/utils.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
import pickle
|
2 |
|
|
|
3 |
import matplotlib.pyplot as plt
|
4 |
import numpy as np
|
5 |
import pandas as pd
|
6 |
import seaborn as sns
|
|
|
7 |
from datasets import DatasetDict, Dataset
|
8 |
from sklearn.metrics import confusion_matrix
|
9 |
from sklearn.model_selection import train_test_split
|
@@ -93,4 +95,22 @@ def plot_training_history(history):
|
|
93 |
plt.legend()
|
94 |
|
95 |
plt.tight_layout()
|
96 |
-
plt.savefig('../docs/images/training_history.png')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import pickle
|
2 |
|
3 |
+
import evaluate
|
4 |
import matplotlib.pyplot as plt
|
5 |
import numpy as np
|
6 |
import pandas as pd
|
7 |
import seaborn as sns
|
8 |
+
import torch
|
9 |
from datasets import DatasetDict, Dataset
|
10 |
from sklearn.metrics import confusion_matrix
|
11 |
from sklearn.model_selection import train_test_split
|
|
|
95 |
plt.legend()
|
96 |
|
97 |
plt.tight_layout()
|
98 |
+
plt.savefig('../docs/images/training_history.png')
|
99 |
+
|
100 |
+
|
101 |
+
def get_model_accuracy(model, test_loader):
|
102 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
103 |
+
model.to(device)
|
104 |
+
|
105 |
+
metric = evaluate.load("accuracy")
|
106 |
+
model.eval()
|
107 |
+
for batch in test_loader:
|
108 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
109 |
+
with torch.no_grad():
|
110 |
+
outputs = model(**batch)
|
111 |
+
|
112 |
+
logits = outputs.logits
|
113 |
+
predictions = torch.argmax(logits, dim=-1)
|
114 |
+
metric.add_batch(predictions=predictions, references=batch["labels"])
|
115 |
+
|
116 |
+
return metric.compute()
|