Quazim0t0 commited on
Commit
f776bb6
ยท
verified ยท
1 Parent(s): 6ed45c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -20
app.py CHANGED
@@ -12,9 +12,11 @@ from database import (
12
  get_table_schema
13
  )
14
 
15
- def get_data_table():
16
  """
17
- Fetches all data from the current table and returns it as a Pandas DataFrame.
 
 
18
  """
19
  try:
20
  # Get list of tables
@@ -24,24 +26,30 @@ def get_data_table():
24
  )).fetchall()
25
 
26
  if not tables:
27
- return pd.DataFrame()
28
 
29
  # Use the first table found
30
  table_name = tables[0][0]
31
-
 
32
  with engine.connect() as con:
33
- result = con.execute(text(f"SELECT * FROM {table_name}"))
34
- rows = result.fetchall()
35
 
36
- if not rows:
37
- return pd.DataFrame()
38
-
39
- columns = result.keys()
40
- df = pd.DataFrame(rows, columns=columns)
41
- return df
42
-
 
 
 
 
 
43
  except Exception as e:
44
- return pd.DataFrame({"Error": [str(e)]})
 
45
 
46
  def process_sql_file(file_path):
47
  """
@@ -152,14 +160,22 @@ def query_sql(user_query: str) -> str:
152
  """
153
  Converts natural language input to an SQL query using CodeAgent.
154
  """
155
- schema = get_table_schema()
156
- if not schema:
 
 
157
  return "Error: No data table exists. Please upload a file first."
158
-
 
159
  schema_info = (
160
- f"The database has the following schema:\n"
161
- f"{schema}\n"
162
- "Generate a valid SQL SELECT query using ONLY these column names.\n"
 
 
 
 
 
163
  "DO NOT explain your reasoning, and DO NOT return anything other than the SQL query itself."
164
  )
165
 
@@ -171,6 +187,17 @@ def query_sql(user_query: str) -> str:
171
  if not generated_sql.strip().lower().startswith(("select", "show", "pragma")):
172
  return "Error: Only SELECT queries are allowed."
173
 
 
 
 
 
 
 
 
 
 
 
 
174
  result = sql_engine(generated_sql)
175
 
176
  try:
 
12
  get_table_schema
13
  )
14
 
15
+ def get_table_info():
16
  """
17
+ Gets the current table name and column information.
18
+ Returns:
19
+ tuple: (table_name, list of column names, column info)
20
  """
21
  try:
22
  # Get list of tables
 
26
  )).fetchall()
27
 
28
  if not tables:
29
+ return None, [], {}
30
 
31
  # Use the first table found
32
  table_name = tables[0][0]
33
+
34
+ # Get column information
35
  with engine.connect() as con:
36
+ columns = con.execute(text(f"PRAGMA table_info({table_name})")).fetchall()
 
37
 
38
+ # Extract column names and types
39
+ column_names = [col[1] for col in columns]
40
+ column_info = {
41
+ col[1]: {
42
+ 'type': col[2],
43
+ 'is_primary': bool(col[5])
44
+ }
45
+ for col in columns
46
+ }
47
+
48
+ return table_name, column_names, column_info
49
+
50
  except Exception as e:
51
+ print(f"Error getting table info: {str(e)}")
52
+ return None, [], {}
53
 
54
  def process_sql_file(file_path):
55
  """
 
160
  """
161
  Converts natural language input to an SQL query using CodeAgent.
162
  """
163
+ # Get current table information
164
+ table_name, column_names, column_info = get_table_info()
165
+
166
+ if not table_name:
167
  return "Error: No data table exists. Please upload a file first."
168
+
169
+ # Create schema information with actual column names
170
  schema_info = (
171
+ f"The database has a table named '{table_name}' with the following columns:\n"
172
+ + "\n".join([
173
+ f"- {col} ({info['type']}){' primary key' if info['is_primary'] else ''}"
174
+ for col, info in column_info.items()
175
+ ])
176
+ + "\n\nGenerate a valid SQL SELECT query using ONLY these column names.\n"
177
+ "The table name is '" + table_name + "'.\n"
178
+ "If column names contain spaces, they must be quoted.\n"
179
  "DO NOT explain your reasoning, and DO NOT return anything other than the SQL query itself."
180
  )
181
 
 
187
  if not generated_sql.strip().lower().startswith(("select", "show", "pragma")):
188
  return "Error: Only SELECT queries are allowed."
189
 
190
+ # Fix table names
191
+ for wrong_name in ['table_name', 'customers', 'main']:
192
+ if wrong_name in generated_sql:
193
+ generated_sql = generated_sql.replace(wrong_name, table_name)
194
+
195
+ # Add quotes around column names that need them
196
+ for col in column_names:
197
+ if ' ' in col: # If column name contains spaces
198
+ if col in generated_sql and f'"{col}"' not in generated_sql and f'`{col}`' not in generated_sql:
199
+ generated_sql = generated_sql.replace(col, f'"{col}"')
200
+
201
  result = sql_engine(generated_sql)
202
 
203
  try: