Spaces:
No application file
No application file
import torch | |
import os, re, io | |
import json | |
import dill as pickle | |
from shutil import copy2 as copy | |
MODEL_EXTENSION = ".pkl" | |
MODEL_FILE_FORMAT = "{:s}_{:d}{:s}" # model_prefix, epoch and extension | |
BEST_MODEL_FILE = ".model_score.txt" | |
MODEL_SERVE_FILE = ".serve.txt" | |
VOCAB_FILE_FORMAT = "{:s}{:s}{:s}" | |
def save_model_name(name, path, serve_config_path=MODEL_SERVE_FILE): | |
with io.open(os.path.join(path, serve_config_path), "w", encoding="utf-8") as serve_config_file: | |
serve_config_file.write(name) | |
def save_vocab_to_path(path, language_tuple, fields, name_prefix="vocab", check_saved_vocab=True): | |
src_field, trg_field = fields | |
src_ext, trg_ext = language_tuple | |
src_vocab_path = os.path.join(path, VOCAB_FILE_FORMAT.format(name_prefix, src_ext, MODEL_EXTENSION)) | |
trg_vocab_path = os.path.join(path, VOCAB_FILE_FORMAT.format(name_prefix, trg_ext, MODEL_EXTENSION)) | |
if(check_saved_vocab and os.path.isfile(src_vocab_path) and os.path.isfile(trg_vocab_path)):# do nothing if already exist | |
return | |
with io.open(src_vocab_path , "wb") as src_vocab_file: | |
pickle.dump(src_field.vocab, src_vocab_file) | |
with io.open(trg_vocab_path , "wb") as trg_vocab_file: | |
pickle.dump(trg_field.vocab, trg_vocab_file) | |
def load_vocab_from_path(path, language_tuple, fields, name_prefix="vocab"): | |
"""Load the vocabulary from path into respective fields. If files doesn't exist, return False; if loaded properly, return True""" | |
src_field, trg_field = fields | |
src_ext, trg_ext = language_tuple | |
src_vocab_file_path = os.path.join(path, VOCAB_FILE_FORMAT.format(name_prefix, src_ext, MODEL_EXTENSION)) | |
trg_vocab_file_path = os.path.join(path, VOCAB_FILE_FORMAT.format(name_prefix, trg_ext, MODEL_EXTENSION)) | |
if(not os.path.isfile(src_vocab_file_path) or not os.path.isfile(trg_vocab_file_path)): | |
# the vocab file wasn't dumped, return False | |
return False | |
with io.open(src_vocab_file_path, "rb") as src_vocab_file, io.open(trg_vocab_file_path, "rb") as trg_vocab_file: | |
src_vocab = pickle.load(src_vocab_file) | |
src_field.vocab = src_vocab | |
trg_vocab = pickle.load(trg_vocab_file) | |
trg_field.vocab = trg_vocab | |
return True | |
def save_model_to_path(model, path, name_prefix="model", checkpoint_idx=0, save_vocab=True): | |
save_path = os.path.join(path, MODEL_FILE_FORMAT.format(name_prefix, checkpoint_idx, MODEL_EXTENSION)) | |
torch.save(model.state_dict(), save_path) | |
if(save_vocab): | |
save_vocab_to_path(path, model.loader._language_tuple, model.fields) | |
def load_model_from_path(model, path, name_prefix="model", checkpoint_idx=0): | |
# do not load vocab here, as the vocab structure will be decided in model.loader.build_vocab | |
save_path = os.path.join(path, MODEL_FILE_FORMAT.format(name_prefix, checkpoint_idx, MODEL_EXTENSION)) | |
model.load_state_dict(torch.load(save_path)) | |
def load_model(model, model_path): | |
model.load_state_dict(torch.load(model_path)) | |
def check_model_in_path(path, name_prefix="model", return_all_checkpoint=False): | |
model_re = re.compile(r"{:s}_(\d+){:s}".format(name_prefix, MODEL_EXTENSION)) | |
if(not os.path.isdir(path)): | |
return 0 | |
matches = [re.match(model_re, f) for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))] | |
# print(matches) | |
indices = sorted([int(m.group(1)) for m in matches if m is not None]) | |
if(return_all_checkpoint): | |
return indices | |
elif(len(indices) == 0): | |
return 0 | |
else: | |
return indices[-1] | |
def save_and_clear_model(model, path, name_prefix="model", checkpoint_idx=0, maximum_saved_model=5): | |
"""Keep only last n models when saving. Explicitly save the model regardless of its checkpoint index, e.g if checkpoint_idx=0 & model 3 4 5 6 7 is in path, it will remove 3 and save 0 instead.""" | |
indices = check_model_in_path(path, name_prefix=name_prefix, return_all_checkpoint=True) | |
if(maximum_saved_model <= len(indices)): | |
# remove models until n-1 models are left | |
for i in indices[:-(maximum_saved_model-1)]: | |
os.remove(os.path.join(path, MODEL_FILE_FORMAT.format(name_prefix, i, MODEL_EXTENSION))) | |
# perform save as normal | |
save_model_to_path(model, path, name_prefix=name_prefix, checkpoint_idx=checkpoint_idx) | |
def load_model_score(path, score_file=BEST_MODEL_FILE): | |
"""Load the model score as a list from a json dump, organized from best to worst.""" | |
score_file_path = os.path.join(path, score_file) | |
if(not os.path.isfile(score_file_path)): | |
return [] | |
with io.open(score_file_path, "r") as jf: | |
return json.load(jf) | |
def write_model_score(path, score_obj, score_file=BEST_MODEL_FILE): | |
with io.open(os.path.join(path, score_file), "w") as jf: | |
json.dump(score_obj, jf) | |
def save_model_best_to_path(model, path, score_obj, model_metric, best_model_prefix="best_model", maximum_saved_model=5, score_file=BEST_MODEL_FILE, save_after_update=True): | |
worst_score = score_obj[-1] if len(score_obj) > 0 else -1.0 | |
if(model_metric > worst_score): | |
# perform update, overriding a slot or create new if needed | |
insert_loc = next((idx for idx, score in enumerate(score_obj) if model_metric > score), 0) | |
# every model below it, up to {maximum_saved_model}, will be moved down an index | |
for i in range(insert_loc, min(len(score_obj), maximum_saved_model)-1): # -1, due to the models are copied up to +1 | |
old_loc = save_path = os.path.join(path, MODEL_FILE_FORMAT.format(best_model_prefix, i, MODEL_EXTENSION)) | |
new_loc = save_path = os.path.join(path, MODEL_FILE_FORMAT.format(best_model_prefix, i+1, MODEL_EXTENSION)) | |
copy(old_loc, new_loc) | |
# save the model to the selected loc | |
save_model_to_path(model, path, name_prefix=best_model_prefix, checkpoint_idx=insert_loc) | |
# update the score obj | |
score_obj.insert(insert_loc, model_metric) | |
score_obj = score_obj[:maximum_saved_model] | |
# also update in disk, if enabled | |
if(save_after_update): | |
write_model_score(path, score_obj, score_file=score_file) | |
# after routine had been done, return the obj | |
return score_obj | |