Spaces:
Running
on
L40S
Running
on
L40S
Siddhant
commited on
Commit
·
b9a6dd9
1
Parent(s):
58f82d5
Update demo
Browse files- app.py +856 -492
- pyscripts/utils/dialog_eval/ASR_WER.py +165 -0
- pyscripts/utils/dialog_eval/LLM_Metrics.py +245 -0
- pyscripts/utils/dialog_eval/TTS_intelligibility.py +169 -0
- pyscripts/utils/dialog_eval/TTS_speech_quality.py +98 -0
- pyscripts/utils/dialog_eval/__pycache__/ASR_WER.cpython-39.pyc +0 -0
- pyscripts/utils/dialog_eval/__pycache__/LLM_Metrics.cpython-39.pyc +0 -0
- pyscripts/utils/dialog_eval/__pycache__/TTS_intelligibility.cpython-39.pyc +0 -0
- pyscripts/utils/dialog_eval/__pycache__/TTS_speech_quality.cpython-39.pyc +0 -0
- pyscripts/utils/dialog_eval/__pycache__/human_feedback.cpython-39.pyc +0 -0
- pyscripts/utils/dialog_eval/__pycache__/vert.cpython-39.pyc +0 -0
- pyscripts/utils/dialog_eval/human_feedback.py +242 -0
- pyscripts/utils/dialog_eval/vert.py +299 -0
app.py
CHANGED
@@ -5,347 +5,382 @@ except ImportError:
|
|
5 |
with open('versa.sh', 'rb') as file:
|
6 |
script = file.read()
|
7 |
rc = call(script, shell=True)
|
|
|
8 |
import os
|
9 |
import shutil
|
10 |
-
|
11 |
-
from
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
from espnet2.sds.llm.hugging_face_llm import HuggingFaceLLM
|
17 |
-
from espnet2.sds.vad.webrtc_vad import WebrtcVADModel
|
18 |
-
from espnet2.sds.eval.TTS_intelligibility import handle_espnet_TTS_intelligibility
|
19 |
-
from espnet2.sds.eval.ASR_WER import handle_espnet_ASR_WER
|
20 |
-
from espnet2.sds.eval.TTS_speech_quality import TTS_psuedomos
|
21 |
-
from espnet2.sds.eval.LLM_Metrics import perplexity, vert, bert_score, DialoGPT_perplexity
|
22 |
-
from espnet2.sds.utils.chat import Chat
|
23 |
-
from espnet2.sds.end_to_end.mini_omni_e2e import MiniOmniE2EModel
|
24 |
-
import argparse
|
25 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
access_token = os.environ.get("HF_TOKEN")
|
28 |
ASR_name="pyf98/owsm_ctc_v3.1_1B"
|
29 |
LLM_name="meta-llama/Llama-3.2-1B-Instruct"
|
30 |
TTS_name="kan-bayashi/ljspeech_vits"
|
31 |
-
ASR_options="pyf98/owsm_ctc_v3.1_1B,espnet/owsm_ctc_v3.2_ft_1B,espnet/owsm_v3.1_ebf,librispeech_asr,whisper".split(",")
|
32 |
LLM_options="meta-llama/Llama-3.2-1B-Instruct,HuggingFaceTB/SmolLM2-1.7B-Instruct".split(",")
|
33 |
TTS_options="kan-bayashi/ljspeech_vits,kan-bayashi/libritts_xvector_vits,kan-bayashi/vctk_multi_spk_vits,ChatTTS".split(",")
|
34 |
Eval_options="Latency,TTS Intelligibility,TTS Speech Quality,ASR WER,Text Dialog Metrics"
|
35 |
upload_to_hub=None
|
|
|
|
|
|
|
36 |
ASR_curr_name=None
|
37 |
LLM_curr_name=None
|
38 |
TTS_curr_name=None
|
39 |
-
# def read_args():
|
40 |
-
# global access_token
|
41 |
-
# global ASR_name
|
42 |
-
# global LLM_name
|
43 |
-
# global TTS_name
|
44 |
-
# global ASR_options
|
45 |
-
# global LLM_options
|
46 |
-
# global TTS_options
|
47 |
-
# global Eval_options
|
48 |
-
# global upload_to_hub
|
49 |
-
# parser = argparse.ArgumentParser(description="Run the app with HF_TOKEN as a command-line argument.")
|
50 |
-
# parser.add_argument("--HF_TOKEN", required=True, help="Provide the Hugging Face token.")
|
51 |
-
# parser.add_argument("--asr_options", required=True, help="Provide the possible ASR options available to user.")
|
52 |
-
# parser.add_argument("--llm_options", required=True, help="Provide the possible LLM options available to user.")
|
53 |
-
# parser.add_argument("--tts_options", required=True, help="Provide the possible TTS options available to user.")
|
54 |
-
# parser.add_argument("--eval_options", required=True, help="Provide the possible automatic evaluation metrics available to user.")
|
55 |
-
# parser.add_argument("--default_asr_model", required=False, default="pyf98/owsm_ctc_v3.1_1B", help="Provide the default ASR model.")
|
56 |
-
# parser.add_argument("--default_llm_model", required=False, default="meta-llama/Llama-3.2-1B-Instruct", help="Provide the default ASR model.")
|
57 |
-
# parser.add_argument("--default_tts_model", required=False, default="kan-bayashi/ljspeech_vits", help="Provide the default ASR model.")
|
58 |
-
# parser.add_argument("--upload_to_hub", required=False, default=None, help="Hugging Face dataset to upload user data")
|
59 |
-
# args = parser.parse_args()
|
60 |
-
# access_token=args.HF_TOKEN
|
61 |
-
# ASR_name=args.default_asr_model
|
62 |
-
# LLM_name=args.default_llm_model
|
63 |
-
# TTS_name=args.default_tts_model
|
64 |
-
# ASR_options=args.asr_options.split(",")
|
65 |
-
# LLM_options=args.llm_options.split(",")
|
66 |
-
# TTS_options=args.tts_options.split(",")
|
67 |
-
# Eval_options=args.eval_options.split(",")
|
68 |
-
# upload_to_hub=args.upload_to_hub
|
69 |
-
|
70 |
-
# read_args()
|
71 |
-
from huggingface_hub import HfApi
|
72 |
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
import gradio as gr
|
77 |
-
|
78 |
-
|
79 |
-
import numpy as np
|
80 |
-
|
81 |
-
chat = Chat(2)
|
82 |
-
chat.init_chat({"role": "system", "content": "You are a helpful and friendly AI assistant. The user is talking to you with their voice and you should respond in a conversational style. You are polite, respectful, and aim to provide concise and complete responses of less than 15 words."})
|
83 |
-
user_role = "user"
|
84 |
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
client=None
|
89 |
-
|
90 |
-
latency_ASR=0.0
|
91 |
-
latency_LM=0.0
|
92 |
-
latency_TTS=0.0
|
93 |
-
|
94 |
-
text_str=""
|
95 |
-
asr_output_str=""
|
96 |
-
vad_output=None
|
97 |
audio_output = None
|
98 |
audio_output1 = None
|
99 |
-
LLM_response_arr=[]
|
100 |
-
total_response_arr=[]
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
if TTS_curr_name is not None:
|
105 |
-
if option==TTS_curr_name:
|
106 |
-
return
|
107 |
-
yield gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False)
|
108 |
-
global text2speech
|
109 |
-
TTS_curr_name=option
|
110 |
-
tag = option
|
111 |
-
if tag=="ChatTTS":
|
112 |
-
text2speech = ChatTTSModel()
|
113 |
-
else:
|
114 |
-
text2speech = ESPnetTTSModel(tag)
|
115 |
-
text2speech.warmup()
|
116 |
-
yield gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True)
|
117 |
-
|
118 |
-
def handle_LLM_selection(option):
|
119 |
-
global LLM_curr_name
|
120 |
-
if LLM_curr_name is not None:
|
121 |
-
if option==LLM_curr_name:
|
122 |
-
return
|
123 |
-
yield gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False)
|
124 |
-
global LM_pipe
|
125 |
-
LLM_curr_name=option
|
126 |
-
LM_pipe = HuggingFaceLLM(access_token=access_token,tag = option)
|
127 |
-
LM_pipe.warmup()
|
128 |
-
yield gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True)
|
129 |
-
|
130 |
-
def handle_ASR_selection(option):
|
131 |
-
global ASR_curr_name
|
132 |
-
if option=="librispeech_asr":
|
133 |
-
option="espnet/simpleoier_librispeech_asr_train_asr_conformer7_wavlm_large_raw_en_bpe5000_sp"
|
134 |
-
if ASR_curr_name is not None:
|
135 |
-
if option==ASR_curr_name:
|
136 |
-
return
|
137 |
-
yield gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False)
|
138 |
-
global s2t
|
139 |
-
ASR_curr_name=option
|
140 |
-
if option=="espnet/owsm_v3.1_ebf":
|
141 |
-
s2t = OWSMModel()
|
142 |
-
elif option=="espnet/simpleoier_librispeech_asr_train_asr_conformer7_wavlm_large_raw_en_bpe5000_sp":
|
143 |
-
s2t = ESPnetASRModel(tag=option)
|
144 |
-
elif option=="whisper":
|
145 |
-
s2t = WhisperASRModel()
|
146 |
-
else:
|
147 |
-
s2t = OWSMCTCModel(tag=option)
|
148 |
|
149 |
-
|
150 |
-
|
|
|
151 |
|
152 |
-
def handle_eval_selection(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
global LLM_response_arr
|
154 |
global total_response_arr
|
155 |
-
yield (option,gr.Textbox(visible=True))
|
156 |
-
if option=="Latency":
|
157 |
-
text=
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
yield (None,
|
163 |
-
elif option=="
|
164 |
-
yield (None,
|
165 |
-
elif option=="
|
166 |
-
yield (None,
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
global LLM_response_arr
|
170 |
global total_response_arr
|
171 |
-
yield (option,gr.Textbox(visible=True))
|
172 |
-
if option=="Latency":
|
173 |
-
text=f"Total Latency: {latency_TTS:.2f}"
|
174 |
-
yield (None,text)
|
175 |
-
elif option=="TTS Intelligibility":
|
176 |
-
yield (None,handle_espnet_TTS_intelligibility(TTS_audio_output,LLM_Output))
|
177 |
-
elif option=="TTS Speech Quality":
|
178 |
-
yield (None,TTS_psuedomos(TTS_audio_output))
|
179 |
-
elif option=="Text Dialog Metrics":
|
180 |
-
yield (None,perplexity(LLM_Output.replace("\n"," "))+vert(LLM_response_arr))
|
181 |
-
|
182 |
-
|
183 |
-
global client
|
184 |
-
global LM_pipe
|
185 |
-
global s2t
|
186 |
-
global text2speech
|
187 |
-
yield (gr.Radio(visible=False),gr.Radio(visible=False),gr.Radio(visible=False),gr.Radio(visible=False), gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False),gr.Radio(visible=False),gr.Radio(visible=False))
|
188 |
-
if option=="Cascaded":
|
189 |
-
client=None
|
190 |
-
for _ in handle_selection(TTS_radio):
|
191 |
-
continue
|
192 |
-
for _ in handle_ASR_selection(ASR_radio):
|
193 |
-
continue
|
194 |
-
for _ in handle_LLM_selection(LLM_radio):
|
195 |
-
continue
|
196 |
-
yield (gr.Radio(visible=True),gr.Radio(visible=True),gr.Radio(visible=True),gr.Radio(visible=False),gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True),gr.Radio(visible=True, interactive=True),gr.Radio(visible=False))
|
197 |
else:
|
198 |
-
|
199 |
-
|
200 |
-
LM_pipe=None
|
201 |
-
global ASR_curr_name
|
202 |
-
global LLM_curr_name
|
203 |
-
global TTS_curr_name
|
204 |
-
ASR_curr_name=None
|
205 |
-
LLM_curr_name=None
|
206 |
-
TTS_curr_name=None
|
207 |
-
handle_E2E_selection()
|
208 |
-
yield (gr.Radio(visible=False),gr.Radio(visible=False),gr.Radio(visible=False),gr.Radio(visible=True),gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True),gr.Radio(visible=False),gr.Radio(visible=True, interactive=True))
|
209 |
-
|
210 |
-
|
211 |
-
def handle_E2E_selection():
|
212 |
-
global client
|
213 |
-
if client is None:
|
214 |
-
client = MiniOmniE2EModel()
|
215 |
-
client.warmup()
|
216 |
|
217 |
def start_warmup():
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
continue
|
241 |
-
for _ in handle_ASR_selection(ASR_name):
|
242 |
continue
|
243 |
-
for _ in handle_LLM_selection(LLM_name):
|
244 |
continue
|
245 |
-
dummy_input =
|
|
|
246 |
(3000),
|
247 |
dtype=getattr(torch, "float16"),
|
248 |
device="cpu",
|
249 |
-
|
250 |
-
|
|
|
|
|
|
|
251 |
for opt in Eval_options:
|
252 |
handle_eval_selection(opt, dummy_input, dummy_text, dummy_input, dummy_text)
|
253 |
|
254 |
-
start_warmup()
|
255 |
-
vad_model=WebrtcVADModel()
|
256 |
|
257 |
-
callback = gr.CSVLogger()
|
258 |
-
start_record_time=None
|
259 |
-
enable_btn = gr.Button(interactive=True, visible=True)
|
260 |
-
disable_btn = gr.Button(interactive=False, visible=False)
|
261 |
def flash_buttons():
|
|
|
|
|
|
|
262 |
btn_updates = (enable_btn,) * 8
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
return ip
|
277 |
-
|
278 |
-
|
279 |
-
def vote_last_response(vote_type, request: gr.Request):
|
280 |
-
with open("save_dict.json", "a") as fout:
|
281 |
-
data = {
|
282 |
-
"tstamp": round(time.time(), 4),
|
283 |
-
"type": vote_type,
|
284 |
-
"ip": get_ip(request),
|
285 |
-
}
|
286 |
-
fout.write(json.dumps(data) + "\n")
|
287 |
-
|
288 |
-
|
289 |
-
def natural_vote1_last_response(
|
290 |
-
request: gr.Request
|
291 |
):
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
ip_address1=get_ip(request)
|
328 |
-
print(f"Partially Relevant (voted). ip: {ip_address1}")
|
329 |
-
return ("Partially Relevant",ip_address1,)+(disable_btn,) * 4
|
330 |
-
|
331 |
-
def relevant_vote3_last_response(
|
332 |
-
request: gr.Request
|
333 |
-
):
|
334 |
-
ip_address1=get_ip(request)
|
335 |
-
print(f"Slightly Irrelevant (voted). ip: {ip_address1}")
|
336 |
-
return ("Slightly Irrelevant",ip_address1,)+(disable_btn,) * 4
|
337 |
-
|
338 |
-
def relevant_vote4_last_response(
|
339 |
-
request: gr.Request
|
340 |
-
):
|
341 |
-
ip_address1=get_ip(request)
|
342 |
-
print(f"Completely Irrelevant (voted). ip: {ip_address1}")
|
343 |
-
return ("Completely Irrelevant",ip_address1,)+(disable_btn,) * 4
|
344 |
-
|
345 |
-
import json
|
346 |
-
import time
|
347 |
-
|
348 |
-
def transcribe(stream, new_chunk, TTS_option, ASR_option, LLM_option, type_option):
|
349 |
sr, y = new_chunk
|
350 |
global text_str
|
351 |
global chat
|
@@ -364,219 +399,548 @@ def transcribe(stream, new_chunk, TTS_option, ASR_option, LLM_option, type_optio
|
|
364 |
global total_response_arr
|
365 |
if stream is None:
|
366 |
# Handle user refresh
|
367 |
-
|
368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
369 |
gr.Info("The models are being reloaded due to a browser refresh.")
|
370 |
-
yield (stream,asr_output_box,text_box,audio_box,gr.Audio(visible=False))
|
371 |
-
stream=y
|
372 |
-
|
373 |
-
text_str=""
|
374 |
audio_output = None
|
375 |
audio_output1 = None
|
376 |
else:
|
377 |
-
stream=np.concatenate((stream,y))
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
chat.append({"role": user_role, "content": prompt})
|
410 |
-
chat_messages = chat.to_list()
|
411 |
-
generated_text = LM_pipe(chat_messages)
|
412 |
-
start_TTS_time=time.time()
|
413 |
-
latency_LM=(start_TTS_time - start_LM_time)
|
414 |
-
|
415 |
-
chat.append({"role": "assistant", "content": generated_text})
|
416 |
-
text_str=generated_text
|
417 |
-
audio_output=text2speech(text_str)
|
418 |
-
latency_TTS=(time.time() - start_TTS_time)
|
419 |
-
audio_output1=(orig_sr,stream)
|
420 |
-
stream=y
|
421 |
-
LLM_response_arr.append(text_str.replace("\n"," "))
|
422 |
-
total_response_arr.append(text_str.replace("\n"," "))
|
423 |
-
text_str1=text_str
|
424 |
-
if ((text_str!="") and (start_record_time is None)):
|
425 |
-
start_record_time=time.time()
|
426 |
elif start_record_time is not None:
|
427 |
-
current_record_time=time.time()
|
428 |
-
if current_record_time-start_record_time>300:
|
429 |
-
gr.Info(
|
430 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
431 |
if upload_to_hub is not None:
|
432 |
api.upload_folder(
|
433 |
folder_path="flagged_data_points",
|
434 |
-
path_in_repo="checkpoint_"+str(start_record_time),
|
435 |
repo_id=upload_to_hub,
|
436 |
repo_type="dataset",
|
437 |
token=access_token,
|
438 |
)
|
439 |
-
chat.buffer=[
|
440 |
-
text_str=""
|
441 |
audio_output = None
|
442 |
audio_output1 = None
|
443 |
asr_output_str = ""
|
444 |
start_record_time = None
|
445 |
-
LLM_response_arr=[]
|
446 |
-
total_response_arr=[]
|
447 |
-
shutil.rmtree(
|
448 |
os.mkdir("flagged_data_points")
|
449 |
-
yield (stream,asr_output_str,text_str1, audio_output, audio_output1)
|
450 |
-
yield stream,gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(
|
451 |
-
|
452 |
-
|
453 |
|
|
|
454 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
455 |
with gr.Blocks(
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
)
|
496 |
-
with gr.Row():
|
497 |
-
natural_btn1 = gr.Button(
|
498 |
-
value="Very Natural", visible=False, interactive=False, scale=1
|
499 |
-
)
|
500 |
-
natural_btn2 = gr.Button(
|
501 |
-
value="Somewhat Awkward", visible=False, interactive=False, scale=1
|
502 |
-
)
|
503 |
-
natural_btn3 = gr.Button(value="Very Awkward", visible=False, interactive=False, scale=1)
|
504 |
-
natural_btn4 = gr.Button(
|
505 |
-
value="Unnatural", visible=False, interactive=False, scale=1
|
506 |
-
)
|
507 |
-
with gr.Row():
|
508 |
-
relevant_btn1 = gr.Button(
|
509 |
-
value="Highly Relevant", visible=False, interactive=False, scale=1
|
510 |
-
)
|
511 |
-
relevant_btn2 = gr.Button(
|
512 |
-
value="Partially Relevant", visible=False, interactive=False, scale=1
|
513 |
-
)
|
514 |
-
relevant_btn3 = gr.Button(value="Slightly Irrelevant", visible=False, interactive=False, scale=1)
|
515 |
-
relevant_btn4 = gr.Button(
|
516 |
-
value= "Completely Irrelevant", visible=False, interactive=False, scale=1
|
517 |
-
)
|
518 |
-
with gr.Column(scale=1):
|
519 |
-
output_audio = gr.Audio(label="Output", interactive=False, autoplay=True, visible=True)
|
520 |
-
output_audio1 = gr.Audio(label="Output1", autoplay=False, visible=False)
|
521 |
-
output_asr_text = gr.Textbox(label="ASR output", interactive=False)
|
522 |
-
output_text = gr.Textbox(label="LLM output", interactive=False)
|
523 |
-
eval_radio = gr.Radio(
|
524 |
-
choices=["Latency", "TTS Intelligibility", "TTS Speech Quality", "ASR WER","Text Dialog Metrics"],
|
525 |
-
label="Choose Evaluation metrics:",
|
526 |
)
|
527 |
-
|
528 |
-
|
529 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
530 |
visible=False,
|
531 |
)
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
582 |
demo.launch(share=True)
|
|
|
|
5 |
with open('versa.sh', 'rb') as file:
|
6 |
script = file.read()
|
7 |
rc = call(script, shell=True)
|
8 |
+
|
9 |
import os
|
10 |
import shutil
|
11 |
+
import time
|
12 |
+
from typing import Generator, Optional, Tuple
|
13 |
+
|
14 |
+
import gradio as gr
|
15 |
+
import nltk
|
16 |
+
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
import torch
|
18 |
+
from huggingface_hub import HfApi
|
19 |
+
from pyscripts.utils.dialog_eval.ASR_WER import handle_espnet_ASR_WER
|
20 |
+
from pyscripts.utils.dialog_eval.human_feedback import (
|
21 |
+
natural_vote1_last_response,
|
22 |
+
natural_vote2_last_response,
|
23 |
+
natural_vote3_last_response,
|
24 |
+
natural_vote4_last_response,
|
25 |
+
relevant_vote1_last_response,
|
26 |
+
relevant_vote2_last_response,
|
27 |
+
relevant_vote3_last_response,
|
28 |
+
relevant_vote4_last_response,
|
29 |
+
)
|
30 |
+
from pyscripts.utils.dialog_eval.LLM_Metrics import (
|
31 |
+
DialoGPT_perplexity,
|
32 |
+
bert_score,
|
33 |
+
perplexity,
|
34 |
+
vert,
|
35 |
+
)
|
36 |
+
from pyscripts.utils.dialog_eval.TTS_intelligibility import (
|
37 |
+
handle_espnet_TTS_intelligibility,
|
38 |
+
)
|
39 |
+
from pyscripts.utils.dialog_eval.TTS_speech_quality import TTS_psuedomos
|
40 |
+
|
41 |
+
from espnet2.sds.espnet_model import ESPnetSDSModelInterface
|
42 |
+
|
43 |
+
# ------------------------
|
44 |
+
# Hyperparameters
|
45 |
+
# ------------------------
|
46 |
|
47 |
access_token = os.environ.get("HF_TOKEN")
|
48 |
ASR_name="pyf98/owsm_ctc_v3.1_1B"
|
49 |
LLM_name="meta-llama/Llama-3.2-1B-Instruct"
|
50 |
TTS_name="kan-bayashi/ljspeech_vits"
|
51 |
+
ASR_options="pyf98/owsm_ctc_v3.1_1B,espnet/owsm_ctc_v3.2_ft_1B,espnet/owsm_v3.1_ebf,librispeech_asr,whisper-large".split(",")
|
52 |
LLM_options="meta-llama/Llama-3.2-1B-Instruct,HuggingFaceTB/SmolLM2-1.7B-Instruct".split(",")
|
53 |
TTS_options="kan-bayashi/ljspeech_vits,kan-bayashi/libritts_xvector_vits,kan-bayashi/vctk_multi_spk_vits,ChatTTS".split(",")
|
54 |
Eval_options="Latency,TTS Intelligibility,TTS Speech Quality,ASR WER,Text Dialog Metrics"
|
55 |
upload_to_hub=None
|
56 |
+
dialogue_model = ESPnetSDSModelInterface(
|
57 |
+
ASR_name, LLM_name, TTS_name, "Cascaded", access_token
|
58 |
+
)
|
59 |
ASR_curr_name=None
|
60 |
LLM_curr_name=None
|
61 |
TTS_curr_name=None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
+
latency_ASR = 0.0
|
64 |
+
latency_LM = 0.0
|
65 |
+
latency_TTS = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
+
text_str = ""
|
68 |
+
asr_output_str = ""
|
69 |
+
vad_output = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
audio_output = None
|
71 |
audio_output1 = None
|
72 |
+
LLM_response_arr = []
|
73 |
+
total_response_arr = []
|
74 |
+
callback = gr.CSVLogger()
|
75 |
+
start_record_time = None
|
76 |
+
enable_btn = gr.Button(interactive=True, visible=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
+
# ------------------------
|
79 |
+
# Function Definitions
|
80 |
+
# ------------------------
|
81 |
|
82 |
+
def handle_eval_selection(
|
83 |
+
option: str,
|
84 |
+
TTS_audio_output: str,
|
85 |
+
LLM_Output: str,
|
86 |
+
ASR_audio_output: str,
|
87 |
+
ASR_transcript: str,
|
88 |
+
):
|
89 |
+
"""
|
90 |
+
Handles the evaluation of a selected metric based on
|
91 |
+
user input and provided outputs.
|
92 |
+
|
93 |
+
This function evaluates different aspects of a
|
94 |
+
casacaded conversational AI pipeline, such as:
|
95 |
+
Latency, TTS intelligibility, TTS speech quality,
|
96 |
+
ASR WER, and text dialog metrics.
|
97 |
+
It is designed to integrate with Gradio via
|
98 |
+
multiple yield statements,
|
99 |
+
allowing updates to be displayed in real time.
|
100 |
+
|
101 |
+
Parameters:
|
102 |
+
----------
|
103 |
+
option : str
|
104 |
+
The evaluation metric selected by the user.
|
105 |
+
Supported options include:
|
106 |
+
- "Latency"
|
107 |
+
- "TTS Intelligibility"
|
108 |
+
- "TTS Speech Quality"
|
109 |
+
- "ASR WER"
|
110 |
+
- "Text Dialog Metrics"
|
111 |
+
TTS_audio_output : np.ndarray
|
112 |
+
The audio output generated by the TTS module for evaluation.
|
113 |
+
LLM_Output : str
|
114 |
+
The text output generated by the LLM module for evaluation.
|
115 |
+
ASR_audio_output : np.ndarray
|
116 |
+
The audio input/output used for ASR evaluation.
|
117 |
+
ASR_transcript : str
|
118 |
+
The transcript generated by the ASR module for evaluation.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
-------
|
122 |
+
str
|
123 |
+
A string representation of the evaluation results.
|
124 |
+
The specific result depends on the selected evaluation metric:
|
125 |
+
- "Latency": Latencies of ASR, LLM, and TTS modules.
|
126 |
+
- "TTS Intelligibility": A range of scores indicating how intelligible
|
127 |
+
the TTS audio output is based on different reference ASR models.
|
128 |
+
- "TTS Speech Quality": A range of scores representing the
|
129 |
+
speech quality of the TTS audio output.
|
130 |
+
- "ASR WER": The Word Error Rate (WER) of the ASR output
|
131 |
+
based on different judge ASR models.
|
132 |
+
- "Text Dialog Metrics": A combination of perplexity,
|
133 |
+
diversity metrics, and relevance scores for the dialog.
|
134 |
+
|
135 |
+
Raises:
|
136 |
+
------
|
137 |
+
ValueError
|
138 |
+
If the `option` parameter does not match any supported evaluation metric.
|
139 |
+
|
140 |
+
Example:
|
141 |
+
-------
|
142 |
+
>>> result = handle_eval_selection(
|
143 |
+
option="Latency",
|
144 |
+
TTS_audio_output=audio_array,
|
145 |
+
LLM_Output="Generated response",
|
146 |
+
ASR_audio_output=audio_input,
|
147 |
+
ASR_transcript="Expected transcript"
|
148 |
+
)
|
149 |
+
>>> print(result)
|
150 |
+
"ASR Latency: 0.14
|
151 |
+
LLM Latency: 0.42
|
152 |
+
TTS Latency: 0.21"
|
153 |
+
"""
|
154 |
global LLM_response_arr
|
155 |
global total_response_arr
|
156 |
+
yield (option, gr.Textbox(visible=True))
|
157 |
+
if option == "Latency":
|
158 |
+
text = (
|
159 |
+
f"ASR Latency: {latency_ASR:.2f}\n"
|
160 |
+
f"LLM Latency: {latency_LM:.2f}\n"
|
161 |
+
f"TTS Latency: {latency_TTS:.2f}"
|
162 |
+
)
|
163 |
+
yield (None, text)
|
164 |
+
elif option == "TTS Intelligibility":
|
165 |
+
yield (None, handle_espnet_TTS_intelligibility(TTS_audio_output, LLM_Output))
|
166 |
+
elif option == "TTS Speech Quality":
|
167 |
+
yield (None, TTS_psuedomos(TTS_audio_output))
|
168 |
+
elif option == "ASR WER":
|
169 |
+
yield (None, handle_espnet_ASR_WER(ASR_audio_output, ASR_transcript))
|
170 |
+
elif option == "Text Dialog Metrics":
|
171 |
+
yield (
|
172 |
+
None,
|
173 |
+
perplexity(LLM_Output.replace("\n", " "))
|
174 |
+
+ vert(LLM_response_arr)
|
175 |
+
+ bert_score(total_response_arr)
|
176 |
+
+ DialoGPT_perplexity(
|
177 |
+
ASR_transcript.replace("\n", " "), LLM_Output.replace("\n", " ")
|
178 |
+
),
|
179 |
+
)
|
180 |
+
elif option is None:
|
181 |
+
return
|
182 |
+
else:
|
183 |
+
raise ValueError(f"Unknown option: {option}")
|
184 |
+
|
185 |
+
|
186 |
+
def handle_eval_selection_E2E(
|
187 |
+
option: str,
|
188 |
+
TTS_audio_output: str,
|
189 |
+
LLM_Output: str,
|
190 |
+
):
|
191 |
+
"""
|
192 |
+
Handles the evaluation of a selected metric based on user input
|
193 |
+
and provided outputs.
|
194 |
+
|
195 |
+
This function evaluates different aspects of an E2E
|
196 |
+
conversational AI model, such as:
|
197 |
+
Latency, TTS intelligibility, TTS speech quality, and
|
198 |
+
text dialog metrics.
|
199 |
+
It is designed to integrate with Gradio via
|
200 |
+
multiple yield statements,
|
201 |
+
allowing updates to be displayed in real time.
|
202 |
+
|
203 |
+
Parameters:
|
204 |
+
----------
|
205 |
+
option : str
|
206 |
+
The evaluation metric selected by the user.
|
207 |
+
Supported options include:
|
208 |
+
- "Latency"
|
209 |
+
- "TTS Intelligibility"
|
210 |
+
- "TTS Speech Quality"
|
211 |
+
- "Text Dialog Metrics"
|
212 |
+
TTS_audio_output : np.ndarray
|
213 |
+
The audio output generated by the TTS module for evaluation.
|
214 |
+
LLM_Output : str
|
215 |
+
The text output generated by the LLM module for evaluation.
|
216 |
+
|
217 |
+
Returns:
|
218 |
+
-------
|
219 |
+
str
|
220 |
+
A string representation of the evaluation results.
|
221 |
+
The specific result depends on the selected evaluation metric:
|
222 |
+
- "Latency": Latency of the entire system.
|
223 |
+
- "TTS Intelligibility": A range of scores indicating how intelligible the
|
224 |
+
TTS audio output is based on different reference ASR models.
|
225 |
+
- "TTS Speech Quality": A range of scores representing the
|
226 |
+
speech quality of the TTS audio output.
|
227 |
+
- "Text Dialog Metrics": A combination of perplexity and
|
228 |
+
diversity metrics for the dialog.
|
229 |
+
|
230 |
+
Raises:
|
231 |
+
------
|
232 |
+
ValueError
|
233 |
+
If the `option` parameter does not match any supported evaluation metric.
|
234 |
+
|
235 |
+
Example:
|
236 |
+
-------
|
237 |
+
>>> result = handle_eval_selection(
|
238 |
+
option="Latency",
|
239 |
+
TTS_audio_output=audio_array,
|
240 |
+
LLM_Output="Generated response",
|
241 |
+
)
|
242 |
+
>>> print(result)
|
243 |
+
"Total Latency: 2.34"
|
244 |
+
"""
|
245 |
global LLM_response_arr
|
246 |
global total_response_arr
|
247 |
+
yield (option, gr.Textbox(visible=True))
|
248 |
+
if option == "Latency":
|
249 |
+
text = f"Total Latency: {latency_TTS:.2f}"
|
250 |
+
yield (None, text)
|
251 |
+
elif option == "TTS Intelligibility":
|
252 |
+
yield (None, handle_espnet_TTS_intelligibility(TTS_audio_output, LLM_Output))
|
253 |
+
elif option == "TTS Speech Quality":
|
254 |
+
yield (None, TTS_psuedomos(TTS_audio_output))
|
255 |
+
elif option == "Text Dialog Metrics":
|
256 |
+
yield (None, perplexity(LLM_Output.replace("\n", " ")) + vert(LLM_response_arr))
|
257 |
+
elif option is None:
|
258 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
else:
|
260 |
+
raise ValueError(f"Unknown option: {option}")
|
261 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
|
263 |
def start_warmup():
|
264 |
+
"""
|
265 |
+
Initializes and warms up the dialogue and evaluation model.
|
266 |
+
|
267 |
+
This function is designed to ensure that all
|
268 |
+
components of the dialogue model are pre-loaded
|
269 |
+
and ready for execution, avoiding delays during runtime.
|
270 |
+
"""
|
271 |
+
global dialogue_model
|
272 |
+
global ASR_options
|
273 |
+
global LLM_options
|
274 |
+
global TTS_options
|
275 |
+
global ASR_name
|
276 |
+
global LLM_name
|
277 |
+
global TTS_name
|
278 |
+
for opt_count in range(len(ASR_options)):
|
279 |
+
opt = ASR_options[opt_count]
|
280 |
+
try:
|
281 |
+
for _ in dialogue_model.handle_ASR_selection(opt):
|
282 |
+
continue
|
283 |
+
except Exception:
|
284 |
+
print("Removing " + opt + " from ASR options since it cannot be loaded.")
|
285 |
+
ASR_options = ASR_options[:opt_count] + ASR_options[(opt_count + 1) :]
|
286 |
+
if opt == ASR_name:
|
287 |
+
ASR_name = ASR_options[0]
|
288 |
+
for opt_count in range(len(LLM_options)):
|
289 |
+
opt = LLM_options[opt_count]
|
290 |
+
try:
|
291 |
+
for _ in dialogue_model.handle_LLM_selection(opt):
|
292 |
+
continue
|
293 |
+
except Exception:
|
294 |
+
print("Removing " + opt + " from LLM options since it cannot be loaded.")
|
295 |
+
LLM_options = LLM_options[:opt_count] + LLM_options[(opt_count + 1) :]
|
296 |
+
if opt == LLM_name:
|
297 |
+
LLM_name = LLM_options[0]
|
298 |
+
for opt_count in range(len(TTS_options)):
|
299 |
+
opt = TTS_options[opt_count]
|
300 |
+
try:
|
301 |
+
for _ in dialogue_model.handle_TTS_selection(opt):
|
302 |
+
continue
|
303 |
+
except Exception:
|
304 |
+
print("Removing " + opt + " from TTS options since it cannot be loaded.")
|
305 |
+
TTS_options = TTS_options[:opt_count] + TTS_options[(opt_count + 1) :]
|
306 |
+
if opt == TTS_name:
|
307 |
+
TTS_name = TTS_options[0]
|
308 |
+
dialogue_model.handle_E2E_selection()
|
309 |
+
dialogue_model.client = None
|
310 |
+
for _ in dialogue_model.handle_TTS_selection(TTS_name):
|
311 |
continue
|
312 |
+
for _ in dialogue_model.handle_ASR_selection(ASR_name):
|
313 |
continue
|
314 |
+
for _ in dialogue_model.handle_LLM_selection(LLM_name):
|
315 |
continue
|
316 |
+
dummy_input = (
|
317 |
+
torch.randn(
|
318 |
(3000),
|
319 |
dtype=getattr(torch, "float16"),
|
320 |
device="cpu",
|
321 |
+
)
|
322 |
+
.cpu()
|
323 |
+
.numpy()
|
324 |
+
)
|
325 |
+
dummy_text = "This is dummy text"
|
326 |
for opt in Eval_options:
|
327 |
handle_eval_selection(opt, dummy_input, dummy_text, dummy_input, dummy_text)
|
328 |
|
|
|
|
|
329 |
|
|
|
|
|
|
|
|
|
330 |
def flash_buttons():
|
331 |
+
"""
|
332 |
+
Enables human feedback buttons after displaying system output.
|
333 |
+
"""
|
334 |
btn_updates = (enable_btn,) * 8
|
335 |
+
yield (
|
336 |
+
"",
|
337 |
+
"",
|
338 |
+
) + btn_updates
|
339 |
+
|
340 |
+
|
341 |
+
def transcribe(
|
342 |
+
stream: np.ndarray,
|
343 |
+
new_chunk: Tuple[int, np.ndarray],
|
344 |
+
TTS_option: str,
|
345 |
+
ASR_option: str,
|
346 |
+
LLM_option: str,
|
347 |
+
type_option: str,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
348 |
):
|
349 |
+
"""
|
350 |
+
Processes and transcribes an audio stream in real-time.
|
351 |
+
|
352 |
+
This function handles the transcription of audio input
|
353 |
+
and its transformation through a cascaded
|
354 |
+
or E2E conversational AI system.
|
355 |
+
It dynamically updates the transcription, text generation,
|
356 |
+
and synthesized speech output, while managing global states and latencies.
|
357 |
+
|
358 |
+
Args:
|
359 |
+
stream: The current audio stream buffer.
|
360 |
+
`None` if the stream is being reset (e.g., after user refresh).
|
361 |
+
new_chunk: A tuple containing:
|
362 |
+
- `sr`: Sample rate of the new audio chunk.
|
363 |
+
- `y`: New audio data chunk.
|
364 |
+
TTS_option: Selected TTS model option.
|
365 |
+
ASR_option: Selected ASR model option.
|
366 |
+
LLM_option: Selected LLM model option.
|
367 |
+
type_option: Type of system ("Cascaded" or "E2E").
|
368 |
+
|
369 |
+
Yields:
|
370 |
+
Tuple[Optional[np.ndarray], Optional[str], Optional[str],
|
371 |
+
Optional[Tuple[int, np.ndarray]], Optional[Tuple[int, np.ndarray]]]:
|
372 |
+
A tuple containing:
|
373 |
+
- Updated stream buffer.
|
374 |
+
- ASR output text.
|
375 |
+
- Generated LLM output text.
|
376 |
+
- Audio output as a tuple of sample rate and audio waveform.
|
377 |
+
- User input audio as a tuple of sample rate and audio waveform.
|
378 |
+
|
379 |
+
Notes:
|
380 |
+
- Resets the session if the transcription exceeds 5 minutes.
|
381 |
+
- Updates the Gradio interface elements dynamically.
|
382 |
+
- Manages latencies.
|
383 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
sr, y = new_chunk
|
385 |
global text_str
|
386 |
global chat
|
|
|
399 |
global total_response_arr
|
400 |
if stream is None:
|
401 |
# Handle user refresh
|
402 |
+
for (
|
403 |
+
_,
|
404 |
+
_,
|
405 |
+
_,
|
406 |
+
_,
|
407 |
+
asr_output_box,
|
408 |
+
text_box,
|
409 |
+
audio_box,
|
410 |
+
_,
|
411 |
+
_,
|
412 |
+
) in dialogue_model.handle_type_selection(
|
413 |
+
type_option, TTS_option, ASR_option, LLM_option
|
414 |
+
):
|
415 |
gr.Info("The models are being reloaded due to a browser refresh.")
|
416 |
+
yield (stream, asr_output_box, text_box, audio_box, gr.Audio(visible=False))
|
417 |
+
stream = y
|
418 |
+
text_str = ""
|
|
|
419 |
audio_output = None
|
420 |
audio_output1 = None
|
421 |
else:
|
422 |
+
stream = np.concatenate((stream, y))
|
423 |
+
(
|
424 |
+
asr_output_str,
|
425 |
+
text_str,
|
426 |
+
audio_output,
|
427 |
+
audio_output1,
|
428 |
+
latency_ASR,
|
429 |
+
latency_LM,
|
430 |
+
latency_TTS,
|
431 |
+
stream,
|
432 |
+
change,
|
433 |
+
) = dialogue_model(
|
434 |
+
y,
|
435 |
+
sr,
|
436 |
+
stream,
|
437 |
+
asr_output_str,
|
438 |
+
text_str,
|
439 |
+
audio_output,
|
440 |
+
audio_output1,
|
441 |
+
latency_ASR,
|
442 |
+
latency_LM,
|
443 |
+
latency_TTS,
|
444 |
+
)
|
445 |
+
text_str1 = text_str
|
446 |
+
if change:
|
447 |
+
print("Output changed")
|
448 |
+
if asr_output_str != "":
|
449 |
+
total_response_arr.append(asr_output_str.replace("\n", " "))
|
450 |
+
LLM_response_arr.append(text_str.replace("\n", " "))
|
451 |
+
total_response_arr.append(text_str.replace("\n", " "))
|
452 |
+
if (text_str != "") and (start_record_time is None):
|
453 |
+
start_record_time = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
454 |
elif start_record_time is not None:
|
455 |
+
current_record_time = time.time()
|
456 |
+
if current_record_time - start_record_time > 300:
|
457 |
+
gr.Info(
|
458 |
+
"Conversations are limited to 5 minutes. "
|
459 |
+
"The session will restart in approximately 60 seconds. "
|
460 |
+
"Please wait for the demo to reset. "
|
461 |
+
"Close this message once you have read it.",
|
462 |
+
duration=None,
|
463 |
+
)
|
464 |
+
yield stream, gr.Textbox(visible=False), gr.Textbox(
|
465 |
+
visible=False
|
466 |
+
), gr.Audio(visible=False), gr.Audio(visible=False)
|
467 |
if upload_to_hub is not None:
|
468 |
api.upload_folder(
|
469 |
folder_path="flagged_data_points",
|
470 |
+
path_in_repo="checkpoint_" + str(start_record_time),
|
471 |
repo_id=upload_to_hub,
|
472 |
repo_type="dataset",
|
473 |
token=access_token,
|
474 |
)
|
475 |
+
dialogue_model.chat.buffer = []
|
476 |
+
text_str = ""
|
477 |
audio_output = None
|
478 |
audio_output1 = None
|
479 |
asr_output_str = ""
|
480 |
start_record_time = None
|
481 |
+
LLM_response_arr = []
|
482 |
+
total_response_arr = []
|
483 |
+
shutil.rmtree("flagged_data_points")
|
484 |
os.mkdir("flagged_data_points")
|
485 |
+
yield (stream, asr_output_str, text_str1, audio_output, audio_output1)
|
486 |
+
yield stream, gr.Textbox(visible=True), gr.Textbox(visible=True), gr.Audio(
|
487 |
+
visible=True
|
488 |
+
), gr.Audio(visible=False)
|
489 |
|
490 |
+
yield (stream, asr_output_str, text_str1, audio_output, audio_output1)
|
491 |
|
492 |
+
|
493 |
+
# ------------------------
|
494 |
+
# Executable Script
|
495 |
+
# ------------------------
|
496 |
+
api = HfApi()
|
497 |
+
nltk.download("averaged_perceptron_tagger_eng")
|
498 |
+
start_warmup()
|
499 |
with gr.Blocks(
|
500 |
+
title="E2E Spoken Dialog System",
|
501 |
+
) as demo:
|
502 |
+
with gr.Row():
|
503 |
+
gr.Markdown(
|
504 |
+
"""
|
505 |
+
## ESPnet-SDS
|
506 |
+
Welcome to our unified web interface for various cascaded and
|
507 |
+
E2E spoken dialogue systems built using ESPnet-SDS toolkit,
|
508 |
+
supporting real-time automated evaluation metrics, and
|
509 |
+
human-in-the-loop feedback collection.
|
510 |
+
|
511 |
+
For more details on how to use the app, refer to the [README]
|
512 |
+
(https://github.com/siddhu001/espnet/tree/sds_demo_recipe/egs2/TEMPLATE/sds1#how-to-use).
|
513 |
+
"""
|
514 |
+
)
|
515 |
+
with gr.Row():
|
516 |
+
with gr.Column(scale=1):
|
517 |
+
user_audio = gr.Audio(
|
518 |
+
sources=["microphone"],
|
519 |
+
streaming=True,
|
520 |
+
waveform_options=gr.WaveformOptions(sample_rate=16000),
|
521 |
+
)
|
522 |
+
with gr.Row():
|
523 |
+
type_radio = gr.Radio(
|
524 |
+
choices=["Cascaded", "E2E"],
|
525 |
+
label="Choose type of Spoken Dialog:",
|
526 |
+
value="Cascaded",
|
527 |
+
)
|
528 |
+
with gr.Row():
|
529 |
+
ASR_radio = gr.Radio(
|
530 |
+
choices=ASR_options,
|
531 |
+
label="Choose ASR:",
|
532 |
+
value=ASR_name,
|
533 |
+
)
|
534 |
+
with gr.Row():
|
535 |
+
LLM_radio = gr.Radio(
|
536 |
+
choices=LLM_options,
|
537 |
+
label="Choose LLM:",
|
538 |
+
value=LLM_name,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
539 |
)
|
540 |
+
with gr.Row():
|
541 |
+
radio = gr.Radio(
|
542 |
+
choices=TTS_options,
|
543 |
+
label="Choose TTS:",
|
544 |
+
value=TTS_name,
|
545 |
+
)
|
546 |
+
with gr.Row():
|
547 |
+
E2Eradio = gr.Radio(
|
548 |
+
choices=["mini-omni"],
|
549 |
+
label="Choose E2E model:",
|
550 |
+
value="mini-omni",
|
551 |
visible=False,
|
552 |
)
|
553 |
+
with gr.Row():
|
554 |
+
feedback_btn = gr.Button(
|
555 |
+
value=(
|
556 |
+
"Please provide your feedback "
|
557 |
+
"after each system response below."
|
558 |
+
),
|
559 |
+
visible=True,
|
560 |
+
interactive=False,
|
561 |
+
elem_id="button",
|
562 |
+
)
|
563 |
+
with gr.Row():
|
564 |
+
natural_btn1 = gr.Button(
|
565 |
+
value="Very Natural", visible=False, interactive=False, scale=1
|
566 |
+
)
|
567 |
+
natural_btn2 = gr.Button(
|
568 |
+
value="Somewhat Awkward", visible=False, interactive=False, scale=1
|
569 |
+
)
|
570 |
+
natural_btn3 = gr.Button(
|
571 |
+
value="Very Awkward", visible=False, interactive=False, scale=1
|
572 |
+
)
|
573 |
+
natural_btn4 = gr.Button(
|
574 |
+
value="Unnatural", visible=False, interactive=False, scale=1
|
575 |
+
)
|
576 |
+
with gr.Row():
|
577 |
+
relevant_btn1 = gr.Button(
|
578 |
+
value="Highly Relevant", visible=False, interactive=False, scale=1
|
579 |
+
)
|
580 |
+
relevant_btn2 = gr.Button(
|
581 |
+
value="Partially Relevant",
|
582 |
+
visible=False,
|
583 |
+
interactive=False,
|
584 |
+
scale=1,
|
585 |
+
)
|
586 |
+
relevant_btn3 = gr.Button(
|
587 |
+
value="Slightly Irrelevant",
|
588 |
+
visible=False,
|
589 |
+
interactive=False,
|
590 |
+
scale=1,
|
591 |
+
)
|
592 |
+
relevant_btn4 = gr.Button(
|
593 |
+
value="Completely Irrelevant",
|
594 |
+
visible=False,
|
595 |
+
interactive=False,
|
596 |
+
scale=1,
|
597 |
+
)
|
598 |
+
with gr.Column(scale=1):
|
599 |
+
output_audio = gr.Audio(label="Output", autoplay=True, visible=True)
|
600 |
+
output_audio1 = gr.Audio(label="Output1", autoplay=False, visible=False)
|
601 |
+
output_asr_text = gr.Textbox(label="ASR output")
|
602 |
+
output_text = gr.Textbox(label="LLM output")
|
603 |
+
eval_radio = gr.Radio(
|
604 |
+
choices=[
|
605 |
+
"Latency",
|
606 |
+
"TTS Intelligibility",
|
607 |
+
"TTS Speech Quality",
|
608 |
+
"ASR WER",
|
609 |
+
"Text Dialog Metrics",
|
610 |
+
],
|
611 |
+
label="Choose Evaluation metrics:",
|
612 |
+
)
|
613 |
+
eval_radio_E2E = gr.Radio(
|
614 |
+
choices=[
|
615 |
+
"Latency",
|
616 |
+
"TTS Intelligibility",
|
617 |
+
"TTS Speech Quality",
|
618 |
+
"Text Dialog Metrics",
|
619 |
+
],
|
620 |
+
label="Choose Evaluation metrics:",
|
621 |
+
visible=False,
|
622 |
+
)
|
623 |
+
output_eval_text = gr.Textbox(label="Evaluation Results")
|
624 |
+
state = gr.State()
|
625 |
+
with gr.Row():
|
626 |
+
privacy_text = gr.Textbox(
|
627 |
+
label="Privacy Notice",
|
628 |
+
interactive=False,
|
629 |
+
value=(
|
630 |
+
"By using this demo, you acknowledge that"
|
631 |
+
"interactions with this dialog system are collected "
|
632 |
+
"for research and improvement purposes. The data "
|
633 |
+
"will only be used to enhance the performance and "
|
634 |
+
"understanding of the system. If you have any "
|
635 |
+
"concerns about data collection, please discontinue "
|
636 |
+
"use."
|
637 |
+
),
|
638 |
+
)
|
639 |
+
|
640 |
+
btn_list = [
|
641 |
+
natural_btn1,
|
642 |
+
natural_btn2,
|
643 |
+
natural_btn3,
|
644 |
+
natural_btn4,
|
645 |
+
relevant_btn1,
|
646 |
+
relevant_btn2,
|
647 |
+
relevant_btn3,
|
648 |
+
relevant_btn4,
|
649 |
+
]
|
650 |
+
natural_btn_list = [
|
651 |
+
natural_btn1,
|
652 |
+
natural_btn2,
|
653 |
+
natural_btn3,
|
654 |
+
natural_btn4,
|
655 |
+
]
|
656 |
+
relevant_btn_list = [
|
657 |
+
relevant_btn1,
|
658 |
+
relevant_btn2,
|
659 |
+
relevant_btn3,
|
660 |
+
relevant_btn4,
|
661 |
+
]
|
662 |
+
natural_response = gr.Textbox(
|
663 |
+
label="natural_response", visible=False, interactive=False
|
664 |
+
)
|
665 |
+
diversity_response = gr.Textbox(
|
666 |
+
label="diversity_response", visible=False, interactive=False
|
667 |
+
)
|
668 |
+
ip_address = gr.Textbox(label="ip_address", visible=False, interactive=False)
|
669 |
+
callback.setup(
|
670 |
+
[
|
671 |
+
user_audio,
|
672 |
+
output_asr_text,
|
673 |
+
output_text,
|
674 |
+
output_audio,
|
675 |
+
output_audio1,
|
676 |
+
type_radio,
|
677 |
+
ASR_radio,
|
678 |
+
LLM_radio,
|
679 |
+
radio,
|
680 |
+
E2Eradio,
|
681 |
+
natural_response,
|
682 |
+
diversity_response,
|
683 |
+
ip_address,
|
684 |
+
],
|
685 |
+
"flagged_data_points",
|
686 |
+
)
|
687 |
+
user_audio.stream(
|
688 |
+
transcribe,
|
689 |
+
inputs=[state, user_audio, radio, ASR_radio, LLM_radio, type_radio],
|
690 |
+
outputs=[state, output_asr_text, output_text, output_audio, output_audio1],
|
691 |
+
).then(
|
692 |
+
lambda *args: callback.flag(list(args)), [user_audio], None, preprocess=False
|
693 |
+
)
|
694 |
+
radio.change(
|
695 |
+
fn=dialogue_model.handle_TTS_selection,
|
696 |
+
inputs=[radio],
|
697 |
+
outputs=[output_asr_text, output_text, output_audio],
|
698 |
+
)
|
699 |
+
LLM_radio.change(
|
700 |
+
fn=dialogue_model.handle_LLM_selection,
|
701 |
+
inputs=[LLM_radio],
|
702 |
+
outputs=[output_asr_text, output_text, output_audio],
|
703 |
+
)
|
704 |
+
ASR_radio.change(
|
705 |
+
fn=dialogue_model.handle_ASR_selection,
|
706 |
+
inputs=[ASR_radio],
|
707 |
+
outputs=[output_asr_text, output_text, output_audio],
|
708 |
+
)
|
709 |
+
eval_radio.change(
|
710 |
+
fn=handle_eval_selection,
|
711 |
+
inputs=[eval_radio, output_audio, output_text, output_audio1, output_asr_text],
|
712 |
+
outputs=[eval_radio, output_eval_text],
|
713 |
+
)
|
714 |
+
eval_radio_E2E.change(
|
715 |
+
fn=handle_eval_selection_E2E,
|
716 |
+
inputs=[eval_radio_E2E, output_audio, output_text],
|
717 |
+
outputs=[eval_radio_E2E, output_eval_text],
|
718 |
+
)
|
719 |
+
type_radio.change(
|
720 |
+
fn=dialogue_model.handle_type_selection,
|
721 |
+
inputs=[type_radio, radio, ASR_radio, LLM_radio],
|
722 |
+
outputs=[
|
723 |
+
radio,
|
724 |
+
ASR_radio,
|
725 |
+
LLM_radio,
|
726 |
+
E2Eradio,
|
727 |
+
output_asr_text,
|
728 |
+
output_text,
|
729 |
+
output_audio,
|
730 |
+
eval_radio,
|
731 |
+
eval_radio_E2E,
|
732 |
+
],
|
733 |
+
)
|
734 |
+
output_audio.play(
|
735 |
+
flash_buttons, [], [natural_response, diversity_response] + btn_list
|
736 |
+
).then(
|
737 |
+
lambda *args: callback.flag(list(args)),
|
738 |
+
[
|
739 |
+
user_audio,
|
740 |
+
output_asr_text,
|
741 |
+
output_text,
|
742 |
+
output_audio,
|
743 |
+
output_audio1,
|
744 |
+
type_radio,
|
745 |
+
ASR_radio,
|
746 |
+
LLM_radio,
|
747 |
+
radio,
|
748 |
+
E2Eradio,
|
749 |
+
],
|
750 |
+
None,
|
751 |
+
preprocess=False,
|
752 |
+
)
|
753 |
+
natural_btn1.click(
|
754 |
+
natural_vote1_last_response,
|
755 |
+
[],
|
756 |
+
[natural_response, ip_address] + natural_btn_list,
|
757 |
+
).then(
|
758 |
+
lambda *args: callback.flag(list(args)),
|
759 |
+
[
|
760 |
+
user_audio,
|
761 |
+
output_asr_text,
|
762 |
+
output_text,
|
763 |
+
output_audio,
|
764 |
+
output_audio1,
|
765 |
+
type_radio,
|
766 |
+
ASR_radio,
|
767 |
+
LLM_radio,
|
768 |
+
radio,
|
769 |
+
E2Eradio,
|
770 |
+
natural_response,
|
771 |
+
diversity_response,
|
772 |
+
ip_address,
|
773 |
+
],
|
774 |
+
None,
|
775 |
+
preprocess=False,
|
776 |
+
)
|
777 |
+
natural_btn2.click(
|
778 |
+
natural_vote2_last_response,
|
779 |
+
[],
|
780 |
+
[natural_response, ip_address] + natural_btn_list,
|
781 |
+
).then(
|
782 |
+
lambda *args: callback.flag(list(args)),
|
783 |
+
[
|
784 |
+
user_audio,
|
785 |
+
output_asr_text,
|
786 |
+
output_text,
|
787 |
+
output_audio,
|
788 |
+
output_audio1,
|
789 |
+
type_radio,
|
790 |
+
ASR_radio,
|
791 |
+
LLM_radio,
|
792 |
+
radio,
|
793 |
+
E2Eradio,
|
794 |
+
natural_response,
|
795 |
+
diversity_response,
|
796 |
+
ip_address,
|
797 |
+
],
|
798 |
+
None,
|
799 |
+
preprocess=False,
|
800 |
+
)
|
801 |
+
natural_btn3.click(
|
802 |
+
natural_vote3_last_response,
|
803 |
+
[],
|
804 |
+
[natural_response, ip_address] + natural_btn_list,
|
805 |
+
).then(
|
806 |
+
lambda *args: callback.flag(list(args)),
|
807 |
+
[
|
808 |
+
user_audio,
|
809 |
+
output_asr_text,
|
810 |
+
output_text,
|
811 |
+
output_audio,
|
812 |
+
output_audio1,
|
813 |
+
type_radio,
|
814 |
+
ASR_radio,
|
815 |
+
LLM_radio,
|
816 |
+
radio,
|
817 |
+
E2Eradio,
|
818 |
+
natural_response,
|
819 |
+
diversity_response,
|
820 |
+
ip_address,
|
821 |
+
],
|
822 |
+
None,
|
823 |
+
preprocess=False,
|
824 |
+
)
|
825 |
+
natural_btn4.click(
|
826 |
+
natural_vote4_last_response,
|
827 |
+
[],
|
828 |
+
[natural_response, ip_address] + natural_btn_list,
|
829 |
+
).then(
|
830 |
+
lambda *args: callback.flag(list(args)),
|
831 |
+
[
|
832 |
+
user_audio,
|
833 |
+
output_asr_text,
|
834 |
+
output_text,
|
835 |
+
output_audio,
|
836 |
+
output_audio1,
|
837 |
+
type_radio,
|
838 |
+
ASR_radio,
|
839 |
+
LLM_radio,
|
840 |
+
radio,
|
841 |
+
E2Eradio,
|
842 |
+
natural_response,
|
843 |
+
diversity_response,
|
844 |
+
ip_address,
|
845 |
+
],
|
846 |
+
None,
|
847 |
+
preprocess=False,
|
848 |
+
)
|
849 |
+
relevant_btn1.click(
|
850 |
+
relevant_vote1_last_response,
|
851 |
+
[],
|
852 |
+
[diversity_response, ip_address] + relevant_btn_list,
|
853 |
+
).then(
|
854 |
+
lambda *args: callback.flag(list(args)),
|
855 |
+
[
|
856 |
+
user_audio,
|
857 |
+
output_asr_text,
|
858 |
+
output_text,
|
859 |
+
output_audio,
|
860 |
+
output_audio1,
|
861 |
+
type_radio,
|
862 |
+
ASR_radio,
|
863 |
+
LLM_radio,
|
864 |
+
radio,
|
865 |
+
E2Eradio,
|
866 |
+
natural_response,
|
867 |
+
diversity_response,
|
868 |
+
ip_address,
|
869 |
+
],
|
870 |
+
None,
|
871 |
+
preprocess=False,
|
872 |
+
)
|
873 |
+
relevant_btn2.click(
|
874 |
+
relevant_vote2_last_response,
|
875 |
+
[],
|
876 |
+
[diversity_response, ip_address] + relevant_btn_list,
|
877 |
+
).then(
|
878 |
+
lambda *args: callback.flag(list(args)),
|
879 |
+
[
|
880 |
+
user_audio,
|
881 |
+
output_asr_text,
|
882 |
+
output_text,
|
883 |
+
output_audio,
|
884 |
+
output_audio1,
|
885 |
+
type_radio,
|
886 |
+
ASR_radio,
|
887 |
+
LLM_radio,
|
888 |
+
radio,
|
889 |
+
E2Eradio,
|
890 |
+
natural_response,
|
891 |
+
diversity_response,
|
892 |
+
ip_address,
|
893 |
+
],
|
894 |
+
None,
|
895 |
+
preprocess=False,
|
896 |
+
)
|
897 |
+
relevant_btn3.click(
|
898 |
+
relevant_vote3_last_response,
|
899 |
+
[],
|
900 |
+
[diversity_response, ip_address] + relevant_btn_list,
|
901 |
+
).then(
|
902 |
+
lambda *args: callback.flag(list(args)),
|
903 |
+
[
|
904 |
+
user_audio,
|
905 |
+
output_asr_text,
|
906 |
+
output_text,
|
907 |
+
output_audio,
|
908 |
+
output_audio1,
|
909 |
+
type_radio,
|
910 |
+
ASR_radio,
|
911 |
+
LLM_radio,
|
912 |
+
radio,
|
913 |
+
E2Eradio,
|
914 |
+
natural_response,
|
915 |
+
diversity_response,
|
916 |
+
ip_address,
|
917 |
+
],
|
918 |
+
None,
|
919 |
+
preprocess=False,
|
920 |
+
)
|
921 |
+
relevant_btn4.click(
|
922 |
+
relevant_vote4_last_response,
|
923 |
+
[],
|
924 |
+
[diversity_response, ip_address] + relevant_btn_list,
|
925 |
+
).then(
|
926 |
+
lambda *args: callback.flag(list(args)),
|
927 |
+
[
|
928 |
+
user_audio,
|
929 |
+
output_asr_text,
|
930 |
+
output_text,
|
931 |
+
output_audio,
|
932 |
+
output_audio1,
|
933 |
+
type_radio,
|
934 |
+
ASR_radio,
|
935 |
+
LLM_radio,
|
936 |
+
radio,
|
937 |
+
E2Eradio,
|
938 |
+
natural_response,
|
939 |
+
diversity_response,
|
940 |
+
ip_address,
|
941 |
+
],
|
942 |
+
None,
|
943 |
+
preprocess=False,
|
944 |
+
)
|
945 |
demo.launch(share=True)
|
946 |
+
|
pyscripts/utils/dialog_eval/ASR_WER.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from espnet2.sds.utils.utils import int2float
|
6 |
+
|
7 |
+
|
8 |
+
def handle_espnet_ASR_WER(
|
9 |
+
ASR_audio_output: Tuple[int, np.ndarray], ASR_transcript: str
|
10 |
+
) -> str:
|
11 |
+
"""
|
12 |
+
Compute and return Word Error Rate (WER) and Character Error Rate (CER) metrics
|
13 |
+
for multiple judge ASR systems (ESPnet, OWSM, Whisper) using the Versa library.
|
14 |
+
|
15 |
+
This function performs the following:
|
16 |
+
1. Imports necessary metrics and setup functions from Versa.
|
17 |
+
2. Prepares configuration arguments for each ASR system (ESPnet, OWSM, Whisper).
|
18 |
+
3. Runs the Levenshtein-based WER/CER calculations.
|
19 |
+
4. Returns a formatted string summarizing WER and CER
|
20 |
+
results for reference produced by each ASR system.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
ASR_audio_output (tuple):
|
24 |
+
A tuple where:
|
25 |
+
- The first element is the frame rate.
|
26 |
+
- The second element is the audio signal (NumPy array).
|
27 |
+
ASR_transcript (str):
|
28 |
+
The transcript produced by the ASR model in the cascaded
|
29 |
+
conversational AI pipeline.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
str:
|
33 |
+
A formatted string showing the WER and CER percentages
|
34 |
+
for ESPnet, OWSM, and Whisper. Example output:
|
35 |
+
|
36 |
+
"ESPnet WER: 10.50
|
37 |
+
ESPnet CER: 7.20
|
38 |
+
OWSM WER: 11.30
|
39 |
+
OWSM CER: 8.00
|
40 |
+
Whisper WER: 9.25
|
41 |
+
Whisper CER: 6.50"
|
42 |
+
|
43 |
+
Raises:
|
44 |
+
ImportError:
|
45 |
+
If Versa is not installed or cannot be imported.
|
46 |
+
|
47 |
+
Example:
|
48 |
+
>>> asr_audio_output = (16000, audio_array)
|
49 |
+
>>> asr_transcript = "This is the ASR transcript."
|
50 |
+
>>> result = handle_espnet_ASR_WER(asr_audio_output, asr_transcript)
|
51 |
+
>>> print(result)
|
52 |
+
"ESPnet WER: 10.50
|
53 |
+
ESPnet CER: 7.20
|
54 |
+
OWSM WER: 11.30
|
55 |
+
OWSM CER: 8.00
|
56 |
+
Whisper WER: 9.25
|
57 |
+
Whisper CER: 6.50"
|
58 |
+
"""
|
59 |
+
try:
|
60 |
+
from versa import (
|
61 |
+
espnet_levenshtein_metric,
|
62 |
+
espnet_wer_setup,
|
63 |
+
owsm_levenshtein_metric,
|
64 |
+
owsm_wer_setup,
|
65 |
+
whisper_levenshtein_metric,
|
66 |
+
whisper_wer_setup,
|
67 |
+
)
|
68 |
+
except Exception as e:
|
69 |
+
print("Error: Versa is not properly installed.")
|
70 |
+
raise e
|
71 |
+
score_modules_espnet = {
|
72 |
+
"module": espnet_levenshtein_metric,
|
73 |
+
"args": espnet_wer_setup(
|
74 |
+
model_tag="default",
|
75 |
+
beam_size=1,
|
76 |
+
text_cleaner="whisper_en",
|
77 |
+
use_gpu=True,
|
78 |
+
),
|
79 |
+
}
|
80 |
+
dict1 = score_modules_espnet["module"](
|
81 |
+
score_modules_espnet["args"],
|
82 |
+
int2float(ASR_audio_output[1]),
|
83 |
+
ASR_transcript,
|
84 |
+
ASR_audio_output[0],
|
85 |
+
)
|
86 |
+
espnet_wer = (
|
87 |
+
dict1["espnet_wer_delete"]
|
88 |
+
+ dict1["espnet_wer_insert"]
|
89 |
+
+ dict1["espnet_wer_replace"]
|
90 |
+
) / (
|
91 |
+
dict1["espnet_wer_insert"]
|
92 |
+
+ dict1["espnet_wer_replace"]
|
93 |
+
+ dict1["espnet_wer_equal"]
|
94 |
+
)
|
95 |
+
espnet_cer = (
|
96 |
+
dict1["espnet_cer_delete"]
|
97 |
+
+ dict1["espnet_cer_insert"]
|
98 |
+
+ dict1["espnet_cer_replace"]
|
99 |
+
) / (
|
100 |
+
dict1["espnet_cer_insert"]
|
101 |
+
+ dict1["espnet_cer_replace"]
|
102 |
+
+ dict1["espnet_cer_equal"]
|
103 |
+
)
|
104 |
+
score_modules_owsm = {
|
105 |
+
"module": owsm_levenshtein_metric,
|
106 |
+
"args": owsm_wer_setup(
|
107 |
+
model_tag="default",
|
108 |
+
beam_size=1,
|
109 |
+
text_cleaner="whisper_en",
|
110 |
+
use_gpu=True,
|
111 |
+
),
|
112 |
+
}
|
113 |
+
dict1 = score_modules_owsm["module"](
|
114 |
+
score_modules_owsm["args"],
|
115 |
+
int2float(ASR_audio_output[1]),
|
116 |
+
ASR_transcript,
|
117 |
+
ASR_audio_output[0],
|
118 |
+
)
|
119 |
+
owsm_wer = (
|
120 |
+
dict1["owsm_wer_delete"] + dict1["owsm_wer_insert"] + dict1["owsm_wer_replace"]
|
121 |
+
) / (dict1["owsm_wer_insert"] + dict1["owsm_wer_replace"] + dict1["owsm_wer_equal"])
|
122 |
+
owsm_cer = (
|
123 |
+
dict1["owsm_cer_delete"] + dict1["owsm_cer_insert"] + dict1["owsm_cer_replace"]
|
124 |
+
) / (dict1["owsm_cer_insert"] + dict1["owsm_cer_replace"] + dict1["owsm_cer_equal"])
|
125 |
+
score_modules_whisper = {
|
126 |
+
"module": whisper_levenshtein_metric,
|
127 |
+
"args": whisper_wer_setup(
|
128 |
+
model_tag="default",
|
129 |
+
beam_size=1,
|
130 |
+
text_cleaner="whisper_en",
|
131 |
+
use_gpu=True,
|
132 |
+
),
|
133 |
+
}
|
134 |
+
dict1 = score_modules_whisper["module"](
|
135 |
+
score_modules_whisper["args"],
|
136 |
+
int2float(ASR_audio_output[1]),
|
137 |
+
ASR_transcript,
|
138 |
+
ASR_audio_output[0],
|
139 |
+
)
|
140 |
+
whisper_wer = (
|
141 |
+
dict1["whisper_wer_delete"]
|
142 |
+
+ dict1["whisper_wer_insert"]
|
143 |
+
+ dict1["whisper_wer_replace"]
|
144 |
+
) / (
|
145 |
+
dict1["whisper_wer_insert"]
|
146 |
+
+ dict1["whisper_wer_replace"]
|
147 |
+
+ dict1["whisper_wer_equal"]
|
148 |
+
)
|
149 |
+
whisper_cer = (
|
150 |
+
dict1["whisper_cer_delete"]
|
151 |
+
+ dict1["whisper_cer_insert"]
|
152 |
+
+ dict1["whisper_cer_replace"]
|
153 |
+
) / (
|
154 |
+
dict1["whisper_cer_insert"]
|
155 |
+
+ dict1["whisper_cer_replace"]
|
156 |
+
+ dict1["whisper_cer_equal"]
|
157 |
+
)
|
158 |
+
return (
|
159 |
+
f"ESPnet WER: {espnet_wer*100:.2f}\n"
|
160 |
+
f"ESPnet CER: {espnet_cer*100:.2f}\n"
|
161 |
+
f"OWSM WER: {owsm_wer*100:.2f}\n"
|
162 |
+
f"OWSM CER: {owsm_cer*100:.2f}\n"
|
163 |
+
f"Whisper WER: {whisper_wer*100:.2f}\n"
|
164 |
+
f"Whisper CER: {whisper_cer*100:.2f}"
|
165 |
+
)
|
pyscripts/utils/dialog_eval/LLM_Metrics.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from multiprocessing import Pool
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from pyscripts.utils.dialog_eval.vert import (
|
7 |
+
get_auto_bleu2_geometric,
|
8 |
+
get_self_bleu2_geometric,
|
9 |
+
run_f,
|
10 |
+
)
|
11 |
+
from scipy.stats import gmean
|
12 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
13 |
+
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
|
14 |
+
|
15 |
+
|
16 |
+
def perplexity(LLM_Output: str, model_id: str = "gpt2") -> str:
|
17 |
+
"""
|
18 |
+
Compute the perplexity of the given text using a specified model from the
|
19 |
+
`evaluate` library (default: GPT-2).
|
20 |
+
|
21 |
+
Args:
|
22 |
+
LLM_Output str:
|
23 |
+
The text (string) for which perplexity is to be computed.
|
24 |
+
model_id (str, optional):
|
25 |
+
The identifier of the model to use for computing
|
26 |
+
perplexity. Defaults to "gpt2".
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
str:
|
30 |
+
A formatted string showing the perplexity of the
|
31 |
+
provided text(s), for example:
|
32 |
+
"Perplexity: 45.23\n"
|
33 |
+
|
34 |
+
Raises:
|
35 |
+
ImportError:
|
36 |
+
If the `evaluate` library is not installed or cannot be imported.
|
37 |
+
|
38 |
+
Example:
|
39 |
+
>>> text = "Hello world, this is a test."
|
40 |
+
>>> result = perplexity(text, model_id="gpt2")
|
41 |
+
>>> print(result)
|
42 |
+
"Perplexity: 27.34\n"
|
43 |
+
"""
|
44 |
+
try:
|
45 |
+
import evaluate
|
46 |
+
except Exception as e:
|
47 |
+
print("Error: evaluate is not properly installed.")
|
48 |
+
raise e
|
49 |
+
perplexity = evaluate.load("perplexity", module_type="metric")
|
50 |
+
results = perplexity.compute(model_id=model_id, predictions=[LLM_Output])
|
51 |
+
return f"Perplexity: {results['mean_perplexity']:.2f}\n"
|
52 |
+
|
53 |
+
|
54 |
+
def vert(LLM_response_arr: List[str]) -> str:
|
55 |
+
"""
|
56 |
+
Calculate and return Self BLEU-2, Auto BLEU-2 and VERT-2
|
57 |
+
metrics for a list of LLM responses.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
LLM_response_arr (List[str]):
|
61 |
+
A list of responses (strings) generated by the language
|
62 |
+
model acting as text dialog response generator.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
str:
|
66 |
+
A formatted string that includes each computed metric and the final
|
67 |
+
VERT value, for example:
|
68 |
+
|
69 |
+
"Self-BLEU2-geometric: 42.13
|
70 |
+
Auto-BLEU2-geometric: 38.94
|
71 |
+
VERT: 40.5
|
72 |
+
"
|
73 |
+
|
74 |
+
Example:
|
75 |
+
>>> # Suppose we have the following LLM responses:
|
76 |
+
>>> responses = ["Hello world", "Foo bar", "Lorem ipsum dolor sit amet"]
|
77 |
+
>>> result = vert(responses)
|
78 |
+
>>> print(result)
|
79 |
+
"Self-BLEU2-geometric: 42.13
|
80 |
+
Auto-BLEU2-geometric: 38.94
|
81 |
+
VERT: 40.5
|
82 |
+
"
|
83 |
+
"""
|
84 |
+
terms = [x.strip().split() for x in LLM_response_arr]
|
85 |
+
|
86 |
+
tasks = [
|
87 |
+
("Self-BLEU2-geometric", get_self_bleu2_geometric),
|
88 |
+
("Auto-BLEU2-geometric", get_auto_bleu2_geometric),
|
89 |
+
]
|
90 |
+
n_processes = min(16, len(tasks))
|
91 |
+
with Pool(n_processes) as pool:
|
92 |
+
metrics = pool.map(run_f, [(t[1], terms) for t in tasks])
|
93 |
+
metric_arr = []
|
94 |
+
str1 = ""
|
95 |
+
for (metric_name, _), metric in zip(tasks, metrics):
|
96 |
+
metric, sem = np.mean(metric), np.std(metric) / np.sqrt(len(metric))
|
97 |
+
|
98 |
+
metric, sem = [round(100 * x, 2) for x in [metric, sem]]
|
99 |
+
metric_arr.append(metric)
|
100 |
+
|
101 |
+
str1 += f"{metric_name}: {metric}\n"
|
102 |
+
str1 += f"VERT: {round(gmean(metric_arr), 2)}\n"
|
103 |
+
return str1
|
104 |
+
|
105 |
+
|
106 |
+
def bert_score(
|
107 |
+
total_response_arr: List[str], bert_model_name: str = "bert-base-uncased"
|
108 |
+
) -> str:
|
109 |
+
"""
|
110 |
+
Compute a cosine similarity score between the concatenated
|
111 |
+
context (all but the last element)
|
112 |
+
and the final response (last element) using a BERT-based model.
|
113 |
+
This serves as a simplified
|
114 |
+
measure of how closely the response aligns with the preceding context semantically.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
total_response_arr (List[str]):
|
118 |
+
A list of strings. The last element represents the response,
|
119 |
+
while all other elements
|
120 |
+
are treated as the context.
|
121 |
+
bert_model_name (str, optional):
|
122 |
+
The name or path of the BERT model to use (from the Hugging Face Model Hub).
|
123 |
+
Defaults to "bert-base-uncased".
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
str:
|
127 |
+
A string containing the cosine similarity
|
128 |
+
(as a percentage) followed by a newline.
|
129 |
+
For example:
|
130 |
+
"Cosine Similarity: 85.67\n"
|
131 |
+
|
132 |
+
Example:
|
133 |
+
>>> total_responses = [
|
134 |
+
... "User: Hi, how are you?",
|
135 |
+
... "Assistant: I'm good! How can I help you today?",
|
136 |
+
... "User: Can you tell me a joke?",
|
137 |
+
... "Assistant: Sure! Here's one: Why did the chicken join a band?"
|
138 |
+
... ]
|
139 |
+
>>> result = bert_score(total_responses, bert_model_name="bert-base-uncased")
|
140 |
+
>>> print(result)
|
141 |
+
"Cosine Similarity: 75.89\n"
|
142 |
+
"""
|
143 |
+
|
144 |
+
def cosine_similarity_context_response(context, response, model, tokenizer):
|
145 |
+
# Tokenize and encode both context and response
|
146 |
+
context_inputs = tokenizer(context, return_tensors="pt", truncation=True)
|
147 |
+
response_inputs = tokenizer(response, return_tensors="pt", truncation=True)
|
148 |
+
for k in context_inputs:
|
149 |
+
context_inputs[k] = context_inputs[k].cuda()
|
150 |
+
for k in response_inputs:
|
151 |
+
response_inputs[k] = response_inputs[k].cuda()
|
152 |
+
|
153 |
+
# Get embeddings from the model
|
154 |
+
with torch.no_grad():
|
155 |
+
context_embedding = model(**context_inputs).last_hidden_state.mean(dim=1)
|
156 |
+
response_embedding = model(**response_inputs).last_hidden_state.mean(dim=1)
|
157 |
+
|
158 |
+
# Compute cosine similarity
|
159 |
+
similarity = cosine_similarity(
|
160 |
+
context_embedding.cpu().numpy(), response_embedding.cpu().numpy()
|
161 |
+
)
|
162 |
+
return similarity[0][0]
|
163 |
+
|
164 |
+
bert_model = AutoModel.from_pretrained(bert_model_name).cuda()
|
165 |
+
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
|
166 |
+
similarity = cosine_similarity_context_response(
|
167 |
+
" ".join(total_response_arr[:-1]),
|
168 |
+
total_response_arr[-1],
|
169 |
+
bert_model,
|
170 |
+
bert_tokenizer,
|
171 |
+
)
|
172 |
+
return f"Cosine Similarity: {similarity*100:.2f}" + "\n"
|
173 |
+
|
174 |
+
|
175 |
+
def DialoGPT_perplexity(
|
176 |
+
user_utterance: str,
|
177 |
+
response: str,
|
178 |
+
dialog_model_name: str = "microsoft/DialoGPT-medium",
|
179 |
+
) -> str:
|
180 |
+
"""
|
181 |
+
Compute the perplexity of a response given a user utterance using a pre-trained
|
182 |
+
DialoGPT model. The function loads DialoGPT (medium by default)
|
183 |
+
from the Hugging Face Model Hub, then calculates the perplexity
|
184 |
+
for the
|
185 |
+
(context + response) sequence.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
user_utterance (str):
|
189 |
+
The user utterance preceding the model's response.
|
190 |
+
response (str):
|
191 |
+
The generated response whose perplexity needs to be evaluated.
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
str:
|
195 |
+
A formatted string containing the DialoGPT perplexity score. For example:
|
196 |
+
"DialoGPT Perplexity: 25.67\n"
|
197 |
+
|
198 |
+
Example:
|
199 |
+
>>> user_text = "Hi, how are you today?"
|
200 |
+
>>> system_response = "I'm good, thank you! How can I help you?"
|
201 |
+
>>> result = DialoGPT_perplexity(user_text, system_response)
|
202 |
+
>>> print(result)
|
203 |
+
"DialoGPT Perplexity: 31.45\n"
|
204 |
+
"""
|
205 |
+
|
206 |
+
def evaluate_response_with_dialoGPT(context, response, model, tokenizer):
|
207 |
+
"""
|
208 |
+
Evaluate the appropriateness of a response based on the
|
209 |
+
given context using DialoGPT.
|
210 |
+
|
211 |
+
Args:
|
212 |
+
context (str): The dialogue context (previous conversation).
|
213 |
+
response (str): The generated response to evaluate.
|
214 |
+
model: Pre-trained DialoGPT model.
|
215 |
+
tokenizer: Corresponding tokenizer for the DialoGPT model.
|
216 |
+
|
217 |
+
Returns:
|
218 |
+
float: Perplexity score of the response given the context.
|
219 |
+
"""
|
220 |
+
model.eval()
|
221 |
+
|
222 |
+
# Combine context and response as input
|
223 |
+
input_text = context + tokenizer.eos_token + response + tokenizer.eos_token
|
224 |
+
inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
|
225 |
+
inputs["input_ids"] = inputs["input_ids"].cuda()
|
226 |
+
inputs["attention_mask"] = inputs["attention_mask"].cuda()
|
227 |
+
# import pdb;pdb.set_trace()
|
228 |
+
|
229 |
+
# Compute model outputs and loss
|
230 |
+
with torch.no_grad():
|
231 |
+
outputs = model(**inputs, labels=inputs["input_ids"].cuda())
|
232 |
+
loss = outputs.loss
|
233 |
+
|
234 |
+
# Calculate perplexity
|
235 |
+
perplexity = torch.exp(loss)
|
236 |
+
return perplexity.cpu().item()
|
237 |
+
|
238 |
+
# Load DialoGPT model and tokenizer
|
239 |
+
model_name = dialog_model_name
|
240 |
+
model = AutoModelForCausalLM.from_pretrained(model_name).cuda()
|
241 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
242 |
+
perplexity = evaluate_response_with_dialoGPT(
|
243 |
+
user_utterance, response, model, tokenizer
|
244 |
+
)
|
245 |
+
return f"DialoGPT Perplexity: {perplexity:.2f}" + "\n"
|
pyscripts/utils/dialog_eval/TTS_intelligibility.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from espnet2.sds.utils.utils import int2float
|
6 |
+
|
7 |
+
|
8 |
+
def handle_espnet_TTS_intelligibility(
|
9 |
+
TTS_audio_output: Tuple[int, np.ndarray], LLM_Output: str
|
10 |
+
) -> str:
|
11 |
+
"""
|
12 |
+
Compute and return Word Error Rate (WER) and Character Error Rate (CER) metrics
|
13 |
+
for multiple ASR systems (ESPnet, OWSM, Whisper) using the Versa library.
|
14 |
+
|
15 |
+
This function:
|
16 |
+
1. Imports the necessary metrics and setup functions from Versa.
|
17 |
+
2. Prepares configuration arguments for each ASR system (ESPnet, OWSM, Whisper).
|
18 |
+
3. Runs the Levenshtein-based WER/CER calculations on the provided TTS audio.
|
19 |
+
4. Returns a formatted string summarizing WER and CER results
|
20 |
+
for hypotheses produced
|
21 |
+
by each ASR system when transcribing the TTS audio, using
|
22 |
+
the LLM output as the reference text.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
TTS_audio_output (Tuple[int, np.ndarray]):
|
26 |
+
A tuple consisting of:
|
27 |
+
- The first element (int): the frame rate of the audio.
|
28 |
+
- The second element (np.ndarray):
|
29 |
+
the audio signal (e.g., a NumPy array).
|
30 |
+
LLM_Output (str):
|
31 |
+
The reference text generated by the LLM, which serves as the ground truth
|
32 |
+
for evaluating the TTS audio.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
str:
|
36 |
+
A formatted string showing the WER and CER percentages
|
37 |
+
for ESPnet, OWSM, and Whisper.
|
38 |
+
Example:
|
39 |
+
|
40 |
+
ESPnet WER: 10.50
|
41 |
+
ESPnet CER: 7.20
|
42 |
+
OWSM WER: 11.30
|
43 |
+
OWSM CER: 8.00
|
44 |
+
Whisper WER: 9.25
|
45 |
+
Whisper CER: 6.50
|
46 |
+
|
47 |
+
Raises:
|
48 |
+
ImportError:
|
49 |
+
If the Versa library is not installed or cannot be imported.
|
50 |
+
|
51 |
+
Example:
|
52 |
+
>>> tts_audio_output = (16000, audio_array)
|
53 |
+
>>> llm_output = "This is the reference text for evaluation."
|
54 |
+
>>> result = handle_espnet_TTS_intelligibility(tts_audio_output, llm_output)
|
55 |
+
>>> print(result)
|
56 |
+
ESPnet WER: 10.50
|
57 |
+
ESPnet CER: 7.20
|
58 |
+
OWSM WER: 11.30
|
59 |
+
OWSM CER: 8.00
|
60 |
+
Whisper WER: 9.25
|
61 |
+
Whisper CER: 6.50
|
62 |
+
"""
|
63 |
+
try:
|
64 |
+
from versa import (
|
65 |
+
espnet_levenshtein_metric,
|
66 |
+
espnet_wer_setup,
|
67 |
+
owsm_levenshtein_metric,
|
68 |
+
owsm_wer_setup,
|
69 |
+
whisper_levenshtein_metric,
|
70 |
+
whisper_wer_setup,
|
71 |
+
)
|
72 |
+
except Exception as e:
|
73 |
+
print("Error: Versa is not properly installed.")
|
74 |
+
raise e
|
75 |
+
score_modules_espnet = {
|
76 |
+
"module": espnet_levenshtein_metric,
|
77 |
+
"args": espnet_wer_setup(
|
78 |
+
model_tag="default",
|
79 |
+
beam_size=1,
|
80 |
+
text_cleaner="whisper_en",
|
81 |
+
use_gpu=True,
|
82 |
+
),
|
83 |
+
}
|
84 |
+
dict1 = score_modules_espnet["module"](
|
85 |
+
score_modules_espnet["args"],
|
86 |
+
int2float(TTS_audio_output[1]),
|
87 |
+
LLM_Output,
|
88 |
+
TTS_audio_output[0],
|
89 |
+
)
|
90 |
+
espnet_wer = (
|
91 |
+
dict1["espnet_wer_delete"]
|
92 |
+
+ dict1["espnet_wer_insert"]
|
93 |
+
+ dict1["espnet_wer_replace"]
|
94 |
+
) / (
|
95 |
+
dict1["espnet_wer_delete"]
|
96 |
+
+ dict1["espnet_wer_replace"]
|
97 |
+
+ dict1["espnet_wer_equal"]
|
98 |
+
)
|
99 |
+
espnet_cer = (
|
100 |
+
dict1["espnet_cer_delete"]
|
101 |
+
+ dict1["espnet_cer_insert"]
|
102 |
+
+ dict1["espnet_cer_replace"]
|
103 |
+
) / (
|
104 |
+
dict1["espnet_cer_delete"]
|
105 |
+
+ dict1["espnet_cer_replace"]
|
106 |
+
+ dict1["espnet_cer_equal"]
|
107 |
+
)
|
108 |
+
score_modules_owsm = {
|
109 |
+
"module": owsm_levenshtein_metric,
|
110 |
+
"args": owsm_wer_setup(
|
111 |
+
model_tag="default",
|
112 |
+
beam_size=1,
|
113 |
+
text_cleaner="whisper_en",
|
114 |
+
use_gpu=True,
|
115 |
+
),
|
116 |
+
}
|
117 |
+
dict1 = score_modules_owsm["module"](
|
118 |
+
score_modules_owsm["args"],
|
119 |
+
int2float(TTS_audio_output[1]),
|
120 |
+
LLM_Output,
|
121 |
+
TTS_audio_output[0],
|
122 |
+
)
|
123 |
+
owsm_wer = (
|
124 |
+
dict1["owsm_wer_delete"] + dict1["owsm_wer_insert"] + dict1["owsm_wer_replace"]
|
125 |
+
) / (dict1["owsm_wer_delete"] + dict1["owsm_wer_replace"] + dict1["owsm_wer_equal"])
|
126 |
+
owsm_cer = (
|
127 |
+
dict1["owsm_cer_delete"] + dict1["owsm_cer_insert"] + dict1["owsm_cer_replace"]
|
128 |
+
) / (dict1["owsm_cer_delete"] + dict1["owsm_cer_replace"] + dict1["owsm_cer_equal"])
|
129 |
+
score_modules_whisper = {
|
130 |
+
"module": whisper_levenshtein_metric,
|
131 |
+
"args": whisper_wer_setup(
|
132 |
+
model_tag="default",
|
133 |
+
beam_size=1,
|
134 |
+
text_cleaner="whisper_en",
|
135 |
+
use_gpu=True,
|
136 |
+
),
|
137 |
+
}
|
138 |
+
dict1 = score_modules_whisper["module"](
|
139 |
+
score_modules_whisper["args"],
|
140 |
+
int2float(TTS_audio_output[1]),
|
141 |
+
LLM_Output,
|
142 |
+
TTS_audio_output[0],
|
143 |
+
)
|
144 |
+
whisper_wer = (
|
145 |
+
dict1["whisper_wer_delete"]
|
146 |
+
+ dict1["whisper_wer_insert"]
|
147 |
+
+ dict1["whisper_wer_replace"]
|
148 |
+
) / (
|
149 |
+
dict1["whisper_wer_delete"]
|
150 |
+
+ dict1["whisper_wer_replace"]
|
151 |
+
+ dict1["whisper_wer_equal"]
|
152 |
+
)
|
153 |
+
whisper_cer = (
|
154 |
+
dict1["whisper_cer_delete"]
|
155 |
+
+ dict1["whisper_cer_insert"]
|
156 |
+
+ dict1["whisper_cer_replace"]
|
157 |
+
) / (
|
158 |
+
dict1["whisper_cer_delete"]
|
159 |
+
+ dict1["whisper_cer_replace"]
|
160 |
+
+ dict1["whisper_cer_equal"]
|
161 |
+
)
|
162 |
+
return (
|
163 |
+
f"ESPnet WER: {espnet_wer*100:.2f}\n"
|
164 |
+
f"ESPnet CER: {espnet_cer*100:.2f}\n"
|
165 |
+
f"OWSM WER: {owsm_wer*100:.2f}\n"
|
166 |
+
f"OWSM CER: {owsm_cer*100:.2f}\n"
|
167 |
+
f"Whisper WER: {whisper_wer*100:.2f}\n"
|
168 |
+
f"Whisper CER: {whisper_cer*100:.2f}"
|
169 |
+
)
|
pyscripts/utils/dialog_eval/TTS_speech_quality.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from espnet2.sds.utils.utils import int2float
|
6 |
+
|
7 |
+
|
8 |
+
def TTS_psuedomos(TTS_audio_output: Tuple[int, np.ndarray]) -> str:
|
9 |
+
"""
|
10 |
+
Compute and return speech quality metrics
|
11 |
+
for the given synthesized audio output
|
12 |
+
using the Versa library.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
TTS_audio_output (Tuple[int, np.ndarray]):
|
16 |
+
A tuple containing:
|
17 |
+
- The first element (int): The frame rate of the audio.
|
18 |
+
- The second element (np.ndarray): The audio signal,
|
19 |
+
typically a NumPy array.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
str:
|
23 |
+
A formatted string containing each metric name
|
24 |
+
and its corresponding score, for example:
|
25 |
+
|
26 |
+
utmos: 3.54
|
27 |
+
dnsmos: 3.47
|
28 |
+
plcmos: 3.62
|
29 |
+
sheet_ssqa: 4.03
|
30 |
+
|
31 |
+
Raises:
|
32 |
+
ImportError:
|
33 |
+
If the Versa library is not installed or cannot be imported.
|
34 |
+
|
35 |
+
Example:
|
36 |
+
>>> tts_audio_output = (16000, audio_array)
|
37 |
+
>>> result = TTS_psuedomos(tts_audio_output)
|
38 |
+
>>> print(result)
|
39 |
+
utmos: 3.54
|
40 |
+
dnsmos: 3.47
|
41 |
+
plcmos: 3.62
|
42 |
+
sheet_ssqa: 4.03
|
43 |
+
"""
|
44 |
+
try:
|
45 |
+
from versa import (
|
46 |
+
pseudo_mos_metric,
|
47 |
+
pseudo_mos_setup,
|
48 |
+
sheet_ssqa,
|
49 |
+
sheet_ssqa_setup,
|
50 |
+
)
|
51 |
+
except Exception as e:
|
52 |
+
print("Error: Versa is not properly installed.")
|
53 |
+
raise e
|
54 |
+
|
55 |
+
predictor_dict, predictor_fs = pseudo_mos_setup(
|
56 |
+
use_gpu=True,
|
57 |
+
predictor_types=["utmos", "dnsmos", "plcmos"],
|
58 |
+
predictor_args={
|
59 |
+
"utmos": {"fs": 16000},
|
60 |
+
"dnsmos": {"fs": 16000},
|
61 |
+
"plcmos": {"fs": 16000},
|
62 |
+
},
|
63 |
+
)
|
64 |
+
score_modules = {
|
65 |
+
"module": pseudo_mos_metric,
|
66 |
+
"args": {
|
67 |
+
"predictor_dict": predictor_dict,
|
68 |
+
"predictor_fs": predictor_fs,
|
69 |
+
"use_gpu": True,
|
70 |
+
},
|
71 |
+
}
|
72 |
+
dict1 = score_modules["module"](
|
73 |
+
int2float(TTS_audio_output[1]),
|
74 |
+
TTS_audio_output[0],
|
75 |
+
**score_modules["args"],
|
76 |
+
)
|
77 |
+
str1 = ""
|
78 |
+
for k in dict1:
|
79 |
+
str1 = str1 + f"{k}: {dict1[k]:.2f}\n"
|
80 |
+
sheet_model = sheet_ssqa_setup(
|
81 |
+
model_tag="default",
|
82 |
+
model_path=None,
|
83 |
+
model_config=None,
|
84 |
+
use_gpu=True,
|
85 |
+
)
|
86 |
+
score_modules = {
|
87 |
+
"module": sheet_ssqa,
|
88 |
+
"args": {"model": sheet_model, "use_gpu": True},
|
89 |
+
}
|
90 |
+
dict1 = score_modules["module"](
|
91 |
+
score_modules["args"]["model"],
|
92 |
+
int2float(TTS_audio_output[1]),
|
93 |
+
TTS_audio_output[0],
|
94 |
+
use_gpu=score_modules["args"]["use_gpu"],
|
95 |
+
)
|
96 |
+
for k in dict1:
|
97 |
+
str1 = str1 + f"{k}: {dict1[k]:.2f}\n"
|
98 |
+
return str1
|
pyscripts/utils/dialog_eval/__pycache__/ASR_WER.cpython-39.pyc
ADDED
Binary file (4.12 kB). View file
|
|
pyscripts/utils/dialog_eval/__pycache__/LLM_Metrics.cpython-39.pyc
ADDED
Binary file (8.51 kB). View file
|
|
pyscripts/utils/dialog_eval/__pycache__/TTS_intelligibility.cpython-39.pyc
ADDED
Binary file (4.34 kB). View file
|
|
pyscripts/utils/dialog_eval/__pycache__/TTS_speech_quality.cpython-39.pyc
ADDED
Binary file (2.39 kB). View file
|
|
pyscripts/utils/dialog_eval/__pycache__/human_feedback.cpython-39.pyc
ADDED
Binary file (7.34 kB). View file
|
|
pyscripts/utils/dialog_eval/__pycache__/vert.cpython-39.pyc
ADDED
Binary file (9.13 kB). View file
|
|
pyscripts/utils/dialog_eval/human_feedback.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
disable_btn = gr.Button(interactive=False, visible=False)
|
4 |
+
|
5 |
+
|
6 |
+
def get_ip(request: gr.Request) -> str:
|
7 |
+
"""
|
8 |
+
Retrieve the IP address from an incoming HTTP request.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
request (gr.Request):
|
12 |
+
The incoming HTTP request from which the IP address will be extracted.
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
str:
|
16 |
+
The IP address as a string.
|
17 |
+
"""
|
18 |
+
if "cf-connecting-ip" in request.headers:
|
19 |
+
ip = request.headers["cf-connecting-ip"]
|
20 |
+
elif "x-forwarded-for" in request.headers:
|
21 |
+
ip = request.headers["x-forwarded-for"]
|
22 |
+
if "," in ip:
|
23 |
+
ip = ip.split(",")[0]
|
24 |
+
else:
|
25 |
+
ip = request.client.host
|
26 |
+
return ip
|
27 |
+
|
28 |
+
|
29 |
+
def natural_vote1_last_response(request: gr.Request):
|
30 |
+
"""
|
31 |
+
Handle a user vote for naturalness as "Very Natural".
|
32 |
+
|
33 |
+
|
34 |
+
Args:
|
35 |
+
request (gr.Request):
|
36 |
+
The Gradio request object providing access to HTTP headers and metadata.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
tuple:
|
40 |
+
A tuple containing:
|
41 |
+
("Very Natural", <ip_address>, (disable_btn,) * 4)
|
42 |
+
|
43 |
+
- "Very Natural": The selected vote or label.
|
44 |
+
- <ip_address>: The IP address of the client retrieved from the request.
|
45 |
+
- disable_btn: An object repeated four times,
|
46 |
+
to disable natural vote buttons.
|
47 |
+
"""
|
48 |
+
ip_address1 = get_ip(request)
|
49 |
+
print(f"Very Natural (voted). ip: {ip_address1}")
|
50 |
+
return (
|
51 |
+
"Very Natural",
|
52 |
+
ip_address1,
|
53 |
+
) + (disable_btn,) * 4
|
54 |
+
|
55 |
+
|
56 |
+
def natural_vote2_last_response(request: gr.Request):
|
57 |
+
"""
|
58 |
+
Handle a user vote for naturalness as "Somewhat Awkward".
|
59 |
+
|
60 |
+
|
61 |
+
Args:
|
62 |
+
request (gr.Request):
|
63 |
+
The Gradio request object providing access to HTTP headers and metadata.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
tuple:
|
67 |
+
A tuple containing:
|
68 |
+
("Somewhat Awkward", <ip_address>, (disable_btn,) * 4)
|
69 |
+
|
70 |
+
- "Somewhat Awkward": The selected vote or label.
|
71 |
+
- <ip_address>: The IP address of the client retrieved from the request.
|
72 |
+
- disable_btn: An object repeated four times,
|
73 |
+
to disable natural vote buttons.
|
74 |
+
"""
|
75 |
+
ip_address1 = get_ip(request)
|
76 |
+
print(f"Somewhat Awkward (voted). ip: {ip_address1}")
|
77 |
+
return (
|
78 |
+
"Somewhat Awkward",
|
79 |
+
ip_address1,
|
80 |
+
) + (disable_btn,) * 4
|
81 |
+
|
82 |
+
|
83 |
+
def natural_vote3_last_response(request: gr.Request):
|
84 |
+
"""
|
85 |
+
Handle a user vote for naturalness as "Very Awkward".
|
86 |
+
|
87 |
+
|
88 |
+
Args:
|
89 |
+
request (gr.Request):
|
90 |
+
The Gradio request object providing access to HTTP headers and metadata.
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
tuple:
|
94 |
+
A tuple containing:
|
95 |
+
("Very Awkward", <ip_address>, (disable_btn,) * 4)
|
96 |
+
|
97 |
+
- "Very Awkward": The selected vote or label.
|
98 |
+
- <ip_address>: The IP address of the client retrieved from the request.
|
99 |
+
- disable_btn: An object repeated four times,
|
100 |
+
to disable natural vote buttons.
|
101 |
+
"""
|
102 |
+
ip_address1 = get_ip(request)
|
103 |
+
print(f"Very Awkward (voted). ip: {ip_address1}")
|
104 |
+
return (
|
105 |
+
"Very Awkward",
|
106 |
+
ip_address1,
|
107 |
+
) + (disable_btn,) * 4
|
108 |
+
|
109 |
+
|
110 |
+
def natural_vote4_last_response(request: gr.Request):
|
111 |
+
"""
|
112 |
+
Handle a user vote for naturalness as "Unnatural".
|
113 |
+
|
114 |
+
|
115 |
+
Args:
|
116 |
+
request (gr.Request):
|
117 |
+
The Gradio request object providing access to HTTP headers and metadata.
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
tuple:
|
121 |
+
A tuple containing:
|
122 |
+
("Unnatural", <ip_address>, (disable_btn,) * 4)
|
123 |
+
|
124 |
+
- "Unnatural": The selected vote or label.
|
125 |
+
- <ip_address>: The IP address of the client retrieved from the request.
|
126 |
+
- disable_btn: An object repeated four times,
|
127 |
+
to disable natural vote buttons.
|
128 |
+
"""
|
129 |
+
ip_address1 = get_ip(request)
|
130 |
+
print(f"Unnatural (voted). ip: {ip_address1}")
|
131 |
+
return (
|
132 |
+
"Unnatural",
|
133 |
+
ip_address1,
|
134 |
+
) + (disable_btn,) * 4
|
135 |
+
|
136 |
+
|
137 |
+
def relevant_vote1_last_response(request: gr.Request):
|
138 |
+
"""
|
139 |
+
Handle a user vote for relevance as "Highly Relevant".
|
140 |
+
|
141 |
+
|
142 |
+
Args:
|
143 |
+
request (gr.Request):
|
144 |
+
The Gradio request object providing access to HTTP headers and metadata.
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
tuple:
|
148 |
+
A tuple containing:
|
149 |
+
("Highly Relevant", <ip_address>, (disable_btn,) * 4)
|
150 |
+
|
151 |
+
- "Highly Relevant": The selected vote or label.
|
152 |
+
- <ip_address>: The IP address of the client retrieved from the request.
|
153 |
+
- disable_btn: An object repeated four times,
|
154 |
+
to disable relevance vote buttons.
|
155 |
+
"""
|
156 |
+
ip_address1 = get_ip(request)
|
157 |
+
print(f"Highly Relevant (voted). ip: {ip_address1}")
|
158 |
+
return (
|
159 |
+
"Highly Relevant",
|
160 |
+
ip_address1,
|
161 |
+
) + (disable_btn,) * 4
|
162 |
+
|
163 |
+
|
164 |
+
def relevant_vote2_last_response(request: gr.Request):
|
165 |
+
"""
|
166 |
+
Handle a user vote for relevance as "Partially Relevant".
|
167 |
+
|
168 |
+
|
169 |
+
Args:
|
170 |
+
request (gr.Request):
|
171 |
+
The Gradio request object providing access to HTTP headers and metadata.
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
tuple:
|
175 |
+
A tuple containing:
|
176 |
+
("Partially Relevant", <ip_address>, (disable_btn,) * 4)
|
177 |
+
|
178 |
+
- "Partially Relevant": The selected vote or label.
|
179 |
+
- <ip_address>: The IP address of the client retrieved from the request.
|
180 |
+
- disable_btn: An object repeated four times,
|
181 |
+
to disable relevance vote buttons.
|
182 |
+
"""
|
183 |
+
ip_address1 = get_ip(request)
|
184 |
+
print(f"Partially Relevant (voted). ip: {ip_address1}")
|
185 |
+
return (
|
186 |
+
"Partially Relevant",
|
187 |
+
ip_address1,
|
188 |
+
) + (disable_btn,) * 4
|
189 |
+
|
190 |
+
|
191 |
+
def relevant_vote3_last_response(request: gr.Request):
|
192 |
+
"""
|
193 |
+
Handle a user vote for relevance as "Slightly Irrelevant".
|
194 |
+
|
195 |
+
|
196 |
+
Args:
|
197 |
+
request (gr.Request):
|
198 |
+
The Gradio request object providing access to HTTP headers and metadata.
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
tuple:
|
202 |
+
A tuple containing:
|
203 |
+
("Slightly Irrelevant", <ip_address>, (disable_btn,) * 4)
|
204 |
+
|
205 |
+
- "Slightly Irrelevant": The selected vote or label.
|
206 |
+
- <ip_address>: The IP address of the client retrieved from the request.
|
207 |
+
- disable_btn: An object repeated four times,
|
208 |
+
to disable relevance vote buttons.
|
209 |
+
"""
|
210 |
+
ip_address1 = get_ip(request)
|
211 |
+
print(f"Slightly Irrelevant (voted). ip: {ip_address1}")
|
212 |
+
return (
|
213 |
+
"Slightly Irrelevant",
|
214 |
+
ip_address1,
|
215 |
+
) + (disable_btn,) * 4
|
216 |
+
|
217 |
+
|
218 |
+
def relevant_vote4_last_response(request: gr.Request):
|
219 |
+
"""
|
220 |
+
Handle a user vote for relevance as "Completely Irrelevant".
|
221 |
+
|
222 |
+
|
223 |
+
Args:
|
224 |
+
request (gr.Request):
|
225 |
+
The Gradio request object providing access to HTTP headers and metadata.
|
226 |
+
|
227 |
+
Returns:
|
228 |
+
tuple:
|
229 |
+
A tuple containing:
|
230 |
+
("Completely Irrelevant", <ip_address>, (disable_btn,) * 4)
|
231 |
+
|
232 |
+
- "Completely Irrelevant": The selected vote or label.
|
233 |
+
- <ip_address>: The IP address of the client retrieved from the request.
|
234 |
+
- disable_btn: An object repeated four times,
|
235 |
+
to disable relevance vote buttons.
|
236 |
+
"""
|
237 |
+
ip_address1 = get_ip(request)
|
238 |
+
print(f"Completely Irrelevant (voted). ip: {ip_address1}")
|
239 |
+
return (
|
240 |
+
"Completely Irrelevant",
|
241 |
+
ip_address1,
|
242 |
+
) + (disable_btn,) * 4
|
pyscripts/utils/dialog_eval/vert.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
import sys
|
8 |
+
import warnings
|
9 |
+
from collections import Counter
|
10 |
+
from fractions import Fraction
|
11 |
+
|
12 |
+
import nltk
|
13 |
+
import numpy as np
|
14 |
+
from nltk.translate.bleu_score import (
|
15 |
+
SmoothingFunction,
|
16 |
+
brevity_penalty,
|
17 |
+
closest_ref_length,
|
18 |
+
modified_precision,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
def corpus_bleu(
|
23 |
+
list_of_references,
|
24 |
+
hypotheses,
|
25 |
+
weights=(0.25, 0.25, 0.25, 0.25),
|
26 |
+
smoothing_function=None,
|
27 |
+
auto_reweigh=False,
|
28 |
+
averaging_mode="geometric",
|
29 |
+
no_length_penalty=False,
|
30 |
+
):
|
31 |
+
"""
|
32 |
+
Calculate a single corpus-level BLEU score (aka. system-level BLEU) for all
|
33 |
+
the hypotheses and their respective references.
|
34 |
+
|
35 |
+
Instead of averaging the sentence level BLEU scores (i.e. marco-average
|
36 |
+
precision), the original BLEU metric (Papineni et al. 2002) accounts for
|
37 |
+
the micro-average precision (i.e. summing the numerators and denominators
|
38 |
+
for each hypothesis-reference(s) pairs before the division).
|
39 |
+
|
40 |
+
>>> hyp1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which',
|
41 |
+
... 'ensures', 'that', 'the', 'military', 'always',
|
42 |
+
... 'obeys', 'the', 'commands', 'of', 'the', 'party']
|
43 |
+
>>> ref1a = ['It', 'is', 'a', 'guide', 'to', 'action', 'that',
|
44 |
+
... 'ensures', 'that', 'the', 'military', 'will', 'forever',
|
45 |
+
... 'heed', 'Party', 'commands']
|
46 |
+
>>> ref1b = ['It', 'is', 'the', 'guiding', 'principle', 'which',
|
47 |
+
... 'guarantees', 'the', 'military', 'forces', 'always',
|
48 |
+
... 'being', 'under', 'the', 'command', 'of', 'the', 'Party']
|
49 |
+
>>> ref1c = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the',
|
50 |
+
... 'army', 'always', 'to', 'heed', 'the', 'directions',
|
51 |
+
... 'of', 'the', 'party']
|
52 |
+
|
53 |
+
>>> hyp2 = ['he', 'read', 'the', 'book', 'because', 'he', 'was',
|
54 |
+
... 'interested', 'in', 'world', 'history']
|
55 |
+
>>> ref2a = ['he', 'was', 'interested', 'in', 'world', 'history',
|
56 |
+
... 'because', 'he', 'read', 'the', 'book']
|
57 |
+
|
58 |
+
>>> list_of_references = [[ref1a, ref1b, ref1c], [ref2a]]
|
59 |
+
>>> hypotheses = [hyp1, hyp2]
|
60 |
+
>>> corpus_bleu(list_of_references, hypotheses) # doctest: +ELLIPSIS
|
61 |
+
0.5920...
|
62 |
+
|
63 |
+
The example below show that corpus_bleu() is different from averaging
|
64 |
+
sentence_bleu() for hypotheses
|
65 |
+
|
66 |
+
>>> score1 = sentence_bleu([ref1a, ref1b, ref1c], hyp1)
|
67 |
+
>>> score2 = sentence_bleu([ref2a], hyp2)
|
68 |
+
>>> (score1 + score2) / 2 # doctest: +ELLIPSIS
|
69 |
+
0.6223...
|
70 |
+
|
71 |
+
:param list_of_references: a corpus of lists of reference
|
72 |
+
sentences, w.r.t. hypotheses
|
73 |
+
:type list_of_references: list(list(list(str)))
|
74 |
+
:param hypotheses: a list of hypothesis sentences
|
75 |
+
:type hypotheses: list(list(str))
|
76 |
+
:param weights: weights for unigrams, bigrams, trigrams and so on
|
77 |
+
:type weights: list(float)
|
78 |
+
:param smoothing_function:
|
79 |
+
:type smoothing_function: SmoothingFunction
|
80 |
+
:param auto_reweigh: Option to re-normalize the weights uniformly.
|
81 |
+
:type auto_reweigh: bool
|
82 |
+
:return: The corpus-level BLEU score.
|
83 |
+
:rtype: float
|
84 |
+
"""
|
85 |
+
# Before proceeding to compute BLEU, perform sanity checks.
|
86 |
+
|
87 |
+
p_numerators = Counter() # Key = ngram order, and value = no. of ngram matches.
|
88 |
+
p_denominators = Counter() # Key = ngram order, and value = no. of ngram in ref.
|
89 |
+
hyp_lengths, ref_lengths = 0, 0
|
90 |
+
|
91 |
+
assert len(list_of_references) == len(hypotheses), (
|
92 |
+
"The number of hypotheses and their reference(s) should be the " "same "
|
93 |
+
)
|
94 |
+
|
95 |
+
# Iterate through each hypothesis and their corresponding references.
|
96 |
+
for references, hypothesis in zip(list_of_references, hypotheses):
|
97 |
+
# For each order of ngram, calculate the numerator and
|
98 |
+
# denominator for the corpus-level modified precision.
|
99 |
+
for i, _ in enumerate(weights, start=1):
|
100 |
+
p_i = modified_precision(references, hypothesis, i)
|
101 |
+
p_numerators[i] += p_i.numerator
|
102 |
+
p_denominators[i] += p_i.denominator
|
103 |
+
|
104 |
+
# Calculate the hypothesis length and the closest reference length.
|
105 |
+
# Adds them to the corpus-level hypothesis and reference counts.
|
106 |
+
hyp_len = len(hypothesis)
|
107 |
+
hyp_lengths += hyp_len
|
108 |
+
ref_lengths += closest_ref_length(references, hyp_len)
|
109 |
+
|
110 |
+
# Calculate corpus-level brevity penalty.
|
111 |
+
if no_length_penalty and averaging_mode == "geometric":
|
112 |
+
bp = 1.0
|
113 |
+
elif no_length_penalty and averaging_mode == "arithmetic":
|
114 |
+
bp = 0.0
|
115 |
+
else:
|
116 |
+
assert not no_length_penalty
|
117 |
+
assert (
|
118 |
+
averaging_mode != "arithmetic"
|
119 |
+
), "Not sure how to apply length penalty when aurithmetic mode"
|
120 |
+
bp = brevity_penalty(ref_lengths, hyp_lengths)
|
121 |
+
|
122 |
+
# Uniformly re-weighting based on maximum hypothesis lengths if largest
|
123 |
+
# order of n-grams < 4 and weights is set at default.
|
124 |
+
if auto_reweigh:
|
125 |
+
if hyp_lengths < 4 and weights == (0.25, 0.25, 0.25, 0.25):
|
126 |
+
weights = (1 / hyp_lengths,) * hyp_lengths
|
127 |
+
|
128 |
+
# Collects the various precision values for the different ngram orders.
|
129 |
+
p_n = [
|
130 |
+
Fraction(p_numerators[i], p_denominators[i], _normalize=False)
|
131 |
+
for i, _ in enumerate(weights, start=1)
|
132 |
+
]
|
133 |
+
|
134 |
+
# Returns 0 if there's no matching n-grams
|
135 |
+
# We only need to check for p_numerators[1] == 0, since if there's
|
136 |
+
# no unigrams, there won't be any higher order ngrams.
|
137 |
+
if p_numerators[1] == 0:
|
138 |
+
return 0
|
139 |
+
|
140 |
+
# If there's no smoothing, set use method0 from SmoothinFunction class.
|
141 |
+
if not smoothing_function:
|
142 |
+
smoothing_function = SmoothingFunction().method0
|
143 |
+
# Smoothen the modified precision.
|
144 |
+
# Note: smoothing_function() may convert values into floats;
|
145 |
+
# it tries to retain the Fraction object as much as the
|
146 |
+
# smoothing method allows.
|
147 |
+
p_n = smoothing_function(
|
148 |
+
p_n, references=references, hypothesis=hypothesis, hyp_len=hyp_lengths
|
149 |
+
)
|
150 |
+
|
151 |
+
if averaging_mode == "geometric":
|
152 |
+
s = (w_i * math.log(p_i) for w_i, p_i in zip(weights, p_n))
|
153 |
+
s = bp * math.exp(math.fsum(s))
|
154 |
+
elif averaging_mode == "arithmetic":
|
155 |
+
s = (w_i * p_i for w_i, p_i in zip(weights, p_n))
|
156 |
+
s = math.fsum(s)
|
157 |
+
|
158 |
+
return s
|
159 |
+
|
160 |
+
|
161 |
+
def sentence_bleu(
|
162 |
+
references,
|
163 |
+
hypothesis,
|
164 |
+
weights=(0.25, 0.25, 0.25, 0.25),
|
165 |
+
smoothing_function=None,
|
166 |
+
auto_reweigh=False,
|
167 |
+
averaging_mode="geometric",
|
168 |
+
no_length_penalty=False,
|
169 |
+
):
|
170 |
+
return corpus_bleu(
|
171 |
+
[references],
|
172 |
+
[hypothesis],
|
173 |
+
weights,
|
174 |
+
smoothing_function,
|
175 |
+
auto_reweigh,
|
176 |
+
averaging_mode,
|
177 |
+
no_length_penalty,
|
178 |
+
)
|
179 |
+
|
180 |
+
|
181 |
+
def get_target_sequences(manifest, ground_truth, to_take=1000):
|
182 |
+
import json
|
183 |
+
import pathlib
|
184 |
+
|
185 |
+
with open(ground_truth, "r") as fin:
|
186 |
+
original_continuations = json.loads(fin.read())
|
187 |
+
|
188 |
+
sequence2length = [(k, v[0]) for k, v in original_continuations.items()]
|
189 |
+
assert all(float(v) >= 6.0 for (_, v) in sequence2length) # 6 seconds
|
190 |
+
|
191 |
+
sequence2length.sort(key=lambda x: x[1])
|
192 |
+
to_take_sequences = set(v[0] for v in sequence2length[:to_take])
|
193 |
+
to_take_ids = []
|
194 |
+
|
195 |
+
with open(manifest, "r") as f:
|
196 |
+
f.readline()
|
197 |
+
|
198 |
+
for i, line in enumerate(f.readlines()):
|
199 |
+
seq_id = line.split()[0]
|
200 |
+
seq_id = pathlib.Path(seq_id).name.split("__")[0]
|
201 |
+
|
202 |
+
if seq_id in to_take_sequences:
|
203 |
+
to_take_ids.append(i)
|
204 |
+
|
205 |
+
print(f"Took {len(to_take_ids)} ids")
|
206 |
+
return set(to_take_ids)
|
207 |
+
|
208 |
+
|
209 |
+
def get_self_bleu(utterances, averaging_mode, weights):
|
210 |
+
self_bleu = []
|
211 |
+
|
212 |
+
for i in range(len(utterances)):
|
213 |
+
hypo = utterances[i]
|
214 |
+
rest = utterances[:i] + utterances[i + 1 :]
|
215 |
+
|
216 |
+
self_bleu.append(
|
217 |
+
sentence_bleu(
|
218 |
+
rest,
|
219 |
+
hypo,
|
220 |
+
weights,
|
221 |
+
no_length_penalty=True,
|
222 |
+
averaging_mode=averaging_mode,
|
223 |
+
)
|
224 |
+
)
|
225 |
+
|
226 |
+
return self_bleu
|
227 |
+
|
228 |
+
|
229 |
+
def get_self_bleu2_arithmetic(utterances):
|
230 |
+
weights = (0.5, 0.5) # equal weight for unigrams and bigrams
|
231 |
+
return get_self_bleu(utterances, averaging_mode="arithmetic", weights=weights)
|
232 |
+
|
233 |
+
|
234 |
+
def get_self_bleu2_geometric(utterances):
|
235 |
+
weights = (0.5, 0.5)
|
236 |
+
return get_self_bleu(utterances, averaging_mode="geometric", weights=weights)
|
237 |
+
|
238 |
+
|
239 |
+
def get_auto_bleu2_arithmetic(utterances):
|
240 |
+
weights = (0.5, 0.5)
|
241 |
+
return [auto_bleu(u, mean_mode="arithmetic", weights=weights) for u in utterances]
|
242 |
+
|
243 |
+
|
244 |
+
def get_auto_bleu2_geometric(utterances):
|
245 |
+
weights = (0.5, 0.5)
|
246 |
+
return [auto_bleu(u, mean_mode="geometric", weights=weights) for u in utterances]
|
247 |
+
|
248 |
+
|
249 |
+
def get_auto_bleu3_geometric(utterances):
|
250 |
+
weights = (1.0 / 3, 1.0 / 3, 1.0 / 3)
|
251 |
+
return [auto_bleu(u, mean_mode="geometric", weights=weights) for u in utterances]
|
252 |
+
|
253 |
+
|
254 |
+
def get_auto_bleu3_arithmetic(utterances):
|
255 |
+
weights = (1.0 / 3, 1.0 / 3, 1.0 / 3)
|
256 |
+
return [auto_bleu(u, mean_mode="arithmetic", weights=weights) for u in utterances]
|
257 |
+
|
258 |
+
|
259 |
+
def get_self_bleu3_arithmetic(utterances):
|
260 |
+
weights = (1.0 / 3, 1.0 / 3, 1.0 / 3)
|
261 |
+
return get_self_bleu(utterances, averaging_mode="arithmetic", weights=weights)
|
262 |
+
|
263 |
+
|
264 |
+
def get_self_bleu3_geometric(utterances):
|
265 |
+
weights = (1.0 / 3, 1.0 / 3, 1.0 / 3)
|
266 |
+
return get_self_bleu(utterances, averaging_mode="geometric", weights=weights)
|
267 |
+
|
268 |
+
|
269 |
+
def auto_bleu(sentence, weights, mean_mode="arithmetic"):
|
270 |
+
if len(sentence) <= 1:
|
271 |
+
return 0
|
272 |
+
|
273 |
+
N = len(weights)
|
274 |
+
|
275 |
+
bleu_n = np.zeros([N])
|
276 |
+
for n in range(N):
|
277 |
+
targ_ngrams = list(nltk.ngrams(sentence, n + 1))
|
278 |
+
for p in range(len(targ_ngrams)):
|
279 |
+
left = sentence[:p]
|
280 |
+
right = sentence[(p + n + 1) :]
|
281 |
+
rest_ngrams = list(nltk.ngrams(left, n + 1)) + list(
|
282 |
+
nltk.ngrams(right, n + 1)
|
283 |
+
)
|
284 |
+
# compute the nb of matching ngrams
|
285 |
+
bleu_n[n] += targ_ngrams[p] in rest_ngrams
|
286 |
+
bleu_n[n] /= len(targ_ngrams) # average them to get a proportion
|
287 |
+
|
288 |
+
weights = np.array(weights)
|
289 |
+
if mean_mode == "arithmetic":
|
290 |
+
return (bleu_n * weights).sum()
|
291 |
+
elif mean_mode == "geometric":
|
292 |
+
return (bleu_n**weights).prod()
|
293 |
+
else:
|
294 |
+
raise ValueError(f"Unknown agggregation mode {mean_mode}")
|
295 |
+
|
296 |
+
|
297 |
+
def run_f(task_params):
|
298 |
+
f, terms = task_params
|
299 |
+
return f(terms)
|