This is an upgraded version of https://huggingface.co/juierror/flan-t5-text2sql-with-schema.
It supports the '<' sign and can handle multiple tables.
How to use
from typing import List
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2")
model = AutoModelForSeq2SeqLM.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2")
def get_prompt(tables, question):
prompt = f"""convert question and table into SQL query. tables: {tables}. question: {question}"""
return prompt
def prepare_input(question: str, tables: Dict[str, List[str]]):
tables = [f"""{table_name}({",".join(tables[table_name])})""" for table_name in tables]
tables = ", ".join(tables)
prompt = get_prompt(tables, question)
input_ids = tokenizer(prompt, max_length=512, return_tensors="pt").input_ids
return input_ids
def inference(question: str, tables: Dict[str, List[str]]) -> str:
input_data = prepare_input(question=question, tables=tables)
input_data = input_data.to(model.device)
outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512)
result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True)
return result
print(inference("how many people with name jui and age less than 25", {
"people_name": ["id", "name"],
"people_age": ["people_id", "age"]
}))
print(inference("what is id with name jui and age less than 25", {
"people_name": ["id", "name", "age"]
})))
Dataset
- Downloads last month
- 643
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.