rristo commited on
Commit
86680a3
1 Parent(s): 6e4e7cd

more decoding methods

Browse files
Files changed (1) hide show
  1. err2020/conformer_ctc3_usage.ipynb +270 -37
err2020/conformer_ctc3_usage.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 3,
6
  "id": "b6b6ded1-0a58-43cb-9065-4f4fae02a01b",
7
  "metadata": {},
8
  "outputs": [],
@@ -50,18 +50,51 @@
50
  },
51
  {
52
  "cell_type": "code",
53
- "execution_count": 4,
54
  "id": "3d69d771-b421-417f-a6ff-e1d1c64ba934",
55
  "metadata": {},
56
  "outputs": [],
57
  "source": [
58
  "class Args:\n",
59
- " model_filename='conformer_ctc3/exp/jit_trace.pt'\n",
60
- " bpe_model_filename=\"data/lang_bpe_500/bpe.model\"\n",
61
- " method=\"ctc-decoding\"\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  " sample_rate=16000\n",
63
- " num_classes=500 #bpe model size\n",
64
- " frame_shift_ms=10\n",
65
  " dither=0\n",
66
  " snip_edges=False\n",
67
  " num_bins=80\n",
@@ -88,7 +121,7 @@
88
  },
89
  {
90
  "cell_type": "code",
91
- "execution_count": 5,
92
  "id": "48306369-fb68-4abe-be62-0806d00059f8",
93
  "metadata": {},
94
  "outputs": [],
@@ -163,7 +196,9 @@
163
  " \n",
164
  " def decode_(self, wave, fbank, model, device, method, bpe_model_filename, num_classes, \n",
165
  " min_active_states, max_active_states, subsampling_factor, use_double_scores, \n",
166
- " frame_shift_ms, search_beam, output_beam):\n",
 
 
167
  " \n",
168
  " wave = [wave.to(device)]\n",
169
  " logging.info(\"Decoding started\")\n",
@@ -223,15 +258,127 @@
223
  " logging.info(timestamps)\n",
224
  " token_ids = get_texts(best_path)\n",
225
  " return self.format_trs(hyps[0], timestamps[0])\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  " \n",
227
- " def transcribe_file(self, audio_filename):\n",
228
  " wave=self.read_sound_file_(audio_filename, expected_sample_rate=self.args.sample_rate)\n",
229
  " \n",
230
- " trs=self.decode_(wave, self.fbank, self.model, self.args.device, self.args.method, \n",
 
 
 
231
  " self.args.bpe_model_filename, self.args.num_classes,\n",
232
  " self.args.min_active_states, self.args.max_active_states, \n",
233
  " self.args.subsampling_factor, self.args.use_double_scores, \n",
234
- " self.args.frame_shift_ms, self.args.search_beam, self.args.output_beam)\n",
 
 
235
  " return trs"
236
  ]
237
  },
@@ -245,23 +392,10 @@
245
  },
246
  {
247
  "cell_type": "code",
248
- "execution_count": 6,
249
  "id": "50ab7c8e-39b6-4783-8342-e79e91d2417e",
250
  "metadata": {},
251
- "outputs": [
252
- {
253
- "name": "stderr",
254
- "output_type": "stream",
255
- "text": [
256
- "fatal: not a git repository (or any parent up to mount point /opt)\n",
257
- "Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).\n",
258
- "fatal: not a git repository (or any parent up to mount point /opt)\n",
259
- "Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).\n",
260
- "fatal: not a git repository (or any parent up to mount point /opt)\n",
261
- "Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).\n"
262
- ]
263
- }
264
- ],
265
  "source": [
266
  "#create transcriber/decoder object\n",
267
  "#if you want to change parameters (for example model filename) you could create a dict (see class Args attribute names)\n",
@@ -272,7 +406,7 @@
272
  },
273
  {
274
  "cell_type": "code",
275
- "execution_count": 7,
276
  "id": "8020f371-7584-4f6c-990b-f2c023e24060",
277
  "metadata": {},
278
  "outputs": [
@@ -280,8 +414,8 @@
280
  "name": "stdout",
281
  "output_type": "stream",
282
  "text": [
283
- "CPU times: user 4.86 s, sys: 435 ms, total: 5.29 s\n",
284
- "Wall time: 4.45 s\n"
285
  ]
286
  },
287
  {
@@ -303,7 +437,7 @@
303
  " {'word': 'panna', 'start': 10.16, 'end': 10.4}]}"
304
  ]
305
  },
306
- "execution_count": 7,
307
  "metadata": {},
308
  "output_type": "execute_result"
309
  }
@@ -315,7 +449,7 @@
315
  },
316
  {
317
  "cell_type": "code",
318
- "execution_count": 10,
319
  "id": "4d2a480d-f0aa-4474-bfdb-ad298a629ce5",
320
  "metadata": {},
321
  "outputs": [
@@ -323,8 +457,8 @@
323
  "name": "stdout",
324
  "output_type": "stream",
325
  "text": [
326
- "CPU times: user 16.2 s, sys: 1.8 s, total: 18 s\n",
327
- "Wall time: 15.1 s\n"
328
  ]
329
  }
330
  ],
@@ -334,7 +468,7 @@
334
  },
335
  {
336
  "cell_type": "code",
337
- "execution_count": 11,
338
  "id": "d3827548-bca0-4409-95bc-9aa8ba377135",
339
  "metadata": {},
340
  "outputs": [
@@ -458,7 +592,7 @@
458
  " {'word': 'jah', 'start': 47.56, 'end': 47.68}]}"
459
  ]
460
  },
461
- "execution_count": 11,
462
  "metadata": {},
463
  "output_type": "execute_result"
464
  }
@@ -467,10 +601,109 @@
467
  "trs"
468
  ]
469
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
  {
471
  "cell_type": "code",
472
  "execution_count": null,
473
- "id": "ea3b25b7-a1f9-4b21-911d-35159c5f3009",
474
  "metadata": {},
475
  "outputs": [],
476
  "source": []
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 1,
6
  "id": "b6b6ded1-0a58-43cb-9065-4f4fae02a01b",
7
  "metadata": {},
8
  "outputs": [],
 
50
  },
51
  {
52
  "cell_type": "code",
53
+ "execution_count": 2,
54
  "id": "3d69d771-b421-417f-a6ff-e1d1c64ba934",
55
  "metadata": {},
56
  "outputs": [],
57
  "source": [
58
  "class Args:\n",
59
+ " model_filename='conformer_ctc3/exp/jit_trace.pt' #Path to the torchscript model.\n",
60
+ " bpe_model_filename='data/lang_bpe_500/bpe.model' #\"Path to bpe.model.\n",
61
+ " #Used only when method is ctc-decoding.\n",
62
+ " method=\"ctc-decoding\" #decoding method\n",
63
+ " # ctc-decoding - Use CTC decoding. It uses a sentence\n",
64
+ " # piece model, i.e., lang_dir/bpe.model, to convert\n",
65
+ " # word pieces to words. It needs neither a lexicon\n",
66
+ " # nor an n-gram LM.\n",
67
+ " # (1) 1best - Use the best path as decoding output. Only\n",
68
+ " # the transformer encoder output is used for decoding.\n",
69
+ " # We call it HLG decoding.\n",
70
+ " # (2) nbest-rescoring. Extract n paths from the decoding lattice,\n",
71
+ " # rescore them with an LM, the path with\n",
72
+ " # the highest score is the decoding result.\n",
73
+ " # We call it HLG decoding + n-gram LM rescoring.\n",
74
+ " # (3) whole-lattice-rescoring - Use an LM to rescore the\n",
75
+ " # decoding lattice and then use 1best to decode the\n",
76
+ " # rescored lattice.\n",
77
+ " # We call it HLG decoding + n-gram LM rescoring.\n",
78
+ " HLG='data/lang_bpe_500/HLG.pt' #Path to HLG.pt.\n",
79
+ " #Used only when method is not ctc-decoding.\n",
80
+ " G='data/lm/G_4_gram.pt' #Used only when method is\n",
81
+ " #whole-lattice-rescoring or nbest-rescoring.\n",
82
+ " #It's usually a 4-gram LM.\n",
83
+ " words_file='data/lang_phone/words.txt' #Path to words.txt.\n",
84
+ " #Used only when method is not ctc-decoding.\n",
85
+ " num_paths=100 # Used only when method is attention-decoder.\n",
86
+ " #It specifies the size of n-best list.\n",
87
+ " ngram_lm_scale=0.1 #Used only when method is whole-lattice-rescoring and nbest-rescoring.\n",
88
+ " #It specifies the scale for n-gram LM scores.\n",
89
+ " #(Note: You need to tune it on a dataset.)\n",
90
+ " nbest_scale=0.5 #Used only when method is nbest-rescoring.\n",
91
+ " # It specifies the scale for lattice.scores when\n",
92
+ " # extracting n-best lists. A smaller value results in\n",
93
+ " # more unique number of paths with the risk of missing\n",
94
+ " # the best path.\n",
95
  " sample_rate=16000\n",
96
+ " num_classes=500 #Vocab size in the BPE model.\n",
97
+ " frame_shift_ms=10 #Frame shift in milliseconds between two contiguous frames.\n",
98
  " dither=0\n",
99
  " snip_edges=False\n",
100
  " num_bins=80\n",
 
121
  },
122
  {
123
  "cell_type": "code",
124
+ "execution_count": 25,
125
  "id": "48306369-fb68-4abe-be62-0806d00059f8",
126
  "metadata": {},
127
  "outputs": [],
 
196
  " \n",
197
  " def decode_(self, wave, fbank, model, device, method, bpe_model_filename, num_classes, \n",
198
  " min_active_states, max_active_states, subsampling_factor, use_double_scores, \n",
199
+ " frame_shift_ms, search_beam, output_beam, HLG=None, G=None, words_file=None,\n",
200
+ " num_paths=None, ngram_lm_scale=None, nbest_scale=None):\n",
201
+ " \n",
202
  " \n",
203
  " wave = [wave.to(device)]\n",
204
  " logging.info(\"Decoding started\")\n",
 
258
  " logging.info(timestamps)\n",
259
  " token_ids = get_texts(best_path)\n",
260
  " return self.format_trs(hyps[0], timestamps[0])\n",
261
+ " \n",
262
+ " elif method in [\n",
263
+ " \"1best\",\n",
264
+ " \"nbest-rescoring\",\n",
265
+ " \"whole-lattice-rescoring\",\n",
266
+ " ]:\n",
267
+ " logging.info(f\"Loading HLG from {HLG}\")\n",
268
+ " HLG = k2.Fsa.from_dict(torch.load(HLG, map_location=\"cpu\"))\n",
269
+ " HLG = HLG.to(device)\n",
270
+ " if not hasattr(HLG, \"lm_scores\"):\n",
271
+ " # For whole-lattice-rescoring and attention-decoder\n",
272
+ " HLG.lm_scores = HLG.scores.clone()\n",
273
+ "\n",
274
+ " if method in [\n",
275
+ " \"nbest-rescoring\",\n",
276
+ " \"whole-lattice-rescoring\",\n",
277
+ " ]:\n",
278
+ " logging.info(f\"Loading G from {G}\")\n",
279
+ " G = k2.Fsa.from_dict(torch.load(G, map_location=\"cpu\"))\n",
280
+ " G = G.to(device)\n",
281
+ " if method == \"whole-lattice-rescoring\":\n",
282
+ " # Add epsilon self-loops to G as we will compose\n",
283
+ " # it with the whole lattice later\n",
284
+ " G = k2.add_epsilon_self_loops(G)\n",
285
+ " G = k2.arc_sort(G)\n",
286
+ "\n",
287
+ " # G.lm_scores is used to replace HLG.lm_scores during\n",
288
+ " # LM rescoring.\n",
289
+ " G.lm_scores = G.scores.clone()\n",
290
+ " if method == \"nbest-rescoring\" or method == \"whole-lattice-rescoring\":\n",
291
+ " #adjustes symbol table othersie returns empty text\n",
292
+ " #https://github.com/k2-fsa/k2/issues/874\n",
293
+ " def is_disambig_symbol(symbol: str, pattern: re.Pattern = re.compile(r'^#\\d+$')) -> bool:\n",
294
+ " return pattern.match(symbol) is not None\n",
295
+ "\n",
296
+ " def find_first_disambig_symbol(symbols: k2.SymbolTable) -> int:\n",
297
+ " return min(v for k, v in symbols._sym2id.items() if is_disambig_symbol(k))\n",
298
+ " symbol_table = k2.SymbolTable.from_file(words_file)\n",
299
+ " first_word_disambig_id = find_first_disambig_symbol(symbol_table)\n",
300
+ " print(\"disambig id:\", first_word_disambig_id)\n",
301
+ " G.labels[G.labels >= first_word_disambig_id] = 0\n",
302
+ " G.labels_sym = symbol_table\n",
303
+ "\n",
304
+ " #added part, transforms G from Fsa to FsaVec otherwise throws error\n",
305
+ " G = k2.create_fsa_vec([G])\n",
306
+ " #https://github.com/k2-fsa/k2/blob/master/k2/python/k2/utils.py\n",
307
+ " delattr(G, \"aux_labels\")\n",
308
+ " G = k2.arc_sort(G)\n",
309
+ "\n",
310
+ "\n",
311
+ " lattice = get_lattice(\n",
312
+ " nnet_output=nnet_output,\n",
313
+ " decoding_graph=HLG,\n",
314
+ " supervision_segments=supervision_segments,\n",
315
+ " search_beam=search_beam,\n",
316
+ " output_beam=output_beam,\n",
317
+ " min_active_states=min_active_states,\n",
318
+ " max_active_states=max_active_states,\n",
319
+ " subsampling_factor=subsampling_factor,\n",
320
+ " )\n",
321
+ "\n",
322
+ " ############\n",
323
+ " # scored_lattice = k2.top_sort(k2.connect(k2.intersect(lattice, G, treat_epsilons_specially=True)))\n",
324
+ " # scored_lattice[0].draw(\"after_intersection.svg\", title=\"after_intersection\")\n",
325
+ " # scores = scored_lattice.get_forward_scores(True, True)\n",
326
+ " # print(scores)\n",
327
+ " #########################\n",
328
+ " if method == \"1best\":\n",
329
+ " logging.info(\"Use HLG decoding\")\n",
330
+ " best_path = one_best_decoding(\n",
331
+ " lattice=lattice, use_double_scores=use_double_scores\n",
332
+ " )\n",
333
+ "\n",
334
+ " timestamps, hyps = parse_fsa_timestamps_and_texts(\n",
335
+ " best_paths=best_path,\n",
336
+ " word_table=word_table,\n",
337
+ " subsampling_factor=subsampling_factor,\n",
338
+ " frame_shift_ms=frame_shift_ms,\n",
339
+ " )\n",
340
+ "\n",
341
+ " if method == \"nbest-rescoring\":\n",
342
+ " logging.info(\"Use HLG decoding + LM rescoring\")\n",
343
+ " best_path_dict = rescore_with_n_best_list(\n",
344
+ " lattice=lattice,\n",
345
+ " G=G,\n",
346
+ " num_paths=num_paths,\n",
347
+ " lm_scale_list=[ngram_lm_scale],\n",
348
+ " nbest_scale=nbest_scale,\n",
349
+ " )\n",
350
+ " best_path = next(iter(best_path_dict.values()))\n",
351
+ " \n",
352
+ " elif method == \"whole-lattice-rescoring\":\n",
353
+ " logging.info(\"Use HLG decoding + LM rescoring\")\n",
354
+ " best_path_dict = rescore_with_whole_lattice(\n",
355
+ " lattice=lattice,\n",
356
+ " G_with_epsilon_loops=G,\n",
357
+ " lm_scale_list=[ngram_lm_scale],\n",
358
+ " )\n",
359
+ " best_path = next(iter(best_path_dict.values()))\n",
360
+ "\n",
361
+ " hyps = get_texts(best_path)\n",
362
+ " word_sym_table = k2.SymbolTable.from_file(words_file)\n",
363
+ " hyps = [[word_sym_table[i] for i in ids] for ids in hyps]\n",
364
+ " return hyps\n",
365
+ " else:\n",
366
+ " raise ValueError(f\"Unsupported decoding method: {method}\")\n",
367
+ "\n",
368
  " \n",
369
+ " def transcribe_file(self, audio_filename, method=None):\n",
370
  " wave=self.read_sound_file_(audio_filename, expected_sample_rate=self.args.sample_rate)\n",
371
  " \n",
372
+ " if method is None:\n",
373
+ " method=self.args.method\n",
374
+ " \n",
375
+ " trs=self.decode_(wave, self.fbank, self.model, self.args.device, method, \n",
376
  " self.args.bpe_model_filename, self.args.num_classes,\n",
377
  " self.args.min_active_states, self.args.max_active_states, \n",
378
  " self.args.subsampling_factor, self.args.use_double_scores, \n",
379
+ " self.args.frame_shift_ms, self.args.search_beam, self.args.output_beam,\n",
380
+ " self.args.HLG, self.args.G, self.args.words_file, self.args.num_paths,\n",
381
+ " self.args.ngram_lm_scale, self.args.nbest_scale)\n",
382
  " return trs"
383
  ]
384
  },
 
392
  },
393
  {
394
  "cell_type": "code",
395
+ "execution_count": 26,
396
  "id": "50ab7c8e-39b6-4783-8342-e79e91d2417e",
397
  "metadata": {},
398
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  "source": [
400
  "#create transcriber/decoder object\n",
401
  "#if you want to change parameters (for example model filename) you could create a dict (see class Args attribute names)\n",
 
406
  },
407
  {
408
  "cell_type": "code",
409
+ "execution_count": 9,
410
  "id": "8020f371-7584-4f6c-990b-f2c023e24060",
411
  "metadata": {},
412
  "outputs": [
 
414
  "name": "stdout",
415
  "output_type": "stream",
416
  "text": [
417
+ "CPU times: user 6.22 s, sys: 432 ms, total: 6.65 s\n",
418
+ "Wall time: 5.79 s\n"
419
  ]
420
  },
421
  {
 
437
  " {'word': 'panna', 'start': 10.16, 'end': 10.4}]}"
438
  ]
439
  },
440
+ "execution_count": 9,
441
  "metadata": {},
442
  "output_type": "execute_result"
443
  }
 
449
  },
450
  {
451
  "cell_type": "code",
452
+ "execution_count": 12,
453
  "id": "4d2a480d-f0aa-4474-bfdb-ad298a629ce5",
454
  "metadata": {},
455
  "outputs": [
 
457
  "name": "stdout",
458
  "output_type": "stream",
459
  "text": [
460
+ "CPU times: user 28.4 s, sys: 2.93 s, total: 31.3 s\n",
461
+ "Wall time: 27.3 s\n"
462
  ]
463
  }
464
  ],
 
468
  },
469
  {
470
  "cell_type": "code",
471
+ "execution_count": 13,
472
  "id": "d3827548-bca0-4409-95bc-9aa8ba377135",
473
  "metadata": {},
474
  "outputs": [
 
592
  " {'word': 'jah', 'start': 47.56, 'end': 47.68}]}"
593
  ]
594
  },
595
+ "execution_count": 13,
596
  "metadata": {},
597
  "output_type": "execute_result"
598
  }
 
601
  "trs"
602
  ]
603
  },
604
+ {
605
+ "cell_type": "markdown",
606
+ "id": "6740a04c-09e1-4497-84e2-5227acd9dda3",
607
+ "metadata": {},
608
+ "source": [
609
+ "## Some other decoding"
610
+ ]
611
+ },
612
+ {
613
+ "cell_type": "markdown",
614
+ "id": "b012c0d7-04ab-451e-8414-85b4b9ac9165",
615
+ "metadata": {},
616
+ "source": [
617
+ "1best decoding currently not working"
618
+ ]
619
+ },
620
+ {
621
+ "cell_type": "code",
622
+ "execution_count": 27,
623
+ "id": "15fcf012-265a-4464-8da7-1c7e1a46556a",
624
+ "metadata": {},
625
+ "outputs": [
626
+ {
627
+ "name": "stdout",
628
+ "output_type": "stream",
629
+ "text": [
630
+ "disambig id: 157281\n",
631
+ "CPU times: user 3min 56s, sys: 7.52 s, total: 4min 3s\n",
632
+ "Wall time: 2min 22s\n"
633
+ ]
634
+ },
635
+ {
636
+ "data": {
637
+ "text/plain": [
638
+ "[['mina',\n",
639
+ " 'tahaksin',\n",
640
+ " 'homme',\n",
641
+ " 'täna',\n",
642
+ " 'ja',\n",
643
+ " 'homme',\n",
644
+ " 'kui',\n",
645
+ " 'saan',\n",
646
+ " 'kontsu',\n",
647
+ " 'madise',\n",
648
+ " 'vei',\n",
649
+ " 'panna']]"
650
+ ]
651
+ },
652
+ "execution_count": 27,
653
+ "metadata": {},
654
+ "output_type": "execute_result"
655
+ }
656
+ ],
657
+ "source": [
658
+ "%time transcriber.transcribe_file('audio/emt16k.wav', method='nbest-rescoring')"
659
+ ]
660
+ },
661
+ {
662
+ "cell_type": "code",
663
+ "execution_count": 28,
664
+ "id": "31591ee0-605c-4b20-b01f-cb8643fefdd1",
665
+ "metadata": {},
666
+ "outputs": [
667
+ {
668
+ "name": "stdout",
669
+ "output_type": "stream",
670
+ "text": [
671
+ "disambig id: 157281\n",
672
+ "CPU times: user 41.2 s, sys: 409 ms, total: 41.6 s\n",
673
+ "Wall time: 31.3 s\n"
674
+ ]
675
+ },
676
+ {
677
+ "data": {
678
+ "text/plain": [
679
+ "[['mina',\n",
680
+ " 'tahaksin',\n",
681
+ " 'homme',\n",
682
+ " 'täna',\n",
683
+ " 'ja',\n",
684
+ " 'homme',\n",
685
+ " 'kui',\n",
686
+ " 'saan',\n",
687
+ " 'all',\n",
688
+ " 'kontsu',\n",
689
+ " 'madise',\n",
690
+ " 'vei',\n",
691
+ " 'panna']]"
692
+ ]
693
+ },
694
+ "execution_count": 28,
695
+ "metadata": {},
696
+ "output_type": "execute_result"
697
+ }
698
+ ],
699
+ "source": [
700
+ "%time transcriber.transcribe_file('audio/emt16k.wav', method='whole-lattice-rescoring')"
701
+ ]
702
+ },
703
  {
704
  "cell_type": "code",
705
  "execution_count": null,
706
+ "id": "80dfe34d-a76b-4ddc-a47c-c481c5e1432f",
707
  "metadata": {},
708
  "outputs": [],
709
  "source": []