Binder / nsql /database.py
tianbaoxiexxx's picture
Update nsql/database.py
9a01777
import copy
import os
import sqlite3
import records
import sqlalchemy
import pandas as pd
from typing import Dict, List
import uuid
from utils.normalizer import convert_df_type, prepare_df_for_neuraldb_from_table
from utils.mmqa.image_stuff import get_caption
def check_in_and_return(key: str, source: dict):
# `` wrapped means as a whole
if key.startswith("`") and key.endswith("`"):
key = key[1:-1]
if key in source.keys():
return source[key]
else:
for _k, _v in source.items():
if _k.lower() == key.lower():
return _v
raise ValueError("{} not in {}".format(key, source))
class NeuralDB(object):
def __init__(self, tables: List[Dict[str, Dict]], passages=None, images=None):
self.raw_tables = copy.deepcopy(tables)
self.passages = {}
self.images = {}
self.image_captions = {}
self.passage_linker = {} # The links from cell value to passage
self.image_linker = {} # The links from cell value to images
# Get passages
if passages:
for passage in passages:
title, passage_content = passage['title'], passage['text']
self.passages[title] = passage_content
# Get images
if images:
for image in images:
_id, title, picture = image['id'], image['title'], image['pic']
self.images[title] = picture
self.image_captions[title] = get_caption(_id)
# Link grounding resources from other modalities(passages, images).
if self.raw_tables[0]['table'].get('rows_with_links', None):
rows = self.raw_tables[0]['table']['rows']
rows_with_links = self.raw_tables[0]['table']['rows_with_links']
link_title2cell_map = {}
for row_id in range(len(rows)):
for col_id in range(len(rows[row_id])):
cell = rows_with_links[row_id][col_id]
for text, title, url in zip(cell[0], cell[1], cell[2]):
text = text.lower().strip()
link_title2cell_map[title] = text
# Link Passages
for passage in passages:
title, passage_content = passage['title'], passage['text']
linked_cell = link_title2cell_map.get(title, None)
if linked_cell:
self.passage_linker[linked_cell] = title
# Images
for image in images:
title, picture = image['title'], image['pic']
linked_cell = link_title2cell_map.get(title, None)
if linked_cell:
self.image_linker[linked_cell] = title
for table_info in tables:
table_info['table'] = prepare_df_for_neuraldb_from_table(table_info['table'])
self.tables = tables
# Connect to SQLite database
self.tmp_path = "tmp"
os.makedirs(self.tmp_path, exist_ok=True)
# self.db_path = os.path.join(self.tmp_path, '{}.db'.format(hash(time.time())))
self.db_path = os.path.join(self.tmp_path, '{}.db'.format(uuid.uuid4()))
self.sqlite_conn = sqlite3.connect(self.db_path, check_same_thread=False)
# Create DB
assert len(tables) >= 1, "DB has no table inside"
table_0 = tables[0]
if len(tables) > 1:
raise ValueError("More than one table not support yet.")
else:
table_0["table"].to_sql("w", self.sqlite_conn)
self.table_name = "w"
self.table_title = table_0.get('title', None)
# Records conn
self.db = records.Database('sqlite:///{}'.format(self.db_path))
self.records_conn = self.db.get_connection()
def __str__(self):
return str(self.execute_query("SELECT * FROM {}".format(self.table_name)))
def get_table(self, table_name=None):
table_name = self.table_name if not table_name else table_name
sql_query = "SELECT * FROM {}".format(table_name)
_table = self.execute_query(sql_query)
return _table
def get_header(self, table_name=None):
_table = self.get_table(table_name)
return _table['header']
def get_rows(self, table_name):
_table = self.get_table(table_name)
return _table['rows']
def get_table_df(self):
return self.tables[0]['table']
def get_table_raw(self):
return self.raw_tables[0]['table']
def get_table_title(self):
return self.tables[0]['title']
def get_passages_titles(self):
return list(self.passages.keys())
def get_images_titles(self):
return list(self.images.keys())
def get_passage_by_title(self, title: str):
return check_in_and_return(title, self.passages)
def get_image_by_title(self, title):
return check_in_and_return(title, self.images)
def get_image_caption_by_title(self, title):
return check_in_and_return(title, self.image_captions)
def get_image_linker(self):
return copy.deepcopy(self.image_linker)
def get_passage_linker(self):
return copy.deepcopy(self.passage_linker)
def execute_query(self, sql_query: str):
"""
Basic operation. Execute the sql query on the database we hold.
@param sql_query:
@return:
"""
# When the sql query is a column name (@deprecated: or a certain value with '' and "" surrounded).
if len(sql_query.split(' ')) == 1 or (sql_query.startswith('`') and sql_query.endswith('`')):
col_name = sql_query
new_sql_query = r"SELECT row_id, {} FROM {}".format(col_name, self.table_name)
# Here we use a hack that when a value is surrounded by '' or "", the sql will return a column of the value,
# while for variable, no ''/"" surrounded, this sql will query for the column.
out = self.records_conn.query(new_sql_query)
# When the sql query wants all cols or col_id, which is no need for us to add 'row_id'.
elif sql_query.lower().startswith("select *") or sql_query.startswith("select col_id"):
out = self.records_conn.query(sql_query)
else:
try:
# SELECT row_id in addition, needed for result and old table alignment.
new_sql_query = "SELECT row_id, " + sql_query[7:]
out = self.records_conn.query(new_sql_query)
except sqlalchemy.exc.OperationalError as e:
# Execute normal SQL, and in this case the row_id is actually in no need.
out = self.records_conn.query(sql_query)
results = out.all()
unmerged_results = []
merged_results = []
headers = out.dataset.headers
for i in range(len(results)):
unmerged_results.append(list(results[i].values()))
merged_results.extend(results[i].values())
return {"header": headers, "rows": unmerged_results}
def add_sub_table(self, sub_table, table_name=None, verbose=True):
"""
Add sub_table into the table.
@return:
"""
table_name = self.table_name if not table_name else table_name
sql_query = "SELECT * FROM {}".format(table_name)
oring_table = self.execute_query(sql_query)
old_table = pd.DataFrame(oring_table["rows"], columns=oring_table["header"])
# concat the new column into old table
sub_table_df_normed = convert_df_type(pd.DataFrame(data=sub_table['rows'], columns=sub_table['header']))
new_table = old_table.merge(sub_table_df_normed,
how='left', on='row_id') # do left join
new_table.to_sql(table_name, self.sqlite_conn, if_exists='replace',
index=False)
if verbose:
print("Insert column(s) {} (dtypes: {}) into table.\n".format(', '.join([_ for _ in sub_table['header']]),
sub_table_df_normed.dtypes))