File size: 3,550 Bytes
7860c23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from typing import List, Optional, Union

import numpy as np
import torch
from pyannote.audio import Pipeline
from transformers import pipeline


class ASRDiarizationPipeline:
    def __init__(
        self,
        asr_model: Optional[str] = "openai/whisper-small",
        diarizer_model: Optional[str] = "pyannote/speaker-diarization",
        chunk_length_s: int = 30,
        **kwargs,
    ):
        self.asr_pipeline = pipeline(
            "automatic-speech-recognition",
            model=asr_model,
            use_auth_token=True,
            chunk_length_s=chunk_length_s,
            **kwargs,
        )
        self.diarization_pipeline = Pipeline.from_pretrained(diarizer_model, use_auth_token=True)

    def __call__(
        self,
        inputs: Union[np.ndarray, List[np.ndarray]],
        sampling_rate: int,
        group_by_speaker: bool = True,
        **kwargs,
    ):
        if not isinstance(inputs, np.ndarray):
            raise ValueError(f"Expected a numpy ndarray as input, got `{type(inputs)}`.")
        if len(inputs.shape) != 1:
            raise ValueError(f"Expected a single channel audio as input, got `{len(inputs.shape)}` channels.")

        diarizer_inputs = torch.from_numpy(inputs).float().unsqueeze(0)
        diarization = self.diarization_pipeline(
            {"waveform": diarizer_inputs, "sample_rate": sampling_rate},
            **kwargs,
        )
        del diarizer_inputs

        segments = diarization.for_json()["content"]

        new_segments = []
        prev_segment = cur_segment = segments[0]

        for i in range(1, len(segments)):
            cur_segment = segments[i]

            if cur_segment["label"] != prev_segment["label"] and i < len(segments):
                new_segments.append(
                    {
                        "segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]},
                        "speaker": prev_segment["label"],
                    }
                )
                prev_segment = segments[i]

        new_segments.append(
            {
                "segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["end"]},
                "speaker": prev_segment["label"],
            }
        )

        asr_out = self.asr_pipeline(
            {"array": inputs, "sampling_rate": sampling_rate},
            return_timestamps=True,
            **kwargs,
        )
        transcript = asr_out["chunks"]

        end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript])
        segmented_preds = []

        for segment in new_segments:
            end_time = segment["segment"]["end"]
            upto_idx = np.argmin(np.abs(end_timestamps - end_time))

            if group_by_speaker:
                segmented_preds.append(
                    {
                        "speaker": segment["speaker"],
                        "text": "".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]),
                        "timestamp": {
                            "start": transcript[0]["timestamp"][0],
                            "end": transcript[upto_idx]["timestamp"][1],
                        },
                    }
                )
            else:
                for i in range(upto_idx + 1):
                    segmented_preds.append({"speaker": segment["speaker"], **transcript[i]})

            transcript = transcript[upto_idx + 1 :]
            end_timestamps = end_timestamps[upto_idx + 1 :]

        return segmented_preds