RapidUnwrap / app.py
Joker1212's picture
fix: add binarize_model
d82549a
import gradio as gr
from rapid_undistorted.inference import InferenceEngine
import numpy as np
# 初始化模型
engine = InferenceEngine("config.yaml")
# 添加示例
example_images = [
"images/demo.jpg",
"images/demo1.jpg",
"images/demo1.png",
"images/demo2.png",
"images/demo3.jpg",
]
# 定义任务和模型选项
tasks = {
"unwrap": ["UVDoc", None],
"unshadow": ["GCDnet", None],
"unblur": ["OpenCvBilateral", "NAFDPM", None],
"binarize": ["UnetCnn", None]
}
def process_image(img_path, unwrap_model, unshadow_model, binarize_model, unblur_model):
task_list = []
if unwrap_model:
task_list.append(("unwrap", unwrap_model))
if unshadow_model:
task_list.append(("unshadow", unshadow_model))
if binarize_model:
task_list.append(("binarize", binarize_model))
if unblur_model:
task_list.append(("unblur", unblur_model))
unwrapped_img, elapse = engine(img_path, task_list)
print(f"doc unwrap elapse: {elapse}")
return unwrapped_img.astype(np.uint8),elapse
def main():
# 定义Gradio界面
with gr.Blocks(css="""
.scrollable-container {
overflow-x: auto;
white-space: nowrap;
}
.header-links {
text-align: center;
}
.header-links a {
display: inline-block;
text-align: center;
margin-right: 10px; /* 调整间距 */
}
""") as demo:
gr.HTML(
"<h1 style='text-align: center;'><a href='https://github.com/Joker1212/RapidUnWrap'>RapidUnDistort</a></h1>"
)
gr.HTML('''
<div class="header-links">
<a href=""><img src="https://img.shields.io/badge/Python->=3.8,<3.13-aff.svg"></a>
<a href=""><img src="https://img.shields.io/badge/OS-Linux%2C%20Mac%2C%20Win-pink.svg"></a>
<a href="https://semver.org/"><img alt="SemVer2.0" src="https://img.shields.io/badge/SemVer-2.0-brightgreen"></a>
<a href="https://github.com/psf/black"><img src="https://img.shields.io/badge/code%20style-black-000000.svg"></a>
<a href="https://github.com/RapidAI/TableStructureRec/blob/c41bbd23898cb27a957ed962b0ffee3c74dfeff1/LICENSE"><img alt="GitHub" src="https://img.shields.io/badge/license-Apache 2.0-blue"></a>
</div>
''')
with gr.Row():
with gr.Column(scale=1): # 左边占1/3
img_input = gr.Image(label="Upload or Select Image", sources="upload", value="images/demo1.jpg")
# 示例图片选择器
examples = gr.Examples(
examples=example_images,
examples_per_page=len(example_images),
inputs=img_input,
fn=lambda x: x, # 简单返回图片路径
outputs=img_input,
cache_examples=False
)
unwrap_dropdown = gr.Dropdown(label="Select Unwrap Model", choices=tasks["unwrap"], value="UVDoc")
unshadow_dropdown = gr.Dropdown(label="Select Unshadow Model", choices=tasks["unshadow"], value="GCDnet")
binarize_dropdown = gr.Dropdown(label="Select Binarize Model", choices=tasks["binarize"], value=None)
unblur_dropdown = gr.Dropdown(label="Select Unblur Model", choices=tasks["unblur"], value="OpenCvBilateral")
run_button = gr.Button("summit")
with gr.Column(scale=2): # 右边占2/3
output_image = gr.Image(label="output")
elapse_textbox = gr.Textbox(label="Elapsed Time", interactive=False)
# 绑定按钮点击事件
run_button.click(fn=process_image, inputs=[img_input, unwrap_dropdown, unshadow_dropdown, binarize_dropdown, unblur_dropdown],outputs=[output_image, elapse_textbox])
# 启动应用
demo.launch()
if __name__ == '__main__':
main()