Spaces:
Running
on
T4
Running
on
T4
File size: 4,577 Bytes
a36a5bd 7263d32 a36a5bd 7263d32 a36a5bd 7263d32 a36a5bd 60af1a7 a36a5bd 7263d32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import gradio as gr
from tqdm import tqdm
from src.htr_pipeline.utils.process_segmask import SegMaskHelper
from src.htr_pipeline.utils.xml_helper import XMLHelper
class PipelineInferencer:
def __init__(self, process_seg_mask: SegMaskHelper, xml_helper: XMLHelper):
self.process_seg_mask = process_seg_mask
self.xml_helper = xml_helper
def image_to_page_xml(
self, image, pred_score_threshold_regions, pred_score_threshold_lines, containments_threshold, inferencer
):
template_data = self.xml_helper.prepare_template_data(self.xml_helper.xml_file_name, image)
template_data["textRegions"] = self._process_regions(
image, inferencer, pred_score_threshold_regions, pred_score_threshold_lines, containments_threshold
)
return self.xml_helper.render(template_data)
def _process_regions(
self,
image,
inferencer,
pred_score_threshold_regions,
pred_score_threshold_lines,
containments_threshold,
htr_threshold=0.7,
):
_, regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered = inferencer.predict_regions(
image,
pred_score_threshold=pred_score_threshold_regions,
containments_threshold=containments_threshold,
visualize=False,
)
gr.Info(f"Found {len(regions_cropped_ordered)} Regions to parse")
region_data_list = []
for i, data in tqdm(enumerate(zip(regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered))):
region_data = self._create_region_data(
data, i, inferencer, pred_score_threshold_lines, containments_threshold, htr_threshold
)
if region_data:
region_data_list.append(region_data)
return region_data_list
def _create_region_data(
self, data, index, inferencer, pred_score_threshold_lines, containments_threshold, htr_threshold
):
text_region, reg_pol, mask = data
region_data = {"id": f"region_{index}", "boundary": reg_pol}
text_lines, htr_scores = self._process_lines(
text_region,
inferencer,
pred_score_threshold_lines,
containments_threshold,
mask,
region_data["id"],
htr_threshold,
)
if not text_lines:
return None
region_data["textLines"] = text_lines
mean_htr_score = sum(htr_scores) / len(htr_scores) if htr_scores else 0
return region_data if mean_htr_score > htr_threshold else None
def _process_lines(
self, text_region, inferencer, pred_score_threshold, containments_threshold, mask, region_id, htr_threshold=0.7
):
_, lines_cropped_ordered, line_polygons_ordered = inferencer.predict_lines(
text_region, pred_score_threshold, containments_threshold, visualize=False, custom_track=False
)
if not lines_cropped_ordered:
return None, []
line_polygons_ordered_trans = self.process_seg_mask._translate_line_coords(mask, line_polygons_ordered)
text_lines = []
htr_scores = []
id_number = region_id.split("_")[1]
total_lines_len = len(lines_cropped_ordered)
gr.Info(f" Region {id_number}, found {total_lines_len} lines to parse and transcribe.")
for index, (line, line_pol) in enumerate(zip(lines_cropped_ordered, line_polygons_ordered_trans)):
line_data, htr_score = self._create_line_data(line, line_pol, index, region_id, inferencer, htr_threshold)
if line_data:
text_lines.append(line_data)
htr_scores.append(htr_score)
remaining_lines = total_lines_len - index - 1
if (index + 1) % 10 == 0 and remaining_lines > 5: # +1 because index starts at 0
gr.Info(
f"Region {id_number}, parsed {index + 1} lines. Still {remaining_lines} lines left to transcribe."
)
return text_lines, htr_scores
def _create_line_data(self, line, line_pol, index, region_id, inferencer, htr_threshold):
line_data = {"id": f"line_{region_id}_{index}", "boundary": line_pol}
transcribed_text, htr_score = inferencer.transcribe(line)
line_data["unicode"] = self.xml_helper.escape_xml_chars(transcribed_text)
line_data["pred_score"] = round(htr_score, 4)
return line_data if htr_score > htr_threshold else None, htr_score
if __name__ == "__main__":
pass
|