Update app.py
Browse files
app.py
CHANGED
@@ -12,9 +12,11 @@ from database import (
|
|
12 |
get_table_schema
|
13 |
)
|
14 |
|
15 |
-
def
|
16 |
"""
|
17 |
-
|
|
|
|
|
18 |
"""
|
19 |
try:
|
20 |
# Get list of tables
|
@@ -24,24 +26,30 @@ def get_data_table():
|
|
24 |
)).fetchall()
|
25 |
|
26 |
if not tables:
|
27 |
-
return
|
28 |
|
29 |
# Use the first table found
|
30 |
table_name = tables[0][0]
|
31 |
-
|
|
|
32 |
with engine.connect() as con:
|
33 |
-
|
34 |
-
rows = result.fetchall()
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
43 |
except Exception as e:
|
44 |
-
|
|
|
45 |
|
46 |
def process_sql_file(file_path):
|
47 |
"""
|
@@ -152,14 +160,22 @@ def query_sql(user_query: str) -> str:
|
|
152 |
"""
|
153 |
Converts natural language input to an SQL query using CodeAgent.
|
154 |
"""
|
155 |
-
|
156 |
-
|
|
|
|
|
157 |
return "Error: No data table exists. Please upload a file first."
|
158 |
-
|
|
|
159 |
schema_info = (
|
160 |
-
f"The database has the following
|
161 |
-
|
162 |
-
|
|
|
|
|
|
|
|
|
|
|
163 |
"DO NOT explain your reasoning, and DO NOT return anything other than the SQL query itself."
|
164 |
)
|
165 |
|
@@ -171,6 +187,17 @@ def query_sql(user_query: str) -> str:
|
|
171 |
if not generated_sql.strip().lower().startswith(("select", "show", "pragma")):
|
172 |
return "Error: Only SELECT queries are allowed."
|
173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
result = sql_engine(generated_sql)
|
175 |
|
176 |
try:
|
|
|
12 |
get_table_schema
|
13 |
)
|
14 |
|
15 |
+
def get_table_info():
|
16 |
"""
|
17 |
+
Gets the current table name and column information.
|
18 |
+
Returns:
|
19 |
+
tuple: (table_name, list of column names, column info)
|
20 |
"""
|
21 |
try:
|
22 |
# Get list of tables
|
|
|
26 |
)).fetchall()
|
27 |
|
28 |
if not tables:
|
29 |
+
return None, [], {}
|
30 |
|
31 |
# Use the first table found
|
32 |
table_name = tables[0][0]
|
33 |
+
|
34 |
+
# Get column information
|
35 |
with engine.connect() as con:
|
36 |
+
columns = con.execute(text(f"PRAGMA table_info({table_name})")).fetchall()
|
|
|
37 |
|
38 |
+
# Extract column names and types
|
39 |
+
column_names = [col[1] for col in columns]
|
40 |
+
column_info = {
|
41 |
+
col[1]: {
|
42 |
+
'type': col[2],
|
43 |
+
'is_primary': bool(col[5])
|
44 |
+
}
|
45 |
+
for col in columns
|
46 |
+
}
|
47 |
+
|
48 |
+
return table_name, column_names, column_info
|
49 |
+
|
50 |
except Exception as e:
|
51 |
+
print(f"Error getting table info: {str(e)}")
|
52 |
+
return None, [], {}
|
53 |
|
54 |
def process_sql_file(file_path):
|
55 |
"""
|
|
|
160 |
"""
|
161 |
Converts natural language input to an SQL query using CodeAgent.
|
162 |
"""
|
163 |
+
# Get current table information
|
164 |
+
table_name, column_names, column_info = get_table_info()
|
165 |
+
|
166 |
+
if not table_name:
|
167 |
return "Error: No data table exists. Please upload a file first."
|
168 |
+
|
169 |
+
# Create schema information with actual column names
|
170 |
schema_info = (
|
171 |
+
f"The database has a table named '{table_name}' with the following columns:\n"
|
172 |
+
+ "\n".join([
|
173 |
+
f"- {col} ({info['type']}){' primary key' if info['is_primary'] else ''}"
|
174 |
+
for col, info in column_info.items()
|
175 |
+
])
|
176 |
+
+ "\n\nGenerate a valid SQL SELECT query using ONLY these column names.\n"
|
177 |
+
"The table name is '" + table_name + "'.\n"
|
178 |
+
"If column names contain spaces, they must be quoted.\n"
|
179 |
"DO NOT explain your reasoning, and DO NOT return anything other than the SQL query itself."
|
180 |
)
|
181 |
|
|
|
187 |
if not generated_sql.strip().lower().startswith(("select", "show", "pragma")):
|
188 |
return "Error: Only SELECT queries are allowed."
|
189 |
|
190 |
+
# Fix table names
|
191 |
+
for wrong_name in ['table_name', 'customers', 'main']:
|
192 |
+
if wrong_name in generated_sql:
|
193 |
+
generated_sql = generated_sql.replace(wrong_name, table_name)
|
194 |
+
|
195 |
+
# Add quotes around column names that need them
|
196 |
+
for col in column_names:
|
197 |
+
if ' ' in col: # If column name contains spaces
|
198 |
+
if col in generated_sql and f'"{col}"' not in generated_sql and f'`{col}`' not in generated_sql:
|
199 |
+
generated_sql = generated_sql.replace(col, f'"{col}"')
|
200 |
+
|
201 |
result = sql_engine(generated_sql)
|
202 |
|
203 |
try:
|