File size: 3,870 Bytes
7860c23
 
 
 
 
 
 
 
 
 
 
071c26a
 
 
 
 
 
 
 
 
7860c23
 
 
071c26a
7860c23
 
071c26a
7860c23
 
 
071c26a
7860c23
 
071c26a
 
7860c23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from typing import List, Optional, Union

import numpy as np
import torch
from pyannote.audio import Pipeline
from transformers import pipeline


class ASRDiarizationPipeline:
    def __init__(
        self,
        asr_pipeline,
        diarization_pipeline,
    ):
        self.asr_pipeline = asr_pipeline
        self.diarization_pipeline = diarization_pipeline

    @classmethod
    def from_pretrained(
        cls,
        asr_model: Optional[str] = "openai/whisper-small",
        diarizer_model: Optional[str] = "pyannote/speaker-diarization",
        chunk_length_s: int = 30,
        use_auth_token: Union[str, bool] = True,
        **kwargs,
    ):
        asr_pipeline = pipeline(
            "automatic-speech-recognition",
            model=asr_model,
            chunk_length_s=chunk_length_s,
            use_auth_token=use_auth_token,
            **kwargs,
        )
        diarization_pipeline = Pipeline.from_pretrained(diarizer_model, use_auth_token=use_auth_token)
        cls(asr_pipeline, diarization_pipeline)

    def __call__(
        self,
        inputs: Union[np.ndarray, List[np.ndarray]],
        sampling_rate: int,
        group_by_speaker: bool = True,
        **kwargs,
    ):
        if not isinstance(inputs, np.ndarray):
            raise ValueError(f"Expected a numpy ndarray as input, got `{type(inputs)}`.")
        if len(inputs.shape) != 1:
            raise ValueError(f"Expected a single channel audio as input, got `{len(inputs.shape)}` channels.")

        diarizer_inputs = torch.from_numpy(inputs).float().unsqueeze(0)
        diarization = self.diarization_pipeline(
            {"waveform": diarizer_inputs, "sample_rate": sampling_rate},
            **kwargs,
        )
        del diarizer_inputs

        segments = diarization.for_json()["content"]

        new_segments = []
        prev_segment = cur_segment = segments[0]

        for i in range(1, len(segments)):
            cur_segment = segments[i]

            if cur_segment["label"] != prev_segment["label"] and i < len(segments):
                new_segments.append(
                    {
                        "segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]},
                        "speaker": prev_segment["label"],
                    }
                )
                prev_segment = segments[i]

        new_segments.append(
            {
                "segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["end"]},
                "speaker": prev_segment["label"],
            }
        )

        asr_out = self.asr_pipeline(
            {"array": inputs, "sampling_rate": sampling_rate},
            return_timestamps=True,
            **kwargs,
        )
        transcript = asr_out["chunks"]

        end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript])
        segmented_preds = []

        for segment in new_segments:
            end_time = segment["segment"]["end"]
            upto_idx = np.argmin(np.abs(end_timestamps - end_time))

            if group_by_speaker:
                segmented_preds.append(
                    {
                        "speaker": segment["speaker"],
                        "text": "".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]),
                        "timestamp": {
                            "start": transcript[0]["timestamp"][0],
                            "end": transcript[upto_idx]["timestamp"][1],
                        },
                    }
                )
            else:
                for i in range(upto_idx + 1):
                    segmented_preds.append({"speaker": segment["speaker"], **transcript[i]})

            transcript = transcript[upto_idx + 1 :]
            end_timestamps = end_timestamps[upto_idx + 1 :]

        return segmented_preds