Spaces:
Running
Running
jhj0517
commited on
Commit
·
3be2b51
1
Parent(s):
db355d3
Add `from_list()` to use in gradio function
Browse files- modules/whisper/data_classes.py +46 -18
modules/whisper/data_classes.py
CHANGED
@@ -4,6 +4,7 @@ from typing import Optional, Dict, List
|
|
4 |
from pydantic import BaseModel, Field, field_validator
|
5 |
from gradio_i18n import Translate, gettext as _
|
6 |
from enum import Enum
|
|
|
7 |
import yaml
|
8 |
|
9 |
from modules.utils.constants import AUTOMATIC_DETECTION
|
@@ -15,7 +16,20 @@ class WhisperImpl(Enum):
|
|
15 |
INSANELY_FAST_WHISPER = "insanely_fast_whisper"
|
16 |
|
17 |
|
18 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
"""Voice Activity Detection parameters"""
|
20 |
vad_filter: bool = Field(default=False, description="Enable voice activity detection to filter out non-speech parts")
|
21 |
threshold: float = Field(
|
@@ -45,9 +59,6 @@ class VadParams(BaseModel):
|
|
45 |
description="Padding added to each side of speech chunks"
|
46 |
)
|
47 |
|
48 |
-
def to_dict(self) -> Dict:
|
49 |
-
return self.model_dump()
|
50 |
-
|
51 |
@classmethod
|
52 |
def to_gradio_inputs(cls, defaults: Optional[Dict] = None) -> List[gr.components.base.FormComponent]:
|
53 |
defaults = defaults or {}
|
@@ -74,8 +85,7 @@ class VadParams(BaseModel):
|
|
74 |
]
|
75 |
|
76 |
|
77 |
-
|
78 |
-
class DiarizationParams(BaseModel):
|
79 |
"""Speaker diarization parameters"""
|
80 |
is_diarize: bool = Field(default=False, description="Enable speaker diarization")
|
81 |
hf_token: str = Field(
|
@@ -83,9 +93,6 @@ class DiarizationParams(BaseModel):
|
|
83 |
description="Hugging Face token for downloading diarization models"
|
84 |
)
|
85 |
|
86 |
-
def to_dict(self) -> Dict:
|
87 |
-
return self.model_dump()
|
88 |
-
|
89 |
@classmethod
|
90 |
def to_gradio_inputs(cls,
|
91 |
defaults: Optional[Dict] = None,
|
@@ -112,7 +119,7 @@ class DiarizationParams(BaseModel):
|
|
112 |
]
|
113 |
|
114 |
|
115 |
-
class BGMSeparationParams(
|
116 |
"""Background music separation parameters"""
|
117 |
is_separate_bgm: bool = Field(default=False, description="Enable background music separation")
|
118 |
model_size: str = Field(
|
@@ -133,9 +140,6 @@ class BGMSeparationParams(BaseModel):
|
|
133 |
description="Offload UVR model after transcription"
|
134 |
)
|
135 |
|
136 |
-
def to_dict(self) -> Dict:
|
137 |
-
return self.model_dump()
|
138 |
-
|
139 |
@classmethod
|
140 |
def to_gradio_input(cls,
|
141 |
defaults: Optional[Dict] = None,
|
@@ -181,7 +185,7 @@ class BGMSeparationParams(BaseModel):
|
|
181 |
]
|
182 |
|
183 |
|
184 |
-
class WhisperParams(
|
185 |
"""Whisper parameters"""
|
186 |
model_size: str = Field(default="large-v2", description="Whisper model size")
|
187 |
lang: Optional[str] = Field(default=None, description="Source language of the file to transcribe")
|
@@ -262,9 +266,6 @@ class WhisperParams(BaseModel):
|
|
262 |
description="Number of segments for language detection"
|
263 |
)
|
264 |
|
265 |
-
def to_dict(self):
|
266 |
-
return self.model_dump()
|
267 |
-
|
268 |
@field_validator('lang')
|
269 |
def validate_lang(cls, v):
|
270 |
from modules.utils.constants import AUTOMATIC_DETECTION
|
@@ -485,9 +486,36 @@ class TranscriptionPipelineParams(BaseModel):
|
|
485 |
}
|
486 |
return data
|
487 |
|
488 |
-
def
|
|
|
|
|
|
|
|
|
|
|
489 |
whisper_list = [value for key, value in self.whisper.to_dict().items()]
|
490 |
vad_list = [value for key, value in self.vad.to_dict().items()]
|
491 |
diarization_list = [value for key, value in self.vad.to_dict().items()]
|
492 |
bgm_sep_list = [value for key, value in self.bgm_separation.to_dict().items()]
|
493 |
return whisper_list + vad_list + diarization_list + bgm_sep_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from pydantic import BaseModel, Field, field_validator
|
5 |
from gradio_i18n import Translate, gettext as _
|
6 |
from enum import Enum
|
7 |
+
from copy import deepcopy
|
8 |
import yaml
|
9 |
|
10 |
from modules.utils.constants import AUTOMATIC_DETECTION
|
|
|
16 |
INSANELY_FAST_WHISPER = "insanely_fast_whisper"
|
17 |
|
18 |
|
19 |
+
class BaseParams(BaseModel):
|
20 |
+
def to_dict(self) -> Dict:
|
21 |
+
return self.model_dump()
|
22 |
+
|
23 |
+
def to_list(self) -> List:
|
24 |
+
return list(self.model_dump().values())
|
25 |
+
|
26 |
+
@classmethod
|
27 |
+
def from_list(cls, data_list: List) -> 'BaseParams':
|
28 |
+
field_names = list(cls.model_fields.keys())
|
29 |
+
return cls(**dict(zip(field_names, data_list)))
|
30 |
+
|
31 |
+
|
32 |
+
class VadParams(BaseParams):
|
33 |
"""Voice Activity Detection parameters"""
|
34 |
vad_filter: bool = Field(default=False, description="Enable voice activity detection to filter out non-speech parts")
|
35 |
threshold: float = Field(
|
|
|
59 |
description="Padding added to each side of speech chunks"
|
60 |
)
|
61 |
|
|
|
|
|
|
|
62 |
@classmethod
|
63 |
def to_gradio_inputs(cls, defaults: Optional[Dict] = None) -> List[gr.components.base.FormComponent]:
|
64 |
defaults = defaults or {}
|
|
|
85 |
]
|
86 |
|
87 |
|
88 |
+
class DiarizationParams(BaseParams):
|
|
|
89 |
"""Speaker diarization parameters"""
|
90 |
is_diarize: bool = Field(default=False, description="Enable speaker diarization")
|
91 |
hf_token: str = Field(
|
|
|
93 |
description="Hugging Face token for downloading diarization models"
|
94 |
)
|
95 |
|
|
|
|
|
|
|
96 |
@classmethod
|
97 |
def to_gradio_inputs(cls,
|
98 |
defaults: Optional[Dict] = None,
|
|
|
119 |
]
|
120 |
|
121 |
|
122 |
+
class BGMSeparationParams(BaseParams):
|
123 |
"""Background music separation parameters"""
|
124 |
is_separate_bgm: bool = Field(default=False, description="Enable background music separation")
|
125 |
model_size: str = Field(
|
|
|
140 |
description="Offload UVR model after transcription"
|
141 |
)
|
142 |
|
|
|
|
|
|
|
143 |
@classmethod
|
144 |
def to_gradio_input(cls,
|
145 |
defaults: Optional[Dict] = None,
|
|
|
185 |
]
|
186 |
|
187 |
|
188 |
+
class WhisperParams(BaseParams):
|
189 |
"""Whisper parameters"""
|
190 |
model_size: str = Field(default="large-v2", description="Whisper model size")
|
191 |
lang: Optional[str] = Field(default=None, description="Source language of the file to transcribe")
|
|
|
266 |
description="Number of segments for language detection"
|
267 |
)
|
268 |
|
|
|
|
|
|
|
269 |
@field_validator('lang')
|
270 |
def validate_lang(cls, v):
|
271 |
from modules.utils.constants import AUTOMATIC_DETECTION
|
|
|
486 |
}
|
487 |
return data
|
488 |
|
489 |
+
def to_list(self) -> List:
|
490 |
+
"""
|
491 |
+
Convert data class to the list because I have to pass the parameters as a list in the gradio.
|
492 |
+
Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471
|
493 |
+
See more about Gradio pre-processing: https://www.gradio.app/docs/components
|
494 |
+
"""
|
495 |
whisper_list = [value for key, value in self.whisper.to_dict().items()]
|
496 |
vad_list = [value for key, value in self.vad.to_dict().items()]
|
497 |
diarization_list = [value for key, value in self.vad.to_dict().items()]
|
498 |
bgm_sep_list = [value for key, value in self.bgm_separation.to_dict().items()]
|
499 |
return whisper_list + vad_list + diarization_list + bgm_sep_list
|
500 |
+
|
501 |
+
@staticmethod
|
502 |
+
def from_list(pipeline_list: List) -> 'TranscriptionPipelineParams':
|
503 |
+
"""Convert list to the data class again to use it in a function."""
|
504 |
+
data_list = deepcopy(pipeline_list)
|
505 |
+
whisper_list, data_list = data_list[0:len(WhisperParams.__annotations__)]
|
506 |
+
data_list = data_list[len(WhisperParams.__annotations__):]
|
507 |
+
|
508 |
+
vad_list = data_list[0:len(VadParams.__annotations__)]
|
509 |
+
data_list = data_list[len(VadParams.__annotations__):]
|
510 |
+
|
511 |
+
diarization_list = data_list[0:len(DiarizationParams.__annotations__)]
|
512 |
+
data_list = data_list[len(DiarizationParams.__annotations__)]
|
513 |
+
|
514 |
+
bgm_sep_list = data_list[0:len(BGMSeparationParams.__annotations__)]
|
515 |
+
|
516 |
+
return TranscriptionPipelineParams(
|
517 |
+
whisper=WhisperParams.from_list(whisper_list),
|
518 |
+
vad=VadParams.from_list(vad_list),
|
519 |
+
diarization=DiarizationParams.from_list(diarization_list),
|
520 |
+
bgm_separation=BGMSeparationParams.from_list(bgm_sep_list)
|
521 |
+
)
|