vineelpratap commited on
Commit
65d863f
1 Parent(s): 69b07b9

Update asr_lm_eng.py

Browse files
Files changed (1) hide show
  1. asr_lm_eng.py +48 -63
asr_lm_eng.py CHANGED
@@ -21,54 +21,56 @@ processor = AutoProcessor.from_pretrained(MODEL_ID)
21
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
22
 
23
 
24
- # lm_decoding_config = {}
25
- # lm_decoding_configfile = hf_hub_download(
26
- # repo_id="facebook/mms-cclms",
27
- # filename="decoding_config.json",
28
- # subfolder="mms-1b-all",
29
- # )
30
-
31
- # with open(lm_decoding_configfile) as f:
32
- # lm_decoding_config = json.loads(f.read())
33
-
34
- # # allow language model decoding for "eng"
35
-
36
- # decoding_config = lm_decoding_config["eng"]
37
-
38
- # lm_file = hf_hub_download(
39
- # repo_id="facebook/mms-cclms",
40
- # filename=decoding_config["lmfile"].rsplit("/", 1)[1],
41
- # subfolder=decoding_config["lmfile"].rsplit("/", 1)[0],
42
- # )
43
- # token_file = hf_hub_download(
44
- # repo_id="facebook/mms-cclms",
45
- # filename=decoding_config["tokensfile"].rsplit("/", 1)[1],
46
- # subfolder=decoding_config["tokensfile"].rsplit("/", 1)[0],
47
- # )
48
- # lexicon_file = None
49
- # if decoding_config["lexiconfile"] is not None:
50
- # lexicon_file = hf_hub_download(
51
- # repo_id="facebook/mms-cclms",
52
- # filename=decoding_config["lexiconfile"].rsplit("/", 1)[1],
53
- # subfolder=decoding_config["lexiconfile"].rsplit("/", 1)[0],
54
- # )
55
-
56
- # beam_search_decoder = ctc_decoder(
57
- # lexicon=lexicon_file,
58
- # tokens=token_file,
59
- # lm=lm_file,
60
- # nbest=1,
61
- # beam_size=500,
62
- # beam_size_token=50,
63
- # lm_weight=float(decoding_config["lmweight"]),
64
- # word_score=float(decoding_config["wordscore"]),
65
- # sil_score=float(decoding_config["silweight"]),
66
- # blank_token="<s>",
67
- # )
68
 
69
 
70
  def transcribe(audio_data=None, lang="eng (English)"):
71
 
 
 
72
  if not audio_data:
73
  return "<<ERROR: Empty Audio Input>>"
74
 
@@ -113,24 +115,7 @@ def transcribe(audio_data=None, lang="eng (English)"):
113
  with torch.no_grad():
114
  outputs = model(**inputs).logits
115
 
116
- if lang_code != "eng" or True:
117
- ids = torch.argmax(outputs, dim=-1)[0]
118
- transcription = processor.decode(ids)
119
- else:
120
- assert False
121
- # beam_search_result = beam_search_decoder(outputs.to("cpu"))
122
- # transcription = " ".join(beam_search_result[0][0].words).strip()
123
 
124
  return transcription
125
-
126
-
127
- ASR_EXAMPLES = [
128
- ["upload/english.mp3", "eng (English)"],
129
- # ["upload/tamil.mp3", "tam (Tamil)"],
130
- # ["upload/burmese.mp3", "mya (Burmese)"],
131
- ]
132
-
133
- ASR_NOTE = """
134
- The above demo doesn't use beam-search decoding using a language model.
135
- Checkout the instructions [here](https://huggingface.co/facebook/mms-1b-all) on how to run LM decoding for better accuracy.
136
- """
 
21
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
22
 
23
 
24
+ lm_decoding_config = {}
25
+ lm_decoding_configfile = hf_hub_download(
26
+ repo_id="facebook/mms-cclms",
27
+ filename="decoding_config.json",
28
+ subfolder="mms-1b-all",
29
+ )
30
+
31
+ with open(lm_decoding_configfile) as f:
32
+ lm_decoding_config = json.loads(f.read())
33
+
34
+ # allow language model decoding for "eng"
35
+
36
+ decoding_config = lm_decoding_config["eng"]
37
+
38
+ lm_file = hf_hub_download(
39
+ repo_id="facebook/mms-cclms",
40
+ filename=decoding_config["lmfile"].rsplit("/", 1)[1],
41
+ subfolder=decoding_config["lmfile"].rsplit("/", 1)[0],
42
+ )
43
+ token_file = hf_hub_download(
44
+ repo_id="facebook/mms-cclms",
45
+ filename=decoding_config["tokensfile"].rsplit("/", 1)[1],
46
+ subfolder=decoding_config["tokensfile"].rsplit("/", 1)[0],
47
+ )
48
+ lexicon_file = None
49
+ if decoding_config["lexiconfile"] is not None:
50
+ lexicon_file = hf_hub_download(
51
+ repo_id="facebook/mms-cclms",
52
+ filename=decoding_config["lexiconfile"].rsplit("/", 1)[1],
53
+ subfolder=decoding_config["lexiconfile"].rsplit("/", 1)[0],
54
+ )
55
+
56
+ beam_search_decoder = ctc_decoder(
57
+ lexicon=lexicon_file,
58
+ tokens=token_file,
59
+ lm=lm_file,
60
+ nbest=1,
61
+ beam_size=500,
62
+ beam_size_token=50,
63
+ lm_weight=float(decoding_config["lmweight"]),
64
+ word_score=float(decoding_config["wordscore"]),
65
+ sil_score=float(decoding_config["silweight"]),
66
+ blank_token="<s>",
67
+ )
68
 
69
 
70
  def transcribe(audio_data=None, lang="eng (English)"):
71
 
72
+ assert lang.startswith("eng")
73
+
74
  if not audio_data:
75
  return "<<ERROR: Empty Audio Input>>"
76
 
 
115
  with torch.no_grad():
116
  outputs = model(**inputs).logits
117
 
118
+ beam_search_result = beam_search_decoder(outputs.to("cpu"))
119
+ transcription = " ".join(beam_search_result[0][0].words).strip()
 
 
 
 
 
120
 
121
  return transcription