Manoj Kumar commited on
Commit
39179ce
·
1 Parent(s): f753320

updated question structure

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +3 -35
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: red
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.11.0
8
- app_file: gpt_neo_db.py
9
  pinned: false
10
  python: 3.9
11
  ---
 
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.11.0
8
+ app_file: app.py
9
  pinned: false
10
  python: 3.9
11
  ---
app.py CHANGED
@@ -9,41 +9,9 @@ db_schema = {
9
  "customers": ["customer_id", "name", "email", "phone_number"]
10
  }
11
 
12
- # Load the model and tokenizer
13
- model_name = "EleutherAI/gpt-neox-20b" # You can also use "Llama-2-7b" or another model
14
- tokenizer = AutoTokenizer.from_pretrained(model_name)
15
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
16
 
17
- def generate_sql_query(context, question):
18
- """
19
- Generate an SQL query based on the question and context.
20
-
21
- Args:
22
- context (str): Description of the database schema or table relationships.
23
- question (str): User's natural language query.
24
-
25
- Returns:
26
- str: Generated SQL query.
27
- """
28
- # Prepare the prompt
29
- prompt = f"""
30
- Context: {context}
31
-
32
- Question: {question}
33
-
34
- Write an SQL query to address the question based on the context.
35
- Query:
36
- """
37
- # Tokenize input
38
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to("cuda" if torch.cuda.is_available() else "cpu")
39
-
40
- # Generate SQL query
41
- output = model.generate(inputs.input_ids, max_length=512, num_beams=5, early_stopping=True)
42
- query = tokenizer.decode(output[0], skip_special_tokens=True)
43
-
44
- # Extract query from the output
45
- sql_query = query.split("Query:")[-1].strip()
46
- return sql_query
47
 
48
  # Schema as a context for the model
49
  schema_description = json.dumps(db_schema, indent=4)
@@ -57,5 +25,5 @@ while True:
57
  break
58
 
59
  # Generate SQL query
60
- sql_query = generate_sql_query(schema_description, user_question)
61
  print(f"Generated SQL Query:\n{sql_query}\n")
 
9
  "customers": ["customer_id", "name", "email", "phone_number"]
10
  }
11
 
 
 
 
 
12
 
13
+ def dummy_function(schema_description, user_question):
14
+ print(user_question)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # Schema as a context for the model
17
  schema_description = json.dumps(db_schema, indent=4)
 
25
  break
26
 
27
  # Generate SQL query
28
+ sql_query = dummy_function(schema_description, user_question)
29
  print(f"Generated SQL Query:\n{sql_query}\n")