Spaces:
Runtime error
Runtime error
Raj-Master
commited on
Commit
·
4635598
1
Parent(s):
e1422df
Create predict.py
Browse files- predict.py +151 -0
predict.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
import random
|
4 |
+
import string
|
5 |
+
from ultralyticsplus import YOLO
|
6 |
+
import streamlit as st
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
from process import (
|
10 |
+
filter_columns,
|
11 |
+
extract_text_of_col,
|
12 |
+
prepare_cols,
|
13 |
+
process_cols,
|
14 |
+
finalize_data,
|
15 |
+
)
|
16 |
+
from file_utils import (
|
17 |
+
get_img,
|
18 |
+
save_excel_file,
|
19 |
+
concat_csv,
|
20 |
+
convert_pdf_to_image,
|
21 |
+
filter_color,
|
22 |
+
plot,
|
23 |
+
delete_file,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
def process_img(
|
28 |
+
img,
|
29 |
+
page_enumeration: int = 0,
|
30 |
+
filter=False,
|
31 |
+
foldername: str = "",
|
32 |
+
filename: str = "",
|
33 |
+
):
|
34 |
+
tables = PaddleOCR.table_model(img, conf=0.75)
|
35 |
+
tables = tables[0].boxes.xyxy.cpu().numpy()
|
36 |
+
results = []
|
37 |
+
for table in tables:
|
38 |
+
try:
|
39 |
+
# * crop the table as an image from the original image
|
40 |
+
sub_img = img[
|
41 |
+
int(table[1].item()): int(table[3].item()),
|
42 |
+
int(table[0].item()): int(table[2].item()),
|
43 |
+
]
|
44 |
+
columns_detect = PaddleOCR.column_model(sub_img, conf=0.75)
|
45 |
+
cols_data = columns_detect[0].boxes.data.cpu().numpy()
|
46 |
+
|
47 |
+
# * Sort columns according to the x coordinate
|
48 |
+
cols_data = np.array(
|
49 |
+
sorted(cols_data, key=lambda x: x[0]), dtype=np.ndarray
|
50 |
+
)
|
51 |
+
|
52 |
+
# * merge the duplicated columns
|
53 |
+
cols_data = filter_columns(cols_data)
|
54 |
+
st.image(plot(sub_img, cols_data), channels="RGB")
|
55 |
+
except:
|
56 |
+
st.warning("No Detection")
|
57 |
+
|
58 |
+
try:
|
59 |
+
columns = cols_data[:, 0:4]
|
60 |
+
sub_imgs = []
|
61 |
+
for column in columns:
|
62 |
+
# * Create list of cropped images for each column
|
63 |
+
sub_imgs.append(sub_img[:, int(column[0]): int(column[2])])
|
64 |
+
cols = []
|
65 |
+
thr = 0
|
66 |
+
for image in sub_imgs:
|
67 |
+
if filter:
|
68 |
+
# * keep only black color in the image
|
69 |
+
image = filter_color(image)
|
70 |
+
|
71 |
+
# * extract text of each column and get the length threshold
|
72 |
+
res, threshold = extract_text_of_col(image)
|
73 |
+
thr += threshold
|
74 |
+
|
75 |
+
# * arrange the rows of each column with respect to row length threshold
|
76 |
+
cols.append(prepare_cols(res, threshold * 0.6))
|
77 |
+
|
78 |
+
thr = thr / len(sub_imgs)
|
79 |
+
|
80 |
+
# * append each element in each column to its right place in the dataframe
|
81 |
+
data = process_cols(cols, thr * 0.6)
|
82 |
+
|
83 |
+
# * merge the related rows together
|
84 |
+
data: pd.DataFrame = finalize_data(data, page_enumeration)
|
85 |
+
results.append(data)
|
86 |
+
print("data : ",data)
|
87 |
+
print("results : ", results)
|
88 |
+
except:
|
89 |
+
st.warning("Text Extraction Failed")
|
90 |
+
continue
|
91 |
+
list(
|
92 |
+
map(
|
93 |
+
lambda x: save_excel_file(
|
94 |
+
*x,
|
95 |
+
foldername,
|
96 |
+
filename,
|
97 |
+
page_enumeration,
|
98 |
+
),
|
99 |
+
enumerate(results),
|
100 |
+
)
|
101 |
+
)
|
102 |
+
|
103 |
+
|
104 |
+
class PaddleOCR:
|
105 |
+
# Load Image Detection model
|
106 |
+
table_model = YOLO("table.pt")
|
107 |
+
column_model = YOLO("columns.pt")
|
108 |
+
|
109 |
+
def __call__(self, uploaded, filter=False):
|
110 |
+
foldername = tempfile.TemporaryDirectory(dir=os.getcwd())
|
111 |
+
filename = uploaded.name.split(".")[0]
|
112 |
+
if uploaded.name.split(".")[1].lower() == "pdf":
|
113 |
+
pdf_pages = convert_pdf_to_image(uploaded.read())
|
114 |
+
for page_enumeration, page in enumerate(pdf_pages, start=1):
|
115 |
+
process_img(
|
116 |
+
np.asarray(page),
|
117 |
+
page_enumeration,
|
118 |
+
filter=filter,
|
119 |
+
foldername=foldername.name,
|
120 |
+
filename=filename,
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
img = get_img(uploaded)
|
124 |
+
process_img(
|
125 |
+
img,
|
126 |
+
filter=filter,
|
127 |
+
foldername=foldername.name,
|
128 |
+
filename=filename,
|
129 |
+
)
|
130 |
+
|
131 |
+
# * concatenate all csv files if many
|
132 |
+
extra = "".join(random.choices(string.ascii_uppercase, k=5))
|
133 |
+
filename = f"{filename}_{extra}.csv"
|
134 |
+
try:
|
135 |
+
concat_csv(foldername, filename)
|
136 |
+
except:
|
137 |
+
st.warning("No results found")
|
138 |
+
|
139 |
+
foldername.cleanup()
|
140 |
+
|
141 |
+
if os.path.exists(filename):
|
142 |
+
with open(f"{filename}", "rb") as fp:
|
143 |
+
st.download_button(
|
144 |
+
label="Download CSV file",
|
145 |
+
data=fp,
|
146 |
+
file_name=filename,
|
147 |
+
mime="text/csv",
|
148 |
+
)
|
149 |
+
delete_file(filename)
|
150 |
+
else:
|
151 |
+
st.warning("No results found")
|