OratioAI / config.py
torinriley's picture
Update config.py
a60d524 verified
raw
history blame
1.05 kB
from pathlib import Path
def get_config():
return {
"batch_size": 8,
"num_epochs": 40,
"lr": 10**-4,
"seq_len": 512,
"d_model": 512,
"datasource": "opus_books",
"lang_src": "en",
"lang_tgt": "fr",
"model_folder": "weights_fr",
"model_basename": "fr_model_",
"preload": None,
"tokenizer_file": "tokenizer_{0}.json",
"experiment_name": "runs/oratio_en_fr"
}
def get_weights_file_path(config, epoch: str):
model_folder = f"{config['datasource']}_{config['model_folder']}"
model_filename = f"{config['model_basename']}{epoch}.pt"
return str(Path('.') / model_folder / model_filename)
def latest_weights_file_path(config):
model_folder = f"{config['datasource']}_{config['model_folder']}"
model_filename = f"{config['model_basename']}*"
weights_files = list(Path(model_folder).glob(model_filename))
if len(weights_files) == 0:
return None
weights_files.sort()
return str(weights_files[-1])