Spaces:
Runtime error
Runtime error
import copy | |
import os | |
import sqlite3 | |
import records | |
import sqlalchemy | |
import pandas as pd | |
from typing import Dict, List | |
import uuid | |
from utils.normalizer import convert_df_type, prepare_df_for_neuraldb_from_table | |
from utils.mmqa.image_stuff import get_caption | |
def check_in_and_return(key: str, source: dict): | |
# `` wrapped means as a whole | |
if key.startswith("`") and key.endswith("`"): | |
key = key[1:-1] | |
if key in source.keys(): | |
return source[key] | |
else: | |
for _k, _v in source.items(): | |
if _k.lower() == key.lower(): | |
return _v | |
raise ValueError("{} not in {}".format(key, source)) | |
class NeuralDB(object): | |
def __init__(self, tables: List[Dict[str, Dict]], passages=None, images=None): | |
self.raw_tables = copy.deepcopy(tables) | |
self.passages = {} | |
self.images = {} | |
self.image_captions = {} | |
self.passage_linker = {} # The links from cell value to passage | |
self.image_linker = {} # The links from cell value to images | |
# Get passages | |
if passages: | |
for passage in passages: | |
title, passage_content = passage['title'], passage['text'] | |
self.passages[title] = passage_content | |
# Get images | |
if images: | |
for image in images: | |
_id, title, picture = image['id'], image['title'], image['pic'] | |
self.images[title] = picture | |
self.image_captions[title] = get_caption(_id) | |
# Link grounding resources from other modalities(passages, images). | |
if self.raw_tables[0]['table'].get('rows_with_links', None): | |
rows = self.raw_tables[0]['table']['rows'] | |
rows_with_links = self.raw_tables[0]['table']['rows_with_links'] | |
link_title2cell_map = {} | |
for row_id in range(len(rows)): | |
for col_id in range(len(rows[row_id])): | |
cell = rows_with_links[row_id][col_id] | |
for text, title, url in zip(cell[0], cell[1], cell[2]): | |
text = text.lower().strip() | |
link_title2cell_map[title] = text | |
# Link Passages | |
for passage in passages: | |
title, passage_content = passage['title'], passage['text'] | |
linked_cell = link_title2cell_map.get(title, None) | |
if linked_cell: | |
self.passage_linker[linked_cell] = title | |
# Images | |
for image in images: | |
title, picture = image['title'], image['pic'] | |
linked_cell = link_title2cell_map.get(title, None) | |
if linked_cell: | |
self.image_linker[linked_cell] = title | |
for table_info in tables: | |
table_info['table'] = prepare_df_for_neuraldb_from_table(table_info['table']) | |
self.tables = tables | |
# Connect to SQLite database | |
self.tmp_path = "tmp" | |
os.makedirs(self.tmp_path, exist_ok=True) | |
# self.db_path = os.path.join(self.tmp_path, '{}.db'.format(hash(time.time()))) | |
self.db_path = os.path.join(self.tmp_path, '{}.db'.format(uuid.uuid4())) | |
self.sqlite_conn = sqlite3.connect(self.db_path, check_same_thread=False) | |
# Create DB | |
assert len(tables) >= 1, "DB has no table inside" | |
table_0 = tables[0] | |
if len(tables) > 1: | |
raise ValueError("More than one table not support yet.") | |
else: | |
table_0["table"].to_sql("w", self.sqlite_conn) | |
self.table_name = "w" | |
self.table_title = table_0.get('title', None) | |
# Records conn | |
self.db = records.Database('sqlite:///{}'.format(self.db_path)) | |
self.records_conn = self.db.get_connection() | |
def __str__(self): | |
return str(self.execute_query("SELECT * FROM {}".format(self.table_name))) | |
def get_table(self, table_name=None): | |
table_name = self.table_name if not table_name else table_name | |
sql_query = "SELECT * FROM {}".format(table_name) | |
_table = self.execute_query(sql_query) | |
return _table | |
def get_header(self, table_name=None): | |
_table = self.get_table(table_name) | |
return _table['header'] | |
def get_rows(self, table_name): | |
_table = self.get_table(table_name) | |
return _table['rows'] | |
def get_table_df(self): | |
return self.tables[0]['table'] | |
def get_table_raw(self): | |
return self.raw_tables[0]['table'] | |
def get_table_title(self): | |
return self.tables[0]['title'] | |
def get_passages_titles(self): | |
return list(self.passages.keys()) | |
def get_images_titles(self): | |
return list(self.images.keys()) | |
def get_passage_by_title(self, title: str): | |
return check_in_and_return(title, self.passages) | |
def get_image_by_title(self, title): | |
return check_in_and_return(title, self.images) | |
def get_image_caption_by_title(self, title): | |
return check_in_and_return(title, self.image_captions) | |
def get_image_linker(self): | |
return copy.deepcopy(self.image_linker) | |
def get_passage_linker(self): | |
return copy.deepcopy(self.passage_linker) | |
def execute_query(self, sql_query: str): | |
""" | |
Basic operation. Execute the sql query on the database we hold. | |
@param sql_query: | |
@return: | |
""" | |
# When the sql query is a column name (@deprecated: or a certain value with '' and "" surrounded). | |
if len(sql_query.split(' ')) == 1 or (sql_query.startswith('`') and sql_query.endswith('`')): | |
col_name = sql_query | |
new_sql_query = r"SELECT row_id, {} FROM {}".format(col_name, self.table_name) | |
# Here we use a hack that when a value is surrounded by '' or "", the sql will return a column of the value, | |
# while for variable, no ''/"" surrounded, this sql will query for the column. | |
out = self.records_conn.query(new_sql_query) | |
# When the sql query wants all cols or col_id, which is no need for us to add 'row_id'. | |
elif sql_query.lower().startswith("select *") or sql_query.startswith("select col_id"): | |
out = self.records_conn.query(sql_query) | |
else: | |
try: | |
# SELECT row_id in addition, needed for result and old table alignment. | |
new_sql_query = "SELECT row_id, " + sql_query[7:] | |
out = self.records_conn.query(new_sql_query) | |
except sqlalchemy.exc.OperationalError as e: | |
# Execute normal SQL, and in this case the row_id is actually in no need. | |
out = self.records_conn.query(sql_query) | |
results = out.all() | |
unmerged_results = [] | |
merged_results = [] | |
headers = out.dataset.headers | |
for i in range(len(results)): | |
unmerged_results.append(list(results[i].values())) | |
merged_results.extend(results[i].values()) | |
return {"header": headers, "rows": unmerged_results} | |
def add_sub_table(self, sub_table, table_name=None, verbose=True): | |
""" | |
Add sub_table into the table. | |
@return: | |
""" | |
table_name = self.table_name if not table_name else table_name | |
sql_query = "SELECT * FROM {}".format(table_name) | |
oring_table = self.execute_query(sql_query) | |
old_table = pd.DataFrame(oring_table["rows"], columns=oring_table["header"]) | |
# concat the new column into old table | |
sub_table_df_normed = convert_df_type(pd.DataFrame(data=sub_table['rows'], columns=sub_table['header'])) | |
new_table = old_table.merge(sub_table_df_normed, | |
how='left', on='row_id') # do left join | |
new_table.to_sql(table_name, self.sqlite_conn, if_exists='replace', | |
index=False) | |
if verbose: | |
print("Insert column(s) {} (dtypes: {}) into table.\n".format(', '.join([_ for _ in sub_table['header']]), | |
sub_table_df_normed.dtypes)) | |