Spaces:
Running
on
T4
Running
on
T4
File size: 5,283 Bytes
5ebeb73 4c85050 5ebeb73 417b347 4c85050 417b347 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import os
import re
from datetime import datetime
import jinja2
from tqdm import tqdm
from src.htr_pipeline.inferencer import InferencerInterface
from src.htr_pipeline.utils.process_segmask import SegMaskHelper
class XMLHelper:
def __init__(self):
self.process_seg_mask = SegMaskHelper()
def image_to_page_xml(
self,
image,
pred_score_threshold_regions,
pred_score_threshold_lines,
containments_threshold,
inferencer: InferencerInterface,
xml_file_name="page_xml.xml",
):
img_height = image.shape[0]
img_width = image.shape[1]
img_file_name = xml_file_name
template_data = self.prepare_template_data(img_file_name, img_width, img_height)
template_data["textRegions"] = self._process_regions(
image,
inferencer,
pred_score_threshold_regions,
pred_score_threshold_lines,
containments_threshold,
)
rendered_xml = self._render_xml(template_data)
return rendered_xml
def _transform_coords(self, input_string):
pattern = r"\[\s*([^\s,]+)\s*,\s*([^\s\]]+)\s*\]"
replacement = r"\1,\2"
return re.sub(pattern, replacement, input_string)
def _render_xml(self, template_data):
template_loader = jinja2.FileSystemLoader(searchpath="./src/htr_pipeline/utils/templates")
template_env = jinja2.Environment(loader=template_loader, trim_blocks=True)
template = template_env.get_template("page_xml_2013.xml")
rendered_xml = template.render(template_data)
rendered_xml = self._transform_coords(rendered_xml)
return rendered_xml
def prepare_template_data(self, img_file_name, img_width, img_height):
now = datetime.now()
date_time = now.strftime("%Y-%m-%d, %H:%M:%S")
return {
"created": date_time,
"imageFilename": img_file_name,
"imageWidth": img_width,
"imageHeight": img_height,
"textRegions": list(),
}
def _process_regions(
self,
image,
inferencer: InferencerInterface,
pred_score_threshold_regions,
pred_score_threshold_lines,
containments_threshold,
htr_threshold=0.7,
):
_, regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered = inferencer.predict_regions(
image,
pred_score_threshold=pred_score_threshold_regions,
containments_threshold=containments_threshold,
visualize=False,
)
region_data_list = []
for i, (text_region, reg_pol, mask) in tqdm(
enumerate(zip(regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered))
):
region_id = "region_" + str(i)
region_data = dict()
region_data["id"] = region_id
region_data["boundary"] = reg_pol
text_lines, htr_scores = self._process_lines(
text_region,
inferencer,
pred_score_threshold_lines,
containments_threshold,
mask,
region_id,
)
if text_lines is None:
continue
region_data["textLines"] = text_lines
mean_htr_score = sum(htr_scores) / len(htr_scores)
if mean_htr_score > htr_threshold:
region_data_list.append(region_data)
return region_data_list
def _process_lines(
self,
text_region,
inferencer: InferencerInterface,
pred_score_threshold_lines,
containments_threshold,
mask,
region_id,
htr_threshold=0.7,
):
_, lines_cropped_ordered, line_polygons_ordered = inferencer.predict_lines(
text_region,
pred_score_threshold=pred_score_threshold_lines,
containments_threshold=containments_threshold,
visualize=False,
custom_track=False,
)
if lines_cropped_ordered is None:
return None, None
line_polygons_ordered_trans = self.process_seg_mask._translate_line_coords(mask, line_polygons_ordered)
htr_scores = list()
text_lines = list()
for j, (line, line_pol) in enumerate(zip(lines_cropped_ordered, line_polygons_ordered_trans)):
line_id = "line_" + region_id + "_" + str(j)
line_data = dict()
line_data["id"] = line_id
line_data["boundary"] = line_pol
transcribed_text, htr_score = inferencer.transcribe(line)
escaped_text = self._escape_xml_chars(transcribed_text)
line_data["unicode"] = escaped_text
line_data["pred_score"] = round(htr_score, 4)
htr_scores.append(htr_score)
if htr_score > htr_threshold:
text_lines.append(line_data)
return text_lines, htr_scores
def _escape_xml_chars(self, textline):
return (
textline.replace("&", "&")
.replace("<", "<")
.replace(">", ">")
.replace("'", "'")
.replace('"', """)
)
if __name__ == "__main__":
pass
|