from typing import List, Optional, Union import numpy as np import torch from pyannote.audio import Pipeline from transformers import pipeline class ASRDiarizationPipeline: def __init__( self, asr_pipeline, diarization_pipeline, ): self.asr_pipeline = asr_pipeline self.diarization_pipeline = diarization_pipeline @classmethod def from_pretrained( cls, asr_model: Optional[str] = "openai/whisper-small", diarizer_model: Optional[str] = "pyannote/speaker-diarization", chunk_length_s: int = 30, use_auth_token: Union[str, bool] = True, **kwargs, ): asr_pipeline = pipeline( "automatic-speech-recognition", model=asr_model, chunk_length_s=chunk_length_s, use_auth_token=use_auth_token, **kwargs, ) diarization_pipeline = Pipeline.from_pretrained(diarizer_model, use_auth_token=use_auth_token) cls(asr_pipeline, diarization_pipeline) def __call__( self, inputs: Union[np.ndarray, List[np.ndarray]], sampling_rate: int, group_by_speaker: bool = True, **kwargs, ): if not isinstance(inputs, np.ndarray): raise ValueError(f"Expected a numpy ndarray as input, got `{type(inputs)}`.") if len(inputs.shape) != 1: raise ValueError(f"Expected a single channel audio as input, got `{len(inputs.shape)}` channels.") diarizer_inputs = torch.from_numpy(inputs).float().unsqueeze(0) diarization = self.diarization_pipeline( {"waveform": diarizer_inputs, "sample_rate": sampling_rate}, **kwargs, ) del diarizer_inputs segments = diarization.for_json()["content"] new_segments = [] prev_segment = cur_segment = segments[0] for i in range(1, len(segments)): cur_segment = segments[i] if cur_segment["label"] != prev_segment["label"] and i < len(segments): new_segments.append( { "segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]}, "speaker": prev_segment["label"], } ) prev_segment = segments[i] new_segments.append( { "segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["end"]}, "speaker": prev_segment["label"], } ) asr_out = self.asr_pipeline( {"array": inputs, "sampling_rate": sampling_rate}, return_timestamps=True, **kwargs, ) transcript = asr_out["chunks"] end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript]) segmented_preds = [] for segment in new_segments: end_time = segment["segment"]["end"] upto_idx = np.argmin(np.abs(end_timestamps - end_time)) if group_by_speaker: segmented_preds.append( { "speaker": segment["speaker"], "text": "".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]), "timestamp": { "start": transcript[0]["timestamp"][0], "end": transcript[upto_idx]["timestamp"][1], }, } ) else: for i in range(upto_idx + 1): segmented_preds.append({"speaker": segment["speaker"], **transcript[i]}) transcript = transcript[upto_idx + 1 :] end_timestamps = end_timestamps[upto_idx + 1 :] return segmented_preds