#!/usr/bin/env python3 # Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) import os from typing import Any, Dict import onnx import torch from onnxruntime.quantization import QuantType, quantize_dynamic from pyannote.audio import Model from pyannote.audio.core.task import Problem, Resolution def add_meta_data(filename: str, meta_data: Dict[str, Any]): """Add meta data to an ONNX model. It is changed in-place. Args: filename: Filename of the ONNX model to be changed. meta_data: Key-value pairs. """ model = onnx.load(filename) while len(model.metadata_props): model.metadata_props.pop() for key, value in meta_data.items(): meta = model.metadata_props.add() meta.key = key meta.value = str(value) onnx.save(model, filename) @torch.no_grad() def main(): # You can download ./pytorch_model.bin from # https://hf-mirror.com/csukuangfj/pyannote-models/tree/main/segmentation-3.0 # or from # https://huggingface.co/Revai/reverb-diarization-v1/tree/main pt_filename = "./pytorch_model.bin" model = Model.from_pretrained(pt_filename) model.eval() assert model.dimension == 7, model.dimension print(model.specifications) assert ( model.specifications.problem == Problem.MONO_LABEL_CLASSIFICATION ), model.specifications.problem assert ( model.specifications.resolution == Resolution.FRAME ), model.specifications.resolution assert model.specifications.duration == 10.0, model.specifications.duration assert model.audio.sample_rate == 16000, model.audio.sample_rate # (batch, num_channels, num_samples) assert list(model.example_input_array.shape) == [ 1, 1, 16000 * 10, ], model.example_input_array.shape example_output = model(model.example_input_array) # (batch, num_frames, num_classes) # assert list(example_output.shape) == [1, 589, 7], example_output.shape print(example_output.shape) print(model.receptive_field.step) print(model.receptive_field.duration) print(model.receptive_field.step * 16000) print(model.receptive_field.duration * 16000) # assert model.receptive_field.step == 0.016875, model.receptive_field.step # assert model.receptive_field.duration == 0.0619375, model.receptive_field.duration # assert model.receptive_field.step * 16000 == 270, model.receptive_field.step * 16000 # assert model.receptive_field.duration * 16000 == 991, ( # model.receptive_field.duration * 16000 # ) opset_version = 14 filename = "model.onnx" torch.onnx.export( model, model.example_input_array, filename, opset_version=opset_version, input_names=["x"], output_names=["y"], dynamic_axes={ "x": {0: "N", 2: "T"}, "y": {0: "N", 1: "T"}, }, ) sample_rate = model.audio.sample_rate window_size = int(model.specifications.duration) * 16000 receptive_field_size = int(model.receptive_field.duration * 16000) receptive_field_shift = int(model.receptive_field.step * 16000) is_revai = os.getenv("SHERPA_ONNX_IS_REVAI", "") if is_revai == "": url_1 = "https://huggingface.co/pyannote/segmentation-3.0" url_2 = "https://huggingface.co/csukuangfj/pyannote-models/tree/main/segmentation-3.0" license_url = ( "https://huggingface.co/pyannote/segmentation-3.0/blob/main/LICENSE" ) model_author = "pyannote-audio" else: url_1 = "https://huggingface.co/Revai/reverb-diarization-v1" url_2 = "https://huggingface.co/csukuangfj/sherpa-onnx-reverb-diarization-v1" license_url = ( "https://huggingface.co/Revai/reverb-diarization-v1/blob/main/LICENSE" ) model_author = "Revai" meta_data = { "num_speakers": len(model.specifications.classes), "powerset_max_classes": model.specifications.powerset_max_classes, "num_classes": model.dimension, "sample_rate": sample_rate, "window_size": window_size, "receptive_field_size": receptive_field_size, "receptive_field_shift": receptive_field_shift, "model_type": "pyannote-segmentation-3.0", "version": "1", "model_author": model_author, "maintainer": "k2-fsa", "url_1": url_1, "url_2": url_2, "license": license_url, } add_meta_data(filename=filename, meta_data=meta_data) print("Generate int8 quantization models") filename_int8 = "model.int8.onnx" quantize_dynamic( model_input=filename, model_output=filename_int8, weight_type=QuantType.QUInt8, ) print(f"Saved to {filename} and {filename_int8}") if __name__ == "__main__": main()