from typing import List, Optional, Union import numpy as np import torch from torchaudio import functional as F import requests from pyannote.audio import Pipeline from transformers import pipeline from transformers.pipelines.audio_utils import ffmpeg_read class ASRDiarizationPipeline: def __init__( self, asr_pipeline, diarization_pipeline, ): self.asr_pipeline = asr_pipeline self.diarization_pipeline = diarization_pipeline self.sampling_rate = self.asr_pipeline.feature_extractor.sampling_rate @classmethod def from_pretrained( cls, asr_model: Optional[str] = "openai/whisper-small", diarizer_model: Optional[str] = "pyannote/speaker-diarization", chunk_length_s: Optional[int] = 30, use_auth_token: Optional[Union[str, bool]] = True, **kwargs, ): asr_pipeline = pipeline( "automatic-speech-recognition", model=asr_model, chunk_length_s=chunk_length_s, use_auth_token=use_auth_token, **kwargs, ) diarization_pipeline = Pipeline.from_pretrained(diarizer_model, use_auth_token=use_auth_token) cls(asr_pipeline, diarization_pipeline) def __call__( self, inputs: Union[np.ndarray, List[np.ndarray]], group_by_speaker: bool = True, **kwargs, ): """ Transcribe the audio sequence(s) given as inputs to text. Args: inputs (`np.ndarray` or `bytes` or `str` or `dict`): The inputs is either : - `str` that is the filename of the audio file, the file will be read at the correct sampling rate to get the waveform using *ffmpeg*. This requires *ffmpeg* to be installed on the system. - `bytes` it is supposed to be the content of an audio file and is interpreted by *ffmpeg* in the same way. - (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`) Raw audio at the correct sampling rate (no further check will be done) - `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this pipeline do the resampling. The dict must be in the format `{"sampling_rate": int, "raw": np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to treat the first `left` samples and last `right` samples to be ignored in decoding (but used at inference to provide more context to the model). Only use `stride` with CTC models. Return: `Dict`: A dictionary with the following keys: - **text** (`str` ) -- The recognized text. - **chunks** (*optional(, `List[Dict]`) When using `return_timestamps`, the `chunks` will become a list containing all the various text chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamps": (0.5,0.9), {"text": "there", "timestamps": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing `"".join(chunk["text"] for chunk in output["chunks"])`. """ inputs, diarizer_inputs = self.preprocess(inputs) diarization = self.diarization_pipeline( {"waveform": diarizer_inputs, "sample_rate": self.sampling_rate}, **kwargs, ) segments = diarization.for_json()["content"] new_segments = [] prev_segment = cur_segment = segments[0] for i in range(1, len(segments)): cur_segment = segments[i] if cur_segment["label"] != prev_segment["label"] and i < len(segments): new_segments.append( { "segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]}, "speaker": prev_segment["label"], } ) prev_segment = segments[i] new_segments.append( { "segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["end"]}, "speaker": prev_segment["label"], } ) asr_out = self.asr_pipeline( {"array": inputs, "sampling_rate": self.sampling_rate}, return_timestamps=True, **kwargs, ) transcript = asr_out["chunks"] end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript]) segmented_preds = [] for segment in new_segments: end_time = segment["segment"]["end"] upto_idx = np.argmin(np.abs(end_timestamps - end_time)) if group_by_speaker: segmented_preds.append( { "speaker": segment["speaker"], "text": "".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]), "timestamp": { "start": transcript[0]["timestamp"][0], "end": transcript[upto_idx]["timestamp"][1], }, } ) else: for i in range(upto_idx + 1): segmented_preds.append({"speaker": segment["speaker"], **transcript[i]}) transcript = transcript[upto_idx + 1 :] end_timestamps = end_timestamps[upto_idx + 1 :] return segmented_preds def preprocess(self, inputs): if isinstance(inputs, str): if inputs.startswith("http://") or inputs.startswith("https://"): # We need to actually check for a real protocol, otherwise it's impossible to use a local file # like http_huggingface_co.png inputs = requests.get(inputs).content else: with open(inputs, "rb") as f: inputs = f.read() if isinstance(inputs, bytes): inputs = ffmpeg_read(inputs, self.sampling_rate) if isinstance(inputs, dict): # Accepting `"array"` which is the key defined in `datasets` for better integration if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)): raise ValueError( "When passing a dictionary to ASRDiarizePipeline, the dict needs to contain a " '"raw" key containing the numpy array representing the audio and a "sampling_rate" key, ' "containing the sampling_rate associated with that array" ) _inputs = inputs.pop("raw", None) if _inputs is None: # Remove path which will not be used from `datasets`. inputs.pop("path", None) _inputs = inputs.pop("array", None) in_sampling_rate = inputs.pop("sampling_rate") inputs = _inputs if in_sampling_rate != self.sampling_rate: inputs = F.resample( torch.from_numpy(inputs), in_sampling_rate, self.sampling_rate ).numpy() if not isinstance(inputs, np.ndarray): raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`") if len(inputs.shape) != 1: raise ValueError("We expect a single channel audio input for ASRDiarizePipeline") diarizer_inputs = torch.from_numpy(inputs).float().unsqueeze(0) return inputs, diarizer_inputs