import json
from pathlib import Path
from typing import Callable, Dict, List, Tuple
import langcodes as lc
import streamlit as st
import yaml
from datasets.utils.metadata import DatasetMetadata
st.set_page_config(
page_title="HF Dataset Tagging App",
page_icon="https://huggingface.co/front/assets/huggingface_logo.svg",
layout="wide",
initial_sidebar_state="auto",
)
# XXX: restyling errors as streamlit does not respect whitespaces on `st.error` and doesn't scroll horizontally, which
# generally makes things easier when reading error reports
st.markdown(
"""
""",
unsafe_allow_html=True,
)
task_set = json.load(open("task_set.json"))
license_set = json.load(open("license_set.json"))
multilinguality_set = {
"monolingual": "contains a single language",
"multilingual": "contains multiple languages",
"translation": "contains translated or aligned text",
"other": "other type of language distribution",
}
creator_set = {
"language": [
"found",
"crowdsourced",
"expert-generated",
"machine-generated",
"other",
],
"annotations": [
"found",
"crowdsourced",
"expert-generated",
"machine-generated",
"no-annotation",
"other",
],
}
########################
## Helper functions
########################
def load_ds_datas():
metada_exports = sorted(
[f for f in Path.cwd().iterdir() if f.name.startswith("metadata_")],
key=lambda f: f.lstat().st_mtime,
reverse=True,
)
if len(metada_exports) == 0:
raise ValueError("need to run ./build_metada_file.py at least once")
with metada_exports[0].open() as fi:
return json.load(fi)
def split_known(vals: List[str], okset: List[str]) -> Tuple[List[str], List[str]]:
if vals is None:
return [], []
return [v for v in vals if v in okset], [v for v in vals if v not in okset]
def multiselect(
w: st.delta_generator.DeltaGenerator,
title: str,
markdown: str,
values: List[str],
valid_set: List[str],
format_func: Callable = str,
):
valid_values, invalid_values = split_known(values, valid_set)
w.markdown(f"#### {title}")
if len(invalid_values) > 0:
w.markdown("Found the following invalid values:")
w.error(invalid_values)
return w.multiselect(markdown, valid_set, default=valid_values, format_func=format_func)
def validate_dict(w: st.delta_generator.DeltaGenerator, state_dict: Dict):
try:
DatasetMetadata(**state_dict)
w.markdown("✅ This is a valid tagset! 🤗")
except Exception as e:
w.markdown("❌ This is an invalid tagset, here are the errors in it:")
w.error(e)
def new_state():
return {
"task_categories": [],
"task_ids": [],
"multilinguality": [],
"languages": [],
"language_creators": [],
"annotations_creators": [],
"source_datasets": [],
"size_categories": [],
"licenses": [],
}
state = new_state()
datasets_md = load_ds_datas()
existing_tag_sets = {name: mds["metadata"] for name, mds in datasets_md.items()}
all_dataset_ids = list(existing_tag_sets.keys())
########################
## Dataset selection
########################
st.sidebar.markdown(
"""
# HuggingFace Dataset Tagger
This app aims to make it easier to add structured tags to the datasets present in the library.
"""
)
queryparams = st.experimental_get_query_params()
preload = queryparams.get("preload_dataset", list())
preloaded_id = None
initial_state = None
did_index = 0
if len(preload) == 1 and preload[0] in all_dataset_ids:
preloaded_id, *_ = preload
initial_state = existing_tag_sets.get(preloaded_id)
state = initial_state or new_state()
did_index = all_dataset_ids.index(preloaded_id)
preloaded_id = st.sidebar.selectbox(
label="Choose dataset to load tag set from", options=all_dataset_ids, index=did_index
)
leftbtn, rightbtn = st.sidebar.beta_columns(2)
if leftbtn.button("pre-load"):
initial_state = existing_tag_sets[preloaded_id]
state = initial_state or new_state()
st.experimental_set_query_params(preload_dataset=preloaded_id)
if sum(len(v) if v is not None else 0 for v in state.values()) > 0:
if rightbtn.button("flush state"):
state = new_state()
initial_state = None
preloaded_id = None
st.experimental_set_query_params()
if preloaded_id is not None and initial_state is not None:
st.sidebar.markdown(
f"""
---
The current base tagset is [`{preloaded_id}`](https://huggingface.co/datasets/{preloaded_id})
"""
)
validate_dict(st.sidebar, initial_state)
st.sidebar.markdown(
f"""
Here is the matching yaml block:
```yaml
{yaml.dump(initial_state)}
```
"""
)
leftcol, _, rightcol = st.beta_columns([12, 1, 12])
leftcol.markdown("### Supported tasks")
state["task_categories"] = multiselect(
leftcol,
"Task category",
"What categories of task does the dataset support?",
values=state["task_categories"],
valid_set=list(task_set.keys()),
format_func=lambda tg: f"{tg}: {task_set[tg]['description']}",
)
task_specifics = []
for tg in state["task_categories"]:
specs = multiselect(
leftcol,
f"Specific _{tg}_ tasks",
f"What specific tasks does the dataset support?",
values=[ts for ts in (state["task_ids"] or []) if ts in task_set[tg]["options"]],
valid_set=task_set[tg]["options"],
)
if "other" in specs:
other_task = st.text_input(
"You selected 'other' task. Please enter a short hyphen-separated description for the task:",
value="my-task-description",
)
st.write(f"Registering {tg}-other-{other_task} task")
specs[specs.index("other")] = f"{tg}-other-{other_task}"
task_specifics += specs
state["task_ids"] = task_specifics
leftcol.markdown("### Languages")
state["multilinguality"] = multiselect(
leftcol,
"Monolingual?",
"Does the dataset contain more than one language?",
values=state["multilinguality"],
valid_set=list(multilinguality_set.keys()),
format_func=lambda m: f"{m} : {multilinguality_set[m]}",
)
if "other" in state["multilinguality"]:
other_multilinguality = st.text_input(
"You selected 'other' type of multilinguality. Please enter a short hyphen-separated description:",
value="my-multilinguality",
)
st.write(f"Registering other-{other_multilinguality} multilinguality")
state["multilinguality"][state["multilinguality"].index("other")] = f"other-{other_multilinguality}"
valid_values, invalid_values = list(), list()
for langtag in state["languages"]:
try:
lc.get(langtag)
valid_values.append(langtag)
except:
invalid_values.append(langtag)
leftcol.markdown("#### Languages")
if len(invalid_values) > 0:
leftcol.markdown("Found the following invalid values:")
leftcol.error(invalid_values)
langtags = leftcol.text_area(
"What languages are represented in the dataset? expected format is BCP47 tags separated for ';' e.g. 'en-US;fr-FR'",
value=";".join(valid_values),
)
state["languages"] = langtags.split(";")
leftcol.markdown("### Dataset creators")
state["language_creators"] = multiselect(
leftcol,
"Data origin",
"Where does the text in the dataset come from?",
values=state["language_creators"],
valid_set=creator_set["language"],
)
state["annotations_creators"] = multiselect(
leftcol,
"Annotations origin",
"Where do the annotations in the dataset come from?",
values=state["annotations_creators"],
valid_set=creator_set["annotations"],
)
state["licenses"] = multiselect(
leftcol,
"Licenses",
"What licenses is the dataset under?",
valid_set=list(license_set.keys()),
values=state["licenses"],
format_func=lambda l: f"{l} : {license_set[l]}",
)
if "other" in state["licenses"]:
other_license = st.text_input(
"You selected 'other' type of license. Please enter a short hyphen-separated description:",
value="my-license",
)
st.write(f"Registering other-{other_license} license")
state["licenses"][state["licenses"].index("other")] = f"other-{other_license}"
# link to supported datasets
pre_select_ext_a = []
if "original" in state["source_datasets"]:
pre_select_ext_a += ["original"]
if any([p.startswith("extended") for p in state["source_datasets"]]):
pre_select_ext_a += ["extended"]
state["extended"] = multiselect(
leftcol,
"Relations to existing work",
"Does the dataset contain original data and/or was it extended from other datasets?",
values=pre_select_ext_a,
valid_set=["original", "extended"],
)
state["source_datasets"] = ["original"] if "original" in state["extended"] else []
if "extended" in state["extended"]:
pre_select_ext_b = [p.split("|")[1] for p in state["source_datasets"] if p.startswith("extended")]
extended_sources = multiselect(
leftcol,
"Linked datasets",
"Which other datasets does this one use data from?",
values=pre_select_ext_b,
valid_set=all_dataset_ids + ["other"],
)
if "other" in extended_sources:
other_extended_sources = st.text_input(
"You selected 'other' dataset. Please enter a short hyphen-separated description:",
value="my-dataset",
)
st.write(f"Registering other-{other_extended_sources} dataset")
extended_sources[extended_sources.index("other")] = f"other-{other_extended_sources}"
state["source_datasets"] += [f"extended|{src}" for src in extended_sources]
size_cats = ["unknown", "n<1K", "1K1M"]
current_size_cats = state.get("size_categories") or ["unknown"]
ok, nonok = split_known(current_size_cats, size_cats)
if len(nonok) > 0:
leftcol.markdown(f"**Found bad codes in existing tagset**:\n{nonok}")
state["size_categories"] = [
leftcol.selectbox(
"What is the size category of the dataset?",
options=size_cats,
index=size_cats.index(ok[0]) if len(ok) > 0 else 0,
)
]
########################
## Show results
########################
rightcol.markdown(
f"""
### Finalized tag set
"""
)
validate_dict(rightcol, state)
rightcol.markdown(
f"""
```yaml
{yaml.dump(state)}
```
---
#### Arbitrary yaml validator
This is a standalone tool, it is useful to check for errors on an existing tagset or modifying directly the text rather than the UI on the left.
""",
)
yamlblock = rightcol.text_area("Input your yaml here")
if yamlblock.strip() != "":
inputdict = yaml.safe_load(yamlblock)
validate_dict(rightcol, inputdict)