File size: 5,987 Bytes
6a0ec6a
 
9002697
6a0ec6a
1767e22
5a55ea7
6349cf9
 
8912157
6349cf9
5a55ea7
9002697
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8912157
5a55ea7
28200f6
 
 
 
 
 
 
 
 
5a55ea7
 
 
 
 
8912157
5a55ea7
 
1767e22
8912157
5a55ea7
28200f6
8912157
28200f6
 
 
 
5a55ea7
 
20e319d
8912157
5a55ea7
28200f6
 
 
 
 
 
 
 
 
5a55ea7
 
 
 
8912157
5a55ea7
28200f6
8912157
28200f6
 
 
 
 
 
 
20e319d
 
5a55ea7
20e319d
 
5a55ea7
20e319d
5a55ea7
 
 
 
20e319d
5a55ea7
20e319d
8912157
5a55ea7
28200f6
8912157
28200f6
 
 
 
 
 
 
5a55ea7
 
 
 
 
 
edb7e14
5a55ea7
 
 
6349cf9
1f7ee11
5a55ea7
 
9002697
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f7ee11
5a55ea7
1767e22
5a55ea7
1767e22
 
5a55ea7
 
1767e22
 
 
5a55ea7
 
 
 
 
 
 
 
6a0ec6a
 
0380e03
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import os
import gradio as gr
from sqlalchemy import text, inspect
from smolagents import tool, CodeAgent, HfApiModel
import pandas as pd
import tempfile
from database import engine, initialize_database

# Ensure the database initializes (won't crash if empty)
initialize_database()

# SQL Execution Tool (FIXED - Defined BEFORE Use)
@tool
def sql_engine(query: str) -> str:
    """
    Executes an SQL SELECT query and returns the results.

    Args:
        query (str): The SQL query string to execute. Only SELECT queries are allowed.

    Returns:
        str: A formatted string containing the query results, or an error message if the query fails.
    """
    try:
        with engine.connect() as con:
            rows = con.execute(text(query)).fetchall()
        if not rows:
            return "No results found."
        return "\n".join([", ".join(map(str, row)) for row in rows])
    except Exception as e:
        return f"Error: {str(e)}"

# Initialize CodeAgent for SQL query generation (Moved Below `sql_engine`)
agent = CodeAgent(
    tools=[sql_engine],
    model=HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
)

# Function to execute an uploaded SQL script
def execute_sql_script(file_path):
    """
    Executes an uploaded SQL file to initialize the database.

    Args:
        file_path (str): Path to the SQL file.

    Returns:
        str: Success message or error description.
    """
    try:
        with engine.connect() as con:
            with open(file_path, "r") as f:
                sql_script = f.read()
                con.execute(text(sql_script))
        return "SQL file executed successfully. Database updated."
    except Exception as e:
        return f"Error: {str(e)}"

# Function to get table names dynamically
def get_table_names():
    """
    Returns a list of tables available in the database.

    Returns:
        list: List of table names.
    """
    inspector = inspect(engine)
    return inspector.get_table_names()

# Function to get table schema dynamically
def get_table_schema(table_name):
    """
    Returns a list of column names for a given table.

    Args:
        table_name (str): Name of the table.

    Returns:
        list: List of column names.
    """
    inspector = inspect(engine)
    columns = inspector.get_columns(table_name)
    return [col["name"] for col in columns]

# Function to fetch data dynamically from any table
def get_table_data(table_name):
    """
    Retrieves all rows from a specified table as a Pandas DataFrame.

    Args:
        table_name (str): Name of the table.

    Returns:
        pd.DataFrame: Table data or an error message.
    """
    try:
        with engine.connect() as con:
            result = con.execute(text(f"SELECT * FROM {table_name}"))
            rows = result.fetchall()

        columns = get_table_schema(table_name)

        if not rows:
            return pd.DataFrame(columns=columns)

        return pd.DataFrame(rows, columns=columns)
    except Exception as e:
        return pd.DataFrame({"Error": [str(e)]})

# Function to handle SQL file uploads and execute them
def handle_file_upload(file):
    """
    Handles SQL file upload, executes SQL, and updates database schema.

    Args:
        file (File): Uploaded SQL file.

    Returns:
        tuple: Execution result message and updated table data.
    """
    temp_file_path = tempfile.mkstemp(suffix=".sql")[1]
    with open(temp_file_path, "wb") as temp_file:
        temp_file.write(file.read())
    
    result = execute_sql_script(temp_file_path)
    tables = get_table_names()

    if tables:
        table_data = {table: get_table_data(table) for table in tables}
    else:
        table_data = {"Error": ["No tables found after execution. Ensure your SQL file creates tables."]}

    return result, table_data

# Function to handle natural language to SQL conversion
def query_sql(user_query: str) -> str:
    """
    Converts a user's natural language query into an SQL query.

    Args:
        user_query (str): The question asked by the user.

    Returns:
        str: The results of the executed SQL query.
    """
    tables = get_table_names()
    if not tables:
        return "Error: No tables found. Please upload an SQL file first."

    schema_info = "Available tables and columns:\n"
    
    for table in tables:
        columns = get_table_schema(table)
        schema_info += f"Table '{table}' has columns: {', '.join(columns)}.\n"

    schema_info += "Generate a valid SQL SELECT query using ONLY these column names. DO NOT return anything other than the SQL query itself."

    generated_sql = agent.run(f"{schema_info} Convert this request into SQL: {user_query}")

    if not isinstance(generated_sql, str) or not generated_sql.strip().lower().startswith(("select", "show", "pragma")):
        return "Error: Only SELECT queries are allowed."

    return sql_engine(generated_sql)

# Function to handle query input
def handle_query(user_input: str) -> str:
    """
    Handles user input and returns the SQL query result.

    Args:
        user_input (str): User's natural language query.

    Returns:
        str: The query result or error message.
    """
    return query_sql(user_input)

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## SQL Query Interface")

    with gr.Row():
        user_input = gr.Textbox(label="Ask a question about the data")
        query_output = gr.Textbox(label="Result")

    user_input.change(fn=handle_query, inputs=user_input, outputs=query_output)

    gr.Markdown("## Upload SQL File to Execute")
    file_upload = gr.File(label="Upload SQL File")
    upload_output = gr.Textbox(label="Execution Result")

    # Dynamic table display
    table_output = gr.Dataframe(label="Database Tables (Dynamic)")

    file_upload.change(fn=handle_file_upload, inputs=file_upload, outputs=[upload_output, table_output])

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, share=True)