import gradio as gr import json from pathlib import Path from huggingface_hub import hf_hub_download, HfApi from coremltools import ComputeUnit from transformers.onnx.utils import get_preprocessor from exporters.coreml import export from exporters.coreml.features import FeaturesManager from exporters.coreml.validate import validate_model_outputs compute_units_mapping = { "All": ComputeUnit.ALL, "CPU": ComputeUnit.CPU_ONLY, "CPU + GPU": ComputeUnit.CPU_AND_GPU, "CPU + NE": ComputeUnit.CPU_AND_NE, } compute_units_labels = list(compute_units_mapping.keys()) framework_mapping = { "PyTorch": "pt", "TensorFlow": "tf", } framework_labels = list(framework_mapping.keys()) precision_mapping = { "Float32": "float32", "Float16 quantization": "float16", } precision_labels = list(precision_mapping.keys()) tolerance_mapping = { "Model default": None, "1e-2": 1e-2, "1e-3": 1e-3, "1e-4": 1e-4, } tolerance_labels = list(tolerance_mapping.keys()) def error_str(error, title="Error"): return f"""#### {title} {error}""" if error else "" def url_to_model_id(model_id_str): if not model_id_str.startswith("https://huggingface.co/"): return model_id_str return model_id_str.split("/")[-2] + "/" + model_id_str.split("/")[-1] def supported_frameworks(model_id): """ Return a list of supported frameworks (`PyTorch` or `TensorFlow`) for a given model_id. Only PyTorch and Tensorflow are supported. """ api = HfApi() model_info = api.model_info(model_id) tags = model_info.tags frameworks = [tag for tag in tags if tag in ["pytorch", "tf"]] return sorted(["PyTorch" if f == "pytorch" else "TensorFlow" for f in frameworks]) def on_model_change(model): model = url_to_model_id(model) tasks = None error = None try: config_file = hf_hub_download(model, filename="config.json") if config_file is None: raise Exception(f"Model {model} not found") with open(config_file, "r") as f: config_json = f.read() config = json.loads(config_json) model_type = config["model_type"] features = FeaturesManager.get_supported_features_for_model_type(model_type) tasks = list(features.keys()) frameworks = supported_frameworks(model) selected_framework = frameworks[0] if len(frameworks) > 0 else None return ( gr.update(visible=bool(model_type)), # Settings column gr.update(choices=tasks, value=tasks[0] if tasks else None), # Tasks gr.update(visible=len(frameworks)>1, choices=frameworks, value=selected_framework), # Frameworks gr.update(value=error_str(error)), # Error ) except Exception as e: error = e model_type = None def convert_model(preprocessor, model, model_coreml_config, compute_units, precision, tolerance, output, use_past=False, seq2seq=None): coreml_config = model_coreml_config(model.config, use_past=use_past, seq2seq=seq2seq) mlmodel = export( preprocessor, model, coreml_config, quantize=precision, compute_units=compute_units, ) filename = output if seq2seq == "encoder": filename = filename.parent / ("encoder_" + filename.name) elif seq2seq == "decoder": filename = filename.parent / ("decoder_" + filename.name) filename = filename.as_posix() mlmodel.save(filename) if tolerance is None: tolerance = coreml_config.atol_for_validation validate_model_outputs(coreml_config, preprocessor, model, mlmodel, tolerance) def convert(model, task, compute_units, precision, tolerance, framework): model = url_to_model_id(model) compute_units = compute_units_mapping[compute_units] precision = precision_mapping[precision] tolerance = tolerance_mapping[tolerance] framework = framework_mapping[framework] # TODO: support legacy format output = Path("exported")/model/"coreml"/task output.mkdir(parents=True, exist_ok=True) output = output/f"{precision}_model.mlpackage" try: preprocessor = get_preprocessor(model) model = FeaturesManager.get_model_from_feature(task, model, framework=framework) _, model_coreml_config = FeaturesManager.check_supported_model_or_raise(model, feature=task) if task in ["seq2seq-lm", "speech-seq2seq"]: # Convert encoder / decoder convert_model( preprocessor, model, model_coreml_config, compute_units, precision, tolerance, output, seq2seq="encoder" ) convert_model( preprocessor, model, model_coreml_config, compute_units, precision, tolerance, output, seq2seq="decoder" ) else: convert_model( preprocessor, model, model_coreml_config, compute_units, precision, tolerance, output, ) # TODO: push to hub, whatever return "Done" except Exception as e: return error_str(e) DESCRIPTION = """ ## Convert a transformers model to Core ML With this Space you can try to convert a transformers model to Core ML. It uses the 🤗 Hugging Face [Exporters repo](https://huggingface.co/exporters) under the hood. Note that not all models are supported. If you get an error on a model you'd like to convert, please open an issue on the [repo](https://github.com/huggingface/exporters). After conversion, you can choose to submit a PR to the original repo, or create your own repo with just the converted Core ML weights. """ with gr.Blocks() as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(scale=2): gr.Markdown("## 1. Load model info") input_model = gr.Textbox( max_lines=1, label="Model name or URL, such as apple/mobilevit-small", placeholder="distilbert-base-uncased", value="distilbert-base-uncased", ) btn_get_tasks = gr.Button("Load") with gr.Column(scale=3): with gr.Column(visible=False) as group_settings: gr.Markdown("## 2. Select Task") radio_tasks = gr.Radio(label="Choose the task for the converted model.") gr.Markdown("The `default` task is suitable for feature extraction.") radio_framework = gr.Radio( visible=False, label="Framework", choices=framework_labels, value=framework_labels[0], ) radio_compute = gr.Radio( label="Compute Units", choices=compute_units_labels, value=compute_units_labels[0], ) radio_precision = gr.Radio( label="Precision", choices=precision_labels, value=precision_labels[0], ) radio_tolerance = gr.Radio( label="Absolute Tolerance for Validation", choices=tolerance_labels, value=tolerance_labels[0], ) btn_convert = gr.Button("Convert") gr.Markdown("Conversion will take a few minutes.") error_output = gr.Markdown(label="Output") btn_get_tasks.click( fn=on_model_change, inputs=input_model, outputs=[group_settings, radio_tasks, radio_framework, error_output], queue=False, scroll_to_output=True ) btn_convert.click( fn=convert, inputs=[input_model, radio_tasks, radio_compute, radio_precision, radio_tolerance, radio_framework], outputs=error_output, scroll_to_output=True ) # gr.HTML(""" #
Footer