lewistape commited on
Commit
6a7ab5f
·
verified ·
1 Parent(s): 2b90fdf

Update asr.py

Browse files
Files changed (1) hide show
  1. asr.py +21 -10
asr.py CHANGED
@@ -11,6 +11,13 @@ MODEL_ID = "facebook/mms-1b-all"
11
  processor = AutoProcessor.from_pretrained(MODEL_ID)
12
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID).to("cuda")
13
 
 
 
 
 
 
 
 
14
  def transcribe(audio_data=None, lang="eng (English)"):
15
  if audio_data is None or (isinstance(audio_data, np.ndarray) and audio_data.size == 0):
16
  return "<<ERROR: Empty Audio Input>>"
@@ -32,10 +39,21 @@ def transcribe(audio_data=None, lang="eng (English)"):
32
  else:
33
  return f"<<ERROR: Invalid Audio Input Instance: {type(audio_data)}>>"
34
 
 
35
  lang_code = lang.split()[0]
36
- processor.tokenizer.set_target_lang(lang_code)
37
- model.load_adapter(lang_code)
38
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  inputs = processor(audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt").to("cuda")
40
 
41
  with torch.no_grad(), autocast():
@@ -44,11 +62,4 @@ def transcribe(audio_data=None, lang="eng (English)"):
44
  ids = torch.argmax(outputs, dim=-1)[0]
45
  transcription = processor.decode(ids)
46
 
47
- return transcription
48
-
49
- ASR_LANGUAGES = {
50
- "eng": "English",
51
- "spa": "Spanish",
52
- "fra": "French",
53
- # Add more languages as needed
54
- }
 
11
  processor = AutoProcessor.from_pretrained(MODEL_ID)
12
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID).to("cuda")
13
 
14
+ # Load supported languages from the TSV file
15
+ ASR_LANGUAGES = {}
16
+ with open("data/asr/all_langs.tsv") as f:
17
+ for line in f:
18
+ iso, name = line.split(" ", 1)
19
+ ASR_LANGUAGES[iso.strip()] = name.strip()
20
+
21
  def transcribe(audio_data=None, lang="eng (English)"):
22
  if audio_data is None or (isinstance(audio_data, np.ndarray) and audio_data.size == 0):
23
  return "<<ERROR: Empty Audio Input>>"
 
39
  else:
40
  return f"<<ERROR: Invalid Audio Input Instance: {type(audio_data)}>>"
41
 
42
+ # Extract language code (e.g., "eng" from "eng (English)")
43
  lang_code = lang.split()[0]
 
 
44
 
45
+ # Validate if the language code is supported
46
+ if lang_code not in ASR_LANGUAGES:
47
+ return f"<<ERROR: Unsupported Language Code: {lang_code}>>"
48
+
49
+ try:
50
+ # Set target language and load adapter
51
+ processor.tokenizer.set_target_lang(lang_code)
52
+ model.load_adapter(lang_code)
53
+ except Exception as e:
54
+ return f"<<ERROR: Language Adaptation Failed: {str(e)}>>"
55
+
56
+ # Process audio and perform transcription
57
  inputs = processor(audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt").to("cuda")
58
 
59
  with torch.no_grad(), autocast():
 
62
  ids = torch.argmax(outputs, dim=-1)[0]
63
  transcription = processor.decode(ids)
64
 
65
+ return transcription