File size: 3,092 Bytes
6126ba8
 
 
 
18a1e7e
 
6126ba8
 
 
 
42de116
6126ba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56abc73
6126ba8
 
8f0ce1c
 
 
 
 
 
6126ba8
8f0ce1c
 
42de116
8f0ce1c
 
 
42de116
8f0ce1c
 
42de116
 
 
 
6126ba8
 
 
 
 
 
 
42de116
6126ba8
 
 
 
8f0ce1c
42de116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import requests
from config import ACCESS_TOKEN, SHOP_NAME

class SQLGenerator:
    def __init__(self):
        self.model_name = "premai-io/prem-1B-SQL"
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
        
    def generate_query(self, natural_language_query):
        schema_info = """
        CREATE TABLE products (
          id DECIMAL(8,2) PRIMARY KEY,
          title VARCHAR(255),
          body_html VARCHAR(255),
          vendor VARCHAR(255),
          product_type VARCHAR(255),
          created_at VARCHAR(255),
          handle VARCHAR(255),
          updated_at DATE,
          published_at VARCHAR(255),
          template_suffix VARCHAR(255),
          published_scope VARCHAR(255),
          tags VARCHAR(255),
          status VARCHAR(255),
          admin_graphql_api_id DECIMAL(8,2),
          variants VARCHAR(255),
          options VARCHAR(255),
          images VARCHAR(255),
          image VARCHAR(255)
        );
        """
        
        prompt = f"""### Task: Generate a SQL query to answer the following question.
### Database Schema:
{schema_info}
### Question: {natural_language_query}
### SQL Query:"""

        inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(self.model.device)
        outputs = self.model.generate(
            inputs["input_ids"],
            max_length=256,
            do_sample=True,  # Enable sampling to use temperature
            num_return_sequences=1,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.pad_token_id,
            temperature=0.7,  # Allow temperature to affect output
            top_k=50          # Consider top k predictions for variability
        )
        
        generated_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
        return generated_query  # Return the generated SQL query
    
    def fetch_shopify_data(self, endpoint):
        headers = {
            'X-Shopify-Access-Token': ACCESS_TOKEN,
            'Content-Type': 'application/json'
        }
        url = f"https://{SHOP_NAME}/admin/api/2023-10/{endpoint}.json"
        response = requests.get(url, headers=headers)
        
        if response.status_code == 200:
            return response.json()
        else:
            print(f"Error fetching {endpoint}: {response.status_code} - {response.text}")
            return None

# Example of how to use the SQLGenerator class
if __name__ == "__main__":
    sql_generator = SQLGenerator()
    
    # Example natural language query
    natural_language_query = "Show me shirts with red color"
    
    # Generate SQL query
    sql_query = sql_generator.generate_query(natural_language_query)
    print(f"Generated SQL Query: {sql_query}")
    
    # Fetch data from Shopify (example endpoint)
    shopify_data = sql_generator.fetch_shopify_data("products")
    print(f"Shopify Data: {shopify_data}")