File size: 4,941 Bytes
e6f4fec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch
import re

# Load the trained model and tokenizer
model = T5ForConditionalGeneration.from_pretrained("./t5_sql_finetuned")
tokenizer = T5Tokenizer.from_pretrained("./t5_sql_finetuned")

# Define a simple function to check if the question is schema-related or SQL-related
def is_schema_question(question: str):
    schema_keywords = ["columns", "tables", "structure", "schema", "relations", "fields"]
    return any(keyword in question.lower() for keyword in schema_keywords)

# Helper function to extract table name from the question
def extract_table_name(question: str):
    # Regex pattern to find table names, assuming table names are capitalized or match a known pattern
    table_name_match = re.search(r'for (\w+)|in (\w+)|from (\w+)', question)
    
    if table_name_match:
        # Return the matched table name (first capturing group)
        return table_name_match.group(1) or table_name_match.group(2) or table_name_match.group(3)
    
    # If no table name is detected, return None
    return None


# Define a function to handle SQL generation
def generate_sql(question: str, schema: dict, model, tokenizer, device):
    # Preprocess the question for SQL generation (e.g., reformat)
    # Example question: "What is the price of the product with ID 123?"
    
    # Here we use the model to generate SQL query
    inputs = tokenizer(question, return_tensors="pt")
    input_ids = inputs.input_ids.to(device)

    with torch.no_grad():
        generated_ids = model.generate(input_ids, max_length=128)

    # Decode the SQL query generated by the model
    sql_query = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    
    return sql_query

# Define a function to handle schema-related questions
def handle_schema_question(question: str, schema: dict):
    # Here you handle questions about the schema (tables, columns, relations)
    # Example schema-related question: "What columns does the products table have?"

    question = question.lower()

    # Check if the question asks about columns
    if "columns" in question or "fields" in question:
        table_name = extract_table_name(question)
        if table_name:
            if table_name in schema:
                return schema[table_name]["columns"]
            else:
                return f"Table '{table_name}' not found in the schema."
    
    # Check if the question asks about relations
    elif "relations" in question or "relationships" in question:
        table_name = extract_table_name(question)
        if table_name:
            if table_name in schema:
                return schema[table_name]["relations"]
            else:
                return f"Table '{table_name}' not found in the schema."
    
    # Additional cases can be handled here (e.g., "Which tables are in the schema?")
    elif "tables" in question:
        return list(schema.keys())
    
    # If the question is too vague or doesn't match the expected patterns
    return "Sorry, I couldn't understand your schema question. Could you rephrase?"


# Example schema for your custom use case
custom_schema = {
    "products": {
        "columns": ["product_id", "name", "price", "category_id"],
        "relations": "category_id -> categories.id",
    },
    "categories": {
        "columns": ["id", "category_name"],
        "relations": None,
    },
    "orders": {
        "columns": ["order_id", "user_id", "product_id", "order_date"],
        "relations": ["product_id -> products.product_id", "user_id -> users.user_id"],
    },
    "users": {
         "columns": ["user_id", "first_name", "last_name", "email", "phone_number", "address"],
        "relations": None,
    }
}

def answer_question(question: str, schema: dict, model, tokenizer, device):
    # First, check if the question is about the schema or SQL
    if is_schema_question(question):
        # Handle schema-related questions
        response = handle_schema_question(question, schema)
        return f"Schema Information: {response}"
    else:
        # Generate an SQL query for data-related questions
        sql_query = generate_sql(question, schema, model, tokenizer, device)
        return f"Generated SQL Query: {sql_query}"

# Example input questions
question_1 = "What columns does the products table have?"
question_2 = "What is the price of the product with product_id 123?"

# Assuming you have loaded your model and tokenizer as `model` and `tokenizer`
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Handle schema question
response_1 = answer_question(question_1, custom_schema, model, tokenizer, device)
print(response_1)  # This should give you the columns of the products table

# Handle SQL query question
response_2 = answer_question(question_2, custom_schema, model, tokenizer, device)
print(response_2)  # This should generate an SQL query for fetching the price