Joker1212 commited on
Commit
8570e66
·
verified ·
1 Parent(s): a27d639

update app

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import time
2
 
3
  import cv2
@@ -6,10 +7,11 @@ from lineless_table_rec import LinelessTableRecognition
6
  from paddleocr import PPStructure
7
  from rapid_table import RapidTable
8
  from rapidocr_onnxruntime import RapidOCR
 
9
  from table_cls import TableCls
10
  from wired_table_rec import WiredTableRecognition
11
- from utils import plot_rec_box, LoadImage, format_html, box_4_2_poly_to_box_4_1
12
 
 
13
  img_loader = LoadImage()
14
  table_rec_path = "models/table_rec/ch_ppstructure_mobile_v2_SLANet.onnx"
15
  det_model_dir = {
@@ -21,7 +23,8 @@ rec_model_dir = {
21
  }
22
  table_engine_list = [
23
  "auto",
24
- "rapid_table",
 
25
  "wired_table_v2",
26
  "pp_table",
27
  "wired_table_v1",
@@ -41,6 +44,7 @@ example_images = [
41
  "images/wired6.jpg",
42
  ]
43
  rapid_table_engine = RapidTable(model_path=table_rec_path)
 
44
  wired_table_engine_v1 = WiredTableRecognition(version="v1")
45
  wired_table_engine_v2 = WiredTableRecognition(version="v2")
46
  lineless_table_engine = LinelessTableRecognition()
@@ -69,8 +73,10 @@ def select_ocr_model(det_model, rec_model):
69
 
70
 
71
  def select_table_model(img, table_engine_type, det_model, rec_model):
72
- if table_engine_type == "rapid_table":
73
  return rapid_table_engine, table_engine_type
 
 
74
  elif table_engine_type == "wired_table_v1":
75
  return wired_table_engine_v1, table_engine_type
76
  elif table_engine_type == "wired_table_v2":
@@ -105,7 +111,6 @@ def process_image(img, table_engine_type, det_model, rec_model):
105
  ocr_res, ocr_infer_elapse = ocr_engine(img)
106
  det_cost, cls_cost, rec_cost = ocr_infer_elapse
107
  ocr_boxes = [box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_res]
108
-
109
  if isinstance(table_engine, RapidTable):
110
  html, polygons, table_rec_elapse = table_engine(img, ocr_result=ocr_res)
111
  polygons = [[polygon[0], polygon[1], polygon[4], polygon[5]] for polygon in polygons]
@@ -176,4 +181,4 @@ def main():
176
 
177
 
178
  if __name__ == '__main__':
179
- main()
 
1
+ import threading
2
  import time
3
 
4
  import cv2
 
7
  from paddleocr import PPStructure
8
  from rapid_table import RapidTable
9
  from rapidocr_onnxruntime import RapidOCR
10
+ from slanet_plus_table import SLANetPlus
11
  from table_cls import TableCls
12
  from wired_table_rec import WiredTableRecognition
 
13
 
14
+ from utils import plot_rec_box, LoadImage, format_html, box_4_2_poly_to_box_4_1
15
  img_loader = LoadImage()
16
  table_rec_path = "models/table_rec/ch_ppstructure_mobile_v2_SLANet.onnx"
17
  det_model_dir = {
 
23
  }
24
  table_engine_list = [
25
  "auto",
26
+ "RapidTable(SLANet)",
27
+ "RapidTable(SLANet-plus)",
28
  "wired_table_v2",
29
  "pp_table",
30
  "wired_table_v1",
 
44
  "images/wired6.jpg",
45
  ]
46
  rapid_table_engine = RapidTable(model_path=table_rec_path)
47
+ SLANet_plus_table_Engine = RapidTable()
48
  wired_table_engine_v1 = WiredTableRecognition(version="v1")
49
  wired_table_engine_v2 = WiredTableRecognition(version="v2")
50
  lineless_table_engine = LinelessTableRecognition()
 
73
 
74
 
75
  def select_table_model(img, table_engine_type, det_model, rec_model):
76
+ if table_engine_type == "RapidTable(SLANet)":
77
  return rapid_table_engine, table_engine_type
78
+ elif table_engine_type == "RapidTable(SLANet-plus)":
79
+ return SLANet_plus_table_Engine, table_engine_type
80
  elif table_engine_type == "wired_table_v1":
81
  return wired_table_engine_v1, table_engine_type
82
  elif table_engine_type == "wired_table_v2":
 
111
  ocr_res, ocr_infer_elapse = ocr_engine(img)
112
  det_cost, cls_cost, rec_cost = ocr_infer_elapse
113
  ocr_boxes = [box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_res]
 
114
  if isinstance(table_engine, RapidTable):
115
  html, polygons, table_rec_elapse = table_engine(img, ocr_result=ocr_res)
116
  polygons = [[polygon[0], polygon[1], polygon[4], polygon[5]] for polygon in polygons]
 
181
 
182
 
183
  if __name__ == '__main__':
184
+ main()