|
import os |
|
import gradio as gr |
|
from sqlalchemy import text, inspect |
|
from smolagents import tool, CodeAgent, HfApiModel |
|
import pandas as pd |
|
import tempfile |
|
from database import engine, initialize_database |
|
|
|
|
|
initialize_database() |
|
|
|
|
|
@tool |
|
def sql_engine(query: str) -> str: |
|
""" |
|
Executes an SQL SELECT query and returns the results. |
|
|
|
Args: |
|
query (str): The SQL query string to execute. Only SELECT queries are allowed. |
|
|
|
Returns: |
|
str: A formatted string containing the query results, or an error message if the query fails. |
|
""" |
|
try: |
|
with engine.connect() as con: |
|
rows = con.execute(text(query)).fetchall() |
|
if not rows: |
|
return "No results found." |
|
return "\n".join([", ".join(map(str, row)) for row in rows]) |
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
|
|
agent = CodeAgent( |
|
tools=[sql_engine], |
|
model=HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"), |
|
) |
|
|
|
|
|
def execute_sql_script(file_path): |
|
""" |
|
Executes an uploaded SQL file to initialize the database. |
|
|
|
Args: |
|
file_path (str): Path to the SQL file. |
|
|
|
Returns: |
|
str: Success message or error description. |
|
""" |
|
try: |
|
with engine.connect() as con: |
|
with open(file_path, "r") as f: |
|
sql_script = f.read() |
|
con.execute(text(sql_script)) |
|
return "SQL file executed successfully. Database updated." |
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
|
|
def get_table_names(): |
|
""" |
|
Returns a list of tables available in the database. |
|
|
|
Returns: |
|
list: List of table names. |
|
""" |
|
inspector = inspect(engine) |
|
return inspector.get_table_names() |
|
|
|
|
|
def get_table_schema(table_name): |
|
""" |
|
Returns a list of column names for a given table. |
|
|
|
Args: |
|
table_name (str): Name of the table. |
|
|
|
Returns: |
|
list: List of column names. |
|
""" |
|
inspector = inspect(engine) |
|
columns = inspector.get_columns(table_name) |
|
return [col["name"] for col in columns] |
|
|
|
|
|
def get_table_data(table_name): |
|
""" |
|
Retrieves all rows from a specified table as a Pandas DataFrame. |
|
|
|
Args: |
|
table_name (str): Name of the table. |
|
|
|
Returns: |
|
pd.DataFrame: Table data or an error message. |
|
""" |
|
try: |
|
with engine.connect() as con: |
|
result = con.execute(text(f"SELECT * FROM {table_name}")) |
|
rows = result.fetchall() |
|
|
|
columns = get_table_schema(table_name) |
|
|
|
if not rows: |
|
return pd.DataFrame(columns=columns) |
|
|
|
return pd.DataFrame(rows, columns=columns) |
|
except Exception as e: |
|
return pd.DataFrame({"Error": [str(e)]}) |
|
|
|
|
|
def handle_file_upload(file): |
|
""" |
|
Handles SQL file upload, executes SQL, and updates database schema. |
|
|
|
Args: |
|
file (File): Uploaded SQL file. |
|
|
|
Returns: |
|
tuple: Execution result message and updated table data. |
|
""" |
|
temp_file_path = tempfile.mkstemp(suffix=".sql")[1] |
|
with open(temp_file_path, "wb") as temp_file: |
|
temp_file.write(file.read()) |
|
|
|
result = execute_sql_script(temp_file_path) |
|
tables = get_table_names() |
|
|
|
if tables: |
|
table_data = {table: get_table_data(table) for table in tables} |
|
else: |
|
table_data = {"Error": ["No tables found after execution. Ensure your SQL file creates tables."]} |
|
|
|
return result, table_data |
|
|
|
|
|
def query_sql(user_query: str) -> str: |
|
""" |
|
Converts a user's natural language query into an SQL query. |
|
|
|
Args: |
|
user_query (str): The question asked by the user. |
|
|
|
Returns: |
|
str: The results of the executed SQL query. |
|
""" |
|
tables = get_table_names() |
|
if not tables: |
|
return "Error: No tables found. Please upload an SQL file first." |
|
|
|
schema_info = "Available tables and columns:\n" |
|
|
|
for table in tables: |
|
columns = get_table_schema(table) |
|
schema_info += f"Table '{table}' has columns: {', '.join(columns)}.\n" |
|
|
|
schema_info += "Generate a valid SQL SELECT query using ONLY these column names. DO NOT return anything other than the SQL query itself." |
|
|
|
generated_sql = agent.run(f"{schema_info} Convert this request into SQL: {user_query}") |
|
|
|
if not isinstance(generated_sql, str) or not generated_sql.strip().lower().startswith(("select", "show", "pragma")): |
|
return "Error: Only SELECT queries are allowed." |
|
|
|
return sql_engine(generated_sql) |
|
|
|
|
|
def handle_query(user_input: str) -> str: |
|
""" |
|
Handles user input and returns the SQL query result. |
|
|
|
Args: |
|
user_input (str): User's natural language query. |
|
|
|
Returns: |
|
str: The query result or error message. |
|
""" |
|
return query_sql(user_input) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## SQL Query Interface") |
|
|
|
with gr.Row(): |
|
user_input = gr.Textbox(label="Ask a question about the data") |
|
query_output = gr.Textbox(label="Result") |
|
|
|
user_input.change(fn=handle_query, inputs=user_input, outputs=query_output) |
|
|
|
gr.Markdown("## Upload SQL File to Execute") |
|
file_upload = gr.File(label="Upload SQL File") |
|
upload_output = gr.Textbox(label="Execution Result") |
|
|
|
|
|
table_output = gr.Dataframe(label="Database Tables (Dynamic)") |
|
|
|
file_upload.change(fn=handle_file_upload, inputs=file_upload, outputs=[upload_output, table_output]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |
|
|