import torch from transformers import AutoModelForCausalLM, AutoTokenizer import requests from config import ACCESS_TOKEN, SHOP_NAME class SQLGenerator: def __init__(self): self.model_name = "premai-io/prem-1B-SQL" self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.model = AutoModelForCausalLM.from_pretrained(self.model_name) def generate_query(self, natural_language_query): schema_info = """ CREATE TABLE products ( id DECIMAL(8,2) PRIMARY KEY, title VARCHAR(255), body_html VARCHAR(255), vendor VARCHAR(255), product_type VARCHAR(255), created_at VARCHAR(255), handle VARCHAR(255), updated_at DATE, published_at VARCHAR(255), template_suffix VARCHAR(255), published_scope VARCHAR(255), tags VARCHAR(255), status VARCHAR(255), admin_graphql_api_id DECIMAL(8,2), variants VARCHAR(255), options VARCHAR(255), images VARCHAR(255), image VARCHAR(255) ); """ prompt = f"""### Task: Generate a SQL query to answer the following question. ### Database Schema: {schema_info} ### Question: {natural_language_query} ### SQL Query:""" inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(self.model.device) outputs = self.model.generate( inputs["input_ids"], max_length=256, do_sample=True, # Enable sampling to use temperature num_return_sequences=1, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id, temperature=0.7, # Allow temperature to affect output top_k=50 # Consider top k predictions for variability ) generated_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip() return generated_query # Return the generated SQL query def fetch_shopify_data(self, endpoint): headers = { 'X-Shopify-Access-Token': ACCESS_TOKEN, 'Content-Type': 'application/json' } url = f"https://{SHOP_NAME}/admin/api/2023-10/{endpoint}.json" response = requests.get(url, headers=headers) if response.status_code == 200: return response.json() else: print(f"Error fetching {endpoint}: {response.status_code} - {response.text}") return None # Example of how to use the SQLGenerator class if __name__ == "__main__": sql_generator = SQLGenerator() # Example natural language query natural_language_query = "Show me shirts with red color" # Generate SQL query sql_query = sql_generator.generate_query(natural_language_query) print(f"Generated SQL Query: {sql_query}") # Fetch data from Shopify (example endpoint) shopify_data = sql_generator.fetch_shopify_data("products") print(f"Shopify Data: {shopify_data}")