File size: 4,577 Bytes
a36a5bd
7263d32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a36a5bd
7263d32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a36a5bd
 
 
 
 
 
7263d32
 
 
 
 
 
 
a36a5bd
 
 
60af1a7
 
 
a36a5bd
7263d32
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr
from tqdm import tqdm

from src.htr_pipeline.utils.process_segmask import SegMaskHelper
from src.htr_pipeline.utils.xml_helper import XMLHelper


class PipelineInferencer:
    def __init__(self, process_seg_mask: SegMaskHelper, xml_helper: XMLHelper):
        self.process_seg_mask = process_seg_mask
        self.xml_helper = xml_helper

    def image_to_page_xml(
        self, image, pred_score_threshold_regions, pred_score_threshold_lines, containments_threshold, inferencer
    ):
        template_data = self.xml_helper.prepare_template_data(self.xml_helper.xml_file_name, image)
        template_data["textRegions"] = self._process_regions(
            image, inferencer, pred_score_threshold_regions, pred_score_threshold_lines, containments_threshold
        )

        return self.xml_helper.render(template_data)

    def _process_regions(
        self,
        image,
        inferencer,
        pred_score_threshold_regions,
        pred_score_threshold_lines,
        containments_threshold,
        htr_threshold=0.7,
    ):
        _, regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered = inferencer.predict_regions(
            image,
            pred_score_threshold=pred_score_threshold_regions,
            containments_threshold=containments_threshold,
            visualize=False,
        )
        gr.Info(f"Found {len(regions_cropped_ordered)} Regions to parse")
        region_data_list = []
        for i, data in tqdm(enumerate(zip(regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered))):
            region_data = self._create_region_data(
                data, i, inferencer, pred_score_threshold_lines, containments_threshold, htr_threshold
            )
            if region_data:
                region_data_list.append(region_data)

        return region_data_list

    def _create_region_data(
        self, data, index, inferencer, pred_score_threshold_lines, containments_threshold, htr_threshold
    ):
        text_region, reg_pol, mask = data
        region_data = {"id": f"region_{index}", "boundary": reg_pol}

        text_lines, htr_scores = self._process_lines(
            text_region,
            inferencer,
            pred_score_threshold_lines,
            containments_threshold,
            mask,
            region_data["id"],
            htr_threshold,
        )

        if not text_lines:
            return None

        region_data["textLines"] = text_lines
        mean_htr_score = sum(htr_scores) / len(htr_scores) if htr_scores else 0

        return region_data if mean_htr_score > htr_threshold else None

    def _process_lines(
        self, text_region, inferencer, pred_score_threshold, containments_threshold, mask, region_id, htr_threshold=0.7
    ):
        _, lines_cropped_ordered, line_polygons_ordered = inferencer.predict_lines(
            text_region, pred_score_threshold, containments_threshold, visualize=False, custom_track=False
        )

        if not lines_cropped_ordered:
            return None, []

        line_polygons_ordered_trans = self.process_seg_mask._translate_line_coords(mask, line_polygons_ordered)

        text_lines = []
        htr_scores = []

        id_number = region_id.split("_")[1]
        total_lines_len = len(lines_cropped_ordered)

        gr.Info(f" Region {id_number}, found {total_lines_len} lines to parse and transcribe.")

        for index, (line, line_pol) in enumerate(zip(lines_cropped_ordered, line_polygons_ordered_trans)):
            line_data, htr_score = self._create_line_data(line, line_pol, index, region_id, inferencer, htr_threshold)

            if line_data:
                text_lines.append(line_data)
            htr_scores.append(htr_score)

            remaining_lines = total_lines_len - index - 1

            if (index + 1) % 10 == 0 and remaining_lines > 5:  # +1 because index starts at 0
                gr.Info(
                    f"Region {id_number}, parsed {index + 1} lines. Still {remaining_lines} lines left to transcribe."
                )

        return text_lines, htr_scores

    def _create_line_data(self, line, line_pol, index, region_id, inferencer, htr_threshold):
        line_data = {"id": f"line_{region_id}_{index}", "boundary": line_pol}

        transcribed_text, htr_score = inferencer.transcribe(line)
        line_data["unicode"] = self.xml_helper.escape_xml_chars(transcribed_text)
        line_data["pred_score"] = round(htr_score, 4)

        return line_data if htr_score > htr_threshold else None, htr_score


if __name__ == "__main__":
    pass