|
import gradio as gr |
|
import json |
|
from pathlib import Path |
|
|
|
from huggingface_hub import hf_hub_download, HfApi |
|
from coremltools import ComputeUnit |
|
from transformers.onnx.utils import get_preprocessor |
|
|
|
from exporters.coreml import export |
|
from exporters.coreml.features import FeaturesManager |
|
from exporters.coreml.validate import validate_model_outputs |
|
|
|
compute_units_mapping = { |
|
"All": ComputeUnit.ALL, |
|
"CPU": ComputeUnit.CPU_ONLY, |
|
"CPU + GPU": ComputeUnit.CPU_AND_GPU, |
|
"CPU + NE": ComputeUnit.CPU_AND_NE, |
|
} |
|
compute_units_labels = list(compute_units_mapping.keys()) |
|
|
|
framework_mapping = { |
|
"PyTorch": "pt", |
|
"TensorFlow": "tf", |
|
} |
|
framework_labels = list(framework_mapping.keys()) |
|
|
|
precision_mapping = { |
|
"Float32": "float32", |
|
"Float16 quantization": "float16", |
|
} |
|
precision_labels = list(precision_mapping.keys()) |
|
|
|
tolerance_mapping = { |
|
"Model default": None, |
|
"1e-2": 1e-2, |
|
"1e-3": 1e-3, |
|
"1e-4": 1e-4, |
|
} |
|
tolerance_labels = list(tolerance_mapping.keys()) |
|
|
|
def error_str(error, title="Error"): |
|
return f"""#### {title} |
|
{error}""" if error else "" |
|
|
|
def url_to_model_id(model_id_str): |
|
if not model_id_str.startswith("https://huggingface.co/"): return model_id_str |
|
return model_id_str.split("/")[-2] + "/" + model_id_str.split("/")[-1] |
|
|
|
def supported_frameworks(model_id): |
|
""" |
|
Return a list of supported frameworks (`PyTorch` or `TensorFlow`) for a given model_id. |
|
Only PyTorch and Tensorflow are supported. |
|
""" |
|
api = HfApi() |
|
model_info = api.model_info(model_id) |
|
tags = model_info.tags |
|
frameworks = [tag for tag in tags if tag in ["pytorch", "tf"]] |
|
return sorted(["PyTorch" if f == "pytorch" else "TensorFlow" for f in frameworks]) |
|
|
|
def on_model_change(model): |
|
model = url_to_model_id(model) |
|
tasks = None |
|
error = None |
|
|
|
try: |
|
config_file = hf_hub_download(model, filename="config.json") |
|
if config_file is None: |
|
raise Exception(f"Model {model} not found") |
|
|
|
with open(config_file, "r") as f: |
|
config_json = f.read() |
|
|
|
config = json.loads(config_json) |
|
model_type = config["model_type"] |
|
|
|
features = FeaturesManager.get_supported_features_for_model_type(model_type) |
|
tasks = list(features.keys()) |
|
|
|
frameworks = supported_frameworks(model) |
|
selected_framework = frameworks[0] if len(frameworks) > 0 else None |
|
return ( |
|
gr.update(visible=bool(model_type)), |
|
gr.update(choices=tasks, value=tasks[0] if tasks else None), |
|
gr.update(visible=len(frameworks)>1, choices=frameworks, value=selected_framework), |
|
gr.update(value=error_str(error)), |
|
) |
|
except Exception as e: |
|
error = e |
|
model_type = None |
|
|
|
|
|
def convert_model(preprocessor, model, model_coreml_config, compute_units, precision, tolerance, output, use_past=False, seq2seq=None): |
|
coreml_config = model_coreml_config(model.config, use_past=use_past, seq2seq=seq2seq) |
|
|
|
mlmodel = export( |
|
preprocessor, |
|
model, |
|
coreml_config, |
|
quantize=precision, |
|
compute_units=compute_units, |
|
) |
|
|
|
filename = output |
|
if seq2seq == "encoder": |
|
filename = filename.parent / ("encoder_" + filename.name) |
|
elif seq2seq == "decoder": |
|
filename = filename.parent / ("decoder_" + filename.name) |
|
filename = filename.as_posix() |
|
|
|
mlmodel.save(filename) |
|
|
|
if tolerance is None: |
|
tolerance = coreml_config.atol_for_validation |
|
validate_model_outputs(coreml_config, preprocessor, model, mlmodel, tolerance) |
|
|
|
|
|
def convert(model, task, compute_units, precision, tolerance, framework): |
|
model = url_to_model_id(model) |
|
compute_units = compute_units_mapping[compute_units] |
|
precision = precision_mapping[precision] |
|
tolerance = tolerance_mapping[tolerance] |
|
framework = framework_mapping[framework] |
|
|
|
|
|
output = Path("exported")/model/"coreml"/task |
|
output.mkdir(parents=True, exist_ok=True) |
|
output = output/f"{precision}_model.mlpackage" |
|
|
|
try: |
|
preprocessor = get_preprocessor(model) |
|
model = FeaturesManager.get_model_from_feature(task, model, framework=framework) |
|
_, model_coreml_config = FeaturesManager.check_supported_model_or_raise(model, feature=task) |
|
|
|
if task in ["seq2seq-lm", "speech-seq2seq"]: |
|
|
|
convert_model( |
|
preprocessor, |
|
model, |
|
model_coreml_config, |
|
compute_units, |
|
precision, |
|
tolerance, |
|
output, |
|
seq2seq="encoder" |
|
) |
|
convert_model( |
|
preprocessor, |
|
model, |
|
model_coreml_config, |
|
compute_units, |
|
precision, |
|
tolerance, |
|
output, |
|
seq2seq="decoder" |
|
) |
|
else: |
|
convert_model( |
|
preprocessor, |
|
model, |
|
model_coreml_config, |
|
compute_units, |
|
precision, |
|
tolerance, |
|
output, |
|
) |
|
|
|
|
|
return "Done" |
|
except Exception as e: |
|
return error_str(e) |
|
|
|
DESCRIPTION = """ |
|
## Convert a transformers model to Core ML |
|
|
|
With this Space you can try to convert a transformers model to Core ML. It uses the 🤗 Hugging Face [Exporters repo](https://huggingface.co/exporters) under the hood. |
|
|
|
Note that not all models are supported. If you get an error on a model you'd like to convert, please open an issue on the [repo](https://github.com/huggingface/exporters). |
|
|
|
After conversion, you can choose to submit a PR to the original repo, or create your own repo with just the converted Core ML weights. |
|
""" |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(DESCRIPTION) |
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
gr.Markdown("## 1. Load model info") |
|
input_model = gr.Textbox( |
|
max_lines=1, |
|
label="Model name or URL, such as apple/mobilevit-small", |
|
placeholder="distilbert-base-uncased", |
|
value="distilbert-base-uncased", |
|
) |
|
btn_get_tasks = gr.Button("Load") |
|
with gr.Column(scale=3): |
|
with gr.Column(visible=False) as group_settings: |
|
gr.Markdown("## 2. Select Task") |
|
radio_tasks = gr.Radio(label="Choose the task for the converted model.") |
|
gr.Markdown("The `default` task is suitable for feature extraction.") |
|
radio_framework = gr.Radio( |
|
visible=False, |
|
label="Framework", |
|
choices=framework_labels, |
|
value=framework_labels[0], |
|
) |
|
radio_compute = gr.Radio( |
|
label="Compute Units", |
|
choices=compute_units_labels, |
|
value=compute_units_labels[0], |
|
) |
|
radio_precision = gr.Radio( |
|
label="Precision", |
|
choices=precision_labels, |
|
value=precision_labels[0], |
|
) |
|
radio_tolerance = gr.Radio( |
|
label="Absolute Tolerance for Validation", |
|
choices=tolerance_labels, |
|
value=tolerance_labels[0], |
|
) |
|
btn_convert = gr.Button("Convert") |
|
gr.Markdown("Conversion will take a few minutes.") |
|
|
|
|
|
error_output = gr.Markdown(label="Output") |
|
|
|
btn_get_tasks.click( |
|
fn=on_model_change, |
|
inputs=input_model, |
|
outputs=[group_settings, radio_tasks, radio_framework, error_output], |
|
queue=False, |
|
scroll_to_output=True |
|
) |
|
|
|
btn_convert.click( |
|
fn=convert, |
|
inputs=[input_model, radio_tasks, radio_compute, radio_precision, radio_tolerance, radio_framework], |
|
outputs=error_output, |
|
scroll_to_output=True |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.queue(concurrency_count=1, max_size=10) |
|
demo.launch(debug=True, share=False) |
|
|