Update app.py
Browse files
app.py
CHANGED
@@ -4,10 +4,14 @@ from sqlalchemy import text, create_engine, inspect
|
|
4 |
from smolagents import tool, CodeAgent, HfApiModel
|
5 |
import pandas as pd
|
6 |
import tempfile
|
7 |
-
from database import engine
|
|
|
|
|
|
|
8 |
|
9 |
# Function to execute SQL script from uploaded file
|
10 |
def execute_sql_script(file_path):
|
|
|
11 |
try:
|
12 |
with engine.connect() as con:
|
13 |
with open(file_path, "r") as f:
|
@@ -19,17 +23,20 @@ def execute_sql_script(file_path):
|
|
19 |
|
20 |
# Function to fetch table names dynamically
|
21 |
def get_table_names():
|
|
|
22 |
inspector = inspect(engine)
|
23 |
return inspector.get_table_names()
|
24 |
|
25 |
# Function to fetch table schema dynamically
|
26 |
def get_table_schema(table_name):
|
|
|
27 |
inspector = inspect(engine)
|
28 |
columns = inspector.get_columns(table_name)
|
29 |
return [col["name"] for col in columns]
|
30 |
|
31 |
# Function to fetch table data dynamically
|
32 |
def get_table_data(table_name):
|
|
|
33 |
try:
|
34 |
with engine.connect() as con:
|
35 |
result = con.execute(text(f"SELECT * FROM {table_name}"))
|
@@ -47,6 +54,7 @@ def get_table_data(table_name):
|
|
47 |
# SQL Execution Tool
|
48 |
@tool
|
49 |
def sql_engine(query: str) -> str:
|
|
|
50 |
try:
|
51 |
with engine.connect() as con:
|
52 |
rows = con.execute(text(query)).fetchall()
|
@@ -58,8 +66,11 @@ def sql_engine(query: str) -> str:
|
|
58 |
|
59 |
# Function to generate and execute SQL queries dynamically
|
60 |
def query_sql(user_query: str) -> str:
|
61 |
-
|
62 |
tables = get_table_names()
|
|
|
|
|
|
|
63 |
schema_info = "Available tables and columns:\n"
|
64 |
|
65 |
for table in tables:
|
@@ -77,10 +88,12 @@ def query_sql(user_query: str) -> str:
|
|
77 |
|
78 |
# Function to handle query input
|
79 |
def handle_query(user_input: str) -> str:
|
|
|
80 |
return query_sql(user_input)
|
81 |
|
82 |
# Function to handle SQL file uploads
|
83 |
def handle_file_upload(file):
|
|
|
84 |
temp_file_path = tempfile.mkstemp(suffix=".sql")[1]
|
85 |
with open(temp_file_path, "wb") as temp_file:
|
86 |
temp_file.write(file.read())
|
@@ -91,7 +104,7 @@ def handle_file_upload(file):
|
|
91 |
if tables:
|
92 |
table_data = {table: get_table_data(table) for table in tables}
|
93 |
else:
|
94 |
-
table_data = {}
|
95 |
|
96 |
return result, table_data
|
97 |
|
@@ -122,3 +135,4 @@ with gr.Blocks() as demo:
|
|
122 |
|
123 |
if __name__ == "__main__":
|
124 |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|
|
|
|
4 |
from smolagents import tool, CodeAgent, HfApiModel
|
5 |
import pandas as pd
|
6 |
import tempfile
|
7 |
+
from database import engine, initialize_database
|
8 |
+
|
9 |
+
# Ensure the database initializes
|
10 |
+
initialize_database()
|
11 |
|
12 |
# Function to execute SQL script from uploaded file
|
13 |
def execute_sql_script(file_path):
|
14 |
+
"""Executes an uploaded SQL file to initialize the database."""
|
15 |
try:
|
16 |
with engine.connect() as con:
|
17 |
with open(file_path, "r") as f:
|
|
|
23 |
|
24 |
# Function to fetch table names dynamically
|
25 |
def get_table_names():
|
26 |
+
"""Returns a list of all tables in the database."""
|
27 |
inspector = inspect(engine)
|
28 |
return inspector.get_table_names()
|
29 |
|
30 |
# Function to fetch table schema dynamically
|
31 |
def get_table_schema(table_name):
|
32 |
+
"""Returns a list of column names for a given table."""
|
33 |
inspector = inspect(engine)
|
34 |
columns = inspector.get_columns(table_name)
|
35 |
return [col["name"] for col in columns]
|
36 |
|
37 |
# Function to fetch table data dynamically
|
38 |
def get_table_data(table_name):
|
39 |
+
"""Retrieves all rows from the specified table as a Pandas DataFrame."""
|
40 |
try:
|
41 |
with engine.connect() as con:
|
42 |
result = con.execute(text(f"SELECT * FROM {table_name}"))
|
|
|
54 |
# SQL Execution Tool
|
55 |
@tool
|
56 |
def sql_engine(query: str) -> str:
|
57 |
+
"""Executes an SQL SELECT query and returns formatted results."""
|
58 |
try:
|
59 |
with engine.connect() as con:
|
60 |
rows = con.execute(text(query)).fetchall()
|
|
|
66 |
|
67 |
# Function to generate and execute SQL queries dynamically
|
68 |
def query_sql(user_query: str) -> str:
|
69 |
+
"""Processes a user’s natural language query and generates an SQL query dynamically."""
|
70 |
tables = get_table_names()
|
71 |
+
if not tables:
|
72 |
+
return "Error: No tables found. Please upload an SQL file first."
|
73 |
+
|
74 |
schema_info = "Available tables and columns:\n"
|
75 |
|
76 |
for table in tables:
|
|
|
88 |
|
89 |
# Function to handle query input
|
90 |
def handle_query(user_input: str) -> str:
|
91 |
+
"""Handles user input and returns the SQL query result."""
|
92 |
return query_sql(user_input)
|
93 |
|
94 |
# Function to handle SQL file uploads
|
95 |
def handle_file_upload(file):
|
96 |
+
"""Handles file upload, executes SQL, and updates database schema dynamically."""
|
97 |
temp_file_path = tempfile.mkstemp(suffix=".sql")[1]
|
98 |
with open(temp_file_path, "wb") as temp_file:
|
99 |
temp_file.write(file.read())
|
|
|
104 |
if tables:
|
105 |
table_data = {table: get_table_data(table) for table in tables}
|
106 |
else:
|
107 |
+
table_data = {"Error": ["No tables found after execution. Ensure your SQL file creates tables."]}
|
108 |
|
109 |
return result, table_data
|
110 |
|
|
|
135 |
|
136 |
if __name__ == "__main__":
|
137 |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|
138 |
+
|