Joker1212 commited on
Commit
51d474a
·
verified ·
1 Parent(s): d6a3490

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -3,7 +3,8 @@ import time
3
  import cv2
4
  import gradio as gr
5
  from lineless_table_rec import LinelessTableRecognition
6
- from rapid_table import RapidTable
 
7
  from rapidocr_onnxruntime import RapidOCR
8
  from table_cls import TableCls
9
  from wired_table_rec import WiredTableRecognition
@@ -23,6 +24,7 @@ table_engine_list = [
23
  "auto",
24
  "RapidTable(SLANet)",
25
  "RapidTable(SLANet-plus)",
 
26
  "wired_table_v2",
27
  "wired_table_v1",
28
  "lineless_table"
@@ -41,8 +43,9 @@ example_images = [
41
  "images/wired7.jpg",
42
  "images/wired9.jpg",
43
  ]
44
- rapid_table_engine = RapidTable(model_path=table_rec_path)
45
- SLANet_plus_table_Engine = RapidTable()
 
46
  wired_table_engine_v1 = WiredTableRecognition(version="v1")
47
  wired_table_engine_v2 = WiredTableRecognition(version="v2")
48
  lineless_table_engine = LinelessTableRecognition()
@@ -77,6 +80,8 @@ def select_table_model(img, table_engine_type, det_model, rec_model):
77
  return rapid_table_engine, table_engine_type
78
  elif table_engine_type == "RapidTable(SLANet-plus)":
79
  return SLANet_plus_table_Engine, table_engine_type
 
 
80
  elif table_engine_type == "wired_table_v1":
81
  return wired_table_engine_v1, table_engine_type
82
  elif table_engine_type == "wired_table_v2":
@@ -106,7 +111,8 @@ def process_image(img_input, small_box_cut_enhance, table_engine_type, char_ocr,
106
  ocr_res = trans_char_ocr_res(ocr_res)
107
  ocr_boxes = [box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_res]
108
  if isinstance(table_engine, RapidTable):
109
- html, polygons, table_rec_elapse = table_engine(img, ocr_result=ocr_res)
 
110
  polygons = [[polygon[0], polygon[1], polygon[4], polygon[5]] for polygon in polygons]
111
  elif isinstance(table_engine, (WiredTableRecognition, LinelessTableRecognition)):
112
  html, table_rec_elapse, polygons, logic_points, ocr_res = table_engine(img, ocr_result=ocr_res,
 
3
  import cv2
4
  import gradio as gr
5
  from lineless_table_rec import LinelessTableRecognition
6
+ from rapid_table import RapidTable, RapidTableInput
7
+ from rapid_table.main import ModelType
8
  from rapidocr_onnxruntime import RapidOCR
9
  from table_cls import TableCls
10
  from wired_table_rec import WiredTableRecognition
 
24
  "auto",
25
  "RapidTable(SLANet)",
26
  "RapidTable(SLANet-plus)",
27
+ "RapidTable(unitable)",
28
  "wired_table_v2",
29
  "wired_table_v1",
30
  "lineless_table"
 
43
  "images/wired7.jpg",
44
  "images/wired9.jpg",
45
  ]
46
+ rapid_table_engine = RapidTable(RapidTableInput(model_type=ModelType.PPSTRUCTURE_ZH.value))
47
+ SLANet_plus_table_Engine = RapidTable(RapidTableInput(model_type=ModelType.SLANETPLUS.value))
48
+ unitable_table_Engine = RapidTable(RapidTableInput(model_type=ModelType.UNITABLE.value))
49
  wired_table_engine_v1 = WiredTableRecognition(version="v1")
50
  wired_table_engine_v2 = WiredTableRecognition(version="v2")
51
  lineless_table_engine = LinelessTableRecognition()
 
80
  return rapid_table_engine, table_engine_type
81
  elif table_engine_type == "RapidTable(SLANet-plus)":
82
  return SLANet_plus_table_Engine, table_engine_type
83
+ elif table_engine_type == "RapidTable(unitable)":
84
+ return unitable_table_Engine, table_engine_type
85
  elif table_engine_type == "wired_table_v1":
86
  return wired_table_engine_v1, table_engine_type
87
  elif table_engine_type == "wired_table_v2":
 
111
  ocr_res = trans_char_ocr_res(ocr_res)
112
  ocr_boxes = [box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_res]
113
  if isinstance(table_engine, RapidTable):
114
+ table_results = table_engine(img, ocr_res)
115
+ html, polygons, table_rec_elapse = table_results.pred_html, table_results.cell_bboxes,table_results.elapse
116
  polygons = [[polygon[0], polygon[1], polygon[4], polygon[5]] for polygon in polygons]
117
  elif isinstance(table_engine, (WiredTableRecognition, LinelessTableRecognition)):
118
  html, table_rec_elapse, polygons, logic_points, ocr_res = table_engine(img, ocr_result=ocr_res,