jhj0517 commited on
Commit
3be2b51
·
1 Parent(s): db355d3

Add `from_list()` to use in gradio function

Browse files
Files changed (1) hide show
  1. modules/whisper/data_classes.py +46 -18
modules/whisper/data_classes.py CHANGED
@@ -4,6 +4,7 @@ from typing import Optional, Dict, List
4
  from pydantic import BaseModel, Field, field_validator
5
  from gradio_i18n import Translate, gettext as _
6
  from enum import Enum
 
7
  import yaml
8
 
9
  from modules.utils.constants import AUTOMATIC_DETECTION
@@ -15,7 +16,20 @@ class WhisperImpl(Enum):
15
  INSANELY_FAST_WHISPER = "insanely_fast_whisper"
16
 
17
 
18
- class VadParams(BaseModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  """Voice Activity Detection parameters"""
20
  vad_filter: bool = Field(default=False, description="Enable voice activity detection to filter out non-speech parts")
21
  threshold: float = Field(
@@ -45,9 +59,6 @@ class VadParams(BaseModel):
45
  description="Padding added to each side of speech chunks"
46
  )
47
 
48
- def to_dict(self) -> Dict:
49
- return self.model_dump()
50
-
51
  @classmethod
52
  def to_gradio_inputs(cls, defaults: Optional[Dict] = None) -> List[gr.components.base.FormComponent]:
53
  defaults = defaults or {}
@@ -74,8 +85,7 @@ class VadParams(BaseModel):
74
  ]
75
 
76
 
77
-
78
- class DiarizationParams(BaseModel):
79
  """Speaker diarization parameters"""
80
  is_diarize: bool = Field(default=False, description="Enable speaker diarization")
81
  hf_token: str = Field(
@@ -83,9 +93,6 @@ class DiarizationParams(BaseModel):
83
  description="Hugging Face token for downloading diarization models"
84
  )
85
 
86
- def to_dict(self) -> Dict:
87
- return self.model_dump()
88
-
89
  @classmethod
90
  def to_gradio_inputs(cls,
91
  defaults: Optional[Dict] = None,
@@ -112,7 +119,7 @@ class DiarizationParams(BaseModel):
112
  ]
113
 
114
 
115
- class BGMSeparationParams(BaseModel):
116
  """Background music separation parameters"""
117
  is_separate_bgm: bool = Field(default=False, description="Enable background music separation")
118
  model_size: str = Field(
@@ -133,9 +140,6 @@ class BGMSeparationParams(BaseModel):
133
  description="Offload UVR model after transcription"
134
  )
135
 
136
- def to_dict(self) -> Dict:
137
- return self.model_dump()
138
-
139
  @classmethod
140
  def to_gradio_input(cls,
141
  defaults: Optional[Dict] = None,
@@ -181,7 +185,7 @@ class BGMSeparationParams(BaseModel):
181
  ]
182
 
183
 
184
- class WhisperParams(BaseModel):
185
  """Whisper parameters"""
186
  model_size: str = Field(default="large-v2", description="Whisper model size")
187
  lang: Optional[str] = Field(default=None, description="Source language of the file to transcribe")
@@ -262,9 +266,6 @@ class WhisperParams(BaseModel):
262
  description="Number of segments for language detection"
263
  )
264
 
265
- def to_dict(self):
266
- return self.model_dump()
267
-
268
  @field_validator('lang')
269
  def validate_lang(cls, v):
270
  from modules.utils.constants import AUTOMATIC_DETECTION
@@ -485,9 +486,36 @@ class TranscriptionPipelineParams(BaseModel):
485
  }
486
  return data
487
 
488
- def as_list(self) -> List:
 
 
 
 
 
489
  whisper_list = [value for key, value in self.whisper.to_dict().items()]
490
  vad_list = [value for key, value in self.vad.to_dict().items()]
491
  diarization_list = [value for key, value in self.vad.to_dict().items()]
492
  bgm_sep_list = [value for key, value in self.bgm_separation.to_dict().items()]
493
  return whisper_list + vad_list + diarization_list + bgm_sep_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from pydantic import BaseModel, Field, field_validator
5
  from gradio_i18n import Translate, gettext as _
6
  from enum import Enum
7
+ from copy import deepcopy
8
  import yaml
9
 
10
  from modules.utils.constants import AUTOMATIC_DETECTION
 
16
  INSANELY_FAST_WHISPER = "insanely_fast_whisper"
17
 
18
 
19
+ class BaseParams(BaseModel):
20
+ def to_dict(self) -> Dict:
21
+ return self.model_dump()
22
+
23
+ def to_list(self) -> List:
24
+ return list(self.model_dump().values())
25
+
26
+ @classmethod
27
+ def from_list(cls, data_list: List) -> 'BaseParams':
28
+ field_names = list(cls.model_fields.keys())
29
+ return cls(**dict(zip(field_names, data_list)))
30
+
31
+
32
+ class VadParams(BaseParams):
33
  """Voice Activity Detection parameters"""
34
  vad_filter: bool = Field(default=False, description="Enable voice activity detection to filter out non-speech parts")
35
  threshold: float = Field(
 
59
  description="Padding added to each side of speech chunks"
60
  )
61
 
 
 
 
62
  @classmethod
63
  def to_gradio_inputs(cls, defaults: Optional[Dict] = None) -> List[gr.components.base.FormComponent]:
64
  defaults = defaults or {}
 
85
  ]
86
 
87
 
88
+ class DiarizationParams(BaseParams):
 
89
  """Speaker diarization parameters"""
90
  is_diarize: bool = Field(default=False, description="Enable speaker diarization")
91
  hf_token: str = Field(
 
93
  description="Hugging Face token for downloading diarization models"
94
  )
95
 
 
 
 
96
  @classmethod
97
  def to_gradio_inputs(cls,
98
  defaults: Optional[Dict] = None,
 
119
  ]
120
 
121
 
122
+ class BGMSeparationParams(BaseParams):
123
  """Background music separation parameters"""
124
  is_separate_bgm: bool = Field(default=False, description="Enable background music separation")
125
  model_size: str = Field(
 
140
  description="Offload UVR model after transcription"
141
  )
142
 
 
 
 
143
  @classmethod
144
  def to_gradio_input(cls,
145
  defaults: Optional[Dict] = None,
 
185
  ]
186
 
187
 
188
+ class WhisperParams(BaseParams):
189
  """Whisper parameters"""
190
  model_size: str = Field(default="large-v2", description="Whisper model size")
191
  lang: Optional[str] = Field(default=None, description="Source language of the file to transcribe")
 
266
  description="Number of segments for language detection"
267
  )
268
 
 
 
 
269
  @field_validator('lang')
270
  def validate_lang(cls, v):
271
  from modules.utils.constants import AUTOMATIC_DETECTION
 
486
  }
487
  return data
488
 
489
+ def to_list(self) -> List:
490
+ """
491
+ Convert data class to the list because I have to pass the parameters as a list in the gradio.
492
+ Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471
493
+ See more about Gradio pre-processing: https://www.gradio.app/docs/components
494
+ """
495
  whisper_list = [value for key, value in self.whisper.to_dict().items()]
496
  vad_list = [value for key, value in self.vad.to_dict().items()]
497
  diarization_list = [value for key, value in self.vad.to_dict().items()]
498
  bgm_sep_list = [value for key, value in self.bgm_separation.to_dict().items()]
499
  return whisper_list + vad_list + diarization_list + bgm_sep_list
500
+
501
+ @staticmethod
502
+ def from_list(pipeline_list: List) -> 'TranscriptionPipelineParams':
503
+ """Convert list to the data class again to use it in a function."""
504
+ data_list = deepcopy(pipeline_list)
505
+ whisper_list, data_list = data_list[0:len(WhisperParams.__annotations__)]
506
+ data_list = data_list[len(WhisperParams.__annotations__):]
507
+
508
+ vad_list = data_list[0:len(VadParams.__annotations__)]
509
+ data_list = data_list[len(VadParams.__annotations__):]
510
+
511
+ diarization_list = data_list[0:len(DiarizationParams.__annotations__)]
512
+ data_list = data_list[len(DiarizationParams.__annotations__)]
513
+
514
+ bgm_sep_list = data_list[0:len(BGMSeparationParams.__annotations__)]
515
+
516
+ return TranscriptionPipelineParams(
517
+ whisper=WhisperParams.from_list(whisper_list),
518
+ vad=VadParams.from_list(vad_list),
519
+ diarization=DiarizationParams.from_list(diarization_list),
520
+ bgm_separation=BGMSeparationParams.from_list(bgm_sep_list)
521
+ )