File size: 1,517 Bytes
6755a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from typing import Dict, Callable, List
import os

from ..utils.data_util import dict_has_keys, dict_get_keys
from .model_cfg import ModelCfg


def get_model_path(
    model_names: List[str],
    online_dir: str,
    offline_dir: str,
    download_func: Callable,
) -> Dict:
    """get model_path dict by model_name. If not existed, do download.

    Args:
        model_name (str): _description_
        online_dir (str): _description_
        offline_dir (str): _description_
        download_func (Callable): _description_

    Returns:
        Dict: _description_
    """
    if not dict_has_keys(ModelCfg, model_names):
        print("please set online model_path at least for {}".format(model_names))
        return
    else:
        model_basename_dct = dict_get_keys(ModelCfg, model_names)
        offline_path_dct = {}
        for k, v in model_basename_dct.items():
            offline_path = os.path.join(offline_dir, v)
            os.makedirs(os.path.dirname(offline_path), exist_ok=True)
            if not os.path.exists(offline_path):
                online_path = os.path.join(online_dir, v)
                print(
                    "starting downloading models from {} to".format(
                        online_path, offline_path
                    )
                )
                download_func(online_path, offline_dir)
            else:
                print("load offline model from {}".format(offline_path))
            offline_path_dct[k] = offline_path
        return offline_path_dct