Spaces:
Sleeping
Sleeping
nileshhanotia
commited on
Commit
•
419da4a
1
Parent(s):
a93ae9f
Create db_query.py
Browse files- db_query.py +131 -0
db_query.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import requests
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, DistilBertTokenizer, DistilBertForSequenceClassification
|
4 |
+
|
5 |
+
# Shopify API credentials
|
6 |
+
API_KEY = 'c7c07dac8f8eaebd316ac4aa16350472'
|
7 |
+
API_SECRET = '046b2c0bfb7879db2bba86f24a281ec2'
|
8 |
+
ACCESS_TOKEN = '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'
|
9 |
+
SHOP_NAME = '6znwwf-77.myshopify.com'
|
10 |
+
|
11 |
+
# Load the SQL generation model
|
12 |
+
model_name = "premai-io/prem-1B-SQL"
|
13 |
+
tokenizer_sql = AutoTokenizer.from_pretrained(model_name)
|
14 |
+
model_sql = AutoModelForCausalLM.from_pretrained(model_name)
|
15 |
+
|
16 |
+
# Load the DistilGPT-2 model and tokenizer for processing results
|
17 |
+
distil_gpt2_model_name = "distilgpt2"
|
18 |
+
tokenizer_gpt2 = AutoTokenizer.from_pretrained(distil_gpt2_model_name)
|
19 |
+
model_gpt2 = AutoModelForCausalLM.from_pretrained(distil_gpt2_model_name)
|
20 |
+
|
21 |
+
def fetch_data_from_shopify(endpoint):
|
22 |
+
"""Fetch data from Shopify API."""
|
23 |
+
headers = {
|
24 |
+
'X-Shopify-Access-Token': ACCESS_TOKEN,
|
25 |
+
'Content-Type': 'application/json'
|
26 |
+
}
|
27 |
+
url = f"https://{SHOP_NAME}/admin/api/2023-10/{endpoint}.json"
|
28 |
+
response = requests.get(url, headers=headers)
|
29 |
+
|
30 |
+
if response.status_code == 200:
|
31 |
+
return response.json()
|
32 |
+
else:
|
33 |
+
print(f"Error fetching {endpoint}: {response.status_code} - {response.text}")
|
34 |
+
return None
|
35 |
+
|
36 |
+
def generate_sql(natural_language_query):
|
37 |
+
"""Generate SQL query from natural language and retrieve data from Shopify."""
|
38 |
+
|
39 |
+
# Define your schema information
|
40 |
+
schema_info = """
|
41 |
+
CREATE TABLE products (
|
42 |
+
id DECIMAL(8,2) PRIMARY KEY,
|
43 |
+
title VARCHAR(255),
|
44 |
+
body_html VARCHAR(255),
|
45 |
+
vendor VARCHAR(255),
|
46 |
+
product_type VARCHAR(255),
|
47 |
+
created_at VARCHAR(255),
|
48 |
+
handle VARCHAR(255),
|
49 |
+
updated_at DATE,
|
50 |
+
published_at VARCHAR(255),
|
51 |
+
template_suffix VARCHAR(255),
|
52 |
+
published_scope VARCHAR(255),
|
53 |
+
tags VARCHAR(255),
|
54 |
+
status VARCHAR(255),
|
55 |
+
admin_graphql_api_id DECIMAL(8,2),
|
56 |
+
variants VARCHAR(255),
|
57 |
+
options VARCHAR(255),
|
58 |
+
images VARCHAR(255),
|
59 |
+
image VARCHAR(255)
|
60 |
+
);
|
61 |
+
"""
|
62 |
+
|
63 |
+
# Construct the prompt for SQL generation
|
64 |
+
prompt = f"""### Task: Generate a SQL query to answer the following question.
|
65 |
+
### Database Schema:
|
66 |
+
{schema_info}
|
67 |
+
### Question: {natural_language_query}
|
68 |
+
### SQL Query:"""
|
69 |
+
# Tokenize and generate SQL query
|
70 |
+
inputs = tokenizer_sql(prompt, return_tensors="pt", add_special_tokens=False).to(model_sql.device)
|
71 |
+
outputs = model_sql.generate(
|
72 |
+
inputs["input_ids"],
|
73 |
+
max_length=256, # Keep it short to speed up generation
|
74 |
+
do_sample=False, # Use greedy search for faster results
|
75 |
+
num_return_sequences=1,
|
76 |
+
eos_token_id=tokenizer_sql.eos_token_id,
|
77 |
+
pad_token_id=tokenizer_sql.pad_token_id
|
78 |
+
)
|
79 |
+
|
80 |
+
# Decode and clean up the response
|
81 |
+
generated_query = tokenizer_sql.decode(outputs[0], skip_special_tokens=True)
|
82 |
+
sql_query = generated_query.strip() # Clean the output to get the SQL query
|
83 |
+
|
84 |
+
|
85 |
+
# Use the generated SQL query to determine the endpoint and parameters
|
86 |
+
endpoint = "products" # This can be dynamic based on your SQL parsing logic
|
87 |
+
products = fetch_data_from_shopify(endpoint)
|
88 |
+
|
89 |
+
# Prepare results based on the Shopify data
|
90 |
+
if products and 'products' in products:
|
91 |
+
results = products['products'] # Assuming you want the product data
|
92 |
+
|
93 |
+
# Process results through DistilGPT-2 model
|
94 |
+
result_text = "\n".join([f"Title: {product['title']}, Vendor: {product['vendor']}" for product in results])
|
95 |
+
|
96 |
+
# Prepare prompt for DistilGPT-2
|
97 |
+
gpt2_prompt = f"Here are the products based on your query:\n{result_text}\n\nProvide a summary of these products."
|
98 |
+
|
99 |
+
# Tokenize and generate response
|
100 |
+
inputs_gpt2 = tokenizer_gpt2(gpt2_prompt, return_tensors="pt", add_special_tokens=False).to(model_gpt2.device)
|
101 |
+
outputs_gpt2 = model_gpt2.generate(
|
102 |
+
inputs_gpt2["input_ids"],
|
103 |
+
max_length=150,
|
104 |
+
temperature=0.7,
|
105 |
+
do_sample=True,
|
106 |
+
top_p=0.95,
|
107 |
+
num_return_sequences=1,
|
108 |
+
eos_token_id=tokenizer_gpt2.eos_token_id,
|
109 |
+
pad_token_id=tokenizer_gpt2.pad_token_id
|
110 |
+
)
|
111 |
+
|
112 |
+
# Decode GPT-2 output
|
113 |
+
gpt2_response = tokenizer_gpt2.decode(outputs_gpt2[0], skip_special_tokens=True)
|
114 |
+
|
115 |
+
return f"SQL Query: {sql_query}\n\nResults:\n{gpt2_response}"
|
116 |
+
else:
|
117 |
+
return "No results found or error fetching data from Shopify."
|
118 |
+
|
119 |
+
def main():
|
120 |
+
# Gradio interface setup
|
121 |
+
iface = gr.Interface(
|
122 |
+
fn=generate_sql,
|
123 |
+
inputs="text",
|
124 |
+
outputs="text",
|
125 |
+
title="Natural Language to SQL Query Generator",
|
126 |
+
description="Enter a natural language query to generate the corresponding SQL query and display results."
|
127 |
+
)
|
128 |
+
iface.launch()
|
129 |
+
|
130 |
+
if __name__ == "__main__":
|
131 |
+
main()
|