import os os.environ["MODELSCOPE_CACHE"] = ".cache/" import string import time from threading import Lock import librosa import numpy as np import opencc import torch from faster_whisper import WhisperModel t2s_converter = opencc.OpenCC("t2s") def load_model(*, device="cuda"): model = WhisperModel( "medium", device=device, compute_type="float16", download_root="faster_whisper", ) print("faster_whisper loaded!") return model @torch.no_grad() def batch_asr_internal(model: WhisperModel, audios, sr): resampled_audios = [] for audio in audios: if isinstance(audio, np.ndarray): audio = torch.from_numpy(audio).float() if audio.dim() > 1: audio = audio.squeeze() assert audio.dim() == 1 audio_np = audio.numpy() resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000) resampled_audios.append(resampled_audio) trans_results = [] for resampled_audio in resampled_audios: segments, info = model.transcribe( resampled_audio, language=None, beam_size=5, initial_prompt="Punctuation is needed in any language.", ) trans_results.append(list(segments)) results = [] for trans_res, audio in zip(trans_results, audios): duration = len(audio) / sr * 1000 huge_gap = False max_gap = 0.0 text = None last_tr = None for tr in trans_res: delta = tr.text.strip() if tr.id > 1: max_gap = max(tr.start - last_tr.end, max_gap) text += delta else: text = delta last_tr = tr if max_gap > 3.0: huge_gap = True break sim_text = t2s_converter.convert(text) results.append( { "text": sim_text, "duration": duration, "huge_gap": huge_gap, } ) return results global_lock = Lock() def batch_asr(model, audios, sr): return batch_asr_internal(model, audios, sr) def is_chinese(text): return True def calculate_wer(text1, text2, debug=False): chars1 = remove_punctuation(text1) chars2 = remove_punctuation(text2) m, n = len(chars1), len(chars2) if m > n: chars1, chars2 = chars2, chars1 m, n = n, m prev = list(range(m + 1)) # row 0 distance: [0, 1, 2, ...] curr = [0] * (m + 1) for j in range(1, n + 1): curr[0] = j for i in range(1, m + 1): if chars1[i - 1] == chars2[j - 1]: curr[i] = prev[i - 1] else: curr[i] = min(prev[i], curr[i - 1], prev[i - 1]) + 1 prev, curr = curr, prev edits = prev[m] tot = max(len(chars1), len(chars2)) wer = edits / tot if debug: print(" gt: ", chars1) print(" pred: ", chars2) print(" edits/tot = wer: ", edits, "/", tot, "=", wer) return wer def remove_punctuation(text): chinese_punctuation = ( " \n\t”“!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—" '‛""„‟…‧﹏' ) all_punctuation = string.punctuation + chinese_punctuation translator = str.maketrans("", "", all_punctuation) text_without_punctuation = text.translate(translator) return text_without_punctuation if __name__ == "__main__": model = load_model() audios = [ librosa.load("44100.wav", sr=44100)[0], librosa.load("lengyue.wav", sr=44100)[0], ] print(np.array(audios[0])) print(batch_asr(model, audios, 44100)) start_time = time.time() for _ in range(10): print(batch_asr(model, audios, 44100)) print("Time taken:", time.time() - start_time)