File size: 4,876 Bytes
5ec554b
a275607
5ec554b
a275607
5ec554b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84e9708
 
5ec554b
2e71e47
 
 
 
 
 
 
 
 
 
 
5ec554b
d6a516e
5ec554b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a275607
 
 
 
 
 
 
11788ff
a275607
 
 
 
 
 
11788ff
a275607
5ec554b
 
 
 
 
 
 
 
 
 
11788ff
5ec554b
a275607
 
 
5ec554b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#!/usr/bin/env python3
# Copyright      2024  Xiaomi Corp.        (authors: Fangjun Kuang)

import os
from typing import Any, Dict

import onnx
import torch
from onnxruntime.quantization import QuantType, quantize_dynamic
from pyannote.audio import Model
from pyannote.audio.core.task import Problem, Resolution


def add_meta_data(filename: str, meta_data: Dict[str, Any]):
    """Add meta data to an ONNX model. It is changed in-place.

    Args:
      filename:
        Filename of the ONNX model to be changed.
      meta_data:
        Key-value pairs.
    """
    model = onnx.load(filename)

    while len(model.metadata_props):
        model.metadata_props.pop()

    for key, value in meta_data.items():
        meta = model.metadata_props.add()
        meta.key = key
        meta.value = str(value)

    onnx.save(model, filename)


@torch.no_grad()
def main():
    # You can download ./pytorch_model.bin from
    # https://hf-mirror.com/csukuangfj/pyannote-models/tree/main/segmentation-3.0
    # or from
    # https://huggingface.co/Revai/reverb-diarization-v1/tree/main
    pt_filename = "./pytorch_model.bin"
    model = Model.from_pretrained(pt_filename)
    model.eval()
    assert model.dimension == 7, model.dimension
    print(model.specifications)

    assert (
        model.specifications.problem == Problem.MONO_LABEL_CLASSIFICATION
    ), model.specifications.problem

    assert (
        model.specifications.resolution == Resolution.FRAME
    ), model.specifications.resolution

    assert model.specifications.duration == 10.0, model.specifications.duration

    assert model.audio.sample_rate == 16000, model.audio.sample_rate

    # (batch, num_channels, num_samples)
    assert list(model.example_input_array.shape) == [
        1,
        1,
        16000 * 10,
    ], model.example_input_array.shape

    example_output = model(model.example_input_array)

    # (batch, num_frames, num_classes)
    #  assert list(example_output.shape) == [1, 589, 7], example_output.shape
    print(example_output.shape)

    print(model.receptive_field.step)
    print(model.receptive_field.duration)
    print(model.receptive_field.step * 16000)
    print(model.receptive_field.duration * 16000)

    #  assert model.receptive_field.step == 0.016875, model.receptive_field.step
    #  assert model.receptive_field.duration == 0.0619375, model.receptive_field.duration
    #  assert model.receptive_field.step * 16000 == 270, model.receptive_field.step * 16000
    #  assert model.receptive_field.duration * 16000 == 991, (
    #      model.receptive_field.duration * 16000
    #  )

    opset_version = 14

    filename = "model.onnx"
    torch.onnx.export(
        model,
        model.example_input_array,
        filename,
        opset_version=opset_version,
        input_names=["x"],
        output_names=["y"],
        dynamic_axes={
            "x": {0: "N", 2: "T"},
            "y": {0: "N", 1: "T"},
        },
    )

    sample_rate = model.audio.sample_rate

    window_size = int(model.specifications.duration) * 16000
    receptive_field_size = int(model.receptive_field.duration * 16000)
    receptive_field_shift = int(model.receptive_field.step * 16000)

    is_revai = os.getenv("SHERPA_ONNX_IS_REVAI", "")
    if is_revai == "":
        url_1 = "https://huggingface.co/pyannote/segmentation-3.0"
        url_2 = "https://huggingface.co/csukuangfj/pyannote-models/tree/main/segmentation-3.0"
        license_url = (
            "https://huggingface.co/pyannote/segmentation-3.0/blob/main/LICENSE"
        )
        model_author = "pyannote-audio"
    else:
        url_1 = "https://huggingface.co/Revai/reverb-diarization-v1"
        url_2 = "https://huggingface.co/csukuangfj/sherpa-onnx-reverb-diarization-v1"
        license_url = (
            "https://huggingface.co/Revai/reverb-diarization-v1/blob/main/LICENSE"
        )
        model_author = "Revai"

    meta_data = {
        "num_speakers": len(model.specifications.classes),
        "powerset_max_classes": model.specifications.powerset_max_classes,
        "num_classes": model.dimension,
        "sample_rate": sample_rate,
        "window_size": window_size,
        "receptive_field_size": receptive_field_size,
        "receptive_field_shift": receptive_field_shift,
        "model_type": "pyannote-segmentation-3.0",
        "version": "1",
        "model_author": model_author,
        "maintainer": "k2-fsa",
        "url_1": url_1,
        "url_2": url_2,
        "license": license_url,
    }
    add_meta_data(filename=filename, meta_data=meta_data)

    print("Generate int8 quantization models")

    filename_int8 = "model.int8.onnx"
    quantize_dynamic(
        model_input=filename,
        model_output=filename_int8,
        weight_type=QuantType.QUInt8,
    )

    print(f"Saved to {filename} and {filename_int8}")


if __name__ == "__main__":
    main()