ResNet50_replicate / utils.py
atiwari751's picture
modularised resnet_execute into individual scripts
3352589
raw
history blame
2.26 kB
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
def save_checkpoint(model, optimizer, epoch, loss, path):
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, path)
def load_checkpoint(model, optimizer, path):
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
return model, optimizer, epoch, loss
def plot_training_curves(epochs, train_acc1, test_acc1, train_acc5, test_acc5, train_losses, test_losses, learning_rates):
plt.figure(figsize=(12, 8))
plt.subplot(2, 2, 1)
plt.plot(epochs, train_acc1, label='Train Top-1 Acc')
plt.plot(epochs, test_acc1, label='Test Top-1 Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Top-1 Accuracy')
plt.subplot(2, 2, 2)
plt.plot(epochs, train_acc5, label='Train Top-5 Acc')
plt.plot(epochs, test_acc5, label='Test Top-5 Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Top-5 Accuracy')
plt.subplot(2, 2, 3)
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss')
plt.subplot(2, 2, 4)
plt.plot(epochs, learning_rates, label='Learning Rate')
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.legend()
plt.title('Learning Rate')
plt.tight_layout()
plt.show()
def plot_misclassified_samples(misclassified_images, misclassified_labels, misclassified_preds, classes):
if misclassified_images:
print("\nDisplaying some misclassified samples:")
misclassified_grid = make_grid(misclassified_images[:16], nrow=4, normalize=True, scale_each=True)
plt.figure(figsize=(8, 8))
plt.imshow(misclassified_grid.permute(1, 2, 0))
plt.title("Misclassified Samples")
plt.axis('off')
plt.show()