sanchit-gandhi HF staff commited on
Commit
7860c23
1 Parent(s): cc504bd

Create asr_diarizer.py

Browse files
Files changed (1) hide show
  1. asr_diarizer.py +101 -0
asr_diarizer.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ from pyannote.audio import Pipeline
6
+ from transformers import pipeline
7
+
8
+
9
+ class ASRDiarizationPipeline:
10
+ def __init__(
11
+ self,
12
+ asr_model: Optional[str] = "openai/whisper-small",
13
+ diarizer_model: Optional[str] = "pyannote/speaker-diarization",
14
+ chunk_length_s: int = 30,
15
+ **kwargs,
16
+ ):
17
+ self.asr_pipeline = pipeline(
18
+ "automatic-speech-recognition",
19
+ model=asr_model,
20
+ use_auth_token=True,
21
+ chunk_length_s=chunk_length_s,
22
+ **kwargs,
23
+ )
24
+ self.diarization_pipeline = Pipeline.from_pretrained(diarizer_model, use_auth_token=True)
25
+
26
+ def __call__(
27
+ self,
28
+ inputs: Union[np.ndarray, List[np.ndarray]],
29
+ sampling_rate: int,
30
+ group_by_speaker: bool = True,
31
+ **kwargs,
32
+ ):
33
+ if not isinstance(inputs, np.ndarray):
34
+ raise ValueError(f"Expected a numpy ndarray as input, got `{type(inputs)}`.")
35
+ if len(inputs.shape) != 1:
36
+ raise ValueError(f"Expected a single channel audio as input, got `{len(inputs.shape)}` channels.")
37
+
38
+ diarizer_inputs = torch.from_numpy(inputs).float().unsqueeze(0)
39
+ diarization = self.diarization_pipeline(
40
+ {"waveform": diarizer_inputs, "sample_rate": sampling_rate},
41
+ **kwargs,
42
+ )
43
+ del diarizer_inputs
44
+
45
+ segments = diarization.for_json()["content"]
46
+
47
+ new_segments = []
48
+ prev_segment = cur_segment = segments[0]
49
+
50
+ for i in range(1, len(segments)):
51
+ cur_segment = segments[i]
52
+
53
+ if cur_segment["label"] != prev_segment["label"] and i < len(segments):
54
+ new_segments.append(
55
+ {
56
+ "segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]},
57
+ "speaker": prev_segment["label"],
58
+ }
59
+ )
60
+ prev_segment = segments[i]
61
+
62
+ new_segments.append(
63
+ {
64
+ "segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["end"]},
65
+ "speaker": prev_segment["label"],
66
+ }
67
+ )
68
+
69
+ asr_out = self.asr_pipeline(
70
+ {"array": inputs, "sampling_rate": sampling_rate},
71
+ return_timestamps=True,
72
+ **kwargs,
73
+ )
74
+ transcript = asr_out["chunks"]
75
+
76
+ end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript])
77
+ segmented_preds = []
78
+
79
+ for segment in new_segments:
80
+ end_time = segment["segment"]["end"]
81
+ upto_idx = np.argmin(np.abs(end_timestamps - end_time))
82
+
83
+ if group_by_speaker:
84
+ segmented_preds.append(
85
+ {
86
+ "speaker": segment["speaker"],
87
+ "text": "".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]),
88
+ "timestamp": {
89
+ "start": transcript[0]["timestamp"][0],
90
+ "end": transcript[upto_idx]["timestamp"][1],
91
+ },
92
+ }
93
+ )
94
+ else:
95
+ for i in range(upto_idx + 1):
96
+ segmented_preds.append({"speaker": segment["speaker"], **transcript[i]})
97
+
98
+ transcript = transcript[upto_idx + 1 :]
99
+ end_timestamps = end_timestamps[upto_idx + 1 :]
100
+
101
+ return segmented_preds