lewistape commited on
Commit
8fc270e
·
verified ·
1 Parent(s): d9961d8

Update asr.py

Browse files
Files changed (1) hide show
  1. asr.py +74 -105
asr.py CHANGED
@@ -3,84 +3,115 @@ from transformers import Wav2Vec2ForCTC, AutoProcessor
3
  import torch
4
  import numpy as np
5
  from pathlib import Path
6
- import concurrent.futures
7
  from torch.cuda.amp import autocast
8
  from huggingface_hub import hf_hub_download
9
  from pyctcdecode import build_ctcdecoder
10
  import json
11
- import resampy # Import resampy for faster resampling
12
 
13
  ASR_SAMPLING_RATE = 16_000
14
- CHUNK_LENGTH_S = 60 # Adjust based on your testing
15
- MAX_CONCURRENT_CHUNKS = 4 # Adjust based on VRAM (monitor with nvidia-smi)
16
- BATCH_SIZE = 4 # Batch size for processing chunks within process_chunk
17
-
18
  ASR_LANGUAGES = {}
19
- with open(f"data/asr/all_langs.tsv", "r") as f:
 
 
20
  for line in f:
21
  iso, name = line.split(" ", 1)
22
  ASR_LANGUAGES[iso.strip()] = name.strip()
23
 
24
  MODEL_ID = "facebook/mms-1b-all"
25
-
26
  processor = AutoProcessor.from_pretrained(MODEL_ID)
27
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
28
 
29
- # Optimize model for inference
30
- model.eval() # Ensure the model is in evaluation mode
31
 
32
- # Dictionary to store loaded adapters
33
- loaded_adapters = {}
34
 
35
- # Dictionary to cache language model decoders for each language
36
- cached_decoders = {}
37
 
38
  def load_audio(audio_data):
39
  if isinstance(audio_data, tuple):
40
  sr, audio_samples = audio_data
41
  audio_samples = (audio_samples / 32768.0).astype(np.float32)
42
  if sr != ASR_SAMPLING_RATE:
43
- audio_samples = resampy.resample(audio_samples, sr, ASR_SAMPLING_RATE) # Use resampy
44
  elif isinstance(audio_data, np.ndarray):
45
  audio_samples = audio_data
46
  elif isinstance(audio_data, str):
47
  audio_samples, sr = librosa.load(audio_data, sr=ASR_SAMPLING_RATE, mono=True)
48
  if sr != ASR_SAMPLING_RATE:
49
- audio_samples = resampy.resample(audio_samples, sr, ASR_SAMPLING_RATE) # Use resampy
50
  else:
51
- raise ValueError(f"Invalid Audio Input Instance: {type(audio_data)}")
52
  return audio_samples
53
 
 
54
  def process_chunk(chunks, device, decoder=None):
55
- batch_size = BATCH_SIZE # Local batch size
56
  transcriptions = []
 
 
 
 
 
57
 
58
- for i in range(0, len(chunks), batch_size):
59
- batch = chunks[i : i + batch_size]
60
  inputs = processor(
61
  batch,
62
  sampling_rate=ASR_SAMPLING_RATE,
63
  return_tensors="pt",
64
  padding=True,
65
  truncation=True,
66
- ).to(
67
- device
68
- ) # Enable padding
69
  with torch.no_grad():
70
  with autocast():
71
  outputs = model(**inputs).logits
72
 
73
  if decoder:
74
- # Batch decoding with LM (if pyctcdecode supports it)
75
- texts = decoder.decode_batch(outputs.cpu().numpy()) # Check for batch support
76
  transcriptions.extend(texts)
77
  else:
78
  ids = torch.argmax(outputs, dim=-1)
79
- for id_tensor in ids:
80
- transcriptions.append(processor.decode(id_tensor))
81
 
82
  return " ".join(transcriptions)
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  def transcribe(audio_data=None, lang="eng (English)", use_lm_decoder=False):
85
  if audio_data is None or (isinstance(audio_data, np.ndarray) and audio_data.size == 0):
86
  return "<<ERROR: Empty Audio Input>>"
@@ -92,97 +123,35 @@ def transcribe(audio_data=None, lang="eng (English)", use_lm_decoder=False):
92
 
93
  lang_code = lang.split()[0]
94
 
95
- # Load adapter efficiently
96
  if lang_code not in loaded_adapters:
97
  processor.tokenizer.set_target_lang(lang_code)
98
  model.load_adapter(lang_code)
99
- loaded_adapters[lang_code] = True # Mark as loaded
100
 
101
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
102
  model.to(device)
103
 
104
- # Create chunks
105
  chunk_length = int(CHUNK_LENGTH_S * ASR_SAMPLING_RATE)
106
  chunks = [
107
- audio_samples[i : i + chunk_length]
108
  for i in range(0, len(audio_samples), chunk_length)
109
  ]
110
 
111
- # Use cached language model decoder if available
112
- if use_lm_decoder and lang_code in cached_decoders:
113
- decoder = cached_decoders[lang_code]
114
- else:
115
- decoder = None
116
- if use_lm_decoder:
117
- lm_decoding_config = {}
118
- lm_decoding_configfile = hf_hub_download(
119
- repo_id="facebook/mms-cclms",
120
- filename="decoding_config.json",
121
- subfolder="mms-1b-all",
122
- )
123
-
124
- with open(lm_decoding_configfile) as f:
125
- lm_decoding_config = json.loads(f.read())
126
-
127
- if lang_code in lm_decoding_config:
128
- decoding_config = lm_decoding_config[lang_code]
129
-
130
- lm_file = hf_hub_download(
131
- repo_id="facebook/mms-cclms",
132
- filename=decoding_config["lmfile"].rsplit("/", 1)[1],
133
- subfolder=decoding_config["lmfile"].rsplit("/", 1)[0],
134
- )
135
- token_file = hf_hub_download(
136
- repo_id="facebook/mms-cclms",
137
- filename=decoding_config["tokensfile"].rsplit("/", 1)[1],
138
- subfolder=decoding_config["tokensfile"].rsplit("/", 1)[0],
139
- )
140
- lexicon_file = None
141
- if decoding_config["lexiconfile"] is not None:
142
- lexicon_file = hf_hub_download(
143
- repo_id="facebook/mms-cclms",
144
- filename=decoding_config["lexiconfile"].rsplit("/", 1)[1],
145
- subfolder=decoding_config["lexiconfile"].rsplit("/", 1)[0],
146
- )
147
-
148
- vocab_dict = processor.tokenizer.get_vocab()
149
- sort_vocab = sorted((value, key) for (key, value) in vocab_dict.items())
150
- vocab = [x[1] for x in sort_vocab]
151
- vocab_list = vocab
152
- # Update special tokens
153
- vocab_list[vocab_list.index("<s>")] = "<s>"
154
- vocab_list[vocab_list.index("</s>")] = "</s>"
155
- vocab_list[vocab_list.index("<pad>")] = "<pad>"
156
-
157
- decoder = build_ctcdecoder(
158
- vocab_list,
159
- kenlm_model_path=lm_file, # either .arpa or .bin file
160
- alpha=float(decoding_config["alpha"]),
161
- beta=float(decoding_config["beta"]),
162
- )
163
-
164
- # Cache the decoder for this language
165
- cached_decoders[lang_code] = decoder
166
-
167
- # Process chunks with the selected batch size
168
- transcription = process_chunk(chunks, device, decoder)
169
-
170
- return transcription
171
-
172
- # Example usage (Make sure the file paths are correct)
173
- ASR_EXAMPLES = [
174
- ["upload/english.mp3", "eng (English)"], # Update with your file paths
175
- # ["upload/tamil.mp3", "tam (Tamil)"],
176
- # ["upload/burmese.mp3", "mya (Burmese)"],
177
- ]
178
 
179
- # Example to transcribe with LM decoding (for supported languages like English)
180
- # result_with_lm = transcribe("upload/english.mp3", "eng (English)", use_lm_decoder=True)
181
- # print(f"Transcription with LM decoding: {result_with_lm}")
182
 
183
- # Example to transcribe without LM decoding
184
- # result_without_lm = transcribe("upload/english.mp3", "eng (English)", use_lm_decoder=False)
185
- # print(f"Transcription without LM decoding: {result_without_lm}")
 
186
 
187
  for audio_path, lang in ASR_EXAMPLES:
188
  try:
@@ -198,4 +167,4 @@ for audio_path, lang in ASR_EXAMPLES:
198
  else:
199
  print(f"Error: File not found: {audio_path}")
200
  except Exception as e:
201
- print(f"An error occurred while processing {audio_path}: {e}")
 
3
  import torch
4
  import numpy as np
5
  from pathlib import Path
 
6
  from torch.cuda.amp import autocast
7
  from huggingface_hub import hf_hub_download
8
  from pyctcdecode import build_ctcdecoder
9
  import json
10
+ import resampy # For efficient resampling
11
 
12
  ASR_SAMPLING_RATE = 16_000
13
+ CHUNK_LENGTH_S = 60 # Adjust chunk length in seconds
14
+ BATCH_SIZE = 4 # Batch size for processing chunks
 
 
15
  ASR_LANGUAGES = {}
16
+
17
+ # Load available ASR languages
18
+ with open("data/asr/all_langs.tsv", "r") as f:
19
  for line in f:
20
  iso, name = line.split(" ", 1)
21
  ASR_LANGUAGES[iso.strip()] = name.strip()
22
 
23
  MODEL_ID = "facebook/mms-1b-all"
 
24
  processor = AutoProcessor.from_pretrained(MODEL_ID)
25
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
26
 
27
+ # Ensure the model is in evaluation mode for inference
28
+ model.eval()
29
 
30
+ loaded_adapters = {} # Store loaded adapters
31
+ cached_decoders = {} # Cache language model decoders for each language
32
 
 
 
33
 
34
  def load_audio(audio_data):
35
  if isinstance(audio_data, tuple):
36
  sr, audio_samples = audio_data
37
  audio_samples = (audio_samples / 32768.0).astype(np.float32)
38
  if sr != ASR_SAMPLING_RATE:
39
+ audio_samples = resampy.resample(audio_samples, sr, ASR_SAMPLING_RATE)
40
  elif isinstance(audio_data, np.ndarray):
41
  audio_samples = audio_data
42
  elif isinstance(audio_data, str):
43
  audio_samples, sr = librosa.load(audio_data, sr=ASR_SAMPLING_RATE, mono=True)
44
  if sr != ASR_SAMPLING_RATE:
45
+ audio_samples = resampy.resample(audio_samples, sr, ASR_SAMPLING_RATE)
46
  else:
47
+ raise ValueError(f"Invalid Audio Input: {type(audio_data)}")
48
  return audio_samples
49
 
50
+
51
  def process_chunk(chunks, device, decoder=None):
 
52
  transcriptions = []
53
+ max_length = CHUNK_LENGTH_S * ASR_SAMPLING_RATE # Maximum input length for truncation
54
+
55
+ for i in range(0, len(chunks), BATCH_SIZE):
56
+ batch = chunks[i:i + BATCH_SIZE]
57
+ batch = [chunk[:max_length] for chunk in batch] # Truncate each chunk to max_length
58
 
 
 
59
  inputs = processor(
60
  batch,
61
  sampling_rate=ASR_SAMPLING_RATE,
62
  return_tensors="pt",
63
  padding=True,
64
  truncation=True,
65
+ max_length=max_length,
66
+ ).to(device)
67
+
68
  with torch.no_grad():
69
  with autocast():
70
  outputs = model(**inputs).logits
71
 
72
  if decoder:
73
+ texts = decoder.decode_batch(outputs.cpu().numpy())
 
74
  transcriptions.extend(texts)
75
  else:
76
  ids = torch.argmax(outputs, dim=-1)
77
+ transcriptions.extend(processor.batch_decode(ids))
 
78
 
79
  return " ".join(transcriptions)
80
 
81
+
82
+ def load_decoder_for_language(lang_code):
83
+ lm_decoding_configfile = hf_hub_download(
84
+ repo_id="facebook/mms-cclms",
85
+ filename="decoding_config.json",
86
+ subfolder="mms-1b-all",
87
+ )
88
+
89
+ with open(lm_decoding_configfile) as f:
90
+ lm_decoding_config = json.load(f)
91
+
92
+ if lang_code in lm_decoding_config:
93
+ decoding_config = lm_decoding_config[lang_code]
94
+
95
+ lm_file = hf_hub_download(
96
+ repo_id="facebook/mms-cclms",
97
+ filename=decoding_config["lmfile"].rsplit("/", 1)[1],
98
+ subfolder=decoding_config["lmfile"].rsplit("/", 1)[0],
99
+ )
100
+ vocab_dict = processor.tokenizer.get_vocab()
101
+ vocab_list = [key for key, _ in sorted(vocab_dict.items(), key=lambda item: item[1])]
102
+ vocab_list[vocab_list.index("<s>")] = "<s>"
103
+ vocab_list[vocab_list.index("</s>")] = "</s>"
104
+
105
+ return build_ctcdecoder(
106
+ vocab_list,
107
+ kenlm_model_path=lm_file,
108
+ alpha=float(decoding_config["alpha"]),
109
+ beta=float(decoding_config["beta"]),
110
+ )
111
+ else:
112
+ raise ValueError(f"No LM configuration found for language code: {lang_code}")
113
+
114
+
115
  def transcribe(audio_data=None, lang="eng (English)", use_lm_decoder=False):
116
  if audio_data is None or (isinstance(audio_data, np.ndarray) and audio_data.size == 0):
117
  return "<<ERROR: Empty Audio Input>>"
 
123
 
124
  lang_code = lang.split()[0]
125
 
 
126
  if lang_code not in loaded_adapters:
127
  processor.tokenizer.set_target_lang(lang_code)
128
  model.load_adapter(lang_code)
129
+ loaded_adapters[lang_code] = True
130
 
131
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
132
  model.to(device)
133
 
 
134
  chunk_length = int(CHUNK_LENGTH_S * ASR_SAMPLING_RATE)
135
  chunks = [
136
+ audio_samples[i:i + chunk_length]
137
  for i in range(0, len(audio_samples), chunk_length)
138
  ]
139
 
140
+ decoder = cached_decoders.get(lang_code) if use_lm_decoder else None
141
+ if use_lm_decoder and lang_code not in cached_decoders:
142
+ try:
143
+ decoder = load_decoder_for_language(lang_code)
144
+ cached_decoders[lang_code] = decoder
145
+ except Exception as e:
146
+ print(f"<<WARNING: Could not load LM decoder for {lang_code}: {str(e)}>>")
147
+
148
+ return process_chunk(chunks, device, decoder)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
 
 
 
150
 
151
+ # Example usage
152
+ ASR_EXAMPLES = [
153
+ ["upload/english.mp3", "eng (English)"],
154
+ ]
155
 
156
  for audio_path, lang in ASR_EXAMPLES:
157
  try:
 
167
  else:
168
  print(f"Error: File not found: {audio_path}")
169
  except Exception as e:
170
+ print(f"An error occurred while processing {audio_path}: {e}")