Raj-Master commited on
Commit
4635598
·
1 Parent(s): e1422df

Create predict.py

Browse files
Files changed (1) hide show
  1. predict.py +151 -0
predict.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import random
4
+ import string
5
+ from ultralyticsplus import YOLO
6
+ import streamlit as st
7
+ import numpy as np
8
+ import pandas as pd
9
+ from process import (
10
+ filter_columns,
11
+ extract_text_of_col,
12
+ prepare_cols,
13
+ process_cols,
14
+ finalize_data,
15
+ )
16
+ from file_utils import (
17
+ get_img,
18
+ save_excel_file,
19
+ concat_csv,
20
+ convert_pdf_to_image,
21
+ filter_color,
22
+ plot,
23
+ delete_file,
24
+ )
25
+
26
+
27
+ def process_img(
28
+ img,
29
+ page_enumeration: int = 0,
30
+ filter=False,
31
+ foldername: str = "",
32
+ filename: str = "",
33
+ ):
34
+ tables = PaddleOCR.table_model(img, conf=0.75)
35
+ tables = tables[0].boxes.xyxy.cpu().numpy()
36
+ results = []
37
+ for table in tables:
38
+ try:
39
+ # * crop the table as an image from the original image
40
+ sub_img = img[
41
+ int(table[1].item()): int(table[3].item()),
42
+ int(table[0].item()): int(table[2].item()),
43
+ ]
44
+ columns_detect = PaddleOCR.column_model(sub_img, conf=0.75)
45
+ cols_data = columns_detect[0].boxes.data.cpu().numpy()
46
+
47
+ # * Sort columns according to the x coordinate
48
+ cols_data = np.array(
49
+ sorted(cols_data, key=lambda x: x[0]), dtype=np.ndarray
50
+ )
51
+
52
+ # * merge the duplicated columns
53
+ cols_data = filter_columns(cols_data)
54
+ st.image(plot(sub_img, cols_data), channels="RGB")
55
+ except:
56
+ st.warning("No Detection")
57
+
58
+ try:
59
+ columns = cols_data[:, 0:4]
60
+ sub_imgs = []
61
+ for column in columns:
62
+ # * Create list of cropped images for each column
63
+ sub_imgs.append(sub_img[:, int(column[0]): int(column[2])])
64
+ cols = []
65
+ thr = 0
66
+ for image in sub_imgs:
67
+ if filter:
68
+ # * keep only black color in the image
69
+ image = filter_color(image)
70
+
71
+ # * extract text of each column and get the length threshold
72
+ res, threshold = extract_text_of_col(image)
73
+ thr += threshold
74
+
75
+ # * arrange the rows of each column with respect to row length threshold
76
+ cols.append(prepare_cols(res, threshold * 0.6))
77
+
78
+ thr = thr / len(sub_imgs)
79
+
80
+ # * append each element in each column to its right place in the dataframe
81
+ data = process_cols(cols, thr * 0.6)
82
+
83
+ # * merge the related rows together
84
+ data: pd.DataFrame = finalize_data(data, page_enumeration)
85
+ results.append(data)
86
+ print("data : ",data)
87
+ print("results : ", results)
88
+ except:
89
+ st.warning("Text Extraction Failed")
90
+ continue
91
+ list(
92
+ map(
93
+ lambda x: save_excel_file(
94
+ *x,
95
+ foldername,
96
+ filename,
97
+ page_enumeration,
98
+ ),
99
+ enumerate(results),
100
+ )
101
+ )
102
+
103
+
104
+ class PaddleOCR:
105
+ # Load Image Detection model
106
+ table_model = YOLO("table.pt")
107
+ column_model = YOLO("columns.pt")
108
+
109
+ def __call__(self, uploaded, filter=False):
110
+ foldername = tempfile.TemporaryDirectory(dir=os.getcwd())
111
+ filename = uploaded.name.split(".")[0]
112
+ if uploaded.name.split(".")[1].lower() == "pdf":
113
+ pdf_pages = convert_pdf_to_image(uploaded.read())
114
+ for page_enumeration, page in enumerate(pdf_pages, start=1):
115
+ process_img(
116
+ np.asarray(page),
117
+ page_enumeration,
118
+ filter=filter,
119
+ foldername=foldername.name,
120
+ filename=filename,
121
+ )
122
+ else:
123
+ img = get_img(uploaded)
124
+ process_img(
125
+ img,
126
+ filter=filter,
127
+ foldername=foldername.name,
128
+ filename=filename,
129
+ )
130
+
131
+ # * concatenate all csv files if many
132
+ extra = "".join(random.choices(string.ascii_uppercase, k=5))
133
+ filename = f"{filename}_{extra}.csv"
134
+ try:
135
+ concat_csv(foldername, filename)
136
+ except:
137
+ st.warning("No results found")
138
+
139
+ foldername.cleanup()
140
+
141
+ if os.path.exists(filename):
142
+ with open(f"{filename}", "rb") as fp:
143
+ st.download_button(
144
+ label="Download CSV file",
145
+ data=fp,
146
+ file_name=filename,
147
+ mime="text/csv",
148
+ )
149
+ delete_file(filename)
150
+ else:
151
+ st.warning("No results found")