Manoj Kumar commited on
Commit
9c8236d
·
1 Parent(s): 39179ce

updated question structure

Browse files
Files changed (1) hide show
  1. app.py +42 -11
app.py CHANGED
@@ -9,21 +9,52 @@ db_schema = {
9
  "customers": ["customer_id", "name", "email", "phone_number"]
10
  }
11
 
 
 
 
 
12
 
13
- def dummy_function(schema_description, user_question):
14
- print(user_question)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # Schema as a context for the model
17
  schema_description = json.dumps(db_schema, indent=4)
18
 
19
  # Example interactive questions
20
- print("Ask a question about the database schema.")
21
- while True:
22
- user_question = input("Question: ")
23
- if user_question.lower() in ["exit", "quit"]:
24
- print("Exiting...")
25
- break
26
 
27
- # Generate SQL query
28
- sql_query = dummy_function(schema_description, user_question)
29
- print(f"Generated SQL Query:\n{sql_query}\n")
 
 
9
  "customers": ["customer_id", "name", "email", "phone_number"]
10
  }
11
 
12
+ # Load the model and tokenizer
13
+ model_name = "EleutherAI/gpt-neox-20b" # You can also use "Llama-2-7b" or another model
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
16
 
17
+ def generate_sql_query(context, question):
18
+ """
19
+ Generate an SQL query based on the question and context.
20
+
21
+ Args:
22
+ context (str): Description of the database schema or table relationships.
23
+ question (str): User's natural language query.
24
+
25
+ Returns:
26
+ str: Generated SQL query.
27
+ """
28
+ # Prepare the prompt
29
+ prompt = f"""
30
+ Context: {context}
31
+
32
+ Question: {question}
33
+
34
+ Write an SQL query to address the question based on the context.
35
+ Query:
36
+ """
37
+ # Tokenize input
38
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to("cuda" if torch.cuda.is_available() else "cpu")
39
+
40
+ # Generate SQL query
41
+ output = model.generate(inputs.input_ids, max_length=512, num_beams=5, early_stopping=True)
42
+ query = tokenizer.decode(output[0], skip_special_tokens=True)
43
+
44
+ # Extract query from the output
45
+ sql_query = query.split("Query:")[-1].strip()
46
+ return sql_query
47
 
48
  # Schema as a context for the model
49
  schema_description = json.dumps(db_schema, indent=4)
50
 
51
  # Example interactive questions
52
+ questions = [
53
+ "Show all products that cost more than $50.",
54
+ "List all customers who ordered a specific product.",
55
+ ]
 
 
56
 
57
+ for user_question in questions:
58
+ print(f"Question: {user_question}")
59
+ sql_query = generate_sql_query(schema_description, user_question)
60
+ print(f"Generated SQL Query:\n{sql_query}\n")