import os |
import uuid |
import warnings |
from datetime import datetime as dt |
from typing import Dict |
import matplotlib.pyplot as plt |
import numpy as np |
import torch |
import torch.distributed as dist |
import torch.nn as nn |
import torch.optim as optim |
import wandb |
from tqdm import tqdm |
from zoedepth.utils.config import flatten |
from zoedepth.utils.misc import RunningAverageDict, colorize, colors |
def is_rank_zero(args): |
return args.rank == 0 |
class BaseTrainer: |
def __init__(self, config, model, train_loader, test_loader=None, device=None): |
""" Base Trainer class for training a model.""" |
self.config = config |
self.metric_criterion = "abs_rel" |
if device is None: |
device = torch.device( |
'cuda') if torch.cuda.is_available() else torch.device('cpu') |
self.device = device |
self.model = model |
self.train_loader = train_loader |
self.test_loader = test_loader |
self.optimizer = self.init_optimizer() |
self.scheduler = self.init_scheduler() |
def resize_to_target(self, prediction, target): |
if prediction.shape[2:] != target.shape[-2:]: |
prediction = nn.functional.interpolate( |
prediction, size=target.shape[-2:], mode="bilinear", align_corners=True |
) |
return prediction |
def load_ckpt(self, checkpoint_dir="./checkpoints", ckpt_type="best"): |
import glob |
import os |
from zoedepth.models.model_io import load_wts |
if hasattr(self.config, "checkpoint"): |
checkpoint = self.config.checkpoint |
elif hasattr(self.config, "ckpt_pattern"): |
pattern = self.config.ckpt_pattern |
matches = glob.glob(os.path.join( |
checkpoint_dir, f"*{pattern}*{ckpt_type}*")) |
if not (len(matches) > 0): |
raise ValueError(f"No matches found for the pattern {pattern}") |
checkpoint = matches[0] |
else: |
return |
model = load_wts(self.model, checkpoint) |
print("Loaded weights from {0}".format(checkpoint)) |
warnings.warn( |
"Resuming training is not properly supported in this repo. Implement loading / saving of optimizer and scheduler to support it.") |
self.model = model |
def init_optimizer(self): |
m = self.model.module if self.config.multigpu else self.model |
if self.config.same_lr: |
print("Using same LR") |
if hasattr(m, 'core'): |
m.core.unfreeze() |
params = self.model.parameters() |
else: |
print("Using diff LR") |
if not hasattr(m, 'get_lr_params'): |
raise NotImplementedError( |
f"Model {m.__class__.__name__} does not implement get_lr_params. Please implement it or use the same LR for all parameters.") |
params = m.get_lr_params(self.config.lr) |
return optim.AdamW(params, lr=self.config.lr, weight_decay=self.config.wd) |
def init_scheduler(self): |
lrs = [l['lr'] for l in self.optimizer.param_groups] |
return optim.lr_scheduler.OneCycleLR(self.optimizer, lrs, epochs=self.config.epochs, steps_per_epoch=len(self.train_loader), |
cycle_momentum=self.config.cycle_momentum, |
base_momentum=0.85, max_momentum=0.95, div_factor=self.config.div_factor, final_div_factor=self.config.final_div_factor, pct_start=self.config.pct_start, three_phase=self.config.three_phase) |
def train_on_batch(self, batch, train_step): |
raise NotImplementedError |
def validate_on_batch(self, batch, val_step): |
raise NotImplementedError |
def raise_if_nan(self, losses): |
for key, value in losses.items(): |
if torch.isnan(value): |
raise ValueError(f"{key} is NaN, Stopping training") |
@property |
def iters_per_epoch(self): |
return len(self.train_loader) |
@property |
def total_iters(self): |
return self.config.epochs * self.iters_per_epoch |
def should_early_stop(self): |
if self.config.get('early_stop', False) and self.step > self.config.early_stop: |
return True |
def train(self): |
print(f"Training {self.config.name}") |
if self.config.uid is None: |
self.config.uid = str(uuid.uuid4()).split('-')[-1] |
run_id = f"{dt.now().strftime('%d-%h_%H-%M')}-{self.config.uid}" |
self.config.run_id = run_id |
self.config.experiment_id = f"{self.config.name}{self.config.version_name}_{run_id}" |
self.should_write = ((not self.config.distributed) |
or self.config.rank == 0) |
self.should_log = self.should_write |
if self.should_log: |
tags = self.config.tags.split( |
',') if self.config.tags != '' else None |
wandb.init(project=self.config.project, name=self.config.experiment_id, config=flatten(self.config), dir=self.config.root, |
tags=tags, notes=self.config.notes, settings=wandb.Settings(start_method="fork")) |
self.model.train() |
self.step = 0 |
best_loss = np.inf |
validate_every = int(self.config.validate_every * self.iters_per_epoch) |
if self.config.prefetch: |
for i, batch in tqdm(enumerate(self.train_loader), desc=f"Prefetching...", |
total=self.iters_per_epoch) if is_rank_zero(self.config) else enumerate(self.train_loader): |
pass |
losses = {} |
def stringify_losses(L): return "; ".join(map( |
lambda kv: f"{colors.fg.purple}{kv[0]}{colors.reset}: {round(kv[1].item(),3):.4e}", L.items())) |
for epoch in range(self.config.epochs): |
if self.should_early_stop(): |
break |
self.epoch = epoch |
if self.should_log: |
wandb.log({"Epoch": epoch}, step=self.step) |
pbar = tqdm(enumerate(self.train_loader), desc=f"Epoch: {epoch + 1}/{self.config.epochs}. Loop: Train", |
total=self.iters_per_epoch) if is_rank_zero(self.config) else enumerate(self.train_loader) |
for i, batch in pbar: |
if self.should_early_stop(): |
print("Early stopping") |
break |
losses = self.train_on_batch(batch, i) |
self.raise_if_nan(losses) |
if is_rank_zero(self.config) and self.config.print_losses: |
pbar.set_description( |
f"Epoch: {epoch + 1}/{self.config.epochs}. Loop: Train. Losses: {stringify_losses(losses)}") |
self.scheduler.step() |
if self.should_log and self.step % 50 == 0: |
wandb.log({f"Train/{name}": loss.item() |
for name, loss in losses.items()}, step=self.step) |
self.step += 1 |
if self.test_loader: |
if (self.step % validate_every) == 0: |
self.model.eval() |
if self.should_write: |
self.save_checkpoint( |
f"{self.config.experiment_id}_latest.pt") |
metrics, test_losses = self.validate() |
if self.should_log: |
wandb.log( |
{f"Test/{name}": tloss for name, tloss in test_losses.items()}, step=self.step) |
wandb.log({f"Metrics/{k}": v for k, |
v in metrics.items()}, step=self.step) |
if (metrics[self.metric_criterion] < best_loss) and self.should_write: |
self.save_checkpoint( |
f"{self.config.experiment_id}_best.pt") |
best_loss = metrics[self.metric_criterion] |
self.model.train() |
if self.config.distributed: |
dist.barrier() |
self.step += 1 |
self.model.eval() |
self.save_checkpoint(f"{self.config.experiment_id}_latest.pt") |
if self.test_loader: |
metrics, test_losses = self.validate() |
if self.should_log: |
wandb.log({f"Test/{name}": tloss for name, |
tloss in test_losses.items()}, step=self.step) |
wandb.log({f"Metrics/{k}": v for k, |
v in metrics.items()}, step=self.step) |
if (metrics[self.metric_criterion] < best_loss) and self.should_write: |
self.save_checkpoint( |
f"{self.config.experiment_id}_best.pt") |
best_loss = metrics[self.metric_criterion] |
self.model.train() |
def validate(self): |
with torch.no_grad(): |
losses_avg = RunningAverageDict() |
metrics_avg = RunningAverageDict() |
for i, batch in tqdm(enumerate(self.test_loader), desc=f"Epoch: {self.epoch + 1}/{self.config.epochs}. Loop: Validation", total=len(self.test_loader), disable=not is_rank_zero(self.config)): |
metrics, losses = self.validate_on_batch(batch, val_step=i) |
if losses: |
losses_avg.update(losses) |
if metrics: |
metrics_avg.update(metrics) |
return metrics_avg.get_value(), losses_avg.get_value() |
def save_checkpoint(self, filename): |
if not self.should_write: |
return |
root = self.config.save_dir |
if not os.path.isdir(root): |
os.makedirs(root) |
fpath = os.path.join(root, filename) |
m = self.model.module if self.config.multigpu else self.model |
torch.save( |
{ |
"model": m.state_dict(), |
"optimizer": None, |
"epoch": self.epoch |
}, fpath) |
def log_images(self, rgb: Dict[str, list] = {}, depth: Dict[str, list] = {}, scalar_field: Dict[str, list] = {}, prefix="", scalar_cmap="jet", min_depth=None, max_depth=None): |
if not self.should_log: |
return |
if min_depth is None: |
try: |
min_depth = self.config.min_depth |
max_depth = self.config.max_depth |
except AttributeError: |
min_depth = None |
max_depth = None |
depth = {k: colorize(v, vmin=min_depth, vmax=max_depth) |
for k, v in depth.items()} |
scalar_field = {k: colorize( |
v, vmin=None, vmax=None, cmap=scalar_cmap) for k, v in scalar_field.items()} |
images = {**rgb, **depth, **scalar_field} |
wimages = { |
prefix+"Predictions": [wandb.Image(v, caption=k) for k, v in images.items()]} |
wandb.log(wimages, step=self.step) |
def log_line_plot(self, data): |
if not self.should_log: |
return |
plt.plot(data) |
plt.ylabel("Scale factors") |
wandb.log({"Scale factors": wandb.Image(plt)}, step=self.step) |
plt.close() |
def log_bar_plot(self, title, labels, values): |
if not self.should_log: |
return |
data = [[label, val] for (label, val) in zip(labels, values)] |
table = wandb.Table(data=data, columns=["label", "value"]) |
wandb.log({title: wandb.plot.bar(table, "label", |
"value", title=title)}, step=self.step) |