nileshhanotia commited on
Commit
419da4a
1 Parent(s): a93ae9f

Create db_query.py

Browse files
Files changed (1) hide show
  1. db_query.py +131 -0
db_query.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, DistilBertTokenizer, DistilBertForSequenceClassification
4
+
5
+ # Shopify API credentials
6
+ API_KEY = 'c7c07dac8f8eaebd316ac4aa16350472'
7
+ API_SECRET = '046b2c0bfb7879db2bba86f24a281ec2'
8
+ ACCESS_TOKEN = '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'
9
+ SHOP_NAME = '6znwwf-77.myshopify.com'
10
+
11
+ # Load the SQL generation model
12
+ model_name = "premai-io/prem-1B-SQL"
13
+ tokenizer_sql = AutoTokenizer.from_pretrained(model_name)
14
+ model_sql = AutoModelForCausalLM.from_pretrained(model_name)
15
+
16
+ # Load the DistilGPT-2 model and tokenizer for processing results
17
+ distil_gpt2_model_name = "distilgpt2"
18
+ tokenizer_gpt2 = AutoTokenizer.from_pretrained(distil_gpt2_model_name)
19
+ model_gpt2 = AutoModelForCausalLM.from_pretrained(distil_gpt2_model_name)
20
+
21
+ def fetch_data_from_shopify(endpoint):
22
+ """Fetch data from Shopify API."""
23
+ headers = {
24
+ 'X-Shopify-Access-Token': ACCESS_TOKEN,
25
+ 'Content-Type': 'application/json'
26
+ }
27
+ url = f"https://{SHOP_NAME}/admin/api/2023-10/{endpoint}.json"
28
+ response = requests.get(url, headers=headers)
29
+
30
+ if response.status_code == 200:
31
+ return response.json()
32
+ else:
33
+ print(f"Error fetching {endpoint}: {response.status_code} - {response.text}")
34
+ return None
35
+
36
+ def generate_sql(natural_language_query):
37
+ """Generate SQL query from natural language and retrieve data from Shopify."""
38
+
39
+ # Define your schema information
40
+ schema_info = """
41
+ CREATE TABLE products (
42
+ id DECIMAL(8,2) PRIMARY KEY,
43
+ title VARCHAR(255),
44
+ body_html VARCHAR(255),
45
+ vendor VARCHAR(255),
46
+ product_type VARCHAR(255),
47
+ created_at VARCHAR(255),
48
+ handle VARCHAR(255),
49
+ updated_at DATE,
50
+ published_at VARCHAR(255),
51
+ template_suffix VARCHAR(255),
52
+ published_scope VARCHAR(255),
53
+ tags VARCHAR(255),
54
+ status VARCHAR(255),
55
+ admin_graphql_api_id DECIMAL(8,2),
56
+ variants VARCHAR(255),
57
+ options VARCHAR(255),
58
+ images VARCHAR(255),
59
+ image VARCHAR(255)
60
+ );
61
+ """
62
+
63
+ # Construct the prompt for SQL generation
64
+ prompt = f"""### Task: Generate a SQL query to answer the following question.
65
+ ### Database Schema:
66
+ {schema_info}
67
+ ### Question: {natural_language_query}
68
+ ### SQL Query:"""
69
+ # Tokenize and generate SQL query
70
+ inputs = tokenizer_sql(prompt, return_tensors="pt", add_special_tokens=False).to(model_sql.device)
71
+ outputs = model_sql.generate(
72
+ inputs["input_ids"],
73
+ max_length=256, # Keep it short to speed up generation
74
+ do_sample=False, # Use greedy search for faster results
75
+ num_return_sequences=1,
76
+ eos_token_id=tokenizer_sql.eos_token_id,
77
+ pad_token_id=tokenizer_sql.pad_token_id
78
+ )
79
+
80
+ # Decode and clean up the response
81
+ generated_query = tokenizer_sql.decode(outputs[0], skip_special_tokens=True)
82
+ sql_query = generated_query.strip() # Clean the output to get the SQL query
83
+
84
+
85
+ # Use the generated SQL query to determine the endpoint and parameters
86
+ endpoint = "products" # This can be dynamic based on your SQL parsing logic
87
+ products = fetch_data_from_shopify(endpoint)
88
+
89
+ # Prepare results based on the Shopify data
90
+ if products and 'products' in products:
91
+ results = products['products'] # Assuming you want the product data
92
+
93
+ # Process results through DistilGPT-2 model
94
+ result_text = "\n".join([f"Title: {product['title']}, Vendor: {product['vendor']}" for product in results])
95
+
96
+ # Prepare prompt for DistilGPT-2
97
+ gpt2_prompt = f"Here are the products based on your query:\n{result_text}\n\nProvide a summary of these products."
98
+
99
+ # Tokenize and generate response
100
+ inputs_gpt2 = tokenizer_gpt2(gpt2_prompt, return_tensors="pt", add_special_tokens=False).to(model_gpt2.device)
101
+ outputs_gpt2 = model_gpt2.generate(
102
+ inputs_gpt2["input_ids"],
103
+ max_length=150,
104
+ temperature=0.7,
105
+ do_sample=True,
106
+ top_p=0.95,
107
+ num_return_sequences=1,
108
+ eos_token_id=tokenizer_gpt2.eos_token_id,
109
+ pad_token_id=tokenizer_gpt2.pad_token_id
110
+ )
111
+
112
+ # Decode GPT-2 output
113
+ gpt2_response = tokenizer_gpt2.decode(outputs_gpt2[0], skip_special_tokens=True)
114
+
115
+ return f"SQL Query: {sql_query}\n\nResults:\n{gpt2_response}"
116
+ else:
117
+ return "No results found or error fetching data from Shopify."
118
+
119
+ def main():
120
+ # Gradio interface setup
121
+ iface = gr.Interface(
122
+ fn=generate_sql,
123
+ inputs="text",
124
+ outputs="text",
125
+ title="Natural Language to SQL Query Generator",
126
+ description="Enter a natural language query to generate the corresponding SQL query and display results."
127
+ )
128
+ iface.launch()
129
+
130
+ if __name__ == "__main__":
131
+ main()