Spaces:
Sleeping
Sleeping
File size: 3,092 Bytes
6126ba8 18a1e7e 6126ba8 42de116 6126ba8 56abc73 6126ba8 8f0ce1c 6126ba8 8f0ce1c 42de116 8f0ce1c 42de116 8f0ce1c 42de116 6126ba8 42de116 6126ba8 8f0ce1c 42de116 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import requests
from config import ACCESS_TOKEN, SHOP_NAME
class SQLGenerator:
def __init__(self):
self.model_name = "premai-io/prem-1B-SQL"
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
def generate_query(self, natural_language_query):
schema_info = """
CREATE TABLE products (
id DECIMAL(8,2) PRIMARY KEY,
title VARCHAR(255),
body_html VARCHAR(255),
vendor VARCHAR(255),
product_type VARCHAR(255),
created_at VARCHAR(255),
handle VARCHAR(255),
updated_at DATE,
published_at VARCHAR(255),
template_suffix VARCHAR(255),
published_scope VARCHAR(255),
tags VARCHAR(255),
status VARCHAR(255),
admin_graphql_api_id DECIMAL(8,2),
variants VARCHAR(255),
options VARCHAR(255),
images VARCHAR(255),
image VARCHAR(255)
);
"""
prompt = f"""### Task: Generate a SQL query to answer the following question.
### Database Schema:
{schema_info}
### Question: {natural_language_query}
### SQL Query:"""
inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(self.model.device)
outputs = self.model.generate(
inputs["input_ids"],
max_length=256,
do_sample=True, # Enable sampling to use temperature
num_return_sequences=1,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
temperature=0.7, # Allow temperature to affect output
top_k=50 # Consider top k predictions for variability
)
generated_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
return generated_query # Return the generated SQL query
def fetch_shopify_data(self, endpoint):
headers = {
'X-Shopify-Access-Token': ACCESS_TOKEN,
'Content-Type': 'application/json'
}
url = f"https://{SHOP_NAME}/admin/api/2023-10/{endpoint}.json"
response = requests.get(url, headers=headers)
if response.status_code == 200:
return response.json()
else:
print(f"Error fetching {endpoint}: {response.status_code} - {response.text}")
return None
# Example of how to use the SQLGenerator class
if __name__ == "__main__":
sql_generator = SQLGenerator()
# Example natural language query
natural_language_query = "Show me shirts with red color"
# Generate SQL query
sql_query = sql_generator.generate_query(natural_language_query)
print(f"Generated SQL Query: {sql_query}")
# Fetch data from Shopify (example endpoint)
shopify_data = sql_generator.fetch_shopify_data("products")
print(f"Shopify Data: {shopify_data}")
|