File size: 5,342 Bytes
a36a5bd
7263d32
 
 
 
 
089249c
 
76f8319
 
7263d32
 
 
 
 
 
 
c60ebd1
 
 
 
 
 
 
7263d32
c60ebd1
 
 
7263d32
 
 
 
 
 
 
 
 
 
 
 
 
 
76f8319
7263d32
089249c
 
7263d32
 
 
 
 
 
a36a5bd
7263d32
 
089249c
 
7263d32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
089249c
7263d32
 
76f8319
7263d32
 
 
 
 
 
 
 
 
 
 
 
a36a5bd
 
 
 
 
 
089249c
 
7263d32
089249c
 
7263d32
 
 
 
 
 
a36a5bd
 
 
60af1a7
 
 
a36a5bd
7263d32
 
 
 
 
c60ebd1
 
 
 
 
 
 
 
7263d32
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import gradio as gr
from tqdm import tqdm

from src.htr_pipeline.utils.process_segmask import SegMaskHelper
from src.htr_pipeline.utils.xml_helper import XMLHelper

terminate = False

# TODO check why region is so slow to start.. Is their error with loading the model?


class PipelineInferencer:
    def __init__(self, process_seg_mask: SegMaskHelper, xml_helper: XMLHelper):
        self.process_seg_mask = process_seg_mask
        self.xml_helper = xml_helper

    def image_to_page_xml(
        self,
        image,
        htr_tool_transcriber_model_dropdown,
        pred_score_threshold_regions,
        pred_score_threshold_lines,
        containments_threshold,
        inferencer,
    ):
        # temporary solutions.. for trocr..
        self.htr_tool_transcriber_model_dropdown = htr_tool_transcriber_model_dropdown

        template_data = self.xml_helper.prepare_template_data(self.xml_helper.xml_file_name, image)
        template_data["textRegions"] = self._process_regions(
            image, inferencer, pred_score_threshold_regions, pred_score_threshold_lines, containments_threshold
        )

        return self.xml_helper.render(template_data)

    def _process_regions(
        self,
        image,
        inferencer,
        pred_score_threshold_regions,
        pred_score_threshold_lines,
        containments_threshold,
        htr_threshold=0.6,
    ):
        global terminate

        _, regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered = inferencer.predict_regions(
            image,
            pred_score_threshold=pred_score_threshold_regions,
            containments_threshold=containments_threshold,
            visualize=False,
        )
        gr.Info(f"Found {len(regions_cropped_ordered)} Regions to parse")
        region_data_list = []
        for i, data in tqdm(enumerate(zip(regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered))):
            if terminate:
                break
            region_data = self._create_region_data(
                data, i, inferencer, pred_score_threshold_lines, containments_threshold, htr_threshold
            )
            if region_data:
                region_data_list.append(region_data)

        return region_data_list

    def _create_region_data(
        self, data, index, inferencer, pred_score_threshold_lines, containments_threshold, htr_threshold
    ):
        text_region, reg_pol, mask = data
        region_data = {"id": f"region_{index}", "boundary": reg_pol}

        text_lines, htr_scores = self._process_lines(
            text_region,
            inferencer,
            pred_score_threshold_lines,
            containments_threshold,
            mask,
            region_data["id"],
            htr_threshold,
        )

        if not text_lines:
            return None

        region_data["textLines"] = text_lines
        mean_htr_score = sum(htr_scores) / len(htr_scores) if htr_scores else 0

        return region_data if mean_htr_score > htr_threshold + 0.1 else None

    def _process_lines(
        self, text_region, inferencer, pred_score_threshold, containments_threshold, mask, region_id, htr_threshold=0.6
    ):
        _, lines_cropped_ordered, line_polygons_ordered = inferencer.predict_lines(
            text_region, pred_score_threshold, containments_threshold, visualize=False, custom_track=False
        )

        if not lines_cropped_ordered:
            return None, []

        line_polygons_ordered_trans = self.process_seg_mask._translate_line_coords(mask, line_polygons_ordered)

        text_lines = []
        htr_scores = []

        id_number = region_id.split("_")[1]
        total_lines_len = len(lines_cropped_ordered)

        gr.Info(f" Region {id_number}, found {total_lines_len} lines to parse and transcribe.")

        global terminate

        for index, (line, line_pol) in enumerate(zip(lines_cropped_ordered, line_polygons_ordered_trans)):
            if terminate:
                break
            line_data, htr_score = self._create_line_data(line, line_pol, index, region_id, inferencer, htr_threshold)

            if line_data:
                text_lines.append(line_data)
            htr_scores.append(htr_score)

            remaining_lines = total_lines_len - index - 1

            if (index + 1) % 10 == 0 and remaining_lines > 5:  # +1 because index starts at 0
                gr.Info(
                    f"Region {id_number}, parsed {index + 1} lines. Still {remaining_lines} lines left to transcribe."
                )

        return text_lines, htr_scores

    def _create_line_data(self, line, line_pol, index, region_id, inferencer, htr_threshold):
        line_data = {"id": f"line_{region_id}_{index}", "boundary": line_pol}

        # temporary solution..
        if self.htr_tool_transcriber_model_dropdown == "Riksarkivet/satrn_htr":
            transcribed_text, htr_score = inferencer.transcribe(line)
        else:
            transcribed_text, htr_score = inferencer.transcribe_different_model(
                line, self.htr_tool_transcriber_model_dropdown
            )

        line_data["unicode"] = self.xml_helper.escape_xml_chars(transcribed_text)
        line_data["pred_score"] = round(htr_score, 4)

        return line_data if htr_score > htr_threshold else None, htr_score


if __name__ == "__main__":
    pass