whisper-speaker-diarization / asr_diarizer.py
sanchit-gandhi's picture
Update asr_diarizer.py
51fd668
raw
history blame
7.76 kB
from typing import List, Optional, Union
import numpy as np
import torch
from torchaudio import functional as F
import requests
from pyannote.audio import Pipeline
from transformers import pipeline
from transformers.pipelines.audio_utils import ffmpeg_read
class ASRDiarizationPipeline:
def __init__(
self,
asr_pipeline,
diarization_pipeline,
):
self.asr_pipeline = asr_pipeline
self.diarization_pipeline = diarization_pipeline
self.sampling_rate = self.asr_pipeline.feature_extractor.sampling_rate
@classmethod
def from_pretrained(
cls,
asr_model: Optional[str] = "openai/whisper-small",
diarizer_model: Optional[str] = "pyannote/speaker-diarization",
chunk_length_s: Optional[int] = 30,
use_auth_token: Optional[Union[str, bool]] = True,
**kwargs,
):
asr_pipeline = pipeline(
"automatic-speech-recognition",
model=asr_model,
chunk_length_s=chunk_length_s,
use_auth_token=use_auth_token,
**kwargs,
)
diarization_pipeline = Pipeline.from_pretrained(diarizer_model, use_auth_token=use_auth_token)
cls(asr_pipeline, diarization_pipeline)
def __call__(
self,
inputs: Union[np.ndarray, List[np.ndarray]],
group_by_speaker: bool = True,
**kwargs,
):
"""
Transcribe the audio sequence(s) given as inputs to text.
Args:
inputs (`np.ndarray` or `bytes` or `str` or `dict`):
The inputs is either :
- `str` that is the filename of the audio file, the file will be read at the correct sampling rate
to get the waveform using *ffmpeg*. This requires *ffmpeg* to be installed on the system.
- `bytes` it is supposed to be the content of an audio file and is interpreted by *ffmpeg* in the
same way.
- (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
Raw audio at the correct sampling rate (no further check will be done)
- `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this
pipeline do the resampling. The dict must be in the format `{"sampling_rate": int, "raw":
np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to
treat the first `left` samples and last `right` samples to be ignored in decoding (but used at
inference to provide more context to the model). Only use `stride` with CTC models.
Return:
`Dict`: A dictionary with the following keys:
- **text** (`str` ) -- The recognized text.
- **chunks** (*optional(, `List[Dict]`)
When using `return_timestamps`, the `chunks` will become a list containing all the various text
chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamps": (0.5,0.9), {"text":
"there", "timestamps": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
`"".join(chunk["text"] for chunk in output["chunks"])`.
"""
inputs, diarizer_inputs = self.preprocess(inputs)
diarization = self.diarization_pipeline(
{"waveform": diarizer_inputs, "sample_rate": self.sampling_rate},
**kwargs,
)
segments = diarization.for_json()["content"]
new_segments = []
prev_segment = cur_segment = segments[0]
for i in range(1, len(segments)):
cur_segment = segments[i]
if cur_segment["label"] != prev_segment["label"] and i < len(segments):
new_segments.append(
{
"segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]},
"speaker": prev_segment["label"],
}
)
prev_segment = segments[i]
new_segments.append(
{
"segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["end"]},
"speaker": prev_segment["label"],
}
)
asr_out = self.asr_pipeline(
{"array": inputs, "sampling_rate": self.sampling_rate},
return_timestamps=True,
**kwargs,
)
transcript = asr_out["chunks"]
end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript])
segmented_preds = []
for segment in new_segments:
end_time = segment["segment"]["end"]
upto_idx = np.argmin(np.abs(end_timestamps - end_time))
if group_by_speaker:
segmented_preds.append(
{
"speaker": segment["speaker"],
"text": "".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]),
"timestamp": {
"start": transcript[0]["timestamp"][0],
"end": transcript[upto_idx]["timestamp"][1],
},
}
)
else:
for i in range(upto_idx + 1):
segmented_preds.append({"speaker": segment["speaker"], **transcript[i]})
transcript = transcript[upto_idx + 1 :]
end_timestamps = end_timestamps[upto_idx + 1 :]
return segmented_preds
def preprocess(self, inputs):
if isinstance(inputs, str):
if inputs.startswith("http://") or inputs.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
inputs = requests.get(inputs).content
else:
with open(inputs, "rb") as f:
inputs = f.read()
if isinstance(inputs, bytes):
inputs = ffmpeg_read(inputs, self.sampling_rate)
if isinstance(inputs, dict):
# Accepting `"array"` which is the key defined in `datasets` for better integration
if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
raise ValueError(
"When passing a dictionary to ASRDiarizePipeline, the dict needs to contain a "
'"raw" key containing the numpy array representing the audio and a "sampling_rate" key, '
"containing the sampling_rate associated with that array"
)
_inputs = inputs.pop("raw", None)
if _inputs is None:
# Remove path which will not be used from `datasets`.
inputs.pop("path", None)
_inputs = inputs.pop("array", None)
in_sampling_rate = inputs.pop("sampling_rate")
inputs = _inputs
if in_sampling_rate != self.sampling_rate:
inputs = F.resample(
torch.from_numpy(inputs), in_sampling_rate, self.sampling_rate
).numpy()
if not isinstance(inputs, np.ndarray):
raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
if len(inputs.shape) != 1:
raise ValueError("We expect a single channel audio input for ASRDiarizePipeline")
diarizer_inputs = torch.from_numpy(inputs).float().unsqueeze(0)
return inputs, diarizer_inputs