In [4]:
import argparse
import logging
import math
import re
from typing import List
import sys
sys.path.append('/opt/notebooks/err2020/conformer_ctc3/')
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from decode import get_decoding_params
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params

from icefall.decode import (
 get_lattice,
 one_best_decoding,
 rescore_with_n_best_list,
 rescore_with_whole_lattice
)
from icefall.utils import get_texts, parse_fsa_timestamps_and_texts

## Helpers

#### Load args helpers

In [5]:
class Args:
 model_filename='conformer_ctc3/exp/jit_trace.pt' #Path to the torchscript model.
 bpe_model_filename='data/lang_bpe_500/bpe.model' #"Path to bpe.model.
 #Used only when method is ctc-decoding.
 method="ctc-decoding" #decoding method
 # ctc-decoding - Use CTC decoding. It uses a sentence
 # piece model, i.e., lang_dir/bpe.model, to convert
 # word pieces to words. It needs neither a lexicon
 # nor an n-gram LM.
 # (1) 1best - Use the best path as decoding output. Only
 # the transformer encoder output is used for decoding.
 # We call it HLG decoding.
 # (2) nbest-rescoring. Extract n paths from the decoding lattice,
 # rescore them with an LM, the path with
 # the highest score is the decoding result.
 # We call it HLG decoding + n-gram LM rescoring.
 # (3) whole-lattice-rescoring - Use an LM to rescore the
 # decoding lattice and then use 1best to decode the
 # rescored lattice.
 # We call it HLG decoding + n-gram LM rescoring.
 HLG='data/lang_bpe_500/HLG.pt' #Path to HLG.pt.
 #Used only when method is not ctc-decoding.
 G='data/lm/G_4_gram.pt' #Used only when method is
 #whole-lattice-rescoring or nbest-rescoring.
 #It's usually a 4-gram LM.
 words_file='data/lang_phone/words.txt' #Path to words.txt.
 #Used only when method is not ctc-decoding.
 num_paths=100 # Used only when method is attention-decoder.
 #It specifies the size of n-best list.
 ngram_lm_scale=0.1 #Used only when method is whole-lattice-rescoring and nbest-rescoring.
 #It specifies the scale for n-gram LM scores.
 #(Note: You need to tune it on a dataset.)
 nbest_scale=0.5 #Used only when method is nbest-rescoring.
 # It specifies the scale for lattice.scores when
 # extracting n-best lists. A smaller value results in
 # more unique number of paths with the risk of missing
 # the best path.
 sample_rate=16000
 num_classes=500 #Vocab size in the BPE model.
 frame_shift_ms=10 #Frame shift in milliseconds between two contiguous frames.
 dither=0
 snip_edges=False
 num_bins=80
 device='cpu'
 
 def args_from_dict(self, dct):
 for key in dct:
 setattr(self, key, dct[key])
 
 def __repr__(self):
 text=''
 for k, v in self.__dict__.items():
 text+=f'{k} = {v}\n'
 return text

#### Decoder helper

In [6]:
class ConformerCtc3Decoder:
 def __init__(self, params_dct=None):
 logging.info('loading args')
 self.args=Args()
 if params_dct is not None:
 self.args.args_from_dict(params_dct)
 logging.info('loading model')
 self.load_model()
 logging.info('loading fbank')
 self.get_fbank()
 
 def update_args(self, dct):
 self.args.args_from_dict(dct)
 
 def load_model_(self, model_filename, device):
 device = torch.device("cpu")
 model = torch.jit.load(model_filename)
 model.to(device)
 model=model.eval()
 self.model=model
 
 def load_model(self, model_filename=None, device=None):
 if model_filename is not None:
 self.args.model_filename=model_filename
 if device is not None:
 self.args.device=device
 self.load_model_(self.args.model_filename, self.args.device)
 
 def get_fbank_(self, device='cpu'):
 opts = kaldifeat.FbankOptions()
 opts.device = device
 opts.frame_opts.dither = self.args.dither
 opts.frame_opts.snip_edges = self.args.snip_edges
 #opts.frame_opts.samp_freq = sample_rate
 opts.mel_opts.num_bins = self.args.num_bins

 fbank = kaldifeat.Fbank(opts)
 return fbank
 
 def get_fbank(self):
 self.fbank=self.get_fbank_(self.args.device)
 
 def read_sound_file_(self, filename: str, expected_sample_rate: float ) -> List[torch.Tensor]:
 """Read a sound file into a 1-D float32 torch tensor.
 Args:
 filenames:
 A list of sound filenames.
 expected_sample_rate:
 The expected sample rate of the sound files.
 Returns:
 Return a 1-D float32 torch tensor.
 """
 wave, sample_rate = torchaudio.load(filename)
 assert sample_rate == expected_sample_rate, (
 f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
 )
 # We use only the first channel
 return wave[0]
 
 def format_trs(self, hyp, timestamps):
 if len(hyp)!=len(timestamps):
 print(f'len of hyp and timestamps is not the same len hyp {len(hyp)} and len of timestamps {len(timestamps)}')
 return None
 trs ={'text': ' '.join(hyp),
 'words': [{'word': w, 'start':timestamps[i][0], 'end': timestamps[i][1]} for i, w in enumerate(hyp)]
 }
 return trs
 
 def decode_(self, wave, fbank, model, device, method, bpe_model_filename, num_classes, 
 min_active_states, max_active_states, subsampling_factor, use_double_scores, 
 frame_shift_ms, search_beam, output_beam, HLG=None, G=None, words_file=None,
 num_paths=None, ngram_lm_scale=None, nbest_scale=None):
 
 
 wave = [wave.to(device)]
 logging.info("Decoding started")
 features = fbank(wave)
 feature_lengths = [f.size(0) for f in features]

 features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 feature_lengths = torch.tensor(feature_lengths, device=device)

 nnet_output, _ = model(features, feature_lengths)

 batch_size = nnet_output.shape[0]
 supervision_segments = torch.tensor(
 [
 [i, 0, feature_lengths[i] // subsampling_factor]
 for i in range(batch_size)
 ],
 dtype=torch.int32,
 )

 if method == "ctc-decoding":
 logging.info("Use CTC decoding")
 bpe_model = spm.SentencePieceProcessor()
 bpe_model.load(bpe_model_filename)
 max_token_id = num_classes - 1

 H = k2.ctc_topo(
 max_token=max_token_id,
 modified=False,
 device=device,
 )

 lattice = get_lattice(
 nnet_output=nnet_output,
 decoding_graph=H,
 supervision_segments=supervision_segments,
 search_beam=search_beam,
 output_beam=output_beam,
 min_active_states=min_active_states,
 max_active_states=max_active_states,
 subsampling_factor=subsampling_factor,
 )

 best_path = one_best_decoding(
 lattice=lattice, use_double_scores=use_double_scores
 )

 confidence=best_path.get_tot_scores(use_double_scores=False, log_semiring=False).detach()[0]

 timestamps, hyps = parse_fsa_timestamps_and_texts(
 best_paths=best_path,
 sp=bpe_model,
 subsampling_factor=subsampling_factor,
 frame_shift_ms=frame_shift_ms,
 )
 logging.info(f'confidence {confidence}')
 logging.info(timestamps)
 token_ids = get_texts(best_path)
 return self.format_trs(hyps[0], timestamps[0])
 
 elif method in [
 "1best",
 "nbest-rescoring",
 "whole-lattice-rescoring",
 ]:
 logging.info(f"Loading HLG from {HLG}")
 HLG = k2.Fsa.from_dict(torch.load(HLG, map_location="cpu"))
 HLG = HLG.to(device)
 if not hasattr(HLG, "lm_scores"):
 # For whole-lattice-rescoring and attention-decoder
 HLG.lm_scores = HLG.scores.clone()

 if method in [
 "nbest-rescoring",
 "whole-lattice-rescoring",
 ]:
 logging.info(f"Loading G from {G}")
 G = k2.Fsa.from_dict(torch.load(G, map_location="cpu"))
 G = G.to(device)
 if method == "whole-lattice-rescoring":
 # Add epsilon self-loops to G as we will compose
 # it with the whole lattice later
 G = k2.add_epsilon_self_loops(G)
 G = k2.arc_sort(G)

 # G.lm_scores is used to replace HLG.lm_scores during
 # LM rescoring.
 G.lm_scores = G.scores.clone()
 if method == "nbest-rescoring" or method == "whole-lattice-rescoring":
 #adjustes symbol table othersie returns empty text
 #https://github.com/k2-fsa/k2/issues/874
 def is_disambig_symbol(symbol: str, pattern: re.Pattern = re.compile(r'^#\d+$')) -> bool:
 return pattern.match(symbol) is not None

 def find_first_disambig_symbol(symbols: k2.SymbolTable) -> int:
 return min(v for k, v in symbols._sym2id.items() if is_disambig_symbol(k))
 symbol_table = k2.SymbolTable.from_file(words_file)
 first_word_disambig_id = find_first_disambig_symbol(symbol_table)
 print("disambig id:", first_word_disambig_id)
 G.labels[G.labels >= first_word_disambig_id] = 0
 G.labels_sym = symbol_table

 #added part, transforms G from Fsa to FsaVec otherwise throws error
 G = k2.create_fsa_vec([G])
 #https://github.com/k2-fsa/k2/blob/master/k2/python/k2/utils.py
 delattr(G, "aux_labels")
 G = k2.arc_sort(G)


 lattice = get_lattice(
 nnet_output=nnet_output,
 decoding_graph=HLG,
 supervision_segments=supervision_segments,
 search_beam=search_beam,
 output_beam=output_beam,
 min_active_states=min_active_states,
 max_active_states=max_active_states,
 subsampling_factor=subsampling_factor,
 )

 ############
 # scored_lattice = k2.top_sort(k2.connect(k2.intersect(lattice, G, treat_epsilons_specially=True)))
 # scored_lattice[0].draw("after_intersection.svg", title="after_intersection")
 # scores = scored_lattice.get_forward_scores(True, True)
 # print(scores)
 #########################
 if method == "1best":
 logging.info("Use HLG decoding")
 best_path = one_best_decoding(
 lattice=lattice, use_double_scores=use_double_scores
 )

 timestamps, hyps = parse_fsa_timestamps_and_texts(
 best_paths=best_path,
 word_table=word_table,
 subsampling_factor=subsampling_factor,
 frame_shift_ms=frame_shift_ms,
 )

 if method == "nbest-rescoring":
 logging.info("Use HLG decoding + LM rescoring")
 best_path_dict = rescore_with_n_best_list(
 lattice=lattice,
 G=G,
 num_paths=num_paths,
 lm_scale_list=[ngram_lm_scale],
 nbest_scale=nbest_scale,
 )
 best_path = next(iter(best_path_dict.values()))
 
 elif method == "whole-lattice-rescoring":
 logging.info("Use HLG decoding + LM rescoring")
 best_path_dict = rescore_with_whole_lattice(
 lattice=lattice,
 G_with_epsilon_loops=G,
 lm_scale_list=[ngram_lm_scale],
 )
 best_path = next(iter(best_path_dict.values()))

 hyps = get_texts(best_path)
 word_sym_table = k2.SymbolTable.from_file(words_file)
 hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
 return hyps
 else:
 raise ValueError(f"Unsupported decoding method: {method}")

 
 def transcribe_file(self, audio_filename, method=None):
 wave=self.read_sound_file_(audio_filename, expected_sample_rate=self.args.sample_rate)
 
 if method is None:
 method=self.args.method
 
 trs=self.decode_(wave, self.fbank, self.model, self.args.device, method, 
 self.args.bpe_model_filename, self.args.num_classes,
 self.args.min_active_states, self.args.max_active_states, 
 self.args.subsampling_factor, self.args.use_double_scores, 
 self.args.frame_shift_ms, self.args.search_beam, self.args.output_beam,
 self.args.HLG, self.args.G, self.args.words_file, self.args.num_paths,
 self.args.ngram_lm_scale, self.args.nbest_scale)
 return trs

## Example usage

In [7]:
#create transcriber/decoder object
#if you want to change parameters (for example model filename) you could create a dict (see class Args attribute names)
#and add it to as argument decoder initialization:
#conformerCtc3Decoder(get_params() | get_decoding_params() | {'model_filename':'my new model filename'})
transcriber=ConformerCtc3Decoder(get_params() | get_decoding_params())

In [8]:
#transribe audiofile (NB! model assumes sample rate of 16000)
%time transcriber.transcribe_file('audio/emt16k.wav')

CPU times: user 4.83 s, sys: 210 ms, total: 5.04 s
Wall time: 4.13 s


{'text': 'mina tahaksin homme täna ja homme kui saan all kolm krantsumadiseid veiki panna',
 'words': [{'word': 'mina', 'start': 0.8, 'end': 0.84},
 {'word': 'tahaksin', 'start': 1.0, 'end': 1.32},
 {'word': 'homme', 'start': 1.48, 'end': 1.76},
 {'word': 'täna', 'start': 2.08, 'end': 2.12},
 {'word': 'ja', 'start': 3.72, 'end': 3.76},
 {'word': 'homme', 'start': 4.16, 'end': 4.44},
 {'word': 'kui', 'start': 5.96, 'end': 6.0},
 {'word': 'saan', 'start': 6.52, 'end': 6.84},
 {'word': 'all', 'start': 7.36, 'end': 7.4},
 {'word': 'kolm', 'start': 8.32, 'end': 8.36},
 {'word': 'krantsumadiseid', 'start': 8.68, 'end': 9.72},
 {'word': 'veiki', 'start': 9.76, 'end': 10.04},
 {'word': 'panna', 'start': 10.16, 'end': 10.4}]}

In [9]:
%time trs=transcriber.transcribe_file('audio/oden_kypsis16k.wav')

CPU times: user 16 s, sys: 1.13 s, total: 17.2 s
Wall time: 14.4 s


In [10]:
trs

{'text': 'enamus ajast nagu klikkid neid allserva tekivad need luba küpsiseid mis on nagu ilusti kohati tõlgitud eesti keelde see idee arusaadavamaks ma tean et see on kukis inglise kees ma ei saa sellest ka aru nagu mis asi on kukis on ju ma saan aru et ta vaid minee eest ära luba küpsises tava ei anna noh anna minna ma luban küpssi juhmaoloog okei on ju ma ei tea mis ta teeb lihtsalt selle eestikeelseks tõlk või eesti keelde tõlkimine kui teinud seda nagu arusaadavamaks küpsised kuule kuule veebisaid küsib sinu käest tahad tähendab on okei kui me neid kugiseid kasutame sa mingi ja mida iga mul täiesti savi või noh et et jah',
 'words': [{'word': 'enamus', 'start': 3.56, 'end': 3.8},
 {'word': 'ajast', 'start': 3.8, 'end': 4.04},
 {'word': 'nagu', 'start': 4.2, 'end': 4.24},
 {'word': 'klikkid', 'start': 4.72, 'end': 5.12},
 {'word': 'neid', 'start': 5.16, 'end': 5.2},
 {'word': 'allserva', 'start': 5.72, 'end': 6.2},
 {'word': 'tekivad', 'start': 6.32, 'end': 6.64},
 {'word': 'need',

## Some other decoding

1best decoding currently not working

In [27]:
%time transcriber.transcribe_file('audio/emt16k.wav', method='nbest-rescoring')

disambig id: 157281
CPU times: user 3min 56s, sys: 7.52 s, total: 4min 3s
Wall time: 2min 22s


[['mina',
 'tahaksin',
 'homme',
 'täna',
 'ja',
 'homme',
 'kui',
 'saan',
 'kontsu',
 'madise',
 'vei',
 'panna']]

In [28]:
%time transcriber.transcribe_file('audio/emt16k.wav', method='whole-lattice-rescoring')

disambig id: 157281
CPU times: user 41.2 s, sys: 409 ms, total: 41.6 s
Wall time: 31.3 s


[['mina',
 'tahaksin',
 'homme',
 'täna',
 'ja',
 'homme',
 'kui',
 'saan',
 'all',
 'kontsu',
 'madise',
 'vei',
 'panna']]