Spaces:
Runtime error
Runtime error
from pathlib import Path | |
import click | |
import torch | |
from sklearn.metrics import f1_score | |
from torch.utils import data | |
from utils import * | |
from model import createDeepLabv3 | |
from trainer import train_model | |
def main(data_directory, exp_directory, epochs, batch_size): | |
# Create the deeplabv3 resnet101 model which is pretrained on a subset | |
# of COCO train2017, on the 20 categories that are present in the Pascal VOC dataset. | |
model = createDeepLabv3() | |
model.train() | |
data_directory = Path(data_directory) | |
# Create the experiment directory if not present | |
exp_directory = Path(exp_directory) | |
if not exp_directory.exists(): | |
exp_directory.mkdir() | |
# Specify the loss function | |
criterion = torch.nn.MSELoss(reduction='mean') | |
# Specify the optimizer with a lower learning rate | |
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) | |
# Specify the evaluation metrics | |
metrics = {'f1_score': f1_score, 'iou': iou} | |
# Create the dataloader | |
dataloaders = get_dataloader_single_folder( | |
data_directory, batch_size=batch_size) | |
_ = train_model(model, | |
criterion, | |
dataloaders, | |
optimizer, | |
bpath=exp_directory, | |
metrics=metrics, | |
num_epochs=epochs) | |
# Save the trained model | |
torch.save(model, exp_directory / 'weights.pt') | |
if __name__ == "__main__": | |
main() |