kevinwang676's picture
Upload folder using huggingface_hub
6755a2d verified
raw
history blame
1.52 kB
from typing import Dict, Callable, List
import os
from ..utils.data_util import dict_has_keys, dict_get_keys
from .model_cfg import ModelCfg
def get_model_path(
model_names: List[str],
online_dir: str,
offline_dir: str,
download_func: Callable,
) -> Dict:
"""get model_path dict by model_name. If not existed, do download.
Args:
model_name (str): _description_
online_dir (str): _description_
offline_dir (str): _description_
download_func (Callable): _description_
Returns:
Dict: _description_
"""
if not dict_has_keys(ModelCfg, model_names):
print("please set online model_path at least for {}".format(model_names))
return
else:
model_basename_dct = dict_get_keys(ModelCfg, model_names)
offline_path_dct = {}
for k, v in model_basename_dct.items():
offline_path = os.path.join(offline_dir, v)
os.makedirs(os.path.dirname(offline_path), exist_ok=True)
if not os.path.exists(offline_path):
online_path = os.path.join(online_dir, v)
print(
"starting downloading models from {} to".format(
online_path, offline_path
)
)
download_func(online_path, offline_dir)
else:
print("load offline model from {}".format(offline_path))
offline_path_dct[k] = offline_path
return offline_path_dct