|
import os |
|
import gradio as gr |
|
from sqlalchemy import text |
|
from smolagents import tool, CodeAgent, HfApiModel |
|
import spaces |
|
import pandas as pd |
|
from database import ( |
|
engine, |
|
create_dynamic_table, |
|
clear_database, |
|
insert_rows_into_table, |
|
get_table_schema |
|
) |
|
|
|
def process_uploaded_file(file): |
|
""" |
|
Process the uploaded CSV file and load it into the database. |
|
|
|
Args: |
|
file: Path to the uploaded file |
|
|
|
Returns: |
|
tuple: (success_flag, message) |
|
""" |
|
try: |
|
if file is None: |
|
return False, "Please upload a file." |
|
|
|
|
|
df = pd.read_csv(file.name) |
|
|
|
if len(df.columns) == 0: |
|
return False, "Error: File contains no columns" |
|
|
|
|
|
clear_database() |
|
table = create_dynamic_table(df) |
|
|
|
|
|
records = df.to_dict('records') |
|
insert_rows_into_table(records, table) |
|
|
|
return True, "File successfully loaded! Proceeding to query interface..." |
|
|
|
except Exception as e: |
|
return False, f"Error processing file: {str(e)}" |
|
|
|
def get_data_table(): |
|
""" |
|
Fetches all data from the current table and returns it as a Pandas DataFrame. |
|
""" |
|
try: |
|
with engine.connect() as con: |
|
result = con.execute(text("SELECT * FROM data_table")) |
|
rows = result.fetchall() |
|
|
|
if not rows: |
|
return pd.DataFrame() |
|
|
|
|
|
columns = result.keys() |
|
df = pd.DataFrame(rows, columns=columns) |
|
return df |
|
|
|
except Exception as e: |
|
return pd.DataFrame({"Error": [str(e)]}) |
|
|
|
@tool |
|
def sql_engine(query: str) -> str: |
|
""" |
|
Executes an SQL query and returns formatted results. |
|
""" |
|
try: |
|
with engine.connect() as con: |
|
rows = con.execute(text(query)).fetchall() |
|
|
|
if not rows: |
|
return "No results found." |
|
|
|
if len(rows) == 1 and len(rows[0]) == 1: |
|
return str(rows[0][0]) |
|
|
|
return "\n".join([", ".join(map(str, row)) for row in rows]) |
|
|
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
def query_sql(user_query: str) -> str: |
|
""" |
|
Converts natural language input to an SQL query using CodeAgent |
|
and returns the execution results. |
|
""" |
|
schema = get_table_schema() |
|
if not schema: |
|
return "Error: No data table exists. Please upload a file first." |
|
|
|
schema_info = ( |
|
"The database has a table named 'data_table' with the following schema:\n" |
|
f"{schema}\n" |
|
"Generate a valid SQL SELECT query using ONLY these column names.\n" |
|
"DO NOT explain your reasoning, and DO NOT return anything other than the SQL query itself." |
|
) |
|
|
|
generated_sql = agent.run(f"{schema_info} Convert this request into SQL: {user_query}") |
|
|
|
if not isinstance(generated_sql, str): |
|
return f"{generated_sql}" |
|
|
|
if not generated_sql.strip().lower().startswith(("select", "show", "pragma")): |
|
return "Error: Only SELECT queries are allowed." |
|
|
|
result = sql_engine(generated_sql) |
|
|
|
try: |
|
float_result = float(result) |
|
return f"{float_result:.2f}" |
|
except ValueError: |
|
return result |
|
|
|
agent = CodeAgent( |
|
tools=[sql_engine], |
|
model=HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"), |
|
) |
|
|
|
|
|
def create_upload_interface(): |
|
with gr.Blocks() as upload_interface: |
|
gr.Markdown(""" |
|
# Data Query Interface |
|
|
|
Upload your CSV file to begin. |
|
|
|
### Requirements: |
|
- File must be in CSV format |
|
- First column will be used as the primary key |
|
- All columns will be automatically typed based on their content |
|
""") |
|
|
|
file_input = gr.File( |
|
label="Upload CSV File", |
|
file_types=[".csv"], |
|
type="file" |
|
) |
|
status = gr.Textbox(label="Status", interactive=False) |
|
|
|
def handle_upload(file): |
|
success, message = process_uploaded_file(file) |
|
if success: |
|
return message, gr.Blocks.update(visible=False), gr.Blocks.update(visible=True) |
|
return message, gr.Blocks.update(visible=True), gr.Blocks.update(visible=False) |
|
|
|
file_input.upload( |
|
fn=handle_upload, |
|
inputs=[file_input], |
|
outputs=[status, upload_interface, query_interface] |
|
) |
|
|
|
return upload_interface |
|
|
|
|
|
def create_query_interface(): |
|
with gr.Blocks(visible=False) as query_interface: |
|
gr.Markdown(""" |
|
## Data Query Interface |
|
|
|
Enter your questions about the data in natural language. |
|
The AI will convert your questions into SQL queries. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
user_input = gr.Textbox(label="Ask a question about the data") |
|
query_output = gr.Textbox(label="Result") |
|
|
|
with gr.Column(scale=2): |
|
gr.Markdown("### Current Data") |
|
data_table = gr.Dataframe( |
|
value=get_data_table(), |
|
label="Data Table", |
|
interactive=False |
|
) |
|
|
|
|
|
schema_display = gr.Markdown(value="Loading schema...") |
|
|
|
def update_schema(): |
|
schema = get_table_schema() |
|
if schema: |
|
return f"### Current Schema:\n```\n{schema}\n```" |
|
return "No data loaded" |
|
|
|
user_input.change( |
|
fn=query_sql, |
|
inputs=[user_input], |
|
outputs=[query_output] |
|
) |
|
|
|
|
|
with gr.Row(): |
|
refresh_table_btn = gr.Button("Refresh Table") |
|
refresh_schema_btn = gr.Button("Refresh Schema") |
|
|
|
refresh_table_btn.click( |
|
fn=get_data_table, |
|
outputs=[data_table] |
|
) |
|
|
|
refresh_schema_btn.click( |
|
fn=update_schema, |
|
outputs=[schema_display] |
|
) |
|
|
|
|
|
query_interface.load( |
|
fn=update_schema, |
|
outputs=[schema_display] |
|
) |
|
|
|
return query_interface |
|
|
|
|
|
with gr.Blocks() as demo: |
|
upload_interface = create_upload_interface() |
|
query_interface = create_query_interface() |
|
|
|
if __name__ == "__main__": |
|
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |
|
|