Spaces:
Sleeping
Sleeping
import gradio as gr | |
from tqdm import tqdm | |
from src.htr_pipeline.utils.process_segmask import SegMaskHelper | |
from src.htr_pipeline.utils.xml_helper import XMLHelper | |
terminate = False | |
# TODO check why region is so slow to start.. Is their error with loading the model? | |
class PipelineInferencer: | |
def __init__(self, process_seg_mask: SegMaskHelper, xml_helper: XMLHelper): | |
self.process_seg_mask = process_seg_mask | |
self.xml_helper = xml_helper | |
def image_to_page_xml( | |
self, | |
image, | |
htr_tool_transcriber_model_dropdown, | |
pred_score_threshold_regions, | |
pred_score_threshold_lines, | |
containments_threshold, | |
inferencer, | |
): | |
# temporary solutions.. for trocr.. | |
self.htr_tool_transcriber_model_dropdown = htr_tool_transcriber_model_dropdown | |
template_data = self.xml_helper.prepare_template_data(self.xml_helper.xml_file_name, image) | |
template_data["textRegions"] = self._process_regions( | |
image, inferencer, pred_score_threshold_regions, pred_score_threshold_lines, containments_threshold | |
) | |
return self.xml_helper.render(template_data) | |
def _process_regions( | |
self, | |
image, | |
inferencer, | |
pred_score_threshold_regions, | |
pred_score_threshold_lines, | |
containments_threshold, | |
htr_threshold=0.6, | |
): | |
global terminate | |
_, regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered = inferencer.predict_regions( | |
image, | |
pred_score_threshold=pred_score_threshold_regions, | |
containments_threshold=containments_threshold, | |
visualize=False, | |
) | |
gr.Info(f"Found {len(regions_cropped_ordered)} Regions to parse") | |
region_data_list = [] | |
for i, data in tqdm(enumerate(zip(regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered))): | |
if terminate: | |
break | |
region_data = self._create_region_data( | |
data, i, inferencer, pred_score_threshold_lines, containments_threshold, htr_threshold | |
) | |
if region_data: | |
region_data_list.append(region_data) | |
return region_data_list | |
def _create_region_data( | |
self, data, index, inferencer, pred_score_threshold_lines, containments_threshold, htr_threshold | |
): | |
text_region, reg_pol, mask = data | |
region_data = {"id": f"region_{index}", "boundary": reg_pol} | |
text_lines, htr_scores = self._process_lines( | |
text_region, | |
inferencer, | |
pred_score_threshold_lines, | |
containments_threshold, | |
mask, | |
region_data["id"], | |
htr_threshold, | |
) | |
if not text_lines: | |
return None | |
region_data["textLines"] = text_lines | |
mean_htr_score = sum(htr_scores) / len(htr_scores) if htr_scores else 0 | |
return region_data if mean_htr_score > htr_threshold + 0.1 else None | |
def _process_lines( | |
self, text_region, inferencer, pred_score_threshold, containments_threshold, mask, region_id, htr_threshold=0.6 | |
): | |
_, lines_cropped_ordered, line_polygons_ordered = inferencer.predict_lines( | |
text_region, pred_score_threshold, containments_threshold, visualize=False, custom_track=False | |
) | |
if not lines_cropped_ordered: | |
return None, [] | |
line_polygons_ordered_trans = self.process_seg_mask._translate_line_coords(mask, line_polygons_ordered) | |
text_lines = [] | |
htr_scores = [] | |
id_number = region_id.split("_")[1] | |
total_lines_len = len(lines_cropped_ordered) | |
gr.Info(f" Region {id_number}, found {total_lines_len} lines to parse and transcribe.") | |
global terminate | |
for index, (line, line_pol) in enumerate(zip(lines_cropped_ordered, line_polygons_ordered_trans)): | |
if terminate: | |
break | |
line_data, htr_score = self._create_line_data(line, line_pol, index, region_id, inferencer, htr_threshold) | |
if line_data: | |
text_lines.append(line_data) | |
htr_scores.append(htr_score) | |
remaining_lines = total_lines_len - index - 1 | |
if (index + 1) % 10 == 0 and remaining_lines > 5: # +1 because index starts at 0 | |
gr.Info( | |
f"Region {id_number}, parsed {index + 1} lines. Still {remaining_lines} lines left to transcribe." | |
) | |
return text_lines, htr_scores | |
def _create_line_data(self, line, line_pol, index, region_id, inferencer, htr_threshold): | |
line_data = {"id": f"line_{region_id}_{index}", "boundary": line_pol} | |
# temporary solution.. | |
if self.htr_tool_transcriber_model_dropdown == "Riksarkivet/satrn_htr": | |
transcribed_text, htr_score = inferencer.transcribe(line) | |
else: | |
transcribed_text, htr_score = inferencer.transcribe_different_model( | |
line, self.htr_tool_transcriber_model_dropdown | |
) | |
line_data["unicode"] = self.xml_helper.escape_xml_chars(transcribed_text) | |
line_data["pred_score"] = round(htr_score, 4) | |
return line_data if htr_score > htr_threshold else None, htr_score | |
if __name__ == "__main__": | |
pass | |