Quazim0t0 commited on
Commit
5a55ea7
ยท
verified ยท
1 Parent(s): 18bb121

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -95
app.py CHANGED
@@ -1,144 +1,124 @@
1
  import os
2
  import gradio as gr
3
- from sqlalchemy import text
4
  from smolagents import tool, CodeAgent, HfApiModel
5
- import spaces
6
- import pandas as pd
7
- from database import engine, receipts
8
  import pandas as pd
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- def get_receipts_table():
11
- """
12
- Fetches all data from the 'receipts' table and returns it as a Pandas DataFrame.
 
13
 
14
- Returns:
15
- A Pandas DataFrame containing all receipt data.
16
- """
 
 
 
 
 
17
  try:
18
  with engine.connect() as con:
19
- result = con.execute(text("SELECT * FROM receipts"))
20
  rows = result.fetchall()
21
-
22
- if not rows:
23
- return pd.DataFrame(columns=["receipt_id", "customer_name", "price", "tip"])
24
 
25
- # Convert rows into a DataFrame
26
- df = pd.DataFrame(rows, columns=["receipt_id", "customer_name", "price", "tip"])
27
- return df
28
 
 
 
 
 
29
  except Exception as e:
30
- return pd.DataFrame({"Error": [str(e)]}) # Return error message in DataFrame format
31
 
 
32
  @tool
33
  def sql_engine(query: str) -> str:
34
- """
35
- Executes an SQL query on the 'receipts' table and returns formatted results.
36
-
37
- Args:
38
- query: The SQL query to execute.
39
-
40
- Returns:
41
- Query result as a formatted string.
42
- """
43
  try:
44
  with engine.connect() as con:
45
  rows = con.execute(text(query)).fetchall()
46
-
47
  if not rows:
48
  return "No results found."
49
-
50
- if len(rows) == 1 and len(rows[0]) == 1:
51
- return str(rows[0][0]) # Convert numerical result to string
52
-
53
  return "\n".join([", ".join(map(str, row)) for row in rows])
54
-
55
  except Exception as e:
56
  return f"Error: {str(e)}"
57
 
 
58
  def query_sql(user_query: str) -> str:
59
- """
60
- Converts natural language input to an SQL query using CodeAgent
61
- and returns the execution results.
62
-
63
- Args:
64
- user_query: The user's request in natural language.
65
-
66
- Returns:
67
- The query result from the database as a formatted string.
68
- """
69
-
70
- schema_info = (
71
- "The database has a table named 'receipts' with the following schema:\n"
72
- "- receipt_id (INTEGER, primary key)\n"
73
- "- customer_name (VARCHAR(16))\n"
74
- "- price (FLOAT)\n"
75
- "- tip (FLOAT)\n"
76
- "Generate a valid SQL SELECT query using ONLY these column names.\n"
77
- "DO NOT explain your reasoning, and DO NOT return anything other than the SQL query itself."
78
- )
79
-
80
- generated_sql = agent.run(f"{schema_info} Convert this request into SQL: {user_query}")
81
 
82
- if not isinstance(generated_sql, str):
83
- return f"{generated_sql}" # Handle unexpected numerical result
84
 
85
- print(f"{generated_sql}")
86
 
87
- if not generated_sql.strip().lower().startswith(("select", "show", "pragma")):
88
  return "Error: Only SELECT queries are allowed."
89
 
90
- result = sql_engine(generated_sql)
91
-
92
- print(f"{result}")
93
-
94
- try:
95
- float_result = float(result)
96
- return f"{float_result:.2f}"
97
- except ValueError:
98
- return result
99
 
 
100
  def handle_query(user_input: str) -> str:
101
- """
102
- Calls query_sql, captures the output, and directly returns it to the UI.
103
 
104
- Args:
105
- user_input: The user's natural language question.
 
 
 
 
 
 
106
 
107
- Returns:
108
- The SQL query result as a plain string to be displayed in the UI.
109
- """
110
- return query_sql(user_input)
111
 
 
 
 
112
  agent = CodeAgent(
113
  tools=[sql_engine],
114
  model=HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
115
  )
116
 
 
117
  with gr.Blocks() as demo:
118
- gr.Markdown("""
119
- ## Plain Text Query Interface
120
-
121
- This tool allows you to query a receipts database using natural language. Simply type your question into the input box, and the tool will generate and execute an SQL query to retrieve relevant data. The results will be displayed in the output box.
122
-
123
- ### How to Use:
124
- 1. Enter a question related to the receipts data in the text box.
125
- 2. The tool will convert your question into an SQL query and execute it.
126
- 3. The result will be displayed in the output box.
127
- 4. The current receipts table is also displayed for reference.
128
- """)
129
 
130
  with gr.Row():
131
- with gr.Column(scale=1):
132
- user_input = gr.Textbox(label="Ask a question about the data")
133
- query_output = gr.Textbox(label="Result")
134
-
135
- with gr.Column(scale=2):
136
- gr.Markdown("### Receipts Table")
137
- receipts_table = gr.Dataframe(value=get_receipts_table(), label="Receipts Table")
138
 
139
  user_input.change(fn=handle_query, inputs=user_input, outputs=query_output)
140
 
141
- demo.load(fn=get_receipts_table, outputs=receipts_table)
 
 
 
 
 
 
 
142
 
143
  if __name__ == "__main__":
144
  demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
1
  import os
2
  import gradio as gr
3
+ from sqlalchemy import text, create_engine, inspect
4
  from smolagents import tool, CodeAgent, HfApiModel
 
 
 
5
  import pandas as pd
6
+ import tempfile
7
+ from database import engine
8
+
9
+ # Function to execute SQL script from uploaded file
10
+ def execute_sql_script(file_path):
11
+ try:
12
+ with engine.connect() as con:
13
+ with open(file_path, "r") as f:
14
+ sql_script = f.read()
15
+ con.execute(text(sql_script))
16
+ return "SQL file executed successfully."
17
+ except Exception as e:
18
+ return f"Error: {str(e)}"
19
 
20
+ # Function to fetch table names dynamically
21
+ def get_table_names():
22
+ inspector = inspect(engine)
23
+ return inspector.get_table_names()
24
 
25
+ # Function to fetch table schema dynamically
26
+ def get_table_schema(table_name):
27
+ inspector = inspect(engine)
28
+ columns = inspector.get_columns(table_name)
29
+ return [col["name"] for col in columns]
30
+
31
+ # Function to fetch table data dynamically
32
+ def get_table_data(table_name):
33
  try:
34
  with engine.connect() as con:
35
+ result = con.execute(text(f"SELECT * FROM {table_name}"))
36
  rows = result.fetchall()
 
 
 
37
 
38
+ columns = get_table_schema(table_name)
 
 
39
 
40
+ if not rows:
41
+ return pd.DataFrame(columns=columns)
42
+
43
+ return pd.DataFrame(rows, columns=columns)
44
  except Exception as e:
45
+ return pd.DataFrame({"Error": [str(e)]})
46
 
47
+ # SQL Execution Tool
48
  @tool
49
  def sql_engine(query: str) -> str:
 
 
 
 
 
 
 
 
 
50
  try:
51
  with engine.connect() as con:
52
  rows = con.execute(text(query)).fetchall()
 
53
  if not rows:
54
  return "No results found."
 
 
 
 
55
  return "\n".join([", ".join(map(str, row)) for row in rows])
 
56
  except Exception as e:
57
  return f"Error: {str(e)}"
58
 
59
+ # Function to generate and execute SQL queries dynamically
60
  def query_sql(user_query: str) -> str:
61
+ # Get schema details dynamically
62
+ tables = get_table_names()
63
+ schema_info = "Available tables and columns:\n"
64
+
65
+ for table in tables:
66
+ columns = get_table_schema(table)
67
+ schema_info += f"Table '{table}' has columns: {', '.join(columns)}.\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ schema_info += "Generate a valid SQL SELECT query using ONLY these column names. DO NOT return anything other than the SQL query itself."
 
70
 
71
+ generated_sql = agent.run(f"{schema_info} Convert this request into SQL: {user_query}")
72
 
73
+ if not isinstance(generated_sql, str) or not generated_sql.strip().lower().startswith(("select", "show", "pragma")):
74
  return "Error: Only SELECT queries are allowed."
75
 
76
+ return sql_engine(generated_sql)
 
 
 
 
 
 
 
 
77
 
78
+ # Function to handle query input
79
  def handle_query(user_input: str) -> str:
80
+ return query_sql(user_input)
 
81
 
82
+ # Function to handle SQL file uploads
83
+ def handle_file_upload(file):
84
+ temp_file_path = tempfile.mkstemp(suffix=".sql")[1]
85
+ with open(temp_file_path, "wb") as temp_file:
86
+ temp_file.write(file.read())
87
+
88
+ result = execute_sql_script(temp_file_path)
89
+ tables = get_table_names()
90
 
91
+ if tables:
92
+ table_data = {table: get_table_data(table) for table in tables}
93
+ else:
94
+ table_data = {}
95
 
96
+ return result, table_data
97
+
98
+ # Initialize CodeAgent for SQL query generation
99
  agent = CodeAgent(
100
  tools=[sql_engine],
101
  model=HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
102
  )
103
 
104
+ # Gradio UI
105
  with gr.Blocks() as demo:
106
+ gr.Markdown("## SQL Query Interface")
 
 
 
 
 
 
 
 
 
 
107
 
108
  with gr.Row():
109
+ user_input = gr.Textbox(label="Ask a question about the data")
110
+ query_output = gr.Textbox(label="Result")
 
 
 
 
 
111
 
112
  user_input.change(fn=handle_query, inputs=user_input, outputs=query_output)
113
 
114
+ gr.Markdown("## Upload SQL File to Execute")
115
+ file_upload = gr.File(label="Upload SQL File")
116
+ upload_output = gr.Textbox(label="Execution Result")
117
+
118
+ # Dynamic table display
119
+ table_output = gr.Dataframe(label="Database Tables (Dynamic)")
120
+
121
+ file_upload.change(fn=handle_file_upload, inputs=file_upload, outputs=[upload_output, table_output])
122
 
123
  if __name__ == "__main__":
124
  demo.launch(server_name="0.0.0.0", server_port=7860, share=True)