|
import pickle |
|
from pathlib import Path |
|
|
|
from tabulate import tabulate |
|
|
|
|
|
def load_latest_ckpt(ckpt_dir: Path): |
|
ckpt = ckpt_dir / "duration_latest_ckpt.pickle" |
|
if not ckpt.exists(): |
|
return None |
|
print("Loading latest checkpoint from file", ckpt) |
|
with open(ckpt, "rb") as f: |
|
dic = pickle.load(f) |
|
return dic["step"], dic["params"], dic["aux"], dic["rng"], dic["optim_state"] |
|
|
|
|
|
def save_ckpt(step, params, aux, rng, optim_state, ckpt_dir: Path): |
|
dic = { |
|
"step": step, |
|
"params": params, |
|
"aux": aux, |
|
"rng": rng, |
|
"optim_state": optim_state, |
|
} |
|
with open(ckpt_dir / "duration_latest_ckpt.pickle", "wb") as f: |
|
pickle.dump(dic, f) |
|
|
|
|
|
def print_flags(flags): |
|
values = [(k, v) for k, v in flags.items() if not k.startswith("_")] |
|
print(tabulate(values)) |
|
|