sadjava's picture
changed to pipelines
fd52b7f
import copy
import csv
import os
import time
import numpy as np
import torch
from tqdm import tqdm
def train_model(model, criterion, dataloaders, optimizer, metrics, bpath,
num_epochs):
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_loss = 1e10
# Use gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
# Initialize the log file for training and testing loss and metrics
fieldnames = ['epoch', 'Train_loss', 'Test_loss'] + \
[f'Train_{m}' for m in metrics.keys()] + \
[f'Test_{m}' for m in metrics.keys()]
with open(os.path.join(bpath, 'log.csv'), 'w', newline='') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for epoch in range(1, num_epochs + 1):
print('Epoch {}/{}'.format(epoch, num_epochs))
print('-' * 10)
# Each epoch has a training and validation phase
# Initialize batch summary
batchsummary = {a: [0] for a in fieldnames}
for phase in ['Train', 'Test']:
if phase == 'Train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
# Iterate over data.
for sample in tqdm(iter(dataloaders[phase])):
inputs = sample['image'].to(device)
masks = sample['mask'].to(device)
# zero the parameter gradients
optimizer.zero_grad()
# track history if only in train
with torch.set_grad_enabled(phase == 'Train'):
outputs = model(inputs)
loss = criterion(outputs['out'], masks)
y_pred = outputs['out'].data.cpu().numpy().ravel()
y_true = masks.data.cpu().numpy().ravel()
for name, metric in metrics.items():
if name == 'f1_score':
# Use a classification threshold of 0.1
batchsummary[f'{phase}_{name}'].append(
metric(y_true > 0, y_pred > 0.1))
else:
batchsummary[f'{phase}_{name}'].append(
metric(y_true.astype('uint8'), y_pred))
# backward + optimize only if in training phase
if phase == 'Train':
loss.backward()
optimizer.step()
batchsummary['epoch'] = epoch
epoch_loss = loss
batchsummary[f'{phase}_loss'] = epoch_loss.item()
print('{} Loss: {:.4f}'.format(phase, loss))
for field in fieldnames[3:]:
batchsummary[field] = np.mean(batchsummary[field])
print(batchsummary)
with open(os.path.join(bpath, 'log.csv'), 'a', newline='') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writerow(batchsummary)
# deep copy the model
if phase == 'Test' and loss < best_loss:
best_loss = loss
best_model_wts = copy.deepcopy(model.state_dict())
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
print('Lowest Loss: {:4f}'.format(best_loss))
# load best model weights
model.load_state_dict(best_model_wts)
return model