nileshhanotia commited on
Commit
8f0ce1c
1 Parent(s): c923b38

Update sql_generator.py

Browse files
Files changed (1) hide show
  1. sql_generator.py +19 -18
sql_generator.py CHANGED
@@ -8,7 +8,7 @@ class SQLGenerator:
8
  self.model_name = "premai-io/prem-1B-SQL"
9
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
10
  self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
11
-
12
  def generate_query(self, natural_language_query):
13
  schema_info = """
14
  CREATE TABLE products (
@@ -34,23 +34,24 @@ class SQLGenerator:
34
  """
35
 
36
  prompt = f"""### Task: Generate a SQL query to answer the following question.
37
- ### Database Schema:
38
- {schema_info}
39
- ### Question: {natural_language_query}
40
- ### SQL Query:"""
41
-
 
42
  outputs = self.model.generate(
43
- inputs["input_ids"],
44
- max_length=256,
45
- do_sample=False,
46
- num_return_sequences=1,
47
- eos_token_id=self.tokenizer.eos_token_id,
48
- pad_token_id=self.tokenizer.pad_token_id,
49
- temperature=0.7, # Adjust temperature for more creative output
50
- top_k=50 # Consider top k predictions for variability
51
- )
52
  return self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
53
-
54
  def fetch_shopify_data(self, endpoint):
55
  headers = {
56
  'X-Shopify-Access-Token': ACCESS_TOKEN,
@@ -58,9 +59,9 @@ class SQLGenerator:
58
  }
59
  url = f"https://{SHOP_NAME}/admin/api/2023-10/{endpoint}.json"
60
  response = requests.get(url, headers=headers)
61
-
62
  if response.status_code == 200:
63
  return response.json()
64
  else:
65
  print(f"Error fetching {endpoint}: {response.status_code} - {response.text}")
66
- return None
 
8
  self.model_name = "premai-io/prem-1B-SQL"
9
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
10
  self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
11
+
12
  def generate_query(self, natural_language_query):
13
  schema_info = """
14
  CREATE TABLE products (
 
34
  """
35
 
36
  prompt = f"""### Task: Generate a SQL query to answer the following question.
37
+ ### Database Schema:
38
+ {schema_info}
39
+ ### Question: {natural_language_query}
40
+ ### SQL Query:"""
41
+
42
+ inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(self.model.device)
43
  outputs = self.model.generate(
44
+ inputs["input_ids"],
45
+ max_length=256,
46
+ do_sample=False,
47
+ num_return_sequences=1,
48
+ eos_token_id=self.tokenizer.eos_token_id,
49
+ pad_token_id=self.tokenizer.pad_token_id,
50
+ temperature=0.7, # Adjust temperature for more creative output
51
+ top_k=50 # Consider top k predictions for variability
52
+ )
53
  return self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
54
+
55
  def fetch_shopify_data(self, endpoint):
56
  headers = {
57
  'X-Shopify-Access-Token': ACCESS_TOKEN,
 
59
  }
60
  url = f"https://{SHOP_NAME}/admin/api/2023-10/{endpoint}.json"
61
  response = requests.get(url, headers=headers)
62
+
63
  if response.status_code == 200:
64
  return response.json()
65
  else:
66
  print(f"Error fetching {endpoint}: {response.status_code} - {response.text}")
67
+ return None