Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -33,7 +33,7 @@ table_engine_list = [
|
|
33 |
# 示例图片路径
|
34 |
example_images = [
|
35 |
"images/wired1.png",
|
36 |
-
"images/wired2.
|
37 |
"images/wired3.png",
|
38 |
"images/lineless1.png",
|
39 |
"images/wired4.jpg",
|
@@ -67,6 +67,17 @@ for det_model in det_model_dir.keys():
|
|
67 |
rec_model_dir=rec_model_path
|
68 |
)
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
def select_ocr_model(det_model, rec_model):
|
72 |
return ocr_engine_dict[f"{det_model}_{rec_model}"]
|
@@ -94,8 +105,10 @@ def select_table_model(img, table_engine_type, det_model, rec_model):
|
|
94 |
return lineless_table_engine, "lineless_table"
|
95 |
|
96 |
|
97 |
-
def process_image(
|
98 |
-
|
|
|
|
|
99 |
start = time.time()
|
100 |
table_engine, talbe_type = select_table_model(img, table_engine_type, det_model, rec_model)
|
101 |
ocr_engine = select_ocr_model(det_model, rec_model)
|
@@ -108,24 +121,20 @@ def process_image(img, table_engine_type, det_model, rec_model, small_box_cut_en
|
|
108 |
ocr_boxes = result[0]['res']['boxes']
|
109 |
all_elapse = f"- `table all cost: {time.time() - start:.5f}`"
|
110 |
else:
|
111 |
-
ocr_res, ocr_infer_elapse = ocr_engine(img)
|
112 |
det_cost, cls_cost, rec_cost = ocr_infer_elapse
|
|
|
|
|
113 |
ocr_boxes = [box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_res]
|
114 |
if isinstance(table_engine, RapidTable):
|
115 |
html, polygons, table_rec_elapse = table_engine(img, ocr_result=ocr_res)
|
116 |
polygons = [[polygon[0], polygon[1], polygon[4], polygon[5]] for polygon in polygons]
|
117 |
elif isinstance(table_engine, (WiredTableRecognition, LinelessTableRecognition)):
|
118 |
-
html, table_rec_elapse, polygons,
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
)
|
124 |
-
else:
|
125 |
-
html, table_rec_elapse, polygons, logic_points, ocr_res = table_engine(
|
126 |
-
img, ocr_result=ocr_res
|
127 |
-
)
|
128 |
-
|
129 |
sum_elapse = time.time() - start
|
130 |
all_elapse = f"- table_type: {talbe_type}\n table all cost: {sum_elapse:.5f}\n - table rec cost: {table_rec_elapse:.5f}\n - ocr cost: {det_cost + cls_cost + rec_cost:.5f}"
|
131 |
|
@@ -191,10 +200,33 @@ def main():
|
|
191 |
label="Box Cutting Enhancement (Disable to avoid excessive cutting, Enable to reduce missed cutting)",
|
192 |
value=True
|
193 |
)
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
199 |
run_button = gr.Button("Run")
|
200 |
gr.Markdown("# Elapsed Time")
|
@@ -210,7 +242,7 @@ def main():
|
|
210 |
|
211 |
run_button.click(
|
212 |
fn=process_image,
|
213 |
-
inputs=[img_input, table_engine_type,
|
214 |
outputs=[html_output, table_boxes_output, ocr_boxes_output, elapse_text]
|
215 |
)
|
216 |
|
|
|
33 |
# 示例图片路径
|
34 |
example_images = [
|
35 |
"images/wired1.png",
|
36 |
+
"images/wired2.jpg",
|
37 |
"images/wired3.png",
|
38 |
"images/lineless1.png",
|
39 |
"images/wired4.jpg",
|
|
|
67 |
rec_model_dir=rec_model_path
|
68 |
)
|
69 |
|
70 |
+
def trans_char_ocr_res(ocr_res):
|
71 |
+
word_result = []
|
72 |
+
for res in ocr_res:
|
73 |
+
score = res[2]
|
74 |
+
for word_box, word in zip(res[3], res[4]):
|
75 |
+
word_res = []
|
76 |
+
word_res.append(word_box)
|
77 |
+
word_res.append(word)
|
78 |
+
word_res.append(score)
|
79 |
+
word_result.append(word_res)
|
80 |
+
return word_result
|
81 |
|
82 |
def select_ocr_model(det_model, rec_model):
|
83 |
return ocr_engine_dict[f"{det_model}_{rec_model}"]
|
|
|
105 |
return lineless_table_engine, "lineless_table"
|
106 |
|
107 |
|
108 |
+
def process_image(img_input, small_box_cut_enhance, table_engine_type, char_ocr, rotated_fix, col_threshold, row_threshold):
|
109 |
+
det_model="mobile_det"
|
110 |
+
rec_model="mobile_rec"
|
111 |
+
img = img_loader(img_input)
|
112 |
start = time.time()
|
113 |
table_engine, talbe_type = select_table_model(img, table_engine_type, det_model, rec_model)
|
114 |
ocr_engine = select_ocr_model(det_model, rec_model)
|
|
|
121 |
ocr_boxes = result[0]['res']['boxes']
|
122 |
all_elapse = f"- `table all cost: {time.time() - start:.5f}`"
|
123 |
else:
|
124 |
+
ocr_res, ocr_infer_elapse = ocr_engine(img, return_word_box=char_ocr)
|
125 |
det_cost, cls_cost, rec_cost = ocr_infer_elapse
|
126 |
+
if char_ocr:
|
127 |
+
ocr_res = trans_char_ocr_res(ocr_res)
|
128 |
ocr_boxes = [box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_res]
|
129 |
if isinstance(table_engine, RapidTable):
|
130 |
html, polygons, table_rec_elapse = table_engine(img, ocr_result=ocr_res)
|
131 |
polygons = [[polygon[0], polygon[1], polygon[4], polygon[5]] for polygon in polygons]
|
132 |
elif isinstance(table_engine, (WiredTableRecognition, LinelessTableRecognition)):
|
133 |
+
html, table_rec_elapse, polygons, logic_points, ocr_res = table_engine(img, ocr_result=ocr_res,
|
134 |
+
enhance_box_line=small_box_cut_enhance,
|
135 |
+
rotated_fix=rotated_fix,
|
136 |
+
col_threshold=col_threshold,
|
137 |
+
row_threshold=row_threshold)
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
sum_elapse = time.time() - start
|
139 |
all_elapse = f"- table_type: {talbe_type}\n table all cost: {sum_elapse:.5f}\n - table rec cost: {table_rec_elapse:.5f}\n - ocr cost: {det_cost + cls_cost + rec_cost:.5f}"
|
140 |
|
|
|
200 |
label="Box Cutting Enhancement (Disable to avoid excessive cutting, Enable to reduce missed cutting)",
|
201 |
value=True
|
202 |
)
|
203 |
+
char_ocr = gr.Checkbox(
|
204 |
+
label="char rec ocr",
|
205 |
+
value=False
|
206 |
+
)
|
207 |
+
rotate_adapt = gr.Checkbox(
|
208 |
+
label="Table Rotate Rec Enhancement",
|
209 |
+
value=False
|
210 |
+
)
|
211 |
+
col_threshold = gr.Slider(
|
212 |
+
label="col threshold(determine same col)",
|
213 |
+
minimum=5,
|
214 |
+
maximum=100,
|
215 |
+
value=15,
|
216 |
+
step=5
|
217 |
+
)
|
218 |
+
row_threshold = gr.Slider(
|
219 |
+
label="row threshold(determine same row)",
|
220 |
+
minimum=5,
|
221 |
+
maximum=100,
|
222 |
+
value=10,
|
223 |
+
step=5
|
224 |
+
)
|
225 |
+
|
226 |
+
# det_model = gr.Dropdown(det_models_labels, label="Select OCR Detection Model",
|
227 |
+
# value=det_models_labels[0])
|
228 |
+
# rec_model = gr.Dropdown(rec_models_labels, label="Select OCR Recognition Model",
|
229 |
+
# value=rec_models_labels[0])
|
230 |
|
231 |
run_button = gr.Button("Run")
|
232 |
gr.Markdown("# Elapsed Time")
|
|
|
242 |
|
243 |
run_button.click(
|
244 |
fn=process_image,
|
245 |
+
inputs=[img_input, small_box_cut_enhance, table_engine_type, char_ocr, rotate_adapt, col_threshold, row_threshold],
|
246 |
outputs=[html_output, table_boxes_output, ocr_boxes_output, elapse_text]
|
247 |
)
|
248 |
|