File size: 4,941 Bytes
e6f4fec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch
import re
# Load the trained model and tokenizer
model = T5ForConditionalGeneration.from_pretrained("./t5_sql_finetuned")
tokenizer = T5Tokenizer.from_pretrained("./t5_sql_finetuned")
# Define a simple function to check if the question is schema-related or SQL-related
def is_schema_question(question: str):
schema_keywords = ["columns", "tables", "structure", "schema", "relations", "fields"]
return any(keyword in question.lower() for keyword in schema_keywords)
# Helper function to extract table name from the question
def extract_table_name(question: str):
# Regex pattern to find table names, assuming table names are capitalized or match a known pattern
table_name_match = re.search(r'for (\w+)|in (\w+)|from (\w+)', question)
if table_name_match:
# Return the matched table name (first capturing group)
return table_name_match.group(1) or table_name_match.group(2) or table_name_match.group(3)
# If no table name is detected, return None
return None
# Define a function to handle SQL generation
def generate_sql(question: str, schema: dict, model, tokenizer, device):
# Preprocess the question for SQL generation (e.g., reformat)
# Example question: "What is the price of the product with ID 123?"
# Here we use the model to generate SQL query
inputs = tokenizer(question, return_tensors="pt")
input_ids = inputs.input_ids.to(device)
with torch.no_grad():
generated_ids = model.generate(input_ids, max_length=128)
# Decode the SQL query generated by the model
sql_query = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
return sql_query
# Define a function to handle schema-related questions
def handle_schema_question(question: str, schema: dict):
# Here you handle questions about the schema (tables, columns, relations)
# Example schema-related question: "What columns does the products table have?"
question = question.lower()
# Check if the question asks about columns
if "columns" in question or "fields" in question:
table_name = extract_table_name(question)
if table_name:
if table_name in schema:
return schema[table_name]["columns"]
else:
return f"Table '{table_name}' not found in the schema."
# Check if the question asks about relations
elif "relations" in question or "relationships" in question:
table_name = extract_table_name(question)
if table_name:
if table_name in schema:
return schema[table_name]["relations"]
else:
return f"Table '{table_name}' not found in the schema."
# Additional cases can be handled here (e.g., "Which tables are in the schema?")
elif "tables" in question:
return list(schema.keys())
# If the question is too vague or doesn't match the expected patterns
return "Sorry, I couldn't understand your schema question. Could you rephrase?"
# Example schema for your custom use case
custom_schema = {
"products": {
"columns": ["product_id", "name", "price", "category_id"],
"relations": "category_id -> categories.id",
},
"categories": {
"columns": ["id", "category_name"],
"relations": None,
},
"orders": {
"columns": ["order_id", "user_id", "product_id", "order_date"],
"relations": ["product_id -> products.product_id", "user_id -> users.user_id"],
},
"users": {
"columns": ["user_id", "first_name", "last_name", "email", "phone_number", "address"],
"relations": None,
}
}
def answer_question(question: str, schema: dict, model, tokenizer, device):
# First, check if the question is about the schema or SQL
if is_schema_question(question):
# Handle schema-related questions
response = handle_schema_question(question, schema)
return f"Schema Information: {response}"
else:
# Generate an SQL query for data-related questions
sql_query = generate_sql(question, schema, model, tokenizer, device)
return f"Generated SQL Query: {sql_query}"
# Example input questions
question_1 = "What columns does the products table have?"
question_2 = "What is the price of the product with product_id 123?"
# Assuming you have loaded your model and tokenizer as `model` and `tokenizer`
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Handle schema question
response_1 = answer_question(question_1, custom_schema, model, tokenizer, device)
print(response_1) # This should give you the columns of the products table
# Handle SQL query question
response_2 = answer_question(question_2, custom_schema, model, tokenizer, device)
print(response_2) # This should generate an SQL query for fetching the price |