csukuangfj commited on
Commit
c64aaa1
·
1 Parent(s): da4b8f8
Files changed (3) hide show
  1. app.py +28 -1
  2. examples.py +1 -1
  3. model.py +3 -3
app.py CHANGED
@@ -21,22 +21,35 @@
21
 
22
  import logging
23
  import os
 
24
  import time
25
  import uuid
26
  from datetime import datetime
27
 
28
  import gradio as gr
29
 
 
30
  from model import (
31
  embedding2models,
 
32
  get_speaker_diarization,
33
  read_wave,
34
  speaker_segmentation_models,
35
  )
36
- from examples import examples
37
 
38
  embedding_frameworks = list(embedding2models.keys())
39
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  def MyPrint(s):
42
  now = datetime.now()
@@ -288,6 +301,20 @@ with demo:
288
  uploaded_output = gr.Textbox(label="Result from uploaded file")
289
  uploaded_html_info = gr.HTML(label="Info")
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  upload_button.click(
292
  process_uploaded_file,
293
  inputs=[
 
21
 
22
  import logging
23
  import os
24
+ import shutil
25
  import time
26
  import uuid
27
  from datetime import datetime
28
 
29
  import gradio as gr
30
 
31
+ from examples import examples
32
  from model import (
33
  embedding2models,
34
+ get_file,
35
  get_speaker_diarization,
36
  read_wave,
37
  speaker_segmentation_models,
38
  )
 
39
 
40
  embedding_frameworks = list(embedding2models.keys())
41
 
42
+ waves = [e[-1] for e in examples]
43
+
44
+ for name in waves:
45
+ filename = get_file(
46
+ "k2-fsa/speaker-diarization",
47
+ name,
48
+ subfolder="test_wavs",
49
+ )
50
+
51
+ shutil.copyfile(filename, name)
52
+
53
 
54
  def MyPrint(s):
55
  now = datetime.now()
 
301
  uploaded_output = gr.Textbox(label="Result from uploaded file")
302
  uploaded_html_info = gr.HTML(label="Info")
303
 
304
+ gr.Examples(
305
+ examples=examples,
306
+ inputs=[
307
+ embedding_framework_radio,
308
+ embedding_model_dropdown,
309
+ speaker_segmentation_model_dropdown,
310
+ input_num_speakers,
311
+ input_threshold,
312
+ uploaded_file,
313
+ ],
314
+ outputs=[uploaded_output, uploaded_html_info],
315
+ fn=process_uploaded_file,
316
+ )
317
+
318
  upload_button.click(
319
  process_uploaded_file,
320
  inputs=[
examples.py CHANGED
@@ -5,6 +5,6 @@ examples = [
5
  "pyannote/segmentation-3.0",
6
  "4",
7
  "0",
8
- "./test_wavs/0-four-speakers-zh.wav",
9
  ],
10
  ]
 
5
  "pyannote/segmentation-3.0",
6
  "4",
7
  "0",
8
+ "./0-four-speakers-zh.wav",
9
  ],
10
  ]
model.py CHANGED
@@ -49,7 +49,7 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
49
 
50
 
51
  @lru_cache(maxsize=30)
52
- def _get_nn_model_filename(
53
  repo_id: str,
54
  filename: str,
55
  subfolder: str = ".",
@@ -66,7 +66,7 @@ def get_speaker_segmentation_model(repo_id) -> str:
66
  assert repo_id in ("pyannote/segmentation-3.0",)
67
 
68
  if repo_id == "pyannote/segmentation-3.0":
69
- return _get_nn_model_filename(
70
  repo_id="csukuangfj/sherpa-onnx-pyannote-segmentation-3-0",
71
  filename="model.onnx",
72
  )
@@ -81,7 +81,7 @@ def get_speaker_embedding_model(model_name) -> str:
81
  )
82
  model_name = model_name.split("|")[0]
83
 
84
- return _get_nn_model_filename(
85
  repo_id="csukuangfj/speaker-embedding-models",
86
  filename=model_name,
87
  )
 
49
 
50
 
51
  @lru_cache(maxsize=30)
52
+ def get_file(
53
  repo_id: str,
54
  filename: str,
55
  subfolder: str = ".",
 
66
  assert repo_id in ("pyannote/segmentation-3.0",)
67
 
68
  if repo_id == "pyannote/segmentation-3.0":
69
+ return get_file(
70
  repo_id="csukuangfj/sherpa-onnx-pyannote-segmentation-3-0",
71
  filename="model.onnx",
72
  )
 
81
  )
82
  model_name = model_name.split("|")[0]
83
 
84
+ return get_file(
85
  repo_id="csukuangfj/speaker-embedding-models",
86
  filename=model_name,
87
  )