chronos / wikiSQL.py
Manoj Kumar
Mark POhase 1
e6f4fec
import json
import os
import torch
from datasets import Dataset
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
)
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
def load_table_schemas(tables_file):
"""
Load table schemas from the tables.jsonl file.
Args:
tables_file: Path to the tables.jsonl file.
Returns:
A dictionary mapping table IDs to their column names.
"""
table_schemas = {}
with open(tables_file, 'r') as f:
for line in f:
table_data = json.loads(line)
table_id = table_data["id"]
table_columns = table_data["header"]
table_schemas[table_id] = table_columns
return table_schemas
# Step 1: Load and Preprocess WikiSQL Data
def load_wikisql(data_dir):
"""
Load WikiSQL data and prepare it for training.
Args:
data_dir: Path to the WikiSQL dataset directory.
Returns:
List of examples with input and target text.
"""
def parse_file(file_path):
with open(file_path, 'r') as f:
return [json.loads(line) for line in f]
tables_data = parse_file(os.path.join(data_dir, "train.tables.jsonl"))
train_data = parse_file(os.path.join(data_dir, "train.jsonl"))
dev_data = parse_file(os.path.join(data_dir, "dev.jsonl"))
print("====>", train_data[0])
tables_file = "./data/train.tables.jsonl"
table_schemas = load_table_schemas(tables_file)
dev_tables = './data/dev.tables.jsonl'
dev_tables_schema = load_table_schemas(dev_tables)
def format_data(data, type):
formatted = []
for item in data:
table_id = item["table_id"]
table_columns = table_schemas[table_id] if type == 'train' else dev_tables_schema[table_id]
question = item["question"]
sql = item["sql"]
sql_query = sql_to_text(sql, table_columns)
print("SQL Query", sql_query)
formatted.append({"input": f"Question: {question}", "target": sql_query})
return formatted
return format_data(train_data, "train"), format_data(dev_data, "dev")
def sql_to_text(sql, table_columns):
"""
Convert SQL dictionary from WikiSQL to text representation.
Args:
sql: SQL dictionary from WikiSQL (e.g., {"sel": 5, "conds": [[3, 0, "value"]], "agg": 0}).
table_columns: List of column names corresponding to the table.
Returns:
SQL query as a string.
"""
# Aggregation functions mapping
agg_functions = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]
operators = ["=", ">", "<"]
# Get selected column
sel_column = table_columns[sql["sel"]]
agg_func = agg_functions[sql["agg"]]
select_clause = f"SELECT {agg_func}({sel_column})" if agg_func else f"SELECT {sel_column}"
# Get conditions
if sql["conds"]:
conditions = []
for cond in sql["conds"]:
col_idx, operator, value = cond
col_name = table_columns[col_idx]
conditions.append(f"{col_name} {operators[operator]} '{value}'")
where_clause = " WHERE " + " AND ".join(conditions)
else:
where_clause = ""
# Combine clauses into a full query
return select_clause + where_clause
# Step 2: Tokenize the Data
def tokenize_data(data, tokenizer, max_length=128):
"""
Tokenize the input and target text.
Args:
data: List of examples with "input" and "target".
tokenizer: Pretrained tokenizer.
max_length: Maximum sequence length for the model.
Returns:
Tokenized dataset.
"""
inputs = [item["input"] for item in data]
targets = [item["target"] for item in data]
tokenized = tokenizer(
inputs,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
labels = tokenizer(
targets,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
tokenized["labels"] = labels["input_ids"]
return tokenized
# Step 3: Load Model and Tokenizer
model_name = "t5-small" # Use "t5-small", "t5-base", or "t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# Step 4: Prepare Training and Validation Data
data_dir = "data" # Path to the WikiSQL dataset
train_data, dev_data = load_wikisql(data_dir)
# Tokenize Data
train_dataset = tokenize_data(train_data, tokenizer)
dev_dataset = tokenize_data(dev_data, tokenizer)
# # Convert to Hugging Face Dataset format
train_dataset = Dataset.from_dict(train_dataset)
dev_dataset = Dataset.from_dict(dev_dataset)
# # # Step 5: Define Training Arguments
# training_args = Seq2SeqTrainingArguments(
# output_dir="./t5_sql_finetuned",
# evaluation_strategy="steps",
# save_steps=1000,
# eval_steps=100,
# logging_steps=100,
# per_device_train_batch_size=16,
# per_device_eval_batch_size=16,
# num_train_epochs=3,
# save_total_limit=2,
# learning_rate=5e-5,
# predict_with_generate=True,
# fp16=torch.cuda.is_available(), # Enable mixed precision for faster training
# logging_dir="./logs",
# )
# # # Step 6: Define Trainer
# trainer = Seq2SeqTrainer(
# model=model,
# args=training_args,
# train_dataset=train_dataset,
# eval_dataset=dev_dataset,
# tokenizer=tokenizer,
# )
# # # Step 7: Train the Model
# trainer.train()
# # # Step 8: Save the Model
# trainer.save_model("./t5_sql_finetuned")
# tokenizer.save_pretrained("./t5_sql_finetuned")
# # Step 9: Test the Model
test_question = "Find all orders with product_id greater than 5."
input_text = f"Question: {test_question}"
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
outputs = model.generate(**inputs, max_length=128)
generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Generated SQL:", generated_sql)