import streamlit as st from predict import PaddleOCR from pdf2image import convert_from_bytes import cv2 import PIL import numpy as np import os import tempfile import random import string from ultralyticsplus import YOLO import streamlit as st import numpy as np import pandas as pd import matplotlib.pyplot as plt import matplotlib.patches as patches import io import re from dateutil.parser import parse import datetime from file_utils import ( get_img, save_excel_file, concat_csv, convert_pdf_to_image, filter_color, plot, delete_file, ) from process import ( filter_columns, extract_text_of_col, prepare_cols, process_cols, finalize_data, ) table_model = YOLO("table.pt") column_model = YOLO("columns.pt") def remove_dots(string): # Remove dots from the first and last position of the string string = string.strip('.') # Remove the first dot from left to right if there are still more than one dots if string.count('.') > 1: string = string.replace(".", "", 1) return string def convert_df(df): return df.to_csv(index=False).encode('utf-8') def PIL_to_cv(pil_img): return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) def cv_to_PIL(cv_img): return PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)) def visualize_ocr(pil_img, ocr_result): plt.imshow(pil_img, interpolation='lanczos') plt.gcf().set_size_inches(20, 20) ax = plt.gca() for idx, result in enumerate(ocr_result): bbox = result['bbox'] text = result['text'] rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=2, edgecolor='red', facecolor='none', linestyle='-') ax.add_patch(rect) ax.text(bbox[0], bbox[1], text, horizontalalignment='left', verticalalignment='bottom', color='blue', fontsize=7) plt.xticks([], []) plt.yticks([], []) plt.gcf().set_size_inches(10, 10) plt.axis('off') img_buf = io.BytesIO() plt.savefig(img_buf, bbox_inches='tight', dpi=150) plt.close() return PIL.Image.open(img_buf) def filter_columns(columns: np.ndarray): for idx, col in enumerate(columns): if idx >= len(columns) - 1: break nxt = columns[idx + 1] threshold = ((col[2] - col[0]) + (nxt[2] - nxt[0])) / 2 if (col[2] - columns[idx + 1][0]) > threshold * 0.5: col[1], col[2], col[3] = min(col[1], nxt[1]), nxt[2], max(col[3], nxt[3]) columns = np.delete(columns, idx + 1, 0) idx -= 1 return columns st.title("Extract data from bank statements") model = PaddleOCR() uploaded = st.file_uploader( "upload a bank statement image", type=["png", "jpg", "jpeg", "PNG", "JPG", "JPEG", "pdf", "PDF"], ) number = st.number_input('Insert a number',value=2023, step=1) filter = st.checkbox("filter color") if st.button('Analyze image'): final_csv = pd.DataFrame() first_flag_dataframe=0 if uploaded is None: st.write('Please upload an image') else: tabs = st.tabs( ['Pages','Table Detection', 'Table Structure Recognition', 'Extracted Table(s)'] ) print(uploaded.type) if uploaded.type == "application/pdf": foldername = tempfile.TemporaryDirectory(dir=os.getcwd()) filename = uploaded.name.split(".")[0] pdf_pages=convert_from_bytes(uploaded.read(),500) for page_enumeration, page in enumerate(pdf_pages, start=1): with tabs[0]: st.header('Pages : '+str(page_enumeration)) st.image(page) page_img=np.asarray(page) tables = PaddleOCR.table_model(page_img, conf=0.60) tabel_datas=tables[0].boxes.data.cpu().numpy() tables = tables[0].boxes.xyxy.cpu().numpy() with tabs[1]: st.header('Table Detection Page :'+str(page_enumeration)) str_cols = st.columns(4) str_cols[0].subheader('Table image') str_cols[1].subheader('Columns') str_cols[2].subheader('Structure result') str_cols[3].subheader('Cells result') results = [] for table in tables: try: tabel_data = np.array( sorted(tabel_datas, key=lambda x: x[0]), dtype=np.ndarray ) tabel_data = filter_columns(tabel_data) str_cols[0].image(plot(page_img, tabel_data), channels="RGB") # * crop the table as an image from the original image sub_img = page_img[ int(table[1].item()): int(table[3].item()), int(table[0].item()): int(table[2].item()), ] columns_detect = PaddleOCR.column_model(sub_img, conf=0.65) cols_data = columns_detect[0].boxes.data.cpu().numpy() # * Sort columns according to the x coordinate cols_data = np.array( sorted(cols_data, key=lambda x: x[0]), dtype=np.ndarray ) # * merge the duplicated columns cols_data = filter_columns(cols_data) str_cols[1].image(plot(sub_img, cols_data), channels="RGB") except Exception as e: print(e) st.warning("No Detection") try: #################################################################### # # columns = cols_data[:, 0:4] # # #sub_imgs = [] # # thr = 0 # # column = columns[0] # # maxcol1=int(column[1]) # # maxcol3=int(column[3]) # # cols = [] # # for column in columns: # # if maxcol1 < int(column[1]) : # # maxcol1=int(column[1]) # # if maxcol3 < int(column[3]) : # # maxcol3=int(column[3]) # # sub_imgs = (sub_img[ maxcol1: maxcol3, : ]) # # str_cols[2].image(sub_imgs) # # image = filter_color(sub_imgs) # # res, threshold,ocr_res = extract_text_of_col(image) # # vis_ocr_img = visualize_ocr(image, ocr_res) # # str_cols[3].image(vis_ocr_img) # # thr += threshold # # cols.append(prepare_cols(res, threshold * 0.6)) # # print("cols : ",cols) # # thr = thr / len(columns) # # data = process_cols(cols, thr * 0.6) # # print("data : ",data) ###################################################################### columns = cols_data[:, 0:4] sub_imgs = [] column = columns[0] maxcol1=int(column[1]) maxcol3=int(column[3]) #for column in columns: # if maxcol1 < int(column[1]) : # maxcol1=int(column[1]) # if maxcol3 < int(column[3]) : # maxcol3=int(column[3]) for column in columns: # * Create list of cropped images for each column sub_imgs.append(sub_img[maxcol1:maxcol3, int(column[0]): int(column[2])]) cols = [] thr = 0 for image in sub_imgs: if filter: # * keep only black color in the image image = filter_color(image) # * extract text of each column and get the length threshold res, threshold, ocr_res = extract_text_of_col(image) thr += threshold # * arrange the rows of each column with respect to row length threshold cols.append(prepare_cols(res, threshold * 0.6)) thr = thr / len(sub_imgs) # * append each element in each column to its right place in the dataframe data = process_cols(cols, thr * 0.6) # * merge the related rows together data: pd.DataFrame = finalize_data(data, page_enumeration) results.append(data) with tabs[2]: st.header('Extracted Table(s)') st.dataframe(data) print("data : ",data) print("results : ", results) if first_flag_dataframe == 0 : first_flag_dataframe=1 final_csv=data else: final_csv = pd.concat([final_csv,data],ignore_index=True) csv = convert_df(data) print(csv) except: st.warning("Text Extraction Failed") continue with tabs[3]: st.dataframe(final_csv) rough_csv= convert_df(final_csv) st.download_button( "rough-csv", rough_csv, "file.csv", "text/csv", key='rough-csv' ) final_csv.columns = ['page','Date', 'Transaction_Details', 'Three', 'Deposit','Withdrawal','Balance'] #final_csv = final_csv.rename(columns={1: 'Date', 2: 'Transaction_Details', 3: 'Three', 4: 'Deposit',5 : 'Withdrawal',6:'Balance'}) final_csv['Date'] = final_csv['Date'].astype(str) st.dataframe(final_csv) final_csv = final_csv[~final_csv['Date'].str.contains('Date')] final_csv = final_csv[~final_csv['Date'].str.contains('日期')] final_csv = final_csv[~final_csv['Date'].str.contains('口期')] final_csv['Date'] = final_csv['Date'].apply(lambda x: re.sub(r'[^a-zA-Z0-9 ]', '', x)) final_csv['Date'] = final_csv['Date'].apply(lambda x: x + str(number)) final_csv['Date'] = final_csv['Date'].apply(lambda x:parse(x, fuzzy=True)) #final_csv['Date']=final_csv['Date'].str.replace(' ', '') final_csv['*Date'] = pd.to_datetime(final_csv['Date']).dt.strftime('%d-%m-%Y') final_csv['Withdrawal'] = final_csv['Withdrawal'].astype(str) final_csv['Withdrawal'] = final_csv['Withdrawal'].str.replace('i', '').str.replace('E', '').str.replace(':', '').str.replace('M', '').str.replace('?', '').str.replace('t', '').str.replace('+', '').str.replace(';', '').str.replace('g', '').str.replace('^', '').str.replace('m', '').str.replace('/', '').str.replace('#', '').str.replace("'", '').str.replace('w', '').str.replace('"', '').str.replace('%', '').str.replace('r', '').str.replace('-', '').str.replace('v', '').str.replace(',', '').str.replace('·', '').str.replace(':', '').str.replace(' ', '').str.replace('*', '').str.replace('~', '').str.replace('V', '') final_csv['Withdrawal'] = final_csv['Withdrawal'].apply(remove_dots) final_csv['Withdrawal'] = final_csv['Withdrawal'].astype(float)*-1 final_csv['Deposit'] = final_csv['Deposit'].astype(str) final_csv['Deposit'] = final_csv['Deposit'].str.replace('i', '').str.replace('E', '').str.replace(':', '').str.replace('M', '').str.replace('?', '').str.replace('t', '').str.replace('+', '').str.replace(';', '').str.replace('g', '').str.replace('^', '').str.replace('m', '').str.replace('/', '').str.replace('#', '').str.replace("'", '').str.replace('w', '').str.replace('"', '').str.replace('%', '').str.replace('r', '').str.replace('-', '').str.replace('v', '').str.replace(',', '').str.replace('·', '').str.replace(':', '').str.replace(' ', '').str.replace('*', '').str.replace('~', '').str.replace('V', '') final_csv['Deposit'] = final_csv['Deposit'].apply(remove_dots) final_csv['Deposit'] = final_csv['Deposit'].astype(float) final_csv['*Amount'] = final_csv['Withdrawal'].fillna(0) + final_csv['Deposit'].fillna(0) final_csv = final_csv.drop(['Withdrawal','Deposit'], axis=1) final_csv['Payee'] = '' final_csv['Description'] = final_csv['Transaction_Details'] final_csv.loc[final_csv['Three'].notnull(), 'Description'] += " "+final_csv['Three'] final_csv = final_csv.drop(['Transaction_Details','Three'], axis=1) final_csv['Reference'] = '' final_csv['Check Number'] = '' df = final_csv[['*Date', '*Amount', 'Payee', 'Description','Reference','Check Number']] df = df[df['*Amount'] != 0] csv = convert_df(df) st.dataframe(df) st.download_button( "Press to Download", csv, "file.csv", "text/csv", key='download-csv' ) #success = st.button("Extract", on_click=model, args=[uploaded, filter])