File size: 4,693 Bytes
6a0ec6a 6c37e10 6a0ec6a 4f04c00 20e319d 6a0ec6a 79f396e 6c37e10 1767e22 20e319d 042246b 215368b 042246b 10e2935 042246b 7306c07 edb7e14 7306c07 edb7e14 7306c07 edb7e14 7306c07 61d9b40 35cddc5 61d9b40 35cddc5 61d9b40 2443195 61d9b40 f8c651a 61d9b40 2e81bab 61d9b40 2e81bab 54c2240 2443195 10e2935 2443195 10e2935 2443195 1f7ee11 215368b edb7e14 215368b edb7e14 215368b 1f7ee11 1767e22 6a0ec6a 0380e03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import os
import gradio as gr
from sqlalchemy import text
from smolagents import tool, CodeAgent, HfApiModel
import spaces
import pandas as pd
# Import the persistent database
from database import engine, receipts
import pandas as pd
def get_receipts_table():
"""
Fetches all data from the 'receipts' table and returns it as a Pandas DataFrame.
Returns:
A Pandas DataFrame containing all receipt data.
"""
try:
with engine.connect() as con:
result = con.execute(text("SELECT * FROM receipts"))
rows = result.fetchall()
if not rows:
return pd.DataFrame(columns=["receipt_id", "customer_name", "price", "tip"])
# Convert rows into a DataFrame
df = pd.DataFrame(rows, columns=["receipt_id", "customer_name", "price", "tip"])
return df
except Exception as e:
return pd.DataFrame({"Error": [str(e)]}) # Return error message in DataFrame format
@tool
def sql_engine(query: str) -> str:
"""
Executes an SQL query on the 'receipts' table and returns formatted results.
Args:
query: The SQL query to execute.
Returns:
Query result as a formatted string.
"""
try:
with engine.connect() as con:
rows = con.execute(text(query)).fetchall()
if not rows:
return "No results found."
if len(rows) == 1 and len(rows[0]) == 1:
return str(rows[0][0]) # Convert numerical result to string
return "\n".join([", ".join(map(str, row)) for row in rows])
except Exception as e:
return f"Error: {str(e)}"
def query_sql(user_query: str) -> str:
"""
Converts natural language input to an SQL query using CodeAgent
and returns the execution results.
Args:
user_query: The user's request in natural language.
Returns:
The query result from the database as a formatted string.
"""
# Provide the AI with the correct schema and strict instructions
schema_info = (
"The database has a table named 'receipts' with the following schema:\n"
"- receipt_id (INTEGER, primary key)\n"
"- customer_name (VARCHAR(16))\n"
"- price (FLOAT)\n"
"- tip (FLOAT)\n"
"Generate a valid SQL SELECT query using ONLY these column names.\n"
"DO NOT explain your reasoning, and DO NOT return anything other than the SQL query itself."
)
# Generate SQL query using the provided schema
generated_sql = agent.run(f"{schema_info} Convert this request into SQL: {user_query}")
# Ensure generated_sql is always a string
if not isinstance(generated_sql, str):
return f"Unexpected result: {generated_sql}" # Handle unexpected numerical result
# Log the generated SQL for debugging
print(f"Generated SQL: {generated_sql}")
if not generated_sql.strip().lower().startswith(("select", "show", "pragma")):
return "Error: Only SELECT queries are allowed."
result = sql_engine(generated_sql)
#print(f"SQL Query Result: {result}")
print(f"{result}")
try:
float_result = float(result)
return f"{float_result:.2f}"
except ValueError:
return result
def handle_query(user_input: str) -> str:
"""
Calls query_sql, captures the output, and directly returns it to the UI.
Args:
user_input: The user's natural language question.
Returns:
The SQL query result as a plain string to be displayed in the UI.
"""
return query_sql(user_input) # Directly return the processed result
# Initialize CodeAgent to generate SQL queries from natural language
agent = CodeAgent(
tools=[sql_engine], # Ensure sql_engine is properly registered
model=HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
)
with gr.Blocks() as demo:
gr.Markdown("## Natural Language to SQL Executor with Live Data")
with gr.Row():
with gr.Column(scale=1): # Left: Query Interface
user_input = gr.Textbox(label="Enter your query in plain English")
query_output = gr.Textbox(label="Query Result")
with gr.Column(scale=2): # Right: Live Database Table
gr.Markdown("### Receipts Table (Live View)")
receipts_table = gr.Dataframe(value=get_receipts_table(), label="Receipts Table")
# Query handling function
user_input.change(fn=handle_query, inputs=user_input, outputs=query_output)
# Auto-refresh table every 5 seconds
demo.load(fn=get_receipts_table, outputs=receipts_table, every=5)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|