nileshhanotia commited on
Commit
6126ba8
1 Parent(s): 8cf01c9

Update sql_generator.py

Browse files
Files changed (1) hide show
  1. sql_generator.py +62 -36
sql_generator.py CHANGED
@@ -1,40 +1,66 @@
1
- import sqlite3
2
- from typing import List, Dict, Any
3
- import logging
 
4
 
5
  class SQLGenerator:
6
- def __init__(self, db_path: str = "shopify.db"):
7
- self.db_path = db_path
8
- self.setup_logging()
 
9
 
10
- def setup_logging(self):
11
- logging.basicConfig(level=logging.INFO)
12
- self.logger = logging.getLogger(__name__)
13
-
14
- def execute_query(self, query: str) -> List[Dict[str, Any]]:
15
- """
16
- Execute SQL query and return results as a list of dictionaries
17
- """
18
- try:
19
- with sqlite3.connect(self.db_path) as conn:
20
- conn.row_factory = sqlite3.Row
21
- cursor = conn.cursor()
22
- cursor.execute(query)
23
- results = [dict(row) for row in cursor.fetchall()]
24
- self.logger.info(f"Successfully executed query: {query[:100]}...")
25
- return results
26
- except sqlite3.Error as e:
27
- self.logger.error(f"Database error: {e}")
28
- raise
29
- except Exception as e:
30
- self.logger.error(f"Error executing query: {e}")
31
- raise
32
-
33
- def validate_query(self, query: str) -> bool:
34
- """
35
- Validate SQL query before execution
36
  """
37
- # Basic validation - you might want to add more sophisticated validation
38
- dangerous_keywords = ["DROP", "DELETE", "TRUNCATE", "UPDATE", "INSERT"]
39
- return not any(keyword in query.upper() for keyword in dangerous_keywords)
40
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import requests
4
+ from config import ACCESS_TOKEN, SHOP_NAME
5
 
6
  class SQLGenerator:
7
+ def __init__(self):
8
+ self.model_name = "premai-io/prem-1B-SQL"
9
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
10
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
11
 
12
+ def generate_query(self, natural_language_query):
13
+ schema_info = """
14
+ CREATE TABLE products (
15
+ id DECIMAL(8,2) PRIMARY KEY,
16
+ title VARCHAR(255),
17
+ body_html VARCHAR(255),
18
+ vendor VARCHAR(255),
19
+ product_type VARCHAR(255),
20
+ created_at VARCHAR(255),
21
+ handle VARCHAR(255),
22
+ updated_at DATE,
23
+ published_at VARCHAR(255),
24
+ template_suffix VARCHAR(255),
25
+ published_scope VARCHAR(255),
26
+ tags VARCHAR(255),
27
+ status VARCHAR(255),
28
+ admin_graphql_api_id DECIMAL(8,2),
29
+ variants VARCHAR(255),
30
+ options VARCHAR(255),
31
+ images VARCHAR(255),
32
+ image VARCHAR(255)
33
+ );
 
 
 
 
34
  """
35
+
36
+ prompt = f"""### Task: Generate a SQL query to answer the following question.
37
+ ### Database Schema:
38
+ {schema_info}
39
+ ### Question: {natural_language_query}
40
+ ### SQL Query:"""
41
+
42
+ inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(self.model.device)
43
+ outputs = self.model.generate(
44
+ inputs["input_ids"],
45
+ max_length=256,
46
+ do_sample=False,
47
+ num_return_sequences=1,
48
+ eos_token_id=self.tokenizer.eos_token_id,
49
+ pad_token_id=self.tokenizer.pad_token_id
50
+ )
51
+
52
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
53
+
54
+ def fetch_shopify_data(self, endpoint):
55
+ headers = {
56
+ 'X-Shopify-Access-Token': ACCESS_TOKEN,
57
+ 'Content-Type': 'application/json'
58
+ }
59
+ url = f"https://{SHOP_NAME}/admin/api/2023-10/{endpoint}.json"
60
+ response = requests.get(url, headers=headers)
61
+
62
+ if response.status_code == 200:
63
+ return response.json()
64
+ else:
65
+ print(f"Error fetching {endpoint}: {response.status_code} - {response.text}")
66
+ return None