htr_demo / tabs /htr_tool.py
Gabriel's picture
Update tabs/htr_tool.py
fcad386 verified
import os
import gradio as gr
from helper.examples.examples import DemoImages
from helper.utils import TrafficDataHandler
from src.htr_pipeline.gradio_backend import (
FastTrack,
SingletonModelLoader,
compare_diff_runs_highlight,
compute_cer_a_and_b_with_gt,
update_selected_tab_image_viewer,
update_selected_tab_model_compare,
update_selected_tab_output_and_setting,
upload_file,
)
model_loader = SingletonModelLoader()
fast_track = FastTrack(model_loader)
images_for_demo = DemoImages()
terminate = False
with gr.Blocks() as htr_tool_tab:
with gr.Row(equal_height=True):
with gr.Column(scale=2):
with gr.Row():
fast_track_input_region_image = gr.Image(
label="Image to run HTR on", type="numpy", tool="editor", elem_id="image_upload", height=395
)
with gr.Row():
with gr.Tab("HTRFLOW") as tab_output_and_setting_selector:
with gr.Row():
stop_htr_button = gr.Button(
value="Stop run",
variant="stop",
)
htr_pipeline_button = gr.Button(
"Run ",
variant="primary",
visible=True,
elem_id="run_pipeline_button",
)
htr_pipeline_button_var = gr.State(value="htr_pipeline_button")
htr_pipeline_button_api = gr.Button("Run pipeline", variant="primary", visible=False, scale=1)
fast_file_downlod = gr.File(
label="Download output file", visible=True, scale=1, height=100, elem_id="download_file"
)
with gr.Tab("Visualize") as tab_image_viewer_selector:
with gr.Row():
gr.Markdown("")
run_image_visualizer_button = gr.Button(
value="Visualize results", variant="primary", interactive=True
)
selection_text_from_image_viewer = gr.Textbox(
interactive=False, label="Text Selector", info="Select a line on Image Viewer to return text"
)
with gr.Tab("Compare") as tab_model_compare_selector:
with gr.Row():
diff_runs_button = gr.Button("Compare runs", variant="primary", visible=True)
calc_cer_button_fast = gr.Button("Calculate CER", variant="primary", visible=True)
with gr.Row():
cer_output_fast = gr.Textbox(
label="Character Error Rate:",
info="The percentage of characters that have been transcribed incorrectly",
)
with gr.Column(scale=4):
with gr.Box():
with gr.Row(visible=True) as output_and_setting_tab:
with gr.Column(scale=2):
fast_name_files_placeholder = gr.Markdown(visible=False)
gr.Examples(
examples=images_for_demo.examples_list,
inputs=[fast_name_files_placeholder, fast_track_input_region_image],
label="Example images",
examples_per_page=5,
)
gr.Markdown(" ")
with gr.Column(scale=3):
with gr.Group():
gr.Markdown("   ⚙️ Settings ")
with gr.Row():
radio_file_input = gr.CheckboxGroup(
choices=["Txt", "Page XML"],
value=["Txt", "Page XML"],
label="Output file extension",
info="JSON and ALTO-XML will be added",
scale=1,
)
with gr.Row():
gr.Checkbox(
value=True,
label="Binarize image",
info="Binarize image to reduce background noise",
)
gr.Checkbox(
value=True,
label="Output prediction threshold",
info="Output XML with prediction score",
)
with gr.Accordion("Advanced settings", open=False):
with gr.Group():
with gr.Row():
htr_tool_region_segment_model_dropdown = gr.Dropdown(
choices=["Riksarkivet/rtmdet_region"],
value="Riksarkivet/rtmdet_region",
label="Region segmentation models",
info="More models will be added",
)
gr.Slider(
minimum=0.4,
maximum=1,
value=0.5,
step=0.05,
label="P-threshold",
info="""Filter confidence score for a prediction score to be considered""",
)
with gr.Row():
htr_tool_line_segment_model_dropdown = gr.Dropdown(
choices=["Riksarkivet/rtmdet_lines"],
value="Riksarkivet/rtmdet_lines",
label="Line segmentation models",
info="More models will be added",
)
gr.Slider(
minimum=0.4,
maximum=1,
value=0.5,
step=0.05,
label="P-threshold",
info="""Filter confidence score for a prediction score to be considered""",
)
with gr.Row():
htr_tool_transcriber_model_dropdown = gr.Dropdown(
choices=[
"Riksarkivet/trocr-base-handwritten-swe",
"Riksarkivet/satrn_htr",
"microsoft/trocr-base-handwritten",
"pstroe/bullinger-general-model",
],
value="Riksarkivet/trocr-base-handwritten-swe",
label="Text recognition models",
info="More models will be added",
)
gr.Slider(
value=0.6,
minimum=0.5,
maximum=1,
label="HTR threshold",
info="Prediction score threshold for transcribed lines",
scale=1,
)
with gr.Row():
gr.Markdown("   More settings will be added")
with gr.Row(visible=False) as image_viewer_tab:
text_polygon_dict = gr.Variable()
fast_track_output_image = gr.Image(
label="Image Viewer", type="numpy", height=600, interactive=False
)
with gr.Column(visible=False) as model_compare_selector:
with gr.Row():
gr.Markdown("Compare different runs (Page XML output) with Ground Truth (GT)")
with gr.Row():
with gr.Group():
upload_button_run_a = gr.UploadButton("A", file_types=[".xml"], file_count="single")
file_input_xml_run_a = gr.File(
label=None,
file_count="single",
height=100,
elem_id="download_file",
interactive=False,
visible=False,
)
with gr.Group():
upload_button_run_b = gr.UploadButton("B", file_types=[".xml"], file_count="single")
file_input_xml_run_b = gr.File(
label=None,
file_count="single",
height=100,
elem_id="download_file",
interactive=False,
visible=False,
)
with gr.Group():
upload_button_run_gt = gr.UploadButton("GT", file_types=[".xml"], file_count="single")
file_input_xml_run_gt = gr.File(
label=None,
file_count="single",
height=100,
elem_id="download_file",
interactive=False,
visible=False,
)
with gr.Tab("Comparing run A with B"):
text_diff_runs = gr.HighlightedText(
label="A with B",
combine_adjacent=True,
show_legend=True,
color_map={"+": "red", "-": "green"},
)
with gr.Tab("Compare run A with Ground Truth"):
text_diff_gt = gr.HighlightedText(
label="A with GT",
combine_adjacent=True,
show_legend=True,
color_map={"+": "red", "-": "green"},
)
xml_rendered_placeholder_for_api = gr.Textbox(placeholder="XML", visible=False)
htr_event_click_event = htr_pipeline_button.click(
fast_track.segment_to_xml,
inputs=[fast_track_input_region_image, radio_file_input, htr_tool_transcriber_model_dropdown],
outputs=[fast_file_downlod, fast_file_downlod],
api_name=False,
)
htr_pipeline_button_api.click(
fast_track.segment_to_xml_api,
inputs=[fast_track_input_region_image],
outputs=[xml_rendered_placeholder_for_api],
queue=False,
api_name="run_htr_pipeline",
)
tab_output_and_setting_selector.select(
fn=update_selected_tab_output_and_setting,
outputs=[output_and_setting_tab, image_viewer_tab, model_compare_selector],
api_name=False,
)
tab_image_viewer_selector.select(
fn=update_selected_tab_image_viewer,
outputs=[output_and_setting_tab, image_viewer_tab, model_compare_selector],
api_name=False,
)
tab_model_compare_selector.select(
fn=update_selected_tab_model_compare,
outputs=[output_and_setting_tab, image_viewer_tab, model_compare_selector],
api_name=False,
)
def stop_function():
from src.htr_pipeline.utils import pipeline_inferencer
pipeline_inferencer.terminate = True
gr.Info("The HTR execution was halted")
stop_htr_button.click(
fn=stop_function,
inputs=None,
outputs=None,
api_name=False,
# cancels=[htr_event_click_event],
)
run_image_visualizer_button.click(
fn=fast_track.visualize_image_viewer,
inputs=fast_track_input_region_image,
outputs=[fast_track_output_image, text_polygon_dict],
api_name=False,
)
fast_track_output_image.select(
fast_track.get_text_from_coords,
inputs=text_polygon_dict,
outputs=selection_text_from_image_viewer,
api_name=False,
)
upload_button_run_a.upload(
upload_file, inputs=upload_button_run_a, outputs=[file_input_xml_run_a, file_input_xml_run_a], api_name=False
)
upload_button_run_b.upload(
upload_file, inputs=upload_button_run_b, outputs=[file_input_xml_run_b, file_input_xml_run_b], api_name=False
)
upload_button_run_gt.upload(
upload_file, inputs=upload_button_run_gt, outputs=[file_input_xml_run_gt, file_input_xml_run_gt], api_name=False
)
diff_runs_button.click(
fn=compare_diff_runs_highlight,
inputs=[file_input_xml_run_a, file_input_xml_run_b, file_input_xml_run_gt],
outputs=[text_diff_runs, text_diff_gt],
api_name=False,
)
calc_cer_button_fast.click(
fn=compute_cer_a_and_b_with_gt,
inputs=[file_input_xml_run_a, file_input_xml_run_b, file_input_xml_run_gt],
outputs=cer_output_fast,
api_name=False,
)
SECRET_KEY = os.environ.get("HUB_TOKEN", False)
if SECRET_KEY:
htr_pipeline_button.click(
fn=TrafficDataHandler.store_metric_data,
inputs=htr_pipeline_button_var,
)