Spaces:
Running
Running
File size: 11,525 Bytes
3b5e13f 51d474a 3b5e13f 8570e66 fcd5742 3b5e13f 9e87980 3b5e13f 9e87980 3b5e13f 8570e66 51d474a 3b5e13f e0db3b9 3b5e13f e0db3b9 2a39c5a fcd5742 3b5e13f 51d474a 3b5e13f 9e1fe47 3b5e13f 8570e66 9e87980 8570e66 51d474a 3b5e13f 9e87980 3b5e13f 9e87980 3b5e13f 9e87980 3b5e13f 9e87980 3b5e13f 9e1fe47 3b5e13f 9e87980 3b5e13f e0db3b9 51d474a 3b5e13f e0db3b9 9e1fe47 e0db3b9 3b5e13f c887e15 fcd5742 c887e15 fcd5742 e0db3b9 fcd5742 e0db3b9 fcd5742 3b5e13f c887e15 fcd5742 c887e15 fcd5742 c887e15 fcd5742 c887e15 fcd5742 9e1fe47 c887e15 9e87980 3b5e13f 9e1fe47 3b5e13f fcd5742 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 |
import time
import cv2
import gradio as gr
from lineless_table_rec import LinelessTableRecognition
from rapid_table import RapidTable, RapidTableInput
from rapid_table.main import ModelType
from rapidocr_onnxruntime import RapidOCR
from table_cls import TableCls
from wired_table_rec import WiredTableRecognition
from utils import plot_rec_box, LoadImage, format_html, box_4_2_poly_to_box_4_1
img_loader = LoadImage()
table_rec_path = "models/table_rec/ch_ppstructure_mobile_v2_SLANet.onnx"
det_model_dir = {
"mobile_det": "models/ocr/ch_PP-OCRv4_det_infer.onnx",
}
rec_model_dir = {
"mobile_rec": "models/ocr/ch_PP-OCRv4_rec_infer.onnx",
}
table_engine_list = [
"auto",
"RapidTable(SLANet)",
"RapidTable(SLANet-plus)",
"RapidTable(unitable)",
"wired_table_v2",
"wired_table_v1",
"lineless_table"
]
# 示例图片路径
example_images = [
"images/wired1.jpg",
"images/wired2.png",
"images/wired3.png",
"images/lineless1.jpg",
"images/wired4.jpg",
"images/lineless2.png",
"images/wired5.jpg",
"images/lineless4.jpg",
"images/wired7.jpg",
"images/wired9.jpg",
]
rapid_table_engine = RapidTable(RapidTableInput(model_type=ModelType.PPSTRUCTURE_ZH.value))
SLANet_plus_table_Engine = RapidTable(RapidTableInput(model_type=ModelType.SLANETPLUS.value))
unitable_table_Engine = RapidTable(RapidTableInput(model_type=ModelType.UNITABLE.value))
wired_table_engine_v1 = WiredTableRecognition(version="v1")
wired_table_engine_v2 = WiredTableRecognition(version="v2")
lineless_table_engine = LinelessTableRecognition()
table_cls = TableCls()
ocr_engine_dict = {}
pp_engine_dict = {}
for det_model in det_model_dir.keys():
for rec_model in rec_model_dir.keys():
det_model_path = det_model_dir[det_model]
rec_model_path = rec_model_dir[rec_model]
key = f"{det_model}_{rec_model}"
ocr_engine_dict[key] = RapidOCR(det_model_path=det_model_path, rec_model_path=rec_model_path)
def trans_char_ocr_res(ocr_res):
word_result = []
for res in ocr_res:
score = res[2]
for word_box, word in zip(res[3], res[4]):
word_res = []
word_res.append(word_box)
word_res.append(word)
word_res.append(score)
word_result.append(word_res)
return word_result
def select_ocr_model(det_model, rec_model):
return ocr_engine_dict[f"{det_model}_{rec_model}"]
def select_table_model(img, table_engine_type, det_model, rec_model):
if table_engine_type == "RapidTable(SLANet)":
return rapid_table_engine, table_engine_type
elif table_engine_type == "RapidTable(SLANet-plus)":
return SLANet_plus_table_Engine, table_engine_type
elif table_engine_type == "RapidTable(unitable)":
return unitable_table_Engine, table_engine_type
elif table_engine_type == "wired_table_v1":
return wired_table_engine_v1, table_engine_type
elif table_engine_type == "wired_table_v2":
print("使用v2 wired table")
return wired_table_engine_v2, table_engine_type
elif table_engine_type == "lineless_table":
return lineless_table_engine, table_engine_type
elif table_engine_type == "auto":
cls, elasp = table_cls(img)
if cls == 'wired':
table_engine = wired_table_engine_v2
return table_engine, "wired_table_v2"
return lineless_table_engine, "lineless_table"
def process_image(img_input, small_box_cut_enhance, table_engine_type, char_ocr, rotated_fix, col_threshold, row_threshold):
det_model="mobile_det"
rec_model="mobile_rec"
img = img_loader(img_input)
start = time.time()
table_engine, talbe_type = select_table_model(img, table_engine_type, det_model, rec_model)
ocr_engine = select_ocr_model(det_model, rec_model)
ocr_res, ocr_infer_elapse = ocr_engine(img, return_word_box=char_ocr)
det_cost, cls_cost, rec_cost = ocr_infer_elapse
if char_ocr:
ocr_res = trans_char_ocr_res(ocr_res)
ocr_boxes = [box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_res]
if isinstance(table_engine, RapidTable):
table_results = table_engine(img, ocr_res)
html, polygons, table_rec_elapse = table_results.pred_html, table_results.cell_bboxes,table_results.elapse
polygons = [[polygon[0], polygon[1], polygon[4], polygon[5]] for polygon in polygons]
elif isinstance(table_engine, (WiredTableRecognition, LinelessTableRecognition)):
html, table_rec_elapse, polygons, logic_points, ocr_res = table_engine(img, ocr_result=ocr_res,
enhance_box_line=small_box_cut_enhance,
rotated_fix=rotated_fix,
col_threshold=col_threshold,
row_threshold=row_threshold)
sum_elapse = time.time() - start
all_elapse = f"- table_type: {talbe_type}\n table all cost: {sum_elapse:.5f}\n - table rec cost: {table_rec_elapse:.5f}\n - ocr cost: {det_cost + cls_cost + rec_cost:.5f}"
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
table_boxes_img = plot_rec_box(img.copy(), polygons)
ocr_boxes_img = plot_rec_box(img.copy(), ocr_boxes)
complete_html = format_html(html)
return complete_html, table_boxes_img, ocr_boxes_img, all_elapse
def main():
det_models_labels = list(det_model_dir.keys())
rec_models_labels = list(rec_model_dir.keys())
with gr.Blocks(css="""
.scrollable-container {
overflow-x: auto;
white-space: nowrap;
}
.header-links {
text-align: center;
}
.header-links a {
display: inline-block;
text-align: center;
margin-right: 10px; /* 调整间距 */
}
""") as demo:
gr.HTML(
"<h1 style='text-align: center;'><a href='https://github.com/RapidAI/TableStructureRec?tab=readme-ov-file'>TableStructureRec</a> & <a href='https://github.com/RapidAI/RapidTable'>RapidTable</a></h1>"
)
gr.HTML('''
<div class="header-links">
<a href=""><img src="https://img.shields.io/badge/Python->=3.6,<3.12-aff.svg"></a>
<a href=""><img src="https://img.shields.io/badge/OS-Linux%2C%20Mac%2C%20Win-pink.svg"></a>
<a href="https://pypi.org/project/lineless-table-rec/"><img alt="PyPI" src="https://img.shields.io/pypi/v/lineless-table-rec"></a>
<a href="https://pepy.tech/project/lineless-table-rec"><img src="https://static.pepy.tech/personalized-badge/lineless-table-rec?period=total&units=abbreviation&left_color=grey&right_color=blue&left_text=Downloads%20Lineless"></a>
<a href="https://pepy.tech/project/wired-table-rec"><img src="https://static.pepy.tech/personalized-badge/wired-table-rec?period=total&units=abbreviation&left_color=grey&right_color=blue&left_text=Downloads%20Wired"></a>
<a href="https://pepy.tech/project/rapid-table"><img src="https://static.pepy.tech/personalized-badge/rapid-table?period=total&units=abbreviation&left_color=grey&right_color=blue&left_text=Downloads%20RapidTable"></a>
<a href="https://semver.org/"><img alt="SemVer2.0" src="https://img.shields.io/badge/SemVer-2.0-brightgreen"></a>
<a href="https://github.com/psf/black"><img src="https://img.shields.io/badge/code%20style-black-000000.svg"></a>
<a href="https://github.com/RapidAI/TableStructureRec/blob/c41bbd23898cb27a957ed962b0ffee3c74dfeff1/LICENSE"><img alt="GitHub" src="https://img.shields.io/badge/license-Apache 2.0-blue"></a>
</div>
''')
with gr.Row(): # 两列布局
with gr.Tab("Options"):
with gr.Column(variant="panel", scale=1): # 侧边栏,宽度比例为1
img_input = gr.Image(label="Upload or Select Image", sources="upload", value="images/lineless3.jpg")
# 示例图片选择器
examples = gr.Examples(
examples=example_images,
examples_per_page=len(example_images),
inputs=img_input,
fn=lambda x: x, # 简单返回图片路径
outputs=img_input,
cache_examples=False
)
table_engine_type = gr.Dropdown(table_engine_list, label="Select Recognition Table Engine",
value=table_engine_list[0])
small_box_cut_enhance = gr.Checkbox(
label="Box Cutting Enhancement (Disable to avoid excessive cutting, Enable to reduce missed cutting)",
value=True
)
char_ocr = gr.Checkbox(
label="char rec ocr",
value=False
)
rotate_adapt = gr.Checkbox(
label="Table Rotate Rec Enhancement",
value=False
)
col_threshold = gr.Slider(
label="col threshold(determine same col)",
minimum=5,
maximum=100,
value=15,
step=5
)
row_threshold = gr.Slider(
label="row threshold(determine same row)",
minimum=5,
maximum=100,
value=10,
step=5
)
# det_model = gr.Dropdown(det_models_labels, label="Select OCR Detection Model",
# value=det_models_labels[0])
# rec_model = gr.Dropdown(rec_models_labels, label="Select OCR Recognition Model",
# value=rec_models_labels[0])
run_button = gr.Button("Run")
gr.Markdown("# Elapsed Time")
elapse_text = gr.Text(label="") # 使用 `gr.Text` 组件展示字符串
with gr.Column(scale=2): # 右边列
# 使用 Markdown 标题分隔各个组件
gr.Markdown("# Html Render")
html_output = gr.HTML(label="", elem_classes="scrollable-container")
gr.Markdown("# Table Boxes")
table_boxes_output = gr.Image(label="")
gr.Markdown("# OCR Boxes")
ocr_boxes_output = gr.Image(label="")
run_button.click(
fn=process_image,
inputs=[img_input, small_box_cut_enhance, table_engine_type, char_ocr, rotate_adapt, col_threshold, row_threshold],
outputs=[html_output, table_boxes_output, ocr_boxes_output, elapse_text]
)
demo.launch()
if __name__ == '__main__':
main()
|