Spaces:
Running
on
T4
Running
on
T4
File size: 5,342 Bytes
a36a5bd 7263d32 089249c 76f8319 7263d32 c60ebd1 7263d32 c60ebd1 7263d32 76f8319 7263d32 089249c 7263d32 a36a5bd 7263d32 089249c 7263d32 089249c 7263d32 76f8319 7263d32 a36a5bd 089249c 7263d32 089249c 7263d32 a36a5bd 60af1a7 a36a5bd 7263d32 c60ebd1 7263d32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import gradio as gr
from tqdm import tqdm
from src.htr_pipeline.utils.process_segmask import SegMaskHelper
from src.htr_pipeline.utils.xml_helper import XMLHelper
terminate = False
# TODO check why region is so slow to start.. Is their error with loading the model?
class PipelineInferencer:
def __init__(self, process_seg_mask: SegMaskHelper, xml_helper: XMLHelper):
self.process_seg_mask = process_seg_mask
self.xml_helper = xml_helper
def image_to_page_xml(
self,
image,
htr_tool_transcriber_model_dropdown,
pred_score_threshold_regions,
pred_score_threshold_lines,
containments_threshold,
inferencer,
):
# temporary solutions.. for trocr..
self.htr_tool_transcriber_model_dropdown = htr_tool_transcriber_model_dropdown
template_data = self.xml_helper.prepare_template_data(self.xml_helper.xml_file_name, image)
template_data["textRegions"] = self._process_regions(
image, inferencer, pred_score_threshold_regions, pred_score_threshold_lines, containments_threshold
)
return self.xml_helper.render(template_data)
def _process_regions(
self,
image,
inferencer,
pred_score_threshold_regions,
pred_score_threshold_lines,
containments_threshold,
htr_threshold=0.6,
):
global terminate
_, regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered = inferencer.predict_regions(
image,
pred_score_threshold=pred_score_threshold_regions,
containments_threshold=containments_threshold,
visualize=False,
)
gr.Info(f"Found {len(regions_cropped_ordered)} Regions to parse")
region_data_list = []
for i, data in tqdm(enumerate(zip(regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered))):
if terminate:
break
region_data = self._create_region_data(
data, i, inferencer, pred_score_threshold_lines, containments_threshold, htr_threshold
)
if region_data:
region_data_list.append(region_data)
return region_data_list
def _create_region_data(
self, data, index, inferencer, pred_score_threshold_lines, containments_threshold, htr_threshold
):
text_region, reg_pol, mask = data
region_data = {"id": f"region_{index}", "boundary": reg_pol}
text_lines, htr_scores = self._process_lines(
text_region,
inferencer,
pred_score_threshold_lines,
containments_threshold,
mask,
region_data["id"],
htr_threshold,
)
if not text_lines:
return None
region_data["textLines"] = text_lines
mean_htr_score = sum(htr_scores) / len(htr_scores) if htr_scores else 0
return region_data if mean_htr_score > htr_threshold + 0.1 else None
def _process_lines(
self, text_region, inferencer, pred_score_threshold, containments_threshold, mask, region_id, htr_threshold=0.6
):
_, lines_cropped_ordered, line_polygons_ordered = inferencer.predict_lines(
text_region, pred_score_threshold, containments_threshold, visualize=False, custom_track=False
)
if not lines_cropped_ordered:
return None, []
line_polygons_ordered_trans = self.process_seg_mask._translate_line_coords(mask, line_polygons_ordered)
text_lines = []
htr_scores = []
id_number = region_id.split("_")[1]
total_lines_len = len(lines_cropped_ordered)
gr.Info(f" Region {id_number}, found {total_lines_len} lines to parse and transcribe.")
global terminate
for index, (line, line_pol) in enumerate(zip(lines_cropped_ordered, line_polygons_ordered_trans)):
if terminate:
break
line_data, htr_score = self._create_line_data(line, line_pol, index, region_id, inferencer, htr_threshold)
if line_data:
text_lines.append(line_data)
htr_scores.append(htr_score)
remaining_lines = total_lines_len - index - 1
if (index + 1) % 10 == 0 and remaining_lines > 5: # +1 because index starts at 0
gr.Info(
f"Region {id_number}, parsed {index + 1} lines. Still {remaining_lines} lines left to transcribe."
)
return text_lines, htr_scores
def _create_line_data(self, line, line_pol, index, region_id, inferencer, htr_threshold):
line_data = {"id": f"line_{region_id}_{index}", "boundary": line_pol}
# temporary solution..
if self.htr_tool_transcriber_model_dropdown == "Riksarkivet/satrn_htr":
transcribed_text, htr_score = inferencer.transcribe(line)
else:
transcribed_text, htr_score = inferencer.transcribe_different_model(
line, self.htr_tool_transcriber_model_dropdown
)
line_data["unicode"] = self.xml_helper.escape_xml_chars(transcribed_text)
line_data["pred_score"] = round(htr_score, 4)
return line_data if htr_score > htr_threshold else None, htr_score
if __name__ == "__main__":
pass
|