{ "cells": [ { "cell_type": "code", "execution_count": 3, "id": "b6b6ded1-0a58-43cb-9065-4f4fae02a01b", "metadata": {}, "outputs": [], "source": [ "import argparse\n", "import logging\n", "import math\n", "import re\n", "from typing import List\n", "import sys\n", "sys.path.append('/opt/notebooks/err2020/conformer_ctc3/')\n", "import k2\n", "import kaldifeat\n", "import sentencepiece as spm\n", "import torch\n", "import torchaudio\n", "from decode import get_decoding_params\n", "from torch.nn.utils.rnn import pad_sequence\n", "from train import add_model_arguments, get_params\n", "\n", "from icefall.decode import (\n", " get_lattice,\n", " one_best_decoding,\n", " rescore_with_n_best_list,\n", " rescore_with_whole_lattice\n", ")\n", "from icefall.utils import get_texts, parse_fsa_timestamps_and_texts" ] }, { "cell_type": "markdown", "id": "52514f2f-1195-4e4f-8174-d21aa7462476", "metadata": {}, "source": [ "## Helpers" ] }, { "cell_type": "markdown", "id": "8ec024bf-7f91-47a9-9293-822fe2765c4b", "metadata": {}, "source": [ "#### Load args helpers" ] }, { "cell_type": "code", "execution_count": 4, "id": "3d69d771-b421-417f-a6ff-e1d1c64ba934", "metadata": {}, "outputs": [], "source": [ "class Args:\n", " model_filename='conformer_ctc3/exp/jit_trace.pt'\n", " bpe_model_filename=\"data/lang_bpe_500/bpe.model\"\n", " method=\"ctc-decoding\"\n", " sample_rate=16000\n", " num_classes=500 #bpe model size\n", " frame_shift_ms=10\n", " dither=0\n", " snip_edges=False\n", " num_bins=80\n", " device='cpu'\n", " \n", " def args_from_dict(self, dct):\n", " for key in dct:\n", " setattr(self, key, dct[key])\n", " \n", " def __repr__(self):\n", " text=''\n", " for k, v in self.__dict__.items():\n", " text+=f'{k} = {v}\\n'\n", " return text" ] }, { "cell_type": "markdown", "id": "57a3cd62-3037-4c99-9094-dd63429e660e", "metadata": {}, "source": [ "#### Decoder helper" ] }, { "cell_type": "code", "execution_count": 5, "id": "48306369-fb68-4abe-be62-0806d00059f8", "metadata": {}, "outputs": [], "source": [ "class ConformerCtc3Decoder:\n", " def __init__(self, params_dct=None):\n", " logging.info('loading args')\n", " self.args=Args()\n", " if params_dct is not None:\n", " self.args.args_from_dict(params_dct)\n", " logging.info('loading model')\n", " self.load_model()\n", " logging.info('loading fbank')\n", " self.get_fbank()\n", " \n", " def update_args(self, dct):\n", " self.args.args_from_dict(dct)\n", " \n", " def load_model_(self, model_filename, device):\n", " device = torch.device(\"cpu\")\n", " model = torch.jit.load(model_filename)\n", " model.to(device)\n", " model=model.eval()\n", " self.model=model\n", " \n", " def load_model(self, model_filename=None, device=None):\n", " if model_filename is not None:\n", " self.args.model_filename=model_filename\n", " if device is not None:\n", " self.args.device=device\n", " self.load_model_(self.args.model_filename, self.args.device)\n", " \n", " def get_fbank_(self, device='cpu'):\n", " opts = kaldifeat.FbankOptions()\n", " opts.device = device\n", " opts.frame_opts.dither = self.args.dither\n", " opts.frame_opts.snip_edges = self.args.snip_edges\n", " #opts.frame_opts.samp_freq = sample_rate\n", " opts.mel_opts.num_bins = self.args.num_bins\n", "\n", " fbank = kaldifeat.Fbank(opts)\n", " return fbank\n", " \n", " def get_fbank(self):\n", " self.fbank=self.get_fbank_(self.args.device)\n", " \n", " def read_sound_file_(self, filename: str, expected_sample_rate: float ) -> List[torch.Tensor]:\n", " \"\"\"Read a sound file into a 1-D float32 torch tensor.\n", " Args:\n", " filenames:\n", " A list of sound filenames.\n", " expected_sample_rate:\n", " The expected sample rate of the sound files.\n", " Returns:\n", " Return a 1-D float32 torch tensor.\n", " \"\"\"\n", " wave, sample_rate = torchaudio.load(filename)\n", " assert sample_rate == expected_sample_rate, (\n", " f\"expected sample rate: {expected_sample_rate}. \" f\"Given: {sample_rate}\"\n", " )\n", " # We use only the first channel\n", " return wave[0]\n", " \n", " def format_trs(self, hyp, timestamps):\n", " if len(hyp)!=len(timestamps):\n", " print(f'len of hyp and timestamps is not the same len hyp {len(hyp)} and len of timestamps {len(timestamps)}')\n", " return None\n", " trs ={'text': ' '.join(hyp),\n", " 'words': [{'word': w, 'start':timestamps[i][0], 'end': timestamps[i][1]} for i, w in enumerate(hyp)]\n", " }\n", " return trs\n", " \n", " def decode_(self, wave, fbank, model, device, method, bpe_model_filename, num_classes, \n", " min_active_states, max_active_states, subsampling_factor, use_double_scores, \n", " frame_shift_ms, search_beam, output_beam):\n", " \n", " wave = [wave.to(device)]\n", " logging.info(\"Decoding started\")\n", " features = fbank(wave)\n", " feature_lengths = [f.size(0) for f in features]\n", "\n", " features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))\n", " feature_lengths = torch.tensor(feature_lengths, device=device)\n", "\n", " nnet_output, _ = model(features, feature_lengths)\n", "\n", " batch_size = nnet_output.shape[0]\n", " supervision_segments = torch.tensor(\n", " [\n", " [i, 0, feature_lengths[i] // subsampling_factor]\n", " for i in range(batch_size)\n", " ],\n", " dtype=torch.int32,\n", " )\n", "\n", " if method == \"ctc-decoding\":\n", " logging.info(\"Use CTC decoding\")\n", " bpe_model = spm.SentencePieceProcessor()\n", " bpe_model.load(bpe_model_filename)\n", " max_token_id = num_classes - 1\n", "\n", " H = k2.ctc_topo(\n", " max_token=max_token_id,\n", " modified=False,\n", " device=device,\n", " )\n", "\n", " lattice = get_lattice(\n", " nnet_output=nnet_output,\n", " decoding_graph=H,\n", " supervision_segments=supervision_segments,\n", " search_beam=search_beam,\n", " output_beam=output_beam,\n", " min_active_states=min_active_states,\n", " max_active_states=max_active_states,\n", " subsampling_factor=subsampling_factor,\n", " )\n", "\n", " best_path = one_best_decoding(\n", " lattice=lattice, use_double_scores=use_double_scores\n", " )\n", "\n", " confidence=best_path.get_tot_scores(use_double_scores=False, log_semiring=False).detach()[0]\n", "\n", " timestamps, hyps = parse_fsa_timestamps_and_texts(\n", " best_paths=best_path,\n", " sp=bpe_model,\n", " subsampling_factor=subsampling_factor,\n", " frame_shift_ms=frame_shift_ms,\n", " )\n", " logging.info(f'confidence {confidence}')\n", " logging.info(timestamps)\n", " token_ids = get_texts(best_path)\n", " return self.format_trs(hyps[0], timestamps[0])\n", " \n", " def transcribe_file(self, audio_filename):\n", " wave=self.read_sound_file_(audio_filename, expected_sample_rate=self.args.sample_rate)\n", " \n", " trs=self.decode_(wave, self.fbank, self.model, self.args.device, self.args.method, \n", " self.args.bpe_model_filename, self.args.num_classes,\n", " self.args.min_active_states, self.args.max_active_states, \n", " self.args.subsampling_factor, self.args.use_double_scores, \n", " self.args.frame_shift_ms, self.args.search_beam, self.args.output_beam)\n", " return trs" ] }, { "cell_type": "markdown", "id": "b1464957-05b6-40f8-a1aa-c58edbed440c", "metadata": {}, "source": [ "## Example usage" ] }, { "cell_type": "code", "execution_count": 6, "id": "50ab7c8e-39b6-4783-8342-e79e91d2417e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "fatal: not a git repository (or any parent up to mount point /opt)\n", "Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).\n", "fatal: not a git repository (or any parent up to mount point /opt)\n", "Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).\n", "fatal: not a git repository (or any parent up to mount point /opt)\n", "Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).\n" ] } ], "source": [ "#create transcriber/decoder object\n", "#if you want to change parameters (for example model filename) you could create a dict (see class Args attribute names)\n", "#and add it to as argument decoder initialization:\n", "#conformerCtc3Decoder(get_params() | get_decoding_params() | {'model_filename':'my new model filename'})\n", "transcriber=ConformerCtc3Decoder(get_params() | get_decoding_params())" ] }, { "cell_type": "code", "execution_count": 7, "id": "8020f371-7584-4f6c-990b-f2c023e24060", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 4.86 s, sys: 435 ms, total: 5.29 s\n", "Wall time: 4.45 s\n" ] }, { "data": { "text/plain": [ "{'text': 'mina tahaksin homme täna ja homme kui saan all kolm krantsumadiseid veiki panna',\n", " 'words': [{'word': 'mina', 'start': 0.8, 'end': 0.84},\n", " {'word': 'tahaksin', 'start': 1.0, 'end': 1.32},\n", " {'word': 'homme', 'start': 1.48, 'end': 1.76},\n", " {'word': 'täna', 'start': 2.08, 'end': 2.12},\n", " {'word': 'ja', 'start': 3.72, 'end': 3.76},\n", " {'word': 'homme', 'start': 4.16, 'end': 4.44},\n", " {'word': 'kui', 'start': 5.96, 'end': 6.0},\n", " {'word': 'saan', 'start': 6.52, 'end': 6.84},\n", " {'word': 'all', 'start': 7.36, 'end': 7.4},\n", " {'word': 'kolm', 'start': 8.32, 'end': 8.36},\n", " {'word': 'krantsumadiseid', 'start': 8.68, 'end': 9.72},\n", " {'word': 'veiki', 'start': 9.76, 'end': 10.04},\n", " {'word': 'panna', 'start': 10.16, 'end': 10.4}]}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#transribe audiofile (NB! model assumes sample rate of 16000)\n", "%time transcriber.transcribe_file('audio/emt16k.wav')" ] }, { "cell_type": "code", "execution_count": 10, "id": "4d2a480d-f0aa-4474-bfdb-ad298a629ce5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 16.2 s, sys: 1.8 s, total: 18 s\n", "Wall time: 15.1 s\n" ] } ], "source": [ "%time trs=transcriber.transcribe_file('audio/oden_kypsis16k.wav')" ] }, { "cell_type": "code", "execution_count": 11, "id": "d3827548-bca0-4409-95bc-9aa8ba377135", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'text': 'enamus ajast nagu klikkid neid allserva tekivad need luba küpsiseid mis on nagu ilusti kohati tõlgitud eesti keelde see idee arusaadavamaks ma tean et see on kukis inglise kees ma ei saa sellest ka aru nagu mis asi on kukis on ju ma saan aru et ta vaid minee eest ära luba küpsises tava ei anna noh anna minna ma luban küpssi juhmaoloog okei on ju ma ei tea mis ta teeb lihtsalt selle eestikeelseks tõlk või eesti keelde tõlkimine kui teinud seda nagu arusaadavamaks küpsised kuule kuule veebisaid küsib sinu käest tahad tähendab on okei kui me neid kugiseid kasutame sa mingi ja mida iga mul täiesti savi või noh et et jah',\n", " 'words': [{'word': 'enamus', 'start': 3.56, 'end': 3.8},\n", " {'word': 'ajast', 'start': 3.8, 'end': 4.04},\n", " {'word': 'nagu', 'start': 4.2, 'end': 4.24},\n", " {'word': 'klikkid', 'start': 4.72, 'end': 5.12},\n", " {'word': 'neid', 'start': 5.16, 'end': 5.2},\n", " {'word': 'allserva', 'start': 5.72, 'end': 6.2},\n", " {'word': 'tekivad', 'start': 6.32, 'end': 6.64},\n", " {'word': 'need', 'start': 7.4, 'end': 7.44},\n", " {'word': 'luba', 'start': 7.72, 'end': 8.0},\n", " {'word': 'küpsiseid', 'start': 8.08, 'end': 8.64},\n", " {'word': 'mis', 'start': 9.68, 'end': 9.72},\n", " {'word': 'on', 'start': 9.76, 'end': 9.8},\n", " {'word': 'nagu', 'start': 9.92, 'end': 9.96},\n", " {'word': 'ilusti', 'start': 10.04, 'end': 10.36},\n", " {'word': 'kohati', 'start': 10.4, 'end': 10.68},\n", " {'word': 'tõlgitud', 'start': 11.08, 'end': 11.4},\n", " {'word': 'eesti', 'start': 11.6, 'end': 11.64},\n", " {'word': 'keelde', 'start': 11.8, 'end': 12.08},\n", " {'word': 'see', 'start': 12.68, 'end': 12.72},\n", " {'word': 'idee', 'start': 12.8, 'end': 13.04},\n", " {'word': 'arusaadavamaks', 'start': 13.2, 'end': 13.8},\n", " {'word': 'ma', 'start': 13.92, 'end': 13.96},\n", " {'word': 'tean', 'start': 14.04, 'end': 14.24},\n", " {'word': 'et', 'start': 14.28, 'end': 14.36},\n", " {'word': 'see', 'start': 14.4, 'end': 14.44},\n", " {'word': 'on', 'start': 14.44, 'end': 14.52},\n", " {'word': 'kukis', 'start': 14.56, 'end': 14.92},\n", " {'word': 'inglise', 'start': 14.92, 'end': 15.2},\n", " {'word': 'kees', 'start': 15.2, 'end': 15.44},\n", " {'word': 'ma', 'start': 15.84, 'end': 15.88},\n", " {'word': 'ei', 'start': 15.92, 'end': 16.0},\n", " {'word': 'saa', 'start': 16.04, 'end': 16.08},\n", " {'word': 'sellest', 'start': 16.24, 'end': 16.28},\n", " {'word': 'ka', 'start': 16.56, 'end': 16.6},\n", " {'word': 'aru', 'start': 16.76, 'end': 16.8},\n", " {'word': 'nagu', 'start': 16.96, 'end': 17.0},\n", " {'word': 'mis', 'start': 17.12, 'end': 17.16},\n", " {'word': 'asi', 'start': 17.28, 'end': 17.32},\n", " {'word': 'on', 'start': 17.36, 'end': 17.4},\n", " {'word': 'kukis', 'start': 17.48, 'end': 17.8},\n", " {'word': 'on', 'start': 17.88, 'end': 17.92},\n", " {'word': 'ju', 'start': 17.96, 'end': 18.0},\n", " {'word': 'ma', 'start': 18.28, 'end': 18.32},\n", " {'word': 'saan', 'start': 18.36, 'end': 18.48},\n", " {'word': 'aru', 'start': 18.52, 'end': 18.56},\n", " {'word': 'et', 'start': 18.72, 'end': 18.76},\n", " {'word': 'ta', 'start': 19.2, 'end': 19.24},\n", " {'word': 'vaid', 'start': 19.32, 'end': 19.44},\n", " {'word': 'minee', 'start': 19.48, 'end': 19.68},\n", " {'word': 'eest', 'start': 19.76, 'end': 19.96},\n", " {'word': 'ära', 'start': 20.12, 'end': 20.16},\n", " {'word': 'luba', 'start': 21.56, 'end': 21.88},\n", " {'word': 'küpsises', 'start': 21.96, 'end': 22.44},\n", " {'word': 'tava', 'start': 22.6, 'end': 22.76},\n", " {'word': 'ei', 'start': 22.84, 'end': 22.88},\n", " {'word': 'anna', 'start': 23.0, 'end': 23.16},\n", " {'word': 'noh', 'start': 23.4, 'end': 23.44},\n", " {'word': 'anna', 'start': 23.64, 'end': 23.76},\n", " {'word': 'minna', 'start': 24.0, 'end': 24.04},\n", " {'word': 'ma', 'start': 24.16, 'end': 24.2},\n", " {'word': 'luban', 'start': 24.24, 'end': 24.56},\n", " {'word': 'küpssi', 'start': 24.64, 'end': 24.92},\n", " {'word': 'juhmaoloog', 'start': 25.0, 'end': 25.28},\n", " {'word': 'okei', 'start': 25.28, 'end': 25.56},\n", " {'word': 'on', 'start': 25.64, 'end': 25.72},\n", " {'word': 'ju', 'start': 25.72, 'end': 25.76},\n", " {'word': 'ma', 'start': 25.84, 'end': 25.88},\n", " {'word': 'ei', 'start': 25.92, 'end': 25.96},\n", " {'word': 'tea', 'start': 26.0, 'end': 26.04},\n", " {'word': 'mis', 'start': 26.28, 'end': 26.32},\n", " {'word': 'ta', 'start': 26.36, 'end': 26.4},\n", " {'word': 'teeb', 'start': 26.56, 'end': 26.8},\n", " {'word': 'lihtsalt', 'start': 27.04, 'end': 27.08},\n", " {'word': 'selle', 'start': 27.24, 'end': 27.28},\n", " {'word': 'eestikeelseks', 'start': 28.04, 'end': 28.68},\n", " {'word': 'tõlk', 'start': 28.8, 'end': 29.08},\n", " {'word': 'või', 'start': 29.16, 'end': 29.2},\n", " {'word': 'eesti', 'start': 29.48, 'end': 29.52},\n", " {'word': 'keelde', 'start': 29.68, 'end': 30.04},\n", " {'word': 'tõlkimine', 'start': 30.2, 'end': 30.68},\n", " {'word': 'kui', 'start': 30.8, 'end': 30.84},\n", " {'word': 'teinud', 'start': 30.96, 'end': 31.16},\n", " {'word': 'seda', 'start': 31.2, 'end': 31.24},\n", " {'word': 'nagu', 'start': 31.72, 'end': 31.76},\n", " {'word': 'arusaadavamaks', 'start': 31.88, 'end': 32.6},\n", " {'word': 'küpsised', 'start': 33.52, 'end': 33.88},\n", " {'word': 'kuule', 'start': 36.96, 'end': 37.08},\n", " {'word': 'kuule', 'start': 37.32, 'end': 37.44},\n", " {'word': 'veebisaid', 'start': 37.8, 'end': 38.28},\n", " {'word': 'küsib', 'start': 38.44, 'end': 38.56},\n", " {'word': 'sinu', 'start': 38.6, 'end': 38.72},\n", " {'word': 'käest', 'start': 38.76, 'end': 39.0},\n", " {'word': 'tahad', 'start': 39.52, 'end': 39.72},\n", " {'word': 'tähendab', 'start': 40.32, 'end': 40.36},\n", " {'word': 'on', 'start': 40.8, 'end': 40.88},\n", " {'word': 'okei', 'start': 40.88, 'end': 41.2},\n", " {'word': 'kui', 'start': 41.24, 'end': 41.28},\n", " {'word': 'me', 'start': 41.36, 'end': 41.4},\n", " {'word': 'neid', 'start': 41.6, 'end': 41.64},\n", " {'word': 'kugiseid', 'start': 42.2, 'end': 42.64},\n", " {'word': 'kasutame', 'start': 42.8, 'end': 43.08},\n", " {'word': 'sa', 'start': 43.56, 'end': 43.6},\n", " {'word': 'mingi', 'start': 43.8, 'end': 43.84},\n", " {'word': 'ja', 'start': 44.04, 'end': 44.08},\n", " {'word': 'mida', 'start': 44.28, 'end': 44.32},\n", " {'word': 'iga', 'start': 44.44, 'end': 44.48},\n", " {'word': 'mul', 'start': 44.56, 'end': 44.6},\n", " {'word': 'täiesti', 'start': 44.92, 'end': 44.96},\n", " {'word': 'savi', 'start': 45.08, 'end': 45.28},\n", " {'word': 'või', 'start': 45.36, 'end': 45.4},\n", " {'word': 'noh', 'start': 45.44, 'end': 45.48},\n", " {'word': 'et', 'start': 45.6, 'end': 45.64},\n", " {'word': 'et', 'start': 47.36, 'end': 47.4},\n", " {'word': 'jah', 'start': 47.56, 'end': 47.68}]}" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trs" ] }, { "cell_type": "code", "execution_count": null, "id": "ea3b25b7-a1f9-4b21-911d-35159c5f3009", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 5 }