import torch import torchaudio import numpy as np from transformers import Wav2Vec2ForCTC, AutoProcessor from torch.cuda.amp import autocast ASR_SAMPLING_RATE = 16_000 # Load model and processor only once MODEL_ID = "facebook/mms-1b-all" processor = AutoProcessor.from_pretrained(MODEL_ID) model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID).to("cuda") # Load supported languages from the TSV file ASR_LANGUAGES = {} with open("data/asr/all_langs.tsv") as f: for line in f: iso, name = line.split(" ", 1) ASR_LANGUAGES[iso.strip()] = name.strip() def transcribe(audio_data=None, lang="eng (English)"): if audio_data is None or (isinstance(audio_data, np.ndarray) and audio_data.size == 0): return "<>" if isinstance(audio_data, tuple): sr, audio_samples = audio_data audio_samples = (audio_samples / 32768.0).astype(np.float32) if sr != ASR_SAMPLING_RATE: audio_samples = torchaudio.functional.resample( torch.tensor(audio_samples), sr, ASR_SAMPLING_RATE ).numpy() elif isinstance(audio_data, np.ndarray): audio_samples = audio_data elif isinstance(audio_data, str): audio_samples, sr = torchaudio.load(audio_data) if sr != ASR_SAMPLING_RATE: audio_samples = torchaudio.functional.resample(audio_samples, sr, ASR_SAMPLING_RATE) audio_samples = audio_samples.numpy() else: return f"<>" # Extract language code (e.g., "eng" from "eng (English)") lang_code = lang.split()[0] # Validate if the language code is supported if lang_code not in ASR_LANGUAGES: return f"<>" try: # Set target language and load adapter processor.tokenizer.set_target_lang(lang_code) model.load_adapter(lang_code) except Exception as e: return f"<>" # Process audio and perform transcription inputs = processor(audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt").to("cuda") with torch.no_grad(), autocast(): outputs = model(**inputs).logits ids = torch.argmax(outputs, dim=-1)[0] transcription = processor.decode(ids) return transcription