aadnk commited on
Commit
44d964a
·
1 Parent(s): 512321e

Add configuration file and support for custom models

Browse files

Custom models can be added to the configuration file,
under the "models" section. See the comments for more
details.

Files changed (8) hide show
  1. .gitignore +1 -0
  2. app.py +37 -17
  3. cli.py +72 -38
  4. config.json5 +62 -0
  5. requirements.txt +3 -1
  6. src/config.py +134 -0
  7. src/conversion/hf_converter.py +67 -0
  8. src/whisperContainer.py +29 -3
.gitignore CHANGED
@@ -1,5 +1,6 @@
1
  # Byte-compiled / optimized / DLL files
2
  __pycache__/
 
3
  flagged/
4
  *.py[cod]
5
  *$py.class
 
1
  # Byte-compiled / optimized / DLL files
2
  __pycache__/
3
+ .vscode/
4
  flagged/
5
  *.py[cod]
6
  *$py.class
app.py CHANGED
@@ -11,6 +11,7 @@ import zipfile
11
  import numpy as np
12
 
13
  import torch
 
14
  from src.modelCache import ModelCache
15
  from src.source import get_audio_source_collection
16
  from src.vadParallel import ParallelContext, ParallelTranscription
@@ -62,7 +63,8 @@ WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v1", "large
62
 
63
  class WhisperTranscriber:
64
  def __init__(self, input_audio_max_duration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, vad_process_timeout: float = None,
65
- vad_cpu_cores: int = 1, delete_uploaded_files: bool = DELETE_UPLOADED_FILES, output_dir: str = None):
 
66
  self.model_cache = ModelCache()
67
  self.parallel_device_list = None
68
  self.gpu_parallel_context = None
@@ -75,6 +77,8 @@ class WhisperTranscriber:
75
  self.deleteUploadedFiles = delete_uploaded_files
76
  self.output_dir = output_dir
77
 
 
 
78
  def set_parallel_devices(self, vad_parallel_devices: str):
79
  self.parallel_device_list = [ device.strip() for device in vad_parallel_devices.split(",") ] if vad_parallel_devices else None
80
 
@@ -115,7 +119,7 @@ class WhisperTranscriber:
115
  selectedLanguage = languageName.lower() if len(languageName) > 0 else None
116
  selectedModel = modelName if modelName is not None else "base"
117
 
118
- model = WhisperContainer(model_name=selectedModel, cache=self.model_cache)
119
 
120
  # Result
121
  download = []
@@ -360,8 +364,8 @@ class WhisperTranscriber:
360
  def create_ui(input_audio_max_duration, share=False, server_name: str = None, server_port: int = 7860,
361
  default_model_name: str = "medium", default_vad: str = None, vad_parallel_devices: str = None,
362
  vad_process_timeout: float = None, vad_cpu_cores: int = 1, auto_parallel: bool = False,
363
- output_dir: str = None):
364
- ui = WhisperTranscriber(input_audio_max_duration, vad_process_timeout, vad_cpu_cores, DELETE_UPLOADED_FILES, output_dir)
365
 
366
  # Specify a list of devices to use for parallel processing
367
  ui.set_parallel_devices(vad_parallel_devices)
@@ -378,8 +382,10 @@ def create_ui(input_audio_max_duration, share=False, server_name: str = None, se
378
 
379
  ui_article = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)"
380
 
 
 
381
  simple_inputs = lambda : [
382
- gr.Dropdown(choices=WHISPER_MODELS, value=default_model_name, label="Model"),
383
  gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
384
  gr.Text(label="URL (YouTube, etc.)"),
385
  gr.File(label="Upload Files", file_count="multiple"),
@@ -429,18 +435,32 @@ def create_ui(input_audio_max_duration, share=False, server_name: str = None, se
429
  ui.close()
430
 
431
  if __name__ == '__main__':
 
 
 
432
  parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
433
- parser.add_argument("--input_audio_max_duration", type=int, default=DEFAULT_INPUT_AUDIO_MAX_DURATION, help="Maximum audio file length in seconds, or -1 for no limit.")
434
- parser.add_argument("--share", type=bool, default=False, help="True to share the app on HuggingFace.")
435
- parser.add_argument("--server_name", type=str, default=None, help="The host or IP to bind to. If None, bind to localhost.")
436
- parser.add_argument("--server_port", type=int, default=7860, help="The port to bind to.")
437
- parser.add_argument("--default_model_name", type=str, choices=WHISPER_MODELS, default="medium", help="The default model name.")
438
- parser.add_argument("--default_vad", type=str, default="silero-vad", help="The default VAD.")
439
- parser.add_argument("--vad_parallel_devices", type=str, default="", help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.")
440
- parser.add_argument("--vad_cpu_cores", type=int, default=1, help="The number of CPU cores to use for VAD pre-processing.")
441
- parser.add_argument("--vad_process_timeout", type=float, default="1800", help="The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.")
442
- parser.add_argument("--auto_parallel", type=bool, default=False, help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.")
443
- parser.add_argument("--output_dir", "-o", type=str, default=None, help="directory to save the outputs")
 
 
 
 
 
 
 
 
 
 
 
444
 
445
  args = parser.parse_args().__dict__
446
- create_ui(**args)
 
11
  import numpy as np
12
 
13
  import torch
14
+ from src.config import ApplicationConfig
15
  from src.modelCache import ModelCache
16
  from src.source import get_audio_source_collection
17
  from src.vadParallel import ParallelContext, ParallelTranscription
 
63
 
64
  class WhisperTranscriber:
65
  def __init__(self, input_audio_max_duration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, vad_process_timeout: float = None,
66
+ vad_cpu_cores: int = 1, delete_uploaded_files: bool = DELETE_UPLOADED_FILES, output_dir: str = None,
67
+ app_config: ApplicationConfig = None):
68
  self.model_cache = ModelCache()
69
  self.parallel_device_list = None
70
  self.gpu_parallel_context = None
 
77
  self.deleteUploadedFiles = delete_uploaded_files
78
  self.output_dir = output_dir
79
 
80
+ self.app_config = app_config
81
+
82
  def set_parallel_devices(self, vad_parallel_devices: str):
83
  self.parallel_device_list = [ device.strip() for device in vad_parallel_devices.split(",") ] if vad_parallel_devices else None
84
 
 
119
  selectedLanguage = languageName.lower() if len(languageName) > 0 else None
120
  selectedModel = modelName if modelName is not None else "base"
121
 
122
+ model = WhisperContainer(model_name=selectedModel, cache=self.model_cache, models=self.app_config.models)
123
 
124
  # Result
125
  download = []
 
364
  def create_ui(input_audio_max_duration, share=False, server_name: str = None, server_port: int = 7860,
365
  default_model_name: str = "medium", default_vad: str = None, vad_parallel_devices: str = None,
366
  vad_process_timeout: float = None, vad_cpu_cores: int = 1, auto_parallel: bool = False,
367
+ output_dir: str = None, app_config: ApplicationConfig = None):
368
+ ui = WhisperTranscriber(input_audio_max_duration, vad_process_timeout, vad_cpu_cores, DELETE_UPLOADED_FILES, output_dir, app_config)
369
 
370
  # Specify a list of devices to use for parallel processing
371
  ui.set_parallel_devices(vad_parallel_devices)
 
382
 
383
  ui_article = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)"
384
 
385
+ whisper_models = app_config.get_model_names()
386
+
387
  simple_inputs = lambda : [
388
+ gr.Dropdown(choices=whisper_models, value=default_model_name, label="Model"),
389
  gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
390
  gr.Text(label="URL (YouTube, etc.)"),
391
  gr.File(label="Upload Files", file_count="multiple"),
 
435
  ui.close()
436
 
437
  if __name__ == '__main__':
438
+ app_config = ApplicationConfig.parse_file(os.environ.get("WHISPER_WEBUI_CONFIG", "config.json5"))
439
+ whisper_models = app_config.get_model_names()
440
+
441
  parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
442
+ parser.add_argument("--input_audio_max_duration", type=int, default=app_config.input_audio_max_duration, \
443
+ help="Maximum audio file length in seconds, or -1 for no limit.") # 600
444
+ parser.add_argument("--share", type=bool, default=app_config.share, \
445
+ help="True to share the app on HuggingFace.") # False
446
+ parser.add_argument("--server_name", type=str, default=app_config.server_name, \
447
+ help="The host or IP to bind to. If None, bind to localhost.") # None
448
+ parser.add_argument("--server_port", type=int, default=app_config.server_port, \
449
+ help="The port to bind to.") # 7860
450
+ parser.add_argument("--default_model_name", type=str, choices=whisper_models, default=app_config.default_model_name, \
451
+ help="The default model name.") # medium
452
+ parser.add_argument("--default_vad", type=str, default=app_config.default_vad, \
453
+ help="The default VAD.") # silero-vad
454
+ parser.add_argument("--vad_parallel_devices", type=str, default=app_config.vad_parallel_devices, \
455
+ help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.") # ""
456
+ parser.add_argument("--vad_cpu_cores", type=int, default=app_config.vad_cpu_cores, \
457
+ help="The number of CPU cores to use for VAD pre-processing.") # 1
458
+ parser.add_argument("--vad_process_timeout", type=float, default=app_config.vad_process_timeout, \
459
+ help="The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.") # 1800
460
+ parser.add_argument("--auto_parallel", type=bool, default=app_config.auto_parallel, \
461
+ help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.") # False
462
+ parser.add_argument("--output_dir", "-o", type=str, default=app_config.output_dir, \
463
+ help="directory to save the outputs") # None
464
 
465
  args = parser.parse_args().__dict__
466
+ create_ui(app_config=app_config, **args)
cli.py CHANGED
@@ -6,48 +6,81 @@ import warnings
6
  import numpy as np
7
 
8
  import torch
9
- from app import LANGUAGES, WHISPER_MODELS, WhisperTranscriber
 
10
  from src.download import download_url
11
 
12
  from src.utils import optional_float, optional_int, str2bool
13
  from src.whisperContainer import WhisperContainer
14
 
15
  def cli():
 
 
 
16
  parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
17
- parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
18
- parser.add_argument("--model", default="small", choices=WHISPER_MODELS, help="name of the Whisper model to use")
19
- parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
20
- parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
21
- parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
22
- parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
23
-
24
- parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
25
- parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES), help="language spoken in the audio, specify None to perform language detection")
26
-
27
- parser.add_argument("--vad", type=str, default="none", choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], help="The voice activity detection algorithm to use")
28
- parser.add_argument("--vad_merge_window", type=optional_float, default=5, help="The window size (in seconds) to merge voice segments")
29
- parser.add_argument("--vad_max_merge_size", type=optional_float, default=30, help="The maximum size (in seconds) of a voice segment")
30
- parser.add_argument("--vad_padding", type=optional_float, default=1, help="The padding (in seconds) to add to each voice segment")
31
- parser.add_argument("--vad_prompt_window", type=optional_float, default=3, help="The window size of the prompt to pass to Whisper")
32
- parser.add_argument("--vad_cpu_cores", type=int, default=1, help="The number of CPU cores to use for VAD pre-processing.")
33
- parser.add_argument("--vad_parallel_devices", type=str, default="", help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.")
34
- parser.add_argument("--auto_parallel", type=bool, default=False, help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.")
35
-
36
- parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
37
- parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
38
- parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
39
- parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
40
- parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple lengt normalization by default")
41
-
42
- parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
43
- parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
44
- parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
45
- parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
46
-
47
- parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
48
- parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
49
- parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
50
- parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  args = parser.parse_args().__dict__
53
  model_name: str = args.pop("model")
@@ -74,12 +107,13 @@ def cli():
74
  vad_prompt_window = args.pop("vad_prompt_window")
75
  vad_cpu_cores = args.pop("vad_cpu_cores")
76
  auto_parallel = args.pop("auto_parallel")
77
-
78
- model = WhisperContainer(model_name, device=device, download_root=model_dir)
79
- transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores)
80
  transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
81
  transcriber.set_auto_parallel(auto_parallel)
82
 
 
 
83
  if (transcriber._has_parallel_devices()):
84
  print("Using parallel devices:", transcriber.parallel_device_list)
85
 
 
6
  import numpy as np
7
 
8
  import torch
9
+ from app import LANGUAGES, WhisperTranscriber
10
+ from src.config import ApplicationConfig
11
  from src.download import download_url
12
 
13
  from src.utils import optional_float, optional_int, str2bool
14
  from src.whisperContainer import WhisperContainer
15
 
16
  def cli():
17
+ app_config = ApplicationConfig.parse_file(os.environ.get("WHISPER_WEBUI_CONFIG", "config.json5"))
18
+ whisper_models = app_config.get_model_names()
19
+
20
  parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
21
+ parser.add_argument("audio", nargs="+", type=str, \
22
+ help="audio file(s) to transcribe")
23
+ parser.add_argument("--model", default=app_config.default_model_name, choices=whisper_models, \
24
+ help="name of the Whisper model to use") # medium
25
+ parser.add_argument("--model_dir", type=str, default=None, \
26
+ help="the path to save model files; uses ~/.cache/whisper by default")
27
+ parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", \
28
+ help="device to use for PyTorch inference")
29
+ parser.add_argument("--output_dir", "-o", type=str, default=".", \
30
+ help="directory to save the outputs")
31
+ parser.add_argument("--verbose", type=str2bool, default=True, \
32
+ help="whether to print out the progress and debug messages")
33
+
34
+ parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], \
35
+ help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
36
+ parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES), \
37
+ help="language spoken in the audio, specify None to perform language detection")
38
+
39
+ parser.add_argument("--vad", type=str, default=app_config.default_vad, choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], \
40
+ help="The voice activity detection algorithm to use") # silero-vad
41
+ parser.add_argument("--vad_merge_window", type=optional_float, default=5, \
42
+ help="The window size (in seconds) to merge voice segments")
43
+ parser.add_argument("--vad_max_merge_size", type=optional_float, default=30,\
44
+ help="The maximum size (in seconds) of a voice segment")
45
+ parser.add_argument("--vad_padding", type=optional_float, default=1, \
46
+ help="The padding (in seconds) to add to each voice segment")
47
+ parser.add_argument("--vad_prompt_window", type=optional_float, default=3, \
48
+ help="The window size of the prompt to pass to Whisper")
49
+ parser.add_argument("--vad_cpu_cores", type=int, default=app_config.vad_cpu_cores, \
50
+ help="The number of CPU cores to use for VAD pre-processing.") # 1
51
+ parser.add_argument("--vad_parallel_devices", type=str, default=app_config.vad_parallel_devices, \
52
+ help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.") # ""
53
+ parser.add_argument("--auto_parallel", type=bool, default=app_config.auto_parallel, \
54
+ help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.") # False
55
+
56
+ parser.add_argument("--temperature", type=float, default=0, \
57
+ help="temperature to use for sampling")
58
+ parser.add_argument("--best_of", type=optional_int, default=5, \
59
+ help="number of candidates when sampling with non-zero temperature")
60
+ parser.add_argument("--beam_size", type=optional_int, default=5, \
61
+ help="number of beams in beam search, only applicable when temperature is zero")
62
+ parser.add_argument("--patience", type=float, default=None, \
63
+ help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
64
+ parser.add_argument("--length_penalty", type=float, default=None, \
65
+ help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple lengt normalization by default")
66
+
67
+ parser.add_argument("--suppress_tokens", type=str, default="-1", \
68
+ help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
69
+ parser.add_argument("--initial_prompt", type=str, default=None, \
70
+ help="optional text to provide as a prompt for the first window.")
71
+ parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, \
72
+ help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
73
+ parser.add_argument("--fp16", type=str2bool, default=True, \
74
+ help="whether to perform inference in fp16; True by default")
75
+
76
+ parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, \
77
+ help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
78
+ parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, \
79
+ help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
80
+ parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, \
81
+ help="if the average log probability is lower than this value, treat the decoding as failed")
82
+ parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, \
83
+ help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
84
 
85
  args = parser.parse_args().__dict__
86
  model_name: str = args.pop("model")
 
107
  vad_prompt_window = args.pop("vad_prompt_window")
108
  vad_cpu_cores = args.pop("vad_cpu_cores")
109
  auto_parallel = args.pop("auto_parallel")
110
+
111
+ transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
 
112
  transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
113
  transcriber.set_auto_parallel(auto_parallel)
114
 
115
+ model = WhisperContainer(model_name, device=device, download_root=model_dir, models=app_config.models)
116
+
117
  if (transcriber._has_parallel_devices()):
118
  print("Using parallel devices:", transcriber.parallel_device_list)
119
 
config.json5 ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "models": [
3
+ // Configuration for the built-in models. You can remove any of these
4
+ // if you don't want to use the default models.
5
+ {
6
+ "name": "tiny",
7
+ "url": "tiny"
8
+ },
9
+ {
10
+ "name": "base",
11
+ "url": "base"
12
+ },
13
+ {
14
+ "name": "small",
15
+ "url": "small"
16
+ },
17
+ {
18
+ "name": "medium",
19
+ "url": "medium"
20
+ },
21
+ {
22
+ "name": "large",
23
+ "url": "large"
24
+ },
25
+ {
26
+ "name": "large-v2",
27
+ "url": "large-v2"
28
+ },
29
+ // Uncomment to add custom Japanese models
30
+ //{
31
+ // "name": "whisper-large-v2-mix-jp",
32
+ // "url": "vumichien/whisper-large-v2-mix-jp",
33
+ // // The type of the model. Can be "huggingface" or "whisper" - "whisper" is the default.
34
+ // // HuggingFace models are loaded using the HuggingFace transformers library and then converted to Whisper models.
35
+ // "type": "huggingface",
36
+ //}
37
+ ],
38
+ // Configuration options that will be used if they are not specified in the command line arguments.
39
+
40
+ // Maximum audio file length in seconds, or -1 for no limit.
41
+ "input_audio_max_duration": 600,
42
+ // True to share the app on HuggingFace.
43
+ "share": false,
44
+ // The host or IP to bind to. If None, bind to localhost.
45
+ "server_name": null,
46
+ // The port to bind to.
47
+ "server_port": 7860,
48
+ // The default model name.
49
+ "default_model_name": "medium",
50
+ // The default VAD.
51
+ "default_vad": "silero-vad",
52
+ // A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.
53
+ "vad_parallel_devices": "",
54
+ // The number of CPU cores to use for VAD pre-processing.
55
+ "vad_cpu_cores": 1,
56
+ // The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.
57
+ "vad_process_timeout": 1800,
58
+ // True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.
59
+ "auto_parallel": false,
60
+ // Directory to save the outputs
61
+ "output_dir": null
62
+ }
requirements.txt CHANGED
@@ -1,7 +1,9 @@
 
1
  git+https://github.com/openai/whisper.git
2
  transformers
3
  ffmpeg-python==0.2.0
4
  gradio==3.13.0
5
  yt-dlp
6
  torchaudio
7
- altair
 
 
1
+ git+https://github.com/huggingface/transformers
2
  git+https://github.com/openai/whisper.git
3
  transformers
4
  ffmpeg-python==0.2.0
5
  gradio==3.13.0
6
  yt-dlp
7
  torchaudio
8
+ altair
9
+ json5
src/config.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import urllib
2
+
3
+ import os
4
+ from typing import List
5
+ from urllib.parse import urlparse
6
+
7
+ from tqdm import tqdm
8
+
9
+ from src.conversion.hf_converter import convert_hf_whisper
10
+
11
+ class ModelConfig:
12
+ def __init__(self, name: str, url: str, path: str = None, type: str = "whisper"):
13
+ """
14
+ Initialize a model configuration.
15
+
16
+ name: Name of the model
17
+ url: URL to download the model from
18
+ path: Path to the model file. If not set, the model will be downloaded from the URL.
19
+ type: Type of model. Can be whisper or huggingface.
20
+ """
21
+ self.name = name
22
+ self.url = url
23
+ self.path = path
24
+ self.type = type
25
+
26
+ def download_url(self, root_dir: str):
27
+ import whisper
28
+
29
+ # See if path is already set
30
+ if self.path is not None:
31
+ return self.path
32
+
33
+ if root_dir is None:
34
+ root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
35
+
36
+ model_type = self.type.lower() if self.type is not None else "whisper"
37
+
38
+ if model_type in ["huggingface", "hf"]:
39
+ self.path = self.url
40
+ destination_target = os.path.join(root_dir, self.name + ".pt")
41
+
42
+ # Convert from HuggingFace format to Whisper format
43
+ if os.path.exists(destination_target):
44
+ print(f"File {destination_target} already exists, skipping conversion")
45
+ else:
46
+ print("Saving HuggingFace model in Whisper format to " + destination_target)
47
+ convert_hf_whisper(self.url, destination_target)
48
+
49
+ self.path = destination_target
50
+
51
+ elif model_type in ["whisper", "w"]:
52
+ self.path = self.url
53
+
54
+ # See if URL is just a file
55
+ if self.url in whisper._MODELS:
56
+ # No need to download anything - Whisper will handle it
57
+ self.path = self.url
58
+ elif self.url.startswith("file://"):
59
+ # Get file path
60
+ self.path = urlparse(self.url).path
61
+ # See if it is an URL
62
+ elif self.url.startswith("http://") or self.url.startswith("https://"):
63
+ # Extension (or file name)
64
+ extension = os.path.splitext(self.url)[-1]
65
+ download_target = os.path.join(root_dir, self.name + extension)
66
+
67
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
68
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
69
+
70
+ if not os.path.isfile(download_target):
71
+ self._download_file(self.url, download_target)
72
+ else:
73
+ print(f"File {download_target} already exists, skipping download")
74
+
75
+ self.path = download_target
76
+ # Must be a local file
77
+ else:
78
+ self.path = self.url
79
+
80
+ else:
81
+ raise ValueError(f"Unknown model type {model_type}")
82
+
83
+ return self.path
84
+
85
+ def _download_file(self, url: str, destination: str):
86
+ with urllib.request.urlopen(url) as source, open(destination, "wb") as output:
87
+ with tqdm(
88
+ total=int(source.info().get("Content-Length")),
89
+ ncols=80,
90
+ unit="iB",
91
+ unit_scale=True,
92
+ unit_divisor=1024,
93
+ ) as loop:
94
+ while True:
95
+ buffer = source.read(8192)
96
+ if not buffer:
97
+ break
98
+
99
+ output.write(buffer)
100
+ loop.update(len(buffer))
101
+
102
+ class ApplicationConfig:
103
+ def __init__(self, models: List[ModelConfig] = [], input_audio_max_duration: int = 600,
104
+ share: bool = False, server_name: str = None, server_port: int = 7860, default_model_name: str = "medium",
105
+ default_vad: str = "silero-vad", vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800,
106
+ auto_parallel: bool = False, output_dir: str = None):
107
+ self.models = models
108
+ self.input_audio_max_duration = input_audio_max_duration
109
+ self.share = share
110
+ self.server_name = server_name
111
+ self.server_port = server_port
112
+ self.default_model_name = default_model_name
113
+ self.default_vad = default_vad
114
+ self.vad_parallel_devices = vad_parallel_devices
115
+ self.vad_cpu_cores = vad_cpu_cores
116
+ self.vad_process_timeout = vad_process_timeout
117
+ self.auto_parallel = auto_parallel
118
+ self.output_dir = output_dir
119
+
120
+ def get_model_names(self):
121
+ return [ x.name for x in self.models ]
122
+
123
+ @staticmethod
124
+ def parse_file(config_path: str):
125
+ import json5
126
+
127
+ with open(config_path, "r") as f:
128
+ # Load using json5
129
+ data = json5.load(f)
130
+ data_models = data.pop("models", [])
131
+
132
+ models = [ ModelConfig(**x) for x in data_models ]
133
+
134
+ return ApplicationConfig(models, **data)
src/conversion/hf_converter.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets
2
+
3
+ from copy import deepcopy
4
+ import torch
5
+ from transformers import WhisperForConditionalGeneration
6
+
7
+ WHISPER_MAPPING = {
8
+ "layers": "blocks",
9
+ "fc1": "mlp.0",
10
+ "fc2": "mlp.2",
11
+ "final_layer_norm": "mlp_ln",
12
+ "layers": "blocks",
13
+ ".self_attn.q_proj": ".attn.query",
14
+ ".self_attn.k_proj": ".attn.key",
15
+ ".self_attn.v_proj": ".attn.value",
16
+ ".self_attn_layer_norm": ".attn_ln",
17
+ ".self_attn.out_proj": ".attn.out",
18
+ ".encoder_attn.q_proj": ".cross_attn.query",
19
+ ".encoder_attn.k_proj": ".cross_attn.key",
20
+ ".encoder_attn.v_proj": ".cross_attn.value",
21
+ ".encoder_attn_layer_norm": ".cross_attn_ln",
22
+ ".encoder_attn.out_proj": ".cross_attn.out",
23
+ "decoder.layer_norm.": "decoder.ln.",
24
+ "encoder.layer_norm.": "encoder.ln_post.",
25
+ "embed_tokens": "token_embedding",
26
+ "encoder.embed_positions.weight": "encoder.positional_embedding",
27
+ "decoder.embed_positions.weight": "decoder.positional_embedding",
28
+ "layer_norm": "ln_post",
29
+ }
30
+
31
+
32
+ def rename_keys(s_dict):
33
+ keys = list(s_dict.keys())
34
+ for key in keys:
35
+ new_key = key
36
+ for k, v in WHISPER_MAPPING.items():
37
+ if k in key:
38
+ new_key = new_key.replace(k, v)
39
+
40
+ print(f"{key} -> {new_key}")
41
+
42
+ s_dict[new_key] = s_dict.pop(key)
43
+ return s_dict
44
+
45
+
46
+ def convert_hf_whisper(hf_model_name_or_path: str, whisper_state_path: str):
47
+ transformer_model = WhisperForConditionalGeneration.from_pretrained(hf_model_name_or_path)
48
+ config = transformer_model.config
49
+
50
+ # first build dims
51
+ dims = {
52
+ 'n_mels': config.num_mel_bins,
53
+ 'n_vocab': config.vocab_size,
54
+ 'n_audio_ctx': config.max_source_positions,
55
+ 'n_audio_state': config.d_model,
56
+ 'n_audio_head': config.encoder_attention_heads,
57
+ 'n_audio_layer': config.encoder_layers,
58
+ 'n_text_ctx': config.max_target_positions,
59
+ 'n_text_state': config.d_model,
60
+ 'n_text_head': config.decoder_attention_heads,
61
+ 'n_text_layer': config.decoder_layers
62
+ }
63
+
64
+ state_dict = deepcopy(transformer_model.model.state_dict())
65
+ state_dict = rename_keys(state_dict)
66
+
67
+ torch.save({"dims": dims, "model_state_dict": state_dict}, whisper_state_path)
src/whisperContainer.py CHANGED
@@ -1,11 +1,14 @@
1
  # External programs
2
  import os
 
3
  import whisper
 
4
 
5
  from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
6
 
7
  class WhisperContainer:
8
- def __init__(self, model_name: str, device: str = None, download_root: str = None, cache: ModelCache = None):
 
9
  self.model_name = model_name
10
  self.device = device
11
  self.download_root = download_root
@@ -13,6 +16,9 @@ class WhisperContainer:
13
 
14
  # Will be created on demand
15
  self.model = None
 
 
 
16
 
17
  def get_model(self):
18
  if self.model is None:
@@ -32,21 +38,40 @@ class WhisperContainer:
32
  # Warning: Using private API here
33
  try:
34
  root_dir = self.download_root
 
35
 
36
  if root_dir is None:
37
  root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
38
 
39
  if self.model_name in whisper._MODELS:
40
  whisper._download(whisper._MODELS[self.model_name], root_dir, False)
 
 
 
41
  return True
 
42
  except Exception as e:
43
  # Given that the API is private, it could change at any time. We don't want to crash the program
44
  print("Error pre-downloading model: " + str(e))
45
  return False
46
 
 
 
 
 
 
 
 
 
 
47
  def _create_model(self):
48
  print("Loading whisper model " + self.model_name)
49
- return whisper.load_model(self.model_name, device=self.device, download_root=self.download_root)
 
 
 
 
 
50
 
51
  def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
52
  """
@@ -71,12 +96,13 @@ class WhisperContainer:
71
 
72
  # This is required for multiprocessing
73
  def __getstate__(self):
74
- return { "model_name": self.model_name, "device": self.device, "download_root": self.download_root }
75
 
76
  def __setstate__(self, state):
77
  self.model_name = state["model_name"]
78
  self.device = state["device"]
79
  self.download_root = state["download_root"]
 
80
  self.model = None
81
  # Depickled objects must use the global cache
82
  self.cache = GLOBAL_MODEL_CACHE
 
1
  # External programs
2
  import os
3
+ from typing import List
4
  import whisper
5
+ from src.config import ModelConfig
6
 
7
  from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
8
 
9
  class WhisperContainer:
10
+ def __init__(self, model_name: str, device: str = None, download_root: str = None,
11
+ cache: ModelCache = None, models: List[ModelConfig] = []):
12
  self.model_name = model_name
13
  self.device = device
14
  self.download_root = download_root
 
16
 
17
  # Will be created on demand
18
  self.model = None
19
+
20
+ # List of known models
21
+ self.models = models
22
 
23
  def get_model(self):
24
  if self.model is None:
 
38
  # Warning: Using private API here
39
  try:
40
  root_dir = self.download_root
41
+ model_config = self.get_model_config()
42
 
43
  if root_dir is None:
44
  root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
45
 
46
  if self.model_name in whisper._MODELS:
47
  whisper._download(whisper._MODELS[self.model_name], root_dir, False)
48
+ else:
49
+ # If the model is not in the official list, see if it needs to be downloaded
50
+ model_config.download_url(root_dir)
51
  return True
52
+
53
  except Exception as e:
54
  # Given that the API is private, it could change at any time. We don't want to crash the program
55
  print("Error pre-downloading model: " + str(e))
56
  return False
57
 
58
+ def get_model_config(self) -> ModelConfig:
59
+ """
60
+ Get the model configuration for the model.
61
+ """
62
+ for model in self.models:
63
+ if model.name == self.model_name:
64
+ return model
65
+ return None
66
+
67
  def _create_model(self):
68
  print("Loading whisper model " + self.model_name)
69
+
70
+ model_config = self.get_model_config()
71
+ # Note that the model will not be downloaded in the case of an official Whisper model
72
+ model_path = model_config.download_url(self.download_root)
73
+
74
+ return whisper.load_model(model_path, device=self.device, download_root=self.download_root)
75
 
76
  def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
77
  """
 
96
 
97
  # This is required for multiprocessing
98
  def __getstate__(self):
99
+ return { "model_name": self.model_name, "device": self.device, "download_root": self.download_root, "models": self.models }
100
 
101
  def __setstate__(self, state):
102
  self.model_name = state["model_name"]
103
  self.device = state["device"]
104
  self.download_root = state["download_root"]
105
+ self.models = state["models"]
106
  self.model = None
107
  # Depickled objects must use the global cache
108
  self.cache = GLOBAL_MODEL_CACHE