Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
class Trainer(): | |
def __init__(self, model, optimizer, criterion, scheduler, train_loader, valid_evaluator, test_evaluator, config_train, config_eval): | |
self.model = model | |
self.optimizer = optimizer | |
self.criterion = criterion | |
self.scheduler = scheduler | |
self.train_loader = train_loader | |
self.valid_evaluator = valid_evaluator | |
self.test_evaluator = test_evaluator | |
self.config_train = config_train | |
self.config_eval = config_eval | |
def _train_step(self, input_image, target_image): | |
self.optimizer.zero_grad() | |
prediction, x_backbone = self.model(input_image.cuda(), return_backbone=True) | |
loss = self.criterion(x_backbone, prediction, target_image.cuda()) | |
loss.backward() | |
self.optimizer.step() | |
if self.scheduler is not None: | |
self.scheduler.step() | |
return loss.item() | |
def _train_epoch(self): | |
epoch_loss = 0 | |
self.model.train() | |
for data in self.train_loader: | |
input_image, target_image, name = data['input_image'], data['target_image'], data['name'] | |
loss = self._train_step(input_image, target_image) | |
epoch_loss += loss | |
return epoch_loss / len(self.train_loader) | |
def train(self): | |
for epoch in range(self.config_train.epochs): | |
epoch_loss = self._train_epoch() | |
logging.info(f"Epoch {epoch+1}/{self.config_train.epochs} | Loss: {epoch_loss}") | |
if self.valid_evaluator is not None and (epoch+1) % self.config_train.valid_every == 0: | |
self.valid_evaluator(self.model) | |
self.test_evaluator(self.model, save_results=True if self.valid_evaluator is None else False) | |
logging.info("Training finished.") |