Spaces:
Runtime error
Runtime error
from typing import List, Optional, Union | |
import numpy as np | |
import requests | |
import torch | |
from pyannote.audio import Pipeline | |
from torchaudio import functional as F | |
from transformers import pipeline | |
from transformers.pipelines.audio_utils import ffmpeg_read | |
class ASRDiarizationPipeline: | |
def __init__( | |
self, | |
asr_pipeline, | |
diarization_pipeline, | |
): | |
self.asr_pipeline = asr_pipeline | |
self.diarization_pipeline = diarization_pipeline | |
self.sampling_rate = self.asr_pipeline.feature_extractor.sampling_rate | |
def from_pretrained( | |
cls, | |
asr_model: Optional[str] = "openai/whisper-small", | |
diarizer_model: Optional[str] = "pyannote/speaker-diarization", | |
chunk_length_s: Optional[int] = 30, | |
use_auth_token: Optional[Union[str, bool]] = True, | |
**kwargs, | |
): | |
asr_pipeline = pipeline( | |
"automatic-speech-recognition", | |
model=asr_model, | |
chunk_length_s=chunk_length_s, | |
use_auth_token=use_auth_token, | |
**kwargs, | |
) | |
diarization_pipeline = Pipeline.from_pretrained(diarizer_model, use_auth_token=use_auth_token) | |
cls(asr_pipeline, diarization_pipeline) | |
def __call__( | |
self, | |
inputs: Union[np.ndarray, List[np.ndarray]], | |
group_by_speaker: bool = True, | |
**kwargs, | |
): | |
""" | |
Transcribe the audio sequence(s) given as inputs to text. | |
Args: | |
inputs (`np.ndarray` or `bytes` or `str` or `dict`): | |
The inputs is either : | |
- `str` that is the filename of the audio file, the file will be read at the correct sampling rate | |
to get the waveform using *ffmpeg*. This requires *ffmpeg* to be installed on the system. | |
- `bytes` it is supposed to be the content of an audio file and is interpreted by *ffmpeg* in the | |
same way. | |
- (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`) | |
Raw audio at the correct sampling rate (no further check will be done) | |
- `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this | |
pipeline do the resampling. The dict must be in the format `{"sampling_rate": int, "raw": | |
np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to | |
treat the first `left` samples and last `right` samples to be ignored in decoding (but used at | |
inference to provide more context to the model). Only use `stride` with CTC models. | |
Return: | |
`Dict`: A dictionary with the following keys: | |
- **text** (`str` ) -- The recognized text. | |
- **chunks** (*optional(, `List[Dict]`) | |
When using `return_timestamps`, the `chunks` will become a list containing all the various text | |
chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamps": (0.5,0.9), {"text": | |
"there", "timestamps": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing | |
`"".join(chunk["text"] for chunk in output["chunks"])`. | |
""" | |
inputs, diarizer_inputs = self.preprocess(inputs) | |
diarization = self.diarization_pipeline( | |
{"waveform": diarizer_inputs, "sample_rate": self.sampling_rate}, | |
**kwargs, | |
) | |
segments = diarization.for_json()["content"] | |
new_segments = [] | |
prev_segment = cur_segment = segments[0] | |
for i in range(1, len(segments)): | |
cur_segment = segments[i] | |
if cur_segment["label"] != prev_segment["label"] and i < len(segments): | |
new_segments.append( | |
{ | |
"segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]}, | |
"speaker": prev_segment["label"], | |
} | |
) | |
prev_segment = segments[i] | |
new_segments.append( | |
{ | |
"segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["end"]}, | |
"speaker": prev_segment["label"], | |
} | |
) | |
asr_out = self.asr_pipeline( | |
{"array": inputs, "sampling_rate": self.sampling_rate}, | |
return_timestamps=True, | |
**kwargs, | |
) | |
transcript = asr_out["chunks"] | |
end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript]) | |
segmented_preds = [] | |
for segment in new_segments: | |
end_time = segment["segment"]["end"] | |
upto_idx = np.argmin(np.abs(end_timestamps - end_time)) | |
if group_by_speaker: | |
segmented_preds.append( | |
{ | |
"speaker": segment["speaker"], | |
"text": "".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]), | |
"timestamp": { | |
"start": transcript[0]["timestamp"][0], | |
"end": transcript[upto_idx]["timestamp"][1], | |
}, | |
} | |
) | |
else: | |
for i in range(upto_idx + 1): | |
segmented_preds.append({"speaker": segment["speaker"], **transcript[i]}) | |
transcript = transcript[upto_idx + 1 :] | |
end_timestamps = end_timestamps[upto_idx + 1 :] | |
return segmented_preds | |
def preprocess(self, inputs): | |
if isinstance(inputs, str): | |
if inputs.startswith("http://") or inputs.startswith("https://"): | |
# We need to actually check for a real protocol, otherwise it's impossible to use a local file | |
# like http_huggingface_co.png | |
inputs = requests.get(inputs).content | |
else: | |
with open(inputs, "rb") as f: | |
inputs = f.read() | |
if isinstance(inputs, bytes): | |
inputs = ffmpeg_read(inputs, self.sampling_rate) | |
if isinstance(inputs, dict): | |
# Accepting `"array"` which is the key defined in `datasets` for better integration | |
if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)): | |
raise ValueError( | |
"When passing a dictionary to ASRDiarizePipeline, the dict needs to contain a " | |
'"raw" key containing the numpy array representing the audio and a "sampling_rate" key, ' | |
"containing the sampling_rate associated with that array" | |
) | |
_inputs = inputs.pop("raw", None) | |
if _inputs is None: | |
# Remove path which will not be used from `datasets`. | |
inputs.pop("path", None) | |
_inputs = inputs.pop("array", None) | |
in_sampling_rate = inputs.pop("sampling_rate") | |
inputs = _inputs | |
if in_sampling_rate != self.sampling_rate: | |
inputs = F.resample(torch.from_numpy(inputs), in_sampling_rate, self.sampling_rate).numpy() | |
if not isinstance(inputs, np.ndarray): | |
raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`") | |
if len(inputs.shape) != 1: | |
raise ValueError("We expect a single channel audio input for ASRDiarizePipeline") | |
diarizer_inputs = torch.from_numpy(inputs).float().unsqueeze(0) | |
return inputs, diarizer_inputs | |