File size: 841 Bytes
12da6cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import pickle
from pathlib import Path

from tabulate import tabulate


def load_latest_ckpt(ckpt_dir: Path):
    ckpt = ckpt_dir / "duration_latest_ckpt.pickle"
    if not ckpt.exists():
        return None
    print("Loading latest checkpoint from file", ckpt)
    with open(ckpt, "rb") as f:
        dic = pickle.load(f)
    return dic["step"], dic["params"], dic["aux"], dic["rng"], dic["optim_state"]


def save_ckpt(step, params, aux, rng, optim_state, ckpt_dir: Path):
    dic = {
        "step": step,
        "params": params,
        "aux": aux,
        "rng": rng,
        "optim_state": optim_state,
    }
    with open(ckpt_dir / "duration_latest_ckpt.pickle", "wb") as f:
        pickle.dump(dic, f)


def print_flags(flags):
    values = [(k, v) for k, v in flags.items() if not k.startswith("_")]
    print(tabulate(values))