File size: 2,716 Bytes
3352589 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import torch
import torch.nn as nn
import torch.optim as optim
from resnet_model import ResNet50
from data_utils import get_train_transform, get_test_transform, get_data_loaders
from train_test import train, test
from utils import save_checkpoint, load_checkpoint, plot_training_curves, plot_misclassified_samples
def main():
# Initialize model, loss function, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet50().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
# Load data
train_transform = get_train_transform()
test_transform = get_test_transform()
trainloader, testloader = get_data_loaders(train_transform, test_transform)
# Load checkpoint if it exists
checkpoint_path = "checkpoint.pth"
try:
model, optimizer, start_epoch, _ = load_checkpoint(model, optimizer, checkpoint_path)
except FileNotFoundError:
print("No checkpoint found, starting from scratch.")
start_epoch = 1
# Store results for plotting
results = []
learning_rates = []
# Training loop
for epoch in range(start_epoch, 26):
train_accuracy1, train_accuracy5, train_loss = train(model, device, trainloader, optimizer, criterion, epoch)
test_accuracy1, test_accuracy5, test_loss, misclassified_images, misclassified_labels, misclassified_preds = test(model, device, testloader, criterion)
print(f'Epoch {epoch} | Train Top-1 Acc: {train_accuracy1:.2f} | Test Top-1 Acc: {test_accuracy1:.2f}')
# Append results for this epoch
results.append((epoch, train_accuracy1, train_accuracy5, test_accuracy1, test_accuracy5, train_loss, test_loss))
learning_rates.append(optimizer.param_groups[0]['lr'])
# Save checkpoint
save_checkpoint(model, optimizer, epoch, test_loss, checkpoint_path)
# Extract results for plotting
epochs = [r[0] for r in results]
train_acc1 = [r[1] for r in results]
train_acc5 = [r[2] for r in results]
test_acc1 = [r[3] for r in results]
test_acc5 = [r[4] for r in results]
train_losses = [r[5] for r in results]
test_losses = [r[6] for r in results]
# Plot training curves
plot_training_curves(epochs, train_acc1, test_acc1, train_acc5, test_acc5, train_losses, test_losses, learning_rates)
# Plot misclassified samples
plot_misclassified_samples(misclassified_images, misclassified_labels, misclassified_preds, classes=['class1', 'class2', ...]) # Replace with actual class names
if __name__ == '__main__':
main() |