Spaces:
Runtime error
Runtime error
from typing import List, Optional, Union | |
import numpy as np | |
import torch | |
from pyannote.audio import Pipeline | |
from transformers import pipeline | |
class ASRDiarizationPipeline: | |
def __init__( | |
self, | |
asr_model: Optional[str] = "openai/whisper-small", | |
diarizer_model: Optional[str] = "pyannote/speaker-diarization", | |
chunk_length_s: int = 30, | |
**kwargs, | |
): | |
self.asr_pipeline = pipeline( | |
"automatic-speech-recognition", | |
model=asr_model, | |
use_auth_token=True, | |
chunk_length_s=chunk_length_s, | |
**kwargs, | |
) | |
self.diarization_pipeline = Pipeline.from_pretrained(diarizer_model, use_auth_token=True) | |
def __call__( | |
self, | |
inputs: Union[np.ndarray, List[np.ndarray]], | |
sampling_rate: int, | |
group_by_speaker: bool = True, | |
**kwargs, | |
): | |
if not isinstance(inputs, np.ndarray): | |
raise ValueError(f"Expected a numpy ndarray as input, got `{type(inputs)}`.") | |
if len(inputs.shape) != 1: | |
raise ValueError(f"Expected a single channel audio as input, got `{len(inputs.shape)}` channels.") | |
diarizer_inputs = torch.from_numpy(inputs).float().unsqueeze(0) | |
diarization = self.diarization_pipeline( | |
{"waveform": diarizer_inputs, "sample_rate": sampling_rate}, | |
**kwargs, | |
) | |
del diarizer_inputs | |
segments = diarization.for_json()["content"] | |
new_segments = [] | |
prev_segment = cur_segment = segments[0] | |
for i in range(1, len(segments)): | |
cur_segment = segments[i] | |
if cur_segment["label"] != prev_segment["label"] and i < len(segments): | |
new_segments.append( | |
{ | |
"segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]}, | |
"speaker": prev_segment["label"], | |
} | |
) | |
prev_segment = segments[i] | |
new_segments.append( | |
{ | |
"segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["end"]}, | |
"speaker": prev_segment["label"], | |
} | |
) | |
asr_out = self.asr_pipeline( | |
{"array": inputs, "sampling_rate": sampling_rate}, | |
return_timestamps=True, | |
**kwargs, | |
) | |
transcript = asr_out["chunks"] | |
end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript]) | |
segmented_preds = [] | |
for segment in new_segments: | |
end_time = segment["segment"]["end"] | |
upto_idx = np.argmin(np.abs(end_timestamps - end_time)) | |
if group_by_speaker: | |
segmented_preds.append( | |
{ | |
"speaker": segment["speaker"], | |
"text": "".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]), | |
"timestamp": { | |
"start": transcript[0]["timestamp"][0], | |
"end": transcript[upto_idx]["timestamp"][1], | |
}, | |
} | |
) | |
else: | |
for i in range(upto_idx + 1): | |
segmented_preds.append({"speaker": segment["speaker"], **transcript[i]}) | |
transcript = transcript[upto_idx + 1 :] | |
end_timestamps = end_timestamps[upto_idx + 1 :] | |
return segmented_preds | |