File size: 2,439 Bytes
6126ba8
 
 
 
18a1e7e
 
6126ba8
 
 
 
8f0ce1c
6126ba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56abc73
6126ba8
 
8f0ce1c
 
 
 
 
 
6126ba8
8f0ce1c
 
 
 
 
 
 
 
 
6126ba8
8f0ce1c
6126ba8
 
 
 
 
 
 
8f0ce1c
6126ba8
 
 
 
8f0ce1c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import requests
from config import ACCESS_TOKEN, SHOP_NAME

class SQLGenerator:
    def __init__(self):
        self.model_name = "premai-io/prem-1B-SQL"
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.model = AutoModelForCausalLM.from_pretrained(self.model_name)

    def generate_query(self, natural_language_query):
        schema_info = """
        CREATE TABLE products (
          id DECIMAL(8,2) PRIMARY KEY,
          title VARCHAR(255),
          body_html VARCHAR(255),
          vendor VARCHAR(255),
          product_type VARCHAR(255),
          created_at VARCHAR(255),
          handle VARCHAR(255),
          updated_at DATE,
          published_at VARCHAR(255),
          template_suffix VARCHAR(255),
          published_scope VARCHAR(255),
          tags VARCHAR(255),
          status VARCHAR(255),
          admin_graphql_api_id DECIMAL(8,2),
          variants VARCHAR(255),
          options VARCHAR(255),
          images VARCHAR(255),
          image VARCHAR(255)
        );
        """
        
        prompt = f"""### Task: Generate a SQL query to answer the following question.
### Database Schema:
{schema_info}
### Question: {natural_language_query}
### SQL Query:"""

        inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(self.model.device)
        outputs = self.model.generate(
            inputs["input_ids"],
            max_length=256,
            do_sample=False,
            num_return_sequences=1,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.pad_token_id,
            temperature=0.7,  # Adjust temperature for more creative output
            top_k=50          # Consider top k predictions for variability
        )
        return self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()

    def fetch_shopify_data(self, endpoint):
        headers = {
            'X-Shopify-Access-Token': ACCESS_TOKEN,
            'Content-Type': 'application/json'
        }
        url = f"https://{SHOP_NAME}/admin/api/2023-10/{endpoint}.json"
        response = requests.get(url, headers=headers)

        if response.status_code == 200:
            return response.json()
        else:
            print(f"Error fetching {endpoint}: {response.status_code} - {response.text}")
            return None