File size: 704 Bytes
d9b2258
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import click
import os
import onnx
from sparseml.exporters.kv_cache_injector import KeyValueCacheInjector
from sparseml.onnx.utils import ONNXGraph
@click.command()
@click.option('--input-file', help='Path to the input ONNX model file')
@click.option('--output-file', help='Output path for the modified model')
def modify_model(input_file, output_file):
    model = onnx.load(input_file, load_external_data=False)
    model = KeyValueCacheInjector(model_path=os.path.dirname(input_file)).apply(model)
    graph = ONNXGraph(model)
    graph.delete_orphaned_node_branches()
    onnx.save(model, output_file)
    print(f"Modified model saved to: {output_file}")
if __name__ == '__main__':
    modify_model()