Spaces:
Runtime error
Runtime error
import copy | |
import csv | |
import os | |
import time | |
import numpy as np | |
import torch | |
from tqdm import tqdm | |
def train_model(model, criterion, dataloaders, optimizer, metrics, bpath, | |
num_epochs): | |
since = time.time() | |
best_model_wts = copy.deepcopy(model.state_dict()) | |
best_loss = 1e10 | |
# Use gpu if available | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
# Initialize the log file for training and testing loss and metrics | |
fieldnames = ['epoch', 'Train_loss', 'Test_loss'] + \ | |
[f'Train_{m}' for m in metrics.keys()] + \ | |
[f'Test_{m}' for m in metrics.keys()] | |
with open(os.path.join(bpath, 'log.csv'), 'w', newline='') as csvfile: | |
writer = csv.DictWriter(csvfile, fieldnames=fieldnames) | |
writer.writeheader() | |
for epoch in range(1, num_epochs + 1): | |
print('Epoch {}/{}'.format(epoch, num_epochs)) | |
print('-' * 10) | |
# Each epoch has a training and validation phase | |
# Initialize batch summary | |
batchsummary = {a: [0] for a in fieldnames} | |
for phase in ['Train', 'Test']: | |
if phase == 'Train': | |
model.train() # Set model to training mode | |
else: | |
model.eval() # Set model to evaluate mode | |
# Iterate over data. | |
for sample in tqdm(iter(dataloaders[phase])): | |
inputs = sample['image'].to(device) | |
masks = sample['mask'].to(device) | |
# zero the parameter gradients | |
optimizer.zero_grad() | |
# track history if only in train | |
with torch.set_grad_enabled(phase == 'Train'): | |
outputs = model(inputs) | |
loss = criterion(outputs['out'], masks) | |
y_pred = outputs['out'].data.cpu().numpy().ravel() | |
y_true = masks.data.cpu().numpy().ravel() | |
for name, metric in metrics.items(): | |
if name == 'f1_score': | |
# Use a classification threshold of 0.1 | |
batchsummary[f'{phase}_{name}'].append( | |
metric(y_true > 0, y_pred > 0.1)) | |
else: | |
batchsummary[f'{phase}_{name}'].append( | |
metric(y_true.astype('uint8'), y_pred)) | |
# backward + optimize only if in training phase | |
if phase == 'Train': | |
loss.backward() | |
optimizer.step() | |
batchsummary['epoch'] = epoch | |
epoch_loss = loss | |
batchsummary[f'{phase}_loss'] = epoch_loss.item() | |
print('{} Loss: {:.4f}'.format(phase, loss)) | |
for field in fieldnames[3:]: | |
batchsummary[field] = np.mean(batchsummary[field]) | |
print(batchsummary) | |
with open(os.path.join(bpath, 'log.csv'), 'a', newline='') as csvfile: | |
writer = csv.DictWriter(csvfile, fieldnames=fieldnames) | |
writer.writerow(batchsummary) | |
# deep copy the model | |
if phase == 'Test' and loss < best_loss: | |
best_loss = loss | |
best_model_wts = copy.deepcopy(model.state_dict()) | |
time_elapsed = time.time() - since | |
print('Training complete in {:.0f}m {:.0f}s'.format( | |
time_elapsed // 60, time_elapsed % 60)) | |
print('Lowest Loss: {:4f}'.format(best_loss)) | |
# load best model weights | |
model.load_state_dict(best_model_wts) | |
return model |