import copy import csv import os import time import numpy as np import torch from tqdm import tqdm def train_model(model, criterion, dataloaders, optimizer, metrics, bpath, num_epochs): since = time.time() best_model_wts = copy.deepcopy(model.state_dict()) best_loss = 1e10 # Use gpu if available device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model.to(device) # Initialize the log file for training and testing loss and metrics fieldnames = ['epoch', 'Train_loss', 'Test_loss'] + \ [f'Train_{m}' for m in metrics.keys()] + \ [f'Test_{m}' for m in metrics.keys()] with open(os.path.join(bpath, 'log.csv'), 'w', newline='') as csvfile: writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() for epoch in range(1, num_epochs + 1): print('Epoch {}/{}'.format(epoch, num_epochs)) print('-' * 10) # Each epoch has a training and validation phase # Initialize batch summary batchsummary = {a: [0] for a in fieldnames} for phase in ['Train', 'Test']: if phase == 'Train': model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode # Iterate over data. for sample in tqdm(iter(dataloaders[phase])): inputs = sample['image'].to(device) masks = sample['mask'].to(device) # zero the parameter gradients optimizer.zero_grad() # track history if only in train with torch.set_grad_enabled(phase == 'Train'): outputs = model(inputs) loss = criterion(outputs['out'], masks) y_pred = outputs['out'].data.cpu().numpy().ravel() y_true = masks.data.cpu().numpy().ravel() for name, metric in metrics.items(): if name == 'f1_score': # Use a classification threshold of 0.1 batchsummary[f'{phase}_{name}'].append( metric(y_true > 0, y_pred > 0.1)) else: batchsummary[f'{phase}_{name}'].append( metric(y_true.astype('uint8'), y_pred)) # backward + optimize only if in training phase if phase == 'Train': loss.backward() optimizer.step() batchsummary['epoch'] = epoch epoch_loss = loss batchsummary[f'{phase}_loss'] = epoch_loss.item() print('{} Loss: {:.4f}'.format(phase, loss)) for field in fieldnames[3:]: batchsummary[field] = np.mean(batchsummary[field]) print(batchsummary) with open(os.path.join(bpath, 'log.csv'), 'a', newline='') as csvfile: writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writerow(batchsummary) # deep copy the model if phase == 'Test' and loss < best_loss: best_loss = loss best_model_wts = copy.deepcopy(model.state_dict()) time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) print('Lowest Loss: {:4f}'.format(best_loss)) # load best model weights model.load_state_dict(best_model_wts) return model