import os import gradio as gr from sqlalchemy import text, inspect from smolagents import tool, CodeAgent, HfApiModel import pandas as pd import tempfile from database import engine, initialize_database # Ensure the database initializes (won't crash if empty) initialize_database() # SQL Execution Tool (FIXED - Defined BEFORE Use) @tool def sql_engine(query: str) -> str: """ Executes an SQL SELECT query and returns the results. Args: query (str): The SQL query string to execute. Only SELECT queries are allowed. Returns: str: A formatted string containing the query results, or an error message if the query fails. """ try: with engine.connect() as con: rows = con.execute(text(query)).fetchall() if not rows: return "No results found." return "\n".join([", ".join(map(str, row)) for row in rows]) except Exception as e: return f"Error: {str(e)}" # Initialize CodeAgent for SQL query generation (Moved Below `sql_engine`) agent = CodeAgent( tools=[sql_engine], model=HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"), ) # Function to execute an uploaded SQL script def execute_sql_script(file_path): """ Executes an uploaded SQL file to initialize the database. Args: file_path (str): Path to the SQL file. Returns: str: Success message or error description. """ try: with engine.connect() as con: with open(file_path, "r") as f: sql_script = f.read() con.execute(text(sql_script)) return "SQL file executed successfully. Database updated." except Exception as e: return f"Error: {str(e)}" # Function to get table names dynamically def get_table_names(): """ Returns a list of tables available in the database. Returns: list: List of table names. """ inspector = inspect(engine) return inspector.get_table_names() # Function to get table schema dynamically def get_table_schema(table_name): """ Returns a list of column names for a given table. Args: table_name (str): Name of the table. Returns: list: List of column names. """ inspector = inspect(engine) columns = inspector.get_columns(table_name) return [col["name"] for col in columns] # Function to fetch data dynamically from any table def get_table_data(table_name): """ Retrieves all rows from a specified table as a Pandas DataFrame. Args: table_name (str): Name of the table. Returns: pd.DataFrame: Table data or an error message. """ try: with engine.connect() as con: result = con.execute(text(f"SELECT * FROM {table_name}")) rows = result.fetchall() columns = get_table_schema(table_name) if not rows: return pd.DataFrame(columns=columns) return pd.DataFrame(rows, columns=columns) except Exception as e: return pd.DataFrame({"Error": [str(e)]}) # Function to handle SQL file uploads and execute them def handle_file_upload(file): """ Handles SQL file upload, executes SQL, and updates database schema. Args: file (File): Uploaded SQL file. Returns: tuple: Execution result message and updated table data. """ temp_file_path = tempfile.mkstemp(suffix=".sql")[1] with open(temp_file_path, "wb") as temp_file: temp_file.write(file.read()) result = execute_sql_script(temp_file_path) tables = get_table_names() if tables: table_data = {table: get_table_data(table) for table in tables} else: table_data = {"Error": ["No tables found after execution. Ensure your SQL file creates tables."]} return result, table_data # Function to handle natural language to SQL conversion def query_sql(user_query: str) -> str: """ Converts a user's natural language query into an SQL query. Args: user_query (str): The question asked by the user. Returns: str: The results of the executed SQL query. """ tables = get_table_names() if not tables: return "Error: No tables found. Please upload an SQL file first." schema_info = "Available tables and columns:\n" for table in tables: columns = get_table_schema(table) schema_info += f"Table '{table}' has columns: {', '.join(columns)}.\n" schema_info += "Generate a valid SQL SELECT query using ONLY these column names. DO NOT return anything other than the SQL query itself." generated_sql = agent.run(f"{schema_info} Convert this request into SQL: {user_query}") if not isinstance(generated_sql, str) or not generated_sql.strip().lower().startswith(("select", "show", "pragma")): return "Error: Only SELECT queries are allowed." return sql_engine(generated_sql) # Function to handle query input def handle_query(user_input: str) -> str: """ Handles user input and returns the SQL query result. Args: user_input (str): User's natural language query. Returns: str: The query result or error message. """ return query_sql(user_input) # Gradio UI with gr.Blocks() as demo: gr.Markdown("## SQL Query Interface") with gr.Row(): user_input = gr.Textbox(label="Ask a question about the data") query_output = gr.Textbox(label="Result") user_input.change(fn=handle_query, inputs=user_input, outputs=query_output) gr.Markdown("## Upload SQL File to Execute") file_upload = gr.File(label="Upload SQL File") upload_output = gr.Textbox(label="Execution Result") # Dynamic table display table_output = gr.Dataframe(label="Database Tables (Dynamic)") file_upload.change(fn=handle_file_upload, inputs=file_upload, outputs=[upload_output, table_output]) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, share=True)