import argparse import json from utils.sql.process_sql import ( tokenize, CLAUSE_KEYWORDS, WHERE_OPS, COND_OPS, UNIT_OPS, AGG_OPS, JOIN_KEYWORDS, ORDER_OPS, skip_semicolon, SQL_OPS) KEPT_WHERE_OP = ('not', 'in', 'exists') def parse_table_unit(toks, start_idx, tables_with_alias): idx = start_idx len_ = len(toks) key = toks[idx] if idx + 1 < len_ and toks[idx + 1] == "as": tables_with_alias[toks[idx + 2]] = toks[idx] idx += 3 else: idx += 1 return idx, key def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None): """ :returns next idx, column id """ tok = toks[start_idx] if tok == "*": return start_idx + 1 if '.' in tok: # if token is a composite alias, col = tok.split('.') # key = tables_with_alias[alias] + "." + col table = tables_with_alias[alias] """ Add schema """ if table not in schema: schema[table] = [] schema[table].append(col) # We also want to normalize the column toks[start_idx] = "{}.{}".format(table, col) """ END """ return start_idx + 1 assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty" # assert len(default_tables) == 1, "Default table should only have one time" """ Add schema """ # Find the best table here def choose_best_table(default_tables, tok): lower_tok = tok.lower() candidate = process.extractOne(lower_tok, [table.lower() for table in default_tables])[0] return candidate if len(default_tables) != 1: # print(default_tables) table = choose_best_table(default_tables, tok) # assert len(default_tables) == 1, "Default table should only have one time" else: table = default_tables[0] if table not in schema: schema[table] = [] schema[table].append(tok) toks[start_idx] = "{}.{}".format(table, tok) return start_idx + 1 # for alias in default_tables: # table = tables_with_alias[alias] # if tok in schema.schema[table]: # key = table + "." + tok # return start_idx + 1, schema.idMap[key] # assert False, "Error col: {}".format(tok) def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None, end_idx=None): """ :returns next idx, (agg_op id, col_id) """ idx = start_idx if end_idx is not None: len_ = len(toks[start_idx:end_idx]) else: len_ = len(toks) isBlock = False isDistinct = False if toks[idx] == '(': isBlock = True idx += 1 if toks[idx] in AGG_OPS: agg_id = AGG_OPS.index(toks[idx]) idx += 1 assert idx < len_ and toks[idx] == '(' idx += 1 if toks[idx] == "distinct": idx += 1 isDistinct = True idx = parse_col(toks, idx, tables_with_alias, schema, default_tables) assert idx < len_ and toks[idx] == ')' idx += 1 return idx if toks[idx] == "distinct": idx += 1 isDistinct = True agg_id = AGG_OPS.index("none") idx = parse_col(toks, idx, tables_with_alias, schema, default_tables) if isBlock: assert toks[idx] == ')' idx += 1 # skip ')' return idx def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): idx = start_idx len_ = len(toks) isBlock = False if toks[idx] == '(': isBlock = True idx += 1 col_unit1 = None col_unit2 = None unit_op = UNIT_OPS.index('none') idx = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) if idx < len_ and toks[idx] in UNIT_OPS: unit_op = UNIT_OPS.index(toks[idx]) idx += 1 idx = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) if isBlock: assert toks[idx] == ')' idx += 1 # skip ')' return idx def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None): idx = start_idx len_ = len(toks) isBlock = False if toks[idx] == '(': isBlock = True idx += 1 if toks[idx] == 'select': idx = parse_sql(toks, idx, schema) elif "\"" in toks[idx]: # token is a string value val = toks[idx] # Replace with placeholder toks[idx] = "_str_value_" idx += 1 else: try: val = float(toks[idx]) toks[idx] = "_num_value_" idx += 1 except: end_idx = idx while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')' \ and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[ end_idx] not in JOIN_KEYWORDS: end_idx += 1 # idx = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables) idx = parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables, end_idx=end_idx) idx = end_idx if isBlock: assert toks[idx] == ')' idx += 1 return idx def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None): idx = start_idx len_ = len(toks) # conds = [] while idx < len_: idx = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) not_op = False if toks[idx] == 'not': not_op = True idx += 1 assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx]) op_id = WHERE_OPS.index(toks[idx]) idx += 1 val1 = val2 = None if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values idx = parse_value(toks, idx, tables_with_alias, schema, default_tables) assert toks[idx] == 'and' idx += 1 idx = parse_value(toks, idx, tables_with_alias, schema, default_tables) else: # normal case: single value idx = parse_value(toks, idx, tables_with_alias, schema, default_tables) val2 = None # conds.append((not_op, op_id, val_unit, val1, val2)) if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS): break if idx < len_ and toks[idx] in COND_OPS: # conds.append(toks[idx]) idx += 1 # skip and/or return idx# , conds def parse_from(toks, start_idx, schema): assert 'from' in toks[start_idx:], "'from' not found" tables_with_alias = {} len_ = len(toks) idx = toks.index('from', start_idx) + 1 default_tables = [] table_units = [] conds = [] # print(idx, len_) while idx < len_: # print("idx", idx, toks[idx]) isBlock = False if toks[idx] == '(': isBlock = True idx += 1 if toks[idx] == 'select': idx = parse_sql(toks, idx, schema) # table_units.append((TABLE_TYPE['sql'], sql)) else: if idx < len_ and toks[idx] == 'join': idx += 1 # skip join idx, table_name = parse_table_unit(toks, idx, tables_with_alias) # print(table_name) # table_units.append((TABLE_TYPE['table_unit'], table_unit)) default_tables.append(table_name) """ Add schema """ if table_name not in schema: schema[table_name] = [] """ END """ if idx < len_ and toks[idx] == "on": idx += 1 # skip on idx = parse_condition(toks, idx, tables_with_alias, schema, default_tables) # if len(conds) > 0: # conds.append('and') # conds.extend(this_conds) if isBlock: assert toks[idx] == ')' idx += 1 if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): break return idx, default_tables, tables_with_alias def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None): idx = start_idx len_ = len(toks) assert toks[idx] == 'select', "'select' not found" idx += 1 isDistinct = False if idx < len_ and toks[idx] == 'distinct': idx += 1 isDistinct = True val_units = [] while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS: agg_id = AGG_OPS.index("none") if toks[idx] in AGG_OPS: agg_id = AGG_OPS.index(toks[idx]) idx += 1 idx = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) # val_units.append((agg_id, val_unit)) if idx < len_ and toks[idx] == ',': idx += 1 # skip ',' return idx def parse_where(toks, start_idx, tables_with_alias, schema, default_tables): idx = start_idx len_ = len(toks) if idx >= len_ or toks[idx] != 'where': return idx idx += 1 idx = parse_condition(toks, idx, tables_with_alias, schema, default_tables) return idx def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables): idx = start_idx len_ = len(toks) col_units = [] if idx >= len_ or toks[idx] != 'group': return idx idx += 1 assert toks[idx] == 'by' idx += 1 while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): idx = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) # col_units.append(col_unit) if idx < len_ and toks[idx] == ',': idx += 1 # skip ',' else: break return idx def parse_having(toks, start_idx, tables_with_alias, schema, default_tables): idx = start_idx len_ = len(toks) if idx >= len_ or toks[idx] != 'having': return idx idx += 1 idx = parse_condition(toks, idx, tables_with_alias, schema, default_tables) return idx def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables): idx = start_idx len_ = len(toks) val_units = [] order_type = 'asc' # default type is 'asc' if idx >= len_ or toks[idx] != 'order': return idx idx += 1 assert toks[idx] == 'by' idx += 1 while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): idx = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) # val_units.append(val_unit) if idx < len_ and toks[idx] in ORDER_OPS: order_type = toks[idx] idx += 1 if idx < len_ and toks[idx] == ',': idx += 1 # skip ',' else: break return idx def parse_limit(toks, start_idx): idx = start_idx len_ = len(toks) if idx < len_ and toks[idx] == 'limit': idx += 2 toks[idx - 1] = "_limit_value_" # make limit value can work, cannot assume put 1 as a fake limit number if type(toks[idx - 1]) != int: return idx return idx return idx def parse_sql(toks, start_idx, schema): isBlock = False # indicate whether this is a block of sql/sub-sql len_ = len(toks) idx = start_idx if toks[idx] == '(': isBlock = True idx += 1 from_end_idx, default_tables, tables_with_alias = parse_from(toks, start_idx, schema) _ = parse_select(toks, idx, tables_with_alias, schema, default_tables) idx = from_end_idx idx = parse_where(toks, idx, tables_with_alias, schema, default_tables) idx = parse_group_by(toks, idx, tables_with_alias, schema, default_tables) idx = parse_having(toks, idx, tables_with_alias, schema, default_tables) idx = parse_order_by(toks, idx, tables_with_alias, schema, default_tables) idx = parse_limit(toks, idx) # idx = skip_semicolon(toks, idx) if isBlock: assert toks[idx] == ')' idx += 1 # skip ')' idx = skip_semicolon(toks, idx) # for op in SQL_OPS: # initialize IUE # sql[op] = None if idx < len_ and toks[idx] in SQL_OPS: sql_op = toks[idx] idx += 1 idx = parse_sql(toks, idx, schema) # sql[sql_op] = IUE_sql return idx def extract_schema_from_sql(schema, sql): toks = tokenize(sql) parse_sql(toks=toks, start_idx=0, schema=schema) return toks def extract_template_from_sql(sql, schema={}): try: toks = tokenize(sql) except: print("Tokenization error for {}".format(sql)) toks = [] # print(toks) template = [] # ignore_follow_up_and = False len_ = len(toks) idx = 0 while idx < len_: tok = toks[idx] if tok == "from": template.append(tok) if toks[idx+1] != "(": template.append("[FROM_PART]") idx += 1 while idx < len_ and (toks[idx] not in CLAUSE_KEYWORDS and toks[idx] != ")"): idx += 1 continue elif tok in CLAUSE_KEYWORDS: template.append(tok) elif tok in AGG_OPS: template.append(tok) elif tok in [",", "*", "(", ")", "having", "by", "distinct"]: template.append(tok) elif tok in ["asc", "desc"]: template.append("[ORDER_DIRECTION]") elif tok in WHERE_OPS: if tok in KEPT_WHERE_OP: template.append(tok) else: template.append("[WHERE_OP]") if tok == "between": idx += 2 elif tok in COND_OPS: template.append(tok) elif template[-1] == "[WHERE_OP]": template.append("[VALUE]") elif template[-1] == "limit": template.append("[LIMIT_VALUE]") elif template[-1] != "[MASK]": # value, schema, join on as template.append("[MASK]") idx += 1 return template def extract_partial_template_from_sql(sql, schema={}): toks = tokenize(sql) # print(toks) template = [] # ignore_follow_up_and = False len_ = len(toks) idx = 0 while idx < len_: tok = toks[idx] if tok == "from": template.append(tok) if toks[idx+1] != "(": # template.append("[FROM_PART]") idx += 1 while idx < len_ and (toks[idx] not in CLAUSE_KEYWORDS and toks[idx] != ")"): template.append(toks[idx]) idx += 1 continue elif tok in CLAUSE_KEYWORDS: template.append(tok) elif tok in AGG_OPS: template.append(tok) elif tok in [",", "*", "(", ")", "having", "by", "distinct"]: template.append(tok) elif tok in ["asc", "desc"]: template.append("[ORDER_DIRECTION]") elif tok in WHERE_OPS: if tok in KEPT_WHERE_OP: template.append(tok) else: template.append("[WHERE_OP]") if tok == "between": idx += 2 elif tok in COND_OPS: template.append(tok) elif template[-1] == "[WHERE_OP]": template.append("[VALUE]") elif template[-1] == "limit": template.append("[LIMIT_VALUE]") else: template.append(tok) idx += 1 return template def is_valid_schema(schema): # There is no "." and " " in the column name for table in schema: if "." in table: return False if any([keyword == table for keyword in CLAUSE_KEYWORDS]): return False for column in schema[table]: if "." in column or " " in column or '"' in column or "'" in column: return False return True def clean_sql(sql): while "JOIN JOIN" in sql: sql = sql.replace("JOIN JOIN", "JOIN") if "JOIN WHERE" in sql: sql = sql.replace("JOIN WHERE", "WHERE") if "JOIN GROUP BY" in sql: sql = sql.replace("JOIN GROUP BY", "GROUP BY") return sql if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--input_file", type=str) parser.add_argument("--output_file", type=str) parser.add_argument("--mode", type=str, choices=["debug", "verbose", "silent"]) parser.add_argument("--task", type=str, choices=["template_extraction", "schema_extraction"]) args = parser.parse_args() if args.task == "schema_extraction": if args.mode == "debug": sql = "SELECT count(*) FROM games" sql = sql + " INTERSECT " + "SELECT sacks, year FROM players" sql = sql + " EXCEPT " + 'SELECT T1.year, T1.sacks FROM players AS T1 JOIN tackles AS T2 ON T1.id = T2.player_id WHERE T2.manager = "A" and T2.season NOT IN (SELECT season FROM match WHERE match_name = "IVL" INTERSECT SELECT T1.year, T1.sacks FROM sack AS T1) GROUP BY T1.year, T1.sacks HAVING count(T1.coach) > 10 ORDER BY T2.score LIMIT 5' sql = "SELECT T1.pld FROM pld AS T1 JOIN games AS T2 ON T1.crs_code = T2.crs_code JOIN GROUP BY T1.pld WHERE T2.gf = '8' AND T2.gf = '9'" sql = 'select * from head where height = "6-0" or height = "6-0" order by height asc' schema = {} extract_schema_from_sql(schema, sql) print(schema, is_valid_schema(schema)) elif args.mode == "verbose": fout = open(args.output_file, "w") with open(args.input_file) as fin: for line in fin: example = json.loads(line) schema = {} try: sql = example["sql"] if "sql" in example else example["pred"] sql = clean_sql(sql) example["sql"] = sql extract_schema_from_sql(schema, sql) except: # print(sql) continue for table in schema: schema[table] = list(set(schema[table])) if is_valid_schema(schema): example["extracted_schema"] = schema fout.write(json.dumps(example) + "\n") elif args.mode == "verbose": fout = open(args.output_file, "w") with open(args.input_file) as fin: for line in fin: example = json.loads(line) schema = {} sql = example["sql"] if "sql" in example else example["pred"] sql = clean_sql(sql) example["sql"] = sql extract_schema_from_sql(schema, sql) for table in schema: schema[table] = list(set(schema[table])) example["extracted_schema"] = schema fout.write(json.dumps(example) + "\n") if is_valid_schema(schema): example["extracted_schema"] = schema fout.write(json.dumps(example) + "\n") elif args.task == "template_extraction": if args.mode == "debug": sql = "SELECT avg(T1.Votes) FROM seats AS T1 JOIN votes AS T2 ON T1.Seat_ID = T2.Seat_ID WHERE T1.seats BETWEEN 1 AND 2 and T1.Seats = 1 AND T2.Votes = 10" print(extract_template_from_sql(sql)) print(extract_partial_template_from_sql(sql)) elif args.mode == "verbose": fout_json = open(args.output_file + ".json", "w") fout_txt = open(args.output_file + ".txt", "w") low_freq_txt = open(args.output_file + ".low_freq", "w") high_freq_txt = open(args.output_file + ".high_freq", "w") all_templates = set() # for input_file in args.input_file.split(","): templates = {} with open(args.input_file) as fin: for line in fin: example = json.loads(line) sql = example["sql"] if "sql" in example else example["pred"] if isinstance(sql, list): sql = sql[-1] template = extract_template_from_sql(sql) template_str = " ".join(template) if template_str not in templates: templates[template_str] = [] templates[template_str].append(sql) print("{} has template {}".format(args.input_file, len(templates))) json.dump(templates, fout_json) for template in sorted(templates.keys()): if len(templates[template]) > 1: high_freq_txt.write(template + "\n") else: low_freq_txt.write(template + "\n") fout_txt.write(template + "\n")