Spaces:
Sleeping
Sleeping
import json | |
import logging | |
import os | |
import urllib.parse | |
from typing import Any | |
import gradio as gr | |
import requests | |
from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
logger = logging.getLogger(__name__) | |
example = HuggingfaceHubSearch().example_value() | |
HEADER_CONTENT = ( | |
"# 🤗 Dataset DuckDB Query Chatbot\n\n" | |
"This is a basic text to SQL tool that allows you to query datasets on Hugging Face Hub. " | |
"It's a fork of " | |
"[davidberenstein1957/text-to-sql-hub-datasets](https://huggingface.co/spaces/davidberenstein1957/text-to-sql-hub-datasets) " | |
"that adds chat capability and table name generation." | |
) | |
ABOUT_CONTENT = """ | |
This space uses [LLama 3.1 70B](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct). | |
via [together.ai](https://together.ai) | |
Also, it uses the | |
[dataset-server API](https://redocly.github.io/redoc/?url=https://datasets-server.huggingface.co/openapi.json#operation/isValidDataset). | |
Query history is saved and given to the chat model so you can chat to refine your query as you go. | |
When the DuckDB modal is presented, you may need to click on the name of the | |
config/split at the base of the modal to get the table loaded for DuckDB's use. | |
Search for and select a dataset to begin. | |
""" | |
SYSTEM_PROMPT_TEMPLATE = ( | |
"You are a SQL query expert assistant that returns a DuckDB SQL queries " | |
"based on the user's natural language query and dataset features. " | |
"You might need to use DuckDB functions for lists and aggregations, " | |
"given the features. Only return the SQL query, no other text. The " | |
"user may ask you to make various adjustments to the query. Every " | |
"time your response should only include the refined SQL query and " | |
"nothing else.\n\n" | |
"The table being queried is named: {table_name}.\n\n" | |
"# Features\n" | |
"{features}" | |
) | |
def get_iframe(hub_repo_id, sql_query=None): | |
if not hub_repo_id: | |
raise ValueError("Hub repo id is required") | |
if sql_query: | |
sql_query = urllib.parse.quote(sql_query) | |
url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer?sql_console=true&sql={sql_query}" | |
else: | |
url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer" | |
iframe = f""" | |
<iframe | |
src="{url}" | |
frameborder="0" | |
width="100%" | |
height="800px" | |
></iframe> | |
""" | |
return iframe | |
def get_table_info(hub_repo_id): | |
url: str = f"https://datasets-server.huggingface.co/info?dataset={hub_repo_id}" | |
response = requests.get(url) | |
try: | |
data = response.json() | |
data = data.get("dataset_info") | |
return json.dumps(data) | |
except Exception as e: | |
gr.Error(f"Error getting column info: {e}") | |
def get_table_name( | |
config: str | None, | |
split: str | None, | |
config_choices: list[str], | |
split_choices: list[str], | |
): | |
if len(config_choices) > 0 and config is None: | |
config = config_choices[0] | |
if len(split_choices) > 0 and split is None: | |
split = split_choices[0] | |
if len(config_choices) > 1 and len(split_choices) > 1: | |
base_name = f"{config}_{split}" | |
elif len(config_choices) >= 1 and len(split_choices) <= 1: | |
base_name = config | |
else: | |
base_name = split | |
def replace_char(c): | |
if c.isalnum(): | |
return c | |
if c in ["-", "_", "/"]: | |
return "_" | |
return "" | |
table_name = "".join(replace_char(c) for c in base_name) | |
if table_name[0].isdigit(): | |
table_name = f"_{table_name}" | |
return table_name.lower() | |
def get_system_prompt( | |
card_data: dict[str, Any], | |
config: str | None, | |
split: str | None, | |
): | |
config_choices = get_config_choices(card_data) | |
split_choices = get_split_choices(card_data) | |
table_name = get_table_name(config, split, config_choices, split_choices) | |
features = card_data[config]["features"] | |
return SYSTEM_PROMPT_TEMPLATE.format( | |
table_name=table_name, | |
features=features, | |
) | |
def get_config_choices(card_data: dict[str, Any]) -> list[str]: | |
return list(card_data.keys()) | |
def get_split_choices(card_data: dict[str, Any]) -> list[str]: | |
splits = set() | |
for config in card_data.values(): | |
splits.update(config.get("splits", {}).keys()) | |
return list(splits) | |
def query_dataset(hub_repo_id, card_data, query, config, split, history): | |
if card_data is None or len(card_data) == 0: | |
if hub_repo_id: | |
iframe = get_iframe(hub_repo_id) | |
else: | |
iframe = "<p>No dataset selected.</p>" | |
return "", iframe, [], "" | |
card_data = json.loads(card_data) | |
system_prompt = get_system_prompt(card_data, config, split) | |
messages = [{"role": "system", "content": system_prompt}] | |
for turn in history: | |
user, assistant = turn | |
messages.append( | |
{ | |
"role": "user", | |
"content": user, | |
} | |
) | |
messages.append( | |
{ | |
"role": "assistant", | |
"content": assistant, | |
} | |
) | |
messages.append( | |
{ | |
"role": "user", | |
"content": query, | |
} | |
) | |
api_key = os.environ["API_KEY_TOGETHER_AI"].strip() | |
response = requests.post( | |
"https://api.together.xyz/v1/chat/completions", | |
json=dict( | |
model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", | |
messages=messages, | |
max_tokens=1000, | |
), | |
headers={"Authorization": f"Bearer {api_key}"}, | |
) | |
if response.status_code != 200: | |
logger.warning(response.text) | |
try: | |
response.raise_for_status() | |
except Exception as e: | |
gr.Error(f"Could not query LLM for suggestion: {e}") | |
response_dict = response.json() | |
duck_query = response_dict["choices"][0]["message"]["content"] | |
duck_query = _sanitize_duck_query(duck_query) | |
history.append((query, duck_query)) | |
return duck_query, get_iframe(hub_repo_id, duck_query), history, "" | |
def _sanitize_duck_query(duck_query: str) -> str: | |
# Sometimes the LLM wraps the query like this: | |
# ```sql | |
# select * from x; | |
# ``` | |
# This removes that wrapping if present. | |
if "```" not in duck_query: | |
return duck_query | |
start_idx = duck_query.index("```") + len("```") | |
end_idx = duck_query.rindex("```") | |
duck_query = duck_query[start_idx:end_idx] | |
if duck_query.startswith("sql\n"): | |
duck_query = duck_query.replace("sql\n", "", 1) | |
return duck_query | |
with gr.Blocks() as demo: | |
gr.Markdown(HEADER_CONTENT) | |
with gr.Accordion("About/Help", open=False): | |
gr.Markdown(ABOUT_CONTENT) | |
with gr.Row(): | |
search_in = HuggingfaceHubSearch( | |
label="Search Hugging Face Hub", | |
placeholder="Search for models on Huggingface", | |
search_type="dataset", | |
sumbit_on_select=True, | |
) | |
with gr.Row(): | |
show_btn = gr.Button("Show Dataset") | |
with gr.Row(): | |
sql_out = gr.Code( | |
label="DuckDB SQL Query", | |
interactive=True, | |
language="sql", | |
lines=1, | |
visible=False, | |
) | |
with gr.Row(): | |
card_data = gr.Code(label="Card data", language="json", visible=False) | |
def show_config_split_choices(data): | |
try: | |
data = json.loads(data.strip()) | |
config_choices = get_config_choices(data) | |
split_choices = get_split_choices(data) | |
except Exception: | |
config_choices = [] | |
split_choices = [] | |
initial_config = config_choices[0] if len(config_choices) > 0 else None | |
initial_split = split_choices[0] if len(split_choices) > 0 else None | |
with gr.Row(): | |
with gr.Column(): | |
config_selection = gr.Dropdown( | |
label="Config Name", choices=config_choices, value=initial_config | |
) | |
with gr.Column(): | |
split_selection = gr.Dropdown( | |
label="Split Name", choices=split_choices, value=initial_split | |
) | |
with gr.Accordion("Query Suggestion History.", open=False) as accordion: | |
chatbot = gr.Chatbot(height=200, layout="bubble") | |
with gr.Row(): | |
query = gr.Textbox( | |
label="Query Description", | |
placeholder="Enter a natural language query to generate SQL", | |
) | |
with gr.Row(): | |
with gr.Column(): | |
query_btn = gr.Button("Get Suggested Query") | |
with gr.Column(): | |
clear = gr.ClearButton([query, chatbot], value="Reset Query History") | |
with gr.Row(): | |
search_out = gr.HTML(label="Search Results") | |
gr.on( | |
[show_btn.click, search_in.submit], | |
fn=get_iframe, | |
inputs=[search_in], | |
outputs=[search_out], | |
).then( | |
fn=get_table_info, | |
inputs=[search_in], | |
outputs=[card_data], | |
) | |
gr.on( | |
[query_btn.click, query.submit], | |
fn=query_dataset, | |
inputs=[ | |
search_in, | |
card_data, | |
query, | |
config_selection, | |
split_selection, | |
chatbot, | |
], | |
outputs=[sql_out, search_out, chatbot, query], | |
) | |
gr.on([query_btn.click], fn=lambda: gr.update(open=True), outputs=[accordion]) | |
if __name__ == "__main__": | |
demo.launch() | |