jhj0517 commited on
Commit
6ae85bd
·
unverified ·
2 Parent(s): bc6b2e9 87cbb02

Merge pull request #330 from jhj0517/fix/compute-type

Browse files
app.py CHANGED
@@ -298,8 +298,7 @@ class App:
298
  tb_api_key = gr.Textbox(label="Your Auth Key (API KEY)", value=deepl_params["api_key"])
299
  with gr.Row():
300
  dd_source_lang = gr.Dropdown(label="Source Language", value=deepl_params["source_lang"],
301
- choices=list(
302
- self.deepl_api.available_source_langs.keys()))
303
  dd_target_lang = gr.Dropdown(label="Target Language", value=deepl_params["target_lang"],
304
  choices=list(self.deepl_api.available_target_langs.keys()))
305
  with gr.Row():
 
298
  tb_api_key = gr.Textbox(label="Your Auth Key (API KEY)", value=deepl_params["api_key"])
299
  with gr.Row():
300
  dd_source_lang = gr.Dropdown(label="Source Language", value=deepl_params["source_lang"],
301
+ choices=list(self.deepl_api.available_source_langs.keys()))
 
302
  dd_target_lang = gr.Dropdown(label="Target Language", value=deepl_params["target_lang"],
303
  choices=list(self.deepl_api.available_target_langs.keys()))
304
  with gr.Row():
modules/translation/nllb_inference.py CHANGED
@@ -3,10 +3,10 @@ import gradio as gr
3
  import os
4
 
5
  from modules.utils.paths import TRANSLATION_OUTPUT_DIR, NLLB_MODELS_DIR
6
- from modules.translation.translation_base import TranslationBase
7
 
8
 
9
- class NLLBInference(TranslationBase):
10
  def __init__(self,
11
  model_dir: str = NLLB_MODELS_DIR,
12
  output_dir: str = TRANSLATION_OUTPUT_DIR
@@ -29,7 +29,7 @@ class NLLBInference(TranslationBase):
29
  text,
30
  max_length=max_length
31
  )
32
- return result[0]['translation_text']
33
 
34
  def update_model(self,
35
  model_size: str,
@@ -41,8 +41,7 @@ class NLLBInference(TranslationBase):
41
  if lang in NLLB_AVAILABLE_LANGS:
42
  return NLLB_AVAILABLE_LANGS[lang]
43
  elif lang not in NLLB_AVAILABLE_LANGS.values():
44
- raise ValueError(
45
- f"Language '{lang}' is not supported. Use one of: {list(NLLB_AVAILABLE_LANGS.keys())}")
46
  return lang
47
 
48
  src_lang = validate_language(src_lang)
 
3
  import os
4
 
5
  from modules.utils.paths import TRANSLATION_OUTPUT_DIR, NLLB_MODELS_DIR
6
+ import modules.translation.translation_base as base
7
 
8
 
9
+ class NLLBInference(base.TranslationBase):
10
  def __init__(self,
11
  model_dir: str = NLLB_MODELS_DIR,
12
  output_dir: str = TRANSLATION_OUTPUT_DIR
 
29
  text,
30
  max_length=max_length
31
  )
32
+ return result[0]["translation_text"]
33
 
34
  def update_model(self,
35
  model_size: str,
 
41
  if lang in NLLB_AVAILABLE_LANGS:
42
  return NLLB_AVAILABLE_LANGS[lang]
43
  elif lang not in NLLB_AVAILABLE_LANGS.values():
44
+ raise ValueError(f"Language '{lang}' is not supported. Use one of: {list(NLLB_AVAILABLE_LANGS.keys())}")
 
45
  return lang
46
 
47
  src_lang = validate_language(src_lang)
modules/translation/translation_base.py CHANGED
@@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
5
  from typing import List
6
  from datetime import datetime
7
 
 
8
  from modules.whisper.whisper_parameter import *
9
  from modules.utils.subtitle_manager import *
10
  from modules.utils.files_manager import load_yaml, save_yaml
@@ -166,11 +167,17 @@ class TranslationBase(ABC):
166
  tgt_lang: str,
167
  max_length: int,
168
  add_timestamp: bool):
 
 
 
 
 
 
169
  cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
170
  cached_params["translation"]["nllb"] = {
171
  "model_size": model_size,
172
- "source_lang": src_lang,
173
- "target_lang": tgt_lang,
174
  "max_length": max_length,
175
  }
176
  cached_params["translation"]["add_timestamp"] = add_timestamp
 
5
  from typing import List
6
  from datetime import datetime
7
 
8
+ import modules.translation.nllb_inference as nllb
9
  from modules.whisper.whisper_parameter import *
10
  from modules.utils.subtitle_manager import *
11
  from modules.utils.files_manager import load_yaml, save_yaml
 
167
  tgt_lang: str,
168
  max_length: int,
169
  add_timestamp: bool):
170
+ def validate_lang(lang: str):
171
+ if lang in list(nllb.NLLB_AVAILABLE_LANGS.values()):
172
+ flipped = {value: key for key, value in nllb.NLLB_AVAILABLE_LANGS.items()}
173
+ return flipped[lang]
174
+ return lang
175
+
176
  cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
177
  cached_params["translation"]["nllb"] = {
178
  "model_size": model_size,
179
+ "source_lang": validate_lang(src_lang),
180
+ "target_lang": validate_lang(tgt_lang),
181
  "max_length": max_length,
182
  }
183
  cached_params["translation"]["add_timestamp"] = add_timestamp
modules/whisper/faster_whisper_inference.py CHANGED
@@ -35,8 +35,6 @@ class FasterWhisperInference(WhisperBase):
35
  self.model_paths = self.get_model_paths()
36
  self.device = self.get_device()
37
  self.available_models = self.model_paths.keys()
38
- self.available_compute_types = ctranslate2.get_supported_compute_types(
39
- "cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
40
 
41
  def transcribe(self,
42
  audio: Union[str, BinaryIO, np.ndarray],
 
35
  self.model_paths = self.get_model_paths()
36
  self.device = self.get_device()
37
  self.available_models = self.model_paths.keys()
 
 
38
 
39
  def transcribe(self,
40
  audio: Union[str, BinaryIO, np.ndarray],
modules/whisper/insanely_fast_whisper_inference.py CHANGED
@@ -35,7 +35,6 @@ class InsanelyFastWhisperInference(WhisperBase):
35
  openai_models = whisper.available_models()
36
  distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
37
  self.available_models = openai_models + distil_models
38
- self.available_compute_types = ["float16"]
39
 
40
  def transcribe(self,
41
  audio: Union[str, np.ndarray, torch.Tensor],
 
35
  openai_models = whisper.available_models()
36
  distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
37
  self.available_models = openai_models + distil_models
 
38
 
39
  def transcribe(self,
40
  audio: Union[str, np.ndarray, torch.Tensor],
modules/whisper/whisper_base.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import torch
3
  import whisper
 
4
  import gradio as gr
5
  import torchaudio
6
  from abc import ABC, abstractmethod
@@ -47,8 +48,8 @@ class WhisperBase(ABC):
47
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
48
  self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
49
  self.device = self.get_device()
50
- self.available_compute_types = ["float16", "float32"]
51
- self.current_compute_type = "float16" if self.device == "cuda" else "float32"
52
 
53
  @abstractmethod
54
  def transcribe(self,
@@ -371,6 +372,20 @@ class WhisperBase(ABC):
371
  finally:
372
  self.release_cuda_memory()
373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
  @staticmethod
375
  def generate_and_write_file(file_name: str,
376
  transcribed_segments: list,
 
1
  import os
2
  import torch
3
  import whisper
4
+ import ctranslate2
5
  import gradio as gr
6
  import torchaudio
7
  from abc import ABC, abstractmethod
 
48
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
49
  self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
50
  self.device = self.get_device()
51
+ self.available_compute_types = self.get_available_compute_type()
52
+ self.current_compute_type = self.get_compute_type()
53
 
54
  @abstractmethod
55
  def transcribe(self,
 
372
  finally:
373
  self.release_cuda_memory()
374
 
375
+ def get_compute_type(self):
376
+ if "float16" in self.available_compute_types:
377
+ return "float16"
378
+ if "float32" in self.available_compute_types:
379
+ return "float32"
380
+ else:
381
+ return self.available_compute_types[0]
382
+
383
+ def get_available_compute_type(self):
384
+ if self.device == "cuda":
385
+ return list(ctranslate2.get_supported_compute_types("cuda"))
386
+ else:
387
+ return list(ctranslate2.get_supported_compute_types("cpu"))
388
+
389
  @staticmethod
390
  def generate_and_write_file(file_name: str,
391
  transcribed_segments: list,