Siddhant commited on
Commit
b9a6dd9
·
1 Parent(s): 58f82d5

Update demo

Browse files
app.py CHANGED
@@ -5,347 +5,382 @@ except ImportError:
5
  with open('versa.sh', 'rb') as file:
6
  script = file.read()
7
  rc = call(script, shell=True)
 
8
  import os
9
  import shutil
10
- from espnet2.sds.asr.espnet_asr import ESPnetASRModel
11
- from espnet2.sds.asr.owsm_asr import OWSMModel
12
- from espnet2.sds.asr.owsm_ctc_asr import OWSMCTCModel
13
- from espnet2.sds.asr.whisper_asr import WhisperASRModel
14
- from espnet2.sds.tts.espnet_tts import ESPnetTTSModel
15
- from espnet2.sds.tts.chat_tts import ChatTTSModel
16
- from espnet2.sds.llm.hugging_face_llm import HuggingFaceLLM
17
- from espnet2.sds.vad.webrtc_vad import WebrtcVADModel
18
- from espnet2.sds.eval.TTS_intelligibility import handle_espnet_TTS_intelligibility
19
- from espnet2.sds.eval.ASR_WER import handle_espnet_ASR_WER
20
- from espnet2.sds.eval.TTS_speech_quality import TTS_psuedomos
21
- from espnet2.sds.eval.LLM_Metrics import perplexity, vert, bert_score, DialoGPT_perplexity
22
- from espnet2.sds.utils.chat import Chat
23
- from espnet2.sds.end_to_end.mini_omni_e2e import MiniOmniE2EModel
24
- import argparse
25
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  access_token = os.environ.get("HF_TOKEN")
28
  ASR_name="pyf98/owsm_ctc_v3.1_1B"
29
  LLM_name="meta-llama/Llama-3.2-1B-Instruct"
30
  TTS_name="kan-bayashi/ljspeech_vits"
31
- ASR_options="pyf98/owsm_ctc_v3.1_1B,espnet/owsm_ctc_v3.2_ft_1B,espnet/owsm_v3.1_ebf,librispeech_asr,whisper".split(",")
32
  LLM_options="meta-llama/Llama-3.2-1B-Instruct,HuggingFaceTB/SmolLM2-1.7B-Instruct".split(",")
33
  TTS_options="kan-bayashi/ljspeech_vits,kan-bayashi/libritts_xvector_vits,kan-bayashi/vctk_multi_spk_vits,ChatTTS".split(",")
34
  Eval_options="Latency,TTS Intelligibility,TTS Speech Quality,ASR WER,Text Dialog Metrics"
35
  upload_to_hub=None
 
 
 
36
  ASR_curr_name=None
37
  LLM_curr_name=None
38
  TTS_curr_name=None
39
- # def read_args():
40
- # global access_token
41
- # global ASR_name
42
- # global LLM_name
43
- # global TTS_name
44
- # global ASR_options
45
- # global LLM_options
46
- # global TTS_options
47
- # global Eval_options
48
- # global upload_to_hub
49
- # parser = argparse.ArgumentParser(description="Run the app with HF_TOKEN as a command-line argument.")
50
- # parser.add_argument("--HF_TOKEN", required=True, help="Provide the Hugging Face token.")
51
- # parser.add_argument("--asr_options", required=True, help="Provide the possible ASR options available to user.")
52
- # parser.add_argument("--llm_options", required=True, help="Provide the possible LLM options available to user.")
53
- # parser.add_argument("--tts_options", required=True, help="Provide the possible TTS options available to user.")
54
- # parser.add_argument("--eval_options", required=True, help="Provide the possible automatic evaluation metrics available to user.")
55
- # parser.add_argument("--default_asr_model", required=False, default="pyf98/owsm_ctc_v3.1_1B", help="Provide the default ASR model.")
56
- # parser.add_argument("--default_llm_model", required=False, default="meta-llama/Llama-3.2-1B-Instruct", help="Provide the default ASR model.")
57
- # parser.add_argument("--default_tts_model", required=False, default="kan-bayashi/ljspeech_vits", help="Provide the default ASR model.")
58
- # parser.add_argument("--upload_to_hub", required=False, default=None, help="Hugging Face dataset to upload user data")
59
- # args = parser.parse_args()
60
- # access_token=args.HF_TOKEN
61
- # ASR_name=args.default_asr_model
62
- # LLM_name=args.default_llm_model
63
- # TTS_name=args.default_tts_model
64
- # ASR_options=args.asr_options.split(",")
65
- # LLM_options=args.llm_options.split(",")
66
- # TTS_options=args.tts_options.split(",")
67
- # Eval_options=args.eval_options.split(",")
68
- # upload_to_hub=args.upload_to_hub
69
-
70
- # read_args()
71
- from huggingface_hub import HfApi
72
 
73
- api = HfApi()
74
- import nltk
75
- nltk.download('averaged_perceptron_tagger_eng')
76
- import gradio as gr
77
-
78
-
79
- import numpy as np
80
-
81
- chat = Chat(2)
82
- chat.init_chat({"role": "system", "content": "You are a helpful and friendly AI assistant. The user is talking to you with their voice and you should respond in a conversational style. You are polite, respectful, and aim to provide concise and complete responses of less than 15 words."})
83
- user_role = "user"
84
 
85
- text2speech=None
86
- s2t=None
87
- LM_pipe=None
88
- client=None
89
-
90
- latency_ASR=0.0
91
- latency_LM=0.0
92
- latency_TTS=0.0
93
-
94
- text_str=""
95
- asr_output_str=""
96
- vad_output=None
97
  audio_output = None
98
  audio_output1 = None
99
- LLM_response_arr=[]
100
- total_response_arr=[]
101
-
102
- def handle_selection(option):
103
- global TTS_curr_name
104
- if TTS_curr_name is not None:
105
- if option==TTS_curr_name:
106
- return
107
- yield gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False)
108
- global text2speech
109
- TTS_curr_name=option
110
- tag = option
111
- if tag=="ChatTTS":
112
- text2speech = ChatTTSModel()
113
- else:
114
- text2speech = ESPnetTTSModel(tag)
115
- text2speech.warmup()
116
- yield gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True)
117
-
118
- def handle_LLM_selection(option):
119
- global LLM_curr_name
120
- if LLM_curr_name is not None:
121
- if option==LLM_curr_name:
122
- return
123
- yield gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False)
124
- global LM_pipe
125
- LLM_curr_name=option
126
- LM_pipe = HuggingFaceLLM(access_token=access_token,tag = option)
127
- LM_pipe.warmup()
128
- yield gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True)
129
-
130
- def handle_ASR_selection(option):
131
- global ASR_curr_name
132
- if option=="librispeech_asr":
133
- option="espnet/simpleoier_librispeech_asr_train_asr_conformer7_wavlm_large_raw_en_bpe5000_sp"
134
- if ASR_curr_name is not None:
135
- if option==ASR_curr_name:
136
- return
137
- yield gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False)
138
- global s2t
139
- ASR_curr_name=option
140
- if option=="espnet/owsm_v3.1_ebf":
141
- s2t = OWSMModel()
142
- elif option=="espnet/simpleoier_librispeech_asr_train_asr_conformer7_wavlm_large_raw_en_bpe5000_sp":
143
- s2t = ESPnetASRModel(tag=option)
144
- elif option=="whisper":
145
- s2t = WhisperASRModel()
146
- else:
147
- s2t = OWSMCTCModel(tag=option)
148
 
149
- s2t.warmup()
150
- yield gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True)
 
151
 
152
- def handle_eval_selection(option, TTS_audio_output, LLM_Output, ASR_audio_output, ASR_transcript):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  global LLM_response_arr
154
  global total_response_arr
155
- yield (option,gr.Textbox(visible=True))
156
- if option=="Latency":
157
- text=f"ASR Latency: {latency_ASR:.2f}\nLLM Latency: {latency_LM:.2f}\nTTS Latency: {latency_TTS:.2f}"
158
- yield (None,text)
159
- elif option=="TTS Intelligibility":
160
- yield (None,handle_espnet_TTS_intelligibility(TTS_audio_output,LLM_Output))
161
- elif option=="TTS Speech Quality":
162
- yield (None,TTS_psuedomos(TTS_audio_output))
163
- elif option=="ASR WER":
164
- yield (None,handle_espnet_ASR_WER(ASR_audio_output, ASR_transcript))
165
- elif option=="Text Dialog Metrics":
166
- yield (None,perplexity(LLM_Output.replace("\n"," "))+vert(LLM_response_arr)+bert_score(total_response_arr)+DialoGPT_perplexity(ASR_transcript.replace("\n"," "),LLM_Output.replace("\n"," ")))
167
-
168
- def handle_eval_selection_E2E(option, TTS_audio_output, LLM_Output):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  global LLM_response_arr
170
  global total_response_arr
171
- yield (option,gr.Textbox(visible=True))
172
- if option=="Latency":
173
- text=f"Total Latency: {latency_TTS:.2f}"
174
- yield (None,text)
175
- elif option=="TTS Intelligibility":
176
- yield (None,handle_espnet_TTS_intelligibility(TTS_audio_output,LLM_Output))
177
- elif option=="TTS Speech Quality":
178
- yield (None,TTS_psuedomos(TTS_audio_output))
179
- elif option=="Text Dialog Metrics":
180
- yield (None,perplexity(LLM_Output.replace("\n"," "))+vert(LLM_response_arr))
181
-
182
- def handle_type_selection(option,TTS_radio,ASR_radio,LLM_radio):
183
- global client
184
- global LM_pipe
185
- global s2t
186
- global text2speech
187
- yield (gr.Radio(visible=False),gr.Radio(visible=False),gr.Radio(visible=False),gr.Radio(visible=False), gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False),gr.Radio(visible=False),gr.Radio(visible=False))
188
- if option=="Cascaded":
189
- client=None
190
- for _ in handle_selection(TTS_radio):
191
- continue
192
- for _ in handle_ASR_selection(ASR_radio):
193
- continue
194
- for _ in handle_LLM_selection(LLM_radio):
195
- continue
196
- yield (gr.Radio(visible=True),gr.Radio(visible=True),gr.Radio(visible=True),gr.Radio(visible=False),gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True),gr.Radio(visible=True, interactive=True),gr.Radio(visible=False))
197
  else:
198
- text2speech=None
199
- s2t=None
200
- LM_pipe=None
201
- global ASR_curr_name
202
- global LLM_curr_name
203
- global TTS_curr_name
204
- ASR_curr_name=None
205
- LLM_curr_name=None
206
- TTS_curr_name=None
207
- handle_E2E_selection()
208
- yield (gr.Radio(visible=False),gr.Radio(visible=False),gr.Radio(visible=False),gr.Radio(visible=True),gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True),gr.Radio(visible=False),gr.Radio(visible=True, interactive=True))
209
-
210
-
211
- def handle_E2E_selection():
212
- global client
213
- if client is None:
214
- client = MiniOmniE2EModel()
215
- client.warmup()
216
 
217
  def start_warmup():
218
- global client
219
- for opt in ASR_options:
220
- if opt==ASR_name:
221
- continue
222
- print(opt)
223
- for _ in handle_ASR_selection(opt):
224
- continue
225
- for opt in LLM_options:
226
- if opt==LLM_name:
227
- continue
228
- print(opt)
229
- for _ in handle_LLM_selection(opt):
230
- continue
231
- for opt in TTS_options:
232
- if opt==TTS_name:
233
- continue
234
- print(opt)
235
- for _ in handle_selection(opt):
236
- continue
237
- handle_E2E_selection()
238
- client=None
239
- for _ in handle_selection(TTS_name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  continue
241
- for _ in handle_ASR_selection(ASR_name):
242
  continue
243
- for _ in handle_LLM_selection(LLM_name):
244
  continue
245
- dummy_input = torch.randn(
 
246
  (3000),
247
  dtype=getattr(torch, "float16"),
248
  device="cpu",
249
- ).cpu().numpy()
250
- dummy_text="This is dummy text"
 
 
 
251
  for opt in Eval_options:
252
  handle_eval_selection(opt, dummy_input, dummy_text, dummy_input, dummy_text)
253
 
254
- start_warmup()
255
- vad_model=WebrtcVADModel()
256
 
257
- callback = gr.CSVLogger()
258
- start_record_time=None
259
- enable_btn = gr.Button(interactive=True, visible=True)
260
- disable_btn = gr.Button(interactive=False, visible=False)
261
  def flash_buttons():
 
 
 
262
  btn_updates = (enable_btn,) * 8
263
- print(enable_btn)
264
- yield ("","",)+btn_updates
265
-
266
-
267
- def get_ip(request: gr.Request):
268
- if "cf-connecting-ip" in request.headers:
269
- ip = request.headers["cf-connecting-ip"]
270
- elif "x-forwarded-for" in request.headers:
271
- ip = request.headers["x-forwarded-for"]
272
- if "," in ip:
273
- ip = ip.split(",")[0]
274
- else:
275
- ip = request.client.host
276
- return ip
277
-
278
-
279
- def vote_last_response(vote_type, request: gr.Request):
280
- with open("save_dict.json", "a") as fout:
281
- data = {
282
- "tstamp": round(time.time(), 4),
283
- "type": vote_type,
284
- "ip": get_ip(request),
285
- }
286
- fout.write(json.dumps(data) + "\n")
287
-
288
-
289
- def natural_vote1_last_response(
290
- request: gr.Request
291
  ):
292
- ip_address1=get_ip(request)
293
- print(f"Very Natural (voted). ip: {ip_address1}")
294
- return ("Very Natural",ip_address1,)+(disable_btn,) * 4
295
-
296
- def natural_vote2_last_response(
297
- request: gr.Request
298
- ):
299
- ip_address1=get_ip(request)
300
- print(f"Somewhat Awkward (voted). ip: {ip_address1}")
301
- return ("Somewhat Awkward",ip_address1,)+(disable_btn,) * 4
302
-
303
- def natural_vote3_last_response(
304
- request: gr.Request
305
- ):
306
- ip_address1=get_ip(request)
307
- print(f"Very Awkward (voted). ip: {ip_address1}")
308
- return ("Very Awkward",ip_address1,)+(disable_btn,) * 4
309
-
310
- def natural_vote4_last_response(
311
- request: gr.Request
312
- ):
313
- ip_address1=get_ip(request)
314
- print(f"Unnatural (voted). ip: {ip_address1}")
315
- return ("Unnatural",ip_address1,)+(disable_btn,) * 4
316
-
317
- def relevant_vote1_last_response(
318
- request: gr.Request
319
- ):
320
- ip_address1=get_ip(request)
321
- print(f"Highly Relevant (voted). ip: {ip_address1}")
322
- return ("Highly Relevant",ip_address1,)+(disable_btn,) * 4
323
-
324
- def relevant_vote2_last_response(
325
- request: gr.Request
326
- ):
327
- ip_address1=get_ip(request)
328
- print(f"Partially Relevant (voted). ip: {ip_address1}")
329
- return ("Partially Relevant",ip_address1,)+(disable_btn,) * 4
330
-
331
- def relevant_vote3_last_response(
332
- request: gr.Request
333
- ):
334
- ip_address1=get_ip(request)
335
- print(f"Slightly Irrelevant (voted). ip: {ip_address1}")
336
- return ("Slightly Irrelevant",ip_address1,)+(disable_btn,) * 4
337
-
338
- def relevant_vote4_last_response(
339
- request: gr.Request
340
- ):
341
- ip_address1=get_ip(request)
342
- print(f"Completely Irrelevant (voted). ip: {ip_address1}")
343
- return ("Completely Irrelevant",ip_address1,)+(disable_btn,) * 4
344
-
345
- import json
346
- import time
347
-
348
- def transcribe(stream, new_chunk, TTS_option, ASR_option, LLM_option, type_option):
349
  sr, y = new_chunk
350
  global text_str
351
  global chat
@@ -364,219 +399,548 @@ def transcribe(stream, new_chunk, TTS_option, ASR_option, LLM_option, type_optio
364
  global total_response_arr
365
  if stream is None:
366
  # Handle user refresh
367
- # import pdb;pdb.set_trace()
368
- for (_,_,_,_,asr_output_box,text_box,audio_box,_,_) in handle_type_selection(type_option,TTS_option,ASR_option,LLM_option):
 
 
 
 
 
 
 
 
 
 
 
369
  gr.Info("The models are being reloaded due to a browser refresh.")
370
- yield (stream,asr_output_box,text_box,audio_box,gr.Audio(visible=False))
371
- stream=y
372
- chat.init_chat({"role": "system", "content": "You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise and complete responses of less than 15 words."})
373
- text_str=""
374
  audio_output = None
375
  audio_output1 = None
376
  else:
377
- stream=np.concatenate((stream,y))
378
- orig_sr=sr
379
- sr=16000
380
- if client is not None:
381
- array=vad_model(y,orig_sr, binary=True)
382
- else:
383
- array=vad_model(y,orig_sr)
384
-
385
- if array is not None:
386
- print("VAD: end of speech detected")
387
- start_time = time.time()
388
- if client is not None:
389
- try:
390
- (text_str, audio_output)=client(array, orig_sr)
391
- except Exception as e:
392
- text_str=""
393
- audio_output=None
394
- raise gr.Error(f"Error during audio streaming: {e}")
395
- asr_output_str=""
396
- latency_TTS=(time.time() - start_time)
397
- else:
398
- prompt=s2t(array)
399
- if len(prompt.strip().split())<2:
400
- text_str1=text_str
401
- yield (stream, asr_output_str, text_str1, audio_output, audio_output1)
402
- return
403
-
404
-
405
- asr_output_str=prompt
406
- total_response_arr.append(prompt.replace("\n"," "))
407
- start_LM_time=time.time()
408
- latency_ASR=(start_LM_time - start_time)
409
- chat.append({"role": user_role, "content": prompt})
410
- chat_messages = chat.to_list()
411
- generated_text = LM_pipe(chat_messages)
412
- start_TTS_time=time.time()
413
- latency_LM=(start_TTS_time - start_LM_time)
414
-
415
- chat.append({"role": "assistant", "content": generated_text})
416
- text_str=generated_text
417
- audio_output=text2speech(text_str)
418
- latency_TTS=(time.time() - start_TTS_time)
419
- audio_output1=(orig_sr,stream)
420
- stream=y
421
- LLM_response_arr.append(text_str.replace("\n"," "))
422
- total_response_arr.append(text_str.replace("\n"," "))
423
- text_str1=text_str
424
- if ((text_str!="") and (start_record_time is None)):
425
- start_record_time=time.time()
426
  elif start_record_time is not None:
427
- current_record_time=time.time()
428
- if current_record_time-start_record_time>300:
429
- gr.Info("Conversations are limited to 5 minutes. The session will restart in approximately 60 seconds. Please wait for the demo to reset. Close this message once you have read it.", duration=None)
430
- yield stream,gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False),gr.Audio(visible=False)
 
 
 
 
 
 
 
 
431
  if upload_to_hub is not None:
432
  api.upload_folder(
433
  folder_path="flagged_data_points",
434
- path_in_repo="checkpoint_"+str(start_record_time),
435
  repo_id=upload_to_hub,
436
  repo_type="dataset",
437
  token=access_token,
438
  )
439
- chat.buffer=[{"role": "system", "content": "You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise and complete responses of less than 15 words."}]
440
- text_str=""
441
  audio_output = None
442
  audio_output1 = None
443
  asr_output_str = ""
444
  start_record_time = None
445
- LLM_response_arr=[]
446
- total_response_arr=[]
447
- shutil.rmtree('flagged_data_points')
448
  os.mkdir("flagged_data_points")
449
- yield (stream,asr_output_str,text_str1, audio_output, audio_output1)
450
- yield stream,gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True),gr.Audio(visible=False)
451
-
452
- yield (stream,asr_output_str,text_str1, audio_output, audio_output1)
453
 
 
454
 
 
 
 
 
 
 
 
455
  with gr.Blocks(
456
- title="E2E Spoken Dialog System",
457
- ) as demo:
458
- with gr.Row():
459
- with gr.Column(scale=1):
460
- user_audio = gr.Audio(sources=["microphone"], streaming=True, waveform_options=gr.WaveformOptions(sample_rate=16000))
461
- with gr.Row():
462
- type_radio = gr.Radio(
463
- choices=["Cascaded", "E2E"],
464
- label="Choose type of Spoken Dialog:",
465
- value="Cascaded",
466
- )
467
- with gr.Row():
468
- ASR_radio = gr.Radio(
469
- choices=ASR_options,
470
- label="Choose ASR:",
471
- value=ASR_name,
472
- )
473
- with gr.Row():
474
- LLM_radio = gr.Radio(
475
- choices=LLM_options,
476
- label="Choose LLM:",
477
- value=LLM_name,
478
- )
479
- with gr.Row():
480
- radio = gr.Radio(
481
- choices=TTS_options,
482
- label="Choose TTS:",
483
- value=TTS_name,
484
- )
485
- with gr.Row():
486
- E2Eradio = gr.Radio(
487
- choices=["mini-omni"],
488
- label="Choose E2E model:",
489
- value="mini-omni",
490
- visible=False,
491
- )
492
- with gr.Row():
493
- feedback_btn = gr.Button(
494
- value="Please provide your feedback after each system response below.", visible=True, interactive=False, elem_id="button"
495
- )
496
- with gr.Row():
497
- natural_btn1 = gr.Button(
498
- value="Very Natural", visible=False, interactive=False, scale=1
499
- )
500
- natural_btn2 = gr.Button(
501
- value="Somewhat Awkward", visible=False, interactive=False, scale=1
502
- )
503
- natural_btn3 = gr.Button(value="Very Awkward", visible=False, interactive=False, scale=1)
504
- natural_btn4 = gr.Button(
505
- value="Unnatural", visible=False, interactive=False, scale=1
506
- )
507
- with gr.Row():
508
- relevant_btn1 = gr.Button(
509
- value="Highly Relevant", visible=False, interactive=False, scale=1
510
- )
511
- relevant_btn2 = gr.Button(
512
- value="Partially Relevant", visible=False, interactive=False, scale=1
513
- )
514
- relevant_btn3 = gr.Button(value="Slightly Irrelevant", visible=False, interactive=False, scale=1)
515
- relevant_btn4 = gr.Button(
516
- value= "Completely Irrelevant", visible=False, interactive=False, scale=1
517
- )
518
- with gr.Column(scale=1):
519
- output_audio = gr.Audio(label="Output", interactive=False, autoplay=True, visible=True)
520
- output_audio1 = gr.Audio(label="Output1", autoplay=False, visible=False)
521
- output_asr_text = gr.Textbox(label="ASR output", interactive=False)
522
- output_text = gr.Textbox(label="LLM output", interactive=False)
523
- eval_radio = gr.Radio(
524
- choices=["Latency", "TTS Intelligibility", "TTS Speech Quality", "ASR WER","Text Dialog Metrics"],
525
- label="Choose Evaluation metrics:",
526
  )
527
- eval_radio_E2E = gr.Radio(
528
- choices=["Latency", "TTS Intelligibility", "TTS Speech Quality","Text Dialog Metrics"],
529
- label="Choose Evaluation metrics:",
 
 
 
 
 
 
 
 
530
  visible=False,
531
  )
532
- output_eval_text = gr.Textbox(label="Evaluation Results")
533
- state = gr.State()
534
- with gr.Row():
535
- privacy_text = gr.Textbox(label="Privacy Notice",interactive=False, value="By using this demo, you acknowledge that interactions with this dialog system are collected for research and improvement purposes. The data will only be used to enhance the performance and understanding of the system. If you have any concerns about data collection, please discontinue use.")
536
-
537
- btn_list=[
538
- natural_btn1,
539
- natural_btn2,
540
- natural_btn3,
541
- natural_btn4,
542
- relevant_btn1,
543
- relevant_btn2,
544
- relevant_btn3,
545
- relevant_btn4,
546
- ]
547
- natural_btn_list=[
548
- natural_btn1,
549
- natural_btn2,
550
- natural_btn3,
551
- natural_btn4,
552
- ]
553
- relevant_btn_list=[
554
- relevant_btn1,
555
- relevant_btn2,
556
- relevant_btn3,
557
- relevant_btn4,
558
- ]
559
- natural_response = gr.Textbox(label="natural_response",visible=False,interactive=False)
560
- diversity_response = gr.Textbox(label="diversity_response",visible=False,interactive=False)
561
- ip_address = gr.Textbox(label="ip_address",visible=False,interactive=False)
562
- callback.setup([user_audio, output_asr_text, output_text, output_audio,output_audio1,type_radio, ASR_radio, LLM_radio, radio, E2Eradio, natural_response,diversity_response,ip_address],"flagged_data_points")
563
- user_audio.stream(transcribe, inputs=[state, user_audio, radio, ASR_radio, LLM_radio, type_radio], outputs=[state, output_asr_text, output_text, output_audio, output_audio1]).then(lambda *args: callback.flag(list(args)),[user_audio], None,preprocess=False)
564
- radio.change(fn=handle_selection, inputs=[radio], outputs=[output_asr_text, output_text, output_audio])
565
- LLM_radio.change(fn=handle_LLM_selection, inputs=[LLM_radio], outputs=[output_asr_text, output_text, output_audio])
566
- ASR_radio.change(fn=handle_ASR_selection, inputs=[ASR_radio], outputs=[output_asr_text, output_text, output_audio])
567
- eval_radio.change(fn=handle_eval_selection, inputs=[eval_radio,output_audio,output_text,output_audio1,output_asr_text], outputs=[eval_radio,output_eval_text])
568
- eval_radio_E2E.change(fn=handle_eval_selection_E2E, inputs=[eval_radio_E2E,output_audio,output_text], outputs=[eval_radio_E2E,output_eval_text])
569
- type_radio.change(fn=handle_type_selection,inputs=[type_radio,radio,ASR_radio,LLM_radio], outputs=[radio,ASR_radio,LLM_radio, E2Eradio,output_asr_text, output_text, output_audio,eval_radio,eval_radio_E2E])
570
- output_audio.play(
571
- flash_buttons, [], [natural_response,diversity_response]+btn_list
572
- ).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,type_radio, ASR_radio, LLM_radio, radio, E2Eradio], None,preprocess=False)
573
- natural_btn1.click(natural_vote1_last_response,[],[natural_response,ip_address]+natural_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,type_radio, ASR_radio, LLM_radio, radio, E2Eradio, natural_response,diversity_response,ip_address], None,preprocess=False)
574
- natural_btn2.click(natural_vote2_last_response,[],[natural_response,ip_address]+natural_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,type_radio, ASR_radio, LLM_radio, radio, E2Eradio, natural_response,diversity_response,ip_address], None,preprocess=False)
575
- natural_btn3.click(natural_vote3_last_response,[],[natural_response,ip_address]+natural_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,type_radio, ASR_radio, LLM_radio, radio, E2Eradio, natural_response,diversity_response,ip_address], None,preprocess=False)
576
- natural_btn4.click(natural_vote4_last_response,[],[natural_response,ip_address]+natural_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,type_radio, ASR_radio, LLM_radio, radio, E2Eradio, natural_response,diversity_response,ip_address], None,preprocess=False)
577
- relevant_btn1.click(relevant_vote1_last_response,[],[diversity_response,ip_address]+relevant_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,type_radio, ASR_radio, LLM_radio, radio, E2Eradio, natural_response,diversity_response,ip_address], None,preprocess=False)
578
- relevant_btn2.click(relevant_vote2_last_response,[],[diversity_response,ip_address]+relevant_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,type_radio, ASR_radio, LLM_radio, radio, E2Eradio, natural_response,diversity_response,ip_address], None,preprocess=False)
579
- relevant_btn3.click(relevant_vote3_last_response,[],[diversity_response,ip_address]+relevant_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,type_radio, ASR_radio, LLM_radio, radio, E2Eradio, natural_response,diversity_response,ip_address], None,preprocess=False)
580
- relevant_btn4.click(relevant_vote4_last_response,[],[diversity_response,ip_address]+relevant_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,type_radio, ASR_radio, LLM_radio, radio, E2Eradio, natural_response,diversity_response,ip_address], None,preprocess=False)
581
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
582
  demo.launch(share=True)
 
 
5
  with open('versa.sh', 'rb') as file:
6
  script = file.read()
7
  rc = call(script, shell=True)
8
+
9
  import os
10
  import shutil
11
+ import time
12
+ from typing import Generator, Optional, Tuple
13
+
14
+ import gradio as gr
15
+ import nltk
16
+ import numpy as np
 
 
 
 
 
 
 
 
 
17
  import torch
18
+ from huggingface_hub import HfApi
19
+ from pyscripts.utils.dialog_eval.ASR_WER import handle_espnet_ASR_WER
20
+ from pyscripts.utils.dialog_eval.human_feedback import (
21
+ natural_vote1_last_response,
22
+ natural_vote2_last_response,
23
+ natural_vote3_last_response,
24
+ natural_vote4_last_response,
25
+ relevant_vote1_last_response,
26
+ relevant_vote2_last_response,
27
+ relevant_vote3_last_response,
28
+ relevant_vote4_last_response,
29
+ )
30
+ from pyscripts.utils.dialog_eval.LLM_Metrics import (
31
+ DialoGPT_perplexity,
32
+ bert_score,
33
+ perplexity,
34
+ vert,
35
+ )
36
+ from pyscripts.utils.dialog_eval.TTS_intelligibility import (
37
+ handle_espnet_TTS_intelligibility,
38
+ )
39
+ from pyscripts.utils.dialog_eval.TTS_speech_quality import TTS_psuedomos
40
+
41
+ from espnet2.sds.espnet_model import ESPnetSDSModelInterface
42
+
43
+ # ------------------------
44
+ # Hyperparameters
45
+ # ------------------------
46
 
47
  access_token = os.environ.get("HF_TOKEN")
48
  ASR_name="pyf98/owsm_ctc_v3.1_1B"
49
  LLM_name="meta-llama/Llama-3.2-1B-Instruct"
50
  TTS_name="kan-bayashi/ljspeech_vits"
51
+ ASR_options="pyf98/owsm_ctc_v3.1_1B,espnet/owsm_ctc_v3.2_ft_1B,espnet/owsm_v3.1_ebf,librispeech_asr,whisper-large".split(",")
52
  LLM_options="meta-llama/Llama-3.2-1B-Instruct,HuggingFaceTB/SmolLM2-1.7B-Instruct".split(",")
53
  TTS_options="kan-bayashi/ljspeech_vits,kan-bayashi/libritts_xvector_vits,kan-bayashi/vctk_multi_spk_vits,ChatTTS".split(",")
54
  Eval_options="Latency,TTS Intelligibility,TTS Speech Quality,ASR WER,Text Dialog Metrics"
55
  upload_to_hub=None
56
+ dialogue_model = ESPnetSDSModelInterface(
57
+ ASR_name, LLM_name, TTS_name, "Cascaded", access_token
58
+ )
59
  ASR_curr_name=None
60
  LLM_curr_name=None
61
  TTS_curr_name=None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ latency_ASR = 0.0
64
+ latency_LM = 0.0
65
+ latency_TTS = 0.0
 
 
 
 
 
 
 
 
66
 
67
+ text_str = ""
68
+ asr_output_str = ""
69
+ vad_output = None
 
 
 
 
 
 
 
 
 
70
  audio_output = None
71
  audio_output1 = None
72
+ LLM_response_arr = []
73
+ total_response_arr = []
74
+ callback = gr.CSVLogger()
75
+ start_record_time = None
76
+ enable_btn = gr.Button(interactive=True, visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ # ------------------------
79
+ # Function Definitions
80
+ # ------------------------
81
 
82
+ def handle_eval_selection(
83
+ option: str,
84
+ TTS_audio_output: str,
85
+ LLM_Output: str,
86
+ ASR_audio_output: str,
87
+ ASR_transcript: str,
88
+ ):
89
+ """
90
+ Handles the evaluation of a selected metric based on
91
+ user input and provided outputs.
92
+
93
+ This function evaluates different aspects of a
94
+ casacaded conversational AI pipeline, such as:
95
+ Latency, TTS intelligibility, TTS speech quality,
96
+ ASR WER, and text dialog metrics.
97
+ It is designed to integrate with Gradio via
98
+ multiple yield statements,
99
+ allowing updates to be displayed in real time.
100
+
101
+ Parameters:
102
+ ----------
103
+ option : str
104
+ The evaluation metric selected by the user.
105
+ Supported options include:
106
+ - "Latency"
107
+ - "TTS Intelligibility"
108
+ - "TTS Speech Quality"
109
+ - "ASR WER"
110
+ - "Text Dialog Metrics"
111
+ TTS_audio_output : np.ndarray
112
+ The audio output generated by the TTS module for evaluation.
113
+ LLM_Output : str
114
+ The text output generated by the LLM module for evaluation.
115
+ ASR_audio_output : np.ndarray
116
+ The audio input/output used for ASR evaluation.
117
+ ASR_transcript : str
118
+ The transcript generated by the ASR module for evaluation.
119
+
120
+ Returns:
121
+ -------
122
+ str
123
+ A string representation of the evaluation results.
124
+ The specific result depends on the selected evaluation metric:
125
+ - "Latency": Latencies of ASR, LLM, and TTS modules.
126
+ - "TTS Intelligibility": A range of scores indicating how intelligible
127
+ the TTS audio output is based on different reference ASR models.
128
+ - "TTS Speech Quality": A range of scores representing the
129
+ speech quality of the TTS audio output.
130
+ - "ASR WER": The Word Error Rate (WER) of the ASR output
131
+ based on different judge ASR models.
132
+ - "Text Dialog Metrics": A combination of perplexity,
133
+ diversity metrics, and relevance scores for the dialog.
134
+
135
+ Raises:
136
+ ------
137
+ ValueError
138
+ If the `option` parameter does not match any supported evaluation metric.
139
+
140
+ Example:
141
+ -------
142
+ >>> result = handle_eval_selection(
143
+ option="Latency",
144
+ TTS_audio_output=audio_array,
145
+ LLM_Output="Generated response",
146
+ ASR_audio_output=audio_input,
147
+ ASR_transcript="Expected transcript"
148
+ )
149
+ >>> print(result)
150
+ "ASR Latency: 0.14
151
+ LLM Latency: 0.42
152
+ TTS Latency: 0.21"
153
+ """
154
  global LLM_response_arr
155
  global total_response_arr
156
+ yield (option, gr.Textbox(visible=True))
157
+ if option == "Latency":
158
+ text = (
159
+ f"ASR Latency: {latency_ASR:.2f}\n"
160
+ f"LLM Latency: {latency_LM:.2f}\n"
161
+ f"TTS Latency: {latency_TTS:.2f}"
162
+ )
163
+ yield (None, text)
164
+ elif option == "TTS Intelligibility":
165
+ yield (None, handle_espnet_TTS_intelligibility(TTS_audio_output, LLM_Output))
166
+ elif option == "TTS Speech Quality":
167
+ yield (None, TTS_psuedomos(TTS_audio_output))
168
+ elif option == "ASR WER":
169
+ yield (None, handle_espnet_ASR_WER(ASR_audio_output, ASR_transcript))
170
+ elif option == "Text Dialog Metrics":
171
+ yield (
172
+ None,
173
+ perplexity(LLM_Output.replace("\n", " "))
174
+ + vert(LLM_response_arr)
175
+ + bert_score(total_response_arr)
176
+ + DialoGPT_perplexity(
177
+ ASR_transcript.replace("\n", " "), LLM_Output.replace("\n", " ")
178
+ ),
179
+ )
180
+ elif option is None:
181
+ return
182
+ else:
183
+ raise ValueError(f"Unknown option: {option}")
184
+
185
+
186
+ def handle_eval_selection_E2E(
187
+ option: str,
188
+ TTS_audio_output: str,
189
+ LLM_Output: str,
190
+ ):
191
+ """
192
+ Handles the evaluation of a selected metric based on user input
193
+ and provided outputs.
194
+
195
+ This function evaluates different aspects of an E2E
196
+ conversational AI model, such as:
197
+ Latency, TTS intelligibility, TTS speech quality, and
198
+ text dialog metrics.
199
+ It is designed to integrate with Gradio via
200
+ multiple yield statements,
201
+ allowing updates to be displayed in real time.
202
+
203
+ Parameters:
204
+ ----------
205
+ option : str
206
+ The evaluation metric selected by the user.
207
+ Supported options include:
208
+ - "Latency"
209
+ - "TTS Intelligibility"
210
+ - "TTS Speech Quality"
211
+ - "Text Dialog Metrics"
212
+ TTS_audio_output : np.ndarray
213
+ The audio output generated by the TTS module for evaluation.
214
+ LLM_Output : str
215
+ The text output generated by the LLM module for evaluation.
216
+
217
+ Returns:
218
+ -------
219
+ str
220
+ A string representation of the evaluation results.
221
+ The specific result depends on the selected evaluation metric:
222
+ - "Latency": Latency of the entire system.
223
+ - "TTS Intelligibility": A range of scores indicating how intelligible the
224
+ TTS audio output is based on different reference ASR models.
225
+ - "TTS Speech Quality": A range of scores representing the
226
+ speech quality of the TTS audio output.
227
+ - "Text Dialog Metrics": A combination of perplexity and
228
+ diversity metrics for the dialog.
229
+
230
+ Raises:
231
+ ------
232
+ ValueError
233
+ If the `option` parameter does not match any supported evaluation metric.
234
+
235
+ Example:
236
+ -------
237
+ >>> result = handle_eval_selection(
238
+ option="Latency",
239
+ TTS_audio_output=audio_array,
240
+ LLM_Output="Generated response",
241
+ )
242
+ >>> print(result)
243
+ "Total Latency: 2.34"
244
+ """
245
  global LLM_response_arr
246
  global total_response_arr
247
+ yield (option, gr.Textbox(visible=True))
248
+ if option == "Latency":
249
+ text = f"Total Latency: {latency_TTS:.2f}"
250
+ yield (None, text)
251
+ elif option == "TTS Intelligibility":
252
+ yield (None, handle_espnet_TTS_intelligibility(TTS_audio_output, LLM_Output))
253
+ elif option == "TTS Speech Quality":
254
+ yield (None, TTS_psuedomos(TTS_audio_output))
255
+ elif option == "Text Dialog Metrics":
256
+ yield (None, perplexity(LLM_Output.replace("\n", " ")) + vert(LLM_response_arr))
257
+ elif option is None:
258
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  else:
260
+ raise ValueError(f"Unknown option: {option}")
261
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
  def start_warmup():
264
+ """
265
+ Initializes and warms up the dialogue and evaluation model.
266
+
267
+ This function is designed to ensure that all
268
+ components of the dialogue model are pre-loaded
269
+ and ready for execution, avoiding delays during runtime.
270
+ """
271
+ global dialogue_model
272
+ global ASR_options
273
+ global LLM_options
274
+ global TTS_options
275
+ global ASR_name
276
+ global LLM_name
277
+ global TTS_name
278
+ for opt_count in range(len(ASR_options)):
279
+ opt = ASR_options[opt_count]
280
+ try:
281
+ for _ in dialogue_model.handle_ASR_selection(opt):
282
+ continue
283
+ except Exception:
284
+ print("Removing " + opt + " from ASR options since it cannot be loaded.")
285
+ ASR_options = ASR_options[:opt_count] + ASR_options[(opt_count + 1) :]
286
+ if opt == ASR_name:
287
+ ASR_name = ASR_options[0]
288
+ for opt_count in range(len(LLM_options)):
289
+ opt = LLM_options[opt_count]
290
+ try:
291
+ for _ in dialogue_model.handle_LLM_selection(opt):
292
+ continue
293
+ except Exception:
294
+ print("Removing " + opt + " from LLM options since it cannot be loaded.")
295
+ LLM_options = LLM_options[:opt_count] + LLM_options[(opt_count + 1) :]
296
+ if opt == LLM_name:
297
+ LLM_name = LLM_options[0]
298
+ for opt_count in range(len(TTS_options)):
299
+ opt = TTS_options[opt_count]
300
+ try:
301
+ for _ in dialogue_model.handle_TTS_selection(opt):
302
+ continue
303
+ except Exception:
304
+ print("Removing " + opt + " from TTS options since it cannot be loaded.")
305
+ TTS_options = TTS_options[:opt_count] + TTS_options[(opt_count + 1) :]
306
+ if opt == TTS_name:
307
+ TTS_name = TTS_options[0]
308
+ dialogue_model.handle_E2E_selection()
309
+ dialogue_model.client = None
310
+ for _ in dialogue_model.handle_TTS_selection(TTS_name):
311
  continue
312
+ for _ in dialogue_model.handle_ASR_selection(ASR_name):
313
  continue
314
+ for _ in dialogue_model.handle_LLM_selection(LLM_name):
315
  continue
316
+ dummy_input = (
317
+ torch.randn(
318
  (3000),
319
  dtype=getattr(torch, "float16"),
320
  device="cpu",
321
+ )
322
+ .cpu()
323
+ .numpy()
324
+ )
325
+ dummy_text = "This is dummy text"
326
  for opt in Eval_options:
327
  handle_eval_selection(opt, dummy_input, dummy_text, dummy_input, dummy_text)
328
 
 
 
329
 
 
 
 
 
330
  def flash_buttons():
331
+ """
332
+ Enables human feedback buttons after displaying system output.
333
+ """
334
  btn_updates = (enable_btn,) * 8
335
+ yield (
336
+ "",
337
+ "",
338
+ ) + btn_updates
339
+
340
+
341
+ def transcribe(
342
+ stream: np.ndarray,
343
+ new_chunk: Tuple[int, np.ndarray],
344
+ TTS_option: str,
345
+ ASR_option: str,
346
+ LLM_option: str,
347
+ type_option: str,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  ):
349
+ """
350
+ Processes and transcribes an audio stream in real-time.
351
+
352
+ This function handles the transcription of audio input
353
+ and its transformation through a cascaded
354
+ or E2E conversational AI system.
355
+ It dynamically updates the transcription, text generation,
356
+ and synthesized speech output, while managing global states and latencies.
357
+
358
+ Args:
359
+ stream: The current audio stream buffer.
360
+ `None` if the stream is being reset (e.g., after user refresh).
361
+ new_chunk: A tuple containing:
362
+ - `sr`: Sample rate of the new audio chunk.
363
+ - `y`: New audio data chunk.
364
+ TTS_option: Selected TTS model option.
365
+ ASR_option: Selected ASR model option.
366
+ LLM_option: Selected LLM model option.
367
+ type_option: Type of system ("Cascaded" or "E2E").
368
+
369
+ Yields:
370
+ Tuple[Optional[np.ndarray], Optional[str], Optional[str],
371
+ Optional[Tuple[int, np.ndarray]], Optional[Tuple[int, np.ndarray]]]:
372
+ A tuple containing:
373
+ - Updated stream buffer.
374
+ - ASR output text.
375
+ - Generated LLM output text.
376
+ - Audio output as a tuple of sample rate and audio waveform.
377
+ - User input audio as a tuple of sample rate and audio waveform.
378
+
379
+ Notes:
380
+ - Resets the session if the transcription exceeds 5 minutes.
381
+ - Updates the Gradio interface elements dynamically.
382
+ - Manages latencies.
383
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  sr, y = new_chunk
385
  global text_str
386
  global chat
 
399
  global total_response_arr
400
  if stream is None:
401
  # Handle user refresh
402
+ for (
403
+ _,
404
+ _,
405
+ _,
406
+ _,
407
+ asr_output_box,
408
+ text_box,
409
+ audio_box,
410
+ _,
411
+ _,
412
+ ) in dialogue_model.handle_type_selection(
413
+ type_option, TTS_option, ASR_option, LLM_option
414
+ ):
415
  gr.Info("The models are being reloaded due to a browser refresh.")
416
+ yield (stream, asr_output_box, text_box, audio_box, gr.Audio(visible=False))
417
+ stream = y
418
+ text_str = ""
 
419
  audio_output = None
420
  audio_output1 = None
421
  else:
422
+ stream = np.concatenate((stream, y))
423
+ (
424
+ asr_output_str,
425
+ text_str,
426
+ audio_output,
427
+ audio_output1,
428
+ latency_ASR,
429
+ latency_LM,
430
+ latency_TTS,
431
+ stream,
432
+ change,
433
+ ) = dialogue_model(
434
+ y,
435
+ sr,
436
+ stream,
437
+ asr_output_str,
438
+ text_str,
439
+ audio_output,
440
+ audio_output1,
441
+ latency_ASR,
442
+ latency_LM,
443
+ latency_TTS,
444
+ )
445
+ text_str1 = text_str
446
+ if change:
447
+ print("Output changed")
448
+ if asr_output_str != "":
449
+ total_response_arr.append(asr_output_str.replace("\n", " "))
450
+ LLM_response_arr.append(text_str.replace("\n", " "))
451
+ total_response_arr.append(text_str.replace("\n", " "))
452
+ if (text_str != "") and (start_record_time is None):
453
+ start_record_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
  elif start_record_time is not None:
455
+ current_record_time = time.time()
456
+ if current_record_time - start_record_time > 300:
457
+ gr.Info(
458
+ "Conversations are limited to 5 minutes. "
459
+ "The session will restart in approximately 60 seconds. "
460
+ "Please wait for the demo to reset. "
461
+ "Close this message once you have read it.",
462
+ duration=None,
463
+ )
464
+ yield stream, gr.Textbox(visible=False), gr.Textbox(
465
+ visible=False
466
+ ), gr.Audio(visible=False), gr.Audio(visible=False)
467
  if upload_to_hub is not None:
468
  api.upload_folder(
469
  folder_path="flagged_data_points",
470
+ path_in_repo="checkpoint_" + str(start_record_time),
471
  repo_id=upload_to_hub,
472
  repo_type="dataset",
473
  token=access_token,
474
  )
475
+ dialogue_model.chat.buffer = []
476
+ text_str = ""
477
  audio_output = None
478
  audio_output1 = None
479
  asr_output_str = ""
480
  start_record_time = None
481
+ LLM_response_arr = []
482
+ total_response_arr = []
483
+ shutil.rmtree("flagged_data_points")
484
  os.mkdir("flagged_data_points")
485
+ yield (stream, asr_output_str, text_str1, audio_output, audio_output1)
486
+ yield stream, gr.Textbox(visible=True), gr.Textbox(visible=True), gr.Audio(
487
+ visible=True
488
+ ), gr.Audio(visible=False)
489
 
490
+ yield (stream, asr_output_str, text_str1, audio_output, audio_output1)
491
 
492
+
493
+ # ------------------------
494
+ # Executable Script
495
+ # ------------------------
496
+ api = HfApi()
497
+ nltk.download("averaged_perceptron_tagger_eng")
498
+ start_warmup()
499
  with gr.Blocks(
500
+ title="E2E Spoken Dialog System",
501
+ ) as demo:
502
+ with gr.Row():
503
+ gr.Markdown(
504
+ """
505
+ ## ESPnet-SDS
506
+ Welcome to our unified web interface for various cascaded and
507
+ E2E spoken dialogue systems built using ESPnet-SDS toolkit,
508
+ supporting real-time automated evaluation metrics, and
509
+ human-in-the-loop feedback collection.
510
+
511
+ For more details on how to use the app, refer to the [README]
512
+ (https://github.com/siddhu001/espnet/tree/sds_demo_recipe/egs2/TEMPLATE/sds1#how-to-use).
513
+ """
514
+ )
515
+ with gr.Row():
516
+ with gr.Column(scale=1):
517
+ user_audio = gr.Audio(
518
+ sources=["microphone"],
519
+ streaming=True,
520
+ waveform_options=gr.WaveformOptions(sample_rate=16000),
521
+ )
522
+ with gr.Row():
523
+ type_radio = gr.Radio(
524
+ choices=["Cascaded", "E2E"],
525
+ label="Choose type of Spoken Dialog:",
526
+ value="Cascaded",
527
+ )
528
+ with gr.Row():
529
+ ASR_radio = gr.Radio(
530
+ choices=ASR_options,
531
+ label="Choose ASR:",
532
+ value=ASR_name,
533
+ )
534
+ with gr.Row():
535
+ LLM_radio = gr.Radio(
536
+ choices=LLM_options,
537
+ label="Choose LLM:",
538
+ value=LLM_name,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539
  )
540
+ with gr.Row():
541
+ radio = gr.Radio(
542
+ choices=TTS_options,
543
+ label="Choose TTS:",
544
+ value=TTS_name,
545
+ )
546
+ with gr.Row():
547
+ E2Eradio = gr.Radio(
548
+ choices=["mini-omni"],
549
+ label="Choose E2E model:",
550
+ value="mini-omni",
551
  visible=False,
552
  )
553
+ with gr.Row():
554
+ feedback_btn = gr.Button(
555
+ value=(
556
+ "Please provide your feedback "
557
+ "after each system response below."
558
+ ),
559
+ visible=True,
560
+ interactive=False,
561
+ elem_id="button",
562
+ )
563
+ with gr.Row():
564
+ natural_btn1 = gr.Button(
565
+ value="Very Natural", visible=False, interactive=False, scale=1
566
+ )
567
+ natural_btn2 = gr.Button(
568
+ value="Somewhat Awkward", visible=False, interactive=False, scale=1
569
+ )
570
+ natural_btn3 = gr.Button(
571
+ value="Very Awkward", visible=False, interactive=False, scale=1
572
+ )
573
+ natural_btn4 = gr.Button(
574
+ value="Unnatural", visible=False, interactive=False, scale=1
575
+ )
576
+ with gr.Row():
577
+ relevant_btn1 = gr.Button(
578
+ value="Highly Relevant", visible=False, interactive=False, scale=1
579
+ )
580
+ relevant_btn2 = gr.Button(
581
+ value="Partially Relevant",
582
+ visible=False,
583
+ interactive=False,
584
+ scale=1,
585
+ )
586
+ relevant_btn3 = gr.Button(
587
+ value="Slightly Irrelevant",
588
+ visible=False,
589
+ interactive=False,
590
+ scale=1,
591
+ )
592
+ relevant_btn4 = gr.Button(
593
+ value="Completely Irrelevant",
594
+ visible=False,
595
+ interactive=False,
596
+ scale=1,
597
+ )
598
+ with gr.Column(scale=1):
599
+ output_audio = gr.Audio(label="Output", autoplay=True, visible=True)
600
+ output_audio1 = gr.Audio(label="Output1", autoplay=False, visible=False)
601
+ output_asr_text = gr.Textbox(label="ASR output")
602
+ output_text = gr.Textbox(label="LLM output")
603
+ eval_radio = gr.Radio(
604
+ choices=[
605
+ "Latency",
606
+ "TTS Intelligibility",
607
+ "TTS Speech Quality",
608
+ "ASR WER",
609
+ "Text Dialog Metrics",
610
+ ],
611
+ label="Choose Evaluation metrics:",
612
+ )
613
+ eval_radio_E2E = gr.Radio(
614
+ choices=[
615
+ "Latency",
616
+ "TTS Intelligibility",
617
+ "TTS Speech Quality",
618
+ "Text Dialog Metrics",
619
+ ],
620
+ label="Choose Evaluation metrics:",
621
+ visible=False,
622
+ )
623
+ output_eval_text = gr.Textbox(label="Evaluation Results")
624
+ state = gr.State()
625
+ with gr.Row():
626
+ privacy_text = gr.Textbox(
627
+ label="Privacy Notice",
628
+ interactive=False,
629
+ value=(
630
+ "By using this demo, you acknowledge that"
631
+ "interactions with this dialog system are collected "
632
+ "for research and improvement purposes. The data "
633
+ "will only be used to enhance the performance and "
634
+ "understanding of the system. If you have any "
635
+ "concerns about data collection, please discontinue "
636
+ "use."
637
+ ),
638
+ )
639
+
640
+ btn_list = [
641
+ natural_btn1,
642
+ natural_btn2,
643
+ natural_btn3,
644
+ natural_btn4,
645
+ relevant_btn1,
646
+ relevant_btn2,
647
+ relevant_btn3,
648
+ relevant_btn4,
649
+ ]
650
+ natural_btn_list = [
651
+ natural_btn1,
652
+ natural_btn2,
653
+ natural_btn3,
654
+ natural_btn4,
655
+ ]
656
+ relevant_btn_list = [
657
+ relevant_btn1,
658
+ relevant_btn2,
659
+ relevant_btn3,
660
+ relevant_btn4,
661
+ ]
662
+ natural_response = gr.Textbox(
663
+ label="natural_response", visible=False, interactive=False
664
+ )
665
+ diversity_response = gr.Textbox(
666
+ label="diversity_response", visible=False, interactive=False
667
+ )
668
+ ip_address = gr.Textbox(label="ip_address", visible=False, interactive=False)
669
+ callback.setup(
670
+ [
671
+ user_audio,
672
+ output_asr_text,
673
+ output_text,
674
+ output_audio,
675
+ output_audio1,
676
+ type_radio,
677
+ ASR_radio,
678
+ LLM_radio,
679
+ radio,
680
+ E2Eradio,
681
+ natural_response,
682
+ diversity_response,
683
+ ip_address,
684
+ ],
685
+ "flagged_data_points",
686
+ )
687
+ user_audio.stream(
688
+ transcribe,
689
+ inputs=[state, user_audio, radio, ASR_radio, LLM_radio, type_radio],
690
+ outputs=[state, output_asr_text, output_text, output_audio, output_audio1],
691
+ ).then(
692
+ lambda *args: callback.flag(list(args)), [user_audio], None, preprocess=False
693
+ )
694
+ radio.change(
695
+ fn=dialogue_model.handle_TTS_selection,
696
+ inputs=[radio],
697
+ outputs=[output_asr_text, output_text, output_audio],
698
+ )
699
+ LLM_radio.change(
700
+ fn=dialogue_model.handle_LLM_selection,
701
+ inputs=[LLM_radio],
702
+ outputs=[output_asr_text, output_text, output_audio],
703
+ )
704
+ ASR_radio.change(
705
+ fn=dialogue_model.handle_ASR_selection,
706
+ inputs=[ASR_radio],
707
+ outputs=[output_asr_text, output_text, output_audio],
708
+ )
709
+ eval_radio.change(
710
+ fn=handle_eval_selection,
711
+ inputs=[eval_radio, output_audio, output_text, output_audio1, output_asr_text],
712
+ outputs=[eval_radio, output_eval_text],
713
+ )
714
+ eval_radio_E2E.change(
715
+ fn=handle_eval_selection_E2E,
716
+ inputs=[eval_radio_E2E, output_audio, output_text],
717
+ outputs=[eval_radio_E2E, output_eval_text],
718
+ )
719
+ type_radio.change(
720
+ fn=dialogue_model.handle_type_selection,
721
+ inputs=[type_radio, radio, ASR_radio, LLM_radio],
722
+ outputs=[
723
+ radio,
724
+ ASR_radio,
725
+ LLM_radio,
726
+ E2Eradio,
727
+ output_asr_text,
728
+ output_text,
729
+ output_audio,
730
+ eval_radio,
731
+ eval_radio_E2E,
732
+ ],
733
+ )
734
+ output_audio.play(
735
+ flash_buttons, [], [natural_response, diversity_response] + btn_list
736
+ ).then(
737
+ lambda *args: callback.flag(list(args)),
738
+ [
739
+ user_audio,
740
+ output_asr_text,
741
+ output_text,
742
+ output_audio,
743
+ output_audio1,
744
+ type_radio,
745
+ ASR_radio,
746
+ LLM_radio,
747
+ radio,
748
+ E2Eradio,
749
+ ],
750
+ None,
751
+ preprocess=False,
752
+ )
753
+ natural_btn1.click(
754
+ natural_vote1_last_response,
755
+ [],
756
+ [natural_response, ip_address] + natural_btn_list,
757
+ ).then(
758
+ lambda *args: callback.flag(list(args)),
759
+ [
760
+ user_audio,
761
+ output_asr_text,
762
+ output_text,
763
+ output_audio,
764
+ output_audio1,
765
+ type_radio,
766
+ ASR_radio,
767
+ LLM_radio,
768
+ radio,
769
+ E2Eradio,
770
+ natural_response,
771
+ diversity_response,
772
+ ip_address,
773
+ ],
774
+ None,
775
+ preprocess=False,
776
+ )
777
+ natural_btn2.click(
778
+ natural_vote2_last_response,
779
+ [],
780
+ [natural_response, ip_address] + natural_btn_list,
781
+ ).then(
782
+ lambda *args: callback.flag(list(args)),
783
+ [
784
+ user_audio,
785
+ output_asr_text,
786
+ output_text,
787
+ output_audio,
788
+ output_audio1,
789
+ type_radio,
790
+ ASR_radio,
791
+ LLM_radio,
792
+ radio,
793
+ E2Eradio,
794
+ natural_response,
795
+ diversity_response,
796
+ ip_address,
797
+ ],
798
+ None,
799
+ preprocess=False,
800
+ )
801
+ natural_btn3.click(
802
+ natural_vote3_last_response,
803
+ [],
804
+ [natural_response, ip_address] + natural_btn_list,
805
+ ).then(
806
+ lambda *args: callback.flag(list(args)),
807
+ [
808
+ user_audio,
809
+ output_asr_text,
810
+ output_text,
811
+ output_audio,
812
+ output_audio1,
813
+ type_radio,
814
+ ASR_radio,
815
+ LLM_radio,
816
+ radio,
817
+ E2Eradio,
818
+ natural_response,
819
+ diversity_response,
820
+ ip_address,
821
+ ],
822
+ None,
823
+ preprocess=False,
824
+ )
825
+ natural_btn4.click(
826
+ natural_vote4_last_response,
827
+ [],
828
+ [natural_response, ip_address] + natural_btn_list,
829
+ ).then(
830
+ lambda *args: callback.flag(list(args)),
831
+ [
832
+ user_audio,
833
+ output_asr_text,
834
+ output_text,
835
+ output_audio,
836
+ output_audio1,
837
+ type_radio,
838
+ ASR_radio,
839
+ LLM_radio,
840
+ radio,
841
+ E2Eradio,
842
+ natural_response,
843
+ diversity_response,
844
+ ip_address,
845
+ ],
846
+ None,
847
+ preprocess=False,
848
+ )
849
+ relevant_btn1.click(
850
+ relevant_vote1_last_response,
851
+ [],
852
+ [diversity_response, ip_address] + relevant_btn_list,
853
+ ).then(
854
+ lambda *args: callback.flag(list(args)),
855
+ [
856
+ user_audio,
857
+ output_asr_text,
858
+ output_text,
859
+ output_audio,
860
+ output_audio1,
861
+ type_radio,
862
+ ASR_radio,
863
+ LLM_radio,
864
+ radio,
865
+ E2Eradio,
866
+ natural_response,
867
+ diversity_response,
868
+ ip_address,
869
+ ],
870
+ None,
871
+ preprocess=False,
872
+ )
873
+ relevant_btn2.click(
874
+ relevant_vote2_last_response,
875
+ [],
876
+ [diversity_response, ip_address] + relevant_btn_list,
877
+ ).then(
878
+ lambda *args: callback.flag(list(args)),
879
+ [
880
+ user_audio,
881
+ output_asr_text,
882
+ output_text,
883
+ output_audio,
884
+ output_audio1,
885
+ type_radio,
886
+ ASR_radio,
887
+ LLM_radio,
888
+ radio,
889
+ E2Eradio,
890
+ natural_response,
891
+ diversity_response,
892
+ ip_address,
893
+ ],
894
+ None,
895
+ preprocess=False,
896
+ )
897
+ relevant_btn3.click(
898
+ relevant_vote3_last_response,
899
+ [],
900
+ [diversity_response, ip_address] + relevant_btn_list,
901
+ ).then(
902
+ lambda *args: callback.flag(list(args)),
903
+ [
904
+ user_audio,
905
+ output_asr_text,
906
+ output_text,
907
+ output_audio,
908
+ output_audio1,
909
+ type_radio,
910
+ ASR_radio,
911
+ LLM_radio,
912
+ radio,
913
+ E2Eradio,
914
+ natural_response,
915
+ diversity_response,
916
+ ip_address,
917
+ ],
918
+ None,
919
+ preprocess=False,
920
+ )
921
+ relevant_btn4.click(
922
+ relevant_vote4_last_response,
923
+ [],
924
+ [diversity_response, ip_address] + relevant_btn_list,
925
+ ).then(
926
+ lambda *args: callback.flag(list(args)),
927
+ [
928
+ user_audio,
929
+ output_asr_text,
930
+ output_text,
931
+ output_audio,
932
+ output_audio1,
933
+ type_radio,
934
+ ASR_radio,
935
+ LLM_radio,
936
+ radio,
937
+ E2Eradio,
938
+ natural_response,
939
+ diversity_response,
940
+ ip_address,
941
+ ],
942
+ None,
943
+ preprocess=False,
944
+ )
945
  demo.launch(share=True)
946
+
pyscripts/utils/dialog_eval/ASR_WER.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import numpy as np
4
+
5
+ from espnet2.sds.utils.utils import int2float
6
+
7
+
8
+ def handle_espnet_ASR_WER(
9
+ ASR_audio_output: Tuple[int, np.ndarray], ASR_transcript: str
10
+ ) -> str:
11
+ """
12
+ Compute and return Word Error Rate (WER) and Character Error Rate (CER) metrics
13
+ for multiple judge ASR systems (ESPnet, OWSM, Whisper) using the Versa library.
14
+
15
+ This function performs the following:
16
+ 1. Imports necessary metrics and setup functions from Versa.
17
+ 2. Prepares configuration arguments for each ASR system (ESPnet, OWSM, Whisper).
18
+ 3. Runs the Levenshtein-based WER/CER calculations.
19
+ 4. Returns a formatted string summarizing WER and CER
20
+ results for reference produced by each ASR system.
21
+
22
+ Args:
23
+ ASR_audio_output (tuple):
24
+ A tuple where:
25
+ - The first element is the frame rate.
26
+ - The second element is the audio signal (NumPy array).
27
+ ASR_transcript (str):
28
+ The transcript produced by the ASR model in the cascaded
29
+ conversational AI pipeline.
30
+
31
+ Returns:
32
+ str:
33
+ A formatted string showing the WER and CER percentages
34
+ for ESPnet, OWSM, and Whisper. Example output:
35
+
36
+ "ESPnet WER: 10.50
37
+ ESPnet CER: 7.20
38
+ OWSM WER: 11.30
39
+ OWSM CER: 8.00
40
+ Whisper WER: 9.25
41
+ Whisper CER: 6.50"
42
+
43
+ Raises:
44
+ ImportError:
45
+ If Versa is not installed or cannot be imported.
46
+
47
+ Example:
48
+ >>> asr_audio_output = (16000, audio_array)
49
+ >>> asr_transcript = "This is the ASR transcript."
50
+ >>> result = handle_espnet_ASR_WER(asr_audio_output, asr_transcript)
51
+ >>> print(result)
52
+ "ESPnet WER: 10.50
53
+ ESPnet CER: 7.20
54
+ OWSM WER: 11.30
55
+ OWSM CER: 8.00
56
+ Whisper WER: 9.25
57
+ Whisper CER: 6.50"
58
+ """
59
+ try:
60
+ from versa import (
61
+ espnet_levenshtein_metric,
62
+ espnet_wer_setup,
63
+ owsm_levenshtein_metric,
64
+ owsm_wer_setup,
65
+ whisper_levenshtein_metric,
66
+ whisper_wer_setup,
67
+ )
68
+ except Exception as e:
69
+ print("Error: Versa is not properly installed.")
70
+ raise e
71
+ score_modules_espnet = {
72
+ "module": espnet_levenshtein_metric,
73
+ "args": espnet_wer_setup(
74
+ model_tag="default",
75
+ beam_size=1,
76
+ text_cleaner="whisper_en",
77
+ use_gpu=True,
78
+ ),
79
+ }
80
+ dict1 = score_modules_espnet["module"](
81
+ score_modules_espnet["args"],
82
+ int2float(ASR_audio_output[1]),
83
+ ASR_transcript,
84
+ ASR_audio_output[0],
85
+ )
86
+ espnet_wer = (
87
+ dict1["espnet_wer_delete"]
88
+ + dict1["espnet_wer_insert"]
89
+ + dict1["espnet_wer_replace"]
90
+ ) / (
91
+ dict1["espnet_wer_insert"]
92
+ + dict1["espnet_wer_replace"]
93
+ + dict1["espnet_wer_equal"]
94
+ )
95
+ espnet_cer = (
96
+ dict1["espnet_cer_delete"]
97
+ + dict1["espnet_cer_insert"]
98
+ + dict1["espnet_cer_replace"]
99
+ ) / (
100
+ dict1["espnet_cer_insert"]
101
+ + dict1["espnet_cer_replace"]
102
+ + dict1["espnet_cer_equal"]
103
+ )
104
+ score_modules_owsm = {
105
+ "module": owsm_levenshtein_metric,
106
+ "args": owsm_wer_setup(
107
+ model_tag="default",
108
+ beam_size=1,
109
+ text_cleaner="whisper_en",
110
+ use_gpu=True,
111
+ ),
112
+ }
113
+ dict1 = score_modules_owsm["module"](
114
+ score_modules_owsm["args"],
115
+ int2float(ASR_audio_output[1]),
116
+ ASR_transcript,
117
+ ASR_audio_output[0],
118
+ )
119
+ owsm_wer = (
120
+ dict1["owsm_wer_delete"] + dict1["owsm_wer_insert"] + dict1["owsm_wer_replace"]
121
+ ) / (dict1["owsm_wer_insert"] + dict1["owsm_wer_replace"] + dict1["owsm_wer_equal"])
122
+ owsm_cer = (
123
+ dict1["owsm_cer_delete"] + dict1["owsm_cer_insert"] + dict1["owsm_cer_replace"]
124
+ ) / (dict1["owsm_cer_insert"] + dict1["owsm_cer_replace"] + dict1["owsm_cer_equal"])
125
+ score_modules_whisper = {
126
+ "module": whisper_levenshtein_metric,
127
+ "args": whisper_wer_setup(
128
+ model_tag="default",
129
+ beam_size=1,
130
+ text_cleaner="whisper_en",
131
+ use_gpu=True,
132
+ ),
133
+ }
134
+ dict1 = score_modules_whisper["module"](
135
+ score_modules_whisper["args"],
136
+ int2float(ASR_audio_output[1]),
137
+ ASR_transcript,
138
+ ASR_audio_output[0],
139
+ )
140
+ whisper_wer = (
141
+ dict1["whisper_wer_delete"]
142
+ + dict1["whisper_wer_insert"]
143
+ + dict1["whisper_wer_replace"]
144
+ ) / (
145
+ dict1["whisper_wer_insert"]
146
+ + dict1["whisper_wer_replace"]
147
+ + dict1["whisper_wer_equal"]
148
+ )
149
+ whisper_cer = (
150
+ dict1["whisper_cer_delete"]
151
+ + dict1["whisper_cer_insert"]
152
+ + dict1["whisper_cer_replace"]
153
+ ) / (
154
+ dict1["whisper_cer_insert"]
155
+ + dict1["whisper_cer_replace"]
156
+ + dict1["whisper_cer_equal"]
157
+ )
158
+ return (
159
+ f"ESPnet WER: {espnet_wer*100:.2f}\n"
160
+ f"ESPnet CER: {espnet_cer*100:.2f}\n"
161
+ f"OWSM WER: {owsm_wer*100:.2f}\n"
162
+ f"OWSM CER: {owsm_cer*100:.2f}\n"
163
+ f"Whisper WER: {whisper_wer*100:.2f}\n"
164
+ f"Whisper CER: {whisper_cer*100:.2f}"
165
+ )
pyscripts/utils/dialog_eval/LLM_Metrics.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocessing import Pool
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import torch
6
+ from pyscripts.utils.dialog_eval.vert import (
7
+ get_auto_bleu2_geometric,
8
+ get_self_bleu2_geometric,
9
+ run_f,
10
+ )
11
+ from scipy.stats import gmean
12
+ from sklearn.metrics.pairwise import cosine_similarity
13
+ from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
14
+
15
+
16
+ def perplexity(LLM_Output: str, model_id: str = "gpt2") -> str:
17
+ """
18
+ Compute the perplexity of the given text using a specified model from the
19
+ `evaluate` library (default: GPT-2).
20
+
21
+ Args:
22
+ LLM_Output str:
23
+ The text (string) for which perplexity is to be computed.
24
+ model_id (str, optional):
25
+ The identifier of the model to use for computing
26
+ perplexity. Defaults to "gpt2".
27
+
28
+ Returns:
29
+ str:
30
+ A formatted string showing the perplexity of the
31
+ provided text(s), for example:
32
+ "Perplexity: 45.23\n"
33
+
34
+ Raises:
35
+ ImportError:
36
+ If the `evaluate` library is not installed or cannot be imported.
37
+
38
+ Example:
39
+ >>> text = "Hello world, this is a test."
40
+ >>> result = perplexity(text, model_id="gpt2")
41
+ >>> print(result)
42
+ "Perplexity: 27.34\n"
43
+ """
44
+ try:
45
+ import evaluate
46
+ except Exception as e:
47
+ print("Error: evaluate is not properly installed.")
48
+ raise e
49
+ perplexity = evaluate.load("perplexity", module_type="metric")
50
+ results = perplexity.compute(model_id=model_id, predictions=[LLM_Output])
51
+ return f"Perplexity: {results['mean_perplexity']:.2f}\n"
52
+
53
+
54
+ def vert(LLM_response_arr: List[str]) -> str:
55
+ """
56
+ Calculate and return Self BLEU-2, Auto BLEU-2 and VERT-2
57
+ metrics for a list of LLM responses.
58
+
59
+ Args:
60
+ LLM_response_arr (List[str]):
61
+ A list of responses (strings) generated by the language
62
+ model acting as text dialog response generator.
63
+
64
+ Returns:
65
+ str:
66
+ A formatted string that includes each computed metric and the final
67
+ VERT value, for example:
68
+
69
+ "Self-BLEU2-geometric: 42.13
70
+ Auto-BLEU2-geometric: 38.94
71
+ VERT: 40.5
72
+ "
73
+
74
+ Example:
75
+ >>> # Suppose we have the following LLM responses:
76
+ >>> responses = ["Hello world", "Foo bar", "Lorem ipsum dolor sit amet"]
77
+ >>> result = vert(responses)
78
+ >>> print(result)
79
+ "Self-BLEU2-geometric: 42.13
80
+ Auto-BLEU2-geometric: 38.94
81
+ VERT: 40.5
82
+ "
83
+ """
84
+ terms = [x.strip().split() for x in LLM_response_arr]
85
+
86
+ tasks = [
87
+ ("Self-BLEU2-geometric", get_self_bleu2_geometric),
88
+ ("Auto-BLEU2-geometric", get_auto_bleu2_geometric),
89
+ ]
90
+ n_processes = min(16, len(tasks))
91
+ with Pool(n_processes) as pool:
92
+ metrics = pool.map(run_f, [(t[1], terms) for t in tasks])
93
+ metric_arr = []
94
+ str1 = ""
95
+ for (metric_name, _), metric in zip(tasks, metrics):
96
+ metric, sem = np.mean(metric), np.std(metric) / np.sqrt(len(metric))
97
+
98
+ metric, sem = [round(100 * x, 2) for x in [metric, sem]]
99
+ metric_arr.append(metric)
100
+
101
+ str1 += f"{metric_name}: {metric}\n"
102
+ str1 += f"VERT: {round(gmean(metric_arr), 2)}\n"
103
+ return str1
104
+
105
+
106
+ def bert_score(
107
+ total_response_arr: List[str], bert_model_name: str = "bert-base-uncased"
108
+ ) -> str:
109
+ """
110
+ Compute a cosine similarity score between the concatenated
111
+ context (all but the last element)
112
+ and the final response (last element) using a BERT-based model.
113
+ This serves as a simplified
114
+ measure of how closely the response aligns with the preceding context semantically.
115
+
116
+ Args:
117
+ total_response_arr (List[str]):
118
+ A list of strings. The last element represents the response,
119
+ while all other elements
120
+ are treated as the context.
121
+ bert_model_name (str, optional):
122
+ The name or path of the BERT model to use (from the Hugging Face Model Hub).
123
+ Defaults to "bert-base-uncased".
124
+
125
+ Returns:
126
+ str:
127
+ A string containing the cosine similarity
128
+ (as a percentage) followed by a newline.
129
+ For example:
130
+ "Cosine Similarity: 85.67\n"
131
+
132
+ Example:
133
+ >>> total_responses = [
134
+ ... "User: Hi, how are you?",
135
+ ... "Assistant: I'm good! How can I help you today?",
136
+ ... "User: Can you tell me a joke?",
137
+ ... "Assistant: Sure! Here's one: Why did the chicken join a band?"
138
+ ... ]
139
+ >>> result = bert_score(total_responses, bert_model_name="bert-base-uncased")
140
+ >>> print(result)
141
+ "Cosine Similarity: 75.89\n"
142
+ """
143
+
144
+ def cosine_similarity_context_response(context, response, model, tokenizer):
145
+ # Tokenize and encode both context and response
146
+ context_inputs = tokenizer(context, return_tensors="pt", truncation=True)
147
+ response_inputs = tokenizer(response, return_tensors="pt", truncation=True)
148
+ for k in context_inputs:
149
+ context_inputs[k] = context_inputs[k].cuda()
150
+ for k in response_inputs:
151
+ response_inputs[k] = response_inputs[k].cuda()
152
+
153
+ # Get embeddings from the model
154
+ with torch.no_grad():
155
+ context_embedding = model(**context_inputs).last_hidden_state.mean(dim=1)
156
+ response_embedding = model(**response_inputs).last_hidden_state.mean(dim=1)
157
+
158
+ # Compute cosine similarity
159
+ similarity = cosine_similarity(
160
+ context_embedding.cpu().numpy(), response_embedding.cpu().numpy()
161
+ )
162
+ return similarity[0][0]
163
+
164
+ bert_model = AutoModel.from_pretrained(bert_model_name).cuda()
165
+ bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
166
+ similarity = cosine_similarity_context_response(
167
+ " ".join(total_response_arr[:-1]),
168
+ total_response_arr[-1],
169
+ bert_model,
170
+ bert_tokenizer,
171
+ )
172
+ return f"Cosine Similarity: {similarity*100:.2f}" + "\n"
173
+
174
+
175
+ def DialoGPT_perplexity(
176
+ user_utterance: str,
177
+ response: str,
178
+ dialog_model_name: str = "microsoft/DialoGPT-medium",
179
+ ) -> str:
180
+ """
181
+ Compute the perplexity of a response given a user utterance using a pre-trained
182
+ DialoGPT model. The function loads DialoGPT (medium by default)
183
+ from the Hugging Face Model Hub, then calculates the perplexity
184
+ for the
185
+ (context + response) sequence.
186
+
187
+ Args:
188
+ user_utterance (str):
189
+ The user utterance preceding the model's response.
190
+ response (str):
191
+ The generated response whose perplexity needs to be evaluated.
192
+
193
+ Returns:
194
+ str:
195
+ A formatted string containing the DialoGPT perplexity score. For example:
196
+ "DialoGPT Perplexity: 25.67\n"
197
+
198
+ Example:
199
+ >>> user_text = "Hi, how are you today?"
200
+ >>> system_response = "I'm good, thank you! How can I help you?"
201
+ >>> result = DialoGPT_perplexity(user_text, system_response)
202
+ >>> print(result)
203
+ "DialoGPT Perplexity: 31.45\n"
204
+ """
205
+
206
+ def evaluate_response_with_dialoGPT(context, response, model, tokenizer):
207
+ """
208
+ Evaluate the appropriateness of a response based on the
209
+ given context using DialoGPT.
210
+
211
+ Args:
212
+ context (str): The dialogue context (previous conversation).
213
+ response (str): The generated response to evaluate.
214
+ model: Pre-trained DialoGPT model.
215
+ tokenizer: Corresponding tokenizer for the DialoGPT model.
216
+
217
+ Returns:
218
+ float: Perplexity score of the response given the context.
219
+ """
220
+ model.eval()
221
+
222
+ # Combine context and response as input
223
+ input_text = context + tokenizer.eos_token + response + tokenizer.eos_token
224
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
225
+ inputs["input_ids"] = inputs["input_ids"].cuda()
226
+ inputs["attention_mask"] = inputs["attention_mask"].cuda()
227
+ # import pdb;pdb.set_trace()
228
+
229
+ # Compute model outputs and loss
230
+ with torch.no_grad():
231
+ outputs = model(**inputs, labels=inputs["input_ids"].cuda())
232
+ loss = outputs.loss
233
+
234
+ # Calculate perplexity
235
+ perplexity = torch.exp(loss)
236
+ return perplexity.cpu().item()
237
+
238
+ # Load DialoGPT model and tokenizer
239
+ model_name = dialog_model_name
240
+ model = AutoModelForCausalLM.from_pretrained(model_name).cuda()
241
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
242
+ perplexity = evaluate_response_with_dialoGPT(
243
+ user_utterance, response, model, tokenizer
244
+ )
245
+ return f"DialoGPT Perplexity: {perplexity:.2f}" + "\n"
pyscripts/utils/dialog_eval/TTS_intelligibility.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import numpy as np
4
+
5
+ from espnet2.sds.utils.utils import int2float
6
+
7
+
8
+ def handle_espnet_TTS_intelligibility(
9
+ TTS_audio_output: Tuple[int, np.ndarray], LLM_Output: str
10
+ ) -> str:
11
+ """
12
+ Compute and return Word Error Rate (WER) and Character Error Rate (CER) metrics
13
+ for multiple ASR systems (ESPnet, OWSM, Whisper) using the Versa library.
14
+
15
+ This function:
16
+ 1. Imports the necessary metrics and setup functions from Versa.
17
+ 2. Prepares configuration arguments for each ASR system (ESPnet, OWSM, Whisper).
18
+ 3. Runs the Levenshtein-based WER/CER calculations on the provided TTS audio.
19
+ 4. Returns a formatted string summarizing WER and CER results
20
+ for hypotheses produced
21
+ by each ASR system when transcribing the TTS audio, using
22
+ the LLM output as the reference text.
23
+
24
+ Args:
25
+ TTS_audio_output (Tuple[int, np.ndarray]):
26
+ A tuple consisting of:
27
+ - The first element (int): the frame rate of the audio.
28
+ - The second element (np.ndarray):
29
+ the audio signal (e.g., a NumPy array).
30
+ LLM_Output (str):
31
+ The reference text generated by the LLM, which serves as the ground truth
32
+ for evaluating the TTS audio.
33
+
34
+ Returns:
35
+ str:
36
+ A formatted string showing the WER and CER percentages
37
+ for ESPnet, OWSM, and Whisper.
38
+ Example:
39
+
40
+ ESPnet WER: 10.50
41
+ ESPnet CER: 7.20
42
+ OWSM WER: 11.30
43
+ OWSM CER: 8.00
44
+ Whisper WER: 9.25
45
+ Whisper CER: 6.50
46
+
47
+ Raises:
48
+ ImportError:
49
+ If the Versa library is not installed or cannot be imported.
50
+
51
+ Example:
52
+ >>> tts_audio_output = (16000, audio_array)
53
+ >>> llm_output = "This is the reference text for evaluation."
54
+ >>> result = handle_espnet_TTS_intelligibility(tts_audio_output, llm_output)
55
+ >>> print(result)
56
+ ESPnet WER: 10.50
57
+ ESPnet CER: 7.20
58
+ OWSM WER: 11.30
59
+ OWSM CER: 8.00
60
+ Whisper WER: 9.25
61
+ Whisper CER: 6.50
62
+ """
63
+ try:
64
+ from versa import (
65
+ espnet_levenshtein_metric,
66
+ espnet_wer_setup,
67
+ owsm_levenshtein_metric,
68
+ owsm_wer_setup,
69
+ whisper_levenshtein_metric,
70
+ whisper_wer_setup,
71
+ )
72
+ except Exception as e:
73
+ print("Error: Versa is not properly installed.")
74
+ raise e
75
+ score_modules_espnet = {
76
+ "module": espnet_levenshtein_metric,
77
+ "args": espnet_wer_setup(
78
+ model_tag="default",
79
+ beam_size=1,
80
+ text_cleaner="whisper_en",
81
+ use_gpu=True,
82
+ ),
83
+ }
84
+ dict1 = score_modules_espnet["module"](
85
+ score_modules_espnet["args"],
86
+ int2float(TTS_audio_output[1]),
87
+ LLM_Output,
88
+ TTS_audio_output[0],
89
+ )
90
+ espnet_wer = (
91
+ dict1["espnet_wer_delete"]
92
+ + dict1["espnet_wer_insert"]
93
+ + dict1["espnet_wer_replace"]
94
+ ) / (
95
+ dict1["espnet_wer_delete"]
96
+ + dict1["espnet_wer_replace"]
97
+ + dict1["espnet_wer_equal"]
98
+ )
99
+ espnet_cer = (
100
+ dict1["espnet_cer_delete"]
101
+ + dict1["espnet_cer_insert"]
102
+ + dict1["espnet_cer_replace"]
103
+ ) / (
104
+ dict1["espnet_cer_delete"]
105
+ + dict1["espnet_cer_replace"]
106
+ + dict1["espnet_cer_equal"]
107
+ )
108
+ score_modules_owsm = {
109
+ "module": owsm_levenshtein_metric,
110
+ "args": owsm_wer_setup(
111
+ model_tag="default",
112
+ beam_size=1,
113
+ text_cleaner="whisper_en",
114
+ use_gpu=True,
115
+ ),
116
+ }
117
+ dict1 = score_modules_owsm["module"](
118
+ score_modules_owsm["args"],
119
+ int2float(TTS_audio_output[1]),
120
+ LLM_Output,
121
+ TTS_audio_output[0],
122
+ )
123
+ owsm_wer = (
124
+ dict1["owsm_wer_delete"] + dict1["owsm_wer_insert"] + dict1["owsm_wer_replace"]
125
+ ) / (dict1["owsm_wer_delete"] + dict1["owsm_wer_replace"] + dict1["owsm_wer_equal"])
126
+ owsm_cer = (
127
+ dict1["owsm_cer_delete"] + dict1["owsm_cer_insert"] + dict1["owsm_cer_replace"]
128
+ ) / (dict1["owsm_cer_delete"] + dict1["owsm_cer_replace"] + dict1["owsm_cer_equal"])
129
+ score_modules_whisper = {
130
+ "module": whisper_levenshtein_metric,
131
+ "args": whisper_wer_setup(
132
+ model_tag="default",
133
+ beam_size=1,
134
+ text_cleaner="whisper_en",
135
+ use_gpu=True,
136
+ ),
137
+ }
138
+ dict1 = score_modules_whisper["module"](
139
+ score_modules_whisper["args"],
140
+ int2float(TTS_audio_output[1]),
141
+ LLM_Output,
142
+ TTS_audio_output[0],
143
+ )
144
+ whisper_wer = (
145
+ dict1["whisper_wer_delete"]
146
+ + dict1["whisper_wer_insert"]
147
+ + dict1["whisper_wer_replace"]
148
+ ) / (
149
+ dict1["whisper_wer_delete"]
150
+ + dict1["whisper_wer_replace"]
151
+ + dict1["whisper_wer_equal"]
152
+ )
153
+ whisper_cer = (
154
+ dict1["whisper_cer_delete"]
155
+ + dict1["whisper_cer_insert"]
156
+ + dict1["whisper_cer_replace"]
157
+ ) / (
158
+ dict1["whisper_cer_delete"]
159
+ + dict1["whisper_cer_replace"]
160
+ + dict1["whisper_cer_equal"]
161
+ )
162
+ return (
163
+ f"ESPnet WER: {espnet_wer*100:.2f}\n"
164
+ f"ESPnet CER: {espnet_cer*100:.2f}\n"
165
+ f"OWSM WER: {owsm_wer*100:.2f}\n"
166
+ f"OWSM CER: {owsm_cer*100:.2f}\n"
167
+ f"Whisper WER: {whisper_wer*100:.2f}\n"
168
+ f"Whisper CER: {whisper_cer*100:.2f}"
169
+ )
pyscripts/utils/dialog_eval/TTS_speech_quality.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import numpy as np
4
+
5
+ from espnet2.sds.utils.utils import int2float
6
+
7
+
8
+ def TTS_psuedomos(TTS_audio_output: Tuple[int, np.ndarray]) -> str:
9
+ """
10
+ Compute and return speech quality metrics
11
+ for the given synthesized audio output
12
+ using the Versa library.
13
+
14
+ Args:
15
+ TTS_audio_output (Tuple[int, np.ndarray]):
16
+ A tuple containing:
17
+ - The first element (int): The frame rate of the audio.
18
+ - The second element (np.ndarray): The audio signal,
19
+ typically a NumPy array.
20
+
21
+ Returns:
22
+ str:
23
+ A formatted string containing each metric name
24
+ and its corresponding score, for example:
25
+
26
+ utmos: 3.54
27
+ dnsmos: 3.47
28
+ plcmos: 3.62
29
+ sheet_ssqa: 4.03
30
+
31
+ Raises:
32
+ ImportError:
33
+ If the Versa library is not installed or cannot be imported.
34
+
35
+ Example:
36
+ >>> tts_audio_output = (16000, audio_array)
37
+ >>> result = TTS_psuedomos(tts_audio_output)
38
+ >>> print(result)
39
+ utmos: 3.54
40
+ dnsmos: 3.47
41
+ plcmos: 3.62
42
+ sheet_ssqa: 4.03
43
+ """
44
+ try:
45
+ from versa import (
46
+ pseudo_mos_metric,
47
+ pseudo_mos_setup,
48
+ sheet_ssqa,
49
+ sheet_ssqa_setup,
50
+ )
51
+ except Exception as e:
52
+ print("Error: Versa is not properly installed.")
53
+ raise e
54
+
55
+ predictor_dict, predictor_fs = pseudo_mos_setup(
56
+ use_gpu=True,
57
+ predictor_types=["utmos", "dnsmos", "plcmos"],
58
+ predictor_args={
59
+ "utmos": {"fs": 16000},
60
+ "dnsmos": {"fs": 16000},
61
+ "plcmos": {"fs": 16000},
62
+ },
63
+ )
64
+ score_modules = {
65
+ "module": pseudo_mos_metric,
66
+ "args": {
67
+ "predictor_dict": predictor_dict,
68
+ "predictor_fs": predictor_fs,
69
+ "use_gpu": True,
70
+ },
71
+ }
72
+ dict1 = score_modules["module"](
73
+ int2float(TTS_audio_output[1]),
74
+ TTS_audio_output[0],
75
+ **score_modules["args"],
76
+ )
77
+ str1 = ""
78
+ for k in dict1:
79
+ str1 = str1 + f"{k}: {dict1[k]:.2f}\n"
80
+ sheet_model = sheet_ssqa_setup(
81
+ model_tag="default",
82
+ model_path=None,
83
+ model_config=None,
84
+ use_gpu=True,
85
+ )
86
+ score_modules = {
87
+ "module": sheet_ssqa,
88
+ "args": {"model": sheet_model, "use_gpu": True},
89
+ }
90
+ dict1 = score_modules["module"](
91
+ score_modules["args"]["model"],
92
+ int2float(TTS_audio_output[1]),
93
+ TTS_audio_output[0],
94
+ use_gpu=score_modules["args"]["use_gpu"],
95
+ )
96
+ for k in dict1:
97
+ str1 = str1 + f"{k}: {dict1[k]:.2f}\n"
98
+ return str1
pyscripts/utils/dialog_eval/__pycache__/ASR_WER.cpython-39.pyc ADDED
Binary file (4.12 kB). View file
 
pyscripts/utils/dialog_eval/__pycache__/LLM_Metrics.cpython-39.pyc ADDED
Binary file (8.51 kB). View file
 
pyscripts/utils/dialog_eval/__pycache__/TTS_intelligibility.cpython-39.pyc ADDED
Binary file (4.34 kB). View file
 
pyscripts/utils/dialog_eval/__pycache__/TTS_speech_quality.cpython-39.pyc ADDED
Binary file (2.39 kB). View file
 
pyscripts/utils/dialog_eval/__pycache__/human_feedback.cpython-39.pyc ADDED
Binary file (7.34 kB). View file
 
pyscripts/utils/dialog_eval/__pycache__/vert.cpython-39.pyc ADDED
Binary file (9.13 kB). View file
 
pyscripts/utils/dialog_eval/human_feedback.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ disable_btn = gr.Button(interactive=False, visible=False)
4
+
5
+
6
+ def get_ip(request: gr.Request) -> str:
7
+ """
8
+ Retrieve the IP address from an incoming HTTP request.
9
+
10
+ Args:
11
+ request (gr.Request):
12
+ The incoming HTTP request from which the IP address will be extracted.
13
+
14
+ Returns:
15
+ str:
16
+ The IP address as a string.
17
+ """
18
+ if "cf-connecting-ip" in request.headers:
19
+ ip = request.headers["cf-connecting-ip"]
20
+ elif "x-forwarded-for" in request.headers:
21
+ ip = request.headers["x-forwarded-for"]
22
+ if "," in ip:
23
+ ip = ip.split(",")[0]
24
+ else:
25
+ ip = request.client.host
26
+ return ip
27
+
28
+
29
+ def natural_vote1_last_response(request: gr.Request):
30
+ """
31
+ Handle a user vote for naturalness as "Very Natural".
32
+
33
+
34
+ Args:
35
+ request (gr.Request):
36
+ The Gradio request object providing access to HTTP headers and metadata.
37
+
38
+ Returns:
39
+ tuple:
40
+ A tuple containing:
41
+ ("Very Natural", <ip_address>, (disable_btn,) * 4)
42
+
43
+ - "Very Natural": The selected vote or label.
44
+ - <ip_address>: The IP address of the client retrieved from the request.
45
+ - disable_btn: An object repeated four times,
46
+ to disable natural vote buttons.
47
+ """
48
+ ip_address1 = get_ip(request)
49
+ print(f"Very Natural (voted). ip: {ip_address1}")
50
+ return (
51
+ "Very Natural",
52
+ ip_address1,
53
+ ) + (disable_btn,) * 4
54
+
55
+
56
+ def natural_vote2_last_response(request: gr.Request):
57
+ """
58
+ Handle a user vote for naturalness as "Somewhat Awkward".
59
+
60
+
61
+ Args:
62
+ request (gr.Request):
63
+ The Gradio request object providing access to HTTP headers and metadata.
64
+
65
+ Returns:
66
+ tuple:
67
+ A tuple containing:
68
+ ("Somewhat Awkward", <ip_address>, (disable_btn,) * 4)
69
+
70
+ - "Somewhat Awkward": The selected vote or label.
71
+ - <ip_address>: The IP address of the client retrieved from the request.
72
+ - disable_btn: An object repeated four times,
73
+ to disable natural vote buttons.
74
+ """
75
+ ip_address1 = get_ip(request)
76
+ print(f"Somewhat Awkward (voted). ip: {ip_address1}")
77
+ return (
78
+ "Somewhat Awkward",
79
+ ip_address1,
80
+ ) + (disable_btn,) * 4
81
+
82
+
83
+ def natural_vote3_last_response(request: gr.Request):
84
+ """
85
+ Handle a user vote for naturalness as "Very Awkward".
86
+
87
+
88
+ Args:
89
+ request (gr.Request):
90
+ The Gradio request object providing access to HTTP headers and metadata.
91
+
92
+ Returns:
93
+ tuple:
94
+ A tuple containing:
95
+ ("Very Awkward", <ip_address>, (disable_btn,) * 4)
96
+
97
+ - "Very Awkward": The selected vote or label.
98
+ - <ip_address>: The IP address of the client retrieved from the request.
99
+ - disable_btn: An object repeated four times,
100
+ to disable natural vote buttons.
101
+ """
102
+ ip_address1 = get_ip(request)
103
+ print(f"Very Awkward (voted). ip: {ip_address1}")
104
+ return (
105
+ "Very Awkward",
106
+ ip_address1,
107
+ ) + (disable_btn,) * 4
108
+
109
+
110
+ def natural_vote4_last_response(request: gr.Request):
111
+ """
112
+ Handle a user vote for naturalness as "Unnatural".
113
+
114
+
115
+ Args:
116
+ request (gr.Request):
117
+ The Gradio request object providing access to HTTP headers and metadata.
118
+
119
+ Returns:
120
+ tuple:
121
+ A tuple containing:
122
+ ("Unnatural", <ip_address>, (disable_btn,) * 4)
123
+
124
+ - "Unnatural": The selected vote or label.
125
+ - <ip_address>: The IP address of the client retrieved from the request.
126
+ - disable_btn: An object repeated four times,
127
+ to disable natural vote buttons.
128
+ """
129
+ ip_address1 = get_ip(request)
130
+ print(f"Unnatural (voted). ip: {ip_address1}")
131
+ return (
132
+ "Unnatural",
133
+ ip_address1,
134
+ ) + (disable_btn,) * 4
135
+
136
+
137
+ def relevant_vote1_last_response(request: gr.Request):
138
+ """
139
+ Handle a user vote for relevance as "Highly Relevant".
140
+
141
+
142
+ Args:
143
+ request (gr.Request):
144
+ The Gradio request object providing access to HTTP headers and metadata.
145
+
146
+ Returns:
147
+ tuple:
148
+ A tuple containing:
149
+ ("Highly Relevant", <ip_address>, (disable_btn,) * 4)
150
+
151
+ - "Highly Relevant": The selected vote or label.
152
+ - <ip_address>: The IP address of the client retrieved from the request.
153
+ - disable_btn: An object repeated four times,
154
+ to disable relevance vote buttons.
155
+ """
156
+ ip_address1 = get_ip(request)
157
+ print(f"Highly Relevant (voted). ip: {ip_address1}")
158
+ return (
159
+ "Highly Relevant",
160
+ ip_address1,
161
+ ) + (disable_btn,) * 4
162
+
163
+
164
+ def relevant_vote2_last_response(request: gr.Request):
165
+ """
166
+ Handle a user vote for relevance as "Partially Relevant".
167
+
168
+
169
+ Args:
170
+ request (gr.Request):
171
+ The Gradio request object providing access to HTTP headers and metadata.
172
+
173
+ Returns:
174
+ tuple:
175
+ A tuple containing:
176
+ ("Partially Relevant", <ip_address>, (disable_btn,) * 4)
177
+
178
+ - "Partially Relevant": The selected vote or label.
179
+ - <ip_address>: The IP address of the client retrieved from the request.
180
+ - disable_btn: An object repeated four times,
181
+ to disable relevance vote buttons.
182
+ """
183
+ ip_address1 = get_ip(request)
184
+ print(f"Partially Relevant (voted). ip: {ip_address1}")
185
+ return (
186
+ "Partially Relevant",
187
+ ip_address1,
188
+ ) + (disable_btn,) * 4
189
+
190
+
191
+ def relevant_vote3_last_response(request: gr.Request):
192
+ """
193
+ Handle a user vote for relevance as "Slightly Irrelevant".
194
+
195
+
196
+ Args:
197
+ request (gr.Request):
198
+ The Gradio request object providing access to HTTP headers and metadata.
199
+
200
+ Returns:
201
+ tuple:
202
+ A tuple containing:
203
+ ("Slightly Irrelevant", <ip_address>, (disable_btn,) * 4)
204
+
205
+ - "Slightly Irrelevant": The selected vote or label.
206
+ - <ip_address>: The IP address of the client retrieved from the request.
207
+ - disable_btn: An object repeated four times,
208
+ to disable relevance vote buttons.
209
+ """
210
+ ip_address1 = get_ip(request)
211
+ print(f"Slightly Irrelevant (voted). ip: {ip_address1}")
212
+ return (
213
+ "Slightly Irrelevant",
214
+ ip_address1,
215
+ ) + (disable_btn,) * 4
216
+
217
+
218
+ def relevant_vote4_last_response(request: gr.Request):
219
+ """
220
+ Handle a user vote for relevance as "Completely Irrelevant".
221
+
222
+
223
+ Args:
224
+ request (gr.Request):
225
+ The Gradio request object providing access to HTTP headers and metadata.
226
+
227
+ Returns:
228
+ tuple:
229
+ A tuple containing:
230
+ ("Completely Irrelevant", <ip_address>, (disable_btn,) * 4)
231
+
232
+ - "Completely Irrelevant": The selected vote or label.
233
+ - <ip_address>: The IP address of the client retrieved from the request.
234
+ - disable_btn: An object repeated four times,
235
+ to disable relevance vote buttons.
236
+ """
237
+ ip_address1 = get_ip(request)
238
+ print(f"Completely Irrelevant (voted). ip: {ip_address1}")
239
+ return (
240
+ "Completely Irrelevant",
241
+ ip_address1,
242
+ ) + (disable_btn,) * 4
pyscripts/utils/dialog_eval/vert.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import sys
8
+ import warnings
9
+ from collections import Counter
10
+ from fractions import Fraction
11
+
12
+ import nltk
13
+ import numpy as np
14
+ from nltk.translate.bleu_score import (
15
+ SmoothingFunction,
16
+ brevity_penalty,
17
+ closest_ref_length,
18
+ modified_precision,
19
+ )
20
+
21
+
22
+ def corpus_bleu(
23
+ list_of_references,
24
+ hypotheses,
25
+ weights=(0.25, 0.25, 0.25, 0.25),
26
+ smoothing_function=None,
27
+ auto_reweigh=False,
28
+ averaging_mode="geometric",
29
+ no_length_penalty=False,
30
+ ):
31
+ """
32
+ Calculate a single corpus-level BLEU score (aka. system-level BLEU) for all
33
+ the hypotheses and their respective references.
34
+
35
+ Instead of averaging the sentence level BLEU scores (i.e. marco-average
36
+ precision), the original BLEU metric (Papineni et al. 2002) accounts for
37
+ the micro-average precision (i.e. summing the numerators and denominators
38
+ for each hypothesis-reference(s) pairs before the division).
39
+
40
+ >>> hyp1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which',
41
+ ... 'ensures', 'that', 'the', 'military', 'always',
42
+ ... 'obeys', 'the', 'commands', 'of', 'the', 'party']
43
+ >>> ref1a = ['It', 'is', 'a', 'guide', 'to', 'action', 'that',
44
+ ... 'ensures', 'that', 'the', 'military', 'will', 'forever',
45
+ ... 'heed', 'Party', 'commands']
46
+ >>> ref1b = ['It', 'is', 'the', 'guiding', 'principle', 'which',
47
+ ... 'guarantees', 'the', 'military', 'forces', 'always',
48
+ ... 'being', 'under', 'the', 'command', 'of', 'the', 'Party']
49
+ >>> ref1c = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the',
50
+ ... 'army', 'always', 'to', 'heed', 'the', 'directions',
51
+ ... 'of', 'the', 'party']
52
+
53
+ >>> hyp2 = ['he', 'read', 'the', 'book', 'because', 'he', 'was',
54
+ ... 'interested', 'in', 'world', 'history']
55
+ >>> ref2a = ['he', 'was', 'interested', 'in', 'world', 'history',
56
+ ... 'because', 'he', 'read', 'the', 'book']
57
+
58
+ >>> list_of_references = [[ref1a, ref1b, ref1c], [ref2a]]
59
+ >>> hypotheses = [hyp1, hyp2]
60
+ >>> corpus_bleu(list_of_references, hypotheses) # doctest: +ELLIPSIS
61
+ 0.5920...
62
+
63
+ The example below show that corpus_bleu() is different from averaging
64
+ sentence_bleu() for hypotheses
65
+
66
+ >>> score1 = sentence_bleu([ref1a, ref1b, ref1c], hyp1)
67
+ >>> score2 = sentence_bleu([ref2a], hyp2)
68
+ >>> (score1 + score2) / 2 # doctest: +ELLIPSIS
69
+ 0.6223...
70
+
71
+ :param list_of_references: a corpus of lists of reference
72
+ sentences, w.r.t. hypotheses
73
+ :type list_of_references: list(list(list(str)))
74
+ :param hypotheses: a list of hypothesis sentences
75
+ :type hypotheses: list(list(str))
76
+ :param weights: weights for unigrams, bigrams, trigrams and so on
77
+ :type weights: list(float)
78
+ :param smoothing_function:
79
+ :type smoothing_function: SmoothingFunction
80
+ :param auto_reweigh: Option to re-normalize the weights uniformly.
81
+ :type auto_reweigh: bool
82
+ :return: The corpus-level BLEU score.
83
+ :rtype: float
84
+ """
85
+ # Before proceeding to compute BLEU, perform sanity checks.
86
+
87
+ p_numerators = Counter() # Key = ngram order, and value = no. of ngram matches.
88
+ p_denominators = Counter() # Key = ngram order, and value = no. of ngram in ref.
89
+ hyp_lengths, ref_lengths = 0, 0
90
+
91
+ assert len(list_of_references) == len(hypotheses), (
92
+ "The number of hypotheses and their reference(s) should be the " "same "
93
+ )
94
+
95
+ # Iterate through each hypothesis and their corresponding references.
96
+ for references, hypothesis in zip(list_of_references, hypotheses):
97
+ # For each order of ngram, calculate the numerator and
98
+ # denominator for the corpus-level modified precision.
99
+ for i, _ in enumerate(weights, start=1):
100
+ p_i = modified_precision(references, hypothesis, i)
101
+ p_numerators[i] += p_i.numerator
102
+ p_denominators[i] += p_i.denominator
103
+
104
+ # Calculate the hypothesis length and the closest reference length.
105
+ # Adds them to the corpus-level hypothesis and reference counts.
106
+ hyp_len = len(hypothesis)
107
+ hyp_lengths += hyp_len
108
+ ref_lengths += closest_ref_length(references, hyp_len)
109
+
110
+ # Calculate corpus-level brevity penalty.
111
+ if no_length_penalty and averaging_mode == "geometric":
112
+ bp = 1.0
113
+ elif no_length_penalty and averaging_mode == "arithmetic":
114
+ bp = 0.0
115
+ else:
116
+ assert not no_length_penalty
117
+ assert (
118
+ averaging_mode != "arithmetic"
119
+ ), "Not sure how to apply length penalty when aurithmetic mode"
120
+ bp = brevity_penalty(ref_lengths, hyp_lengths)
121
+
122
+ # Uniformly re-weighting based on maximum hypothesis lengths if largest
123
+ # order of n-grams < 4 and weights is set at default.
124
+ if auto_reweigh:
125
+ if hyp_lengths < 4 and weights == (0.25, 0.25, 0.25, 0.25):
126
+ weights = (1 / hyp_lengths,) * hyp_lengths
127
+
128
+ # Collects the various precision values for the different ngram orders.
129
+ p_n = [
130
+ Fraction(p_numerators[i], p_denominators[i], _normalize=False)
131
+ for i, _ in enumerate(weights, start=1)
132
+ ]
133
+
134
+ # Returns 0 if there's no matching n-grams
135
+ # We only need to check for p_numerators[1] == 0, since if there's
136
+ # no unigrams, there won't be any higher order ngrams.
137
+ if p_numerators[1] == 0:
138
+ return 0
139
+
140
+ # If there's no smoothing, set use method0 from SmoothinFunction class.
141
+ if not smoothing_function:
142
+ smoothing_function = SmoothingFunction().method0
143
+ # Smoothen the modified precision.
144
+ # Note: smoothing_function() may convert values into floats;
145
+ # it tries to retain the Fraction object as much as the
146
+ # smoothing method allows.
147
+ p_n = smoothing_function(
148
+ p_n, references=references, hypothesis=hypothesis, hyp_len=hyp_lengths
149
+ )
150
+
151
+ if averaging_mode == "geometric":
152
+ s = (w_i * math.log(p_i) for w_i, p_i in zip(weights, p_n))
153
+ s = bp * math.exp(math.fsum(s))
154
+ elif averaging_mode == "arithmetic":
155
+ s = (w_i * p_i for w_i, p_i in zip(weights, p_n))
156
+ s = math.fsum(s)
157
+
158
+ return s
159
+
160
+
161
+ def sentence_bleu(
162
+ references,
163
+ hypothesis,
164
+ weights=(0.25, 0.25, 0.25, 0.25),
165
+ smoothing_function=None,
166
+ auto_reweigh=False,
167
+ averaging_mode="geometric",
168
+ no_length_penalty=False,
169
+ ):
170
+ return corpus_bleu(
171
+ [references],
172
+ [hypothesis],
173
+ weights,
174
+ smoothing_function,
175
+ auto_reweigh,
176
+ averaging_mode,
177
+ no_length_penalty,
178
+ )
179
+
180
+
181
+ def get_target_sequences(manifest, ground_truth, to_take=1000):
182
+ import json
183
+ import pathlib
184
+
185
+ with open(ground_truth, "r") as fin:
186
+ original_continuations = json.loads(fin.read())
187
+
188
+ sequence2length = [(k, v[0]) for k, v in original_continuations.items()]
189
+ assert all(float(v) >= 6.0 for (_, v) in sequence2length) # 6 seconds
190
+
191
+ sequence2length.sort(key=lambda x: x[1])
192
+ to_take_sequences = set(v[0] for v in sequence2length[:to_take])
193
+ to_take_ids = []
194
+
195
+ with open(manifest, "r") as f:
196
+ f.readline()
197
+
198
+ for i, line in enumerate(f.readlines()):
199
+ seq_id = line.split()[0]
200
+ seq_id = pathlib.Path(seq_id).name.split("__")[0]
201
+
202
+ if seq_id in to_take_sequences:
203
+ to_take_ids.append(i)
204
+
205
+ print(f"Took {len(to_take_ids)} ids")
206
+ return set(to_take_ids)
207
+
208
+
209
+ def get_self_bleu(utterances, averaging_mode, weights):
210
+ self_bleu = []
211
+
212
+ for i in range(len(utterances)):
213
+ hypo = utterances[i]
214
+ rest = utterances[:i] + utterances[i + 1 :]
215
+
216
+ self_bleu.append(
217
+ sentence_bleu(
218
+ rest,
219
+ hypo,
220
+ weights,
221
+ no_length_penalty=True,
222
+ averaging_mode=averaging_mode,
223
+ )
224
+ )
225
+
226
+ return self_bleu
227
+
228
+
229
+ def get_self_bleu2_arithmetic(utterances):
230
+ weights = (0.5, 0.5) # equal weight for unigrams and bigrams
231
+ return get_self_bleu(utterances, averaging_mode="arithmetic", weights=weights)
232
+
233
+
234
+ def get_self_bleu2_geometric(utterances):
235
+ weights = (0.5, 0.5)
236
+ return get_self_bleu(utterances, averaging_mode="geometric", weights=weights)
237
+
238
+
239
+ def get_auto_bleu2_arithmetic(utterances):
240
+ weights = (0.5, 0.5)
241
+ return [auto_bleu(u, mean_mode="arithmetic", weights=weights) for u in utterances]
242
+
243
+
244
+ def get_auto_bleu2_geometric(utterances):
245
+ weights = (0.5, 0.5)
246
+ return [auto_bleu(u, mean_mode="geometric", weights=weights) for u in utterances]
247
+
248
+
249
+ def get_auto_bleu3_geometric(utterances):
250
+ weights = (1.0 / 3, 1.0 / 3, 1.0 / 3)
251
+ return [auto_bleu(u, mean_mode="geometric", weights=weights) for u in utterances]
252
+
253
+
254
+ def get_auto_bleu3_arithmetic(utterances):
255
+ weights = (1.0 / 3, 1.0 / 3, 1.0 / 3)
256
+ return [auto_bleu(u, mean_mode="arithmetic", weights=weights) for u in utterances]
257
+
258
+
259
+ def get_self_bleu3_arithmetic(utterances):
260
+ weights = (1.0 / 3, 1.0 / 3, 1.0 / 3)
261
+ return get_self_bleu(utterances, averaging_mode="arithmetic", weights=weights)
262
+
263
+
264
+ def get_self_bleu3_geometric(utterances):
265
+ weights = (1.0 / 3, 1.0 / 3, 1.0 / 3)
266
+ return get_self_bleu(utterances, averaging_mode="geometric", weights=weights)
267
+
268
+
269
+ def auto_bleu(sentence, weights, mean_mode="arithmetic"):
270
+ if len(sentence) <= 1:
271
+ return 0
272
+
273
+ N = len(weights)
274
+
275
+ bleu_n = np.zeros([N])
276
+ for n in range(N):
277
+ targ_ngrams = list(nltk.ngrams(sentence, n + 1))
278
+ for p in range(len(targ_ngrams)):
279
+ left = sentence[:p]
280
+ right = sentence[(p + n + 1) :]
281
+ rest_ngrams = list(nltk.ngrams(left, n + 1)) + list(
282
+ nltk.ngrams(right, n + 1)
283
+ )
284
+ # compute the nb of matching ngrams
285
+ bleu_n[n] += targ_ngrams[p] in rest_ngrams
286
+ bleu_n[n] /= len(targ_ngrams) # average them to get a proportion
287
+
288
+ weights = np.array(weights)
289
+ if mean_mode == "arithmetic":
290
+ return (bleu_n * weights).sum()
291
+ elif mean_mode == "geometric":
292
+ return (bleu_n**weights).prod()
293
+ else:
294
+ raise ValueError(f"Unknown agggregation mode {mean_mode}")
295
+
296
+
297
+ def run_f(task_params):
298
+ f, terms = task_params
299
+ return f(terms)