Spaces:
Running
on
Zero
Running
on
Zero
from __future__ import annotations | |
import os | |
from collections import defaultdict | |
from collections.abc import Callable | |
import numpy as np | |
import pandas as pd | |
import torch | |
import wandb | |
from torch.cuda.amp import GradScaler, autocast | |
from tqdm import tqdm | |
from utmosv2.utils import calc_metrics, print_metrics | |
def _train_1epoch( | |
cfg, | |
model: torch.nn.Module, | |
train_dataloader: torch.utils.data.DataLoader, | |
criterion: torch.nn.Module, | |
optimizer: torch.optim.Optimizer, | |
scheduler: torch.optim.lr_scheduler.LRScheduler, | |
device: torch.device, | |
) -> dict[str, float]: | |
model.train() | |
train_loss = defaultdict(float) | |
scaler = GradScaler() | |
print(f" (lr: {scheduler.get_last_lr()[0]:.6f})") | |
pbar = tqdm(train_dataloader, total=len(train_dataloader)) | |
for i, t in enumerate(pbar): | |
x, y = t[:-1], t[-1] | |
x = [t.to(device, non_blocking=True) for t in x] | |
y = y.to(device, non_blocking=True) | |
if cfg.run.mixup: | |
lmd = np.random.beta(cfg.run.mixup_alpha, cfg.run.mixup_alpha) | |
perm = torch.randperm(x[0].shape[0]).to(device) | |
x2 = [t[perm, :] for t in x] | |
y2 = y[perm] | |
optimizer.zero_grad() | |
with autocast(): | |
if cfg.run.mixup: | |
output = model( | |
*[lmd * t + (1 - lmd) * t2 for t, t2 in zip(x, x2)] | |
).squeeze(1) | |
if isinstance(cfg.loss, list): | |
loss = [ | |
(w1, lmd * l1 + (1 - lmd) * l2) | |
for (w1, l1), (_, l2) in zip( | |
criterion(output, y), criterion(output, y2) | |
) | |
] | |
else: | |
loss = lmd * criterion(output, y) + (1 - lmd) * criterion( | |
output, y2 | |
) | |
else: | |
output = model(*x).squeeze(1) | |
loss = criterion(output, y) | |
if isinstance(loss, list): | |
loss_total = sum(w * ls for w, ls in loss) | |
else: | |
loss_total = loss | |
scaler.scale(loss_total).backward() | |
scaler.step(optimizer) | |
scaler.update() | |
scheduler.step() | |
train_loss["loss"] += loss_total.detach().float().cpu().item() | |
if isinstance(loss, list): | |
for (cl, _), (_, ls) in zip(cfg.loss, loss): | |
train_loss[cl.name] += ls.detach().float().cpu().item() | |
pbar.set_description( | |
f' loss: {train_loss["loss"] / (i + 1):.4f}' | |
+ ( | |
f' ({", ".join([f"{cl.name}: {train_loss[cl.name] / (i + 1):.4f}" for cl, _ in cfg.loss])})' | |
if isinstance(loss, list) | |
else "" | |
) | |
) | |
return {name: v / len(train_dataloader) for name, v in train_loss.items()} | |
def _validate_1epoch( | |
cfg, | |
model: torch.nn.Module, | |
valid_dataloader: torch.utils.data.DataLoader, | |
criterion: torch.nn.Module, | |
metrics: dict[str, Callable[[np.ndarray, np.ndarray], float]], | |
device: torch.device, | |
) -> tuple[dict[str, float], dict[str, float], np.ndarray]: | |
model.eval() | |
valid_loss = defaultdict(float) | |
valid_metrics = {name: 0.0 for name in metrics} | |
valid_preds = [] | |
pbar = tqdm(valid_dataloader, total=len(valid_dataloader)) | |
with torch.no_grad(): | |
for i, t in enumerate(pbar): | |
x, y = t[:-1], t[-1] | |
x = [t.to(device, non_blocking=True) for t in x] | |
y_cpu = y | |
y = y.to(device, non_blocking=True) | |
with autocast(): | |
output = model(*x).squeeze(1) | |
loss = criterion(output, y) | |
if isinstance(loss, list): | |
loss_total = sum(w * ls for w, ls in loss) | |
else: | |
loss_total = loss | |
valid_loss["loss"] += loss_total.detach().float().cpu().item() | |
if isinstance(loss, list): | |
for (cl, _), (_, ls) in zip(cfg.loss, loss): | |
valid_loss[cl.name] += ls.detach().float().cpu().item() | |
output = output.cpu().numpy() | |
for name, metric in metrics.items(): | |
valid_metrics[name] += metric(output, y_cpu.numpy()) | |
pbar.set_description( | |
f' val_loss: {valid_loss["loss"] / (i + 1):.4f} ' | |
+ ( | |
f'({", ".join([f"{cl.name}: {valid_loss[cl.name] / (i + 1):.4f}" for cl, _ in cfg.loss])}) ' | |
if isinstance(loss, list) | |
else "" | |
) | |
+ " - ".join( | |
[ | |
f"val_{name}: {v / (i + 1):.4f}" | |
for name, v in valid_metrics.items() | |
] | |
) | |
) | |
valid_preds.append(output) | |
valid_loss = {name: v / len(valid_dataloader) for name, v in valid_loss.items()} | |
valid_metrics = { | |
name: v / len(valid_dataloader) for name, v in valid_metrics.items() | |
} | |
valid_preds = np.concatenate(valid_preds) | |
return valid_loss, valid_metrics, valid_preds | |
def run_train( | |
cfg, | |
model: torch.nn.Module, | |
train_dataloader: torch.utils.data.DataLoader, | |
valid_dataloader: torch.utils.data.DataLoader, | |
valid_data: pd.DataFrame, | |
oof_preds: np.ndarray, | |
now_fold: int, | |
criterion: torch.nn.Module, | |
metrics: dict[str, Callable[[np.ndarray, np.ndarray], float]], | |
optimizer: torch.optim.Optimizer, | |
scheduler: torch.optim.lr_scheduler.LRScheduler, | |
device: torch.device, | |
) -> None: | |
best_metric = 0.0 | |
os.makedirs(cfg.save_path, exist_ok=True) | |
for epoch in range(cfg.run.num_epochs): | |
print(f"[Epoch {epoch + 1}/{cfg.run.num_epochs}]") | |
train_loss = _train_1epoch( | |
cfg, model, train_dataloader, criterion, optimizer, scheduler, device | |
) | |
valid_loss, _, valid_preds = _validate_1epoch( | |
cfg, model, valid_dataloader, criterion, metrics, device | |
) | |
print(f"Validation dataset: {cfg.validation_dataset}") | |
if cfg.validation_dataset == "each": | |
dataset = valid_data["dataset"].unique() | |
val_metrics = [ | |
calc_metrics( | |
valid_data[valid_data["dataset"] == ds], | |
valid_preds[valid_data["dataset"] == ds], | |
) | |
for ds in dataset | |
] | |
val_metrics = { | |
name: sum([m[name] for m in val_metrics]) / len(val_metrics) | |
for name in val_metrics[0].keys() | |
} | |
if cfg.validation_dataset == "all": | |
print("Validation dataset: ALL") | |
val_metrics = calc_metrics(valid_data, valid_preds) | |
else: | |
val_metrics = calc_metrics( | |
valid_data[valid_data["dataset"] == cfg.validation_dataset], | |
valid_preds[valid_data["dataset"] == cfg.validation_dataset], | |
) | |
print_metrics(val_metrics) | |
if val_metrics[cfg.main_metric] > best_metric: | |
new_metric = val_metrics[cfg.main_metric] | |
print(f"(Found best metric: {best_metric:.4f} -> {new_metric:.4f})") | |
best_metric = new_metric | |
save_path = ( | |
cfg.save_path / f"fold{now_fold}_s{cfg.split.seed}_best_model.pth" | |
) | |
torch.save(model.state_dict(), save_path) | |
print(f"Save best model: {save_path}") | |
oof_preds[valid_data.index] = valid_preds | |
save_path = cfg.save_path / f"fold{now_fold}_s{cfg.split.seed}_last_model.pth" | |
torch.save(model.state_dict(), save_path) | |
print() | |
val_metrics["train_loss"] = train_loss["loss"] | |
val_metrics["val_loss"] = valid_loss["loss"] | |
for cl, _ in cfg.loss: | |
val_metrics[f"train_loss_{cl.name}"] = train_loss[cl.name] | |
val_metrics[f"val_loss_{cl.name}"] = valid_loss[cl.name] | |
if cfg.wandb: | |
wandb.log(val_metrics) | |