|
import json |
|
import os |
|
import torch |
|
from datasets import Dataset |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForSeq2SeqLM, |
|
Seq2SeqTrainer, |
|
Seq2SeqTrainingArguments, |
|
) |
|
from torch.utils.data import DataLoader |
|
from sklearn.model_selection import train_test_split |
|
from tqdm import tqdm |
|
|
|
|
|
def load_table_schemas(tables_file): |
|
""" |
|
Load table schemas from the tables.jsonl file. |
|
|
|
Args: |
|
tables_file: Path to the tables.jsonl file. |
|
|
|
Returns: |
|
A dictionary mapping table IDs to their column names. |
|
""" |
|
table_schemas = {} |
|
with open(tables_file, 'r') as f: |
|
for line in f: |
|
table_data = json.loads(line) |
|
table_id = table_data["id"] |
|
table_columns = table_data["header"] |
|
table_schemas[table_id] = table_columns |
|
return table_schemas |
|
|
|
|
|
|
|
def load_wikisql(data_dir): |
|
""" |
|
Load WikiSQL data and prepare it for training. |
|
Args: |
|
data_dir: Path to the WikiSQL dataset directory. |
|
Returns: |
|
List of examples with input and target text. |
|
""" |
|
def parse_file(file_path): |
|
with open(file_path, 'r') as f: |
|
return [json.loads(line) for line in f] |
|
|
|
tables_data = parse_file(os.path.join(data_dir, "train.tables.jsonl")) |
|
train_data = parse_file(os.path.join(data_dir, "train.jsonl")) |
|
dev_data = parse_file(os.path.join(data_dir, "dev.jsonl")) |
|
|
|
print("====>", train_data[0]) |
|
tables_file = "./data/train.tables.jsonl" |
|
table_schemas = load_table_schemas(tables_file) |
|
|
|
dev_tables = './data/dev.tables.jsonl' |
|
dev_tables_schema = load_table_schemas(dev_tables) |
|
|
|
def format_data(data, type): |
|
formatted = [] |
|
for item in data: |
|
table_id = item["table_id"] |
|
table_columns = table_schemas[table_id] if type == 'train' else dev_tables_schema[table_id] |
|
question = item["question"] |
|
sql = item["sql"] |
|
sql_query = sql_to_text(sql, table_columns) |
|
print("SQL Query", sql_query) |
|
formatted.append({"input": f"Question: {question}", "target": sql_query}) |
|
return formatted |
|
|
|
return format_data(train_data, "train"), format_data(dev_data, "dev") |
|
|
|
|
|
def sql_to_text(sql, table_columns): |
|
""" |
|
Convert SQL dictionary from WikiSQL to text representation. |
|
|
|
Args: |
|
sql: SQL dictionary from WikiSQL (e.g., {"sel": 5, "conds": [[3, 0, "value"]], "agg": 0}). |
|
table_columns: List of column names corresponding to the table. |
|
|
|
Returns: |
|
SQL query as a string. |
|
""" |
|
|
|
agg_functions = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"] |
|
operators = ["=", ">", "<"] |
|
|
|
|
|
sel_column = table_columns[sql["sel"]] |
|
agg_func = agg_functions[sql["agg"]] |
|
select_clause = f"SELECT {agg_func}({sel_column})" if agg_func else f"SELECT {sel_column}" |
|
|
|
|
|
if sql["conds"]: |
|
conditions = [] |
|
for cond in sql["conds"]: |
|
col_idx, operator, value = cond |
|
col_name = table_columns[col_idx] |
|
conditions.append(f"{col_name} {operators[operator]} '{value}'") |
|
where_clause = " WHERE " + " AND ".join(conditions) |
|
else: |
|
where_clause = "" |
|
|
|
|
|
return select_clause + where_clause |
|
|
|
|
|
def tokenize_data(data, tokenizer, max_length=128): |
|
""" |
|
Tokenize the input and target text. |
|
Args: |
|
data: List of examples with "input" and "target". |
|
tokenizer: Pretrained tokenizer. |
|
max_length: Maximum sequence length for the model. |
|
Returns: |
|
Tokenized dataset. |
|
""" |
|
inputs = [item["input"] for item in data] |
|
targets = [item["target"] for item in data] |
|
|
|
tokenized = tokenizer( |
|
inputs, |
|
max_length=max_length, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
labels = tokenizer( |
|
targets, |
|
max_length=max_length, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
|
|
tokenized["labels"] = labels["input_ids"] |
|
return tokenized |
|
|
|
|
|
|
|
model_name = "t5-small" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
|
|
data_dir = "data" |
|
train_data, dev_data = load_wikisql(data_dir) |
|
|
|
|
|
train_dataset = tokenize_data(train_data, tokenizer) |
|
dev_dataset = tokenize_data(dev_data, tokenizer) |
|
|
|
|
|
train_dataset = Dataset.from_dict(train_dataset) |
|
dev_dataset = Dataset.from_dict(dev_dataset) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_question = "Find all orders with product_id greater than 5." |
|
input_text = f"Question: {test_question}" |
|
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True) |
|
|
|
outputs = model.generate(**inputs, max_length=128) |
|
generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
print("Generated SQL:", generated_sql) |
|
|