Spaces:
Running
on
T4
Running
on
T4
File size: 4,836 Bytes
a36a5bd 7263d32 089249c 76f8319 7263d32 76f8319 7263d32 089249c 7263d32 a36a5bd 7263d32 089249c 7263d32 089249c 7263d32 76f8319 7263d32 a36a5bd 089249c 7263d32 089249c 7263d32 a36a5bd 60af1a7 a36a5bd 7263d32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import gradio as gr
from tqdm import tqdm
from src.htr_pipeline.utils.process_segmask import SegMaskHelper
from src.htr_pipeline.utils.xml_helper import XMLHelper
terminate = False
# TODO check why region is so slow to start.. Is their error with loading the model?
class PipelineInferencer:
def __init__(self, process_seg_mask: SegMaskHelper, xml_helper: XMLHelper):
self.process_seg_mask = process_seg_mask
self.xml_helper = xml_helper
def image_to_page_xml(
self, image, pred_score_threshold_regions, pred_score_threshold_lines, containments_threshold, inferencer
):
template_data = self.xml_helper.prepare_template_data(self.xml_helper.xml_file_name, image)
template_data["textRegions"] = self._process_regions(
image, inferencer, pred_score_threshold_regions, pred_score_threshold_lines, containments_threshold
)
return self.xml_helper.render(template_data)
def _process_regions(
self,
image,
inferencer,
pred_score_threshold_regions,
pred_score_threshold_lines,
containments_threshold,
htr_threshold=0.6,
):
global terminate
_, regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered = inferencer.predict_regions(
image,
pred_score_threshold=pred_score_threshold_regions,
containments_threshold=containments_threshold,
visualize=False,
)
gr.Info(f"Found {len(regions_cropped_ordered)} Regions to parse")
region_data_list = []
for i, data in tqdm(enumerate(zip(regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered))):
if terminate:
break
region_data = self._create_region_data(
data, i, inferencer, pred_score_threshold_lines, containments_threshold, htr_threshold
)
if region_data:
region_data_list.append(region_data)
return region_data_list
def _create_region_data(
self, data, index, inferencer, pred_score_threshold_lines, containments_threshold, htr_threshold
):
text_region, reg_pol, mask = data
region_data = {"id": f"region_{index}", "boundary": reg_pol}
text_lines, htr_scores = self._process_lines(
text_region,
inferencer,
pred_score_threshold_lines,
containments_threshold,
mask,
region_data["id"],
htr_threshold,
)
if not text_lines:
return None
region_data["textLines"] = text_lines
mean_htr_score = sum(htr_scores) / len(htr_scores) if htr_scores else 0
return region_data if mean_htr_score > htr_threshold + 0.1 else None
def _process_lines(
self, text_region, inferencer, pred_score_threshold, containments_threshold, mask, region_id, htr_threshold=0.6
):
_, lines_cropped_ordered, line_polygons_ordered = inferencer.predict_lines(
text_region, pred_score_threshold, containments_threshold, visualize=False, custom_track=False
)
if not lines_cropped_ordered:
return None, []
line_polygons_ordered_trans = self.process_seg_mask._translate_line_coords(mask, line_polygons_ordered)
text_lines = []
htr_scores = []
id_number = region_id.split("_")[1]
total_lines_len = len(lines_cropped_ordered)
gr.Info(f" Region {id_number}, found {total_lines_len} lines to parse and transcribe.")
global terminate
for index, (line, line_pol) in enumerate(zip(lines_cropped_ordered, line_polygons_ordered_trans)):
if terminate:
break
line_data, htr_score = self._create_line_data(line, line_pol, index, region_id, inferencer, htr_threshold)
if line_data:
text_lines.append(line_data)
htr_scores.append(htr_score)
remaining_lines = total_lines_len - index - 1
if (index + 1) % 10 == 0 and remaining_lines > 5: # +1 because index starts at 0
gr.Info(
f"Region {id_number}, parsed {index + 1} lines. Still {remaining_lines} lines left to transcribe."
)
return text_lines, htr_scores
def _create_line_data(self, line, line_pol, index, region_id, inferencer, htr_threshold):
line_data = {"id": f"line_{region_id}_{index}", "boundary": line_pol}
transcribed_text, htr_score = inferencer.transcribe(line)
line_data["unicode"] = self.xml_helper.escape_xml_chars(transcribed_text)
line_data["pred_score"] = round(htr_score, 4)
return line_data if htr_score > htr_threshold else None, htr_score
if __name__ == "__main__":
pass
|