amlpai04 commited on
Commit
2a0d582
·
1 Parent(s): c115883

fixed some minior bugs in visuazliser

Browse files
Files changed (3) hide show
  1. app/main.py +7 -7
  2. app/tabs/submit.py +4 -0
  3. app/tabs/visualizer.py +135 -61
app/main.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import os
3
  from app.gradio_config import css, theme
@@ -5,7 +6,6 @@ from app.tabs.submit import (
5
  submit,
6
  custom_template_yaml,
7
  collection_submit_state,
8
- batch_image_gallery,
9
  )
10
  from app.tabs.visualizer import visualizer, collection_viz_state, viz_image_gallery
11
  from app.tabs.templating import (
@@ -53,6 +53,12 @@ with gr.Blocks(title="HTRflow", theme=theme, css=css) as demo:
53
  with gr.Tab(label="Visualize Result") as tab_visualizer:
54
  visualizer.render()
55
 
 
 
 
 
 
 
56
  @demo.load(
57
  inputs=[template_output_yaml_code],
58
  outputs=[template_output_yaml_code],
@@ -85,12 +91,6 @@ with gr.Blocks(title="HTRflow", theme=theme, css=css) as demo:
85
  fn=sync_gradio_objects,
86
  )
87
 
88
- # tab_visualizer.select(
89
- # inputs=[batch_image_gallery, viz_image_gallery],
90
- # outputs=[viz_image_gallery],
91
- # fn=sync_gradio_objects,
92
- # )
93
-
94
  tab_visualizer.select(
95
  inputs=[collection_submit_state, collection_viz_state],
96
  outputs=[collection_viz_state],
 
1
+ import shutil
2
  import gradio as gr
3
  import os
4
  from app.gradio_config import css, theme
 
6
  submit,
7
  custom_template_yaml,
8
  collection_submit_state,
 
9
  )
10
  from app.tabs.visualizer import visualizer, collection_viz_state, viz_image_gallery
11
  from app.tabs.templating import (
 
53
  with gr.Tab(label="Visualize Result") as tab_visualizer:
54
  visualizer.render()
55
 
56
+ @demo.load()
57
+ def inital_yaml_code():
58
+ tmp_dir = "tmp/"
59
+ if os.path.exists(tmp_dir) and os.path.isdir(tmp_dir):
60
+ shutil.rmtree(tmp_dir)
61
+
62
  @demo.load(
63
  inputs=[template_output_yaml_code],
64
  outputs=[template_output_yaml_code],
 
91
  fn=sync_gradio_objects,
92
  )
93
 
 
 
 
 
 
 
94
  tab_visualizer.select(
95
  inputs=[collection_submit_state, collection_viz_state],
96
  outputs=[collection_viz_state],
app/tabs/submit.py CHANGED
@@ -93,6 +93,8 @@ def run_htrflow(custom_template_yaml, batch_image_gallery, progress=gr.Progress(
93
  progress(0, desc="HTRflow: Starting")
94
  time.sleep(0.3)
95
 
 
 
96
  if batch_image_gallery is None:
97
  gr.Warning("HTRflow: You must upload atleast 1 image or more")
98
 
@@ -135,6 +137,8 @@ def tracking_exported_files(tmp_output_paths):
135
 
136
  exported_files = set()
137
 
 
 
138
  for tmp_folder in tmp_output_paths:
139
  for ext in accepted_extensions:
140
  search_pattern = os.path.join(tmp_folder, "**", f"*{ext}")
 
93
  progress(0, desc="HTRflow: Starting")
94
  time.sleep(0.3)
95
 
96
+ print(temp_config)
97
+
98
  if batch_image_gallery is None:
99
  gr.Warning("HTRflow: You must upload atleast 1 image or more")
100
 
 
137
 
138
  exported_files = set()
139
 
140
+ print(tmp_output_paths)
141
+
142
  for tmp_folder in tmp_output_paths:
143
  for ext in accepted_extensions:
144
  search_pattern = os.path.join(tmp_folder, "**", f"*{ext}")
app/tabs/visualizer.py CHANGED
@@ -4,10 +4,65 @@ import numpy as np
4
  from htrflow.volume.volume import Collection
5
  from htrflow.utils.draw import draw_polygons
6
  from htrflow.utils import imgproc
7
-
8
  from htrflow.results import Segment
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  with gr.Blocks() as visualizer:
12
  with gr.Column(variant="panel"):
13
  with gr.Row():
@@ -28,73 +83,92 @@ with gr.Blocks() as visualizer:
28
  "Visualize", scale=0, min_width=200, variant="secondary"
29
  )
30
 
31
- with gr.Column():
32
- # image_visualizer_annotation = gr.Image(
33
- # interactive=False,
34
- # )
35
 
36
- line2 = gr.Gallery(
 
 
 
 
 
 
 
 
 
37
  interactive=False,
38
  )
39
- textlines = gr.Dataframe()
40
 
41
- # @viz_image_gallery.select(outputs=image_visualizer_annotation)
42
- # def return_image_from_gallery(evt: gr.SelectData):
43
- # return evt.value["image"]["path"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- @visualize_button.click(
46
- outputs=[result_collection_viz_state, viz_image_gallery, line2, textlines]
 
47
  )
48
- def testie_load_pickle():
49
- col = Collection.from_pickle(".cache/HTRflow_demo_output.pickle")
 
 
50
 
51
- results = []
52
- for page_idx, page_node in enumerate(col):
53
- page_image = page_node.image.copy()
 
54
 
55
- lines = list(page_node.traverse(lambda node: node.is_line()))
 
56
 
57
- recog_conf_values = {
58
- i: list(zip(tr.texts, tr.scores)) if (tr := ln.text_result) else []
59
- for i, ln in enumerate(lines)
60
- }
61
 
62
- recog_df = pd.DataFrame(
63
- [
64
- {"Transcription": text, "Confidence Score": f"{score:.4f}"}
65
- for values in recog_conf_values.values()
66
- for text, score in values
67
- ]
68
- )
69
-
70
- line_polygons = []
71
- line_crops = []
72
- for ln in lines:
73
- seg: Segment = ln.data.get("segment")
74
- if not seg:
75
- continue
76
-
77
- cropped_line_img = imgproc.crop(page_image, seg.bbox)
78
- cropped_line_img = np.clip(cropped_line_img, 0, 255).astype(np.uint8)
79
- line_crops.append(cropped_line_img)
80
-
81
- if seg.polygon is not None:
82
- line_polygons.append(seg.polygon)
83
-
84
- annotated_image = draw_polygons(page_image, line_polygons)
85
- annotated_page_node = np.clip(annotated_image, 0, 255).astype(np.uint8)
86
- results.append(
87
- {
88
- "page_image": page_node,
89
- "annotated_page_node": annotated_page_node,
90
- "line_crops": line_crops,
91
- "recog_conf_values": recog_df,
92
- }
93
- )
94
-
95
- return (
96
- results,
97
- [results[0]["annotated_page_node"]],
98
- results[0]["line_crops"],
99
- results[0]["recog_conf_values"],
100
- )
 
4
  from htrflow.volume.volume import Collection
5
  from htrflow.utils.draw import draw_polygons
6
  from htrflow.utils import imgproc
7
+ import time
8
  from htrflow.results import Segment
9
 
10
 
11
+ def load_visualize_state_from_submit(col: Collection, progress):
12
+ results = []
13
+
14
+ time.sleep(1)
15
+
16
+ total_steps = len(col.pages)
17
+
18
+ for page_idx, page_node in enumerate(col):
19
+ page_image = page_node.image.copy()
20
+
21
+ progress((page_idx + 1) / total_steps, desc=f"Running Visualizer")
22
+
23
+ lines = list(page_node.traverse(lambda node: node.is_line()))
24
+
25
+ recog_conf_values = {
26
+ i: list(zip(tr.texts, tr.scores)) if (tr := ln.text_result) else []
27
+ for i, ln in enumerate(lines)
28
+ }
29
+
30
+ recog_df = pd.DataFrame(
31
+ [
32
+ {"Transcription": text, "Confidence Score": f"{score:.4f}"}
33
+ for values in recog_conf_values.values()
34
+ for text, score in values
35
+ ]
36
+ )
37
+
38
+ line_polygons = []
39
+ line_crops = []
40
+ for ln in lines:
41
+ seg: Segment = ln.data.get("segment")
42
+ if not seg:
43
+ continue
44
+
45
+ cropped_line_img = imgproc.crop(page_image, seg.bbox)
46
+ cropped_line_img = np.clip(cropped_line_img, 0, 255).astype(np.uint8)
47
+ line_crops.append(cropped_line_img)
48
+
49
+ if seg.polygon is not None:
50
+ line_polygons.append(seg.polygon)
51
+
52
+ annotated_image = draw_polygons(page_image, line_polygons)
53
+ annotated_page_node = np.clip(annotated_image, 0, 255).astype(np.uint8)
54
+ results.append(
55
+ {
56
+ "page_image": page_node,
57
+ "annotated_page_node": annotated_page_node,
58
+ "line_crops": line_crops,
59
+ "recog_conf_values": recog_df,
60
+ }
61
+ )
62
+
63
+ return results
64
+
65
+
66
  with gr.Blocks() as visualizer:
67
  with gr.Column(variant="panel"):
68
  with gr.Row():
 
83
  "Visualize", scale=0, min_width=200, variant="secondary"
84
  )
85
 
86
+ progress_bar = gr.Textbox(visible=False, show_label=False)
 
 
 
87
 
88
+ with gr.Column():
89
+ cropped_image_gallery = gr.Gallery(
90
+ interactive=False,
91
+ preview=True,
92
+ label="Cropped Polygons",
93
+ height=200,
94
+ )
95
+ df_for_cropped_images = gr.Dataframe(
96
+ label="Cropped Transcriptions",
97
+ headers=["Transcription", "Confidence Score"],
98
  interactive=False,
99
  )
 
100
 
101
+ def on_visualize_button_clicked(collection_viz, progress=gr.Progress()):
102
+ """
103
+ This function:
104
+ - Receives the collection (collection_viz).
105
+ - Processes it into 'results' (list of dicts with annotated_page_node, line_crops, dataframe).
106
+ - Returns:
107
+ 1) 'results' as state
108
+ 2) List of annotated_page_node images (one per page) to populate viz_image_gallery
109
+ """
110
+ if not collection_viz:
111
+ return None, []
112
+
113
+ results = load_visualize_state_from_submit(collection_viz, progress)
114
+ annotated_images = [r["annotated_page_node"] for r in results]
115
+ return results, annotated_images, gr.skip()
116
+
117
+ visualize_button.click(lambda: gr.update(visible=True), outputs=progress_bar).then(
118
+ fn=on_visualize_button_clicked,
119
+ inputs=collection_viz_state,
120
+ outputs=[result_collection_viz_state, viz_image_gallery, progress_bar],
121
+ ).then(lambda: gr.update(visible=False), outputs=progress_bar)
122
+
123
+ @viz_image_gallery.change(
124
+ inputs=result_collection_viz_state,
125
+ outputs=[cropped_image_gallery, df_for_cropped_images],
126
+ )
127
+ def update_c_gallery_and_dataframe(results):
128
+ selected = results[0]
129
+ return selected["line_crops"], selected["recog_conf_values"]
130
 
131
+ @viz_image_gallery.select(
132
+ inputs=result_collection_viz_state,
133
+ outputs=[cropped_image_gallery, df_for_cropped_images],
134
  )
135
+ def on_dataframe_select(evt: gr.SelectData, results):
136
+ """
137
+ evt.index => the index of the selected image in the gallery
138
+ results => the state object from result_collection_viz_state
139
 
140
+ Return the line crops and the recognized text for that index.
141
+ """
142
+ if results is None or evt.index is None:
143
+ return [], pd.DataFrame(columns=["Transcription", "Confidence Score"])
144
 
145
+ idx = evt.index
146
+ selected = results[idx]
147
 
148
+ return selected["line_crops"], selected["recog_conf_values"]
 
 
 
149
 
150
+ @df_for_cropped_images.select(
151
+ outputs=[cropped_image_gallery],
152
+ )
153
+ def on_dataframe_select(evt: gr.SelectData):
154
+ return gr.update(selected_index=evt.index[0])
155
+
156
+ @cropped_image_gallery.select(
157
+ inputs=df_for_cropped_images, outputs=df_for_cropped_images
158
+ )
159
+ def return_image_from_gallery(df, evt: gr.SelectData):
160
+ selected_index = evt.index
161
+
162
+ def highlight_row(row):
163
+ return [
164
+ (
165
+ "border: 1px solid blue; font-weight: bold"
166
+ if row.name == selected_index
167
+ else ""
168
+ )
169
+ for _ in row
170
+ ]
171
+
172
+ styler = df.style.apply(highlight_row, axis=1)
173
+
174
+ return styler