Update asr.py
Browse files
asr.py
CHANGED
@@ -11,6 +11,13 @@ MODEL_ID = "facebook/mms-1b-all"
|
|
11 |
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
12 |
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID).to("cuda")
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
def transcribe(audio_data=None, lang="eng (English)"):
|
15 |
if audio_data is None or (isinstance(audio_data, np.ndarray) and audio_data.size == 0):
|
16 |
return "<<ERROR: Empty Audio Input>>"
|
@@ -32,10 +39,21 @@ def transcribe(audio_data=None, lang="eng (English)"):
|
|
32 |
else:
|
33 |
return f"<<ERROR: Invalid Audio Input Instance: {type(audio_data)}>>"
|
34 |
|
|
|
35 |
lang_code = lang.split()[0]
|
36 |
-
processor.tokenizer.set_target_lang(lang_code)
|
37 |
-
model.load_adapter(lang_code)
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
inputs = processor(audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt").to("cuda")
|
40 |
|
41 |
with torch.no_grad(), autocast():
|
@@ -44,11 +62,4 @@ def transcribe(audio_data=None, lang="eng (English)"):
|
|
44 |
ids = torch.argmax(outputs, dim=-1)[0]
|
45 |
transcription = processor.decode(ids)
|
46 |
|
47 |
-
return transcription
|
48 |
-
|
49 |
-
ASR_LANGUAGES = {
|
50 |
-
"eng": "English",
|
51 |
-
"spa": "Spanish",
|
52 |
-
"fra": "French",
|
53 |
-
# Add more languages as needed
|
54 |
-
}
|
|
|
11 |
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
12 |
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID).to("cuda")
|
13 |
|
14 |
+
# Load supported languages from the TSV file
|
15 |
+
ASR_LANGUAGES = {}
|
16 |
+
with open("data/asr/all_langs.tsv") as f:
|
17 |
+
for line in f:
|
18 |
+
iso, name = line.split(" ", 1)
|
19 |
+
ASR_LANGUAGES[iso.strip()] = name.strip()
|
20 |
+
|
21 |
def transcribe(audio_data=None, lang="eng (English)"):
|
22 |
if audio_data is None or (isinstance(audio_data, np.ndarray) and audio_data.size == 0):
|
23 |
return "<<ERROR: Empty Audio Input>>"
|
|
|
39 |
else:
|
40 |
return f"<<ERROR: Invalid Audio Input Instance: {type(audio_data)}>>"
|
41 |
|
42 |
+
# Extract language code (e.g., "eng" from "eng (English)")
|
43 |
lang_code = lang.split()[0]
|
|
|
|
|
44 |
|
45 |
+
# Validate if the language code is supported
|
46 |
+
if lang_code not in ASR_LANGUAGES:
|
47 |
+
return f"<<ERROR: Unsupported Language Code: {lang_code}>>"
|
48 |
+
|
49 |
+
try:
|
50 |
+
# Set target language and load adapter
|
51 |
+
processor.tokenizer.set_target_lang(lang_code)
|
52 |
+
model.load_adapter(lang_code)
|
53 |
+
except Exception as e:
|
54 |
+
return f"<<ERROR: Language Adaptation Failed: {str(e)}>>"
|
55 |
+
|
56 |
+
# Process audio and perform transcription
|
57 |
inputs = processor(audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt").to("cuda")
|
58 |
|
59 |
with torch.no_grad(), autocast():
|
|
|
62 |
ids = torch.argmax(outputs, dim=-1)[0]
|
63 |
transcription = processor.decode(ids)
|
64 |
|
65 |
+
return transcription
|
|
|
|
|
|
|
|
|
|
|
|
|
|