Ubuntu commited on
Commit
41b8141
·
1 Parent(s): d695662

Added checkpoint and early stopping

Browse files
Files changed (2) hide show
  1. checkpoint.py +21 -0
  2. resnet_execute.py +24 -2
checkpoint.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def save_checkpoint(model, optimizer, epoch, loss, checkpoint_path="checkpoint.pth"):
4
+ checkpoint = {
5
+ 'epoch': epoch,
6
+ 'model_state_dict': model.state_dict(),
7
+ 'optimizer_state_dict': optimizer.state_dict(),
8
+ 'loss': loss
9
+ }
10
+ torch.save(checkpoint, checkpoint_path)
11
+ print(f"Checkpoint saved at epoch {epoch}")
12
+
13
+ def load_checkpoint(model, optimizer, checkpoint_path="checkpoint.pth"):
14
+ checkpoint = torch.load(checkpoint_path, weights_only=True)
15
+ model.load_state_dict(checkpoint['model_state_dict'])
16
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
17
+ epoch = checkpoint['epoch']
18
+ loss = checkpoint['loss']
19
+ print(f"Checkpoint loaded, resuming from epoch {epoch}")
20
+ return model, optimizer, loss
21
+
resnet_execute.py CHANGED
@@ -7,6 +7,7 @@ import torch.optim as optim
7
  from resnet_model import ResNet50
8
  from tqdm import tqdm
9
  from torchvision import datasets
 
10
 
11
  # Define transformations
12
  transform = transforms.Compose([
@@ -89,11 +90,32 @@ def test(model, device, test_loader, criterion):
89
 
90
  test_accuracy = 100.*correct/total
91
  print(f'Test Loss: {test_loss/len(test_loader):.4f}, Accuracy: {test_accuracy:.2f}%')
92
- return test_accuracy
93
 
94
  # Main execution
95
  if __name__ == '__main__':
 
 
 
 
 
 
 
 
 
 
 
96
  for epoch in range(1, 6): # 20 epochs
97
  train_accuracy = train(model, device, trainloader, optimizer, criterion, epoch)
98
- test_accuracy = test(model, device, testloader, criterion)
99
  print(f'Epoch {epoch} | Train Accuracy: {train_accuracy:.2f}% | Test Accuracy: {test_accuracy:.2f}%')
 
 
 
 
 
 
 
 
 
 
 
7
  from resnet_model import ResNet50
8
  from tqdm import tqdm
9
  from torchvision import datasets
10
+ from checkpoint import save_checkpoint, load_checkpoint
11
 
12
  # Define transformations
13
  transform = transforms.Compose([
 
90
 
91
  test_accuracy = 100.*correct/total
92
  print(f'Test Loss: {test_loss/len(test_loader):.4f}, Accuracy: {test_accuracy:.2f}%')
93
+ return test_accuracy, test_loss/len(test_loader)
94
 
95
  # Main execution
96
  if __name__ == '__main__':
97
+ # Early stopping parameters and checkpoint path
98
+ checkpoint_path = "checkpoint.pth"
99
+ best_loss = float('inf')
100
+ patience = 5
101
+ patience_counter = 0
102
+ # Load checkpoint if it exists to resume training
103
+ try:
104
+ model, optimizer, best_test_accuracy = load_checkpoint(model, optimizer, checkpoint_path)
105
+ except FileNotFoundError:
106
+ print("No checkpoint found, starting from scratch.")
107
+
108
  for epoch in range(1, 6): # 20 epochs
109
  train_accuracy = train(model, device, trainloader, optimizer, criterion, epoch)
110
+ test_accuracy, test_loss = test(model, device, testloader, criterion)
111
  print(f'Epoch {epoch} | Train Accuracy: {train_accuracy:.2f}% | Test Accuracy: {test_accuracy:.2f}%')
112
+ if test_loss < best_loss:
113
+ best_loss = test_loss
114
+ patience_counter = 0
115
+ save_checkpoint(model, optimizer, epoch, test_loss, checkpoint_path)
116
+ else:
117
+ patience_counter += 1
118
+
119
+ if patience_counter >= patience:
120
+ print("Early stopping triggered. Training terminated.")
121
+ break