Spaces:
Sleeping
Sleeping
from itertools import count, islice | |
from typing import Any, Iterable | |
import gradio as gr | |
import pandas as pd | |
import requests | |
from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
session = requests.Session() | |
empty_dataframe = pd.DataFrame({"1": [], "2": [], "3": []}) | |
NUM_ROWS_PREVIEW = 5 | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
"# 🤗 Dataset ReWriter ✍️✨\n\n" | |
"Adjust, translate or transform completely existing datasets.\n\n" | |
) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
dataset_search = HuggingfaceHubSearch( | |
label="Hub Dataset ID", | |
placeholder="Search for dataset id on Huggingface", | |
search_type="dataset", | |
) | |
subset_dropdown = gr.Dropdown(info="Subset", show_label=False, visible=False) | |
split_dropdown = gr.Dropdown(info="Split", show_label=False, visible=False) | |
input_query = gr.Textbox(label="Enter the adjustment or transformation to apply to the dataset:") | |
rewrite_button = gr.Button("ReWrite Dataset", variant="primary") | |
gr.Markdown("### Input") | |
input_preview = gr.DataFrame(interactive=False, wrap=True) | |
gr.Markdown("### Output") | |
output_preview = gr.DataFrame(interactive=False, wrap=True) | |
save_button = gr.Button("Save ReWriten Dataset", interactive=False) | |
############ | |
# | |
# Utils | |
# | |
########### | |
def stream_rows(dataset: str, subset: str, split: str, batch_size: int = 100) -> Iterable[dict[str, Any]]: | |
for i in count(): | |
rows_resp = session.get(f"https://datasets-server.huggingface.co/rows?dataset={dataset}&config={subset}&split={split}&offset={i * batch_size}&length={batch_size}", timeout=10).json() | |
if "error" in rows_resp: | |
raise RuntimeError(rows_resp["error"]) | |
if not rows_resp["rows"]: | |
break | |
for row_item in rows_resp["rows"]: | |
yield row_item["row"] | |
############ | |
# | |
# Events | |
# | |
########### | |
def _resolve_dataset_selection(dataset: str, default_subset: str, default_split: str) -> dict: | |
if "/" not in dataset.strip().strip("/"): | |
return None, None, { | |
subset_dropdown: gr.Dropdown(visible=False), | |
split_dropdown: gr.Dropdown(visible=False), | |
} | |
info_resp = session.get(f"https://datasets-server.huggingface.co/info?dataset={dataset}", timeout=3).json() | |
if "error" in info_resp: | |
return None, None, { | |
subset_dropdown: gr.Dropdown(visible=False), | |
split_dropdown: gr.Dropdown(visible=False), | |
} | |
subsets: list[str] = list(info_resp["dataset_info"]) | |
subset = default_subset if default_subset in subsets else subsets[0] | |
splits: list[str] = info_resp["dataset_info"][subset]["splits"] | |
split = default_split if default_split in splits else splits[0] | |
return subset, split, { | |
subset_dropdown: gr.Dropdown(value=subset, choices=subsets, visible=len(subsets) > 1), | |
split_dropdown: gr.Dropdown(value=split, choices=splits, visible=len(splits) > 1), | |
} | |
def _show_input_preview(dataset: str, default_subset: str, default_split: str) -> dict: | |
subset, split, output = _resolve_dataset_selection(dataset, default_subset=default_subset, default_split=default_split) | |
if subset is None or split is None: | |
return output | |
return { | |
input_preview: pd.DataFrame(islice(({ | |
k: str(v) for k, v in row.items()} | |
for row in stream_rows(dataset, subset, split, batch_size=NUM_ROWS_PREVIEW) | |
), NUM_ROWS_PREVIEW)), | |
**output | |
} | |
def show_input_from_dataset_search(dataset: str) -> dict: | |
return _show_input_preview(dataset, default_subset="default", default_split="train") | |
def show_input_from_subset_dropdown(dataset: str, subset: str) -> dict: | |
return _show_input_preview(dataset, default_subset=subset, default_split="train") | |
def show_input_from_split_dropdown(dataset: str, subset: str, split: str) -> dict: | |
return _show_input_preview(dataset, default_subset=subset, default_split=split) | |
def rewrite(dataset: str, subset: str, split: str, input_preview_df: pd.DataFrame) -> dict: | |
# TODO: implement | |
return {output_preview: pd.DataFrame([{"TODO": ["implement"]}])} | |
demo.launch() | |