whisper-speaker-diarization / asr_diarizer.py
sanchit-gandhi's picture
Create asr_diarizer.py
7860c23
raw
history blame
3.55 kB
from typing import List, Optional, Union
import numpy as np
import torch
from pyannote.audio import Pipeline
from transformers import pipeline
class ASRDiarizationPipeline:
def __init__(
self,
asr_model: Optional[str] = "openai/whisper-small",
diarizer_model: Optional[str] = "pyannote/speaker-diarization",
chunk_length_s: int = 30,
**kwargs,
):
self.asr_pipeline = pipeline(
"automatic-speech-recognition",
model=asr_model,
use_auth_token=True,
chunk_length_s=chunk_length_s,
**kwargs,
)
self.diarization_pipeline = Pipeline.from_pretrained(diarizer_model, use_auth_token=True)
def __call__(
self,
inputs: Union[np.ndarray, List[np.ndarray]],
sampling_rate: int,
group_by_speaker: bool = True,
**kwargs,
):
if not isinstance(inputs, np.ndarray):
raise ValueError(f"Expected a numpy ndarray as input, got `{type(inputs)}`.")
if len(inputs.shape) != 1:
raise ValueError(f"Expected a single channel audio as input, got `{len(inputs.shape)}` channels.")
diarizer_inputs = torch.from_numpy(inputs).float().unsqueeze(0)
diarization = self.diarization_pipeline(
{"waveform": diarizer_inputs, "sample_rate": sampling_rate},
**kwargs,
)
del diarizer_inputs
segments = diarization.for_json()["content"]
new_segments = []
prev_segment = cur_segment = segments[0]
for i in range(1, len(segments)):
cur_segment = segments[i]
if cur_segment["label"] != prev_segment["label"] and i < len(segments):
new_segments.append(
{
"segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]},
"speaker": prev_segment["label"],
}
)
prev_segment = segments[i]
new_segments.append(
{
"segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["end"]},
"speaker": prev_segment["label"],
}
)
asr_out = self.asr_pipeline(
{"array": inputs, "sampling_rate": sampling_rate},
return_timestamps=True,
**kwargs,
)
transcript = asr_out["chunks"]
end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript])
segmented_preds = []
for segment in new_segments:
end_time = segment["segment"]["end"]
upto_idx = np.argmin(np.abs(end_timestamps - end_time))
if group_by_speaker:
segmented_preds.append(
{
"speaker": segment["speaker"],
"text": "".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]),
"timestamp": {
"start": transcript[0]["timestamp"][0],
"end": transcript[upto_idx]["timestamp"][1],
},
}
)
else:
for i in range(upto_idx + 1):
segmented_preds.append({"speaker": segment["speaker"], **transcript[i]})
transcript = transcript[upto_idx + 1 :]
end_timestamps = end_timestamps[upto_idx + 1 :]
return segmented_preds