Joker1212 commited on
Commit
9e1fe47
·
verified ·
1 Parent(s): a67cd51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -20
app.py CHANGED
@@ -33,7 +33,7 @@ table_engine_list = [
33
  # 示例图片路径
34
  example_images = [
35
  "images/wired1.png",
36
- "images/wired2.png",
37
  "images/wired3.png",
38
  "images/lineless1.png",
39
  "images/wired4.jpg",
@@ -67,6 +67,17 @@ for det_model in det_model_dir.keys():
67
  rec_model_dir=rec_model_path
68
  )
69
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  def select_ocr_model(det_model, rec_model):
72
  return ocr_engine_dict[f"{det_model}_{rec_model}"]
@@ -94,8 +105,10 @@ def select_table_model(img, table_engine_type, det_model, rec_model):
94
  return lineless_table_engine, "lineless_table"
95
 
96
 
97
- def process_image(img, table_engine_type, det_model, rec_model, small_box_cut_enhance):
98
- img = img_loader(img)
 
 
99
  start = time.time()
100
  table_engine, talbe_type = select_table_model(img, table_engine_type, det_model, rec_model)
101
  ocr_engine = select_ocr_model(det_model, rec_model)
@@ -108,24 +121,20 @@ def process_image(img, table_engine_type, det_model, rec_model, small_box_cut_en
108
  ocr_boxes = result[0]['res']['boxes']
109
  all_elapse = f"- `table all cost: {time.time() - start:.5f}`"
110
  else:
111
- ocr_res, ocr_infer_elapse = ocr_engine(img)
112
  det_cost, cls_cost, rec_cost = ocr_infer_elapse
 
 
113
  ocr_boxes = [box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_res]
114
  if isinstance(table_engine, RapidTable):
115
  html, polygons, table_rec_elapse = table_engine(img, ocr_result=ocr_res)
116
  polygons = [[polygon[0], polygon[1], polygon[4], polygon[5]] for polygon in polygons]
117
  elif isinstance(table_engine, (WiredTableRecognition, LinelessTableRecognition)):
118
- html, table_rec_elapse, polygons, _, _ = table_engine(img, ocr_result=ocr_res)
119
- if not small_box_cut_enhance:
120
- html, table_rec_elapse, polygons, logic_points, ocr_res = table_engine(
121
- img, ocr_result=ocr_res,
122
- morph_close=False, more_h_lines=False, more_v_lines=False, extend_line=False
123
- )
124
- else:
125
- html, table_rec_elapse, polygons, logic_points, ocr_res = table_engine(
126
- img, ocr_result=ocr_res
127
- )
128
-
129
  sum_elapse = time.time() - start
130
  all_elapse = f"- table_type: {talbe_type}\n table all cost: {sum_elapse:.5f}\n - table rec cost: {table_rec_elapse:.5f}\n - ocr cost: {det_cost + cls_cost + rec_cost:.5f}"
131
 
@@ -191,10 +200,33 @@ def main():
191
  label="Box Cutting Enhancement (Disable to avoid excessive cutting, Enable to reduce missed cutting)",
192
  value=True
193
  )
194
- det_model = gr.Dropdown(det_models_labels, label="Select OCR Detection Model",
195
- value=det_models_labels[0])
196
- rec_model = gr.Dropdown(rec_models_labels, label="Select OCR Recognition Model",
197
- value=rec_models_labels[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
  run_button = gr.Button("Run")
200
  gr.Markdown("# Elapsed Time")
@@ -210,7 +242,7 @@ def main():
210
 
211
  run_button.click(
212
  fn=process_image,
213
- inputs=[img_input, table_engine_type, det_model, rec_model, small_box_cut_enhance],
214
  outputs=[html_output, table_boxes_output, ocr_boxes_output, elapse_text]
215
  )
216
 
 
33
  # 示例图片路径
34
  example_images = [
35
  "images/wired1.png",
36
+ "images/wired2.jpg",
37
  "images/wired3.png",
38
  "images/lineless1.png",
39
  "images/wired4.jpg",
 
67
  rec_model_dir=rec_model_path
68
  )
69
 
70
+ def trans_char_ocr_res(ocr_res):
71
+ word_result = []
72
+ for res in ocr_res:
73
+ score = res[2]
74
+ for word_box, word in zip(res[3], res[4]):
75
+ word_res = []
76
+ word_res.append(word_box)
77
+ word_res.append(word)
78
+ word_res.append(score)
79
+ word_result.append(word_res)
80
+ return word_result
81
 
82
  def select_ocr_model(det_model, rec_model):
83
  return ocr_engine_dict[f"{det_model}_{rec_model}"]
 
105
  return lineless_table_engine, "lineless_table"
106
 
107
 
108
+ def process_image(img_input, small_box_cut_enhance, table_engine_type, char_ocr, rotated_fix, col_threshold, row_threshold):
109
+ det_model="mobile_det"
110
+ rec_model="mobile_rec"
111
+ img = img_loader(img_input)
112
  start = time.time()
113
  table_engine, talbe_type = select_table_model(img, table_engine_type, det_model, rec_model)
114
  ocr_engine = select_ocr_model(det_model, rec_model)
 
121
  ocr_boxes = result[0]['res']['boxes']
122
  all_elapse = f"- `table all cost: {time.time() - start:.5f}`"
123
  else:
124
+ ocr_res, ocr_infer_elapse = ocr_engine(img, return_word_box=char_ocr)
125
  det_cost, cls_cost, rec_cost = ocr_infer_elapse
126
+ if char_ocr:
127
+ ocr_res = trans_char_ocr_res(ocr_res)
128
  ocr_boxes = [box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_res]
129
  if isinstance(table_engine, RapidTable):
130
  html, polygons, table_rec_elapse = table_engine(img, ocr_result=ocr_res)
131
  polygons = [[polygon[0], polygon[1], polygon[4], polygon[5]] for polygon in polygons]
132
  elif isinstance(table_engine, (WiredTableRecognition, LinelessTableRecognition)):
133
+ html, table_rec_elapse, polygons, logic_points, ocr_res = table_engine(img, ocr_result=ocr_res,
134
+ enhance_box_line=small_box_cut_enhance,
135
+ rotated_fix=rotated_fix,
136
+ col_threshold=col_threshold,
137
+ row_threshold=row_threshold)
 
 
 
 
 
 
138
  sum_elapse = time.time() - start
139
  all_elapse = f"- table_type: {talbe_type}\n table all cost: {sum_elapse:.5f}\n - table rec cost: {table_rec_elapse:.5f}\n - ocr cost: {det_cost + cls_cost + rec_cost:.5f}"
140
 
 
200
  label="Box Cutting Enhancement (Disable to avoid excessive cutting, Enable to reduce missed cutting)",
201
  value=True
202
  )
203
+ char_ocr = gr.Checkbox(
204
+ label="char rec ocr",
205
+ value=False
206
+ )
207
+ rotate_adapt = gr.Checkbox(
208
+ label="Table Rotate Rec Enhancement",
209
+ value=False
210
+ )
211
+ col_threshold = gr.Slider(
212
+ label="col threshold(determine same col)",
213
+ minimum=5,
214
+ maximum=100,
215
+ value=15,
216
+ step=5
217
+ )
218
+ row_threshold = gr.Slider(
219
+ label="row threshold(determine same row)",
220
+ minimum=5,
221
+ maximum=100,
222
+ value=10,
223
+ step=5
224
+ )
225
+
226
+ # det_model = gr.Dropdown(det_models_labels, label="Select OCR Detection Model",
227
+ # value=det_models_labels[0])
228
+ # rec_model = gr.Dropdown(rec_models_labels, label="Select OCR Recognition Model",
229
+ # value=rec_models_labels[0])
230
 
231
  run_button = gr.Button("Run")
232
  gr.Markdown("# Elapsed Time")
 
242
 
243
  run_button.click(
244
  fn=process_image,
245
+ inputs=[img_input, small_box_cut_enhance, table_engine_type, char_ocr, rotate_adapt, col_threshold, row_threshold],
246
  outputs=[html_output, table_boxes_output, ocr_boxes_output, elapse_text]
247
  )
248