import gradio as gr import requests from transformers import AutoModelForCausalLM, AutoTokenizer, DistilBertTokenizer, DistilBertForSequenceClassification # Shopify API credentials API_KEY = 'c7c07dac8f8eaebd316ac4aa16350472' API_SECRET = '046b2c0bfb7879db2bba86f24a281ec2' ACCESS_TOKEN = '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' SHOP_NAME = '6znwwf-77.myshopify.com' # Load the SQL generation model model_name = "premai-io/prem-1B-SQL" tokenizer_sql = AutoTokenizer.from_pretrained(model_name) model_sql = AutoModelForCausalLM.from_pretrained(model_name) # Load the DistilGPT-2 model and tokenizer for processing results distil_gpt2_model_name = "distilgpt2" tokenizer_gpt2 = AutoTokenizer.from_pretrained(distil_gpt2_model_name) model_gpt2 = AutoModelForCausalLM.from_pretrained(distil_gpt2_model_name) def fetch_data_from_shopify(endpoint): """Fetch data from Shopify API.""" headers = { 'X-Shopify-Access-Token': ACCESS_TOKEN, 'Content-Type': 'application/json' } url = f"https://{SHOP_NAME}/admin/api/2023-10/{endpoint}.json" response = requests.get(url, headers=headers) if response.status_code == 200: return response.json() else: print(f"Error fetching {endpoint}: {response.status_code} - {response.text}") return None def generate_sql(natural_language_query): """Generate SQL query from natural language and retrieve data from Shopify.""" # Define your schema information schema_info = """ CREATE TABLE products ( id DECIMAL(8,2) PRIMARY KEY, title VARCHAR(255), body_html VARCHAR(255), vendor VARCHAR(255), product_type VARCHAR(255), created_at VARCHAR(255), handle VARCHAR(255), updated_at DATE, published_at VARCHAR(255), template_suffix VARCHAR(255), published_scope VARCHAR(255), tags VARCHAR(255), status VARCHAR(255), admin_graphql_api_id DECIMAL(8,2), variants VARCHAR(255), options VARCHAR(255), images VARCHAR(255), image VARCHAR(255) ); """ # Construct the prompt for SQL generation prompt = f"""### Task: Generate a SQL query to answer the following question. ### Database Schema: {schema_info} ### Question: {natural_language_query} ### SQL Query:""" # Tokenize and generate SQL query inputs = tokenizer_sql(prompt, return_tensors="pt", add_special_tokens=False).to(model_sql.device) outputs = model_sql.generate( inputs["input_ids"], max_length=256, # Keep it short to speed up generation do_sample=False, # Use greedy search for faster results num_return_sequences=1, eos_token_id=tokenizer_sql.eos_token_id, pad_token_id=tokenizer_sql.pad_token_id ) # Decode and clean up the response generated_query = tokenizer_sql.decode(outputs[0], skip_special_tokens=True) sql_query = generated_query.strip() # Clean the output to get the SQL query # Use the generated SQL query to determine the endpoint and parameters endpoint = "products" # This can be dynamic based on your SQL parsing logic products = fetch_data_from_shopify(endpoint) # Prepare results based on the Shopify data if products and 'products' in products: results = products['products'] # Assuming you want the product data # Process results through DistilGPT-2 model result_text = "\n".join([f"Title: {product['title']}, Vendor: {product['vendor']}" for product in results]) # Prepare prompt for DistilGPT-2 gpt2_prompt = f"Here are the products based on your query:\n{result_text}\n\nProvide a summary of these products." # Tokenize and generate response inputs_gpt2 = tokenizer_gpt2(gpt2_prompt, return_tensors="pt", add_special_tokens=False).to(model_gpt2.device) outputs_gpt2 = model_gpt2.generate( inputs_gpt2["input_ids"], max_length=150, temperature=0.7, do_sample=True, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer_gpt2.eos_token_id, pad_token_id=tokenizer_gpt2.pad_token_id ) # Decode GPT-2 output gpt2_response = tokenizer_gpt2.decode(outputs_gpt2[0], skip_special_tokens=True) return f"SQL Query: {sql_query}\n\nResults:\n{gpt2_response}" else: return "No results found or error fetching data from Shopify." def main(): # Gradio interface setup iface = gr.Interface( fn=generate_sql, inputs="text", outputs="text", title="Natural Language to SQL Query Generator", description="Enter a natural language query to generate the corresponding SQL query and display results." ) iface.launch() if __name__ == "__main__": main()