Spaces:
Running
Running
Merge pull request #330 from jhj0517/fix/compute-type
Browse files
app.py
CHANGED
@@ -298,8 +298,7 @@ class App:
|
|
298 |
tb_api_key = gr.Textbox(label="Your Auth Key (API KEY)", value=deepl_params["api_key"])
|
299 |
with gr.Row():
|
300 |
dd_source_lang = gr.Dropdown(label="Source Language", value=deepl_params["source_lang"],
|
301 |
-
|
302 |
-
self.deepl_api.available_source_langs.keys()))
|
303 |
dd_target_lang = gr.Dropdown(label="Target Language", value=deepl_params["target_lang"],
|
304 |
choices=list(self.deepl_api.available_target_langs.keys()))
|
305 |
with gr.Row():
|
|
|
298 |
tb_api_key = gr.Textbox(label="Your Auth Key (API KEY)", value=deepl_params["api_key"])
|
299 |
with gr.Row():
|
300 |
dd_source_lang = gr.Dropdown(label="Source Language", value=deepl_params["source_lang"],
|
301 |
+
choices=list(self.deepl_api.available_source_langs.keys()))
|
|
|
302 |
dd_target_lang = gr.Dropdown(label="Target Language", value=deepl_params["target_lang"],
|
303 |
choices=list(self.deepl_api.available_target_langs.keys()))
|
304 |
with gr.Row():
|
modules/translation/nllb_inference.py
CHANGED
@@ -3,10 +3,10 @@ import gradio as gr
|
|
3 |
import os
|
4 |
|
5 |
from modules.utils.paths import TRANSLATION_OUTPUT_DIR, NLLB_MODELS_DIR
|
6 |
-
|
7 |
|
8 |
|
9 |
-
class NLLBInference(TranslationBase):
|
10 |
def __init__(self,
|
11 |
model_dir: str = NLLB_MODELS_DIR,
|
12 |
output_dir: str = TRANSLATION_OUTPUT_DIR
|
@@ -29,7 +29,7 @@ class NLLBInference(TranslationBase):
|
|
29 |
text,
|
30 |
max_length=max_length
|
31 |
)
|
32 |
-
return result[0][
|
33 |
|
34 |
def update_model(self,
|
35 |
model_size: str,
|
@@ -41,8 +41,7 @@ class NLLBInference(TranslationBase):
|
|
41 |
if lang in NLLB_AVAILABLE_LANGS:
|
42 |
return NLLB_AVAILABLE_LANGS[lang]
|
43 |
elif lang not in NLLB_AVAILABLE_LANGS.values():
|
44 |
-
raise ValueError(
|
45 |
-
f"Language '{lang}' is not supported. Use one of: {list(NLLB_AVAILABLE_LANGS.keys())}")
|
46 |
return lang
|
47 |
|
48 |
src_lang = validate_language(src_lang)
|
|
|
3 |
import os
|
4 |
|
5 |
from modules.utils.paths import TRANSLATION_OUTPUT_DIR, NLLB_MODELS_DIR
|
6 |
+
import modules.translation.translation_base as base
|
7 |
|
8 |
|
9 |
+
class NLLBInference(base.TranslationBase):
|
10 |
def __init__(self,
|
11 |
model_dir: str = NLLB_MODELS_DIR,
|
12 |
output_dir: str = TRANSLATION_OUTPUT_DIR
|
|
|
29 |
text,
|
30 |
max_length=max_length
|
31 |
)
|
32 |
+
return result[0]["translation_text"]
|
33 |
|
34 |
def update_model(self,
|
35 |
model_size: str,
|
|
|
41 |
if lang in NLLB_AVAILABLE_LANGS:
|
42 |
return NLLB_AVAILABLE_LANGS[lang]
|
43 |
elif lang not in NLLB_AVAILABLE_LANGS.values():
|
44 |
+
raise ValueError(f"Language '{lang}' is not supported. Use one of: {list(NLLB_AVAILABLE_LANGS.keys())}")
|
|
|
45 |
return lang
|
46 |
|
47 |
src_lang = validate_language(src_lang)
|
modules/translation/translation_base.py
CHANGED
@@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
|
|
5 |
from typing import List
|
6 |
from datetime import datetime
|
7 |
|
|
|
8 |
from modules.whisper.whisper_parameter import *
|
9 |
from modules.utils.subtitle_manager import *
|
10 |
from modules.utils.files_manager import load_yaml, save_yaml
|
@@ -166,11 +167,17 @@ class TranslationBase(ABC):
|
|
166 |
tgt_lang: str,
|
167 |
max_length: int,
|
168 |
add_timestamp: bool):
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
|
170 |
cached_params["translation"]["nllb"] = {
|
171 |
"model_size": model_size,
|
172 |
-
"source_lang": src_lang,
|
173 |
-
"target_lang": tgt_lang,
|
174 |
"max_length": max_length,
|
175 |
}
|
176 |
cached_params["translation"]["add_timestamp"] = add_timestamp
|
|
|
5 |
from typing import List
|
6 |
from datetime import datetime
|
7 |
|
8 |
+
import modules.translation.nllb_inference as nllb
|
9 |
from modules.whisper.whisper_parameter import *
|
10 |
from modules.utils.subtitle_manager import *
|
11 |
from modules.utils.files_manager import load_yaml, save_yaml
|
|
|
167 |
tgt_lang: str,
|
168 |
max_length: int,
|
169 |
add_timestamp: bool):
|
170 |
+
def validate_lang(lang: str):
|
171 |
+
if lang in list(nllb.NLLB_AVAILABLE_LANGS.values()):
|
172 |
+
flipped = {value: key for key, value in nllb.NLLB_AVAILABLE_LANGS.items()}
|
173 |
+
return flipped[lang]
|
174 |
+
return lang
|
175 |
+
|
176 |
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
|
177 |
cached_params["translation"]["nllb"] = {
|
178 |
"model_size": model_size,
|
179 |
+
"source_lang": validate_lang(src_lang),
|
180 |
+
"target_lang": validate_lang(tgt_lang),
|
181 |
"max_length": max_length,
|
182 |
}
|
183 |
cached_params["translation"]["add_timestamp"] = add_timestamp
|
modules/whisper/faster_whisper_inference.py
CHANGED
@@ -35,8 +35,6 @@ class FasterWhisperInference(WhisperBase):
|
|
35 |
self.model_paths = self.get_model_paths()
|
36 |
self.device = self.get_device()
|
37 |
self.available_models = self.model_paths.keys()
|
38 |
-
self.available_compute_types = ctranslate2.get_supported_compute_types(
|
39 |
-
"cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
|
40 |
|
41 |
def transcribe(self,
|
42 |
audio: Union[str, BinaryIO, np.ndarray],
|
|
|
35 |
self.model_paths = self.get_model_paths()
|
36 |
self.device = self.get_device()
|
37 |
self.available_models = self.model_paths.keys()
|
|
|
|
|
38 |
|
39 |
def transcribe(self,
|
40 |
audio: Union[str, BinaryIO, np.ndarray],
|
modules/whisper/insanely_fast_whisper_inference.py
CHANGED
@@ -35,7 +35,6 @@ class InsanelyFastWhisperInference(WhisperBase):
|
|
35 |
openai_models = whisper.available_models()
|
36 |
distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
|
37 |
self.available_models = openai_models + distil_models
|
38 |
-
self.available_compute_types = ["float16"]
|
39 |
|
40 |
def transcribe(self,
|
41 |
audio: Union[str, np.ndarray, torch.Tensor],
|
|
|
35 |
openai_models = whisper.available_models()
|
36 |
distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
|
37 |
self.available_models = openai_models + distil_models
|
|
|
38 |
|
39 |
def transcribe(self,
|
40 |
audio: Union[str, np.ndarray, torch.Tensor],
|
modules/whisper/whisper_base.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import os
|
2 |
import torch
|
3 |
import whisper
|
|
|
4 |
import gradio as gr
|
5 |
import torchaudio
|
6 |
from abc import ABC, abstractmethod
|
@@ -47,8 +48,8 @@ class WhisperBase(ABC):
|
|
47 |
self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
|
48 |
self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
|
49 |
self.device = self.get_device()
|
50 |
-
self.available_compute_types =
|
51 |
-
self.current_compute_type =
|
52 |
|
53 |
@abstractmethod
|
54 |
def transcribe(self,
|
@@ -371,6 +372,20 @@ class WhisperBase(ABC):
|
|
371 |
finally:
|
372 |
self.release_cuda_memory()
|
373 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
374 |
@staticmethod
|
375 |
def generate_and_write_file(file_name: str,
|
376 |
transcribed_segments: list,
|
|
|
1 |
import os
|
2 |
import torch
|
3 |
import whisper
|
4 |
+
import ctranslate2
|
5 |
import gradio as gr
|
6 |
import torchaudio
|
7 |
from abc import ABC, abstractmethod
|
|
|
48 |
self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
|
49 |
self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
|
50 |
self.device = self.get_device()
|
51 |
+
self.available_compute_types = self.get_available_compute_type()
|
52 |
+
self.current_compute_type = self.get_compute_type()
|
53 |
|
54 |
@abstractmethod
|
55 |
def transcribe(self,
|
|
|
372 |
finally:
|
373 |
self.release_cuda_memory()
|
374 |
|
375 |
+
def get_compute_type(self):
|
376 |
+
if "float16" in self.available_compute_types:
|
377 |
+
return "float16"
|
378 |
+
if "float32" in self.available_compute_types:
|
379 |
+
return "float32"
|
380 |
+
else:
|
381 |
+
return self.available_compute_types[0]
|
382 |
+
|
383 |
+
def get_available_compute_type(self):
|
384 |
+
if self.device == "cuda":
|
385 |
+
return list(ctranslate2.get_supported_compute_types("cuda"))
|
386 |
+
else:
|
387 |
+
return list(ctranslate2.get_supported_compute_types("cpu"))
|
388 |
+
|
389 |
@staticmethod
|
390 |
def generate_and_write_file(file_name: str,
|
391 |
transcribed_segments: list,
|