File size: 3,702 Bytes
f037e2e
cdac71f
 
 
 
 
f037e2e
cdac71f
 
f037e2e
 
cdac71f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f037e2e
 
cdac71f
 
 
 
 
f037e2e
 
 
cdac71f
f037e2e
cdac71f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f037e2e
cdac71f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f037e2e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

import duckdb
import pandas as pd
import gradio as gr
from datasets import load_dataset
import tempfile
from query import sql_query

max_rows = 20
max_cols = None

df_display_kwargs = dict(
    wrap = True,
    row_count = 3,
    col_count = 4,
)

dataset_choices = [
    "rotten_tomatoes",
    "sciq",
]

def apply_sql(input_table, sql_query):

    output_df = duckdb.query(sql_query).to_df()

    return output_df

def display_dataset(dataset_id):

    dataset = load_dataset(dataset_id, split="train")
    df = dataset.to_pandas()
    display_df = df.iloc[:max_rows, :max_cols]
    return display_df, df

def upload_dataset(dataset_file):

    if dataset_file is None:
        return None, None
    
    df = pd.read_csv(dataset_file.name).iloc[:max_rows, :max_cols]
    display_df = df.iloc[:max_rows, :max_cols]

    return display_df, df


def process_dataset(full_dataset, sql_query):
    input_table = full_dataset
    output_df = duckdb.query(sql_query).to_df()

    with tempfile.NamedTemporaryFile(delete=False) as temp_file:
        file_path = temp_file.name
        output_df.to_csv(file_path)

    return output_df, file_path


theme = gr.themes.Soft(
    primary_hue="blue",
    neutral_hue="slate",
)

with gr.Blocks(theme=theme) as demo:
    full_dataset = gr.State()

    with gr.Column():
        with gr.Row().style(equal_height=True):
            
            with gr.Column(variant="panel"):

                with gr.Row():
                    dark_mode_btn = gr.Button("Dark Mode", variant="primary")
                    load_dataset_button = gr.Button("Load HF Dataset", variant="secondary")

                dataset_selector = gr.Dropdown(label="HF Dataset", choices=dataset_choices, value=dataset_choices[0])
                

            with gr.Column(variant="compact"):

                with gr.Row():
                    sql_query_btn = gr.Button("Apply SQL Query", variant="secondary")
                    download_dataset_btn = gr.Button("Download Queried Dataset", variant="primary")

                sql_query_comp = gr.Code(language=None, label="SQL Query", lines=3, value=sql_query)

        with gr.Row().style(equal_height=True):
            upload_dataset_comp = gr.File(label="Upload Dataset")
            download_dataset_comp = gr.File(label="Download Dataset")
                           
        with gr.Column(variant="panel"):
            input_df_display = gr.Dataframe(**df_display_kwargs, label=f"Input Dataframe (Truncated to first {max_rows} Rows)")

            output_df_display = gr.Dataframe(**df_display_kwargs, label=f"Output Dataframe (Truncated to first {max_rows} Rows)")

    load_dataset_button.click(fn=display_dataset, inputs=[dataset_selector], outputs=[input_df_display, full_dataset])
    upload_dataset_comp.change(fn=upload_dataset, inputs=[upload_dataset_comp], outputs=[input_df_display, full_dataset])

    sql_query_btn.click(fn=apply_sql, inputs=[input_df_display, sql_query_comp], outputs=[output_df_display])
    
    download_dataset_btn.click(fn=process_dataset, inputs=[full_dataset, sql_query_comp], outputs=[output_df_display, download_dataset_comp])

    toggle_dark_mode_args = dict(
        fn=None,
        inputs=None,
        outputs=None,
        _js="""() => {
        if (document.querySelectorAll('.dark').length) {
                document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'));
            } else {
                document.querySelector('body').classList.add('dark');
            }
        }""",
    )
    demo.load(**toggle_dark_mode_args)
    dark_mode_btn.click(**toggle_dark_mode_args)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)