Spaces:
Sleeping
Sleeping
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
from enum import Enum | |
from functools import lru_cache | |
import logging | |
import os | |
import platform | |
from pathlib import Path | |
import huggingface_hub | |
import sherpa | |
import sherpa_onnx | |
main_logger = logging.getLogger("main") | |
class EnumDecodingMethod(Enum): | |
greedy_search = "greedy_search" | |
modified_beam_search = "modified_beam_search" | |
model_map = { | |
"Chinese": [ | |
{ | |
"repo_id": "csukuangfj/wenet-chinese-model", | |
"nn_model_file": "final.zip", | |
"nn_model_file_sub_folder": ".", | |
"tokens_file": "units.txt", | |
"tokens_file_sub_folder": ".", | |
"normalize_samples": False, | |
"loader": "load_sherpa_offline_recognizer", | |
}, | |
{ | |
"repo_id": "csukuangfj/sherpa-onnx-paraformer-zh-2024-03-09", | |
"nn_model_file": "model.int8.onnx", | |
"nn_model_file_sub_folder": ".", | |
"tokens_file": "tokens.txt", | |
"tokens_file_sub_folder": ".", | |
"loader": "load_sherpa_offline_recognizer_from_paraformer", | |
}, | |
{ | |
"repo_id": "csukuangfj/sherpa-onnx-paraformer-zh-small-2024-03-09", | |
"nn_model_file": "model.int8.onnx", | |
"nn_model_file_sub_folder": ".", | |
"tokens_file": "tokens.txt", | |
"tokens_file_sub_folder": ".", | |
"loader": "load_sherpa_offline_recognizer_from_paraformer", | |
}, | |
{ | |
"repo_id": "luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2", | |
"nn_model_file": "cpu_jit_epoch_10_avg_2_torch_1.7.1.pt", | |
"nn_model_file_sub_folder": "exp", | |
"tokens_file": "tokens.txt", | |
"tokens_file_sub_folder": "data/lang_char", | |
"normalize_samples": True, | |
"loader": "load_sherpa_offline_recognizer", | |
}, | |
{ | |
"repo_id": "zrjin/sherpa-onnx-zipformer-multi-zh-hans-2023-9-2", | |
"encoder_model_file": "encoder-epoch-20-avg-1.onnx", | |
"encoder_model_file_sub_folder": ".", | |
"decoder_model_file": "decoder-epoch-20-avg-1.onnx", | |
"decoder_model_file_sub_folder": ".", | |
"joiner_model_file": "joiner-epoch-20-avg-1.onnx", | |
"joiner_model_file_sub_folder": ".", | |
"tokens_file": "tokens.txt", | |
"tokens_file_sub_folder": ".", | |
"loader": "load_sherpa_offline_recognizer_from_transducer", | |
}, | |
], | |
"English": [ | |
{ | |
"repo_id": "csukuangfj/sherpa-onnx-paraformer-en-2024-03-09", | |
"nn_model_file": "model.int8.onnx", | |
"nn_model_file_sub_folder": ".", | |
"tokens_file": "tokens.txt", | |
"tokens_file_sub_folder": ".", | |
"loader": "load_sherpa_offline_recognizer_from_paraformer", | |
}, | |
], | |
"Chinese+English": [ | |
{ | |
"repo_id": "csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28", | |
"nn_model_file": "model.int8.onnx", | |
"nn_model_file_sub_folder": ".", | |
"tokens_file": "tokens.txt", | |
"tokens_file_sub_folder": ".", | |
"loader": "load_sherpa_offline_recognizer_from_paraformer", | |
}, | |
], | |
"Chinese+Cantonese+English": [ | |
{ | |
"repo_id": "csukuangfj/sherpa-onnx-paraformer-trilingual-zh-cantonese-en", | |
"nn_model_file": "model.int8.onnx", | |
"nn_model_file_sub_folder": ".", | |
"tokens_file": "tokens.txt", | |
"tokens_file_sub_folder": ".", | |
"loader": "load_sherpa_offline_recognizer_from_paraformer", | |
}, | |
] | |
} | |
def download_model(local_model_dir: str, | |
**kwargs, | |
): | |
repo_id = kwargs["repo_id"] | |
if "nn_model_file" in kwargs.keys(): | |
main_logger.info("download nn_model_file. filename: {}, subfolder: {}".format(kwargs["nn_model_file"], kwargs["nn_model_file_sub_folder"])) | |
_ = huggingface_hub.hf_hub_download( | |
repo_id=repo_id, | |
filename=kwargs["nn_model_file"], | |
subfolder=kwargs["nn_model_file_sub_folder"], | |
local_dir=local_model_dir, | |
) | |
if "encoder_model_file" in kwargs.keys(): | |
main_logger.info("download encoder_model_file. filename: {}, subfolder: {}".format(kwargs["encoder_model_file"], kwargs["encoder_model_file_sub_folder"])) | |
_ = huggingface_hub.hf_hub_download( | |
repo_id=repo_id, | |
filename=kwargs["encoder_model_file"], | |
subfolder=kwargs["encoder_model_file_sub_folder"], | |
local_dir=local_model_dir, | |
) | |
if "decoder_model_file" in kwargs.keys(): | |
main_logger.info("download decoder_model_file. filename: {}, subfolder: {}".format(kwargs["decoder_model_file"], kwargs["decoder_model_file_sub_folder"])) | |
_ = huggingface_hub.hf_hub_download( | |
repo_id=repo_id, | |
filename=kwargs["decoder_model_file"], | |
subfolder=kwargs["decoder_model_file_sub_folder"], | |
local_dir=local_model_dir, | |
) | |
if "joiner_model_file" in kwargs.keys(): | |
main_logger.info("download joiner_model_file. filename: {}, subfolder: {}".format(kwargs["joiner_model_file"], kwargs["joiner_model_file_sub_folder"])) | |
_ = huggingface_hub.hf_hub_download( | |
repo_id=repo_id, | |
filename=kwargs["joiner_model_file"], | |
subfolder=kwargs["joiner_model_file_sub_folder"], | |
local_dir=local_model_dir, | |
) | |
if "tokens_file" in kwargs.keys(): | |
main_logger.info("download tokens_file. filename: {}, subfolder: {}".format(kwargs["tokens_file"], kwargs["tokens_file_sub_folder"])) | |
_ = huggingface_hub.hf_hub_download( | |
repo_id=repo_id, | |
filename=kwargs["tokens_file"], | |
subfolder=kwargs["tokens_file_sub_folder"], | |
local_dir=local_model_dir, | |
) | |
def load_sherpa_offline_recognizer(nn_model_file: str, | |
tokens_file: str, | |
sample_rate: int = 16000, | |
num_active_paths: int = 2, | |
decoding_method: str = "greedy_search", | |
num_mel_bins: int = 80, | |
frame_dither: int = 0, | |
normalize_samples: bool = False, | |
): | |
feat_config = sherpa.FeatureConfig(normalize_samples=normalize_samples) | |
feat_config.fbank_opts.frame_opts.samp_freq = sample_rate | |
feat_config.fbank_opts.mel_opts.num_bins = num_mel_bins | |
feat_config.fbank_opts.frame_opts.dither = frame_dither | |
if not os.path.exists(nn_model_file): | |
raise AssertionError("nn_model_file not found. ") | |
config = sherpa.OfflineRecognizerConfig( | |
nn_model=nn_model_file, | |
tokens=tokens_file, | |
use_gpu=False, | |
feat_config=feat_config, | |
decoding_method=decoding_method, | |
num_active_paths=num_active_paths, | |
) | |
recognizer = sherpa.OfflineRecognizer(config) | |
return recognizer | |
def load_sherpa_offline_recognizer_from_paraformer(nn_model_file: str, | |
tokens_file: str, | |
sample_rate: int = 16000, | |
decoding_method: str = "greedy_search", | |
feature_dim: int = 80, | |
num_threads: int = 2, | |
): | |
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( | |
paraformer=nn_model_file, | |
tokens=tokens_file, | |
num_threads=num_threads, | |
sample_rate=sample_rate, | |
feature_dim=feature_dim, | |
decoding_method=decoding_method, | |
debug=False, | |
) | |
return recognizer | |
def load_sherpa_offline_recognizer_from_transducer(encoder_model_file: str, | |
decoder_model_file: str, | |
joiner_model_file: str, | |
tokens_file: str, | |
sample_rate: int = 16000, | |
decoding_method: str = "greedy_search", | |
feature_dim: int = 80, | |
num_threads: int = 2, | |
num_active_paths: int = 2, | |
): | |
recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( | |
encoder=encoder_model_file, | |
decoder=decoder_model_file, | |
joiner=joiner_model_file, | |
tokens=tokens_file, | |
num_threads=num_threads, | |
sample_rate=sample_rate, | |
feature_dim=feature_dim, | |
decoding_method=decoding_method, | |
max_active_paths=num_active_paths, | |
) | |
return recognizer | |
def load_recognizer(local_model_dir: Path, | |
decoding_method: str = "greedy_search", | |
num_active_paths: int = 4, | |
**kwargs | |
): | |
if not local_model_dir.exists(): | |
download_model( | |
local_model_dir=local_model_dir.as_posix(), | |
**kwargs, | |
) | |
loader = kwargs["loader"] | |
kwargs_ = dict() | |
if "nn_model_file" in kwargs.keys(): | |
nn_model_file = (local_model_dir / kwargs["nn_model_file"]).as_posix() | |
kwargs_["nn_model_file"] = nn_model_file | |
if "encoder_model_file" in kwargs.keys(): | |
encoder_model_file = (local_model_dir / kwargs["encoder_model_file"]).as_posix() | |
kwargs_["encoder_model_file"] = encoder_model_file | |
if "decoder_model_file" in kwargs.keys(): | |
decoder_model_file = (local_model_dir / kwargs["decoder_model_file"]).as_posix() | |
kwargs_["decoder_model_file"] = decoder_model_file | |
if "joiner_model_file" in kwargs.keys(): | |
joiner_model_file = (local_model_dir / kwargs["joiner_model_file"]).as_posix() | |
kwargs_["joiner_model_file"] = joiner_model_file | |
if "tokens_file" in kwargs.keys(): | |
tokens_file = (local_model_dir / kwargs["tokens_file"]).as_posix() | |
kwargs_["tokens_file"] = tokens_file | |
if "normalize_samples" in kwargs.keys(): | |
kwargs_["normalize_samples"] = kwargs["normalize_samples"] | |
if loader == "load_sherpa_offline_recognizer": | |
recognizer = load_sherpa_offline_recognizer( | |
decoding_method=decoding_method, | |
num_active_paths=num_active_paths, | |
**kwargs_ | |
) | |
elif loader == "load_sherpa_offline_recognizer_from_paraformer": | |
recognizer = load_sherpa_offline_recognizer_from_paraformer( | |
decoding_method=decoding_method, | |
**kwargs_ | |
) | |
elif loader == "load_sherpa_offline_recognizer_from_transducer": | |
recognizer = load_sherpa_offline_recognizer_from_transducer( | |
decoding_method=decoding_method, | |
**kwargs_ | |
) | |
else: | |
raise NotImplementedError("loader not support: {}".format(loader)) | |
return recognizer | |
if __name__ == "__main__": | |
pass | |