Uploading ngram base model
Browse files- LICENSE +201 -0
- README.md +142 -3
- beam_search_utils.py +325 -0
- requirements.txt +9 -0
- run_speaker_tagging_beam_search.sh +59 -0
- speaker_tagging_beamsearch.py +76 -0
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,3 +1,142 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# llm_speaker_tagging
|
2 |
+
|
3 |
+
SLT 2024 Challenge: Post-ASR-Speaker-Tagging Baseline
|
4 |
+
|
5 |
+
# Project Name
|
6 |
+
|
7 |
+
SLT 2024 Challenge GenSEC Track 2: Post-ASR-Speaker-Tagging Baseline
|
8 |
+
|
9 |
+
## Features
|
10 |
+
|
11 |
+
- Data download and cleaning
|
12 |
+
- n-gram + beam search decoder based baselinee system
|
13 |
+
|
14 |
+
## Installation
|
15 |
+
|
16 |
+
Run the following commands at the main level of this repository.
|
17 |
+
|
18 |
+
### Conda Environment
|
19 |
+
|
20 |
+
```
|
21 |
+
conda create --name llmspk python=3.10
|
22 |
+
```
|
23 |
+
### Install requirements
|
24 |
+
|
25 |
+
You need to install the following packages
|
26 |
+
|
27 |
+
```
|
28 |
+
kenlm
|
29 |
+
arpa
|
30 |
+
numpy
|
31 |
+
hydra-core
|
32 |
+
meeteval
|
33 |
+
tqdm
|
34 |
+
requests
|
35 |
+
simplejson
|
36 |
+
pydiardecode @ git+https://github.com/tango4j/pydiardecode@main
|
37 |
+
```
|
38 |
+
|
39 |
+
Simply install all the requirments.
|
40 |
+
|
41 |
+
```
|
42 |
+
pip install -r requirements.txt
|
43 |
+
```
|
44 |
+
|
45 |
+
### Download ARPA language model
|
46 |
+
|
47 |
+
```
|
48 |
+
mkdir -p arpa_model
|
49 |
+
cd arpa_model
|
50 |
+
wget https://kaldi-asr.org/models/5/4gram_small.arpa.gz
|
51 |
+
gunzip 4gram_small.arpa.gz
|
52 |
+
```
|
53 |
+
|
54 |
+
### Download track-2 challenge dev set and eval set
|
55 |
+
|
56 |
+
Clone the dataset from Hugging Face server.
|
57 |
+
```
|
58 |
+
git clone https://huggingface.co/datasets/GenSEC-LLM/SLT-Task2-Post-ASR-Speaker-Tagging
|
59 |
+
```
|
60 |
+
|
61 |
+
```
|
62 |
+
find . $PWD/SLT-Task2-Post-ASR-Speaker-Tagging/err_source_text/dev -name *.seglst.json > err_dev.src.list
|
63 |
+
find . $PWD/SLT-Task2-Post-ASR-Speaker-Tagging/ref_annotated_text/dev -name *.seglst.json > err_dev.ref.list
|
64 |
+
```
|
65 |
+
|
66 |
+
### Launch the baseline script
|
67 |
+
|
68 |
+
Now you are ready to launch the script.
|
69 |
+
Launch the baseline script `run_speaker_tagging_beam_search.sh`
|
70 |
+
|
71 |
+
```
|
72 |
+
BASEPATH=${PWD}
|
73 |
+
DIAR_LM_PATH=$BASEPATH/arpa_model/4gram_small.arpa
|
74 |
+
ASRDIAR_FILE_NAME=err_dev
|
75 |
+
WORKSPACE=$BASEPATH/SLT-Task2-Post-ASR-Speaker-Tagging
|
76 |
+
INPUT_ERROR_SRC_LIST_PATH=$BASEPATH/$ASRDIAR_FILE_NAME.src.list
|
77 |
+
GROUNDTRUTH_REF_LIST_PATH=$BASEPATH/$ASRDIAR_FILE_NAME.ref.list
|
78 |
+
DIAR_OUT_DOWNLOAD=$WORKSPACE/short2_all_seglst_infer
|
79 |
+
mkdir -p $DIAR_OUT_DOWNLOAD
|
80 |
+
|
81 |
+
### SLT 2024 Speaker Tagging Setting v1.0.2
|
82 |
+
ALPHA=0.4
|
83 |
+
BETA=0.04
|
84 |
+
PARALLEL_CHUNK_WORD_LEN=100
|
85 |
+
BEAM_WIDTH=16
|
86 |
+
WORD_WINDOW=32
|
87 |
+
PEAK_PROB=0.95
|
88 |
+
USE_NGRAM=True
|
89 |
+
LM_METHOD=ngram
|
90 |
+
|
91 |
+
# Get the base name of the test_manifest and remove extension
|
92 |
+
UNIQ_MEMO=$(basename "${INPUT_ERROR_SRC_LIST_PATH}" .json | sed 's/\./_/g')
|
93 |
+
echo "UNIQ MEMO:" $UNIQ_MEMO
|
94 |
+
TRIAL=telephonic
|
95 |
+
BATCH_SIZE=11
|
96 |
+
|
97 |
+
rm $WORKSPACE/$ASRDIAR_FILE_NAME.src.seglst.json
|
98 |
+
rm $WORKSPACE/$ASRDIAR_FILE_NAME.ref.seglst.json
|
99 |
+
rm $WORKSPACE/$ASRDIAR_FILE_NAME.hyp.seglst.json
|
100 |
+
|
101 |
+
python $BASEPATH/speaker_tagging_beamsearch.py \
|
102 |
+
port=[5501,5502,5511,5512,5521,5522,5531,5532] \
|
103 |
+
arpa_language_model=$DIAR_LM_PATH \
|
104 |
+
batch_size=$BATCH_SIZE \
|
105 |
+
groundtruth_ref_list_path=$GROUNDTRUTH_REF_LIST_PATH \
|
106 |
+
input_error_src_list_path=$INPUT_ERROR_SRC_LIST_PATH \
|
107 |
+
parallel_chunk_word_len=$PARALLEL_CHUNK_WORD_LEN \
|
108 |
+
use_ngram=$USE_NGRAM \
|
109 |
+
alpha=$ALPHA \
|
110 |
+
beta=$BETA \
|
111 |
+
beam_width=$BEAM_WIDTH \
|
112 |
+
word_window=$WORD_WINDOW \
|
113 |
+
peak_prob=$PEAK_PROB \
|
114 |
+
out_dir=$DIAR_OUT_DOWNLOAD
|
115 |
+
```
|
116 |
+
|
117 |
+
### Evaluate
|
118 |
+
|
119 |
+
We use [MeetEval](https://github.com/fgnt/meeteval) software to evaluate `cpWER`.
|
120 |
+
cpWER measures both speaker tagging and word error rate (WER) by testing all the permutation of trancripts and choosing the permutation that
|
121 |
+
gives the lowest error.
|
122 |
+
|
123 |
+
```
|
124 |
+
echo "Evaluating the original source transcript."
|
125 |
+
meeteval-wer cpwer -h $WORKSPACE/$ASRDIAR_FILE_NAME.src.seglst.json -r $WORKSPACE/$ASRDIAR_FILE_NAME.ref.seglst.json
|
126 |
+
echo "Source cpWER: " $(jq '.error_rate' "[ $WORKSPACE/$ASRDIAR_FILE_NAME.src.seglst_cpwer.json) ]"
|
127 |
+
|
128 |
+
echo "Evaluating the original hypothesis transcript."
|
129 |
+
meeteval-wer cpwer -h $WORKSPACE/$ASRDIAR_FILE_NAME.hyp.seglst.json -r $WORKSPACE/$ASRDIAR_FILE_NAME.ref.seglst.json
|
130 |
+
echo "Hypothesis cpWER: " $(jq '.error_rate' $WORKSPACE/$ASRDIAR_FILE_NAME.hyp.seglst_cpwer.json)
|
131 |
+
```
|
132 |
+
|
133 |
+
### Reference
|
134 |
+
|
135 |
+
@inproceedings{park2024enhancing,
|
136 |
+
title={Enhancing speaker diarization with large language models: A contextual beam search approach},
|
137 |
+
author={Park, Tae Jin and Dhawan, Kunal and Koluguri, Nithin and Balam, Jagadeesh},
|
138 |
+
booktitle={ICASSP 2024-2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
|
139 |
+
pages={10861--10865},
|
140 |
+
year={2024},
|
141 |
+
organization={IEEE}
|
142 |
+
}
|
beam_search_utils.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
+
from typing import Dict, List
|
3 |
+
from pydiardecode import build_diardecoder
|
4 |
+
import numpy as np
|
5 |
+
import copy
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
import concurrent.futures
|
9 |
+
import kenlm
|
10 |
+
|
11 |
+
__INFO_TAG__ = "[INFO]"
|
12 |
+
|
13 |
+
class SpeakerTaggingBeamSearchDecoder:
|
14 |
+
def __init__(self, loaded_kenlm_model: kenlm, cfg: dict):
|
15 |
+
self.realigning_lm_params = cfg
|
16 |
+
self.realigning_lm = self._load_realigning_LM(loaded_kenlm_model=loaded_kenlm_model)
|
17 |
+
self._SPLITSYM = "@"
|
18 |
+
|
19 |
+
def _load_realigning_LM(self, loaded_kenlm_model: kenlm):
|
20 |
+
"""
|
21 |
+
Load ARPA language model for realigning speaker labels for words.
|
22 |
+
"""
|
23 |
+
diar_decoder = build_diardecoder(
|
24 |
+
loaded_kenlm_model=loaded_kenlm_model,
|
25 |
+
kenlm_model_path=self.realigning_lm_params['arpa_language_model'],
|
26 |
+
alpha=self.realigning_lm_params['alpha'],
|
27 |
+
beta=self.realigning_lm_params['beta'],
|
28 |
+
word_window=self.realigning_lm_params['word_window'],
|
29 |
+
use_ngram=self.realigning_lm_params['use_ngram'],
|
30 |
+
)
|
31 |
+
return diar_decoder
|
32 |
+
|
33 |
+
def realign_words_with_lm(self, word_dict_seq_list: List[Dict[str, float]], speaker_count: int = None, port_num=None) -> List[Dict[str, float]]:
|
34 |
+
if speaker_count is None:
|
35 |
+
spk_list = []
|
36 |
+
for k, line_dict in enumerate(word_dict_seq_list):
|
37 |
+
_, spk_label = line_dict['word'], line_dict['speaker']
|
38 |
+
spk_list.append(spk_label)
|
39 |
+
else:
|
40 |
+
spk_list = [ f"speaker_{k}" for k in range(speaker_count)]
|
41 |
+
|
42 |
+
realigned_list = self.realigning_lm.decode_beams(beam_width=self.realigning_lm_params['beam_width'],
|
43 |
+
speaker_list=sorted(list(set(spk_list))),
|
44 |
+
word_dict_seq_list=word_dict_seq_list,
|
45 |
+
port_num=port_num)
|
46 |
+
return realigned_list
|
47 |
+
|
48 |
+
def beam_search_diarization(
|
49 |
+
self,
|
50 |
+
trans_info_dict: Dict[str, Dict[str, list]],
|
51 |
+
port_num: List[int] = None,
|
52 |
+
) -> Dict[str, Dict[str, float]]:
|
53 |
+
"""
|
54 |
+
Match the diarization result with the ASR output.
|
55 |
+
The words and the timestamps for the corresponding words are matched in a for loop.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
trans_info_dict (dict):
|
61 |
+
Dictionary containing word timestamps, speaker labels and words from all sessions.
|
62 |
+
Each session is indexed by a unique ID.
|
63 |
+
"""
|
64 |
+
for uniq_id, session_dict in tqdm(trans_info_dict.items(), total=len(trans_info_dict), disable=True):
|
65 |
+
word_dict_seq_list = session_dict['words']
|
66 |
+
output_beams = self.realign_words_with_lm(word_dict_seq_list=word_dict_seq_list, speaker_count=session_dict['speaker_count'], port_num=port_num)
|
67 |
+
word_dict_seq_list = output_beams[0][2]
|
68 |
+
trans_info_dict[uniq_id]['words'] = word_dict_seq_list
|
69 |
+
return trans_info_dict
|
70 |
+
|
71 |
+
def merge_div_inputs(self, div_trans_info_dict, org_trans_info_dict, win_len=250, word_window=16):
|
72 |
+
"""
|
73 |
+
Merge the outputs of parallel processing.
|
74 |
+
"""
|
75 |
+
uniq_id_list = list(org_trans_info_dict.keys())
|
76 |
+
sub_div_dict = {}
|
77 |
+
for seq_id in div_trans_info_dict.keys():
|
78 |
+
div_info = seq_id.split(self._SPLITSYM)
|
79 |
+
uniq_id, sub_idx, total_count = div_info[0], int(div_info[1]), int(div_info[2])
|
80 |
+
if uniq_id not in sub_div_dict:
|
81 |
+
sub_div_dict[uniq_id] = [None] * total_count
|
82 |
+
sub_div_dict[uniq_id][sub_idx] = div_trans_info_dict[seq_id]['words']
|
83 |
+
|
84 |
+
for uniq_id in uniq_id_list:
|
85 |
+
org_trans_info_dict[uniq_id]['words'] = []
|
86 |
+
for k, div_words in enumerate(sub_div_dict[uniq_id]):
|
87 |
+
if k == 0:
|
88 |
+
div_words = div_words[:win_len]
|
89 |
+
else:
|
90 |
+
div_words = div_words[word_window:]
|
91 |
+
org_trans_info_dict[uniq_id]['words'].extend(div_words)
|
92 |
+
return org_trans_info_dict
|
93 |
+
|
94 |
+
def divide_chunks(self, trans_info_dict, win_len, word_window, port):
|
95 |
+
"""
|
96 |
+
Divide word sequence into chunks of length `win_len` for parallel processing.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
trans_info_dict (_type_): _description_
|
100 |
+
diar_logits (_type_): _description_
|
101 |
+
win_len (int, optional): _description_. Defaults to 250.
|
102 |
+
"""
|
103 |
+
if len(port) > 1:
|
104 |
+
num_workers = len(port)
|
105 |
+
else:
|
106 |
+
num_workers = 1
|
107 |
+
div_trans_info_dict = {}
|
108 |
+
for uniq_id in trans_info_dict.keys():
|
109 |
+
uniq_trans = trans_info_dict[uniq_id]
|
110 |
+
del uniq_trans['status']
|
111 |
+
del uniq_trans['transcription']
|
112 |
+
del uniq_trans['sentences']
|
113 |
+
word_seq = uniq_trans['words']
|
114 |
+
|
115 |
+
div_word_seq = []
|
116 |
+
if win_len is None:
|
117 |
+
win_len = int(np.ceil(len(word_seq)/num_workers))
|
118 |
+
n_chunks = int(np.ceil(len(word_seq)/win_len))
|
119 |
+
|
120 |
+
for k in range(n_chunks):
|
121 |
+
div_word_seq.append(word_seq[max(k*win_len - word_window, 0):(k+1)*win_len])
|
122 |
+
|
123 |
+
total_count = len(div_word_seq)
|
124 |
+
for k, w_seq in enumerate(div_word_seq):
|
125 |
+
seq_id = uniq_id + f"{self._SPLITSYM}{k}{self._SPLITSYM}{total_count}"
|
126 |
+
div_trans_info_dict[seq_id] = dict(uniq_trans)
|
127 |
+
div_trans_info_dict[seq_id]['words'] = w_seq
|
128 |
+
return div_trans_info_dict
|
129 |
+
|
130 |
+
|
131 |
+
def run_mp_beam_search_decoding(
|
132 |
+
speaker_beam_search_decoder,
|
133 |
+
loaded_kenlm_model,
|
134 |
+
trans_info_dict,
|
135 |
+
org_trans_info_dict,
|
136 |
+
div_mp,
|
137 |
+
win_len,
|
138 |
+
word_window,
|
139 |
+
port=None,
|
140 |
+
use_ngram=False
|
141 |
+
):
|
142 |
+
if len(port) > 1:
|
143 |
+
port = [int(p) for p in port]
|
144 |
+
if use_ngram:
|
145 |
+
port = [None]
|
146 |
+
num_workers = 36
|
147 |
+
else:
|
148 |
+
num_workers = len(port)
|
149 |
+
|
150 |
+
uniq_id_list = sorted(list(trans_info_dict.keys() ))
|
151 |
+
tp = concurrent.futures.ProcessPoolExecutor(max_workers=num_workers)
|
152 |
+
futures = []
|
153 |
+
|
154 |
+
count = 0
|
155 |
+
for uniq_id in uniq_id_list:
|
156 |
+
print(f"{__INFO_TAG__} Running beam search decoding for {uniq_id}...")
|
157 |
+
if port is not None:
|
158 |
+
port_num = port[count % len(port)]
|
159 |
+
else:
|
160 |
+
port_num = None
|
161 |
+
count += 1
|
162 |
+
uniq_trans_info_dict = {uniq_id: trans_info_dict[uniq_id]}
|
163 |
+
futures.append(tp.submit(speaker_beam_search_decoder.beam_search_diarization, uniq_trans_info_dict, port_num=port_num))
|
164 |
+
|
165 |
+
pbar = tqdm(total=len(uniq_id_list), desc="Running beam search decoding", unit="files")
|
166 |
+
count = 0
|
167 |
+
output_trans_info_dict = {}
|
168 |
+
for done_future in concurrent.futures.as_completed(futures):
|
169 |
+
count += 1
|
170 |
+
pbar.update()
|
171 |
+
output_trans_info_dict.update(done_future.result())
|
172 |
+
pbar.close()
|
173 |
+
tp.shutdown()
|
174 |
+
if div_mp:
|
175 |
+
output_trans_info_dict = speaker_beam_search_decoder.merge_div_inputs(div_trans_info_dict=output_trans_info_dict,
|
176 |
+
org_trans_info_dict=org_trans_info_dict,
|
177 |
+
win_len=win_len,
|
178 |
+
word_window=word_window)
|
179 |
+
return output_trans_info_dict
|
180 |
+
|
181 |
+
def count_num_of_spks(json_trans_list):
|
182 |
+
spk_set = set()
|
183 |
+
for sentence_dict in json_trans_list:
|
184 |
+
spk_set.add(sentence_dict['speaker'])
|
185 |
+
speaker_map = { spk_str: idx for idx, spk_str in enumerate(spk_set)}
|
186 |
+
return speaker_map
|
187 |
+
|
188 |
+
def add_placeholder_speaker_softmax(json_trans_list, peak_prob=0.94 ,max_spks=4):
|
189 |
+
nemo_json_dict = {}
|
190 |
+
word_dict_seq_list = []
|
191 |
+
if peak_prob > 1 or peak_prob < 0:
|
192 |
+
raise ValueError(f"peak_prob must be between 0 and 1 but got {peak_prob}")
|
193 |
+
speaker_map = count_num_of_spks(json_trans_list)
|
194 |
+
base_array = np.ones(max_spks) * (1 - peak_prob)/(max_spks-1)
|
195 |
+
stt_sec, end_sec = None, None
|
196 |
+
for sentence_dict in json_trans_list:
|
197 |
+
word_list = sentence_dict['words'].split()
|
198 |
+
speaker = sentence_dict['speaker']
|
199 |
+
for word in word_list:
|
200 |
+
speaker_softmax = copy.deepcopy(base_array)
|
201 |
+
speaker_softmax[speaker_map[speaker]] = peak_prob
|
202 |
+
word_dict_seq_list.append({'word': word,
|
203 |
+
'start_time': stt_sec,
|
204 |
+
'end_time': end_sec,
|
205 |
+
'speaker': speaker_map[speaker],
|
206 |
+
'speaker_softmax': speaker_softmax}
|
207 |
+
)
|
208 |
+
nemo_json_dict.update({'words': word_dict_seq_list,
|
209 |
+
'status': "success",
|
210 |
+
'sentences': json_trans_list,
|
211 |
+
'speaker_count': len(speaker_map),
|
212 |
+
'transcription': None}
|
213 |
+
)
|
214 |
+
return nemo_json_dict
|
215 |
+
|
216 |
+
def convert_nemo_json_to_seglst(trans_info_dict):
|
217 |
+
seglst_seq_list = []
|
218 |
+
seg_lst_dict, spk_wise_trans_sessions = {}, {}
|
219 |
+
for uniq_id in trans_info_dict.keys():
|
220 |
+
spk_wise_trans_sessions[uniq_id] = {}
|
221 |
+
seglst_seq_list = []
|
222 |
+
word_seq_list = trans_info_dict[uniq_id]['words']
|
223 |
+
prev_speaker, sentence = None, ''
|
224 |
+
for widx, word_dict in enumerate(word_seq_list):
|
225 |
+
curr_speaker = word_dict['speaker']
|
226 |
+
|
227 |
+
# For making speaker wise transcriptions
|
228 |
+
word = word_dict['word']
|
229 |
+
if curr_speaker not in spk_wise_trans_sessions[uniq_id]:
|
230 |
+
spk_wise_trans_sessions[uniq_id][curr_speaker] = word
|
231 |
+
elif curr_speaker in spk_wise_trans_sessions[uniq_id]:
|
232 |
+
spk_wise_trans_sessions[uniq_id][curr_speaker] = f"{spk_wise_trans_sessions[uniq_id][curr_speaker]} {word_dict['word']}"
|
233 |
+
|
234 |
+
# For making segment wise transcriptions
|
235 |
+
if curr_speaker!= prev_speaker and prev_speaker is not None:
|
236 |
+
seglst_seq_list.append({'session_id': uniq_id,
|
237 |
+
'words': sentence.strip(),
|
238 |
+
'start_time': 0.0,
|
239 |
+
'end_time': 0.0,
|
240 |
+
'speaker': prev_speaker,
|
241 |
+
})
|
242 |
+
sentence = word_dict['word']
|
243 |
+
else:
|
244 |
+
sentence = f"{sentence} {word_dict['word']}"
|
245 |
+
prev_speaker = curr_speaker
|
246 |
+
|
247 |
+
# For the last word:
|
248 |
+
# (1) If there is no speaker change, add the existing sentence and exit the loop
|
249 |
+
# (2) If there is a speaker change, add the last word and exit the loop
|
250 |
+
if widx == len(word_seq_list) - 1:
|
251 |
+
seglst_seq_list.append({'session_id': uniq_id,
|
252 |
+
'words': sentence.strip(),
|
253 |
+
'start_time': 0.0,
|
254 |
+
'end_time': 0.0,
|
255 |
+
'speaker': curr_speaker,
|
256 |
+
})
|
257 |
+
seg_lst_dict[uniq_id] = seglst_seq_list
|
258 |
+
return seg_lst_dict
|
259 |
+
|
260 |
+
def load_input_jsons(input_error_src_list_path, ext_str=".seglst.json", peak_prob=0.94, max_spks=4):
|
261 |
+
trans_info_dict = {}
|
262 |
+
json_filepath_list = open(input_error_src_list_path).readlines()
|
263 |
+
for json_path in json_filepath_list:
|
264 |
+
json_path = json_path.strip()
|
265 |
+
uniq_id = os.path.split(json_path)[-1].split(ext_str)[0]
|
266 |
+
if os.path.exists(json_path):
|
267 |
+
with open(json_path, "r") as file:
|
268 |
+
json_trans = json.load(file)
|
269 |
+
else:
|
270 |
+
raise FileNotFoundError(f"{json_path} does not exist. Aborting.")
|
271 |
+
nemo_json_dict = add_placeholder_speaker_softmax(json_trans, peak_prob=peak_prob, max_spks=max_spks)
|
272 |
+
trans_info_dict[uniq_id] = nemo_json_dict
|
273 |
+
return trans_info_dict
|
274 |
+
|
275 |
+
def load_reference_jsons(reference_seglst_list_path, ext_str=".seglst.json"):
|
276 |
+
reference_info_dict = {}
|
277 |
+
json_filepath_list = open(reference_seglst_list_path).readlines()
|
278 |
+
for json_path in json_filepath_list:
|
279 |
+
json_path = json_path.strip()
|
280 |
+
uniq_id = os.path.split(json_path)[-1].split(ext_str)[0]
|
281 |
+
if os.path.exists(json_path):
|
282 |
+
with open(json_path, "r") as file:
|
283 |
+
json_trans = json.load(file)
|
284 |
+
else:
|
285 |
+
raise FileNotFoundError(f"{json_path} does not exist. Aborting.")
|
286 |
+
json_trans_uniq_id = []
|
287 |
+
for sentence_dict in json_trans:
|
288 |
+
sentence_dict['session_id'] = uniq_id
|
289 |
+
json_trans_uniq_id.append(sentence_dict)
|
290 |
+
reference_info_dict[uniq_id] = json_trans_uniq_id
|
291 |
+
return reference_info_dict
|
292 |
+
|
293 |
+
def write_seglst_jsons(
|
294 |
+
seg_lst_sessions_dict: dict,
|
295 |
+
input_error_src_list_path: str,
|
296 |
+
diar_out_path: str,
|
297 |
+
ext_str: str,
|
298 |
+
write_individual_seglst_jsons=True
|
299 |
+
):
|
300 |
+
"""
|
301 |
+
Writes the segment list (seglst) JSON files to the output directory.
|
302 |
+
|
303 |
+
Parameters:
|
304 |
+
seg_lst_sessions_dict (dict): A dictionary containing session IDs as keys and their corresponding segment lists as values.
|
305 |
+
input_error_src_list_path (str): The path to the input error source list file.
|
306 |
+
diar_out_path (str): The path to the output directory where the seglst JSON files will be written.
|
307 |
+
type_string (str): A string representing the type of the seglst JSON files (e.g., 'hyp' for hypothesis or 'ef' for reference).
|
308 |
+
write_individual_seglst_jsons (bool, optional): A flag indicating whether to write individual seglst JSON files for each session. Defaults to True.
|
309 |
+
|
310 |
+
Returns:
|
311 |
+
None
|
312 |
+
"""
|
313 |
+
total_infer_list = []
|
314 |
+
total_output_filename = os.path.split(input_error_src_list_path)[-1].replace(".list", "")
|
315 |
+
for session_id, seg_lst_list in seg_lst_sessions_dict.items():
|
316 |
+
total_infer_list.extend(seg_lst_list)
|
317 |
+
if write_individual_seglst_jsons:
|
318 |
+
print(f"{__INFO_TAG__} Writing {diar_out_path}/{session_id}.seglst.json")
|
319 |
+
with open(f'{diar_out_path}/{session_id}.seglst.json', 'w') as file:
|
320 |
+
json.dump(seg_lst_list, file, indent=4) # indent=4 for pretty printing
|
321 |
+
|
322 |
+
print(f"{__INFO_TAG__} Writing {diar_out_path}/{session_id}.seglst.json")
|
323 |
+
total_output_filename = total_output_filename.replace("src", ext_str).replace("ref", ext_str)
|
324 |
+
with open(f'{diar_out_path}/../{total_output_filename}.seglst.json', 'w') as file:
|
325 |
+
json.dump(total_infer_list, file, indent=4) # indent=4 for pretty printing
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
kenlm
|
2 |
+
arpa
|
3 |
+
numpy
|
4 |
+
hydra-core
|
5 |
+
meeteval
|
6 |
+
tqdm
|
7 |
+
requests
|
8 |
+
simplejson
|
9 |
+
pydiardecode @ git+https://github.com/tango4j/pydiardecode@main
|
run_speaker_tagging_beam_search.sh
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
### Speaker Tagging Task-2 Parameters
|
3 |
+
|
4 |
+
|
5 |
+
BASEPATH=${PWD}
|
6 |
+
DIAR_LM_PATH=$BASEPATH/arpa_model/4gram_small.arpa
|
7 |
+
ASRDIAR_FILE_NAME=err_dev
|
8 |
+
WORKSPACE=$BASEPATH/SLT-Task2-Post-ASR-Speaker-Tagging
|
9 |
+
INPUT_ERROR_SRC_LIST_PATH=$BASEPATH/$ASRDIAR_FILE_NAME.src.list
|
10 |
+
GROUNDTRUTH_REF_LIST_PATH=$BASEPATH/$ASRDIAR_FILE_NAME.ref.list
|
11 |
+
DIAR_OUT_DOWNLOAD=$WORKSPACE/short2_all_seglst_infer
|
12 |
+
mkdir -p $DIAR_OUT_DOWNLOAD
|
13 |
+
|
14 |
+
|
15 |
+
### SLT 2024 Speaker Tagging Setting v1.0.2
|
16 |
+
ALPHA=0.4
|
17 |
+
BETA=0.04
|
18 |
+
PARALLEL_CHUNK_WORD_LEN=100
|
19 |
+
BEAM_WIDTH=16
|
20 |
+
WORD_WINDOW=32
|
21 |
+
PEAK_PROB=0.95
|
22 |
+
USE_NGRAM=True
|
23 |
+
LM_METHOD=ngram
|
24 |
+
|
25 |
+
# Get the base name of the test_manifest and remove extension
|
26 |
+
UNIQ_MEMO=$(basename "${INPUT_ERROR_SRC_LIST_PATH}" .json | sed 's/\./_/g')
|
27 |
+
echo "UNIQ MEMO:" $UNIQ_MEMO
|
28 |
+
TRIAL=telephonic
|
29 |
+
BATCH_SIZE=11
|
30 |
+
|
31 |
+
|
32 |
+
rm $WORKSPACE/$ASRDIAR_FILE_NAME.src.seglst.json
|
33 |
+
rm $WORKSPACE/$ASRDIAR_FILE_NAME.ref.seglst.json
|
34 |
+
rm $WORKSPACE/$ASRDIAR_FILE_NAME.hyp.seglst.json
|
35 |
+
|
36 |
+
|
37 |
+
python $BASEPATH/speaker_tagging_beamsearch.py \
|
38 |
+
port=[5501,5502,5511,5512,5521,5522,5531,5532] \
|
39 |
+
arpa_language_model=$DIAR_LM_PATH \
|
40 |
+
batch_size=$BATCH_SIZE \
|
41 |
+
groundtruth_ref_list_path=$GROUNDTRUTH_REF_LIST_PATH \
|
42 |
+
input_error_src_list_path=$INPUT_ERROR_SRC_LIST_PATH \
|
43 |
+
parallel_chunk_word_len=$PARALLEL_CHUNK_WORD_LEN \
|
44 |
+
use_ngram=$USE_NGRAM \
|
45 |
+
alpha=$ALPHA \
|
46 |
+
beta=$BETA \
|
47 |
+
beam_width=$BEAM_WIDTH \
|
48 |
+
word_window=$WORD_WINDOW \
|
49 |
+
peak_prob=$PEAK_PROB \
|
50 |
+
out_dir=$DIAR_OUT_DOWNLOAD
|
51 |
+
|
52 |
+
|
53 |
+
echo "Evaluating the original source transcript."
|
54 |
+
meeteval-wer cpwer -h $WORKSPACE/$ASRDIAR_FILE_NAME.src.seglst.json -r $WORKSPACE/$ASRDIAR_FILE_NAME.ref.seglst.json
|
55 |
+
echo "Source cpWER: " $(jq '.error_rate' "[ $WORKSPACE/$ASRDIAR_FILE_NAME.src.seglst_cpwer.json) ]"
|
56 |
+
|
57 |
+
echo "Evaluating the original hypothesis transcript."
|
58 |
+
meeteval-wer cpwer -h $WORKSPACE/$ASRDIAR_FILE_NAME.hyp.seglst.json -r $WORKSPACE/$ASRDIAR_FILE_NAME.ref.seglst.json
|
59 |
+
echo "Hypothesis cpWER: " $(jq '.error_rate' $WORKSPACE/$ASRDIAR_FILE_NAME.hyp.seglst_cpwer.json)
|
speaker_tagging_beamsearch.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hydra
|
2 |
+
from typing import List, Optional
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
import kenlm
|
5 |
+
from beam_search_utils import (
|
6 |
+
SpeakerTaggingBeamSearchDecoder,
|
7 |
+
load_input_jsons,
|
8 |
+
load_reference_jsons,
|
9 |
+
write_seglst_jsons,
|
10 |
+
run_mp_beam_search_decoding,
|
11 |
+
convert_nemo_json_to_seglst,
|
12 |
+
)
|
13 |
+
from hydra.core.config_store import ConfigStore
|
14 |
+
|
15 |
+
__INFO_TAG__ = "[INFO]"
|
16 |
+
|
17 |
+
@dataclass
|
18 |
+
class RealigningLanguageModelParameters:
|
19 |
+
batch_size: int = 32
|
20 |
+
use_mp: bool = True
|
21 |
+
input_error_src_list_path: Optional[str] = None
|
22 |
+
groundtruth_ref_list_path: Optional[str] = None
|
23 |
+
arpa_language_model: Optional[str] = None
|
24 |
+
word_window: int = 32
|
25 |
+
port: List[int] = field(default_factory=list)
|
26 |
+
parallel_chunk_word_len: int = 250
|
27 |
+
use_ngram: bool = True
|
28 |
+
peak_prob: float = 0.95
|
29 |
+
alpha: float = 0.5
|
30 |
+
beta: float = 0.05
|
31 |
+
beam_width: int = 16
|
32 |
+
out_dir: Optional[str] = None
|
33 |
+
|
34 |
+
cs = ConfigStore.instance()
|
35 |
+
cs.store(name="config", node=RealigningLanguageModelParameters)
|
36 |
+
|
37 |
+
@hydra.main(config_name="config", version_base="1.1")
|
38 |
+
def main(cfg: RealigningLanguageModelParameters) -> None:
|
39 |
+
trans_info_dict = load_input_jsons(input_error_src_list_path=cfg.input_error_src_list_path, peak_prob=float(cfg.peak_prob))
|
40 |
+
reference_info_dict = load_reference_jsons(reference_seglst_list_path=cfg.groundtruth_ref_list_path)
|
41 |
+
source_info_dict = load_reference_jsons(reference_seglst_list_path=cfg.input_error_src_list_path)
|
42 |
+
loaded_kenlm_model = kenlm.Model(cfg.arpa_language_model)
|
43 |
+
|
44 |
+
speaker_beam_search_decoder = SpeakerTaggingBeamSearchDecoder(loaded_kenlm_model=loaded_kenlm_model, cfg=cfg)
|
45 |
+
|
46 |
+
div_trans_info_dict = speaker_beam_search_decoder.divide_chunks(trans_info_dict=trans_info_dict,
|
47 |
+
win_len=cfg.parallel_chunk_word_len,
|
48 |
+
word_window=cfg.word_window,
|
49 |
+
port=cfg.port,)
|
50 |
+
|
51 |
+
trans_info_dict = run_mp_beam_search_decoding(speaker_beam_search_decoder,
|
52 |
+
loaded_kenlm_model=loaded_kenlm_model,
|
53 |
+
trans_info_dict=div_trans_info_dict,
|
54 |
+
org_trans_info_dict=trans_info_dict,
|
55 |
+
div_mp=True,
|
56 |
+
win_len=cfg.parallel_chunk_word_len,
|
57 |
+
word_window=cfg.word_window,
|
58 |
+
port=cfg.port,
|
59 |
+
use_ngram=cfg.use_ngram,
|
60 |
+
)
|
61 |
+
hypothesis_sessions_dict = convert_nemo_json_to_seglst(trans_info_dict)
|
62 |
+
|
63 |
+
write_seglst_jsons(hypothesis_sessions_dict, input_error_src_list_path=cfg.input_error_src_list_path, diar_out_path=cfg.out_dir, ext_str='hyp')
|
64 |
+
write_seglst_jsons(reference_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=cfg.out_dir, ext_str='ref')
|
65 |
+
write_seglst_jsons(source_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=cfg.out_dir, ext_str='src')
|
66 |
+
print(f"{__INFO_TAG__} Parameters used: \
|
67 |
+
\n ALPHA: {cfg.alpha} \
|
68 |
+
\n BETA: {cfg.beta} \
|
69 |
+
\n BEAM WIDTH: {cfg.beam_width} \
|
70 |
+
\n Word Window: {cfg.word_window} \
|
71 |
+
\n Use Ngram: {cfg.use_ngram} \
|
72 |
+
\n Chunk Word Len: {cfg.parallel_chunk_word_len} \
|
73 |
+
\n SpeakerLM Model: {cfg.arpa_language_model}") \
|
74 |
+
|
75 |
+
if __name__ == '__main__':
|
76 |
+
main()
|