import torch from transformers import AutoModelForCausalLM, AutoTokenizer import requests from config import ACCESS_TOKEN, SHOP_NAME class SQLGenerator: def __init__(self): self.model_name = "premai-io/prem-1B-SQL" self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.model = AutoModelForCausalLM.from_pretrained(self.model_name) def generate_query(self, natural_language_query): schema_info = """ CREATE TABLE products ( id DECIMAL(8,2) PRIMARY KEY, title VARCHAR(255), body_html VARCHAR(255), vendor VARCHAR(255), product_type VARCHAR(255), created_at VARCHAR(255), handle VARCHAR(255), updated_at DATE, published_at VARCHAR(255), template_suffix VARCHAR(255), published_scope VARCHAR(255), tags VARCHAR(255), status VARCHAR(255), admin_graphql_api_id DECIMAL(8,2), variants VARCHAR(255), options VARCHAR(255), images VARCHAR(255), image VARCHAR(255) ); """ prompt = f"""### Task: Generate a SQL query to answer the following question. ### Database Schema: {schema_info} ### Question: {natural_language_query} ### SQL Query:""" inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(self.model.device) outputs = self.model.generate( inputs["input_ids"], max_length=256, do_sample=False, num_return_sequences=1, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id, temperature=0.7, # Adjust temperature for more creative output top_k=50 # Consider top k predictions for variability ) return self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip() def fetch_shopify_data(self, endpoint): headers = { 'X-Shopify-Access-Token': ACCESS_TOKEN, 'Content-Type': 'application/json' } url = f"https://{SHOP_NAME}/admin/api/2023-10/{endpoint}.json" response = requests.get(url, headers=headers) if response.status_code == 200: return response.json() else: print(f"Error fetching {endpoint}: {response.status_code} - {response.text}") return None