|
import os |
|
import gradio as gr |
|
from sqlalchemy import text |
|
from smolagents import tool, CodeAgent, HfApiModel |
|
import spaces |
|
import pandas as pd |
|
from database import ( |
|
engine, |
|
create_dynamic_table, |
|
clear_database, |
|
insert_rows_into_table, |
|
get_table_schema |
|
) |
|
|
|
def get_data_table(): |
|
""" |
|
Fetches all data from the current table and returns it as a Pandas DataFrame. |
|
""" |
|
try: |
|
|
|
with engine.connect() as con: |
|
tables = con.execute(text( |
|
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'" |
|
)).fetchall() |
|
|
|
if not tables: |
|
return pd.DataFrame() |
|
|
|
|
|
table_name = tables[0][0] |
|
|
|
with engine.connect() as con: |
|
result = con.execute(text(f"SELECT * FROM {table_name}")) |
|
rows = result.fetchall() |
|
|
|
if not rows: |
|
return pd.DataFrame() |
|
|
|
columns = result.keys() |
|
df = pd.DataFrame(rows, columns=columns) |
|
return df |
|
|
|
except Exception as e: |
|
return pd.DataFrame({"Error": [str(e)]}) |
|
|
|
def get_table_info(): |
|
""" |
|
Gets the current table name and column information. |
|
Returns: |
|
tuple: (table_name, list of column names, column info) |
|
""" |
|
try: |
|
|
|
with engine.connect() as con: |
|
tables = con.execute(text( |
|
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'" |
|
)).fetchall() |
|
|
|
if not tables: |
|
return None, [], {} |
|
|
|
|
|
table_name = tables[0][0] |
|
|
|
|
|
with engine.connect() as con: |
|
columns = con.execute(text(f"PRAGMA table_info({table_name})")).fetchall() |
|
|
|
|
|
column_names = [col[1] for col in columns] |
|
column_info = { |
|
col[1]: { |
|
'type': col[2], |
|
'is_primary': bool(col[5]) |
|
} |
|
for col in columns |
|
} |
|
|
|
return table_name, column_names, column_info |
|
|
|
except Exception as e: |
|
print(f"Error getting table info: {str(e)}") |
|
return None, [], {} |
|
|
|
def process_sql_file(file_path): |
|
""" |
|
Process an SQL file and execute its contents. |
|
""" |
|
try: |
|
|
|
with open(file_path, 'r') as file: |
|
sql_content = file.read() |
|
|
|
|
|
sql_content = sql_content.replace('AUTO_INCREMENT', 'AUTOINCREMENT') |
|
|
|
|
|
statements = [stmt.strip() for stmt in sql_content.split(';') if stmt.strip()] |
|
|
|
|
|
clear_database() |
|
|
|
|
|
with engine.begin() as conn: |
|
for statement in statements: |
|
if statement.strip(): |
|
conn.execute(text(statement)) |
|
|
|
return True, "SQL file successfully executed!" |
|
|
|
except Exception as e: |
|
return False, f"Error processing SQL file: {str(e)}" |
|
|
|
def process_csv_file(file_path): |
|
""" |
|
Process a CSV file and load it into the database. |
|
""" |
|
try: |
|
|
|
df = pd.read_csv(file_path) |
|
|
|
if len(df.columns) == 0: |
|
return False, "Error: File contains no columns" |
|
|
|
|
|
clear_database() |
|
table = create_dynamic_table(df) |
|
|
|
|
|
records = df.to_dict('records') |
|
insert_rows_into_table(records, table) |
|
|
|
return True, "CSV file successfully loaded!" |
|
|
|
except Exception as e: |
|
return False, f"Error processing CSV file: {str(e)}" |
|
|
|
def process_uploaded_file(file): |
|
""" |
|
Process the uploaded file (either SQL or CSV). |
|
""" |
|
try: |
|
if file is None: |
|
return False, "Please upload a file." |
|
|
|
|
|
file_ext = os.path.splitext(file)[1].lower() |
|
|
|
if file_ext == '.sql': |
|
return process_sql_file(file) |
|
elif file_ext == '.csv': |
|
return process_csv_file(file) |
|
else: |
|
return False, "Error: Unsupported file type. Please upload either a .sql or .csv file." |
|
|
|
except Exception as e: |
|
return False, f"Error processing file: {str(e)}" |
|
|
|
@tool |
|
def sql_engine(query: str) -> str: |
|
""" |
|
Executes an SQL query and returns formatted results. |
|
|
|
Args: |
|
query: The SQL query string to execute on the database. Must be a valid SELECT query. |
|
|
|
Returns: |
|
str: The formatted query results as a string. |
|
""" |
|
try: |
|
with engine.connect() as con: |
|
rows = con.execute(text(query)).fetchall() |
|
|
|
if not rows: |
|
return "No results found." |
|
|
|
if len(rows) == 1 and len(rows[0]) == 1: |
|
return str(rows[0][0]) |
|
|
|
return "\n".join([", ".join(map(str, row)) for row in rows]) |
|
|
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
agent = CodeAgent( |
|
tools=[sql_engine], |
|
model=HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"), |
|
) |
|
|
|
def query_sql(user_query: str) -> str: |
|
""" |
|
Converts natural language input to an SQL query using CodeAgent. |
|
""" |
|
|
|
table_name, column_names, column_info = get_table_info() |
|
|
|
if not table_name: |
|
return "Error: No data table exists. Please upload a file first." |
|
|
|
|
|
schema_info = ( |
|
f"The database has a table named '{table_name}' with the following columns:\n" |
|
+ "\n".join([ |
|
f"- {col} ({info['type']}){' primary key' if info['is_primary'] else ''}" |
|
for col, info in column_info.items() |
|
]) |
|
+ "\n\nGenerate a valid SQL SELECT query using ONLY these column names.\n" |
|
"The table name is '" + table_name + "'.\n" |
|
"If column names contain spaces, they must be quoted.\n" |
|
"You can use aggregate functions like COUNT, AVG, SUM, etc.\n" |
|
"DO NOT explain your reasoning, and DO NOT return anything other than the SQL query itself." |
|
) |
|
|
|
|
|
generated_sql = agent.run(f"{schema_info} Convert this request into SQL: {user_query}") |
|
|
|
|
|
if not isinstance(generated_sql, str): |
|
return f"Error: Invalid query generated" |
|
|
|
|
|
sql_lines = [line for line in generated_sql.split('\n') if 'select' in line.lower()] |
|
if sql_lines: |
|
generated_sql = sql_lines[0] |
|
|
|
|
|
generated_sql = generated_sql.strip().rstrip(';') |
|
|
|
|
|
for wrong_name in ['table_name', 'customers', 'main']: |
|
if wrong_name in generated_sql: |
|
generated_sql = generated_sql.replace(wrong_name, table_name) |
|
|
|
|
|
for col in column_names: |
|
if ' ' in col: |
|
if col in generated_sql and f'"{col}"' not in generated_sql and f'`{col}`' not in generated_sql: |
|
generated_sql = generated_sql.replace(col, f'"{col}"') |
|
|
|
try: |
|
|
|
result = sql_engine(generated_sql) |
|
|
|
|
|
try: |
|
float_result = float(result) |
|
return f"{float_result:,.0f}" |
|
except ValueError: |
|
return result |
|
except Exception as e: |
|
return f"Error executing query: {str(e)}" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Group() as upload_group: |
|
gr.Markdown(""" |
|
# Data Query Interface |
|
|
|
Upload your data file to begin. |
|
|
|
### Supported File Types: |
|
- SQL (.sql): SQL file containing CREATE TABLE and INSERT statements |
|
- CSV (.csv): CSV file with headers that will be automatically converted to a table |
|
|
|
### CSV Requirements: |
|
- Must include headers |
|
- First column will be used as the primary key |
|
- Column types will be automatically detected |
|
|
|
### SQL Requirements: |
|
- Must contain valid SQL statements |
|
- Statements must be separated by semicolons |
|
- Should include CREATE TABLE and data insertion statements |
|
""") |
|
|
|
file_input = gr.File( |
|
label="Upload Data File", |
|
file_types=[".csv", ".sql"], |
|
type="filepath" |
|
) |
|
status = gr.Textbox(label="Status", interactive=False) |
|
|
|
with gr.Group(visible=False) as query_group: |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
user_input = gr.Textbox(label="Ask a question about the data") |
|
query_output = gr.Textbox(label="Result") |
|
|
|
with gr.Column(scale=2): |
|
gr.Markdown("### Current Data") |
|
data_table = gr.Dataframe( |
|
value=None, |
|
label="Data Table", |
|
interactive=False |
|
) |
|
|
|
schema_display = gr.Markdown(value="Loading schema...") |
|
refresh_btn = gr.Button("Refresh Data") |
|
|
|
def handle_upload(file_obj): |
|
if file_obj is None: |
|
return ( |
|
"Please upload a file.", |
|
None, |
|
"No schema available", |
|
gr.update(visible=True), |
|
gr.update(visible=False) |
|
) |
|
|
|
success, message = process_uploaded_file(file_obj) |
|
if success: |
|
df = get_data_table() |
|
_, _, column_info = get_table_info() |
|
schema = "\n".join([ |
|
f"- {col} ({info['type']}){' primary key' if info['is_primary'] else ''}" |
|
for col, info in column_info.items() |
|
]) |
|
return ( |
|
message, |
|
df, |
|
f"### Current Schema:\n```\n{schema}\n```", |
|
gr.update(visible=False), |
|
gr.update(visible=True) |
|
) |
|
return ( |
|
message, |
|
None, |
|
"No schema available", |
|
gr.update(visible=True), |
|
gr.update(visible=False) |
|
) |
|
|
|
def refresh_data(): |
|
df = get_data_table() |
|
_, _, column_info = get_table_info() |
|
schema = "\n".join([ |
|
f"- {col} ({info['type']}){' primary key' if info['is_primary'] else ''}" |
|
for col, info in column_info.items() |
|
]) |
|
return df, f"### Current Schema:\n```\n{schema}\n```" |
|
|
|
|
|
file_input.upload( |
|
fn=handle_upload, |
|
inputs=file_input, |
|
outputs=[ |
|
status, |
|
data_table, |
|
schema_display, |
|
upload_group, |
|
query_group |
|
] |
|
) |
|
|
|
user_input.change( |
|
fn=query_sql, |
|
inputs=user_input, |
|
outputs=query_output |
|
) |
|
|
|
refresh_btn.click( |
|
fn=refresh_data, |
|
outputs=[data_table, schema_display] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860 |
|
) |