Spaces:
No application file
No application file
from typing import Dict, Callable, List | |
import os | |
from ..utils.data_util import dict_has_keys, dict_get_keys | |
from .model_cfg import ModelCfg | |
def get_model_path( | |
model_names: List[str], | |
online_dir: str, | |
offline_dir: str, | |
download_func: Callable, | |
) -> Dict: | |
"""get model_path dict by model_name. If not existed, do download. | |
Args: | |
model_name (str): _description_ | |
online_dir (str): _description_ | |
offline_dir (str): _description_ | |
download_func (Callable): _description_ | |
Returns: | |
Dict: _description_ | |
""" | |
if not dict_has_keys(ModelCfg, model_names): | |
print("please set online model_path at least for {}".format(model_names)) | |
return | |
else: | |
model_basename_dct = dict_get_keys(ModelCfg, model_names) | |
offline_path_dct = {} | |
for k, v in model_basename_dct.items(): | |
offline_path = os.path.join(offline_dir, v) | |
os.makedirs(os.path.dirname(offline_path), exist_ok=True) | |
if not os.path.exists(offline_path): | |
online_path = os.path.join(online_dir, v) | |
print( | |
"starting downloading models from {} to".format( | |
online_path, offline_path | |
) | |
) | |
download_func(online_path, offline_dir) | |
else: | |
print("load offline model from {}".format(offline_path)) | |
offline_path_dct[k] = offline_path | |
return offline_path_dct | |