Update app.py
Browse files
app.py
CHANGED
@@ -1,144 +1,124 @@
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
-
from sqlalchemy import text
|
4 |
from smolagents import tool, CodeAgent, HfApiModel
|
5 |
-
import spaces
|
6 |
-
import pandas as pd
|
7 |
-
from database import engine, receipts
|
8 |
import pandas as pd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
17 |
try:
|
18 |
with engine.connect() as con:
|
19 |
-
result = con.execute(text("SELECT * FROM
|
20 |
rows = result.fetchall()
|
21 |
-
|
22 |
-
if not rows:
|
23 |
-
return pd.DataFrame(columns=["receipt_id", "customer_name", "price", "tip"])
|
24 |
|
25 |
-
|
26 |
-
df = pd.DataFrame(rows, columns=["receipt_id", "customer_name", "price", "tip"])
|
27 |
-
return df
|
28 |
|
|
|
|
|
|
|
|
|
29 |
except Exception as e:
|
30 |
-
return pd.DataFrame({"Error": [str(e)]})
|
31 |
|
|
|
32 |
@tool
|
33 |
def sql_engine(query: str) -> str:
|
34 |
-
"""
|
35 |
-
Executes an SQL query on the 'receipts' table and returns formatted results.
|
36 |
-
|
37 |
-
Args:
|
38 |
-
query: The SQL query to execute.
|
39 |
-
|
40 |
-
Returns:
|
41 |
-
Query result as a formatted string.
|
42 |
-
"""
|
43 |
try:
|
44 |
with engine.connect() as con:
|
45 |
rows = con.execute(text(query)).fetchall()
|
46 |
-
|
47 |
if not rows:
|
48 |
return "No results found."
|
49 |
-
|
50 |
-
if len(rows) == 1 and len(rows[0]) == 1:
|
51 |
-
return str(rows[0][0]) # Convert numerical result to string
|
52 |
-
|
53 |
return "\n".join([", ".join(map(str, row)) for row in rows])
|
54 |
-
|
55 |
except Exception as e:
|
56 |
return f"Error: {str(e)}"
|
57 |
|
|
|
58 |
def query_sql(user_query: str) -> str:
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
Returns:
|
67 |
-
The query result from the database as a formatted string.
|
68 |
-
"""
|
69 |
-
|
70 |
-
schema_info = (
|
71 |
-
"The database has a table named 'receipts' with the following schema:\n"
|
72 |
-
"- receipt_id (INTEGER, primary key)\n"
|
73 |
-
"- customer_name (VARCHAR(16))\n"
|
74 |
-
"- price (FLOAT)\n"
|
75 |
-
"- tip (FLOAT)\n"
|
76 |
-
"Generate a valid SQL SELECT query using ONLY these column names.\n"
|
77 |
-
"DO NOT explain your reasoning, and DO NOT return anything other than the SQL query itself."
|
78 |
-
)
|
79 |
-
|
80 |
-
generated_sql = agent.run(f"{schema_info} Convert this request into SQL: {user_query}")
|
81 |
|
82 |
-
|
83 |
-
return f"{generated_sql}" # Handle unexpected numerical result
|
84 |
|
85 |
-
|
86 |
|
87 |
-
if not generated_sql.strip().lower().startswith(("select", "show", "pragma")):
|
88 |
return "Error: Only SELECT queries are allowed."
|
89 |
|
90 |
-
|
91 |
-
|
92 |
-
print(f"{result}")
|
93 |
-
|
94 |
-
try:
|
95 |
-
float_result = float(result)
|
96 |
-
return f"{float_result:.2f}"
|
97 |
-
except ValueError:
|
98 |
-
return result
|
99 |
|
|
|
100 |
def handle_query(user_input: str) -> str:
|
101 |
-
|
102 |
-
Calls query_sql, captures the output, and directly returns it to the UI.
|
103 |
|
104 |
-
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
|
|
|
|
|
|
|
112 |
agent = CodeAgent(
|
113 |
tools=[sql_engine],
|
114 |
model=HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
|
115 |
)
|
116 |
|
|
|
117 |
with gr.Blocks() as demo:
|
118 |
-
gr.Markdown(""
|
119 |
-
## Plain Text Query Interface
|
120 |
-
|
121 |
-
This tool allows you to query a receipts database using natural language. Simply type your question into the input box, and the tool will generate and execute an SQL query to retrieve relevant data. The results will be displayed in the output box.
|
122 |
-
|
123 |
-
### How to Use:
|
124 |
-
1. Enter a question related to the receipts data in the text box.
|
125 |
-
2. The tool will convert your question into an SQL query and execute it.
|
126 |
-
3. The result will be displayed in the output box.
|
127 |
-
4. The current receipts table is also displayed for reference.
|
128 |
-
""")
|
129 |
|
130 |
with gr.Row():
|
131 |
-
|
132 |
-
|
133 |
-
query_output = gr.Textbox(label="Result")
|
134 |
-
|
135 |
-
with gr.Column(scale=2):
|
136 |
-
gr.Markdown("### Receipts Table")
|
137 |
-
receipts_table = gr.Dataframe(value=get_receipts_table(), label="Receipts Table")
|
138 |
|
139 |
user_input.change(fn=handle_query, inputs=user_input, outputs=query_output)
|
140 |
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
if __name__ == "__main__":
|
144 |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
+
from sqlalchemy import text, create_engine, inspect
|
4 |
from smolagents import tool, CodeAgent, HfApiModel
|
|
|
|
|
|
|
5 |
import pandas as pd
|
6 |
+
import tempfile
|
7 |
+
from database import engine
|
8 |
+
|
9 |
+
# Function to execute SQL script from uploaded file
|
10 |
+
def execute_sql_script(file_path):
|
11 |
+
try:
|
12 |
+
with engine.connect() as con:
|
13 |
+
with open(file_path, "r") as f:
|
14 |
+
sql_script = f.read()
|
15 |
+
con.execute(text(sql_script))
|
16 |
+
return "SQL file executed successfully."
|
17 |
+
except Exception as e:
|
18 |
+
return f"Error: {str(e)}"
|
19 |
|
20 |
+
# Function to fetch table names dynamically
|
21 |
+
def get_table_names():
|
22 |
+
inspector = inspect(engine)
|
23 |
+
return inspector.get_table_names()
|
24 |
|
25 |
+
# Function to fetch table schema dynamically
|
26 |
+
def get_table_schema(table_name):
|
27 |
+
inspector = inspect(engine)
|
28 |
+
columns = inspector.get_columns(table_name)
|
29 |
+
return [col["name"] for col in columns]
|
30 |
+
|
31 |
+
# Function to fetch table data dynamically
|
32 |
+
def get_table_data(table_name):
|
33 |
try:
|
34 |
with engine.connect() as con:
|
35 |
+
result = con.execute(text(f"SELECT * FROM {table_name}"))
|
36 |
rows = result.fetchall()
|
|
|
|
|
|
|
37 |
|
38 |
+
columns = get_table_schema(table_name)
|
|
|
|
|
39 |
|
40 |
+
if not rows:
|
41 |
+
return pd.DataFrame(columns=columns)
|
42 |
+
|
43 |
+
return pd.DataFrame(rows, columns=columns)
|
44 |
except Exception as e:
|
45 |
+
return pd.DataFrame({"Error": [str(e)]})
|
46 |
|
47 |
+
# SQL Execution Tool
|
48 |
@tool
|
49 |
def sql_engine(query: str) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
try:
|
51 |
with engine.connect() as con:
|
52 |
rows = con.execute(text(query)).fetchall()
|
|
|
53 |
if not rows:
|
54 |
return "No results found."
|
|
|
|
|
|
|
|
|
55 |
return "\n".join([", ".join(map(str, row)) for row in rows])
|
|
|
56 |
except Exception as e:
|
57 |
return f"Error: {str(e)}"
|
58 |
|
59 |
+
# Function to generate and execute SQL queries dynamically
|
60 |
def query_sql(user_query: str) -> str:
|
61 |
+
# Get schema details dynamically
|
62 |
+
tables = get_table_names()
|
63 |
+
schema_info = "Available tables and columns:\n"
|
64 |
+
|
65 |
+
for table in tables:
|
66 |
+
columns = get_table_schema(table)
|
67 |
+
schema_info += f"Table '{table}' has columns: {', '.join(columns)}.\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
+
schema_info += "Generate a valid SQL SELECT query using ONLY these column names. DO NOT return anything other than the SQL query itself."
|
|
|
70 |
|
71 |
+
generated_sql = agent.run(f"{schema_info} Convert this request into SQL: {user_query}")
|
72 |
|
73 |
+
if not isinstance(generated_sql, str) or not generated_sql.strip().lower().startswith(("select", "show", "pragma")):
|
74 |
return "Error: Only SELECT queries are allowed."
|
75 |
|
76 |
+
return sql_engine(generated_sql)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
+
# Function to handle query input
|
79 |
def handle_query(user_input: str) -> str:
|
80 |
+
return query_sql(user_input)
|
|
|
81 |
|
82 |
+
# Function to handle SQL file uploads
|
83 |
+
def handle_file_upload(file):
|
84 |
+
temp_file_path = tempfile.mkstemp(suffix=".sql")[1]
|
85 |
+
with open(temp_file_path, "wb") as temp_file:
|
86 |
+
temp_file.write(file.read())
|
87 |
+
|
88 |
+
result = execute_sql_script(temp_file_path)
|
89 |
+
tables = get_table_names()
|
90 |
|
91 |
+
if tables:
|
92 |
+
table_data = {table: get_table_data(table) for table in tables}
|
93 |
+
else:
|
94 |
+
table_data = {}
|
95 |
|
96 |
+
return result, table_data
|
97 |
+
|
98 |
+
# Initialize CodeAgent for SQL query generation
|
99 |
agent = CodeAgent(
|
100 |
tools=[sql_engine],
|
101 |
model=HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
|
102 |
)
|
103 |
|
104 |
+
# Gradio UI
|
105 |
with gr.Blocks() as demo:
|
106 |
+
gr.Markdown("## SQL Query Interface")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
with gr.Row():
|
109 |
+
user_input = gr.Textbox(label="Ask a question about the data")
|
110 |
+
query_output = gr.Textbox(label="Result")
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
user_input.change(fn=handle_query, inputs=user_input, outputs=query_output)
|
113 |
|
114 |
+
gr.Markdown("## Upload SQL File to Execute")
|
115 |
+
file_upload = gr.File(label="Upload SQL File")
|
116 |
+
upload_output = gr.Textbox(label="Execution Result")
|
117 |
+
|
118 |
+
# Dynamic table display
|
119 |
+
table_output = gr.Dataframe(label="Database Tables (Dynamic)")
|
120 |
+
|
121 |
+
file_upload.change(fn=handle_file_upload, inputs=file_upload, outputs=[upload_output, table_output])
|
122 |
|
123 |
if __name__ == "__main__":
|
124 |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|