Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,799 Bytes
117183e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import logging
class Trainer():
def __init__(self, model, optimizer, criterion, scheduler, train_loader, valid_evaluator, test_evaluator, config_train, config_eval):
self.model = model
self.optimizer = optimizer
self.criterion = criterion
self.scheduler = scheduler
self.train_loader = train_loader
self.valid_evaluator = valid_evaluator
self.test_evaluator = test_evaluator
self.config_train = config_train
self.config_eval = config_eval
def _train_step(self, input_image, target_image):
self.optimizer.zero_grad()
prediction, x_backbone = self.model(input_image.cuda(), return_backbone=True)
loss = self.criterion(x_backbone, prediction, target_image.cuda())
loss.backward()
self.optimizer.step()
if self.scheduler is not None:
self.scheduler.step()
return loss.item()
def _train_epoch(self):
epoch_loss = 0
self.model.train()
for data in self.train_loader:
input_image, target_image, name = data['input_image'], data['target_image'], data['name']
loss = self._train_step(input_image, target_image)
epoch_loss += loss
return epoch_loss / len(self.train_loader)
def train(self):
for epoch in range(self.config_train.epochs):
epoch_loss = self._train_epoch()
logging.info(f"Epoch {epoch+1}/{self.config_train.epochs} | Loss: {epoch_loss}")
if self.valid_evaluator is not None and (epoch+1) % self.config_train.valid_every == 0:
self.valid_evaluator(self.model)
self.test_evaluator(self.model, save_results=True if self.valid_evaluator is None else False)
logging.info("Training finished.") |