|
|
|
from tests import test_cases |
|
from evaluate.utils.logging import get_logger |
|
from evaluate.utils import ( |
|
infer_gradio_input_types, |
|
parse_gradio_data, |
|
json_to_string_type, |
|
parse_readme, |
|
parse_test_cases, |
|
) |
|
from pathlib import Path |
|
|
|
import evaluate |
|
import sys |
|
import evaluate |
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
def launch_gradio_widget(metric, test_cases): |
|
"""Launches `metric` widget with Gradio.""" |
|
|
|
try: |
|
import gradio as gr |
|
except ImportError as error: |
|
logger.error( |
|
"To create a metric widget with Gradio make sure gradio is installed." |
|
) |
|
raise error |
|
|
|
local_path = Path(sys.path[0]) |
|
|
|
if isinstance(metric.features, list): |
|
(feature_names, feature_types) = zip(*metric.features[0].items()) |
|
else: |
|
(feature_names, feature_types) = zip(*metric.features.items()) |
|
gradio_input_types = infer_gradio_input_types(feature_types) |
|
|
|
parsed_test_cases = parse_test_cases(test_cases, feature_names, gradio_input_types) |
|
|
|
def compute(data): |
|
return metric.compute(**parse_gradio_data(data, gradio_input_types)) |
|
|
|
demo = gr.Interface( |
|
fn=compute, |
|
inputs=gr.Dataframe( |
|
value=parsed_test_cases[0], |
|
headers=feature_names, |
|
col_count=len(feature_names), |
|
datatype=json_to_string_type(gradio_input_types), |
|
|
|
), |
|
outputs=gr.Textbox(label=metric.name), |
|
description=( |
|
metric.info.description |
|
+ "\nISCO codes must be wrapped in double quotes." |
|
|
|
), |
|
title=f"Metric: {metric.name}", |
|
article=parse_readme(local_path / "README.md"), |
|
|
|
examples=[parsed_test_cases], |
|
) |
|
if __name__ == "__main__": |
|
demo.launch() |
|
else: |
|
return demo |
|
|
|
|
|
module = evaluate.load("danieldux/isco_hierarchical_accuracy") |
|
|
|
if __name__ == "__main__": |
|
launch_gradio_widget(module, test_cases) |
|
|