Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import requests | |
from config import ACCESS_TOKEN, SHOP_NAME | |
class SQLGenerator: | |
def __init__(self): | |
self.model_name = "premai-io/prem-1B-SQL" | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
self.model = AutoModelForCausalLM.from_pretrained(self.model_name) | |
def generate_query(self, natural_language_query): | |
schema_info = """ | |
CREATE TABLE products ( | |
id DECIMAL(8,2) PRIMARY KEY, | |
title VARCHAR(255), | |
body_html VARCHAR(255), | |
vendor VARCHAR(255), | |
product_type VARCHAR(255), | |
created_at VARCHAR(255), | |
handle VARCHAR(255), | |
updated_at DATE, | |
published_at VARCHAR(255), | |
template_suffix VARCHAR(255), | |
published_scope VARCHAR(255), | |
tags VARCHAR(255), | |
status VARCHAR(255), | |
admin_graphql_api_id DECIMAL(8,2), | |
variants VARCHAR(255), | |
options VARCHAR(255), | |
images VARCHAR(255), | |
image VARCHAR(255) | |
); | |
""" | |
prompt = f"""### Task: Generate a SQL query to answer the following question. | |
### Database Schema: | |
{schema_info} | |
### Question: {natural_language_query} | |
### SQL Query:""" | |
inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(self.model.device) | |
outputs = self.model.generate( | |
inputs["input_ids"], | |
max_length=256, | |
do_sample=False, | |
num_return_sequences=1, | |
eos_token_id=self.tokenizer.eos_token_id, | |
pad_token_id=self.tokenizer.pad_token_id, | |
temperature=0.7, # Adjust temperature for more creative output | |
top_k=50 # Consider top k predictions for variability | |
) | |
return self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip() | |
def fetch_shopify_data(self, endpoint): | |
headers = { | |
'X-Shopify-Access-Token': ACCESS_TOKEN, | |
'Content-Type': 'application/json' | |
} | |
url = f"https://{SHOP_NAME}/admin/api/2023-10/{endpoint}.json" | |
response = requests.get(url, headers=headers) | |
if response.status_code == 200: | |
return response.json() | |
else: | |
print(f"Error fetching {endpoint}: {response.status_code} - {response.text}") | |
return None | |