import torch import matplotlib.pyplot as plt from torchvision.utils import make_grid def save_checkpoint(model, optimizer, epoch, loss, path): torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, }, path) def load_checkpoint(model, optimizer, path): checkpoint = torch.load(path) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] return model, optimizer, epoch, loss def plot_training_curves(epochs, train_acc1, test_acc1, train_acc5, test_acc5, train_losses, test_losses, learning_rates): plt.figure(figsize=(12, 8)) plt.subplot(2, 2, 1) plt.plot(epochs, train_acc1, label='Train Top-1 Acc') plt.plot(epochs, test_acc1, label='Test Top-1 Acc') plt.xlabel('Epoch') plt.ylabel('Accuracy') plt.legend() plt.title('Top-1 Accuracy') plt.subplot(2, 2, 2) plt.plot(epochs, train_acc5, label='Train Top-5 Acc') plt.plot(epochs, test_acc5, label='Test Top-5 Acc') plt.xlabel('Epoch') plt.ylabel('Accuracy') plt.legend() plt.title('Top-5 Accuracy') plt.subplot(2, 2, 3) plt.plot(epochs, train_losses, label='Train Loss') plt.plot(epochs, test_losses, label='Test Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.title('Loss') plt.subplot(2, 2, 4) plt.plot(epochs, learning_rates, label='Learning Rate') plt.xlabel('Epoch') plt.ylabel('Learning Rate') plt.legend() plt.title('Learning Rate') plt.tight_layout() plt.show() def plot_misclassified_samples(misclassified_images, misclassified_labels, misclassified_preds, classes): if misclassified_images: print("\nDisplaying some misclassified samples:") misclassified_grid = make_grid(misclassified_images[:16], nrow=4, normalize=True, scale_each=True) plt.figure(figsize=(8, 8)) plt.imshow(misclassified_grid.permute(1, 2, 0)) plt.title("Misclassified Samples") plt.axis('off') plt.show()