Spaces:
Runtime error
Runtime error
import streamlit as st | |
from predict import PaddleOCR | |
from pdf2image import convert_from_bytes | |
import cv2 | |
import PIL | |
import numpy as np | |
import os | |
import tempfile | |
import random | |
import string | |
from ultralyticsplus import YOLO | |
import streamlit as st | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as patches | |
import io | |
import re | |
from dateutil.parser import parse | |
import datetime | |
from file_utils import ( | |
get_img, | |
save_excel_file, | |
concat_csv, | |
convert_pdf_to_image, | |
filter_color, | |
plot, | |
delete_file, | |
) | |
from process import ( | |
filter_columns, | |
extract_text_of_col, | |
prepare_cols, | |
process_cols, | |
finalize_data, | |
) | |
table_model = YOLO("table.pt") | |
column_model = YOLO("columns.pt") | |
def remove_dots(string): | |
# Remove dots from the first and last position of the string | |
string = string.strip('.') | |
# Remove the first dot from left to right if there are still more than one dots | |
if string.count('.') > 1: | |
string = string.replace(".", "", 1) | |
return string | |
def convert_df(df): | |
return df.to_csv(index=False).encode('utf-8') | |
def PIL_to_cv(pil_img): | |
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) | |
def cv_to_PIL(cv_img): | |
return PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)) | |
def visualize_ocr(pil_img, ocr_result): | |
plt.imshow(pil_img, interpolation='lanczos') | |
plt.gcf().set_size_inches(20, 20) | |
ax = plt.gca() | |
for idx, result in enumerate(ocr_result): | |
bbox = result['bbox'] | |
text = result['text'] | |
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=2, edgecolor='red', facecolor='none', linestyle='-') | |
ax.add_patch(rect) | |
ax.text(bbox[0], bbox[1], text, horizontalalignment='left', verticalalignment='bottom', color='blue', fontsize=7) | |
plt.xticks([], []) | |
plt.yticks([], []) | |
plt.gcf().set_size_inches(10, 10) | |
plt.axis('off') | |
img_buf = io.BytesIO() | |
plt.savefig(img_buf, bbox_inches='tight', dpi=150) | |
plt.close() | |
return PIL.Image.open(img_buf) | |
def filter_columns(columns: np.ndarray): | |
for idx, col in enumerate(columns): | |
if idx >= len(columns) - 1: | |
break | |
nxt = columns[idx + 1] | |
threshold = ((col[2] - col[0]) + (nxt[2] - nxt[0])) / 2 | |
if (col[2] - columns[idx + 1][0]) > threshold * 0.5: | |
col[1], col[2], col[3] = min(col[1], nxt[1]), nxt[2], max(col[3], nxt[3]) | |
columns = np.delete(columns, idx + 1, 0) | |
idx -= 1 | |
return columns | |
st.title("Extract data from bank statements") | |
model = PaddleOCR() | |
uploaded = st.file_uploader( | |
"upload a bank statement image", | |
type=["png", "jpg", "jpeg", "PNG", "JPG", "JPEG", "pdf", "PDF"], | |
) | |
number = st.number_input('Insert a number',value=2023, step=1) | |
filter = st.checkbox("filter color") | |
if st.button('Analyze image'): | |
final_csv = pd.DataFrame() | |
first_flag_dataframe=0 | |
if uploaded is None: | |
st.write('Please upload an image') | |
else: | |
tabs = st.tabs( | |
['Pages','Table Detection', 'Table Structure Recognition', 'Extracted Table(s)'] | |
) | |
print(uploaded.type) | |
if uploaded.type == "application/pdf": | |
foldername = tempfile.TemporaryDirectory(dir=os.getcwd()) | |
filename = uploaded.name.split(".")[0] | |
pdf_pages=convert_from_bytes(uploaded.read(),500) | |
for page_enumeration, page in enumerate(pdf_pages, start=1): | |
with tabs[0]: | |
st.header('Pages : '+str(page_enumeration)) | |
st.image(page) | |
page_img=np.asarray(page) | |
tables = PaddleOCR.table_model(page_img, conf=0.60) | |
tabel_datas=tables[0].boxes.data.cpu().numpy() | |
tables = tables[0].boxes.xyxy.cpu().numpy() | |
with tabs[1]: | |
st.header('Table Detection Page :'+str(page_enumeration)) | |
str_cols = st.columns(4) | |
str_cols[0].subheader('Table image') | |
str_cols[1].subheader('Columns') | |
str_cols[2].subheader('Structure result') | |
str_cols[3].subheader('Cells result') | |
results = [] | |
for table in tables: | |
try: | |
tabel_data = np.array( | |
sorted(tabel_datas, key=lambda x: x[0]), dtype=np.ndarray | |
) | |
tabel_data = filter_columns(tabel_data) | |
str_cols[0].image(plot(page_img, tabel_data), channels="RGB") | |
# * crop the table as an image from the original image | |
sub_img = page_img[ | |
int(table[1].item()): int(table[3].item()), | |
int(table[0].item()): int(table[2].item()), | |
] | |
columns_detect = PaddleOCR.column_model(sub_img, conf=0.65) | |
cols_data = columns_detect[0].boxes.data.cpu().numpy() | |
# * Sort columns according to the x coordinate | |
cols_data = np.array( | |
sorted(cols_data, key=lambda x: x[0]), dtype=np.ndarray | |
) | |
# * merge the duplicated columns | |
cols_data = filter_columns(cols_data) | |
str_cols[1].image(plot(sub_img, cols_data), channels="RGB") | |
except Exception as e: | |
print(e) | |
st.warning("No Detection") | |
try: | |
#################################################################### | |
# # columns = cols_data[:, 0:4] | |
# # #sub_imgs = [] | |
# # thr = 0 | |
# # column = columns[0] | |
# # maxcol1=int(column[1]) | |
# # maxcol3=int(column[3]) | |
# # cols = [] | |
# # for column in columns: | |
# # if maxcol1 < int(column[1]) : | |
# # maxcol1=int(column[1]) | |
# # if maxcol3 < int(column[3]) : | |
# # maxcol3=int(column[3]) | |
# # sub_imgs = (sub_img[ maxcol1: maxcol3, : ]) | |
# # str_cols[2].image(sub_imgs) | |
# # image = filter_color(sub_imgs) | |
# # res, threshold,ocr_res = extract_text_of_col(image) | |
# # vis_ocr_img = visualize_ocr(image, ocr_res) | |
# # str_cols[3].image(vis_ocr_img) | |
# # thr += threshold | |
# # cols.append(prepare_cols(res, threshold * 0.6)) | |
# # print("cols : ",cols) | |
# # thr = thr / len(columns) | |
# # data = process_cols(cols, thr * 0.6) | |
# # print("data : ",data) | |
###################################################################### | |
columns = cols_data[:, 0:4] | |
sub_imgs = [] | |
column = columns[0] | |
maxcol1=int(column[1]) | |
maxcol3=int(column[3]) | |
#for column in columns: | |
# if maxcol1 < int(column[1]) : | |
# maxcol1=int(column[1]) | |
# if maxcol3 < int(column[3]) : | |
# maxcol3=int(column[3]) | |
for column in columns: | |
# * Create list of cropped images for each column | |
sub_imgs.append(sub_img[maxcol1:maxcol3, int(column[0]): int(column[2])]) | |
cols = [] | |
thr = 0 | |
for image in sub_imgs: | |
if filter: | |
# * keep only black color in the image | |
image = filter_color(image) | |
# * extract text of each column and get the length threshold | |
res, threshold, ocr_res = extract_text_of_col(image) | |
thr += threshold | |
# * arrange the rows of each column with respect to row length threshold | |
cols.append(prepare_cols(res, threshold * 0.6)) | |
thr = thr / len(sub_imgs) | |
# * append each element in each column to its right place in the dataframe | |
data = process_cols(cols, thr * 0.6) | |
# * merge the related rows together | |
data: pd.DataFrame = finalize_data(data, page_enumeration) | |
results.append(data) | |
with tabs[2]: | |
st.header('Extracted Table(s)') | |
st.dataframe(data) | |
print("data : ",data) | |
print("results : ", results) | |
if first_flag_dataframe == 0 : | |
first_flag_dataframe=1 | |
final_csv=data | |
else: | |
final_csv = pd.concat([final_csv,data],ignore_index=True) | |
csv = convert_df(data) | |
print(csv) | |
except: | |
st.warning("Text Extraction Failed") | |
continue | |
with tabs[3]: | |
st.dataframe(final_csv) | |
rough_csv= convert_df(final_csv) | |
st.download_button( | |
"rough-csv", | |
rough_csv, | |
"file.csv", | |
"text/csv", | |
key='rough-csv' | |
) | |
final_csv.columns = ['page','Date', 'Transaction_Details', 'Three', 'Deposit','Withdrawal','Balance'] | |
#final_csv = final_csv.rename(columns={1: 'Date', 2: 'Transaction_Details', 3: 'Three', 4: 'Deposit',5 : 'Withdrawal',6:'Balance'}) | |
final_csv['Date'] = final_csv['Date'].astype(str) | |
st.dataframe(final_csv) | |
final_csv = final_csv[~final_csv['Date'].str.contains('Date')] | |
final_csv = final_csv[~final_csv['Date'].str.contains('日期')] | |
final_csv = final_csv[~final_csv['Date'].str.contains('口期')] | |
final_csv['Date'] = final_csv['Date'].apply(lambda x: re.sub(r'[^a-zA-Z0-9 ]', '', x)) | |
final_csv['Date'] = final_csv['Date'].apply(lambda x: x + str(number)) | |
final_csv['Date'] = final_csv['Date'].apply(lambda x:parse(x, fuzzy=True)) | |
#final_csv['Date']=final_csv['Date'].str.replace(' ', '') | |
final_csv['*Date'] = pd.to_datetime(final_csv['Date']).dt.strftime('%d-%m-%Y') | |
final_csv['Withdrawal'] = final_csv['Withdrawal'].astype(str) | |
final_csv['Withdrawal'] = final_csv['Withdrawal'].str.replace('i', '').str.replace('E', '').str.replace(':', '').str.replace('M', '').str.replace('?', '').str.replace('t', '').str.replace('+', '').str.replace(';', '').str.replace('g', '').str.replace('^', '').str.replace('m', '').str.replace('/', '').str.replace('#', '').str.replace("'", '').str.replace('w', '').str.replace('"', '').str.replace('%', '').str.replace('r', '').str.replace('-', '').str.replace('v', '').str.replace(',', '').str.replace('·', '').str.replace(':', '').str.replace(' ', '').str.replace('*', '').str.replace('~', '').str.replace('V', '') | |
final_csv['Withdrawal'] = final_csv['Withdrawal'].apply(remove_dots) | |
final_csv['Withdrawal'] = final_csv['Withdrawal'].astype(float)*-1 | |
final_csv['Deposit'] = final_csv['Deposit'].astype(str) | |
final_csv['Deposit'] = final_csv['Deposit'].str.replace('i', '').str.replace('E', '').str.replace(':', '').str.replace('M', '').str.replace('?', '').str.replace('t', '').str.replace('+', '').str.replace(';', '').str.replace('g', '').str.replace('^', '').str.replace('m', '').str.replace('/', '').str.replace('#', '').str.replace("'", '').str.replace('w', '').str.replace('"', '').str.replace('%', '').str.replace('r', '').str.replace('-', '').str.replace('v', '').str.replace(',', '').str.replace('·', '').str.replace(':', '').str.replace(' ', '').str.replace('*', '').str.replace('~', '').str.replace('V', '') | |
final_csv['Deposit'] = final_csv['Deposit'].apply(remove_dots) | |
final_csv['Deposit'] = final_csv['Deposit'].astype(float) | |
final_csv['*Amount'] = final_csv['Withdrawal'].fillna(0) + final_csv['Deposit'].fillna(0) | |
final_csv = final_csv.drop(['Withdrawal','Deposit'], axis=1) | |
final_csv['Payee'] = '' | |
final_csv['Description'] = final_csv['Transaction_Details'] | |
final_csv.loc[final_csv['Three'].notnull(), 'Description'] += " "+final_csv['Three'] | |
final_csv = final_csv.drop(['Transaction_Details','Three'], axis=1) | |
final_csv['Reference'] = '' | |
final_csv['Check Number'] = '' | |
df = final_csv[['*Date', '*Amount', 'Payee', 'Description','Reference','Check Number']] | |
df = df[df['*Amount'] != 0] | |
csv = convert_df(df) | |
st.dataframe(df) | |
st.download_button( | |
"Press to Download", | |
csv, | |
"file.csv", | |
"text/csv", | |
key='download-csv' | |
) | |
#success = st.button("Extract", on_click=model, args=[uploaded, filter]) | |