Spaces:
Runtime error
Runtime error
import os | |
import tempfile | |
import random | |
import string | |
from ultralyticsplus import YOLO | |
import streamlit as st | |
import numpy as np | |
import pandas as pd | |
from process import ( | |
filter_columns, | |
extract_text_of_col, | |
prepare_cols, | |
process_cols, | |
finalize_data, | |
) | |
from file_utils import ( | |
get_img, | |
save_excel_file, | |
concat_csv, | |
convert_pdf_to_image, | |
filter_color, | |
plot, | |
delete_file, | |
) | |
def process_img( | |
img, | |
page_enumeration: int = 0, | |
filter=False, | |
foldername: str = "", | |
filename: str = "", | |
): | |
tables = PaddleOCR.table_model(img, conf=0.75) | |
tables = tables[0].boxes.xyxy.cpu().numpy() | |
results = [] | |
for table in tables: | |
try: | |
# * crop the table as an image from the original image | |
sub_img = img[ | |
int(table[1].item()): int(table[3].item()), | |
int(table[0].item()): int(table[2].item()), | |
] | |
columns_detect = PaddleOCR.column_model(sub_img, conf=0.75) | |
cols_data = columns_detect[0].boxes.data.cpu().numpy() | |
# * Sort columns according to the x coordinate | |
cols_data = np.array( | |
sorted(cols_data, key=lambda x: x[0]), dtype=np.ndarray | |
) | |
# * merge the duplicated columns | |
cols_data = filter_columns(cols_data) | |
st.image(plot(sub_img, cols_data), channels="RGB") | |
except: | |
st.warning("No Detection") | |
try: | |
columns = cols_data[:, 0:4] | |
sub_imgs = [] | |
for column in columns: | |
# * Create list of cropped images for each column | |
sub_imgs.append(sub_img[:, int(column[0]): int(column[2])]) | |
cols = [] | |
thr = 0 | |
for image in sub_imgs: | |
if filter: | |
# * keep only black color in the image | |
image = filter_color(image) | |
# * extract text of each column and get the length threshold | |
res, threshold = extract_text_of_col(image) | |
thr += threshold | |
# * arrange the rows of each column with respect to row length threshold | |
cols.append(prepare_cols(res, threshold * 0.6)) | |
thr = thr / len(sub_imgs) | |
# * append each element in each column to its right place in the dataframe | |
data = process_cols(cols, thr * 0.6) | |
# * merge the related rows together | |
data: pd.DataFrame = finalize_data(data, page_enumeration) | |
results.append(data) | |
print("data : ",data) | |
print("results : ", results) | |
except: | |
st.warning("Text Extraction Failed") | |
continue | |
list( | |
map( | |
lambda x: save_excel_file( | |
*x, | |
foldername, | |
filename, | |
page_enumeration, | |
), | |
enumerate(results), | |
) | |
) | |
class PaddleOCR: | |
# Load Image Detection model | |
table_model = YOLO("table.pt") | |
column_model = YOLO("columns.pt") | |
def __call__(self, uploaded, filter=False): | |
foldername = tempfile.TemporaryDirectory(dir=os.getcwd()) | |
filename = uploaded.name.split(".")[0] | |
if uploaded.name.split(".")[1].lower() == "pdf": | |
pdf_pages = convert_pdf_to_image(uploaded.read()) | |
for page_enumeration, page in enumerate(pdf_pages, start=1): | |
process_img( | |
np.asarray(page), | |
page_enumeration, | |
filter=filter, | |
foldername=foldername.name, | |
filename=filename, | |
) | |
else: | |
img = get_img(uploaded) | |
process_img( | |
img, | |
filter=filter, | |
foldername=foldername.name, | |
filename=filename, | |
) | |
# * concatenate all csv files if many | |
extra = "".join(random.choices(string.ascii_uppercase, k=5)) | |
filename = f"{filename}_{extra}.csv" | |
try: | |
concat_csv(foldername, filename) | |
except: | |
st.warning("No results found") | |
foldername.cleanup() | |
if os.path.exists(filename): | |
with open(f"{filename}", "rb") as fp: | |
st.download_button( | |
label="Download CSV file", | |
data=fp, | |
file_name=filename, | |
mime="text/csv", | |
) | |
delete_file(filename) | |
else: | |
st.warning("No results found") | |