File size: 5,981 Bytes
8cb4f3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import os, re, io
import json
import dill as pickle
from shutil import copy2 as copy
MODEL_EXTENSION = ".pkl"
MODEL_FILE_FORMAT = "{:s}_{:d}{:s}" # model_prefix, epoch and extension
BEST_MODEL_FILE = ".model_score.txt"
MODEL_SERVE_FILE = ".serve.txt"
VOCAB_FILE_FORMAT = "{:s}{:s}{:s}"

def save_model_name(name, path, serve_config_path=MODEL_SERVE_FILE):
  with io.open(os.path.join(path, serve_config_path), "w", encoding="utf-8") as serve_config_file:
    serve_config_file.write(name)

def save_vocab_to_path(path, language_tuple, fields, name_prefix="vocab", check_saved_vocab=True):
  src_field, trg_field = fields
  src_ext, trg_ext = language_tuple
  src_vocab_path = os.path.join(path, VOCAB_FILE_FORMAT.format(name_prefix, src_ext, MODEL_EXTENSION))
  trg_vocab_path = os.path.join(path, VOCAB_FILE_FORMAT.format(name_prefix, trg_ext, MODEL_EXTENSION))
  if(check_saved_vocab and os.path.isfile(src_vocab_path) and os.path.isfile(trg_vocab_path)):# do nothing if already exist
    return
  with io.open(src_vocab_path , "wb") as src_vocab_file:
    pickle.dump(src_field.vocab, src_vocab_file)
  with io.open(trg_vocab_path , "wb") as trg_vocab_file:
    pickle.dump(trg_field.vocab, trg_vocab_file)

def load_vocab_from_path(path, language_tuple, fields, name_prefix="vocab"):
  """Load the vocabulary from path into respective fields. If files doesn't exist, return False; if loaded properly, return True"""
  src_field, trg_field = fields
  src_ext, trg_ext = language_tuple
  src_vocab_file_path = os.path.join(path, VOCAB_FILE_FORMAT.format(name_prefix, src_ext, MODEL_EXTENSION))
  trg_vocab_file_path = os.path.join(path, VOCAB_FILE_FORMAT.format(name_prefix, trg_ext, MODEL_EXTENSION))
  if(not os.path.isfile(src_vocab_file_path) or not os.path.isfile(trg_vocab_file_path)):
    # the vocab file wasn't dumped, return False
    return False
  with io.open(src_vocab_file_path, "rb") as src_vocab_file, io.open(trg_vocab_file_path, "rb") as trg_vocab_file:
    src_vocab = pickle.load(src_vocab_file)
    src_field.vocab = src_vocab
    trg_vocab = pickle.load(trg_vocab_file)
    trg_field.vocab = trg_vocab
  return True

def save_model_to_path(model, path, name_prefix="model", checkpoint_idx=0, save_vocab=True):
  save_path = os.path.join(path, MODEL_FILE_FORMAT.format(name_prefix, checkpoint_idx, MODEL_EXTENSION))
  torch.save(model.state_dict(), save_path)
  if(save_vocab):
    save_vocab_to_path(path, model.loader._language_tuple, model.fields)

def load_model_from_path(model, path, name_prefix="model", checkpoint_idx=0):
  # do not load vocab here, as the vocab structure will be decided in model.loader.build_vocab
  save_path = os.path.join(path, MODEL_FILE_FORMAT.format(name_prefix, checkpoint_idx, MODEL_EXTENSION))
  model.load_state_dict(torch.load(save_path))


def load_model(model, model_path):
  model.load_state_dict(torch.load(model_path))

def check_model_in_path(path, name_prefix="model", return_all_checkpoint=False):
  model_re = re.compile(r"{:s}_(\d+){:s}".format(name_prefix, MODEL_EXTENSION))
  if(not os.path.isdir(path)):
    return 0
  matches = [re.match(model_re, f) for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))]
#  print(matches)
  indices = sorted([int(m.group(1)) for m in matches if m is not None])
  if(return_all_checkpoint):
    return indices
  elif(len(indices) == 0):
    return 0
  else:
    return indices[-1]

def save_and_clear_model(model, path, name_prefix="model", checkpoint_idx=0, maximum_saved_model=5):
  """Keep only last n models when saving. Explicitly save the model regardless of its checkpoint index, e.g if checkpoint_idx=0 & model 3 4 5 6 7 is in path, it will remove 3 and save 0 instead."""
  indices = check_model_in_path(path, name_prefix=name_prefix, return_all_checkpoint=True)
  if(maximum_saved_model <= len(indices)):
    # remove models until n-1 models are left
    for i in indices[:-(maximum_saved_model-1)]:
      os.remove(os.path.join(path, MODEL_FILE_FORMAT.format(name_prefix, i, MODEL_EXTENSION)))
  # perform save as normal
  save_model_to_path(model, path, name_prefix=name_prefix, checkpoint_idx=checkpoint_idx)

def load_model_score(path, score_file=BEST_MODEL_FILE):
  """Load the model score as a list from a json dump, organized from best to worst."""
  score_file_path = os.path.join(path, score_file)
  if(not os.path.isfile(score_file_path)):
    return []
  with io.open(score_file_path, "r") as jf:
    return json.load(jf)

def write_model_score(path, score_obj, score_file=BEST_MODEL_FILE):
  with io.open(os.path.join(path, score_file), "w") as jf:
    json.dump(score_obj, jf)

def save_model_best_to_path(model, path, score_obj, model_metric, best_model_prefix="best_model", maximum_saved_model=5, score_file=BEST_MODEL_FILE, save_after_update=True):
  worst_score = score_obj[-1] if len(score_obj) > 0 else -1.0
  if(model_metric > worst_score):
    # perform update, overriding a slot or create new if needed
    insert_loc = next((idx for idx, score in enumerate(score_obj) if model_metric > score), 0)
    # every model below it, up to {maximum_saved_model}, will be moved down an index
    for i in range(insert_loc, min(len(score_obj), maximum_saved_model)-1): # -1, due to the models are copied up to +1
      old_loc = save_path = os.path.join(path, MODEL_FILE_FORMAT.format(best_model_prefix, i, MODEL_EXTENSION))
      new_loc = save_path = os.path.join(path, MODEL_FILE_FORMAT.format(best_model_prefix, i+1, MODEL_EXTENSION))
      copy(old_loc, new_loc)
    # save the model to the selected loc
    save_model_to_path(model, path, name_prefix=best_model_prefix, checkpoint_idx=insert_loc)
    # update the score obj
    score_obj.insert(insert_loc, model_metric)
    score_obj = score_obj[:maximum_saved_model]
    # also update in disk, if enabled
    if(save_after_update):
      write_model_score(path, score_obj, score_file=score_file)
  # after routine had been done, return the obj
  return score_obj