Automatic Speech Recognition
Transformers
Safetensors
Japanese
whisper
audio
hf-asr-leaderboard
Inference Endpoints
asahi417 commited on
Commit
fb72b03
1 Parent(s): 8bf9b0a
pipeline/kotoba_whisper.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Optional, Dict, List, Any
2
+ import requests
3
+
4
+ import torch
5
+ import numpy as np
6
+
7
+ from transformers.pipelines.audio_utils import ffmpeg_read
8
+ from transformers.pipelines.automatic_speech_recognition import AutomaticSpeechRecognitionPipeline, chunk_iter
9
+ from transformers.utils import is_torchaudio_available
10
+ from transformers.modeling_utils import PreTrainedModel
11
+ from transformers.tokenization_utils import PreTrainedTokenizer
12
+ from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
13
+ from stable_whisper import WhisperResult
14
+ from punctuators.models import PunctCapSegModelONNX
15
+
16
+
17
+ class Punctuator:
18
+
19
+ ja_punctuations = ["!", "?", "、", "。"]
20
+
21
+ def __init__(self, model: str = "pcs_47lang"):
22
+ self.punctuation_model = PunctCapSegModelONNX.from_pretrained(model)
23
+
24
+ def punctuate(self, pipeline_chunk: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
25
+
26
+ def validate_punctuation(raw: str, punctuated: str):
27
+ if 'unk' in punctuated.lower() or any(p in raw for p in self.ja_punctuations):
28
+ return raw
29
+ if punctuated.count("。") > 1:
30
+ ind = punctuated.rfind("。")
31
+ punctuated = punctuated.replace("。", "")
32
+ punctuated = punctuated[:ind] + "。" + punctuated[ind:]
33
+ return punctuated
34
+
35
+ text_edit = self.punctuation_model.infer([c['text'] for c in pipeline_chunk])
36
+ return [
37
+ {
38
+ 'timestamp': c['timestamp'],
39
+ 'text': validate_punctuation(c['text'], "".join(e))
40
+ } for c, e in zip(pipeline_chunk, text_edit)
41
+ ]
42
+
43
+
44
+ def _fix_timestamp(sample_rate: int, result: List[Dict[str, Any]], audio: np.ndarray) -> WhisperResult or None:
45
+
46
+ def replace_none_ts(parts):
47
+ total_dur = round(audio.shape[-1] / sample_rate, 3)
48
+ _medium_dur = _ts_nonzero_mask = None
49
+
50
+ def ts_nonzero_mask() -> np.ndarray:
51
+ nonlocal _ts_nonzero_mask
52
+ if _ts_nonzero_mask is None:
53
+ _ts_nonzero_mask = np.array([(p['end'] or p['start']) is not None for p in parts])
54
+ return _ts_nonzero_mask
55
+
56
+ def medium_dur() -> float:
57
+ nonlocal _medium_dur
58
+ if _medium_dur is None:
59
+ nonzero_dus = [p['end'] - p['start'] for p in parts if None not in (p['end'], p['start'])]
60
+ nonzero_durs = np.array(nonzero_dus)
61
+ _medium_dur = np.median(nonzero_durs) * 2 if len(nonzero_durs) else 2.0
62
+ return _medium_dur
63
+
64
+ def _curr_max_end(start: float, next_idx: float) -> float:
65
+ max_end = total_dur
66
+ if next_idx != len(parts):
67
+ mask = np.flatnonzero(ts_nonzero_mask()[next_idx:])
68
+ if len(mask):
69
+ _part = parts[mask[0]+next_idx]
70
+ max_end = _part['start'] or _part['end']
71
+
72
+ new_end = round(start + medium_dur(), 3)
73
+ if new_end > max_end:
74
+ return max_end
75
+ return new_end
76
+
77
+ for i, part in enumerate(parts, 1):
78
+ if part['start'] is None:
79
+ is_first = i == 1
80
+ if is_first:
81
+ new_start = round((part['end'] or 0) - medium_dur(), 3)
82
+ part['start'] = max(new_start, 0.0)
83
+ else:
84
+ part['start'] = parts[i - 2]['end']
85
+ if part['end'] is None:
86
+ no_next_start = i == len(parts) or parts[i]['start'] is None
87
+ part['end'] = _curr_max_end(part['start'], i) if no_next_start else parts[i]['start']
88
+
89
+ words = [dict(start=word['timestamp'][0], end=word['timestamp'][1], word=word['text']) for word in result]
90
+ replace_none_ts(words)
91
+ return WhisperResult([words], force_order=True, check_sorted=True)
92
+
93
+
94
+ def fix_timestamp(pipeline_output: List[Dict[str, Any]], audio: np.ndarray, sample_rate: int) -> List[Dict[str, Any]]:
95
+ result = _fix_timestamp(sample_rate=sample_rate, audio=audio, result=pipeline_output)
96
+ result.adjust_by_silence(
97
+ audio,
98
+ q_levels=20,
99
+ k_size=5,
100
+ sample_rate=sample_rate,
101
+ min_word_dur=None,
102
+ word_level=True,
103
+ verbose=True,
104
+ nonspeech_error=0.1,
105
+ use_word_position=True
106
+ )
107
+ if result.has_words:
108
+ result.regroup(True)
109
+ return [{"timestamp": [s.start, s.end], "text": s.text} for s in result.segments]
110
+
111
+
112
+ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
113
+
114
+ def __init__(self,
115
+ model: "PreTrainedModel",
116
+ feature_extractor: Union["SequenceFeatureExtractor", str] = None,
117
+ tokenizer: Optional[PreTrainedTokenizer] = None,
118
+ device: Union[int, "torch.device"] = None,
119
+ torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
120
+ punctuator: bool = True,
121
+ stable_ts: bool = False,
122
+ **kwargs):
123
+ self.type = "seq2seq_whisper"
124
+ self.stable_ts = stable_ts
125
+ if punctuator:
126
+ self.punctuator = Punctuator()
127
+ else:
128
+ self.punctuator = None
129
+ super().__init__(
130
+ model=model,
131
+ feature_extractor=feature_extractor,
132
+ tokenizer=tokenizer,
133
+ device=device,
134
+ torch_dtype=torch_dtype,
135
+ **kwargs
136
+ )
137
+
138
+ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
139
+ if isinstance(inputs, str):
140
+ if inputs.startswith("http://") or inputs.startswith("https://"):
141
+ # We need to actually check for a real protocol, otherwise it's impossible to use a local file
142
+ # like http_huggingface_co.png
143
+ inputs = requests.get(inputs).content
144
+ else:
145
+ with open(inputs, "rb") as f:
146
+ inputs = f.read()
147
+
148
+ if isinstance(inputs, bytes):
149
+ inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
150
+
151
+ stride = None
152
+ extra = {}
153
+ if isinstance(inputs, dict):
154
+ stride = inputs.pop("stride", None)
155
+ # Accepting `"array"` which is the key defined in `datasets` for
156
+ # better integration
157
+ if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
158
+ raise ValueError(
159
+ "When passing a dictionary to AutomaticSpeechRecognitionPipeline, the dict needs to contain a "
160
+ '"raw" key containing the numpy array representing the audio and a "sampling_rate" key, '
161
+ "containing the sampling_rate associated with that array"
162
+ )
163
+
164
+ _inputs = inputs.pop("raw", None)
165
+ if _inputs is None:
166
+ # Remove path which will not be used from `datasets`.
167
+ inputs.pop("path", None)
168
+ _inputs = inputs.pop("array", None)
169
+ in_sampling_rate = inputs.pop("sampling_rate")
170
+ extra = inputs
171
+ inputs = _inputs
172
+ if in_sampling_rate != self.feature_extractor.sampling_rate:
173
+ if is_torchaudio_available():
174
+ from torchaudio import functional as F
175
+ else:
176
+ raise ImportError(
177
+ "torchaudio is required to resample audio samples in AutomaticSpeechRecognitionPipeline. "
178
+ "The torchaudio package can be installed through: `pip install torchaudio`."
179
+ )
180
+
181
+ inputs = F.resample(
182
+ torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate
183
+ ).numpy()
184
+ ratio = self.feature_extractor.sampling_rate / in_sampling_rate
185
+ else:
186
+ ratio = 1
187
+ if stride is not None:
188
+ if stride[0] + stride[1] > inputs.shape[0]:
189
+ raise ValueError("Stride is too large for input")
190
+
191
+ # Stride needs to get the chunk length here, it's going to get
192
+ # swallowed by the `feature_extractor` later, and then batching
193
+ # can add extra data in the inputs, so we need to keep track
194
+ # of the original length in the stride so we can cut properly.
195
+ stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))
196
+ if not isinstance(inputs, np.ndarray):
197
+ raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
198
+ if len(inputs.shape) != 1:
199
+ raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
200
+
201
+ if chunk_length_s:
202
+ if stride_length_s is None:
203
+ stride_length_s = chunk_length_s / 6
204
+
205
+ if isinstance(stride_length_s, (int, float)):
206
+ stride_length_s = [stride_length_s, stride_length_s]
207
+
208
+ # XXX: Carefuly, this variable will not exist in `seq2seq` setting.
209
+ # Currently chunking is not possible at this level for `seq2seq` so
210
+ # it's ok.
211
+ align_to = getattr(self.model.config, "inputs_to_logits_ratio", 1)
212
+ chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to) * align_to)
213
+ stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to)
214
+ stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to)
215
+
216
+ if chunk_len < stride_left + stride_right:
217
+ raise ValueError("Chunk length must be superior to stride length")
218
+
219
+ for item in chunk_iter(
220
+ inputs, self.feature_extractor, chunk_len, stride_left, stride_right, self.torch_dtype
221
+ ):
222
+ item["audio_array"] = inputs
223
+ yield item
224
+ else:
225
+ if inputs.shape[0] > self.feature_extractor.n_samples:
226
+ processed = self.feature_extractor(
227
+ inputs,
228
+ sampling_rate=self.feature_extractor.sampling_rate,
229
+ truncation=False,
230
+ padding="longest",
231
+ return_tensors="pt",
232
+ )
233
+ else:
234
+ processed = self.feature_extractor(
235
+ inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
236
+ )
237
+
238
+ if self.torch_dtype is not None:
239
+ processed = processed.to(dtype=self.torch_dtype)
240
+ if stride is not None:
241
+ processed["stride"] = stride
242
+ yield {"is_last": True, "audio_array": inputs, **processed, **extra}
243
+
244
+ def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs):
245
+ attention_mask = model_inputs.pop("attention_mask", None)
246
+ stride = model_inputs.pop("stride", None)
247
+ is_last = model_inputs.pop("is_last")
248
+ audio_array = model_inputs.pop("audio_array")
249
+ encoder = self.model.get_encoder()
250
+ # Consume values so we can let extra information flow freely through
251
+ # the pipeline (important for `partial` in microphone)
252
+ if type(return_timestamps) is not bool:
253
+ raise ValueError("return_timestamps should be bool")
254
+ if "input_features" in model_inputs:
255
+ inputs = model_inputs.pop("input_features")
256
+ elif "input_values" in model_inputs:
257
+ inputs = model_inputs.pop("input_values")
258
+ else:
259
+ raise ValueError(
260
+ "Seq2Seq speech recognition model requires either a "
261
+ f"`input_features` or `input_values` key, but only has {model_inputs.keys()}"
262
+ )
263
+
264
+ # custom processing for Whisper timestamps and word-level timestamps
265
+ generate_kwargs["return_timestamps"] = True
266
+ if inputs.shape[-1] > self.feature_extractor.nb_max_frames:
267
+ generate_kwargs["input_features"] = inputs
268
+ else:
269
+ generate_kwargs["encoder_outputs"] = encoder(inputs, attention_mask=attention_mask)
270
+
271
+ tokens = self.model.generate(attention_mask=attention_mask, **generate_kwargs)
272
+ # whisper longform generation stores timestamps in "segments"
273
+ out = {"tokens": tokens}
274
+ if self.type == "seq2seq_whisper":
275
+ if stride is not None:
276
+ out["stride"] = stride
277
+
278
+ # Leftover
279
+ extra = model_inputs
280
+ return {"is_last": is_last, "audio_array": audio_array, **out, **extra}
281
+
282
+ def postprocess(self,
283
+ model_outputs,
284
+ decoder_kwargs: Optional[Dict] = None,
285
+ return_timestamps=None,
286
+ return_language=None):
287
+ assert len(model_outputs) > 0
288
+ for model_output in model_outputs:
289
+ audio_array = model_output.pop("audio_array")[0]
290
+ outputs = super().postprocess(
291
+ model_outputs=model_outputs,
292
+ decoder_kwargs=decoder_kwargs,
293
+ return_timestamps=True,
294
+ return_language=return_language
295
+ )
296
+ if self.stable_ts:
297
+ outputs["chunks"] = fix_timestamp(
298
+ pipeline_output=outputs["chunks"], audio=audio_array, sample_rate=self.feature_extractor.sampling_rate
299
+ )
300
+ if self.punctuator:
301
+ outputs["chunks"] = self.punctuator.punctuate(outputs["chunks"])
302
+ outputs["text"] = "".join([c["text"] for c in outputs["chunks"]])
303
+ if not return_timestamps:
304
+ outputs.pop("chunks")
305
+ return outputs
306
+
pipeline/push_pipeline.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from kotoba_whisper import KotobaWhisperPipeline
2
+ from transformers.pipelines import PIPELINE_REGISTRY, pipeline
3
+ from transformers import WhisperForConditionalGeneration, TFWhisperForConditionalGeneration
4
+
5
+
6
+ model_alias = "kotoba-tech/kotoba-whisper-v2.1"
7
+ PIPELINE_REGISTRY.register_pipeline(
8
+ "kotoba-whisper",
9
+ pipeline_class=KotobaWhisperPipeline,
10
+ pt_model=WhisperForConditionalGeneration,
11
+ tf_model=TFWhisperForConditionalGeneration
12
+ )
13
+ pipe = pipeline(
14
+ task="kotoba-whisper",
15
+ model="kotoba-tech/kotoba-whisper-v2.0",
16
+ chunk_length_s=15,
17
+ batch_size=16,
18
+ punctuator=True,
19
+ stable_ts=True,
20
+ )
21
+ pipe.push_to_hub(model_alias)
22
+ pipe = pipeline(model=model_alias,
23
+ punctuator=True,
24
+ stable_ts=True,
25
+ chunk_length_s=15,
26
+ batch_size=16,
27
+ trust_remote_code=True)
28
+
29
+
pipeline/test_pipeline.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pprint import pprint
2
+ from datasets import load_dataset
3
+ from transformers.pipelines import pipeline
4
+
5
+ model_alias = "kotoba-tech/kotoba-whisper-v1.1"
6
+
7
+ print("""### P + S ###""")
8
+ pipe = pipeline(model=model_alias,
9
+ punctuator=True,
10
+ stable_ts=True,
11
+ chunk_length_s=15,
12
+ batch_size=16,
13
+ trust_remote_code=True)
14
+ dataset = load_dataset("kotoba-tech/kotoba-whisper-eval", split="train")
15
+ for i in dataset:
16
+ if i["audio"]["path"] == "long_interview_1.mp3":
17
+ i["audio"]["array"] = i["audio"]["array"][:7938000]
18
+ prediction = pipe(
19
+ i["audio"],
20
+ return_timestamps=True,
21
+ generate_kwargs={"language": "japanese", "task": "transcribe"}
22
+ )
23
+ pprint(prediction)
24
+ break
25
+
26
+ print("""### P ###""")
27
+ pipe = pipeline(model=model_alias,
28
+ punctuator=True,
29
+ stable_ts=False,
30
+ chunk_length_s=15,
31
+ batch_size=16,
32
+ trust_remote_code=True)
33
+ dataset = load_dataset("kotoba-tech/kotoba-whisper-eval", split="train")
34
+ for i in dataset:
35
+ if i["audio"]["path"] == "long_interview_1.mp3":
36
+ i["audio"]["array"] = i["audio"]["array"][:7938000]
37
+ prediction = pipe(
38
+ i["audio"],
39
+ return_timestamps=True,
40
+ generate_kwargs={"language": "japanese", "task": "transcribe"}
41
+ )
42
+ pprint(prediction)
43
+ break
44
+
45
+ print("""### S ###""")
46
+ pipe = pipeline(model=model_alias,
47
+ punctuator=False,
48
+ stable_ts=True,
49
+ chunk_length_s=15,
50
+ batch_size=16,
51
+ trust_remote_code=True)
52
+ dataset = load_dataset("kotoba-tech/kotoba-whisper-eval", split="train")
53
+ for i in dataset:
54
+ if i["audio"]["path"] == "long_interview_1.mp3":
55
+ i["audio"]["array"] = i["audio"]["array"][:7938000]
56
+ prediction = pipe(
57
+ i["audio"],
58
+ return_timestamps=True,
59
+ generate_kwargs={"language": "japanese", "task": "transcribe"}
60
+ )
61
+ pprint(prediction)
62
+ break
63
+
64
+ print("""### RAW ###""")
65
+ pipe = pipeline(model=model_alias,
66
+ punctuator=False,
67
+ stable_ts=False,
68
+ chunk_length_s=15,
69
+ batch_size=16,
70
+ trust_remote_code=True)
71
+ dataset = load_dataset("kotoba-tech/kotoba-whisper-eval", split="train")
72
+ for i in dataset:
73
+ if i["audio"]["path"] == "long_interview_1.mp3":
74
+ i["audio"]["array"] = i["audio"]["array"][:7938000]
75
+ prediction = pipe(
76
+ i["audio"],
77
+ return_timestamps=True,
78
+ generate_kwargs={"language": "japanese", "task": "transcribe"}
79
+ )
80
+ pprint(prediction)
81
+ break
82
+
83
+ print("""### P + S ###""")
84
+ pipe = pipeline(model=model_alias,
85
+ punctuator=True,
86
+ stable_ts=True,
87
+ chunk_length_s=15,
88
+ batch_size=16,
89
+ trust_remote_code=True)
90
+ dataset = load_dataset("kotoba-tech/kotoba-whisper-eval", split="train")
91
+ for i in dataset:
92
+ if i["audio"]["path"] == "long_interview_1.mp3":
93
+ i["audio"]["array"] = i["audio"]["array"][:7938000]
94
+ prediction = pipe(
95
+ i["audio"],
96
+ generate_kwargs={"language": "japanese", "task": "transcribe"}
97
+ )
98
+ pprint(prediction)
99
+ break
100
+
101
+ print("""### P ###""")
102
+ pipe = pipeline(model=model_alias,
103
+ punctuator=True,
104
+ stable_ts=False,
105
+ chunk_length_s=15,
106
+ batch_size=16,
107
+ trust_remote_code=True)
108
+ dataset = load_dataset("kotoba-tech/kotoba-whisper-eval", split="train")
109
+ for i in dataset:
110
+ if i["audio"]["path"] == "long_interview_1.mp3":
111
+ i["audio"]["array"] = i["audio"]["array"][:7938000]
112
+ prediction = pipe(
113
+ i["audio"],
114
+ generate_kwargs={"language": "japanese", "task": "transcribe"}
115
+ )
116
+ pprint(prediction)
117
+ break
118
+
119
+ print("""### S ###""")
120
+ pipe = pipeline(model=model_alias,
121
+ punctuator=False,
122
+ stable_ts=True,
123
+ chunk_length_s=15,
124
+ batch_size=16,
125
+ trust_remote_code=True)
126
+ dataset = load_dataset("kotoba-tech/kotoba-whisper-eval", split="train")
127
+ for i in dataset:
128
+ if i["audio"]["path"] == "long_interview_1.mp3":
129
+ i["audio"]["array"] = i["audio"]["array"][:7938000]
130
+ prediction = pipe(
131
+ i["audio"],
132
+ generate_kwargs={"language": "japanese", "task": "transcribe"}
133
+ )
134
+ pprint(prediction)
135
+ break
136
+
137
+ print("""### RAW ###""")
138
+ pipe = pipeline(model=model_alias,
139
+ punctuator=False,
140
+ stable_ts=False,
141
+ chunk_length_s=15,
142
+ batch_size=16,
143
+ trust_remote_code=True)
144
+ dataset = load_dataset("kotoba-tech/kotoba-whisper-eval", split="train")
145
+ for i in dataset:
146
+ if i["audio"]["path"] == "long_interview_1.mp3":
147
+ i["audio"]["array"] = i["audio"]["array"][:7938000]
148
+ prediction = pipe(
149
+ i["audio"],
150
+ generate_kwargs={"language": "japanese", "task": "transcribe"}
151
+ )
152
+ pprint(prediction)
153
+ break
154
+