jhj0517 commited on
Commit
50380bc
·
unverified ·
2 Parent(s): f12a40c c7bfcf2

Merge pull request #375 from jhj0517/fix/improve-test

Browse files
.github/workflows/ci.yml CHANGED
@@ -37,7 +37,7 @@ jobs:
37
  run: sudo apt-get update && sudo apt-get install -y git ffmpeg
38
 
39
  - name: Install dependencies
40
- run: pip install -r requirements.txt pytest
41
 
42
  - name: Run test
43
  run: python -m pytest -rs tests
 
37
  run: sudo apt-get update && sudo apt-get install -y git ffmpeg
38
 
39
  - name: Install dependencies
40
+ run: pip install -r requirements.txt pytest jiwer
41
 
42
  - name: Run test
43
  run: python -m pytest -rs tests
modules/whisper/base_transcription_pipeline.py CHANGED
@@ -179,7 +179,7 @@ class BaseTranscriptionPipeline(ABC):
179
  add_timestamp: bool = True,
180
  progress=gr.Progress(),
181
  *pipeline_params,
182
- ) -> list:
183
  """
184
  Write subtitle file from Files
185
 
@@ -250,7 +250,7 @@ class BaseTranscriptionPipeline(ABC):
250
  result_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
251
  result_file_path = [info['path'] for info in files_info.values()]
252
 
253
- return [result_str, result_file_path]
254
 
255
  except Exception as e:
256
  print(f"Error transcribing file: {e}")
@@ -264,7 +264,7 @@ class BaseTranscriptionPipeline(ABC):
264
  add_timestamp: bool = True,
265
  progress=gr.Progress(),
266
  *pipeline_params,
267
- ) -> list:
268
  """
269
  Write subtitle file from microphone
270
 
@@ -314,7 +314,7 @@ class BaseTranscriptionPipeline(ABC):
314
  )
315
 
316
  result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
317
- return [result_str, file_path]
318
  except Exception as e:
319
  print(f"Error transcribing mic: {e}")
320
  raise
@@ -327,7 +327,7 @@ class BaseTranscriptionPipeline(ABC):
327
  add_timestamp: bool = True,
328
  progress=gr.Progress(),
329
  *pipeline_params,
330
- ) -> list:
331
  """
332
  Write subtitle file from Youtube
333
 
@@ -385,7 +385,7 @@ class BaseTranscriptionPipeline(ABC):
385
  if os.path.exists(audio):
386
  os.remove(audio)
387
 
388
- return [result_str, file_path]
389
 
390
  except Exception as e:
391
  print(f"Error transcribing youtube: {e}")
 
179
  add_timestamp: bool = True,
180
  progress=gr.Progress(),
181
  *pipeline_params,
182
+ ) -> Tuple[str, List]:
183
  """
184
  Write subtitle file from Files
185
 
 
250
  result_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
251
  result_file_path = [info['path'] for info in files_info.values()]
252
 
253
+ return result_str, result_file_path
254
 
255
  except Exception as e:
256
  print(f"Error transcribing file: {e}")
 
264
  add_timestamp: bool = True,
265
  progress=gr.Progress(),
266
  *pipeline_params,
267
+ ) -> Tuple[str, str]:
268
  """
269
  Write subtitle file from microphone
270
 
 
314
  )
315
 
316
  result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
317
+ return result_str, file_path
318
  except Exception as e:
319
  print(f"Error transcribing mic: {e}")
320
  raise
 
327
  add_timestamp: bool = True,
328
  progress=gr.Progress(),
329
  *pipeline_params,
330
+ ) -> Tuple[str, str]:
331
  """
332
  Write subtitle file from Youtube
333
 
 
385
  if os.path.exists(audio):
386
  os.remove(audio)
387
 
388
+ return result_str, file_path
389
 
390
  except Exception as e:
391
  print(f"Error transcribing youtube: {e}")
tests/test_config.py CHANGED
@@ -1,15 +1,16 @@
1
  import functools
 
 
 
2
 
3
  from modules.utils.paths import *
4
  from modules.utils.youtube_manager import *
5
 
6
- import os
7
- import torch
8
-
9
  TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav"
10
  TEST_FILE_PATH = os.path.join(WEBUI_DIR, "tests", "jfk.wav")
 
11
  TEST_YOUTUBE_URL = "https://www.youtube.com/watch?v=4WEQtgnBu0I&ab_channel=AndriaFitzer"
12
- TEST_WHISPER_MODEL = "tiny.en"
13
  TEST_UVR_MODEL = "UVR-MDX-NET-Inst_HQ_4"
14
  TEST_NLLB_MODEL = "facebook/nllb-200-distilled-600M"
15
  TEST_SUBTITLE_SRT_PATH = os.path.join(WEBUI_DIR, "tests", "test_srt.srt")
@@ -34,3 +35,6 @@ def is_pytube_detected_bot(url: str = TEST_YOUTUBE_URL):
34
  print(f"Pytube has detected as a bot: {e}")
35
  return True
36
 
 
 
 
 
1
  import functools
2
+ import jiwer
3
+ import os
4
+ import torch
5
 
6
  from modules.utils.paths import *
7
  from modules.utils.youtube_manager import *
8
 
 
 
 
9
  TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav"
10
  TEST_FILE_PATH = os.path.join(WEBUI_DIR, "tests", "jfk.wav")
11
+ TEST_ANSWER = "And so my fellow Americans ask not what your country can do for you ask what you can do for your country"
12
  TEST_YOUTUBE_URL = "https://www.youtube.com/watch?v=4WEQtgnBu0I&ab_channel=AndriaFitzer"
13
+ TEST_WHISPER_MODEL = "tiny"
14
  TEST_UVR_MODEL = "UVR-MDX-NET-Inst_HQ_4"
15
  TEST_NLLB_MODEL = "facebook/nllb-200-distilled-600M"
16
  TEST_SUBTITLE_SRT_PATH = os.path.join(WEBUI_DIR, "tests", "test_srt.srt")
 
35
  print(f"Pytube has detected as a bot: {e}")
36
  return True
37
 
38
+
39
+ def calculate_wer(answer, prediction):
40
+ return jiwer.wer(answer, prediction)
tests/test_transcription.py CHANGED
@@ -1,5 +1,6 @@
1
  from modules.whisper.whisper_factory import WhisperFactory
2
  from modules.whisper.data_classes import *
 
3
  from modules.utils.paths import WEBUI_DIR
4
  from test_config import *
5
 
@@ -28,6 +29,10 @@ def test_transcribe(
28
  if not os.path.exists(audio_path):
29
  download_file(TEST_FILE_DOWNLOAD_URL, audio_path_dir)
30
 
 
 
 
 
31
  whisper_inferencer = WhisperFactory.create_whisper_inference(
32
  whisper_type=whisper_type,
33
  )
@@ -54,7 +59,7 @@ def test_transcribe(
54
  ),
55
  ).to_list()
56
 
57
- subtitle_str, file_path = whisper_inferencer.transcribe_file(
58
  [audio_path],
59
  None,
60
  "SRT",
@@ -62,12 +67,11 @@ def test_transcribe(
62
  gr.Progress(),
63
  *hparams,
64
  )
65
-
66
- assert isinstance(subtitle_str, str) and subtitle_str
67
- assert isinstance(file_path[0], str) and file_path
68
 
69
  if not is_pytube_detected_bot():
70
- whisper_inferencer.transcribe_youtube(
71
  TEST_YOUTUBE_URL,
72
  "SRT",
73
  False,
@@ -75,17 +79,17 @@ def test_transcribe(
75
  *hparams,
76
  )
77
  assert isinstance(subtitle_str, str) and subtitle_str
78
- assert isinstance(file_path[0], str) and file_path
79
 
80
- whisper_inferencer.transcribe_mic(
81
  audio_path,
82
  "SRT",
83
  False,
84
  gr.Progress(),
85
  *hparams,
86
  )
87
- assert isinstance(subtitle_str, str) and subtitle_str
88
- assert isinstance(file_path[0], str) and file_path
89
 
90
 
91
  def download_file(url, save_dir):
 
1
  from modules.whisper.whisper_factory import WhisperFactory
2
  from modules.whisper.data_classes import *
3
+ from modules.utils.subtitle_manager import read_file
4
  from modules.utils.paths import WEBUI_DIR
5
  from test_config import *
6
 
 
29
  if not os.path.exists(audio_path):
30
  download_file(TEST_FILE_DOWNLOAD_URL, audio_path_dir)
31
 
32
+ answer = TEST_ANSWER
33
+ if diarization:
34
+ answer = "SPEAKER_00|"+TEST_ANSWER
35
+
36
  whisper_inferencer = WhisperFactory.create_whisper_inference(
37
  whisper_type=whisper_type,
38
  )
 
59
  ),
60
  ).to_list()
61
 
62
+ subtitle_str, file_paths = whisper_inferencer.transcribe_file(
63
  [audio_path],
64
  None,
65
  "SRT",
 
67
  gr.Progress(),
68
  *hparams,
69
  )
70
+ subtitle = read_file(file_paths[0]).split("\n")
71
+ assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1
 
72
 
73
  if not is_pytube_detected_bot():
74
+ subtitle_str, file_path = whisper_inferencer.transcribe_youtube(
75
  TEST_YOUTUBE_URL,
76
  "SRT",
77
  False,
 
79
  *hparams,
80
  )
81
  assert isinstance(subtitle_str, str) and subtitle_str
82
+ assert os.path.exists(file_path)
83
 
84
+ subtitle_str, file_path = whisper_inferencer.transcribe_mic(
85
  audio_path,
86
  "SRT",
87
  False,
88
  gr.Progress(),
89
  *hparams,
90
  )
91
+ subtitle = read_file(file_path).split("\n")
92
+ assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1
93
 
94
 
95
  def download_file(url, save_dir):