Spaces:
Sleeping
Sleeping
from huggingface_hub import hf_hub_url, cached_download | |
from mmcv import Config | |
import torch | |
from risk_biased.utils.load_model import get_predictor | |
from risk_biased.utils.torch_utils import load_weights | |
from risk_biased.utils.waymo_dataloader import WaymoDataloaders | |
config_file = cached_download(hf_hub_url("jmercat/risk_biased_model", filename="learning_config.py"), force_filename="learing_config.py") | |
ckpt = torch.load(cached_download(hf_hub_url("jmercat/risk_biased_model", filename="last.ckpt"), force_filename="last.ckpt"), map_location="cpu") | |
cfg = Config.fromfile(config_file) | |
predictor = get_predictor(cfg, WaymoDataloaders.unnormalize_trajectory) | |
predictor = load_weights(predictor, ckpt) | |
print("Model loaded") |