Raj-Master commited on
Commit
4911ff5
·
1 Parent(s): 5c7c680

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +301 -0
app.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from predict import PaddleOCR
3
+ from pdf2image import convert_from_bytes
4
+ import cv2
5
+ import PIL
6
+ import numpy as np
7
+ import os
8
+ import tempfile
9
+ import random
10
+ import string
11
+ from ultralyticsplus import YOLO
12
+ import streamlit as st
13
+ import numpy as np
14
+ import pandas as pd
15
+ import matplotlib.pyplot as plt
16
+ import matplotlib.patches as patches
17
+ import io
18
+ import re
19
+ from dateutil.parser import parse
20
+
21
+ from file_utils import (
22
+ get_img,
23
+ save_excel_file,
24
+ concat_csv,
25
+ convert_pdf_to_image,
26
+ filter_color,
27
+ plot,
28
+ delete_file,
29
+ )
30
+ from process import (
31
+ filter_columns,
32
+ extract_text_of_col,
33
+ prepare_cols,
34
+ process_cols,
35
+ finalize_data,
36
+ )
37
+
38
+
39
+ table_model = YOLO("table.pt")
40
+ column_model = YOLO("columns.pt")
41
+
42
+ def remove_dots(string):
43
+ # Remove dots from the first and last position of the string
44
+ string = string.strip('.')
45
+
46
+ # Remove the first dot from left to right if there are still more than one dots
47
+ if string.count('.') > 1:
48
+ string = string.replace(".", "", 1)
49
+
50
+ return string
51
+
52
+ def convert_df(df):
53
+ return df.to_csv(index=False).encode('utf-8')
54
+
55
+
56
+ def PIL_to_cv(pil_img):
57
+ return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
58
+
59
+
60
+ def cv_to_PIL(cv_img):
61
+ return PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
62
+
63
+ def visualize_ocr(pil_img, ocr_result):
64
+ plt.imshow(pil_img, interpolation='lanczos')
65
+ plt.gcf().set_size_inches(20, 20)
66
+ ax = plt.gca()
67
+
68
+ for idx, result in enumerate(ocr_result):
69
+ bbox = result['bbox']
70
+ text = result['text']
71
+ rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=2, edgecolor='red', facecolor='none', linestyle='-')
72
+ ax.add_patch(rect)
73
+ ax.text(bbox[0], bbox[1], text, horizontalalignment='left', verticalalignment='bottom', color='blue', fontsize=7)
74
+
75
+ plt.xticks([], [])
76
+ plt.yticks([], [])
77
+
78
+ plt.gcf().set_size_inches(10, 10)
79
+ plt.axis('off')
80
+ img_buf = io.BytesIO()
81
+ plt.savefig(img_buf, bbox_inches='tight', dpi=150)
82
+ plt.close()
83
+
84
+ return PIL.Image.open(img_buf)
85
+
86
+ def filter_columns(columns: np.ndarray):
87
+ for idx, col in enumerate(columns):
88
+ if idx >= len(columns) - 1:
89
+ break
90
+ nxt = columns[idx + 1]
91
+ threshold = ((col[2] - col[0]) + (nxt[2] - nxt[0])) / 2
92
+ if (col[2] - columns[idx + 1][0]) > threshold * 0.5:
93
+ col[1], col[2], col[3] = min(col[1], nxt[1]), nxt[2], max(col[3], nxt[3])
94
+ columns = np.delete(columns, idx + 1, 0)
95
+ idx -= 1
96
+ return columns
97
+
98
+ st.title("Extract data from bank statements")
99
+
100
+ model = PaddleOCR()
101
+
102
+ uploaded = st.file_uploader(
103
+ "upload a bank statement image",
104
+ type=["png", "jpg", "jpeg", "PNG", "JPG", "JPEG", "pdf", "PDF"],
105
+ )
106
+ filter = st.checkbox("filter color")
107
+ if st.button('Analyze image'):
108
+
109
+ final_csv = pd.DataFrame()
110
+ first_flag_dataframe=0
111
+ if uploaded is None:
112
+ st.write('Please upload an image')
113
+
114
+ else:
115
+ tabs = st.tabs(
116
+ ['Pages','Table Detection', 'Table Structure Recognition', 'Extracted Table(s)']
117
+ )
118
+ print(uploaded.type)
119
+ if uploaded.type == "application/pdf":
120
+ foldername = tempfile.TemporaryDirectory(dir=os.getcwd())
121
+ filename = uploaded.name.split(".")[0]
122
+ pdf_pages=convert_from_bytes(uploaded.read(),500)
123
+ for page_enumeration, page in enumerate(pdf_pages, start=1):
124
+
125
+ with tabs[0]:
126
+ st.header('Pages : '+str(page_enumeration))
127
+ st.image(page)
128
+
129
+ page_img=np.asarray(page)
130
+ tables = PaddleOCR.table_model(page_img, conf=0.75)
131
+ tabel_datas=tables[0].boxes.data.cpu().numpy()
132
+
133
+ tables = tables[0].boxes.xyxy.cpu().numpy()
134
+ with tabs[1]:
135
+ st.header('Table Detection Page :'+str(page_enumeration))
136
+
137
+ str_cols = st.columns(4)
138
+ str_cols[0].subheader('Table image')
139
+ str_cols[1].subheader('Columns')
140
+ str_cols[2].subheader('Structure result')
141
+ str_cols[3].subheader('Cells result')
142
+ results = []
143
+ for table in tables:
144
+ try:
145
+
146
+ tabel_data = np.array(
147
+ sorted(tabel_datas, key=lambda x: x[0]), dtype=np.ndarray
148
+ )
149
+
150
+ tabel_data = filter_columns(tabel_data)
151
+
152
+ str_cols[0].image(plot(page_img, tabel_data), channels="RGB")
153
+ # * crop the table as an image from the original image
154
+ sub_img = page_img[
155
+ int(table[1].item()): int(table[3].item()),
156
+ int(table[0].item()): int(table[2].item()),
157
+ ]
158
+
159
+ columns_detect = PaddleOCR.column_model(sub_img, conf=0.75)
160
+ cols_data = columns_detect[0].boxes.data.cpu().numpy()
161
+
162
+ # * Sort columns according to the x coordinate
163
+ cols_data = np.array(
164
+ sorted(cols_data, key=lambda x: x[0]), dtype=np.ndarray
165
+ )
166
+
167
+ # * merge the duplicated columns
168
+ cols_data = filter_columns(cols_data)
169
+ str_cols[1].image(plot(sub_img, cols_data), channels="RGB")
170
+
171
+ except Exception as e:
172
+ print(e)
173
+ st.warning("No Detection")
174
+ try:
175
+ ####################################################################
176
+
177
+ # # columns = cols_data[:, 0:4]
178
+ # # #sub_imgs = []
179
+ # # thr = 0
180
+ # # column = columns[0]
181
+ # # maxcol1=int(column[1])
182
+ # # maxcol3=int(column[3])
183
+ # # cols = []
184
+ # # for column in columns:
185
+ # # if maxcol1 < int(column[1]) :
186
+ # # maxcol1=int(column[1])
187
+ # # if maxcol3 < int(column[3]) :
188
+ # # maxcol3=int(column[3])
189
+
190
+ # # sub_imgs = (sub_img[ maxcol1: maxcol3, : ])
191
+ # # str_cols[2].image(sub_imgs)
192
+ # # image = filter_color(sub_imgs)
193
+ # # res, threshold,ocr_res = extract_text_of_col(image)
194
+ # # vis_ocr_img = visualize_ocr(image, ocr_res)
195
+ # # str_cols[3].image(vis_ocr_img)
196
+ # # thr += threshold
197
+ # # cols.append(prepare_cols(res, threshold * 0.6))
198
+ # # print("cols : ",cols)
199
+ # # thr = thr / len(columns)
200
+ # # data = process_cols(cols, thr * 0.6)
201
+ # # print("data : ",data)
202
+ ######################################################################
203
+ columns = cols_data[:, 0:4]
204
+ sub_imgs = []
205
+ column = columns[0]
206
+ maxcol1=int(column[1])
207
+ maxcol3=int(column[3])
208
+ for column in columns:
209
+ if maxcol1 < int(column[1]) :
210
+ maxcol1=int(column[1])
211
+ if maxcol3 < int(column[3]) :
212
+ maxcol3=int(column[3])
213
+
214
+ for column in columns:
215
+ # * Create list of cropped images for each column
216
+ sub_imgs.append(sub_img[maxcol1:maxcol3, int(column[0]): int(column[2])])
217
+ cols = []
218
+ thr = 0
219
+ for image in sub_imgs:
220
+ if filter:
221
+ # * keep only black color in the image
222
+ image = filter_color(image)
223
+
224
+ # * extract text of each column and get the length threshold
225
+ res, threshold, ocr_res = extract_text_of_col(image)
226
+ thr += threshold
227
+
228
+ # * arrange the rows of each column with respect to row length threshold
229
+ cols.append(prepare_cols(res, threshold * 0.6))
230
+
231
+ thr = thr / len(sub_imgs)
232
+
233
+ # * append each element in each column to its right place in the dataframe
234
+ data = process_cols(cols, thr * 0.6)
235
+
236
+ # * merge the related rows together
237
+
238
+ data: pd.DataFrame = finalize_data(data, page_enumeration)
239
+ results.append(data)
240
+ with tabs[2]:
241
+ st.header('Extracted Table(s)')
242
+ st.dataframe(data)
243
+ print("data : ",data)
244
+ print("results : ", results)
245
+ if first_flag_dataframe == 0 :
246
+ first_flag_dataframe=1
247
+ final_csv=data
248
+ else:
249
+ final_csv = pd.concat([final_csv,data],ignore_index=True)
250
+ csv = convert_df(data)
251
+ print(csv)
252
+
253
+ except:
254
+ st.warning("Text Extraction Failed")
255
+ continue
256
+ with tabs[3]:
257
+ st.dataframe(final_csv)
258
+ st.dataframe(final_csv.keys())
259
+ print(final_csv.head())
260
+ final_csv.columns = ['page','Date', 'Transaction_Details', 'Three', 'Deposit','Withdrawal','Balance']
261
+ #final_csv = final_csv.rename(columns={1: 'Date', 2: 'Transaction_Details', 3: 'Three', 4: 'Deposit',5 : 'Withdrawal',6:'Balance'})
262
+ final_csv['Date'] = final_csv['Date'].astype(str)
263
+ st.dataframe(final_csv)
264
+ final_csv = final_csv[~final_csv['Date'].str.contains('Date')]
265
+ final_csv = final_csv[~final_csv['Date'].str.contains('日期')]
266
+ final_csv['Date'] = final_csv['Date'].apply(lambda x: re.sub(r'[^a-zA-Z0-9 ]', '', x))
267
+ final_csv['Date'] = final_csv['Date'].apply(lambda x: x + ' 2023')
268
+ final_csv['Date'] = final_csv['Date'].apply(lambda x:parse(x, fuzzy=True))
269
+ #final_csv['Date']=final_csv['Date'].str.replace(' ', '')
270
+ final_csv['*Date'] = pd.to_datetime(final_csv['Date']).dt.strftime('%d-%m-%Y')
271
+ final_csv['Withdrawal'] = final_csv['Withdrawal'].astype(str)
272
+ final_csv['Withdrawal'] = final_csv['Withdrawal'].str.replace('i', '').str.replace('E', '').str.replace(':', '').str.replace('M', '').str.replace('?', '').str.replace('t', '').str.replace('+', '').str.replace(';', '').str.replace('g', '').str.replace('^', '').str.replace('m', '').str.replace('/', '').str.replace('#', '').str.replace("'", '').str.replace('w', '').str.replace('"', '').str.replace('%', '').str.replace('r', '').str.replace('-', '').str.replace('v', '').str.replace(',', '').str.replace('·', '').str.replace(':', '').str.replace(' ', '').str.replace('*', '').str.replace('~', '').str.replace('V', '')
273
+ final_csv['Withdrawal'] = final_csv['Withdrawal'].apply(remove_dots)
274
+ final_csv['Withdrawal'] = final_csv['Withdrawal'].astype(float)*-1
275
+ final_csv['Deposit'] = final_csv['Deposit'].astype(str)
276
+ final_csv['Deposit'] = final_csv['Deposit'].str.replace('i', '').str.replace('E', '').str.replace(':', '').str.replace('M', '').str.replace('?', '').str.replace('t', '').str.replace('+', '').str.replace(';', '').str.replace('g', '').str.replace('^', '').str.replace('m', '').str.replace('/', '').str.replace('#', '').str.replace("'", '').str.replace('w', '').str.replace('"', '').str.replace('%', '').str.replace('r', '').str.replace('-', '').str.replace('v', '').str.replace(',', '').str.replace('·', '').str.replace(':', '').str.replace(' ', '').str.replace('*', '').str.replace('~', '').str.replace('V', '')
277
+ final_csv['Deposit'] = final_csv['Deposit'].apply(remove_dots)
278
+ final_csv['Deposit'] = final_csv['Deposit'].astype(float)
279
+ final_csv['*Amount'] = final_csv['Withdrawal'].fillna(0) + final_csv['Deposit'].fillna(0)
280
+ final_csv = final_csv.drop(['Withdrawal','Deposit'], axis=1)
281
+ final_csv['Payee'] = ''
282
+ final_csv['Description'] = final_csv['Transaction_Details']
283
+ final_csv.loc[final_csv['Three'].notnull(), 'Description'] += " "+final_csv['Three']
284
+ final_csv = final_csv.drop(['Transaction_Details','Three'], axis=1)
285
+ final_csv['Reference'] = ''
286
+ final_csv['Check Number'] = ''
287
+ df = final_csv[['*Date', '*Amount', 'Payee', 'Description','Reference','Check Number']]
288
+ df = df[df['*Amount'] != 0]
289
+ csv = convert_df(df)
290
+ st.dataframe(df)
291
+ st.download_button(
292
+ "Press to Download",
293
+ csv,
294
+ "file.csv",
295
+ "text/csv",
296
+ key='download-csv'
297
+ )
298
+
299
+ #success = st.button("Extract", on_click=model, args=[uploaded, filter])
300
+
301
+