tabel_ocr / app.py
Raj-Master's picture
Update app.py
af13121
import streamlit as st
from predict import PaddleOCR
from pdf2image import convert_from_bytes
import cv2
import PIL
import numpy as np
import os
import tempfile
import random
import string
from ultralyticsplus import YOLO
import streamlit as st
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import io
import re
from dateutil.parser import parse
import datetime
from file_utils import (
get_img,
save_excel_file,
concat_csv,
convert_pdf_to_image,
filter_color,
plot,
delete_file,
)
from process import (
filter_columns,
extract_text_of_col,
prepare_cols,
process_cols,
finalize_data,
)
table_model = YOLO("table.pt")
column_model = YOLO("columns.pt")
def remove_dots(string):
# Remove dots from the first and last position of the string
string = string.strip('.')
# Remove the first dot from left to right if there are still more than one dots
if string.count('.') > 1:
string = string.replace(".", "", 1)
return string
def convert_df(df):
return df.to_csv(index=False).encode('utf-8')
def PIL_to_cv(pil_img):
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
def cv_to_PIL(cv_img):
return PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
def visualize_ocr(pil_img, ocr_result):
plt.imshow(pil_img, interpolation='lanczos')
plt.gcf().set_size_inches(20, 20)
ax = plt.gca()
for idx, result in enumerate(ocr_result):
bbox = result['bbox']
text = result['text']
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=2, edgecolor='red', facecolor='none', linestyle='-')
ax.add_patch(rect)
ax.text(bbox[0], bbox[1], text, horizontalalignment='left', verticalalignment='bottom', color='blue', fontsize=7)
plt.xticks([], [])
plt.yticks([], [])
plt.gcf().set_size_inches(10, 10)
plt.axis('off')
img_buf = io.BytesIO()
plt.savefig(img_buf, bbox_inches='tight', dpi=150)
plt.close()
return PIL.Image.open(img_buf)
def filter_columns(columns: np.ndarray):
for idx, col in enumerate(columns):
if idx >= len(columns) - 1:
break
nxt = columns[idx + 1]
threshold = ((col[2] - col[0]) + (nxt[2] - nxt[0])) / 2
if (col[2] - columns[idx + 1][0]) > threshold * 0.5:
col[1], col[2], col[3] = min(col[1], nxt[1]), nxt[2], max(col[3], nxt[3])
columns = np.delete(columns, idx + 1, 0)
idx -= 1
return columns
st.title("Extract data from bank statements")
model = PaddleOCR()
uploaded = st.file_uploader(
"upload a bank statement image",
type=["png", "jpg", "jpeg", "PNG", "JPG", "JPEG", "pdf", "PDF"],
)
number = st.number_input('Insert a number',value=2023, step=1)
filter = st.checkbox("filter color")
if st.button('Analyze image'):
final_csv = pd.DataFrame()
first_flag_dataframe=0
if uploaded is None:
st.write('Please upload an image')
else:
tabs = st.tabs(
['Pages','Table Detection', 'Table Structure Recognition', 'Extracted Table(s)']
)
print(uploaded.type)
if uploaded.type == "application/pdf":
foldername = tempfile.TemporaryDirectory(dir=os.getcwd())
filename = uploaded.name.split(".")[0]
pdf_pages=convert_from_bytes(uploaded.read(),500)
for page_enumeration, page in enumerate(pdf_pages, start=1):
with tabs[0]:
st.header('Pages : '+str(page_enumeration))
st.image(page)
page_img=np.asarray(page)
tables = PaddleOCR.table_model(page_img, conf=0.60)
tabel_datas=tables[0].boxes.data.cpu().numpy()
tables = tables[0].boxes.xyxy.cpu().numpy()
with tabs[1]:
st.header('Table Detection Page :'+str(page_enumeration))
str_cols = st.columns(4)
str_cols[0].subheader('Table image')
str_cols[1].subheader('Columns')
str_cols[2].subheader('Structure result')
str_cols[3].subheader('Cells result')
results = []
for table in tables:
try:
tabel_data = np.array(
sorted(tabel_datas, key=lambda x: x[0]), dtype=np.ndarray
)
tabel_data = filter_columns(tabel_data)
str_cols[0].image(plot(page_img, tabel_data), channels="RGB")
# * crop the table as an image from the original image
sub_img = page_img[
int(table[1].item()): int(table[3].item()),
int(table[0].item()): int(table[2].item()),
]
columns_detect = PaddleOCR.column_model(sub_img, conf=0.65)
cols_data = columns_detect[0].boxes.data.cpu().numpy()
# * Sort columns according to the x coordinate
cols_data = np.array(
sorted(cols_data, key=lambda x: x[0]), dtype=np.ndarray
)
# * merge the duplicated columns
cols_data = filter_columns(cols_data)
str_cols[1].image(plot(sub_img, cols_data), channels="RGB")
except Exception as e:
print(e)
st.warning("No Detection")
try:
####################################################################
# # columns = cols_data[:, 0:4]
# # #sub_imgs = []
# # thr = 0
# # column = columns[0]
# # maxcol1=int(column[1])
# # maxcol3=int(column[3])
# # cols = []
# # for column in columns:
# # if maxcol1 < int(column[1]) :
# # maxcol1=int(column[1])
# # if maxcol3 < int(column[3]) :
# # maxcol3=int(column[3])
# # sub_imgs = (sub_img[ maxcol1: maxcol3, : ])
# # str_cols[2].image(sub_imgs)
# # image = filter_color(sub_imgs)
# # res, threshold,ocr_res = extract_text_of_col(image)
# # vis_ocr_img = visualize_ocr(image, ocr_res)
# # str_cols[3].image(vis_ocr_img)
# # thr += threshold
# # cols.append(prepare_cols(res, threshold * 0.6))
# # print("cols : ",cols)
# # thr = thr / len(columns)
# # data = process_cols(cols, thr * 0.6)
# # print("data : ",data)
######################################################################
columns = cols_data[:, 0:4]
sub_imgs = []
column = columns[0]
maxcol1=int(column[1])
maxcol3=int(column[3])
#for column in columns:
# if maxcol1 < int(column[1]) :
# maxcol1=int(column[1])
# if maxcol3 < int(column[3]) :
# maxcol3=int(column[3])
for column in columns:
# * Create list of cropped images for each column
sub_imgs.append(sub_img[maxcol1:maxcol3, int(column[0]): int(column[2])])
cols = []
thr = 0
for image in sub_imgs:
if filter:
# * keep only black color in the image
image = filter_color(image)
# * extract text of each column and get the length threshold
res, threshold, ocr_res = extract_text_of_col(image)
thr += threshold
# * arrange the rows of each column with respect to row length threshold
cols.append(prepare_cols(res, threshold * 0.6))
thr = thr / len(sub_imgs)
# * append each element in each column to its right place in the dataframe
data = process_cols(cols, thr * 0.6)
# * merge the related rows together
data: pd.DataFrame = finalize_data(data, page_enumeration)
results.append(data)
with tabs[2]:
st.header('Extracted Table(s)')
st.dataframe(data)
print("data : ",data)
print("results : ", results)
if first_flag_dataframe == 0 :
first_flag_dataframe=1
final_csv=data
else:
final_csv = pd.concat([final_csv,data],ignore_index=True)
csv = convert_df(data)
print(csv)
except:
st.warning("Text Extraction Failed")
continue
with tabs[3]:
st.dataframe(final_csv)
rough_csv= convert_df(final_csv)
st.download_button(
"rough-csv",
rough_csv,
"file.csv",
"text/csv",
key='rough-csv'
)
final_csv.columns = ['page','Date', 'Transaction_Details', 'Three', 'Deposit','Withdrawal','Balance']
#final_csv = final_csv.rename(columns={1: 'Date', 2: 'Transaction_Details', 3: 'Three', 4: 'Deposit',5 : 'Withdrawal',6:'Balance'})
final_csv['Date'] = final_csv['Date'].astype(str)
st.dataframe(final_csv)
final_csv = final_csv[~final_csv['Date'].str.contains('Date')]
final_csv = final_csv[~final_csv['Date'].str.contains('日期')]
final_csv = final_csv[~final_csv['Date'].str.contains('口期')]
final_csv['Date'] = final_csv['Date'].apply(lambda x: re.sub(r'[^a-zA-Z0-9 ]', '', x))
final_csv['Date'] = final_csv['Date'].apply(lambda x: x + str(number))
final_csv['Date'] = final_csv['Date'].apply(lambda x:parse(x, fuzzy=True))
#final_csv['Date']=final_csv['Date'].str.replace(' ', '')
final_csv['*Date'] = pd.to_datetime(final_csv['Date']).dt.strftime('%d-%m-%Y')
final_csv['Withdrawal'] = final_csv['Withdrawal'].astype(str)
final_csv['Withdrawal'] = final_csv['Withdrawal'].str.replace('i', '').str.replace('E', '').str.replace(':', '').str.replace('M', '').str.replace('?', '').str.replace('t', '').str.replace('+', '').str.replace(';', '').str.replace('g', '').str.replace('^', '').str.replace('m', '').str.replace('/', '').str.replace('#', '').str.replace("'", '').str.replace('w', '').str.replace('"', '').str.replace('%', '').str.replace('r', '').str.replace('-', '').str.replace('v', '').str.replace(',', '').str.replace('·', '').str.replace(':', '').str.replace(' ', '').str.replace('*', '').str.replace('~', '').str.replace('V', '')
final_csv['Withdrawal'] = final_csv['Withdrawal'].apply(remove_dots)
final_csv['Withdrawal'] = final_csv['Withdrawal'].astype(float)*-1
final_csv['Deposit'] = final_csv['Deposit'].astype(str)
final_csv['Deposit'] = final_csv['Deposit'].str.replace('i', '').str.replace('E', '').str.replace(':', '').str.replace('M', '').str.replace('?', '').str.replace('t', '').str.replace('+', '').str.replace(';', '').str.replace('g', '').str.replace('^', '').str.replace('m', '').str.replace('/', '').str.replace('#', '').str.replace("'", '').str.replace('w', '').str.replace('"', '').str.replace('%', '').str.replace('r', '').str.replace('-', '').str.replace('v', '').str.replace(',', '').str.replace('·', '').str.replace(':', '').str.replace(' ', '').str.replace('*', '').str.replace('~', '').str.replace('V', '')
final_csv['Deposit'] = final_csv['Deposit'].apply(remove_dots)
final_csv['Deposit'] = final_csv['Deposit'].astype(float)
final_csv['*Amount'] = final_csv['Withdrawal'].fillna(0) + final_csv['Deposit'].fillna(0)
final_csv = final_csv.drop(['Withdrawal','Deposit'], axis=1)
final_csv['Payee'] = ''
final_csv['Description'] = final_csv['Transaction_Details']
final_csv.loc[final_csv['Three'].notnull(), 'Description'] += " "+final_csv['Three']
final_csv = final_csv.drop(['Transaction_Details','Three'], axis=1)
final_csv['Reference'] = ''
final_csv['Check Number'] = ''
df = final_csv[['*Date', '*Amount', 'Payee', 'Description','Reference','Check Number']]
df = df[df['*Amount'] != 0]
csv = convert_df(df)
st.dataframe(df)
st.download_button(
"Press to Download",
csv,
"file.csv",
"text/csv",
key='download-csv'
)
#success = st.button("Extract", on_click=model, args=[uploaded, filter])