Spaces:
Runtime error
Runtime error
Commit
•
7860c23
1
Parent(s):
cc504bd
Create asr_diarizer.py
Browse files- asr_diarizer.py +101 -0
asr_diarizer.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from pyannote.audio import Pipeline
|
6 |
+
from transformers import pipeline
|
7 |
+
|
8 |
+
|
9 |
+
class ASRDiarizationPipeline:
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
asr_model: Optional[str] = "openai/whisper-small",
|
13 |
+
diarizer_model: Optional[str] = "pyannote/speaker-diarization",
|
14 |
+
chunk_length_s: int = 30,
|
15 |
+
**kwargs,
|
16 |
+
):
|
17 |
+
self.asr_pipeline = pipeline(
|
18 |
+
"automatic-speech-recognition",
|
19 |
+
model=asr_model,
|
20 |
+
use_auth_token=True,
|
21 |
+
chunk_length_s=chunk_length_s,
|
22 |
+
**kwargs,
|
23 |
+
)
|
24 |
+
self.diarization_pipeline = Pipeline.from_pretrained(diarizer_model, use_auth_token=True)
|
25 |
+
|
26 |
+
def __call__(
|
27 |
+
self,
|
28 |
+
inputs: Union[np.ndarray, List[np.ndarray]],
|
29 |
+
sampling_rate: int,
|
30 |
+
group_by_speaker: bool = True,
|
31 |
+
**kwargs,
|
32 |
+
):
|
33 |
+
if not isinstance(inputs, np.ndarray):
|
34 |
+
raise ValueError(f"Expected a numpy ndarray as input, got `{type(inputs)}`.")
|
35 |
+
if len(inputs.shape) != 1:
|
36 |
+
raise ValueError(f"Expected a single channel audio as input, got `{len(inputs.shape)}` channels.")
|
37 |
+
|
38 |
+
diarizer_inputs = torch.from_numpy(inputs).float().unsqueeze(0)
|
39 |
+
diarization = self.diarization_pipeline(
|
40 |
+
{"waveform": diarizer_inputs, "sample_rate": sampling_rate},
|
41 |
+
**kwargs,
|
42 |
+
)
|
43 |
+
del diarizer_inputs
|
44 |
+
|
45 |
+
segments = diarization.for_json()["content"]
|
46 |
+
|
47 |
+
new_segments = []
|
48 |
+
prev_segment = cur_segment = segments[0]
|
49 |
+
|
50 |
+
for i in range(1, len(segments)):
|
51 |
+
cur_segment = segments[i]
|
52 |
+
|
53 |
+
if cur_segment["label"] != prev_segment["label"] and i < len(segments):
|
54 |
+
new_segments.append(
|
55 |
+
{
|
56 |
+
"segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]},
|
57 |
+
"speaker": prev_segment["label"],
|
58 |
+
}
|
59 |
+
)
|
60 |
+
prev_segment = segments[i]
|
61 |
+
|
62 |
+
new_segments.append(
|
63 |
+
{
|
64 |
+
"segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["end"]},
|
65 |
+
"speaker": prev_segment["label"],
|
66 |
+
}
|
67 |
+
)
|
68 |
+
|
69 |
+
asr_out = self.asr_pipeline(
|
70 |
+
{"array": inputs, "sampling_rate": sampling_rate},
|
71 |
+
return_timestamps=True,
|
72 |
+
**kwargs,
|
73 |
+
)
|
74 |
+
transcript = asr_out["chunks"]
|
75 |
+
|
76 |
+
end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript])
|
77 |
+
segmented_preds = []
|
78 |
+
|
79 |
+
for segment in new_segments:
|
80 |
+
end_time = segment["segment"]["end"]
|
81 |
+
upto_idx = np.argmin(np.abs(end_timestamps - end_time))
|
82 |
+
|
83 |
+
if group_by_speaker:
|
84 |
+
segmented_preds.append(
|
85 |
+
{
|
86 |
+
"speaker": segment["speaker"],
|
87 |
+
"text": "".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]),
|
88 |
+
"timestamp": {
|
89 |
+
"start": transcript[0]["timestamp"][0],
|
90 |
+
"end": transcript[upto_idx]["timestamp"][1],
|
91 |
+
},
|
92 |
+
}
|
93 |
+
)
|
94 |
+
else:
|
95 |
+
for i in range(upto_idx + 1):
|
96 |
+
segmented_preds.append({"speaker": segment["speaker"], **transcript[i]})
|
97 |
+
|
98 |
+
transcript = transcript[upto_idx + 1 :]
|
99 |
+
end_timestamps = end_timestamps[upto_idx + 1 :]
|
100 |
+
|
101 |
+
return segmented_preds
|