tts / vietTTS /nat /utils.py
tobiccino's picture
upload
12da6cc
raw
history blame
841 Bytes
import pickle
from pathlib import Path
from tabulate import tabulate
def load_latest_ckpt(ckpt_dir: Path):
ckpt = ckpt_dir / "duration_latest_ckpt.pickle"
if not ckpt.exists():
return None
print("Loading latest checkpoint from file", ckpt)
with open(ckpt, "rb") as f:
dic = pickle.load(f)
return dic["step"], dic["params"], dic["aux"], dic["rng"], dic["optim_state"]
def save_ckpt(step, params, aux, rng, optim_state, ckpt_dir: Path):
dic = {
"step": step,
"params": params,
"aux": aux,
"rng": rng,
"optim_state": optim_state,
}
with open(ckpt_dir / "duration_latest_ckpt.pickle", "wb") as f:
pickle.dump(dic, f)
def print_flags(flags):
values = [(k, v) for k, v in flags.items() if not k.startswith("_")]
print(tabulate(values))