UTMOSv2 / utmosv2 /runner /_inference.py
kAIto47802
Resolved conflict in README.md
b55d767
raw
history blame
1.17 kB
from __future__ import annotations
import numpy as np
import pandas as pd
import torch
from torch.cuda.amp import autocast
from tqdm import tqdm
from utmosv2.utils import calc_metrics, print_metrics
def run_inference(
cfg,
model: torch.nn.Module,
test_dataloader: torch.utils.data.DataLoader,
cycle: int,
test_data: pd.DataFrame,
device: torch.device,
) -> tuple[np.ndarray, dict[str, float] | None]:
model.eval()
test_preds = []
pbar = tqdm(
test_dataloader,
total=len(test_dataloader),
desc=f" [Inference] ({cycle + 1}/{cfg.inference.num_tta})",
)
with torch.no_grad():
for t in pbar:
x = t[:-1]
x = [t.to(device, non_blocking=True) for t in x]
with autocast():
output = model(*x).squeeze()
test_preds.append(output.squeeze().cpu().numpy())
test_preds = np.concatenate(test_preds) if cfg.input_dir else np.array(test_preds)
if cfg.reproduce:
test_metrics = calc_metrics(test_data, test_preds)
print_metrics(test_metrics)
else:
test_metrics = None
return test_preds, test_metrics