sanchit-gandhi HF staff commited on
Commit
51fd668
1 Parent(s): 2503b95

Update asr_diarizer.py

Browse files
Files changed (1) hide show
  1. asr_diarizer.py +82 -11
asr_diarizer.py CHANGED
@@ -2,8 +2,13 @@ from typing import List, Optional, Union
2
 
3
  import numpy as np
4
  import torch
 
 
 
 
5
  from pyannote.audio import Pipeline
6
  from transformers import pipeline
 
7
 
8
 
9
  class ASRDiarizationPipeline:
@@ -14,14 +19,16 @@ class ASRDiarizationPipeline:
14
  ):
15
  self.asr_pipeline = asr_pipeline
16
  self.diarization_pipeline = diarization_pipeline
 
 
17
 
18
  @classmethod
19
  def from_pretrained(
20
  cls,
21
  asr_model: Optional[str] = "openai/whisper-small",
22
  diarizer_model: Optional[str] = "pyannote/speaker-diarization",
23
- chunk_length_s: int = 30,
24
- use_auth_token: Union[str, bool] = True,
25
  **kwargs,
26
  ):
27
  asr_pipeline = pipeline(
@@ -37,21 +44,42 @@ class ASRDiarizationPipeline:
37
  def __call__(
38
  self,
39
  inputs: Union[np.ndarray, List[np.ndarray]],
40
- sampling_rate: int,
41
  group_by_speaker: bool = True,
42
  **kwargs,
43
  ):
44
- if not isinstance(inputs, np.ndarray):
45
- raise ValueError(f"Expected a numpy ndarray as input, got `{type(inputs)}`.")
46
- if len(inputs.shape) != 1:
47
- raise ValueError(f"Expected a single channel audio as input, got `{len(inputs.shape)}` channels.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- diarizer_inputs = torch.from_numpy(inputs).float().unsqueeze(0)
50
  diarization = self.diarization_pipeline(
51
- {"waveform": diarizer_inputs, "sample_rate": sampling_rate},
52
  **kwargs,
53
  )
54
- del diarizer_inputs
55
 
56
  segments = diarization.for_json()["content"]
57
 
@@ -78,7 +106,7 @@ class ASRDiarizationPipeline:
78
  )
79
 
80
  asr_out = self.asr_pipeline(
81
- {"array": inputs, "sampling_rate": sampling_rate},
82
  return_timestamps=True,
83
  **kwargs,
84
  )
@@ -110,3 +138,46 @@ class ASRDiarizationPipeline:
110
  end_timestamps = end_timestamps[upto_idx + 1 :]
111
 
112
  return segmented_preds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import numpy as np
4
  import torch
5
+ from torchaudio import functional as F
6
+
7
+ import requests
8
+
9
  from pyannote.audio import Pipeline
10
  from transformers import pipeline
11
+ from transformers.pipelines.audio_utils import ffmpeg_read
12
 
13
 
14
  class ASRDiarizationPipeline:
 
19
  ):
20
  self.asr_pipeline = asr_pipeline
21
  self.diarization_pipeline = diarization_pipeline
22
+
23
+ self.sampling_rate = self.asr_pipeline.feature_extractor.sampling_rate
24
 
25
  @classmethod
26
  def from_pretrained(
27
  cls,
28
  asr_model: Optional[str] = "openai/whisper-small",
29
  diarizer_model: Optional[str] = "pyannote/speaker-diarization",
30
+ chunk_length_s: Optional[int] = 30,
31
+ use_auth_token: Optional[Union[str, bool]] = True,
32
  **kwargs,
33
  ):
34
  asr_pipeline = pipeline(
 
44
  def __call__(
45
  self,
46
  inputs: Union[np.ndarray, List[np.ndarray]],
 
47
  group_by_speaker: bool = True,
48
  **kwargs,
49
  ):
50
+ """
51
+ Transcribe the audio sequence(s) given as inputs to text.
52
+
53
+ Args:
54
+ inputs (`np.ndarray` or `bytes` or `str` or `dict`):
55
+ The inputs is either :
56
+ - `str` that is the filename of the audio file, the file will be read at the correct sampling rate
57
+ to get the waveform using *ffmpeg*. This requires *ffmpeg* to be installed on the system.
58
+ - `bytes` it is supposed to be the content of an audio file and is interpreted by *ffmpeg* in the
59
+ same way.
60
+ - (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
61
+ Raw audio at the correct sampling rate (no further check will be done)
62
+ - `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this
63
+ pipeline do the resampling. The dict must be in the format `{"sampling_rate": int, "raw":
64
+ np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to
65
+ treat the first `left` samples and last `right` samples to be ignored in decoding (but used at
66
+ inference to provide more context to the model). Only use `stride` with CTC models.
67
+
68
+ Return:
69
+ `Dict`: A dictionary with the following keys:
70
+ - **text** (`str` ) -- The recognized text.
71
+ - **chunks** (*optional(, `List[Dict]`)
72
+ When using `return_timestamps`, the `chunks` will become a list containing all the various text
73
+ chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamps": (0.5,0.9), {"text":
74
+ "there", "timestamps": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
75
+ `"".join(chunk["text"] for chunk in output["chunks"])`.
76
+ """
77
+ inputs, diarizer_inputs = self.preprocess(inputs)
78
 
 
79
  diarization = self.diarization_pipeline(
80
+ {"waveform": diarizer_inputs, "sample_rate": self.sampling_rate},
81
  **kwargs,
82
  )
 
83
 
84
  segments = diarization.for_json()["content"]
85
 
 
106
  )
107
 
108
  asr_out = self.asr_pipeline(
109
+ {"array": inputs, "sampling_rate": self.sampling_rate},
110
  return_timestamps=True,
111
  **kwargs,
112
  )
 
138
  end_timestamps = end_timestamps[upto_idx + 1 :]
139
 
140
  return segmented_preds
141
+
142
+ def preprocess(self, inputs):
143
+ if isinstance(inputs, str):
144
+ if inputs.startswith("http://") or inputs.startswith("https://"):
145
+ # We need to actually check for a real protocol, otherwise it's impossible to use a local file
146
+ # like http_huggingface_co.png
147
+ inputs = requests.get(inputs).content
148
+ else:
149
+ with open(inputs, "rb") as f:
150
+ inputs = f.read()
151
+
152
+ if isinstance(inputs, bytes):
153
+ inputs = ffmpeg_read(inputs, self.sampling_rate)
154
+
155
+ if isinstance(inputs, dict):
156
+ # Accepting `"array"` which is the key defined in `datasets` for better integration
157
+ if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
158
+ raise ValueError(
159
+ "When passing a dictionary to ASRDiarizePipeline, the dict needs to contain a "
160
+ '"raw" key containing the numpy array representing the audio and a "sampling_rate" key, '
161
+ "containing the sampling_rate associated with that array"
162
+ )
163
+
164
+ _inputs = inputs.pop("raw", None)
165
+ if _inputs is None:
166
+ # Remove path which will not be used from `datasets`.
167
+ inputs.pop("path", None)
168
+ _inputs = inputs.pop("array", None)
169
+ in_sampling_rate = inputs.pop("sampling_rate")
170
+ inputs = _inputs
171
+ if in_sampling_rate != self.sampling_rate:
172
+ inputs = F.resample(
173
+ torch.from_numpy(inputs), in_sampling_rate, self.sampling_rate
174
+ ).numpy()
175
+
176
+ if not isinstance(inputs, np.ndarray):
177
+ raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
178
+ if len(inputs.shape) != 1:
179
+ raise ValueError("We expect a single channel audio input for ASRDiarizePipeline")
180
+
181
+ diarizer_inputs = torch.from_numpy(inputs).float().unsqueeze(0)
182
+
183
+ return inputs, diarizer_inputs