PaddleOCR fast and simplified inference

#4
by Goodsea - opened
This view is limited to 50 files because it contains too many changes.  See the raw diff here.
Files changed (50) hide show
  1. .gitattributes +36 -0
  2. README.md +3 -2
  3. app.py +134 -84
  4. db_utils.py +0 -41
  5. .gitignore → ocr/.gitignore +3 -36
  6. ocr/README.md +1 -0
  7. ocr/__init__.py +0 -0
  8. ocr/ch_PP-OCRv3_det_infer/inference.pdiparams +3 -0
  9. ocr/ch_PP-OCRv3_det_infer/inference.pdiparams.info +0 -0
  10. ocr/ch_PP-OCRv3_det_infer/inference.pdmodel +3 -0
  11. ocr/ch_PP-OCRv3_rec_infer/inference.pdiparams +3 -0
  12. ocr/ch_PP-OCRv3_rec_infer/inference.pdiparams.info +0 -0
  13. ocr/ch_PP-OCRv3_rec_infer/inference.pdmodel +3 -0
  14. ocr/detector.py +248 -0
  15. ocr/inference.py +68 -0
  16. ocr/postprocess/__init__.py +66 -0
  17. ocr/postprocess/cls_postprocess.py +30 -0
  18. ocr/postprocess/db_postprocess.py +207 -0
  19. ocr/postprocess/east_postprocess.py +122 -0
  20. ocr/postprocess/extract_textpoint_fast.py +464 -0
  21. ocr/postprocess/extract_textpoint_slow.py +608 -0
  22. ocr/postprocess/fce_postprocess.py +234 -0
  23. ocr/postprocess/locality_aware_nms.py +198 -0
  24. ocr/postprocess/pg_postprocess.py +189 -0
  25. ocr/postprocess/poly_nms.py +132 -0
  26. ocr/postprocess/pse_postprocess/__init__.py +1 -0
  27. ocr/postprocess/pse_postprocess/pse/__init__.py +20 -0
  28. ocr/postprocess/pse_postprocess/pse/pse.pyx +72 -0
  29. ocr/postprocess/pse_postprocess/pse/setup.py +19 -0
  30. ocr/postprocess/pse_postprocess/pse_postprocess.py +100 -0
  31. ocr/postprocess/rec_postprocess.py +731 -0
  32. ocr/postprocess/sast_postprocess.py +355 -0
  33. ocr/postprocess/vqa_token_re_layoutlm_postprocess.py +36 -0
  34. ocr/postprocess/vqa_token_ser_layoutlm_postprocess.py +96 -0
  35. ocr/ppocr/__init__.py +0 -0
  36. ocr/ppocr/data/__init__.py +79 -0
  37. ocr/ppocr/data/collate_fn.py +59 -0
  38. ocr/ppocr/data/imaug/ColorJitter.py +14 -0
  39. ocr/ppocr/data/imaug/__init__.py +61 -0
  40. ocr/ppocr/data/imaug/copy_paste.py +167 -0
  41. ocr/ppocr/data/imaug/east_process.py +427 -0
  42. ocr/ppocr/data/imaug/fce_aug.py +563 -0
  43. ocr/ppocr/data/imaug/fce_targets.py +671 -0
  44. ocr/ppocr/data/imaug/gen_table_mask.py +228 -0
  45. ocr/ppocr/data/imaug/iaa_augment.py +72 -0
  46. ocr/ppocr/data/imaug/label_ops.py +1046 -0
  47. ocr/ppocr/data/imaug/make_border_map.py +155 -0
  48. ocr/ppocr/data/imaug/make_pse_gt.py +88 -0
  49. ocr/ppocr/data/imaug/make_shrink_map.py +100 -0
  50. ocr/ppocr/data/imaug/operators.py +458 -0
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.pdiparams filter=lfs diff=lfs merge=lfs -text
36
+ *.pdmodel filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: Deprem OCR
3
  emoji: 👀
4
  colorFrom: green
5
  colorTo: blue
6
  sdk: gradio
7
  sdk_version: 3.17.0
8
  app_file: app.py
9
- pinned: true
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Deprem Ocr 2
3
  emoji: 👀
4
  colorFrom: green
5
  colorTo: blue
6
  sdk: gradio
7
  sdk_version: 3.17.0
8
  app_file: app.py
9
+ pinned: false
10
+ duplicated_from: mertcobanov/deprem-ocr-2
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,81 +1,157 @@
1
- from PIL import ImageFilter, Image
2
- from easyocr import Reader
3
  import gradio as gr
4
- import numpy as np
 
5
  import openai
6
  import ast
7
- from transformers import pipeline
8
  import os
 
 
 
 
 
 
9
 
10
- from openai_api import OpenAI_API
11
- import utils
 
 
12
 
13
  openai.api_key = os.getenv("API_KEY")
14
- reader = Reader(["tr"])
15
 
 
 
 
16
 
17
- def get_text(input_img):
18
- img = Image.fromarray(input_img)
19
- detailed = np.asarray(img.filter(ImageFilter.DETAIL))
20
- result = reader.readtext(detailed, detail=0, paragraph=True)
21
- return " ".join(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
 
24
- # Submit button
25
  def get_parsed_address(input_img):
26
 
27
  address_full_text = get_text(input_img)
28
- return ner_response(address_full_text)
29
 
30
 
31
- def save_deta_db(input):
32
- eval_result = ast.literal_eval(input)
33
- utils.write_db(eval_result)
34
- return
 
35
 
36
 
37
- def update_component():
38
- return gr.update(value="Gönderildi, teşekkürler.", visible=True)
39
 
 
 
 
 
40
 
41
- def clear_textbox(value):
42
- return gr.update(value="")
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
 
45
  def text_dict(input):
46
  eval_result = ast.literal_eval(input)
 
 
47
  return (
48
- str(eval_result["il"]),
49
- str(eval_result["ilce"]),
50
- str(eval_result["mahalle"]),
51
- str(eval_result["sokak"]),
52
- str(eval_result["Apartman/site"]),
 
 
53
  str(eval_result["no"]),
54
- str(eval_result["ad-soyad"]),
55
- str(eval_result["dis kapi no"]),
56
  )
57
 
58
 
59
- def ner_response(ocr_input):
60
-
61
- ner_pipe = pipeline("token-classification","deprem-ml/deprem-ner", aggregation_strategy="first")
62
- predictions = ner_pipe(ocr_input)
63
- resp = {}
 
 
 
 
 
 
 
64
 
65
- for item in predictions:
66
- print(item)
67
- key = item["entity_group"]
68
- resp[key] = item["word"]
69
-
 
 
 
 
 
 
 
 
70
  resp["input"] = ocr_input
71
- dict_keys = ["il", "ilce", "mahalle", "sokak", "Apartman/site", "no", "ad-soyad", "dis kapi no"]
 
 
 
 
 
 
 
 
 
 
72
  for key in dict_keys:
73
  if key not in resp.keys():
74
  resp[key] = ""
75
  return resp
76
 
77
 
78
- # User Interface
79
  with gr.Blocks() as demo:
80
  gr.Markdown(
81
  """
@@ -86,68 +162,42 @@ with gr.Blocks() as demo:
86
  "Bu uygulamada ekran görüntüsü sürükleyip bırakarak AFAD'a enkaz bildirimi yapabilirsiniz. Mesajı metin olarak da girebilirsiniz, tam adresi ayrıştırıp döndürür. API olarak kullanmak isterseniz sayfanın en altında use via api'ya tıklayın."
87
  )
88
  with gr.Row():
89
- with gr.Column():
90
- img_area = gr.Image(label="Ekran Görüntüsü yükleyin 👇")
91
- img_area_button = gr.Button(value="Görüntüyü İşle", label="Submit")
92
-
93
- with gr.Column():
94
- text_area = gr.Textbox(label="Metin yükleyin 👇 ", lines=8)
95
- text_area_button = gr.Button(value="Metni Yükle", label="Submit")
96
-
97
  open_api_text = gr.Textbox(label="Tam Adres")
98
-
99
  with gr.Column():
100
  with gr.Row():
101
- il = gr.Textbox(label="İl", interactive=True, show_progress=False)
102
- ilce = gr.Textbox(label="İlçe", interactive=True, show_progress=False)
103
  with gr.Row():
104
- mahalle = gr.Textbox(
105
- label="Mahalle", interactive=True, show_progress=False
106
- )
107
- sokak = gr.Textbox(
108
- label="Sokak/Cadde/Bulvar", interactive=True, show_progress=False
109
- )
110
  with gr.Row():
111
- no = gr.Textbox(label="Telefon", interactive=True, show_progress=False)
112
  with gr.Row():
113
- ad_soyad = gr.Textbox(
114
- label="İsim Soyisim", interactive=True, show_progress=False
115
- )
116
- apartman = gr.Textbox(label="apartman", interactive=True, show_progress=False)
117
  with gr.Row():
118
- dis_kapi_no = gr.Textbox(label="Kapı No", interactive=True, show_progress=False)
119
 
120
- img_area_button.click(
121
  get_parsed_address,
122
  inputs=img_area,
123
  outputs=open_api_text,
124
- api_name="upload-image",
125
  )
126
 
127
- text_area_button.click(
128
- ner_response, text_area, open_api_text, api_name="upload-text"
129
  )
130
 
131
-
132
  open_api_text.change(
133
  text_dict,
134
  open_api_text,
135
- [il, ilce, mahalle, sokak, no, apartman, ad_soyad, dis_kapi_no],
136
- )
137
- ocr_button = gr.Button(value="Sadece OCR kullan")
138
- ocr_button.click(
139
- get_text,
140
- inputs=img_area,
141
- outputs=text_area,
142
- api_name="get-ocr-output",
143
  )
144
- submit_button = gr.Button(value="Veriyi Birimlere Yolla")
145
- submit_button.click(save_deta_db, open_api_text)
146
- done_text = gr.Textbox(label="Done", value="Not Done", visible=False)
147
- submit_button.click(update_component, outputs=done_text)
148
- for txt in [il, ilce, mahalle, sokak, apartman, no, ad_soyad, dis_kapi_no]:
149
- submit_button.click(fn=clear_textbox, inputs=txt, outputs=txt)
150
 
151
 
152
  if __name__ == "__main__":
153
- demo.launch()
 
 
 
1
  import gradio as gr
2
+ import json
3
+ import csv
4
  import openai
5
  import ast
 
6
  import os
7
+ from deta import Deta
8
+
9
+ import numpy as np
10
+ from ocr import utility
11
+ from ocr.detector import TextDetector
12
+ from ocr.recognizer import TextRecognizer
13
 
14
+ # Global Detector and Recognizer
15
+ args = utility.parse_args()
16
+ text_recognizer = TextRecognizer(args)
17
+ text_detector = TextDetector(args)
18
 
19
  openai.api_key = os.getenv("API_KEY")
 
20
 
21
+ args = utility.parse_args()
22
+ text_recognizer = TextRecognizer(args)
23
+ text_detector = TextDetector(args)
24
 
25
+
26
+ def apply_ocr(img):
27
+ # Detect text regions
28
+ dt_boxes, _ = text_detector(img)
29
+
30
+ boxes = []
31
+ for box in dt_boxes:
32
+ p1, p2, p3, p4 = box
33
+ x1 = min(p1[0], p2[0], p3[0], p4[0])
34
+ y1 = min(p1[1], p2[1], p3[1], p4[1])
35
+ x2 = max(p1[0], p2[0], p3[0], p4[0])
36
+ y2 = max(p1[1], p2[1], p3[1], p4[1])
37
+ boxes.append([x1, y1, x2, y2])
38
+
39
+ # Recognize text
40
+ img_list = []
41
+ for i in range(len(boxes)):
42
+ x1, y1, x2, y2 = map(int, boxes[i])
43
+ img_list.append(img.copy()[y1:y2, x1:x2])
44
+ img_list.reverse()
45
+
46
+ rec_res, _ = text_recognizer(img_list)
47
+
48
+ # Postprocess
49
+ total_text = ""
50
+ for i in range(len(rec_res)):
51
+ total_text += rec_res[i][0] + " "
52
+
53
+ total_text = total_text.strip()
54
+ return total_text
55
 
56
 
 
57
  def get_parsed_address(input_img):
58
 
59
  address_full_text = get_text(input_img)
60
+ return openai_response(address_full_text)
61
 
62
 
63
+ def get_text(input_img):
64
+ input_img = np.array(input_img)
65
+ result = apply_ocr(input_img)
66
+ print(result)
67
+ return " ".join(result)
68
 
69
 
70
+ def save_csv(mahalle, il, sokak, apartman):
71
+ adres_full = [mahalle, il, sokak, apartman]
72
 
73
+ with open("adress_book.csv", "a", encoding="utf-8") as f:
74
+ write = csv.writer(f)
75
+ write.writerow(adres_full)
76
+ return adres_full
77
 
78
+
79
+ def get_json(mahalle, il, sokak, apartman):
80
+ adres = {"mahalle": mahalle, "il": il, "sokak": sokak, "apartman": apartman}
81
+ dump = json.dumps(adres, indent=4, ensure_ascii=False)
82
+ return dump
83
+
84
+
85
+ def write_db(data_dict):
86
+ # 2) initialize with a project key
87
+ deta_key = os.getenv("DETA_KEY")
88
+ deta = Deta(deta_key)
89
+
90
+ # 3) create and use as many DBs as you want!
91
+ users = deta.Base("deprem-ocr")
92
+ users.insert(data_dict)
93
 
94
 
95
  def text_dict(input):
96
  eval_result = ast.literal_eval(input)
97
+ write_db(eval_result)
98
+
99
  return (
100
+ str(eval_result["city"]),
101
+ str(eval_result["distinct"]),
102
+ str(eval_result["neighbourhood"]),
103
+ str(eval_result["street"]),
104
+ str(eval_result["address"]),
105
+ str(eval_result["tel"]),
106
+ str(eval_result["name_surname"]),
107
  str(eval_result["no"]),
 
 
108
  )
109
 
110
 
111
+ def openai_response(ocr_input):
112
+ prompt = f"""Tabular Data Extraction You are a highly intelligent and accurate tabular data extractor from
113
+ plain text input and especially from emergency text that carries address information, your inputs can be text
114
+ of arbitrary size, but the output should be in [{{'tabular': {{'entity_type': 'entity'}} }}] JSON format Force it
115
+ to only extract keys that are shared as an example in the examples section, if a key value is not found in the
116
+ text input, then it should be ignored. Have only city, distinct, neighbourhood,
117
+ street, no, tel, name_surname, address Examples: Input: Deprem sırasında evimizde yer alan adresimiz: İstanbul,
118
+ Beşiktaş, Yıldız Mahallesi, Cumhuriyet Caddesi No: 35, cep telefonu numaram 5551231256, adim Ahmet Yilmaz
119
+ Output: {{'city': 'İstanbul', 'distinct': 'Beşiktaş', 'neighbourhood': 'Yıldız Mahallesi', 'street': 'Cumhuriyet Caddesi', 'no': '35', 'tel': '5551231256', 'name_surname': 'Ahmet Yılmaz', 'address': 'İstanbul, Beşiktaş, Yıldız Mahallesi, Cumhuriyet Caddesi No: 35'}}
120
+ Input: {ocr_input}
121
+ Output:
122
+ """
123
 
124
+ response = openai.Completion.create(
125
+ model="text-davinci-003",
126
+ prompt=prompt,
127
+ temperature=0,
128
+ max_tokens=300,
129
+ top_p=1,
130
+ frequency_penalty=0.0,
131
+ presence_penalty=0.0,
132
+ stop=["\n"],
133
+ )
134
+ resp = response["choices"][0]["text"]
135
+ print(resp)
136
+ resp = eval(resp.replace("'{", "{").replace("}'", "}"))
137
  resp["input"] = ocr_input
138
+ dict_keys = [
139
+ "city",
140
+ "distinct",
141
+ "neighbourhood",
142
+ "street",
143
+ "no",
144
+ "tel",
145
+ "name_surname",
146
+ "address",
147
+ "input",
148
+ ]
149
  for key in dict_keys:
150
  if key not in resp.keys():
151
  resp[key] = ""
152
  return resp
153
 
154
 
 
155
  with gr.Blocks() as demo:
156
  gr.Markdown(
157
  """
 
162
  "Bu uygulamada ekran görüntüsü sürükleyip bırakarak AFAD'a enkaz bildirimi yapabilirsiniz. Mesajı metin olarak da girebilirsiniz, tam adresi ayrıştırıp döndürür. API olarak kullanmak isterseniz sayfanın en altında use via api'ya tıklayın."
163
  )
164
  with gr.Row():
165
+ img_area = gr.Image(label="Ekran Görüntüsü yükleyin 👇")
166
+ ocr_result = gr.Textbox(label="Metin yükleyin 👇 ")
 
 
 
 
 
 
167
  open_api_text = gr.Textbox(label="Tam Adres")
168
+ submit_button = gr.Button(label="Yükle")
169
  with gr.Column():
170
  with gr.Row():
171
+ city = gr.Textbox(label="İl")
172
+ distinct = gr.Textbox(label="İlçe")
173
  with gr.Row():
174
+ neighbourhood = gr.Textbox(label="Mahalle")
175
+ street = gr.Textbox(label="Sokak/Cadde/Bulvar")
 
 
 
 
176
  with gr.Row():
177
+ tel = gr.Textbox(label="Telefon")
178
  with gr.Row():
179
+ name_surname = gr.Textbox(label="İsim Soyisim")
180
+ address = gr.Textbox(label="Adres")
 
 
181
  with gr.Row():
182
+ no = gr.Textbox(label="Kapı No")
183
 
184
+ submit_button.click(
185
  get_parsed_address,
186
  inputs=img_area,
187
  outputs=open_api_text,
188
+ api_name="upload_image",
189
  )
190
 
191
+ ocr_result.change(
192
+ openai_response, ocr_result, open_api_text, api_name="upload-text"
193
  )
194
 
 
195
  open_api_text.change(
196
  text_dict,
197
  open_api_text,
198
+ [city, distinct, neighbourhood, street, address, tel, name_surname, no],
 
 
 
 
 
 
 
199
  )
 
 
 
 
 
 
200
 
201
 
202
  if __name__ == "__main__":
203
+ demo.launch()
db_utils.py DELETED
@@ -1,41 +0,0 @@
1
- from deta import Deta # Import Deta
2
- from pprint import pprint
3
- import os
4
-
5
- deta_key = os.getenv("DETA_KEY")
6
- deta = Deta(deta_key)
7
- db = deta.Base("deprem-ocr")
8
-
9
-
10
- def get_users_by_city(city_name, limit=10):
11
-
12
- user = db.fetch({"city": city_name.capitalize()}, limit=limit).items
13
- return user
14
-
15
-
16
- def get_all():
17
- res = db.fetch()
18
- all_items = res.items
19
-
20
- # fetch until last is 'None'
21
- while res.last:
22
- res = db.fetch(last=res.last)
23
- all_items += res.items
24
- return all_items
25
-
26
-
27
- def write_db(data_dict):
28
- # 2) initialize with a project key
29
- deta_key = os.getenv("DETA_KEY")
30
- deta = Deta(deta_key)
31
-
32
- # 3) create and use as many DBs as you want!
33
- users = deta.Base("deprem-ocr")
34
- users.insert(data_dict)
35
- print("Pushed to db")
36
-
37
-
38
- def get_latest_row(last):
39
- all_items = get_all()
40
- latest_items = all_items[-last:]
41
- return latest_items
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore → ocr/.gitignore RENAMED
@@ -20,6 +20,7 @@ parts/
20
  sdist/
21
  var/
22
  wheels/
 
23
  share/python-wheels/
24
  *.egg-info/
25
  .installed.cfg
@@ -49,7 +50,6 @@ coverage.xml
49
  *.py,cover
50
  .hypothesis/
51
  .pytest_cache/
52
- cover/
53
 
54
  # Translations
55
  *.mo
@@ -72,7 +72,6 @@ instance/
72
  docs/_build/
73
 
74
  # PyBuilder
75
- .pybuilder/
76
  target/
77
 
78
  # Jupyter Notebook
@@ -83,9 +82,7 @@ profile_default/
83
  ipython_config.py
84
 
85
  # pyenv
86
- # For a library or package, you might want to ignore these files since the code is
87
- # intended to run in multiple environments; otherwise, check them in:
88
- # .python-version
89
 
90
  # pipenv
91
  # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
@@ -94,22 +91,7 @@ ipython_config.py
94
  # install all needed dependencies.
95
  #Pipfile.lock
96
 
97
- # poetry
98
- # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
- # This is especially recommended for binary packages to ensure reproducibility, and is more
100
- # commonly ignored for libraries.
101
- # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
- #poetry.lock
103
-
104
- # pdm
105
- # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
- #pdm.lock
107
- # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
- # in version control.
109
- # https://pdm.fming.dev/#use-with-ide
110
- .pdm.toml
111
-
112
- # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
  __pypackages__/
114
 
115
  # Celery stuff
@@ -145,18 +127,3 @@ dmypy.json
145
 
146
  # Pyre type checker
147
  .pyre/
148
-
149
- # pytype static type analyzer
150
- .pytype/
151
-
152
- # Cython debug symbols
153
- cython_debug/
154
-
155
- # PyCharm
156
- # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
- # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
- # and can be added to the global gitignore or merged into this file. For a more nuclear
159
- # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
- #.idea/
161
-
162
- .DS_Store
 
20
  sdist/
21
  var/
22
  wheels/
23
+ pip-wheel-metadata/
24
  share/python-wheels/
25
  *.egg-info/
26
  .installed.cfg
 
50
  *.py,cover
51
  .hypothesis/
52
  .pytest_cache/
 
53
 
54
  # Translations
55
  *.mo
 
72
  docs/_build/
73
 
74
  # PyBuilder
 
75
  target/
76
 
77
  # Jupyter Notebook
 
82
  ipython_config.py
83
 
84
  # pyenv
85
+ .python-version
 
 
86
 
87
  # pipenv
88
  # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
 
91
  # install all needed dependencies.
92
  #Pipfile.lock
93
 
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  __pypackages__/
96
 
97
  # Celery stuff
 
127
 
128
  # Pyre type checker
129
  .pyre/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ocr/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # deprem-ocr
ocr/__init__.py ADDED
File without changes
ocr/ch_PP-OCRv3_det_infer/inference.pdiparams ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e9518c6ab706fe87842a8de1c098f990e67f9212b67c9ef8bc4bca6dc17b91a
3
+ size 2377917
ocr/ch_PP-OCRv3_det_infer/inference.pdiparams.info ADDED
Binary file (26.4 kB). View file
 
ocr/ch_PP-OCRv3_det_infer/inference.pdmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:74b075e6cfbc8206dab2eee86a6a8bd015a7be612b2bf6d1a1ef878d31df84f7
3
+ size 1413260
ocr/ch_PP-OCRv3_rec_infer/inference.pdiparams ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d99d4279f7c64471b8f0be426ee09a46c0f1ecb344406bf0bb9571f670e8d0c7
3
+ size 10614098
ocr/ch_PP-OCRv3_rec_infer/inference.pdiparams.info ADDED
Binary file (22 kB). View file
 
ocr/ch_PP-OCRv3_rec_infer/inference.pdmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9beb0b9520d34bde2a0f92581ed64db7e4d6c76abead8b859189ea72db9ee20
3
+ size 1266415
ocr/detector.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ __dir__ = os.path.dirname(os.path.abspath(__file__))
5
+ sys.path.append(__dir__)
6
+ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../..")))
7
+
8
+ os.environ["FLAGS_allocator_strategy"] = "auto_growth"
9
+
10
+ import json
11
+ import sys
12
+ import time
13
+
14
+ import cv2
15
+ import numpy as np
16
+
17
+ import utility
18
+ from postprocess import build_post_process
19
+ from ppocr.data import create_operators, transform
20
+
21
+
22
+ class TextDetector(object):
23
+ def __init__(self, args):
24
+ self.args = args
25
+ self.det_algorithm = args.det_algorithm
26
+ self.use_onnx = args.use_onnx
27
+ pre_process_list = [
28
+ {
29
+ "DetResizeForTest": {
30
+ "limit_side_len": args.det_limit_side_len,
31
+ "limit_type": args.det_limit_type,
32
+ }
33
+ },
34
+ {
35
+ "NormalizeImage": {
36
+ "std": [0.229, 0.224, 0.225],
37
+ "mean": [0.485, 0.456, 0.406],
38
+ "scale": "1./255.",
39
+ "order": "hwc",
40
+ }
41
+ },
42
+ {"ToCHWImage": None},
43
+ {"KeepKeys": {"keep_keys": ["image", "shape"]}},
44
+ ]
45
+ postprocess_params = {}
46
+ if self.det_algorithm == "DB":
47
+ postprocess_params["name"] = "DBPostProcess"
48
+ postprocess_params["thresh"] = args.det_db_thresh
49
+ postprocess_params["box_thresh"] = args.det_db_box_thresh
50
+ postprocess_params["max_candidates"] = 1000
51
+ postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
52
+ postprocess_params["use_dilation"] = args.use_dilation
53
+ postprocess_params["score_mode"] = args.det_db_score_mode
54
+ elif self.det_algorithm == "EAST":
55
+ postprocess_params["name"] = "EASTPostProcess"
56
+ postprocess_params["score_thresh"] = args.det_east_score_thresh
57
+ postprocess_params["cover_thresh"] = args.det_east_cover_thresh
58
+ postprocess_params["nms_thresh"] = args.det_east_nms_thresh
59
+ elif self.det_algorithm == "SAST":
60
+ pre_process_list[0] = {
61
+ "DetResizeForTest": {"resize_long": args.det_limit_side_len}
62
+ }
63
+ postprocess_params["name"] = "SASTPostProcess"
64
+ postprocess_params["score_thresh"] = args.det_sast_score_thresh
65
+ postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
66
+ self.det_sast_polygon = args.det_sast_polygon
67
+ if self.det_sast_polygon:
68
+ postprocess_params["sample_pts_num"] = 6
69
+ postprocess_params["expand_scale"] = 1.2
70
+ postprocess_params["shrink_ratio_of_width"] = 0.2
71
+ else:
72
+ postprocess_params["sample_pts_num"] = 2
73
+ postprocess_params["expand_scale"] = 1.0
74
+ postprocess_params["shrink_ratio_of_width"] = 0.3
75
+ elif self.det_algorithm == "PSE":
76
+ postprocess_params["name"] = "PSEPostProcess"
77
+ postprocess_params["thresh"] = args.det_pse_thresh
78
+ postprocess_params["box_thresh"] = args.det_pse_box_thresh
79
+ postprocess_params["min_area"] = args.det_pse_min_area
80
+ postprocess_params["box_type"] = args.det_pse_box_type
81
+ postprocess_params["scale"] = args.det_pse_scale
82
+ self.det_pse_box_type = args.det_pse_box_type
83
+ elif self.det_algorithm == "FCE":
84
+ pre_process_list[0] = {"DetResizeForTest": {"rescale_img": [1080, 736]}}
85
+ postprocess_params["name"] = "FCEPostProcess"
86
+ postprocess_params["scales"] = args.scales
87
+ postprocess_params["alpha"] = args.alpha
88
+ postprocess_params["beta"] = args.beta
89
+ postprocess_params["fourier_degree"] = args.fourier_degree
90
+ postprocess_params["box_type"] = args.det_fce_box_type
91
+
92
+ self.preprocess_op = create_operators(pre_process_list)
93
+ self.postprocess_op = build_post_process(postprocess_params)
94
+ (
95
+ self.predictor,
96
+ self.input_tensor,
97
+ self.output_tensors,
98
+ self.config,
99
+ ) = utility.create_predictor(args, "det")
100
+
101
+ if self.use_onnx:
102
+ img_h, img_w = self.input_tensor.shape[2:]
103
+ if img_h is not None and img_w is not None and img_h > 0 and img_w > 0:
104
+ pre_process_list[0] = {
105
+ "DetResizeForTest": {"image_shape": [img_h, img_w]}
106
+ }
107
+ self.preprocess_op = create_operators(pre_process_list)
108
+
109
+ def order_points_clockwise(self, pts):
110
+ rect = np.zeros((4, 2), dtype="float32")
111
+ s = pts.sum(axis=1)
112
+ rect[0] = pts[np.argmin(s)]
113
+ rect[2] = pts[np.argmax(s)]
114
+ diff = np.diff(pts, axis=1)
115
+ rect[1] = pts[np.argmin(diff)]
116
+ rect[3] = pts[np.argmax(diff)]
117
+ return rect
118
+
119
+ def clip_det_res(self, points, img_height, img_width):
120
+ for pno in range(points.shape[0]):
121
+ points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
122
+ points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
123
+ return points
124
+
125
+ def filter_tag_det_res(self, dt_boxes, image_shape):
126
+ img_height, img_width = image_shape[0:2]
127
+ dt_boxes_new = []
128
+ for box in dt_boxes:
129
+ box = self.order_points_clockwise(box)
130
+ box = self.clip_det_res(box, img_height, img_width)
131
+ rect_width = int(np.linalg.norm(box[0] - box[1]))
132
+ rect_height = int(np.linalg.norm(box[0] - box[3]))
133
+ if rect_width <= 3 or rect_height <= 3:
134
+ continue
135
+ dt_boxes_new.append(box)
136
+ dt_boxes = np.array(dt_boxes_new)
137
+ return dt_boxes
138
+
139
+ def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
140
+ img_height, img_width = image_shape[0:2]
141
+ dt_boxes_new = []
142
+ for box in dt_boxes:
143
+ box = self.clip_det_res(box, img_height, img_width)
144
+ dt_boxes_new.append(box)
145
+ dt_boxes = np.array(dt_boxes_new)
146
+ return dt_boxes
147
+
148
+ def __call__(self, img):
149
+ ori_im = img.copy()
150
+ data = {"image": img}
151
+
152
+ st = time.time()
153
+
154
+ data = transform(data, self.preprocess_op)
155
+ img, shape_list = data
156
+ if img is None:
157
+ return None, 0
158
+ img = np.expand_dims(img, axis=0)
159
+ shape_list = np.expand_dims(shape_list, axis=0)
160
+ img = img.copy()
161
+
162
+ if self.use_onnx:
163
+ input_dict = {}
164
+ input_dict[self.input_tensor.name] = img
165
+ outputs = self.predictor.run(self.output_tensors, input_dict)
166
+ else:
167
+ self.input_tensor.copy_from_cpu(img)
168
+ self.predictor.run()
169
+ outputs = []
170
+ for output_tensor in self.output_tensors:
171
+ output = output_tensor.copy_to_cpu()
172
+ outputs.append(output)
173
+
174
+ preds = {}
175
+ if self.det_algorithm == "EAST":
176
+ preds["f_geo"] = outputs[0]
177
+ preds["f_score"] = outputs[1]
178
+ elif self.det_algorithm == "SAST":
179
+ preds["f_border"] = outputs[0]
180
+ preds["f_score"] = outputs[1]
181
+ preds["f_tco"] = outputs[2]
182
+ preds["f_tvo"] = outputs[3]
183
+ elif self.det_algorithm in ["DB", "PSE"]:
184
+ preds["maps"] = outputs[0]
185
+ elif self.det_algorithm == "FCE":
186
+ for i, output in enumerate(outputs):
187
+ preds["level_{}".format(i)] = output
188
+ else:
189
+ raise NotImplementedError
190
+
191
+ # self.predictor.try_shrink_memory()
192
+ post_result = self.postprocess_op(preds, shape_list)
193
+ dt_boxes = post_result[0]["points"]
194
+ if (self.det_algorithm == "SAST" and self.det_sast_polygon) or (
195
+ self.det_algorithm in ["PSE", "FCE"]
196
+ and self.postprocess_op.box_type == "poly"
197
+ ):
198
+ dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
199
+ else:
200
+ dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
201
+
202
+ et = time.time()
203
+ return dt_boxes, et - st
204
+
205
+
206
+ if __name__ == "__main__":
207
+ args = utility.parse_args()
208
+ image_file_list = ["images/y.png"]
209
+ text_detector = TextDetector(args)
210
+ count = 0
211
+ total_time = 0
212
+ draw_img_save = "./inference_results"
213
+
214
+ if args.warmup:
215
+ img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
216
+ for i in range(2):
217
+ res = text_detector(img)
218
+
219
+ if not os.path.exists(draw_img_save):
220
+ os.makedirs(draw_img_save)
221
+
222
+ save_results = []
223
+ for image_file in image_file_list:
224
+ img = cv2.imread(image_file)
225
+
226
+ for _ in range(10):
227
+ st = time.time()
228
+ dt_boxes, _ = text_detector(img)
229
+ elapse = time.time() - st
230
+ print(elapse * 1000)
231
+ if count > 0:
232
+ total_time += elapse
233
+ count += 1
234
+ save_pred = (
235
+ os.path.basename(image_file)
236
+ + "\t"
237
+ + str(json.dumps([x.tolist() for x in dt_boxes]))
238
+ + "\n"
239
+ )
240
+ save_results.append(save_pred)
241
+ src_im = utility.draw_text_det_res(dt_boxes, image_file)
242
+ img_name_pure = os.path.split(image_file)[-1]
243
+ img_path = os.path.join(draw_img_save, "det_res_{}".format(img_name_pure))
244
+ cv2.imwrite(img_path, src_im)
245
+
246
+ with open(os.path.join(draw_img_save, "det_results.txt"), "w") as f:
247
+ f.writelines(save_results)
248
+ f.close()
ocr/inference.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
4
+
5
+ import time
6
+ import requests
7
+ from io import BytesIO
8
+
9
+ import utility
10
+ from detector import *
11
+ from recognizer import *
12
+
13
+ # Global Detector and Recognizer
14
+ args = utility.parse_args()
15
+ text_recognizer = TextRecognizer(args)
16
+ text_detector = TextDetector(args)
17
+
18
+
19
+ def apply_ocr(img):
20
+ # Detect text regions
21
+ dt_boxes, _ = text_detector(img)
22
+
23
+ boxes = []
24
+ for box in dt_boxes:
25
+ p1, p2, p3, p4 = box
26
+ x1 = min(p1[0], p2[0], p3[0], p4[0])
27
+ y1 = min(p1[1], p2[1], p3[1], p4[1])
28
+ x2 = max(p1[0], p2[0], p3[0], p4[0])
29
+ y2 = max(p1[1], p2[1], p3[1], p4[1])
30
+ boxes.append([x1, y1, x2, y2])
31
+
32
+ # Recognize text
33
+ img_list = []
34
+ for i in range(len(boxes)):
35
+ x1, y1, x2, y2 = map(int, boxes[i])
36
+ img_list.append(img.copy()[y1:y2, x1:x2])
37
+ img_list.reverse()
38
+
39
+ rec_res, _ = text_recognizer(img_list)
40
+
41
+ # Postprocess
42
+ total_text = ""
43
+ table = dict()
44
+ for i in range(len(rec_res)):
45
+ table[i] = {
46
+ "text": rec_res[i][0],
47
+ }
48
+ total_text += rec_res[i][0] + " "
49
+
50
+ total_text = total_text.strip()
51
+ return total_text
52
+
53
+
54
+ def main():
55
+ image_url = "https://i.ibb.co/kQvHGjj/aewrg.png"
56
+ response = requests.get(image_url)
57
+ img = np.array(Image.open(BytesIO(response.content)).convert("RGB"))
58
+
59
+ t0 = time.time()
60
+ epoch = 1
61
+ for _ in range(epoch):
62
+ ocr_text = apply_ocr(img)
63
+ print("Elapsed time:", (time.time() - t0) * 1000 / epoch, "ms")
64
+ print("Output:", ocr_text)
65
+
66
+
67
+ if __name__ == "__main__":
68
+ main()
ocr/postprocess/__init__.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, division, print_function, unicode_literals
2
+
3
+ import copy
4
+
5
+ __all__ = ["build_post_process"]
6
+
7
+ from .cls_postprocess import ClsPostProcess
8
+ from .db_postprocess import DBPostProcess, DistillationDBPostProcess
9
+ from .east_postprocess import EASTPostProcess
10
+ from .fce_postprocess import FCEPostProcess
11
+ from .pg_postprocess import PGPostProcess
12
+ from .rec_postprocess import (
13
+ AttnLabelDecode,
14
+ CTCLabelDecode,
15
+ DistillationCTCLabelDecode,
16
+ NRTRLabelDecode,
17
+ PRENLabelDecode,
18
+ SARLabelDecode,
19
+ SEEDLabelDecode,
20
+ SRNLabelDecode,
21
+ TableLabelDecode,
22
+ )
23
+ from .sast_postprocess import SASTPostProcess
24
+ from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess
25
+ from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
26
+
27
+
28
+ def build_post_process(config, global_config=None):
29
+ support_dict = [
30
+ "DBPostProcess",
31
+ "EASTPostProcess",
32
+ "SASTPostProcess",
33
+ "FCEPostProcess",
34
+ "CTCLabelDecode",
35
+ "AttnLabelDecode",
36
+ "ClsPostProcess",
37
+ "SRNLabelDecode",
38
+ "PGPostProcess",
39
+ "DistillationCTCLabelDecode",
40
+ "TableLabelDecode",
41
+ "DistillationDBPostProcess",
42
+ "NRTRLabelDecode",
43
+ "SARLabelDecode",
44
+ "SEEDLabelDecode",
45
+ "VQASerTokenLayoutLMPostProcess",
46
+ "VQAReTokenLayoutLMPostProcess",
47
+ "PRENLabelDecode",
48
+ "DistillationSARLabelDecode",
49
+ ]
50
+
51
+ if config["name"] == "PSEPostProcess":
52
+ from .pse_postprocess import PSEPostProcess
53
+
54
+ support_dict.append("PSEPostProcess")
55
+
56
+ config = copy.deepcopy(config)
57
+ module_name = config.pop("name")
58
+ if module_name == "None":
59
+ return
60
+ if global_config is not None:
61
+ config.update(global_config)
62
+ assert module_name in support_dict, Exception(
63
+ "post process only support {}".format(support_dict)
64
+ )
65
+ module_class = eval(module_name)(**config)
66
+ return module_class
ocr/postprocess/cls_postprocess.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import paddle
2
+
3
+
4
+ class ClsPostProcess(object):
5
+ """Convert between text-label and text-index"""
6
+
7
+ def __init__(self, label_list=None, key=None, **kwargs):
8
+ super(ClsPostProcess, self).__init__()
9
+ self.label_list = label_list
10
+ self.key = key
11
+
12
+ def __call__(self, preds, label=None, *args, **kwargs):
13
+ if self.key is not None:
14
+ preds = preds[self.key]
15
+
16
+ label_list = self.label_list
17
+ if label_list is None:
18
+ label_list = {idx: idx for idx in range(preds.shape[-1])}
19
+
20
+ if isinstance(preds, paddle.Tensor):
21
+ preds = preds.numpy()
22
+
23
+ pred_idxs = preds.argmax(axis=1)
24
+ decode_out = [
25
+ (label_list[idx], preds[i, idx]) for i, idx in enumerate(pred_idxs)
26
+ ]
27
+ if label is None:
28
+ return decode_out
29
+ label = [(label_list[idx], 1.0) for idx in label]
30
+ return decode_out, label
ocr/postprocess/db_postprocess.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, division, print_function
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import paddle
6
+ import pyclipper
7
+ from shapely.geometry import Polygon
8
+
9
+
10
+ class DBPostProcess(object):
11
+ """
12
+ The post process for Differentiable Binarization (DB).
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ thresh=0.3,
18
+ box_thresh=0.7,
19
+ max_candidates=1000,
20
+ unclip_ratio=2.0,
21
+ use_dilation=False,
22
+ score_mode="fast",
23
+ **kwargs
24
+ ):
25
+ self.thresh = thresh
26
+ self.box_thresh = box_thresh
27
+ self.max_candidates = max_candidates
28
+ self.unclip_ratio = unclip_ratio
29
+ self.min_size = 3
30
+ self.score_mode = score_mode
31
+ assert score_mode in [
32
+ "slow",
33
+ "fast",
34
+ ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
35
+
36
+ self.dilation_kernel = None if not use_dilation else np.array([[1, 1], [1, 1]])
37
+
38
+ def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
39
+ """
40
+ _bitmap: single map with shape (1, H, W),
41
+ whose values are binarized as {0, 1}
42
+ """
43
+
44
+ bitmap = _bitmap
45
+ height, width = bitmap.shape
46
+
47
+ outs = cv2.findContours(
48
+ (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
49
+ )
50
+ if len(outs) == 3:
51
+ img, contours, _ = outs[0], outs[1], outs[2]
52
+ elif len(outs) == 2:
53
+ contours, _ = outs[0], outs[1]
54
+
55
+ num_contours = min(len(contours), self.max_candidates)
56
+
57
+ boxes = []
58
+ scores = []
59
+ for index in range(num_contours):
60
+ contour = contours[index]
61
+ points, sside = self.get_mini_boxes(contour)
62
+ if sside < self.min_size:
63
+ continue
64
+ points = np.array(points)
65
+ if self.score_mode == "fast":
66
+ score = self.box_score_fast(pred, points.reshape(-1, 2))
67
+ else:
68
+ score = self.box_score_slow(pred, contour)
69
+ if self.box_thresh > score:
70
+ continue
71
+
72
+ box = self.unclip(points).reshape(-1, 1, 2)
73
+ box, sside = self.get_mini_boxes(box)
74
+ if sside < self.min_size + 2:
75
+ continue
76
+ box = np.array(box)
77
+
78
+ box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
79
+ box[:, 1] = np.clip(
80
+ np.round(box[:, 1] / height * dest_height), 0, dest_height
81
+ )
82
+ boxes.append(box.astype(np.int16))
83
+ scores.append(score)
84
+ return np.array(boxes, dtype=np.int16), scores
85
+
86
+ def unclip(self, box):
87
+ unclip_ratio = self.unclip_ratio
88
+ poly = Polygon(box)
89
+ distance = poly.area * unclip_ratio / poly.length
90
+ offset = pyclipper.PyclipperOffset()
91
+ offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
92
+ expanded = np.array(offset.Execute(distance))
93
+ return expanded
94
+
95
+ def get_mini_boxes(self, contour):
96
+ bounding_box = cv2.minAreaRect(contour)
97
+ points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
98
+
99
+ index_1, index_2, index_3, index_4 = 0, 1, 2, 3
100
+ if points[1][1] > points[0][1]:
101
+ index_1 = 0
102
+ index_4 = 1
103
+ else:
104
+ index_1 = 1
105
+ index_4 = 0
106
+ if points[3][1] > points[2][1]:
107
+ index_2 = 2
108
+ index_3 = 3
109
+ else:
110
+ index_2 = 3
111
+ index_3 = 2
112
+
113
+ box = [points[index_1], points[index_2], points[index_3], points[index_4]]
114
+ return box, min(bounding_box[1])
115
+
116
+ def box_score_fast(self, bitmap, _box):
117
+ """
118
+ box_score_fast: use bbox mean score as the mean score
119
+ """
120
+ h, w = bitmap.shape[:2]
121
+ box = _box.copy()
122
+ xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
123
+ xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1)
124
+ ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1)
125
+ ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1)
126
+
127
+ mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
128
+ box[:, 0] = box[:, 0] - xmin
129
+ box[:, 1] = box[:, 1] - ymin
130
+ cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
131
+ return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
132
+
133
+ def box_score_slow(self, bitmap, contour):
134
+ """
135
+ box_score_slow: use polyon mean score as the mean score
136
+ """
137
+ h, w = bitmap.shape[:2]
138
+ contour = contour.copy()
139
+ contour = np.reshape(contour, (-1, 2))
140
+
141
+ xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
142
+ xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
143
+ ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
144
+ ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
145
+
146
+ mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
147
+
148
+ contour[:, 0] = contour[:, 0] - xmin
149
+ contour[:, 1] = contour[:, 1] - ymin
150
+
151
+ cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
152
+ return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
153
+
154
+ def __call__(self, outs_dict, shape_list):
155
+ pred = outs_dict["maps"]
156
+ if isinstance(pred, paddle.Tensor):
157
+ pred = pred.numpy()
158
+ pred = pred[:, 0, :, :]
159
+ segmentation = pred > self.thresh
160
+
161
+ boxes_batch = []
162
+ for batch_index in range(pred.shape[0]):
163
+ src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
164
+ if self.dilation_kernel is not None:
165
+ mask = cv2.dilate(
166
+ np.array(segmentation[batch_index]).astype(np.uint8),
167
+ self.dilation_kernel,
168
+ )
169
+ else:
170
+ mask = segmentation[batch_index]
171
+ boxes, scores = self.boxes_from_bitmap(
172
+ pred[batch_index], mask, src_w, src_h
173
+ )
174
+
175
+ boxes_batch.append({"points": boxes})
176
+ return boxes_batch
177
+
178
+
179
+ class DistillationDBPostProcess(object):
180
+ def __init__(
181
+ self,
182
+ model_name=["student"],
183
+ key=None,
184
+ thresh=0.3,
185
+ box_thresh=0.6,
186
+ max_candidates=1000,
187
+ unclip_ratio=1.5,
188
+ use_dilation=False,
189
+ score_mode="fast",
190
+ **kwargs
191
+ ):
192
+ self.model_name = model_name
193
+ self.key = key
194
+ self.post_process = DBPostProcess(
195
+ thresh=thresh,
196
+ box_thresh=box_thresh,
197
+ max_candidates=max_candidates,
198
+ unclip_ratio=unclip_ratio,
199
+ use_dilation=use_dilation,
200
+ score_mode=score_mode,
201
+ )
202
+
203
+ def __call__(self, predicts, shape_list):
204
+ results = {}
205
+ for k in self.model_name:
206
+ results[k] = self.post_process(predicts[k], shape_list=shape_list)
207
+ return results
ocr/postprocess/east_postprocess.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, division, print_function
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import paddle
6
+
7
+ from .locality_aware_nms import nms_locality
8
+
9
+
10
+ class EASTPostProcess(object):
11
+ """
12
+ The post process for EAST.
13
+ """
14
+
15
+ def __init__(self, score_thresh=0.8, cover_thresh=0.1, nms_thresh=0.2, **kwargs):
16
+
17
+ self.score_thresh = score_thresh
18
+ self.cover_thresh = cover_thresh
19
+ self.nms_thresh = nms_thresh
20
+
21
+ def restore_rectangle_quad(self, origin, geometry):
22
+ """
23
+ Restore rectangle from quadrangle.
24
+ """
25
+ # quad
26
+ origin_concat = np.concatenate(
27
+ (origin, origin, origin, origin), axis=1
28
+ ) # (n, 8)
29
+ pred_quads = origin_concat - geometry
30
+ pred_quads = pred_quads.reshape((-1, 4, 2)) # (n, 4, 2)
31
+ return pred_quads
32
+
33
+ def detect(
34
+ self, score_map, geo_map, score_thresh=0.8, cover_thresh=0.1, nms_thresh=0.2
35
+ ):
36
+ """
37
+ restore text boxes from score map and geo map
38
+ """
39
+
40
+ score_map = score_map[0]
41
+ geo_map = np.swapaxes(geo_map, 1, 0)
42
+ geo_map = np.swapaxes(geo_map, 1, 2)
43
+ # filter the score map
44
+ xy_text = np.argwhere(score_map > score_thresh)
45
+ if len(xy_text) == 0:
46
+ return []
47
+ # sort the text boxes via the y axis
48
+ xy_text = xy_text[np.argsort(xy_text[:, 0])]
49
+ # restore quad proposals
50
+ text_box_restored = self.restore_rectangle_quad(
51
+ xy_text[:, ::-1] * 4, geo_map[xy_text[:, 0], xy_text[:, 1], :]
52
+ )
53
+ boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32)
54
+ boxes[:, :8] = text_box_restored.reshape((-1, 8))
55
+ boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]
56
+
57
+ try:
58
+ import lanms
59
+
60
+ boxes = lanms.merge_quadrangle_n9(boxes, nms_thresh)
61
+ except:
62
+ print(
63
+ "you should install lanms by pip3 install lanms-nova to speed up nms_locality"
64
+ )
65
+ boxes = nms_locality(boxes.astype(np.float64), nms_thresh)
66
+ if boxes.shape[0] == 0:
67
+ return []
68
+ # Here we filter some low score boxes by the average score map,
69
+ # this is different from the orginal paper.
70
+ for i, box in enumerate(boxes):
71
+ mask = np.zeros_like(score_map, dtype=np.uint8)
72
+ cv2.fillPoly(mask, box[:8].reshape((-1, 4, 2)).astype(np.int32) // 4, 1)
73
+ boxes[i, 8] = cv2.mean(score_map, mask)[0]
74
+ boxes = boxes[boxes[:, 8] > cover_thresh]
75
+ return boxes
76
+
77
+ def sort_poly(self, p):
78
+ """
79
+ Sort polygons.
80
+ """
81
+ min_axis = np.argmin(np.sum(p, axis=1))
82
+ p = p[[min_axis, (min_axis + 1) % 4, (min_axis + 2) % 4, (min_axis + 3) % 4]]
83
+ if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]):
84
+ return p
85
+ else:
86
+ return p[[0, 3, 2, 1]]
87
+
88
+ def __call__(self, outs_dict, shape_list):
89
+ score_list = outs_dict["f_score"]
90
+ geo_list = outs_dict["f_geo"]
91
+ if isinstance(score_list, paddle.Tensor):
92
+ score_list = score_list.numpy()
93
+ geo_list = geo_list.numpy()
94
+ img_num = len(shape_list)
95
+ dt_boxes_list = []
96
+ for ino in range(img_num):
97
+ score = score_list[ino]
98
+ geo = geo_list[ino]
99
+ boxes = self.detect(
100
+ score_map=score,
101
+ geo_map=geo,
102
+ score_thresh=self.score_thresh,
103
+ cover_thresh=self.cover_thresh,
104
+ nms_thresh=self.nms_thresh,
105
+ )
106
+ boxes_norm = []
107
+ if len(boxes) > 0:
108
+ h, w = score.shape[1:]
109
+ src_h, src_w, ratio_h, ratio_w = shape_list[ino]
110
+ boxes = boxes[:, :8].reshape((-1, 4, 2))
111
+ boxes[:, :, 0] /= ratio_w
112
+ boxes[:, :, 1] /= ratio_h
113
+ for i_box, box in enumerate(boxes):
114
+ box = self.sort_poly(box.astype(np.int32))
115
+ if (
116
+ np.linalg.norm(box[0] - box[1]) < 5
117
+ or np.linalg.norm(box[3] - box[0]) < 5
118
+ ):
119
+ continue
120
+ boxes_norm.append(box)
121
+ dt_boxes_list.append({"points": np.array(boxes_norm)})
122
+ return dt_boxes_list
ocr/postprocess/extract_textpoint_fast.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, division, print_function
2
+
3
+ from itertools import groupby
4
+
5
+ import cv2
6
+ import numpy as np
7
+ from skimage.morphology._skeletonize import thin
8
+
9
+
10
+ def get_dict(character_dict_path):
11
+ character_str = ""
12
+ with open(character_dict_path, "rb") as fin:
13
+ lines = fin.readlines()
14
+ for line in lines:
15
+ line = line.decode("utf-8").strip("\n").strip("\r\n")
16
+ character_str += line
17
+ dict_character = list(character_str)
18
+ return dict_character
19
+
20
+
21
+ def softmax(logits):
22
+ """
23
+ logits: N x d
24
+ """
25
+ max_value = np.max(logits, axis=1, keepdims=True)
26
+ exp = np.exp(logits - max_value)
27
+ exp_sum = np.sum(exp, axis=1, keepdims=True)
28
+ dist = exp / exp_sum
29
+ return dist
30
+
31
+
32
+ def get_keep_pos_idxs(labels, remove_blank=None):
33
+ """
34
+ Remove duplicate and get pos idxs of keep items.
35
+ The value of keep_blank should be [None, 95].
36
+ """
37
+ duplicate_len_list = []
38
+ keep_pos_idx_list = []
39
+ keep_char_idx_list = []
40
+ for k, v_ in groupby(labels):
41
+ current_len = len(list(v_))
42
+ if k != remove_blank:
43
+ current_idx = int(sum(duplicate_len_list) + current_len // 2)
44
+ keep_pos_idx_list.append(current_idx)
45
+ keep_char_idx_list.append(k)
46
+ duplicate_len_list.append(current_len)
47
+ return keep_char_idx_list, keep_pos_idx_list
48
+
49
+
50
+ def remove_blank(labels, blank=0):
51
+ new_labels = [x for x in labels if x != blank]
52
+ return new_labels
53
+
54
+
55
+ def insert_blank(labels, blank=0):
56
+ new_labels = [blank]
57
+ for l in labels:
58
+ new_labels += [l, blank]
59
+ return new_labels
60
+
61
+
62
+ def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
63
+ """
64
+ CTC greedy (best path) decoder.
65
+ """
66
+ raw_str = np.argmax(np.array(probs_seq), axis=1)
67
+ remove_blank_in_pos = None if keep_blank_in_idxs else blank
68
+ dedup_str, keep_idx_list = get_keep_pos_idxs(
69
+ raw_str, remove_blank=remove_blank_in_pos
70
+ )
71
+ dst_str = remove_blank(dedup_str, blank=blank)
72
+ return dst_str, keep_idx_list
73
+
74
+
75
+ def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4):
76
+ _, _, C = logits_map.shape
77
+ ys, xs = zip(*gather_info)
78
+ logits_seq = logits_map[list(ys), list(xs)]
79
+ probs_seq = logits_seq
80
+ labels = np.argmax(probs_seq, axis=1)
81
+ dst_str = [k for k, v_ in groupby(labels) if k != C - 1]
82
+ detal = len(gather_info) // (pts_num - 1)
83
+ keep_idx_list = [0] + [detal * (i + 1) for i in range(pts_num - 2)] + [-1]
84
+ keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
85
+ return dst_str, keep_gather_list
86
+
87
+
88
+ def ctc_decoder_for_image(gather_info_list, logits_map, Lexicon_Table, pts_num=6):
89
+ """
90
+ CTC decoder using multiple processes.
91
+ """
92
+ decoder_str = []
93
+ decoder_xys = []
94
+ for gather_info in gather_info_list:
95
+ if len(gather_info) < pts_num:
96
+ continue
97
+ dst_str, xys_list = instance_ctc_greedy_decoder(
98
+ gather_info, logits_map, pts_num=pts_num
99
+ )
100
+ dst_str_readable = "".join([Lexicon_Table[idx] for idx in dst_str])
101
+ if len(dst_str_readable) < 2:
102
+ continue
103
+ decoder_str.append(dst_str_readable)
104
+ decoder_xys.append(xys_list)
105
+ return decoder_str, decoder_xys
106
+
107
+
108
+ def sort_with_direction(pos_list, f_direction):
109
+ """
110
+ f_direction: h x w x 2
111
+ pos_list: [[y, x], [y, x], [y, x] ...]
112
+ """
113
+
114
+ def sort_part_with_direction(pos_list, point_direction):
115
+ pos_list = np.array(pos_list).reshape(-1, 2)
116
+ point_direction = np.array(point_direction).reshape(-1, 2)
117
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
118
+ pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
119
+ sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist()
120
+ sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
121
+ return sorted_list, sorted_direction
122
+
123
+ pos_list = np.array(pos_list).reshape(-1, 2)
124
+ point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
125
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
126
+ sorted_point, sorted_direction = sort_part_with_direction(pos_list, point_direction)
127
+
128
+ point_num = len(sorted_point)
129
+ if point_num >= 16:
130
+ middle_num = point_num // 2
131
+ first_part_point = sorted_point[:middle_num]
132
+ first_point_direction = sorted_direction[:middle_num]
133
+ sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
134
+ first_part_point, first_point_direction
135
+ )
136
+
137
+ last_part_point = sorted_point[middle_num:]
138
+ last_point_direction = sorted_direction[middle_num:]
139
+ sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
140
+ last_part_point, last_point_direction
141
+ )
142
+ sorted_point = sorted_fist_part_point + sorted_last_part_point
143
+ sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
144
+
145
+ return sorted_point, np.array(sorted_direction)
146
+
147
+
148
+ def add_id(pos_list, image_id=0):
149
+ """
150
+ Add id for gather feature, for inference.
151
+ """
152
+ new_list = []
153
+ for item in pos_list:
154
+ new_list.append((image_id, item[0], item[1]))
155
+ return new_list
156
+
157
+
158
+ def sort_and_expand_with_direction(pos_list, f_direction):
159
+ """
160
+ f_direction: h x w x 2
161
+ pos_list: [[y, x], [y, x], [y, x] ...]
162
+ """
163
+ h, w, _ = f_direction.shape
164
+ sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
165
+
166
+ point_num = len(sorted_list)
167
+ sub_direction_len = max(point_num // 3, 2)
168
+ left_direction = point_direction[:sub_direction_len, :]
169
+ right_dirction = point_direction[point_num - sub_direction_len :, :]
170
+
171
+ left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
172
+ left_average_len = np.linalg.norm(left_average_direction)
173
+ left_start = np.array(sorted_list[0])
174
+ left_step = left_average_direction / (left_average_len + 1e-6)
175
+
176
+ right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
177
+ right_average_len = np.linalg.norm(right_average_direction)
178
+ right_step = right_average_direction / (right_average_len + 1e-6)
179
+ right_start = np.array(sorted_list[-1])
180
+
181
+ append_num = max(int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
182
+ left_list = []
183
+ right_list = []
184
+ for i in range(append_num):
185
+ ly, lx = (
186
+ np.round(left_start + left_step * (i + 1))
187
+ .flatten()
188
+ .astype("int32")
189
+ .tolist()
190
+ )
191
+ if ly < h and lx < w and (ly, lx) not in left_list:
192
+ left_list.append((ly, lx))
193
+ ry, rx = (
194
+ np.round(right_start + right_step * (i + 1))
195
+ .flatten()
196
+ .astype("int32")
197
+ .tolist()
198
+ )
199
+ if ry < h and rx < w and (ry, rx) not in right_list:
200
+ right_list.append((ry, rx))
201
+
202
+ all_list = left_list[::-1] + sorted_list + right_list
203
+ return all_list
204
+
205
+
206
+ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
207
+ """
208
+ f_direction: h x w x 2
209
+ pos_list: [[y, x], [y, x], [y, x] ...]
210
+ binary_tcl_map: h x w
211
+ """
212
+ h, w, _ = f_direction.shape
213
+ sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
214
+
215
+ point_num = len(sorted_list)
216
+ sub_direction_len = max(point_num // 3, 2)
217
+ left_direction = point_direction[:sub_direction_len, :]
218
+ right_dirction = point_direction[point_num - sub_direction_len :, :]
219
+
220
+ left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
221
+ left_average_len = np.linalg.norm(left_average_direction)
222
+ left_start = np.array(sorted_list[0])
223
+ left_step = left_average_direction / (left_average_len + 1e-6)
224
+
225
+ right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
226
+ right_average_len = np.linalg.norm(right_average_direction)
227
+ right_step = right_average_direction / (right_average_len + 1e-6)
228
+ right_start = np.array(sorted_list[-1])
229
+
230
+ append_num = max(int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
231
+ max_append_num = 2 * append_num
232
+
233
+ left_list = []
234
+ right_list = []
235
+ for i in range(max_append_num):
236
+ ly, lx = (
237
+ np.round(left_start + left_step * (i + 1))
238
+ .flatten()
239
+ .astype("int32")
240
+ .tolist()
241
+ )
242
+ if ly < h and lx < w and (ly, lx) not in left_list:
243
+ if binary_tcl_map[ly, lx] > 0.5:
244
+ left_list.append((ly, lx))
245
+ else:
246
+ break
247
+
248
+ for i in range(max_append_num):
249
+ ry, rx = (
250
+ np.round(right_start + right_step * (i + 1))
251
+ .flatten()
252
+ .astype("int32")
253
+ .tolist()
254
+ )
255
+ if ry < h and rx < w and (ry, rx) not in right_list:
256
+ if binary_tcl_map[ry, rx] > 0.5:
257
+ right_list.append((ry, rx))
258
+ else:
259
+ break
260
+
261
+ all_list = left_list[::-1] + sorted_list + right_list
262
+ return all_list
263
+
264
+
265
+ def point_pair2poly(point_pair_list):
266
+ """
267
+ Transfer vertical point_pairs into poly point in clockwise.
268
+ """
269
+ point_num = len(point_pair_list) * 2
270
+ point_list = [0] * point_num
271
+ for idx, point_pair in enumerate(point_pair_list):
272
+ point_list[idx] = point_pair[0]
273
+ point_list[point_num - 1 - idx] = point_pair[1]
274
+ return np.array(point_list).reshape(-1, 2)
275
+
276
+
277
+ def shrink_quad_along_width(quad, begin_width_ratio=0.0, end_width_ratio=1.0):
278
+ ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
279
+ p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
280
+ p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
281
+ return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
282
+
283
+
284
+ def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
285
+ """
286
+ expand poly along width.
287
+ """
288
+ point_num = poly.shape[0]
289
+ left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
290
+ left_ratio = (
291
+ -shrink_ratio_of_width
292
+ * np.linalg.norm(left_quad[0] - left_quad[3])
293
+ / (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
294
+ )
295
+ left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
296
+ right_quad = np.array(
297
+ [
298
+ poly[point_num // 2 - 2],
299
+ poly[point_num // 2 - 1],
300
+ poly[point_num // 2],
301
+ poly[point_num // 2 + 1],
302
+ ],
303
+ dtype=np.float32,
304
+ )
305
+ right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(
306
+ right_quad[0] - right_quad[3]
307
+ ) / (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
308
+ right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
309
+ poly[0] = left_quad_expand[0]
310
+ poly[-1] = left_quad_expand[-1]
311
+ poly[point_num // 2 - 1] = right_quad_expand[1]
312
+ poly[point_num // 2] = right_quad_expand[2]
313
+ return poly
314
+
315
+
316
+ def restore_poly(
317
+ instance_yxs_list, seq_strs, p_border, ratio_w, ratio_h, src_w, src_h, valid_set
318
+ ):
319
+ poly_list = []
320
+ keep_str_list = []
321
+ for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
322
+ if len(keep_str) < 2:
323
+ print("--> too short, {}".format(keep_str))
324
+ continue
325
+
326
+ offset_expand = 1.0
327
+ if valid_set == "totaltext":
328
+ offset_expand = 1.2
329
+
330
+ point_pair_list = []
331
+ for y, x in yx_center_line:
332
+ offset = p_border[:, y, x].reshape(2, 2) * offset_expand
333
+ ori_yx = np.array([y, x], dtype=np.float32)
334
+ point_pair = (
335
+ (ori_yx + offset)[:, ::-1]
336
+ * 4.0
337
+ / np.array([ratio_w, ratio_h]).reshape(-1, 2)
338
+ )
339
+ point_pair_list.append(point_pair)
340
+
341
+ detected_poly = point_pair2poly(point_pair_list)
342
+ detected_poly = expand_poly_along_width(
343
+ detected_poly, shrink_ratio_of_width=0.2
344
+ )
345
+ detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
346
+ detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
347
+
348
+ keep_str_list.append(keep_str)
349
+ if valid_set == "partvgg":
350
+ middle_point = len(detected_poly) // 2
351
+ detected_poly = detected_poly[[0, middle_point - 1, middle_point, -1], :]
352
+ poly_list.append(detected_poly)
353
+ elif valid_set == "totaltext":
354
+ poly_list.append(detected_poly)
355
+ else:
356
+ print("--> Not supported format.")
357
+ exit(-1)
358
+ return poly_list, keep_str_list
359
+
360
+
361
+ def generate_pivot_list_fast(
362
+ p_score, p_char_maps, f_direction, Lexicon_Table, score_thresh=0.5
363
+ ):
364
+ """
365
+ return center point and end point of TCL instance; filter with the char maps;
366
+ """
367
+ p_score = p_score[0]
368
+ f_direction = f_direction.transpose(1, 2, 0)
369
+ p_tcl_map = (p_score > score_thresh) * 1.0
370
+ skeleton_map = thin(p_tcl_map.astype(np.uint8))
371
+ instance_count, instance_label_map = cv2.connectedComponents(
372
+ skeleton_map.astype(np.uint8), connectivity=8
373
+ )
374
+
375
+ # get TCL Instance
376
+ all_pos_yxs = []
377
+ if instance_count > 0:
378
+ for instance_id in range(1, instance_count):
379
+ pos_list = []
380
+ ys, xs = np.where(instance_label_map == instance_id)
381
+ pos_list = list(zip(ys, xs))
382
+
383
+ if len(pos_list) < 3:
384
+ continue
385
+
386
+ pos_list_sorted = sort_and_expand_with_direction_v2(
387
+ pos_list, f_direction, p_tcl_map
388
+ )
389
+ all_pos_yxs.append(pos_list_sorted)
390
+
391
+ p_char_maps = p_char_maps.transpose([1, 2, 0])
392
+ decoded_str, keep_yxs_list = ctc_decoder_for_image(
393
+ all_pos_yxs, logits_map=p_char_maps, Lexicon_Table=Lexicon_Table
394
+ )
395
+ return keep_yxs_list, decoded_str
396
+
397
+
398
+ def extract_main_direction(pos_list, f_direction):
399
+ """
400
+ f_direction: h x w x 2
401
+ pos_list: [[y, x], [y, x], [y, x] ...]
402
+ """
403
+ pos_list = np.array(pos_list)
404
+ point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
405
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
406
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
407
+ average_direction = average_direction / (np.linalg.norm(average_direction) + 1e-6)
408
+ return average_direction
409
+
410
+
411
+ def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
412
+ """
413
+ f_direction: h x w x 2
414
+ pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
415
+ """
416
+ pos_list_full = np.array(pos_list).reshape(-1, 3)
417
+ pos_list = pos_list_full[:, 1:]
418
+ point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
419
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
420
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
421
+ pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
422
+ sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
423
+ return sorted_list
424
+
425
+
426
+ def sort_by_direction_with_image_id(pos_list, f_direction):
427
+ """
428
+ f_direction: h x w x 2
429
+ pos_list: [[y, x], [y, x], [y, x] ...]
430
+ """
431
+
432
+ def sort_part_with_direction(pos_list_full, point_direction):
433
+ pos_list_full = np.array(pos_list_full).reshape(-1, 3)
434
+ pos_list = pos_list_full[:, 1:]
435
+ point_direction = np.array(point_direction).reshape(-1, 2)
436
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
437
+ pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
438
+ sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
439
+ sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
440
+ return sorted_list, sorted_direction
441
+
442
+ pos_list = np.array(pos_list).reshape(-1, 3)
443
+ point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
444
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
445
+ sorted_point, sorted_direction = sort_part_with_direction(pos_list, point_direction)
446
+
447
+ point_num = len(sorted_point)
448
+ if point_num >= 16:
449
+ middle_num = point_num // 2
450
+ first_part_point = sorted_point[:middle_num]
451
+ first_point_direction = sorted_direction[:middle_num]
452
+ sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
453
+ first_part_point, first_point_direction
454
+ )
455
+
456
+ last_part_point = sorted_point[middle_num:]
457
+ last_point_direction = sorted_direction[middle_num:]
458
+ sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
459
+ last_part_point, last_point_direction
460
+ )
461
+ sorted_point = sorted_fist_part_point + sorted_last_part_point
462
+ sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
463
+
464
+ return sorted_point
ocr/postprocess/extract_textpoint_slow.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, division, print_function
2
+
3
+ import math
4
+ from itertools import groupby
5
+
6
+ import cv2
7
+ import numpy as np
8
+ from skimage.morphology._skeletonize import thin
9
+
10
+
11
+ def get_dict(character_dict_path):
12
+ character_str = ""
13
+ with open(character_dict_path, "rb") as fin:
14
+ lines = fin.readlines()
15
+ for line in lines:
16
+ line = line.decode("utf-8").strip("\n").strip("\r\n")
17
+ character_str += line
18
+ dict_character = list(character_str)
19
+ return dict_character
20
+
21
+
22
+ def point_pair2poly(point_pair_list):
23
+ """
24
+ Transfer vertical point_pairs into poly point in clockwise.
25
+ """
26
+ pair_length_list = []
27
+ for point_pair in point_pair_list:
28
+ pair_length = np.linalg.norm(point_pair[0] - point_pair[1])
29
+ pair_length_list.append(pair_length)
30
+ pair_length_list = np.array(pair_length_list)
31
+ pair_info = (
32
+ pair_length_list.max(),
33
+ pair_length_list.min(),
34
+ pair_length_list.mean(),
35
+ )
36
+
37
+ point_num = len(point_pair_list) * 2
38
+ point_list = [0] * point_num
39
+ for idx, point_pair in enumerate(point_pair_list):
40
+ point_list[idx] = point_pair[0]
41
+ point_list[point_num - 1 - idx] = point_pair[1]
42
+ return np.array(point_list).reshape(-1, 2), pair_info
43
+
44
+
45
+ def shrink_quad_along_width(quad, begin_width_ratio=0.0, end_width_ratio=1.0):
46
+ """
47
+ Generate shrink_quad_along_width.
48
+ """
49
+ ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
50
+ p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
51
+ p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
52
+ return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
53
+
54
+
55
+ def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
56
+ """
57
+ expand poly along width.
58
+ """
59
+ point_num = poly.shape[0]
60
+ left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
61
+ left_ratio = (
62
+ -shrink_ratio_of_width
63
+ * np.linalg.norm(left_quad[0] - left_quad[3])
64
+ / (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
65
+ )
66
+ left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
67
+ right_quad = np.array(
68
+ [
69
+ poly[point_num // 2 - 2],
70
+ poly[point_num // 2 - 1],
71
+ poly[point_num // 2],
72
+ poly[point_num // 2 + 1],
73
+ ],
74
+ dtype=np.float32,
75
+ )
76
+ right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(
77
+ right_quad[0] - right_quad[3]
78
+ ) / (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
79
+ right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
80
+ poly[0] = left_quad_expand[0]
81
+ poly[-1] = left_quad_expand[-1]
82
+ poly[point_num // 2 - 1] = right_quad_expand[1]
83
+ poly[point_num // 2] = right_quad_expand[2]
84
+ return poly
85
+
86
+
87
+ def softmax(logits):
88
+ """
89
+ logits: N x d
90
+ """
91
+ max_value = np.max(logits, axis=1, keepdims=True)
92
+ exp = np.exp(logits - max_value)
93
+ exp_sum = np.sum(exp, axis=1, keepdims=True)
94
+ dist = exp / exp_sum
95
+ return dist
96
+
97
+
98
+ def get_keep_pos_idxs(labels, remove_blank=None):
99
+ """
100
+ Remove duplicate and get pos idxs of keep items.
101
+ The value of keep_blank should be [None, 95].
102
+ """
103
+ duplicate_len_list = []
104
+ keep_pos_idx_list = []
105
+ keep_char_idx_list = []
106
+ for k, v_ in groupby(labels):
107
+ current_len = len(list(v_))
108
+ if k != remove_blank:
109
+ current_idx = int(sum(duplicate_len_list) + current_len // 2)
110
+ keep_pos_idx_list.append(current_idx)
111
+ keep_char_idx_list.append(k)
112
+ duplicate_len_list.append(current_len)
113
+ return keep_char_idx_list, keep_pos_idx_list
114
+
115
+
116
+ def remove_blank(labels, blank=0):
117
+ new_labels = [x for x in labels if x != blank]
118
+ return new_labels
119
+
120
+
121
+ def insert_blank(labels, blank=0):
122
+ new_labels = [blank]
123
+ for l in labels:
124
+ new_labels += [l, blank]
125
+ return new_labels
126
+
127
+
128
+ def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
129
+ """
130
+ CTC greedy (best path) decoder.
131
+ """
132
+ raw_str = np.argmax(np.array(probs_seq), axis=1)
133
+ remove_blank_in_pos = None if keep_blank_in_idxs else blank
134
+ dedup_str, keep_idx_list = get_keep_pos_idxs(
135
+ raw_str, remove_blank=remove_blank_in_pos
136
+ )
137
+ dst_str = remove_blank(dedup_str, blank=blank)
138
+ return dst_str, keep_idx_list
139
+
140
+
141
+ def instance_ctc_greedy_decoder(gather_info, logits_map, keep_blank_in_idxs=True):
142
+ """
143
+ gather_info: [[x, y], [x, y] ...]
144
+ logits_map: H x W X (n_chars + 1)
145
+ """
146
+ _, _, C = logits_map.shape
147
+ ys, xs = zip(*gather_info)
148
+ logits_seq = logits_map[list(ys), list(xs)] # n x 96
149
+ probs_seq = softmax(logits_seq)
150
+ dst_str, keep_idx_list = ctc_greedy_decoder(
151
+ probs_seq, blank=C - 1, keep_blank_in_idxs=keep_blank_in_idxs
152
+ )
153
+ keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
154
+ return dst_str, keep_gather_list
155
+
156
+
157
+ def ctc_decoder_for_image(gather_info_list, logits_map, keep_blank_in_idxs=True):
158
+ """
159
+ CTC decoder using multiple processes.
160
+ """
161
+ decoder_results = []
162
+ for gather_info in gather_info_list:
163
+ res = instance_ctc_greedy_decoder(
164
+ gather_info, logits_map, keep_blank_in_idxs=keep_blank_in_idxs
165
+ )
166
+ decoder_results.append(res)
167
+ return decoder_results
168
+
169
+
170
+ def sort_with_direction(pos_list, f_direction):
171
+ """
172
+ f_direction: h x w x 2
173
+ pos_list: [[y, x], [y, x], [y, x] ...]
174
+ """
175
+
176
+ def sort_part_with_direction(pos_list, point_direction):
177
+ pos_list = np.array(pos_list).reshape(-1, 2)
178
+ point_direction = np.array(point_direction).reshape(-1, 2)
179
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
180
+ pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
181
+ sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist()
182
+ sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
183
+ return sorted_list, sorted_direction
184
+
185
+ pos_list = np.array(pos_list).reshape(-1, 2)
186
+ point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
187
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
188
+ sorted_point, sorted_direction = sort_part_with_direction(pos_list, point_direction)
189
+
190
+ point_num = len(sorted_point)
191
+ if point_num >= 16:
192
+ middle_num = point_num // 2
193
+ first_part_point = sorted_point[:middle_num]
194
+ first_point_direction = sorted_direction[:middle_num]
195
+ sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
196
+ first_part_point, first_point_direction
197
+ )
198
+
199
+ last_part_point = sorted_point[middle_num:]
200
+ last_point_direction = sorted_direction[middle_num:]
201
+ sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
202
+ last_part_point, last_point_direction
203
+ )
204
+ sorted_point = sorted_fist_part_point + sorted_last_part_point
205
+ sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
206
+
207
+ return sorted_point, np.array(sorted_direction)
208
+
209
+
210
+ def add_id(pos_list, image_id=0):
211
+ """
212
+ Add id for gather feature, for inference.
213
+ """
214
+ new_list = []
215
+ for item in pos_list:
216
+ new_list.append((image_id, item[0], item[1]))
217
+ return new_list
218
+
219
+
220
+ def sort_and_expand_with_direction(pos_list, f_direction):
221
+ """
222
+ f_direction: h x w x 2
223
+ pos_list: [[y, x], [y, x], [y, x] ...]
224
+ """
225
+ h, w, _ = f_direction.shape
226
+ sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
227
+
228
+ # expand along
229
+ point_num = len(sorted_list)
230
+ sub_direction_len = max(point_num // 3, 2)
231
+ left_direction = point_direction[:sub_direction_len, :]
232
+ right_dirction = point_direction[point_num - sub_direction_len :, :]
233
+
234
+ left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
235
+ left_average_len = np.linalg.norm(left_average_direction)
236
+ left_start = np.array(sorted_list[0])
237
+ left_step = left_average_direction / (left_average_len + 1e-6)
238
+
239
+ right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
240
+ right_average_len = np.linalg.norm(right_average_direction)
241
+ right_step = right_average_direction / (right_average_len + 1e-6)
242
+ right_start = np.array(sorted_list[-1])
243
+
244
+ append_num = max(int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
245
+ left_list = []
246
+ right_list = []
247
+ for i in range(append_num):
248
+ ly, lx = (
249
+ np.round(left_start + left_step * (i + 1))
250
+ .flatten()
251
+ .astype("int32")
252
+ .tolist()
253
+ )
254
+ if ly < h and lx < w and (ly, lx) not in left_list:
255
+ left_list.append((ly, lx))
256
+ ry, rx = (
257
+ np.round(right_start + right_step * (i + 1))
258
+ .flatten()
259
+ .astype("int32")
260
+ .tolist()
261
+ )
262
+ if ry < h and rx < w and (ry, rx) not in right_list:
263
+ right_list.append((ry, rx))
264
+
265
+ all_list = left_list[::-1] + sorted_list + right_list
266
+ return all_list
267
+
268
+
269
+ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
270
+ """
271
+ f_direction: h x w x 2
272
+ pos_list: [[y, x], [y, x], [y, x] ...]
273
+ binary_tcl_map: h x w
274
+ """
275
+ h, w, _ = f_direction.shape
276
+ sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
277
+
278
+ # expand along
279
+ point_num = len(sorted_list)
280
+ sub_direction_len = max(point_num // 3, 2)
281
+ left_direction = point_direction[:sub_direction_len, :]
282
+ right_dirction = point_direction[point_num - sub_direction_len :, :]
283
+
284
+ left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
285
+ left_average_len = np.linalg.norm(left_average_direction)
286
+ left_start = np.array(sorted_list[0])
287
+ left_step = left_average_direction / (left_average_len + 1e-6)
288
+
289
+ right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
290
+ right_average_len = np.linalg.norm(right_average_direction)
291
+ right_step = right_average_direction / (right_average_len + 1e-6)
292
+ right_start = np.array(sorted_list[-1])
293
+
294
+ append_num = max(int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
295
+ max_append_num = 2 * append_num
296
+
297
+ left_list = []
298
+ right_list = []
299
+ for i in range(max_append_num):
300
+ ly, lx = (
301
+ np.round(left_start + left_step * (i + 1))
302
+ .flatten()
303
+ .astype("int32")
304
+ .tolist()
305
+ )
306
+ if ly < h and lx < w and (ly, lx) not in left_list:
307
+ if binary_tcl_map[ly, lx] > 0.5:
308
+ left_list.append((ly, lx))
309
+ else:
310
+ break
311
+
312
+ for i in range(max_append_num):
313
+ ry, rx = (
314
+ np.round(right_start + right_step * (i + 1))
315
+ .flatten()
316
+ .astype("int32")
317
+ .tolist()
318
+ )
319
+ if ry < h and rx < w and (ry, rx) not in right_list:
320
+ if binary_tcl_map[ry, rx] > 0.5:
321
+ right_list.append((ry, rx))
322
+ else:
323
+ break
324
+
325
+ all_list = left_list[::-1] + sorted_list + right_list
326
+ return all_list
327
+
328
+
329
+ def generate_pivot_list_curved(
330
+ p_score,
331
+ p_char_maps,
332
+ f_direction,
333
+ score_thresh=0.5,
334
+ is_expand=True,
335
+ is_backbone=False,
336
+ image_id=0,
337
+ ):
338
+ """
339
+ return center point and end point of TCL instance; filter with the char maps;
340
+ """
341
+ p_score = p_score[0]
342
+ f_direction = f_direction.transpose(1, 2, 0)
343
+ p_tcl_map = (p_score > score_thresh) * 1.0
344
+ skeleton_map = thin(p_tcl_map)
345
+ instance_count, instance_label_map = cv2.connectedComponents(
346
+ skeleton_map.astype(np.uint8), connectivity=8
347
+ )
348
+
349
+ # get TCL Instance
350
+ all_pos_yxs = []
351
+ center_pos_yxs = []
352
+ end_points_yxs = []
353
+ instance_center_pos_yxs = []
354
+ pred_strs = []
355
+ if instance_count > 0:
356
+ for instance_id in range(1, instance_count):
357
+ pos_list = []
358
+ ys, xs = np.where(instance_label_map == instance_id)
359
+ pos_list = list(zip(ys, xs))
360
+
361
+ ### FIX-ME, eliminate outlier
362
+ if len(pos_list) < 3:
363
+ continue
364
+
365
+ if is_expand:
366
+ pos_list_sorted = sort_and_expand_with_direction_v2(
367
+ pos_list, f_direction, p_tcl_map
368
+ )
369
+ else:
370
+ pos_list_sorted, _ = sort_with_direction(pos_list, f_direction)
371
+ all_pos_yxs.append(pos_list_sorted)
372
+
373
+ # use decoder to filter backgroud points.
374
+ p_char_maps = p_char_maps.transpose([1, 2, 0])
375
+ decode_res = ctc_decoder_for_image(
376
+ all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True
377
+ )
378
+ for decoded_str, keep_yxs_list in decode_res:
379
+ if is_backbone:
380
+ keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
381
+ instance_center_pos_yxs.append(keep_yxs_list_with_id)
382
+ pred_strs.append(decoded_str)
383
+ else:
384
+ end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
385
+ center_pos_yxs.extend(keep_yxs_list)
386
+
387
+ if is_backbone:
388
+ return pred_strs, instance_center_pos_yxs
389
+ else:
390
+ return center_pos_yxs, end_points_yxs
391
+
392
+
393
+ def generate_pivot_list_horizontal(
394
+ p_score, p_char_maps, f_direction, score_thresh=0.5, is_backbone=False, image_id=0
395
+ ):
396
+ """
397
+ return center point and end point of TCL instance; filter with the char maps;
398
+ """
399
+ p_score = p_score[0]
400
+ f_direction = f_direction.transpose(1, 2, 0)
401
+ p_tcl_map_bi = (p_score > score_thresh) * 1.0
402
+ instance_count, instance_label_map = cv2.connectedComponents(
403
+ p_tcl_map_bi.astype(np.uint8), connectivity=8
404
+ )
405
+
406
+ # get TCL Instance
407
+ all_pos_yxs = []
408
+ center_pos_yxs = []
409
+ end_points_yxs = []
410
+ instance_center_pos_yxs = []
411
+
412
+ if instance_count > 0:
413
+ for instance_id in range(1, instance_count):
414
+ pos_list = []
415
+ ys, xs = np.where(instance_label_map == instance_id)
416
+ pos_list = list(zip(ys, xs))
417
+
418
+ ### FIX-ME, eliminate outlier
419
+ if len(pos_list) < 5:
420
+ continue
421
+
422
+ # add rule here
423
+ main_direction = extract_main_direction(pos_list, f_direction) # y x
424
+ reference_directin = np.array([0, 1]).reshape([-1, 2]) # y x
425
+ is_h_angle = abs(np.sum(main_direction * reference_directin)) < math.cos(
426
+ math.pi / 180 * 70
427
+ )
428
+
429
+ point_yxs = np.array(pos_list)
430
+ max_y, max_x = np.max(point_yxs, axis=0)
431
+ min_y, min_x = np.min(point_yxs, axis=0)
432
+ is_h_len = (max_y - min_y) < 1.5 * (max_x - min_x)
433
+
434
+ pos_list_final = []
435
+ if is_h_len:
436
+ xs = np.unique(xs)
437
+ for x in xs:
438
+ ys = instance_label_map[:, x].copy().reshape((-1,))
439
+ y = int(np.where(ys == instance_id)[0].mean())
440
+ pos_list_final.append((y, x))
441
+ else:
442
+ ys = np.unique(ys)
443
+ for y in ys:
444
+ xs = instance_label_map[y, :].copy().reshape((-1,))
445
+ x = int(np.where(xs == instance_id)[0].mean())
446
+ pos_list_final.append((y, x))
447
+
448
+ pos_list_sorted, _ = sort_with_direction(pos_list_final, f_direction)
449
+ all_pos_yxs.append(pos_list_sorted)
450
+
451
+ # use decoder to filter backgroud points.
452
+ p_char_maps = p_char_maps.transpose([1, 2, 0])
453
+ decode_res = ctc_decoder_for_image(
454
+ all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True
455
+ )
456
+ for decoded_str, keep_yxs_list in decode_res:
457
+ if is_backbone:
458
+ keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
459
+ instance_center_pos_yxs.append(keep_yxs_list_with_id)
460
+ else:
461
+ end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
462
+ center_pos_yxs.extend(keep_yxs_list)
463
+
464
+ if is_backbone:
465
+ return instance_center_pos_yxs
466
+ else:
467
+ return center_pos_yxs, end_points_yxs
468
+
469
+
470
+ def generate_pivot_list_slow(
471
+ p_score,
472
+ p_char_maps,
473
+ f_direction,
474
+ score_thresh=0.5,
475
+ is_backbone=False,
476
+ is_curved=True,
477
+ image_id=0,
478
+ ):
479
+ """
480
+ Warp all the function together.
481
+ """
482
+ if is_curved:
483
+ return generate_pivot_list_curved(
484
+ p_score,
485
+ p_char_maps,
486
+ f_direction,
487
+ score_thresh=score_thresh,
488
+ is_expand=True,
489
+ is_backbone=is_backbone,
490
+ image_id=image_id,
491
+ )
492
+ else:
493
+ return generate_pivot_list_horizontal(
494
+ p_score,
495
+ p_char_maps,
496
+ f_direction,
497
+ score_thresh=score_thresh,
498
+ is_backbone=is_backbone,
499
+ image_id=image_id,
500
+ )
501
+
502
+
503
+ # for refine module
504
+ def extract_main_direction(pos_list, f_direction):
505
+ """
506
+ f_direction: h x w x 2
507
+ pos_list: [[y, x], [y, x], [y, x] ...]
508
+ """
509
+ pos_list = np.array(pos_list)
510
+ point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
511
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
512
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
513
+ average_direction = average_direction / (np.linalg.norm(average_direction) + 1e-6)
514
+ return average_direction
515
+
516
+
517
+ def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
518
+ """
519
+ f_direction: h x w x 2
520
+ pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
521
+ """
522
+ pos_list_full = np.array(pos_list).reshape(-1, 3)
523
+ pos_list = pos_list_full[:, 1:]
524
+ point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
525
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
526
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
527
+ pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
528
+ sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
529
+ return sorted_list
530
+
531
+
532
+ def sort_by_direction_with_image_id(pos_list, f_direction):
533
+ """
534
+ f_direction: h x w x 2
535
+ pos_list: [[y, x], [y, x], [y, x] ...]
536
+ """
537
+
538
+ def sort_part_with_direction(pos_list_full, point_direction):
539
+ pos_list_full = np.array(pos_list_full).reshape(-1, 3)
540
+ pos_list = pos_list_full[:, 1:]
541
+ point_direction = np.array(point_direction).reshape(-1, 2)
542
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
543
+ pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
544
+ sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
545
+ sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
546
+ return sorted_list, sorted_direction
547
+
548
+ pos_list = np.array(pos_list).reshape(-1, 3)
549
+ point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
550
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
551
+ sorted_point, sorted_direction = sort_part_with_direction(pos_list, point_direction)
552
+
553
+ point_num = len(sorted_point)
554
+ if point_num >= 16:
555
+ middle_num = point_num // 2
556
+ first_part_point = sorted_point[:middle_num]
557
+ first_point_direction = sorted_direction[:middle_num]
558
+ sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
559
+ first_part_point, first_point_direction
560
+ )
561
+
562
+ last_part_point = sorted_point[middle_num:]
563
+ last_point_direction = sorted_direction[middle_num:]
564
+ sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
565
+ last_part_point, last_point_direction
566
+ )
567
+ sorted_point = sorted_fist_part_point + sorted_last_part_point
568
+ sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
569
+
570
+ return sorted_point
571
+
572
+
573
+ def generate_pivot_list_tt_inference(
574
+ p_score,
575
+ p_char_maps,
576
+ f_direction,
577
+ score_thresh=0.5,
578
+ is_backbone=False,
579
+ is_curved=True,
580
+ image_id=0,
581
+ ):
582
+ """
583
+ return center point and end point of TCL instance; filter with the char maps;
584
+ """
585
+ p_score = p_score[0]
586
+ f_direction = f_direction.transpose(1, 2, 0)
587
+ p_tcl_map = (p_score > score_thresh) * 1.0
588
+ skeleton_map = thin(p_tcl_map)
589
+ instance_count, instance_label_map = cv2.connectedComponents(
590
+ skeleton_map.astype(np.uint8), connectivity=8
591
+ )
592
+
593
+ # get TCL Instance
594
+ all_pos_yxs = []
595
+ if instance_count > 0:
596
+ for instance_id in range(1, instance_count):
597
+ pos_list = []
598
+ ys, xs = np.where(instance_label_map == instance_id)
599
+ pos_list = list(zip(ys, xs))
600
+ ### FIX-ME, eliminate outlier
601
+ if len(pos_list) < 3:
602
+ continue
603
+ pos_list_sorted = sort_and_expand_with_direction_v2(
604
+ pos_list, f_direction, p_tcl_map
605
+ )
606
+ pos_list_sorted_with_id = add_id(pos_list_sorted, image_id=image_id)
607
+ all_pos_yxs.append(pos_list_sorted_with_id)
608
+ return all_pos_yxs
ocr/postprocess/fce_postprocess.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import paddle
4
+ from numpy.fft import ifft
5
+
6
+ from .poly_nms import *
7
+
8
+
9
+ def fill_hole(input_mask):
10
+ h, w = input_mask.shape
11
+ canvas = np.zeros((h + 2, w + 2), np.uint8)
12
+ canvas[1 : h + 1, 1 : w + 1] = input_mask.copy()
13
+
14
+ mask = np.zeros((h + 4, w + 4), np.uint8)
15
+
16
+ cv2.floodFill(canvas, mask, (0, 0), 1)
17
+ canvas = canvas[1 : h + 1, 1 : w + 1].astype(np.bool)
18
+
19
+ return ~canvas | input_mask
20
+
21
+
22
+ def fourier2poly(fourier_coeff, num_reconstr_points=50):
23
+ """Inverse Fourier transform
24
+ Args:
25
+ fourier_coeff (ndarray): Fourier coefficients shaped (n, 2k+1),
26
+ with n and k being candidates number and Fourier degree
27
+ respectively.
28
+ num_reconstr_points (int): Number of reconstructed polygon points.
29
+ Returns:
30
+ Polygons (ndarray): The reconstructed polygons shaped (n, n')
31
+ """
32
+
33
+ a = np.zeros((len(fourier_coeff), num_reconstr_points), dtype="complex")
34
+ k = (len(fourier_coeff[0]) - 1) // 2
35
+
36
+ a[:, 0 : k + 1] = fourier_coeff[:, k:]
37
+ a[:, -k:] = fourier_coeff[:, :k]
38
+
39
+ poly_complex = ifft(a) * num_reconstr_points
40
+ polygon = np.zeros((len(fourier_coeff), num_reconstr_points, 2))
41
+ polygon[:, :, 0] = poly_complex.real
42
+ polygon[:, :, 1] = poly_complex.imag
43
+ return polygon.astype("int32").reshape((len(fourier_coeff), -1))
44
+
45
+
46
+ class FCEPostProcess(object):
47
+ """
48
+ The post process for FCENet.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ scales,
54
+ fourier_degree=5,
55
+ num_reconstr_points=50,
56
+ decoding_type="fcenet",
57
+ score_thr=0.3,
58
+ nms_thr=0.1,
59
+ alpha=1.0,
60
+ beta=1.0,
61
+ box_type="poly",
62
+ **kwargs
63
+ ):
64
+
65
+ self.scales = scales
66
+ self.fourier_degree = fourier_degree
67
+ self.num_reconstr_points = num_reconstr_points
68
+ self.decoding_type = decoding_type
69
+ self.score_thr = score_thr
70
+ self.nms_thr = nms_thr
71
+ self.alpha = alpha
72
+ self.beta = beta
73
+ self.box_type = box_type
74
+
75
+ def __call__(self, preds, shape_list):
76
+ score_maps = []
77
+ for key, value in preds.items():
78
+ if isinstance(value, paddle.Tensor):
79
+ value = value.numpy()
80
+ cls_res = value[:, :4, :, :]
81
+ reg_res = value[:, 4:, :, :]
82
+ score_maps.append([cls_res, reg_res])
83
+
84
+ return self.get_boundary(score_maps, shape_list)
85
+
86
+ def resize_boundary(self, boundaries, scale_factor):
87
+ """Rescale boundaries via scale_factor.
88
+
89
+ Args:
90
+ boundaries (list[list[float]]): The boundary list. Each boundary
91
+ with size 2k+1 with k>=4.
92
+ scale_factor(ndarray): The scale factor of size (4,).
93
+
94
+ Returns:
95
+ boundaries (list[list[float]]): The scaled boundaries.
96
+ """
97
+ boxes = []
98
+ scores = []
99
+ for b in boundaries:
100
+ sz = len(b)
101
+ valid_boundary(b, True)
102
+ scores.append(b[-1])
103
+ b = (
104
+ (
105
+ np.array(b[: sz - 1])
106
+ * (np.tile(scale_factor[:2], int((sz - 1) / 2)).reshape(1, sz - 1))
107
+ )
108
+ .flatten()
109
+ .tolist()
110
+ )
111
+ boxes.append(np.array(b).reshape([-1, 2]))
112
+
113
+ return np.array(boxes, dtype=np.float32), scores
114
+
115
+ def get_boundary(self, score_maps, shape_list):
116
+ assert len(score_maps) == len(self.scales)
117
+ boundaries = []
118
+ for idx, score_map in enumerate(score_maps):
119
+ scale = self.scales[idx]
120
+ boundaries = boundaries + self._get_boundary_single(score_map, scale)
121
+
122
+ # nms
123
+ boundaries = poly_nms(boundaries, self.nms_thr)
124
+ boundaries, scores = self.resize_boundary(
125
+ boundaries, (1 / shape_list[0, 2:]).tolist()[::-1]
126
+ )
127
+
128
+ boxes_batch = [dict(points=boundaries, scores=scores)]
129
+ return boxes_batch
130
+
131
+ def _get_boundary_single(self, score_map, scale):
132
+ assert len(score_map) == 2
133
+ assert score_map[1].shape[1] == 4 * self.fourier_degree + 2
134
+
135
+ return self.fcenet_decode(
136
+ preds=score_map,
137
+ fourier_degree=self.fourier_degree,
138
+ num_reconstr_points=self.num_reconstr_points,
139
+ scale=scale,
140
+ alpha=self.alpha,
141
+ beta=self.beta,
142
+ box_type=self.box_type,
143
+ score_thr=self.score_thr,
144
+ nms_thr=self.nms_thr,
145
+ )
146
+
147
+ def fcenet_decode(
148
+ self,
149
+ preds,
150
+ fourier_degree,
151
+ num_reconstr_points,
152
+ scale,
153
+ alpha=1.0,
154
+ beta=2.0,
155
+ box_type="poly",
156
+ score_thr=0.3,
157
+ nms_thr=0.1,
158
+ ):
159
+ """Decoding predictions of FCENet to instances.
160
+
161
+ Args:
162
+ preds (list(Tensor)): The head output tensors.
163
+ fourier_degree (int): The maximum Fourier transform degree k.
164
+ num_reconstr_points (int): The points number of the polygon
165
+ reconstructed from predicted Fourier coefficients.
166
+ scale (int): The down-sample scale of the prediction.
167
+ alpha (float) : The parameter to calculate final scores. Score_{final}
168
+ = (Score_{text region} ^ alpha)
169
+ * (Score_{text center region}^ beta)
170
+ beta (float) : The parameter to calculate final score.
171
+ box_type (str): Boundary encoding type 'poly' or 'quad'.
172
+ score_thr (float) : The threshold used to filter out the final
173
+ candidates.
174
+ nms_thr (float) : The threshold of nms.
175
+
176
+ Returns:
177
+ boundaries (list[list[float]]): The instance boundary and confidence
178
+ list.
179
+ """
180
+ assert isinstance(preds, list)
181
+ assert len(preds) == 2
182
+ assert box_type in ["poly", "quad"]
183
+
184
+ cls_pred = preds[0][0]
185
+ tr_pred = cls_pred[0:2]
186
+ tcl_pred = cls_pred[2:]
187
+
188
+ reg_pred = preds[1][0].transpose([1, 2, 0])
189
+ x_pred = reg_pred[:, :, : 2 * fourier_degree + 1]
190
+ y_pred = reg_pred[:, :, 2 * fourier_degree + 1 :]
191
+
192
+ score_pred = (tr_pred[1] ** alpha) * (tcl_pred[1] ** beta)
193
+ tr_pred_mask = (score_pred) > score_thr
194
+ tr_mask = fill_hole(tr_pred_mask)
195
+
196
+ tr_contours, _ = cv2.findContours(
197
+ tr_mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
198
+ ) # opencv4
199
+
200
+ mask = np.zeros_like(tr_mask)
201
+ boundaries = []
202
+ for cont in tr_contours:
203
+ deal_map = mask.copy().astype(np.int8)
204
+ cv2.drawContours(deal_map, [cont], -1, 1, -1)
205
+
206
+ score_map = score_pred * deal_map
207
+ score_mask = score_map > 0
208
+ xy_text = np.argwhere(score_mask)
209
+ dxy = xy_text[:, 1] + xy_text[:, 0] * 1j
210
+
211
+ x, y = x_pred[score_mask], y_pred[score_mask]
212
+ c = x + y * 1j
213
+ c[:, fourier_degree] = c[:, fourier_degree] + dxy
214
+ c *= scale
215
+
216
+ polygons = fourier2poly(c, num_reconstr_points)
217
+ score = score_map[score_mask].reshape(-1, 1)
218
+ polygons = poly_nms(np.hstack((polygons, score)).tolist(), nms_thr)
219
+
220
+ boundaries = boundaries + polygons
221
+
222
+ boundaries = poly_nms(boundaries, nms_thr)
223
+
224
+ if box_type == "quad":
225
+ new_boundaries = []
226
+ for boundary in boundaries:
227
+ poly = np.array(boundary[:-1]).reshape(-1, 2).astype(np.float32)
228
+ score = boundary[-1]
229
+ points = cv2.boxPoints(cv2.minAreaRect(poly))
230
+ points = np.int0(points)
231
+ new_boundaries.append(points.reshape(-1).tolist() + [score])
232
+ boundaries = new_boundaries
233
+
234
+ return boundaries
ocr/postprocess/locality_aware_nms.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Locality aware nms.
3
+ This code is refered from: https://github.com/songdejia/EAST/blob/master/locality_aware_nms.py
4
+ """
5
+
6
+ import numpy as np
7
+ from shapely.geometry import Polygon
8
+
9
+
10
+ def intersection(g, p):
11
+ """
12
+ Intersection.
13
+ """
14
+ g = Polygon(g[:8].reshape((4, 2)))
15
+ p = Polygon(p[:8].reshape((4, 2)))
16
+ g = g.buffer(0)
17
+ p = p.buffer(0)
18
+ if not g.is_valid or not p.is_valid:
19
+ return 0
20
+ inter = Polygon(g).intersection(Polygon(p)).area
21
+ union = g.area + p.area - inter
22
+ if union == 0:
23
+ return 0
24
+ else:
25
+ return inter / union
26
+
27
+
28
+ def intersection_iog(g, p):
29
+ """
30
+ Intersection_iog.
31
+ """
32
+ g = Polygon(g[:8].reshape((4, 2)))
33
+ p = Polygon(p[:8].reshape((4, 2)))
34
+ if not g.is_valid or not p.is_valid:
35
+ return 0
36
+ inter = Polygon(g).intersection(Polygon(p)).area
37
+ # union = g.area + p.area - inter
38
+ union = p.area
39
+ if union == 0:
40
+ print("p_area is very small")
41
+ return 0
42
+ else:
43
+ return inter / union
44
+
45
+
46
+ def weighted_merge(g, p):
47
+ """
48
+ Weighted merge.
49
+ """
50
+ g[:8] = (g[8] * g[:8] + p[8] * p[:8]) / (g[8] + p[8])
51
+ g[8] = g[8] + p[8]
52
+ return g
53
+
54
+
55
+ def standard_nms(S, thres):
56
+ """
57
+ Standard nms.
58
+ """
59
+ order = np.argsort(S[:, 8])[::-1]
60
+ keep = []
61
+ while order.size > 0:
62
+ i = order[0]
63
+ keep.append(i)
64
+ ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
65
+
66
+ inds = np.where(ovr <= thres)[0]
67
+ order = order[inds + 1]
68
+
69
+ return S[keep]
70
+
71
+
72
+ def standard_nms_inds(S, thres):
73
+ """
74
+ Standard nms, retun inds.
75
+ """
76
+ order = np.argsort(S[:, 8])[::-1]
77
+ keep = []
78
+ while order.size > 0:
79
+ i = order[0]
80
+ keep.append(i)
81
+ ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
82
+
83
+ inds = np.where(ovr <= thres)[0]
84
+ order = order[inds + 1]
85
+
86
+ return keep
87
+
88
+
89
+ def nms(S, thres):
90
+ """
91
+ nms.
92
+ """
93
+ order = np.argsort(S[:, 8])[::-1]
94
+ keep = []
95
+ while order.size > 0:
96
+ i = order[0]
97
+ keep.append(i)
98
+ ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
99
+
100
+ inds = np.where(ovr <= thres)[0]
101
+ order = order[inds + 1]
102
+
103
+ return keep
104
+
105
+
106
+ def soft_nms(boxes_in, Nt_thres=0.3, threshold=0.8, sigma=0.5, method=2):
107
+ """
108
+ soft_nms
109
+ :para boxes_in, N x 9 (coords + score)
110
+ :para threshould, eliminate cases min score(0.001)
111
+ :para Nt_thres, iou_threshi
112
+ :para sigma, gaussian weght
113
+ :method, linear or gaussian
114
+ """
115
+ boxes = boxes_in.copy()
116
+ N = boxes.shape[0]
117
+ if N is None or N < 1:
118
+ return np.array([])
119
+ pos, maxpos = 0, 0
120
+ weight = 0.0
121
+ inds = np.arange(N)
122
+ tbox, sbox = boxes[0].copy(), boxes[0].copy()
123
+ for i in range(N):
124
+ maxscore = boxes[i, 8]
125
+ maxpos = i
126
+ tbox = boxes[i].copy()
127
+ ti = inds[i]
128
+ pos = i + 1
129
+ # get max box
130
+ while pos < N:
131
+ if maxscore < boxes[pos, 8]:
132
+ maxscore = boxes[pos, 8]
133
+ maxpos = pos
134
+ pos = pos + 1
135
+ # add max box as a detection
136
+ boxes[i, :] = boxes[maxpos, :]
137
+ inds[i] = inds[maxpos]
138
+ # swap
139
+ boxes[maxpos, :] = tbox
140
+ inds[maxpos] = ti
141
+ tbox = boxes[i].copy()
142
+ pos = i + 1
143
+ # NMS iteration
144
+ while pos < N:
145
+ sbox = boxes[pos].copy()
146
+ ts_iou_val = intersection(tbox, sbox)
147
+ if ts_iou_val > 0:
148
+ if method == 1:
149
+ if ts_iou_val > Nt_thres:
150
+ weight = 1 - ts_iou_val
151
+ else:
152
+ weight = 1
153
+ elif method == 2:
154
+ weight = np.exp(-1.0 * ts_iou_val**2 / sigma)
155
+ else:
156
+ if ts_iou_val > Nt_thres:
157
+ weight = 0
158
+ else:
159
+ weight = 1
160
+ boxes[pos, 8] = weight * boxes[pos, 8]
161
+ # if box score falls below thresold, discard the box by
162
+ # swaping last box update N
163
+ if boxes[pos, 8] < threshold:
164
+ boxes[pos, :] = boxes[N - 1, :]
165
+ inds[pos] = inds[N - 1]
166
+ N = N - 1
167
+ pos = pos - 1
168
+ pos = pos + 1
169
+
170
+ return boxes[:N]
171
+
172
+
173
+ def nms_locality(polys, thres=0.3):
174
+ """
175
+ locality aware nms of EAST
176
+ :param polys: a N*9 numpy array. first 8 coordinates, then prob
177
+ :return: boxes after nms
178
+ """
179
+ S = []
180
+ p = None
181
+ for g in polys:
182
+ if p is not None and intersection(g, p) > thres:
183
+ p = weighted_merge(g, p)
184
+ else:
185
+ if p is not None:
186
+ S.append(p)
187
+ p = g
188
+ if p is not None:
189
+ S.append(p)
190
+
191
+ if len(S) == 0:
192
+ return np.array([])
193
+ return standard_nms(np.array(S), thres)
194
+
195
+
196
+ if __name__ == "__main__":
197
+ # 343,350,448,135,474,143,369,359
198
+ print(Polygon(np.array([[343, 350], [448, 135], [474, 143], [369, 359]])).area)
ocr/postprocess/pg_postprocess.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, division, print_function
2
+
3
+ import os
4
+ import sys
5
+
6
+ import paddle
7
+
8
+ from .extract_textpoint_fast import *
9
+ from .extract_textpoint_slow import *
10
+
11
+ __dir__ = os.path.dirname(__file__)
12
+ sys.path.append(__dir__)
13
+ sys.path.append(os.path.join(__dir__, ".."))
14
+
15
+
16
+ class PGNet_PostProcess(object):
17
+ # two different post-process
18
+ def __init__(
19
+ self, character_dict_path, valid_set, score_thresh, outs_dict, shape_list
20
+ ):
21
+ self.Lexicon_Table = get_dict(character_dict_path)
22
+ self.valid_set = valid_set
23
+ self.score_thresh = score_thresh
24
+ self.outs_dict = outs_dict
25
+ self.shape_list = shape_list
26
+
27
+ def pg_postprocess_fast(self):
28
+ p_score = self.outs_dict["f_score"]
29
+ p_border = self.outs_dict["f_border"]
30
+ p_char = self.outs_dict["f_char"]
31
+ p_direction = self.outs_dict["f_direction"]
32
+ if isinstance(p_score, paddle.Tensor):
33
+ p_score = p_score[0].numpy()
34
+ p_border = p_border[0].numpy()
35
+ p_direction = p_direction[0].numpy()
36
+ p_char = p_char[0].numpy()
37
+ else:
38
+ p_score = p_score[0]
39
+ p_border = p_border[0]
40
+ p_direction = p_direction[0]
41
+ p_char = p_char[0]
42
+
43
+ src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
44
+ instance_yxs_list, seq_strs = generate_pivot_list_fast(
45
+ p_score,
46
+ p_char,
47
+ p_direction,
48
+ self.Lexicon_Table,
49
+ score_thresh=self.score_thresh,
50
+ )
51
+ poly_list, keep_str_list = restore_poly(
52
+ instance_yxs_list,
53
+ seq_strs,
54
+ p_border,
55
+ ratio_w,
56
+ ratio_h,
57
+ src_w,
58
+ src_h,
59
+ self.valid_set,
60
+ )
61
+ data = {
62
+ "points": poly_list,
63
+ "texts": keep_str_list,
64
+ }
65
+ return data
66
+
67
+ def pg_postprocess_slow(self):
68
+ p_score = self.outs_dict["f_score"]
69
+ p_border = self.outs_dict["f_border"]
70
+ p_char = self.outs_dict["f_char"]
71
+ p_direction = self.outs_dict["f_direction"]
72
+ if isinstance(p_score, paddle.Tensor):
73
+ p_score = p_score[0].numpy()
74
+ p_border = p_border[0].numpy()
75
+ p_direction = p_direction[0].numpy()
76
+ p_char = p_char[0].numpy()
77
+ else:
78
+ p_score = p_score[0]
79
+ p_border = p_border[0]
80
+ p_direction = p_direction[0]
81
+ p_char = p_char[0]
82
+ src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
83
+ is_curved = self.valid_set == "totaltext"
84
+ char_seq_idx_set, instance_yxs_list = generate_pivot_list_slow(
85
+ p_score,
86
+ p_char,
87
+ p_direction,
88
+ score_thresh=self.score_thresh,
89
+ is_backbone=True,
90
+ is_curved=is_curved,
91
+ )
92
+ seq_strs = []
93
+ for char_idx_set in char_seq_idx_set:
94
+ pr_str = "".join([self.Lexicon_Table[pos] for pos in char_idx_set])
95
+ seq_strs.append(pr_str)
96
+ poly_list = []
97
+ keep_str_list = []
98
+ all_point_list = []
99
+ all_point_pair_list = []
100
+ for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
101
+ if len(yx_center_line) == 1:
102
+ yx_center_line.append(yx_center_line[-1])
103
+
104
+ offset_expand = 1.0
105
+ if self.valid_set == "totaltext":
106
+ offset_expand = 1.2
107
+
108
+ point_pair_list = []
109
+ for batch_id, y, x in yx_center_line:
110
+ offset = p_border[:, y, x].reshape(2, 2)
111
+ if offset_expand != 1.0:
112
+ offset_length = np.linalg.norm(offset, axis=1, keepdims=True)
113
+ expand_length = np.clip(
114
+ offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0
115
+ )
116
+ offset_detal = offset / offset_length * expand_length
117
+ offset = offset + offset_detal
118
+ ori_yx = np.array([y, x], dtype=np.float32)
119
+ point_pair = (
120
+ (ori_yx + offset)[:, ::-1]
121
+ * 4.0
122
+ / np.array([ratio_w, ratio_h]).reshape(-1, 2)
123
+ )
124
+ point_pair_list.append(point_pair)
125
+
126
+ all_point_list.append(
127
+ [int(round(x * 4.0 / ratio_w)), int(round(y * 4.0 / ratio_h))]
128
+ )
129
+ all_point_pair_list.append(point_pair.round().astype(np.int32).tolist())
130
+
131
+ detected_poly, pair_length_info = point_pair2poly(point_pair_list)
132
+ detected_poly = expand_poly_along_width(
133
+ detected_poly, shrink_ratio_of_width=0.2
134
+ )
135
+ detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
136
+ detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
137
+
138
+ if len(keep_str) < 2:
139
+ continue
140
+
141
+ keep_str_list.append(keep_str)
142
+ detected_poly = np.round(detected_poly).astype("int32")
143
+ if self.valid_set == "partvgg":
144
+ middle_point = len(detected_poly) // 2
145
+ detected_poly = detected_poly[
146
+ [0, middle_point - 1, middle_point, -1], :
147
+ ]
148
+ poly_list.append(detected_poly)
149
+ elif self.valid_set == "totaltext":
150
+ poly_list.append(detected_poly)
151
+ else:
152
+ print("--> Not supported format.")
153
+ exit(-1)
154
+ data = {
155
+ "points": poly_list,
156
+ "texts": keep_str_list,
157
+ }
158
+ return data
159
+
160
+
161
+ class PGPostProcess(object):
162
+ """
163
+ The post process for PGNet.
164
+ """
165
+
166
+ def __init__(self, character_dict_path, valid_set, score_thresh, mode, **kwargs):
167
+ self.character_dict_path = character_dict_path
168
+ self.valid_set = valid_set
169
+ self.score_thresh = score_thresh
170
+ self.mode = mode
171
+
172
+ # c++ la-nms is faster, but only support python 3.5
173
+ self.is_python35 = False
174
+ if sys.version_info.major == 3 and sys.version_info.minor == 5:
175
+ self.is_python35 = True
176
+
177
+ def __call__(self, outs_dict, shape_list):
178
+ post = PGNet_PostProcess(
179
+ self.character_dict_path,
180
+ self.valid_set,
181
+ self.score_thresh,
182
+ outs_dict,
183
+ shape_list,
184
+ )
185
+ if self.mode == "fast":
186
+ data = post.pg_postprocess_fast()
187
+ else:
188
+ data = post.pg_postprocess_slow()
189
+ return data
ocr/postprocess/poly_nms.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from shapely.geometry import Polygon
3
+
4
+
5
+ def points2polygon(points):
6
+ """Convert k points to 1 polygon.
7
+
8
+ Args:
9
+ points (ndarray or list): A ndarray or a list of shape (2k)
10
+ that indicates k points.
11
+
12
+ Returns:
13
+ polygon (Polygon): A polygon object.
14
+ """
15
+ if isinstance(points, list):
16
+ points = np.array(points)
17
+
18
+ assert isinstance(points, np.ndarray)
19
+ assert (points.size % 2 == 0) and (points.size >= 8)
20
+
21
+ point_mat = points.reshape([-1, 2])
22
+ return Polygon(point_mat)
23
+
24
+
25
+ def poly_intersection(poly_det, poly_gt, buffer=0.0001):
26
+ """Calculate the intersection area between two polygon.
27
+
28
+ Args:
29
+ poly_det (Polygon): A polygon predicted by detector.
30
+ poly_gt (Polygon): A gt polygon.
31
+
32
+ Returns:
33
+ intersection_area (float): The intersection area between two polygons.
34
+ """
35
+ assert isinstance(poly_det, Polygon)
36
+ assert isinstance(poly_gt, Polygon)
37
+
38
+ if buffer == 0:
39
+ poly_inter = poly_det & poly_gt
40
+ else:
41
+ poly_inter = poly_det.buffer(buffer) & poly_gt.buffer(buffer)
42
+ return poly_inter.area, poly_inter
43
+
44
+
45
+ def poly_union(poly_det, poly_gt):
46
+ """Calculate the union area between two polygon.
47
+
48
+ Args:
49
+ poly_det (Polygon): A polygon predicted by detector.
50
+ poly_gt (Polygon): A gt polygon.
51
+
52
+ Returns:
53
+ union_area (float): The union area between two polygons.
54
+ """
55
+ assert isinstance(poly_det, Polygon)
56
+ assert isinstance(poly_gt, Polygon)
57
+
58
+ area_det = poly_det.area
59
+ area_gt = poly_gt.area
60
+ area_inters, _ = poly_intersection(poly_det, poly_gt)
61
+ return area_det + area_gt - area_inters
62
+
63
+
64
+ def valid_boundary(x, with_score=True):
65
+ num = len(x)
66
+ if num < 8:
67
+ return False
68
+ if num % 2 == 0 and (not with_score):
69
+ return True
70
+ if num % 2 == 1 and with_score:
71
+ return True
72
+
73
+ return False
74
+
75
+
76
+ def boundary_iou(src, target):
77
+ """Calculate the IOU between two boundaries.
78
+
79
+ Args:
80
+ src (list): Source boundary.
81
+ target (list): Target boundary.
82
+
83
+ Returns:
84
+ iou (float): The iou between two boundaries.
85
+ """
86
+ assert valid_boundary(src, False)
87
+ assert valid_boundary(target, False)
88
+ src_poly = points2polygon(src)
89
+ target_poly = points2polygon(target)
90
+
91
+ return poly_iou(src_poly, target_poly)
92
+
93
+
94
+ def poly_iou(poly_det, poly_gt):
95
+ """Calculate the IOU between two polygons.
96
+
97
+ Args:
98
+ poly_det (Polygon): A polygon predicted by detector.
99
+ poly_gt (Polygon): A gt polygon.
100
+
101
+ Returns:
102
+ iou (float): The IOU between two polygons.
103
+ """
104
+ assert isinstance(poly_det, Polygon)
105
+ assert isinstance(poly_gt, Polygon)
106
+ area_inters, _ = poly_intersection(poly_det, poly_gt)
107
+ area_union = poly_union(poly_det, poly_gt)
108
+ if area_union == 0:
109
+ return 0.0
110
+ return area_inters / area_union
111
+
112
+
113
+ def poly_nms(polygons, threshold):
114
+ assert isinstance(polygons, list)
115
+
116
+ polygons = np.array(sorted(polygons, key=lambda x: x[-1]))
117
+
118
+ keep_poly = []
119
+ index = [i for i in range(polygons.shape[0])]
120
+
121
+ while len(index) > 0:
122
+ keep_poly.append(polygons[index[-1]].tolist())
123
+ A = polygons[index[-1]][:-1]
124
+ index = np.delete(index, -1)
125
+ iou_list = np.zeros((len(index),))
126
+ for i in range(len(index)):
127
+ B = polygons[index[i]][:-1]
128
+ iou_list[i] = boundary_iou(A, B)
129
+ remove_index = np.where(iou_list > threshold)
130
+ index = np.delete(index, remove_index)
131
+
132
+ return keep_poly
ocr/postprocess/pse_postprocess/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .pse_postprocess import PSEPostProcess
ocr/postprocess/pse_postprocess/pse/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import sys
4
+
5
+ python_path = sys.executable
6
+
7
+ ori_path = os.getcwd()
8
+ os.chdir("ppocr/postprocess/pse_postprocess/pse")
9
+ if (
10
+ subprocess.call("{} setup.py build_ext --inplace".format(python_path), shell=True)
11
+ != 0
12
+ ):
13
+ raise RuntimeError(
14
+ "Cannot compile pse: {}, if your system is windows, you need to install all the default components of `desktop development using C++` in visual studio 2019+".format(
15
+ os.path.dirname(os.path.realpath(__file__))
16
+ )
17
+ )
18
+ os.chdir(ori_path)
19
+
20
+ from .pse import pse
ocr/postprocess/pse_postprocess/pse/pse.pyx ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import cv2
3
+ import numpy as np
4
+
5
+ cimport cython
6
+ cimport libcpp
7
+ cimport libcpp.pair
8
+ cimport libcpp.queue
9
+ cimport numpy as np
10
+ from libcpp.pair cimport *
11
+ from libcpp.queue cimport *
12
+
13
+
14
+ @cython.boundscheck(False)
15
+ @cython.wraparound(False)
16
+ cdef np.ndarray[np.int32_t, ndim=2] _pse(np.ndarray[np.uint8_t, ndim=3] kernels,
17
+ np.ndarray[np.int32_t, ndim=2] label,
18
+ int kernel_num,
19
+ int label_num,
20
+ float min_area=0):
21
+ cdef np.ndarray[np.int32_t, ndim=2] pred
22
+ pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32)
23
+
24
+ for label_idx in range(1, label_num):
25
+ if np.sum(label == label_idx) < min_area:
26
+ label[label == label_idx] = 0
27
+
28
+ cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] que = \
29
+ queue[libcpp.pair.pair[np.int16_t,np.int16_t]]()
30
+ cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] nxt_que = \
31
+ queue[libcpp.pair.pair[np.int16_t,np.int16_t]]()
32
+ cdef np.int16_t* dx = [-1, 1, 0, 0]
33
+ cdef np.int16_t* dy = [0, 0, -1, 1]
34
+ cdef np.int16_t tmpx, tmpy
35
+
36
+ points = np.array(np.where(label > 0)).transpose((1, 0))
37
+ for point_idx in range(points.shape[0]):
38
+ tmpx, tmpy = points[point_idx, 0], points[point_idx, 1]
39
+ que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy))
40
+ pred[tmpx, tmpy] = label[tmpx, tmpy]
41
+
42
+ cdef libcpp.pair.pair[np.int16_t,np.int16_t] cur
43
+ cdef int cur_label
44
+ for kernel_idx in range(kernel_num - 1, -1, -1):
45
+ while not que.empty():
46
+ cur = que.front()
47
+ que.pop()
48
+ cur_label = pred[cur.first, cur.second]
49
+
50
+ is_edge = True
51
+ for j in range(4):
52
+ tmpx = cur.first + dx[j]
53
+ tmpy = cur.second + dy[j]
54
+ if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]:
55
+ continue
56
+ if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0:
57
+ continue
58
+
59
+ que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy))
60
+ pred[tmpx, tmpy] = cur_label
61
+ is_edge = False
62
+ if is_edge:
63
+ nxt_que.push(cur)
64
+
65
+ que, nxt_que = nxt_que, que
66
+
67
+ return pred
68
+
69
+ def pse(kernels, min_area):
70
+ kernel_num = kernels.shape[0]
71
+ label_num, label = cv2.connectedComponents(kernels[-1], connectivity=4)
72
+ return _pse(kernels[:-1], label, kernel_num, label_num, min_area)
ocr/postprocess/pse_postprocess/pse/setup.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from distutils.core import Extension, setup
2
+
3
+ import numpy
4
+ from Cython.Build import cythonize
5
+
6
+ setup(
7
+ ext_modules=cythonize(
8
+ Extension(
9
+ "pse",
10
+ sources=["pse.pyx"],
11
+ language="c++",
12
+ include_dirs=[numpy.get_include()],
13
+ library_dirs=[],
14
+ libraries=[],
15
+ extra_compile_args=["-O3"],
16
+ extra_link_args=[],
17
+ )
18
+ )
19
+ )
ocr/postprocess/pse_postprocess/pse_postprocess.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, division, print_function
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import paddle
6
+ from paddle.nn import functional as F
7
+
8
+ from .pse import pse
9
+
10
+
11
+ class PSEPostProcess(object):
12
+ """
13
+ The post process for PSE.
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ thresh=0.5,
19
+ box_thresh=0.85,
20
+ min_area=16,
21
+ box_type="quad",
22
+ scale=4,
23
+ **kwargs
24
+ ):
25
+ assert box_type in ["quad", "poly"], "Only quad and poly is supported"
26
+ self.thresh = thresh
27
+ self.box_thresh = box_thresh
28
+ self.min_area = min_area
29
+ self.box_type = box_type
30
+ self.scale = scale
31
+
32
+ def __call__(self, outs_dict, shape_list):
33
+ pred = outs_dict["maps"]
34
+ if not isinstance(pred, paddle.Tensor):
35
+ pred = paddle.to_tensor(pred)
36
+ pred = F.interpolate(pred, scale_factor=4 // self.scale, mode="bilinear")
37
+
38
+ score = F.sigmoid(pred[:, 0, :, :])
39
+
40
+ kernels = (pred > self.thresh).astype("float32")
41
+ text_mask = kernels[:, 0, :, :]
42
+ kernels[:, 0:, :, :] = kernels[:, 0:, :, :] * text_mask
43
+
44
+ score = score.numpy()
45
+ kernels = kernels.numpy().astype(np.uint8)
46
+
47
+ boxes_batch = []
48
+ for batch_index in range(pred.shape[0]):
49
+ boxes, scores = self.boxes_from_bitmap(
50
+ score[batch_index], kernels[batch_index], shape_list[batch_index]
51
+ )
52
+
53
+ boxes_batch.append({"points": boxes, "scores": scores})
54
+ return boxes_batch
55
+
56
+ def boxes_from_bitmap(self, score, kernels, shape):
57
+ label = pse(kernels, self.min_area)
58
+ return self.generate_box(score, label, shape)
59
+
60
+ def generate_box(self, score, label, shape):
61
+ src_h, src_w, ratio_h, ratio_w = shape
62
+ label_num = np.max(label) + 1
63
+
64
+ boxes = []
65
+ scores = []
66
+ for i in range(1, label_num):
67
+ ind = label == i
68
+ points = np.array(np.where(ind)).transpose((1, 0))[:, ::-1]
69
+
70
+ if points.shape[0] < self.min_area:
71
+ label[ind] = 0
72
+ continue
73
+
74
+ score_i = np.mean(score[ind])
75
+ if score_i < self.box_thresh:
76
+ label[ind] = 0
77
+ continue
78
+
79
+ if self.box_type == "quad":
80
+ rect = cv2.minAreaRect(points)
81
+ bbox = cv2.boxPoints(rect)
82
+ elif self.box_type == "poly":
83
+ box_height = np.max(points[:, 1]) + 10
84
+ box_width = np.max(points[:, 0]) + 10
85
+
86
+ mask = np.zeros((box_height, box_width), np.uint8)
87
+ mask[points[:, 1], points[:, 0]] = 255
88
+
89
+ contours, _ = cv2.findContours(
90
+ mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
91
+ )
92
+ bbox = np.squeeze(contours[0], 1)
93
+ else:
94
+ raise NotImplementedError
95
+
96
+ bbox[:, 0] = np.clip(np.round(bbox[:, 0] / ratio_w), 0, src_w)
97
+ bbox[:, 1] = np.clip(np.round(bbox[:, 1] / ratio_h), 0, src_h)
98
+ boxes.append(bbox)
99
+ scores.append(score_i)
100
+ return boxes, scores
ocr/postprocess/rec_postprocess.py ADDED
@@ -0,0 +1,731 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import numpy as np
4
+ import paddle
5
+ from paddle.nn import functional as F
6
+
7
+
8
+ class BaseRecLabelDecode(object):
9
+ """Convert between text-label and text-index"""
10
+
11
+ def __init__(self, character_dict_path=None, use_space_char=False):
12
+ self.beg_str = "sos"
13
+ self.end_str = "eos"
14
+
15
+ self.character_str = []
16
+ if character_dict_path is None:
17
+ self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
18
+ dict_character = list(self.character_str)
19
+ else:
20
+ with open(character_dict_path, "rb") as fin:
21
+ lines = fin.readlines()
22
+ for line in lines:
23
+ line = line.decode("utf-8").strip("\n").strip("\r\n")
24
+ self.character_str.append(line)
25
+ if use_space_char:
26
+ self.character_str.append(" ")
27
+ dict_character = list(self.character_str)
28
+
29
+ dict_character = self.add_special_char(dict_character)
30
+ self.dict = {}
31
+ for i, char in enumerate(dict_character):
32
+ self.dict[char] = i
33
+ self.character = dict_character
34
+
35
+ def add_special_char(self, dict_character):
36
+ return dict_character
37
+
38
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
39
+ """convert text-index into text-label."""
40
+ result_list = []
41
+ ignored_tokens = self.get_ignored_tokens()
42
+ batch_size = len(text_index)
43
+ for batch_idx in range(batch_size):
44
+ selection = np.ones(len(text_index[batch_idx]), dtype=bool)
45
+ if is_remove_duplicate:
46
+ selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
47
+ for ignored_token in ignored_tokens:
48
+ selection &= text_index[batch_idx] != ignored_token
49
+
50
+ char_list = [
51
+ self.character[text_id] for text_id in text_index[batch_idx][selection]
52
+ ]
53
+ if text_prob is not None:
54
+ conf_list = text_prob[batch_idx][selection]
55
+ else:
56
+ conf_list = [1] * len(selection)
57
+ if len(conf_list) == 0:
58
+ conf_list = [0]
59
+
60
+ text = "".join(char_list)
61
+ result_list.append((text, np.mean(conf_list).tolist()))
62
+ return result_list
63
+
64
+ def get_ignored_tokens(self):
65
+ return [0] # for ctc blank
66
+
67
+
68
+ class CTCLabelDecode(BaseRecLabelDecode):
69
+ """Convert between text-label and text-index"""
70
+
71
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
72
+ super(CTCLabelDecode, self).__init__(character_dict_path, use_space_char)
73
+
74
+ def __call__(self, preds, label=None, *args, **kwargs):
75
+ if isinstance(preds, tuple) or isinstance(preds, list):
76
+ preds = preds[-1]
77
+ if isinstance(preds, paddle.Tensor):
78
+ preds = preds.numpy()
79
+ preds_idx = preds.argmax(axis=2)
80
+ preds_prob = preds.max(axis=2)
81
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
82
+ if label is None:
83
+ return text
84
+ label = self.decode(label)
85
+ return text, label
86
+
87
+ def add_special_char(self, dict_character):
88
+ dict_character = ["blank"] + dict_character
89
+ return dict_character
90
+
91
+
92
+ class DistillationCTCLabelDecode(CTCLabelDecode):
93
+ """
94
+ Convert
95
+ Convert between text-label and text-index
96
+ """
97
+
98
+ def __init__(
99
+ self,
100
+ character_dict_path=None,
101
+ use_space_char=False,
102
+ model_name=["student"],
103
+ key=None,
104
+ multi_head=False,
105
+ **kwargs
106
+ ):
107
+ super(DistillationCTCLabelDecode, self).__init__(
108
+ character_dict_path, use_space_char
109
+ )
110
+ if not isinstance(model_name, list):
111
+ model_name = [model_name]
112
+ self.model_name = model_name
113
+
114
+ self.key = key
115
+ self.multi_head = multi_head
116
+
117
+ def __call__(self, preds, label=None, *args, **kwargs):
118
+ output = dict()
119
+ for name in self.model_name:
120
+ pred = preds[name]
121
+ if self.key is not None:
122
+ pred = pred[self.key]
123
+ if self.multi_head and isinstance(pred, dict):
124
+ pred = pred["ctc"]
125
+ output[name] = super().__call__(pred, label=label, *args, **kwargs)
126
+ return output
127
+
128
+
129
+ class NRTRLabelDecode(BaseRecLabelDecode):
130
+ """Convert between text-label and text-index"""
131
+
132
+ def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
133
+ super(NRTRLabelDecode, self).__init__(character_dict_path, use_space_char)
134
+
135
+ def __call__(self, preds, label=None, *args, **kwargs):
136
+
137
+ if len(preds) == 2:
138
+ preds_id = preds[0]
139
+ preds_prob = preds[1]
140
+ if isinstance(preds_id, paddle.Tensor):
141
+ preds_id = preds_id.numpy()
142
+ if isinstance(preds_prob, paddle.Tensor):
143
+ preds_prob = preds_prob.numpy()
144
+ if preds_id[0][0] == 2:
145
+ preds_idx = preds_id[:, 1:]
146
+ preds_prob = preds_prob[:, 1:]
147
+ else:
148
+ preds_idx = preds_id
149
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
150
+ if label is None:
151
+ return text
152
+ label = self.decode(label[:, 1:])
153
+ else:
154
+ if isinstance(preds, paddle.Tensor):
155
+ preds = preds.numpy()
156
+ preds_idx = preds.argmax(axis=2)
157
+ preds_prob = preds.max(axis=2)
158
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
159
+ if label is None:
160
+ return text
161
+ label = self.decode(label[:, 1:])
162
+ return text, label
163
+
164
+ def add_special_char(self, dict_character):
165
+ dict_character = ["blank", "<unk>", "<s>", "</s>"] + dict_character
166
+ return dict_character
167
+
168
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
169
+ """convert text-index into text-label."""
170
+ result_list = []
171
+ batch_size = len(text_index)
172
+ for batch_idx in range(batch_size):
173
+ char_list = []
174
+ conf_list = []
175
+ for idx in range(len(text_index[batch_idx])):
176
+ if text_index[batch_idx][idx] == 3: # end
177
+ break
178
+ try:
179
+ char_list.append(self.character[int(text_index[batch_idx][idx])])
180
+ except:
181
+ continue
182
+ if text_prob is not None:
183
+ conf_list.append(text_prob[batch_idx][idx])
184
+ else:
185
+ conf_list.append(1)
186
+ text = "".join(char_list)
187
+ result_list.append((text.lower(), np.mean(conf_list).tolist()))
188
+ return result_list
189
+
190
+
191
+ class AttnLabelDecode(BaseRecLabelDecode):
192
+ """Convert between text-label and text-index"""
193
+
194
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
195
+ super(AttnLabelDecode, self).__init__(character_dict_path, use_space_char)
196
+
197
+ def add_special_char(self, dict_character):
198
+ self.beg_str = "sos"
199
+ self.end_str = "eos"
200
+ dict_character = dict_character
201
+ dict_character = [self.beg_str] + dict_character + [self.end_str]
202
+ return dict_character
203
+
204
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
205
+ """convert text-index into text-label."""
206
+ result_list = []
207
+ ignored_tokens = self.get_ignored_tokens()
208
+ [beg_idx, end_idx] = self.get_ignored_tokens()
209
+ batch_size = len(text_index)
210
+ for batch_idx in range(batch_size):
211
+ char_list = []
212
+ conf_list = []
213
+ for idx in range(len(text_index[batch_idx])):
214
+ if text_index[batch_idx][idx] in ignored_tokens:
215
+ continue
216
+ if int(text_index[batch_idx][idx]) == int(end_idx):
217
+ break
218
+ if is_remove_duplicate:
219
+ # only for predict
220
+ if (
221
+ idx > 0
222
+ and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
223
+ ):
224
+ continue
225
+ char_list.append(self.character[int(text_index[batch_idx][idx])])
226
+ if text_prob is not None:
227
+ conf_list.append(text_prob[batch_idx][idx])
228
+ else:
229
+ conf_list.append(1)
230
+ text = "".join(char_list)
231
+ result_list.append((text, np.mean(conf_list).tolist()))
232
+ return result_list
233
+
234
+ def __call__(self, preds, label=None, *args, **kwargs):
235
+ """
236
+ text = self.decode(text)
237
+ if label is None:
238
+ return text
239
+ else:
240
+ label = self.decode(label, is_remove_duplicate=False)
241
+ return text, label
242
+ """
243
+ if isinstance(preds, paddle.Tensor):
244
+ preds = preds.numpy()
245
+
246
+ preds_idx = preds.argmax(axis=2)
247
+ preds_prob = preds.max(axis=2)
248
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
249
+ if label is None:
250
+ return text
251
+ label = self.decode(label, is_remove_duplicate=False)
252
+ return text, label
253
+
254
+ def get_ignored_tokens(self):
255
+ beg_idx = self.get_beg_end_flag_idx("beg")
256
+ end_idx = self.get_beg_end_flag_idx("end")
257
+ return [beg_idx, end_idx]
258
+
259
+ def get_beg_end_flag_idx(self, beg_or_end):
260
+ if beg_or_end == "beg":
261
+ idx = np.array(self.dict[self.beg_str])
262
+ elif beg_or_end == "end":
263
+ idx = np.array(self.dict[self.end_str])
264
+ else:
265
+ assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
266
+ return idx
267
+
268
+
269
+ class SEEDLabelDecode(BaseRecLabelDecode):
270
+ """Convert between text-label and text-index"""
271
+
272
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
273
+ super(SEEDLabelDecode, self).__init__(character_dict_path, use_space_char)
274
+
275
+ def add_special_char(self, dict_character):
276
+ self.padding_str = "padding"
277
+ self.end_str = "eos"
278
+ self.unknown = "unknown"
279
+ dict_character = dict_character + [self.end_str, self.padding_str, self.unknown]
280
+ return dict_character
281
+
282
+ def get_ignored_tokens(self):
283
+ end_idx = self.get_beg_end_flag_idx("eos")
284
+ return [end_idx]
285
+
286
+ def get_beg_end_flag_idx(self, beg_or_end):
287
+ if beg_or_end == "sos":
288
+ idx = np.array(self.dict[self.beg_str])
289
+ elif beg_or_end == "eos":
290
+ idx = np.array(self.dict[self.end_str])
291
+ else:
292
+ assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
293
+ return idx
294
+
295
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
296
+ """convert text-index into text-label."""
297
+ result_list = []
298
+ [end_idx] = self.get_ignored_tokens()
299
+ batch_size = len(text_index)
300
+ for batch_idx in range(batch_size):
301
+ char_list = []
302
+ conf_list = []
303
+ for idx in range(len(text_index[batch_idx])):
304
+ if int(text_index[batch_idx][idx]) == int(end_idx):
305
+ break
306
+ if is_remove_duplicate:
307
+ # only for predict
308
+ if (
309
+ idx > 0
310
+ and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
311
+ ):
312
+ continue
313
+ char_list.append(self.character[int(text_index[batch_idx][idx])])
314
+ if text_prob is not None:
315
+ conf_list.append(text_prob[batch_idx][idx])
316
+ else:
317
+ conf_list.append(1)
318
+ text = "".join(char_list)
319
+ result_list.append((text, np.mean(conf_list).tolist()))
320
+ return result_list
321
+
322
+ def __call__(self, preds, label=None, *args, **kwargs):
323
+ """
324
+ text = self.decode(text)
325
+ if label is None:
326
+ return text
327
+ else:
328
+ label = self.decode(label, is_remove_duplicate=False)
329
+ return text, label
330
+ """
331
+ preds_idx = preds["rec_pred"]
332
+ if isinstance(preds_idx, paddle.Tensor):
333
+ preds_idx = preds_idx.numpy()
334
+ if "rec_pred_scores" in preds:
335
+ preds_idx = preds["rec_pred"]
336
+ preds_prob = preds["rec_pred_scores"]
337
+ else:
338
+ preds_idx = preds["rec_pred"].argmax(axis=2)
339
+ preds_prob = preds["rec_pred"].max(axis=2)
340
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
341
+ if label is None:
342
+ return text
343
+ label = self.decode(label, is_remove_duplicate=False)
344
+ return text, label
345
+
346
+
347
+ class SRNLabelDecode(BaseRecLabelDecode):
348
+ """Convert between text-label and text-index"""
349
+
350
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
351
+ super(SRNLabelDecode, self).__init__(character_dict_path, use_space_char)
352
+ self.max_text_length = kwargs.get("max_text_length", 25)
353
+
354
+ def __call__(self, preds, label=None, *args, **kwargs):
355
+ pred = preds["predict"]
356
+ char_num = len(self.character_str) + 2
357
+ if isinstance(pred, paddle.Tensor):
358
+ pred = pred.numpy()
359
+ pred = np.reshape(pred, [-1, char_num])
360
+
361
+ preds_idx = np.argmax(pred, axis=1)
362
+ preds_prob = np.max(pred, axis=1)
363
+
364
+ preds_idx = np.reshape(preds_idx, [-1, self.max_text_length])
365
+
366
+ preds_prob = np.reshape(preds_prob, [-1, self.max_text_length])
367
+
368
+ text = self.decode(preds_idx, preds_prob)
369
+
370
+ if label is None:
371
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
372
+ return text
373
+ label = self.decode(label)
374
+ return text, label
375
+
376
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
377
+ """convert text-index into text-label."""
378
+ result_list = []
379
+ ignored_tokens = self.get_ignored_tokens()
380
+ batch_size = len(text_index)
381
+
382
+ for batch_idx in range(batch_size):
383
+ char_list = []
384
+ conf_list = []
385
+ for idx in range(len(text_index[batch_idx])):
386
+ if text_index[batch_idx][idx] in ignored_tokens:
387
+ continue
388
+ if is_remove_duplicate:
389
+ # only for predict
390
+ if (
391
+ idx > 0
392
+ and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
393
+ ):
394
+ continue
395
+ char_list.append(self.character[int(text_index[batch_idx][idx])])
396
+ if text_prob is not None:
397
+ conf_list.append(text_prob[batch_idx][idx])
398
+ else:
399
+ conf_list.append(1)
400
+
401
+ text = "".join(char_list)
402
+ result_list.append((text, np.mean(conf_list).tolist()))
403
+ return result_list
404
+
405
+ def add_special_char(self, dict_character):
406
+ dict_character = dict_character + [self.beg_str, self.end_str]
407
+ return dict_character
408
+
409
+ def get_ignored_tokens(self):
410
+ beg_idx = self.get_beg_end_flag_idx("beg")
411
+ end_idx = self.get_beg_end_flag_idx("end")
412
+ return [beg_idx, end_idx]
413
+
414
+ def get_beg_end_flag_idx(self, beg_or_end):
415
+ if beg_or_end == "beg":
416
+ idx = np.array(self.dict[self.beg_str])
417
+ elif beg_or_end == "end":
418
+ idx = np.array(self.dict[self.end_str])
419
+ else:
420
+ assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
421
+ return idx
422
+
423
+
424
+ class TableLabelDecode(object):
425
+ """ """
426
+
427
+ def __init__(self, character_dict_path, **kwargs):
428
+ list_character, list_elem = self.load_char_elem_dict(character_dict_path)
429
+ list_character = self.add_special_char(list_character)
430
+ list_elem = self.add_special_char(list_elem)
431
+ self.dict_character = {}
432
+ self.dict_idx_character = {}
433
+ for i, char in enumerate(list_character):
434
+ self.dict_idx_character[i] = char
435
+ self.dict_character[char] = i
436
+ self.dict_elem = {}
437
+ self.dict_idx_elem = {}
438
+ for i, elem in enumerate(list_elem):
439
+ self.dict_idx_elem[i] = elem
440
+ self.dict_elem[elem] = i
441
+
442
+ def load_char_elem_dict(self, character_dict_path):
443
+ list_character = []
444
+ list_elem = []
445
+ with open(character_dict_path, "rb") as fin:
446
+ lines = fin.readlines()
447
+ substr = lines[0].decode("utf-8").strip("\n").strip("\r\n").split("\t")
448
+ character_num = int(substr[0])
449
+ elem_num = int(substr[1])
450
+ for cno in range(1, 1 + character_num):
451
+ character = lines[cno].decode("utf-8").strip("\n").strip("\r\n")
452
+ list_character.append(character)
453
+ for eno in range(1 + character_num, 1 + character_num + elem_num):
454
+ elem = lines[eno].decode("utf-8").strip("\n").strip("\r\n")
455
+ list_elem.append(elem)
456
+ return list_character, list_elem
457
+
458
+ def add_special_char(self, list_character):
459
+ self.beg_str = "sos"
460
+ self.end_str = "eos"
461
+ list_character = [self.beg_str] + list_character + [self.end_str]
462
+ return list_character
463
+
464
+ def __call__(self, preds):
465
+ structure_probs = preds["structure_probs"]
466
+ loc_preds = preds["loc_preds"]
467
+ if isinstance(structure_probs, paddle.Tensor):
468
+ structure_probs = structure_probs.numpy()
469
+ if isinstance(loc_preds, paddle.Tensor):
470
+ loc_preds = loc_preds.numpy()
471
+ structure_idx = structure_probs.argmax(axis=2)
472
+ structure_probs = structure_probs.max(axis=2)
473
+ (
474
+ structure_str,
475
+ structure_pos,
476
+ result_score_list,
477
+ result_elem_idx_list,
478
+ ) = self.decode(structure_idx, structure_probs, "elem")
479
+ res_html_code_list = []
480
+ res_loc_list = []
481
+ batch_num = len(structure_str)
482
+ for bno in range(batch_num):
483
+ res_loc = []
484
+ for sno in range(len(structure_str[bno])):
485
+ text = structure_str[bno][sno]
486
+ if text in ["<td>", "<td"]:
487
+ pos = structure_pos[bno][sno]
488
+ res_loc.append(loc_preds[bno, pos])
489
+ res_html_code = "".join(structure_str[bno])
490
+ res_loc = np.array(res_loc)
491
+ res_html_code_list.append(res_html_code)
492
+ res_loc_list.append(res_loc)
493
+ return {
494
+ "res_html_code": res_html_code_list,
495
+ "res_loc": res_loc_list,
496
+ "res_score_list": result_score_list,
497
+ "res_elem_idx_list": result_elem_idx_list,
498
+ "structure_str_list": structure_str,
499
+ }
500
+
501
+ def decode(self, text_index, structure_probs, char_or_elem):
502
+ """convert text-label into text-index."""
503
+ if char_or_elem == "char":
504
+ current_dict = self.dict_idx_character
505
+ else:
506
+ current_dict = self.dict_idx_elem
507
+ ignored_tokens = self.get_ignored_tokens("elem")
508
+ beg_idx, end_idx = ignored_tokens
509
+
510
+ result_list = []
511
+ result_pos_list = []
512
+ result_score_list = []
513
+ result_elem_idx_list = []
514
+ batch_size = len(text_index)
515
+ for batch_idx in range(batch_size):
516
+ char_list = []
517
+ elem_pos_list = []
518
+ elem_idx_list = []
519
+ score_list = []
520
+ for idx in range(len(text_index[batch_idx])):
521
+ tmp_elem_idx = int(text_index[batch_idx][idx])
522
+ if idx > 0 and tmp_elem_idx == end_idx:
523
+ break
524
+ if tmp_elem_idx in ignored_tokens:
525
+ continue
526
+
527
+ char_list.append(current_dict[tmp_elem_idx])
528
+ elem_pos_list.append(idx)
529
+ score_list.append(structure_probs[batch_idx, idx])
530
+ elem_idx_list.append(tmp_elem_idx)
531
+ result_list.append(char_list)
532
+ result_pos_list.append(elem_pos_list)
533
+ result_score_list.append(score_list)
534
+ result_elem_idx_list.append(elem_idx_list)
535
+ return result_list, result_pos_list, result_score_list, result_elem_idx_list
536
+
537
+ def get_ignored_tokens(self, char_or_elem):
538
+ beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
539
+ end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
540
+ return [beg_idx, end_idx]
541
+
542
+ def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
543
+ if char_or_elem == "char":
544
+ if beg_or_end == "beg":
545
+ idx = self.dict_character[self.beg_str]
546
+ elif beg_or_end == "end":
547
+ idx = self.dict_character[self.end_str]
548
+ else:
549
+ assert False, (
550
+ "Unsupport type %s in get_beg_end_flag_idx of char" % beg_or_end
551
+ )
552
+ elif char_or_elem == "elem":
553
+ if beg_or_end == "beg":
554
+ idx = self.dict_elem[self.beg_str]
555
+ elif beg_or_end == "end":
556
+ idx = self.dict_elem[self.end_str]
557
+ else:
558
+ assert False, (
559
+ "Unsupport type %s in get_beg_end_flag_idx of elem" % beg_or_end
560
+ )
561
+ else:
562
+ assert False, "Unsupport type %s in char_or_elem" % char_or_elem
563
+ return idx
564
+
565
+
566
+ class SARLabelDecode(BaseRecLabelDecode):
567
+ """Convert between text-label and text-index"""
568
+
569
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
570
+ super(SARLabelDecode, self).__init__(character_dict_path, use_space_char)
571
+
572
+ self.rm_symbol = kwargs.get("rm_symbol", False)
573
+
574
+ def add_special_char(self, dict_character):
575
+ beg_end_str = "<BOS/EOS>"
576
+ unknown_str = "<UKN>"
577
+ padding_str = "<PAD>"
578
+ dict_character = dict_character + [unknown_str]
579
+ self.unknown_idx = len(dict_character) - 1
580
+ dict_character = dict_character + [beg_end_str]
581
+ self.start_idx = len(dict_character) - 1
582
+ self.end_idx = len(dict_character) - 1
583
+ dict_character = dict_character + [padding_str]
584
+ self.padding_idx = len(dict_character) - 1
585
+ return dict_character
586
+
587
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
588
+ """convert text-index into text-label."""
589
+ result_list = []
590
+ ignored_tokens = self.get_ignored_tokens()
591
+
592
+ batch_size = len(text_index)
593
+ for batch_idx in range(batch_size):
594
+ char_list = []
595
+ conf_list = []
596
+ for idx in range(len(text_index[batch_idx])):
597
+ if text_index[batch_idx][idx] in ignored_tokens:
598
+ continue
599
+ if int(text_index[batch_idx][idx]) == int(self.end_idx):
600
+ if text_prob is None and idx == 0:
601
+ continue
602
+ else:
603
+ break
604
+ if is_remove_duplicate:
605
+ # only for predict
606
+ if (
607
+ idx > 0
608
+ and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
609
+ ):
610
+ continue
611
+ char_list.append(self.character[int(text_index[batch_idx][idx])])
612
+ if text_prob is not None:
613
+ conf_list.append(text_prob[batch_idx][idx])
614
+ else:
615
+ conf_list.append(1)
616
+ text = "".join(char_list)
617
+ if self.rm_symbol:
618
+ comp = re.compile("[^A-Z^a-z^0-9^\u4e00-\u9fa5]")
619
+ text = text.lower()
620
+ text = comp.sub("", text)
621
+ result_list.append((text, np.mean(conf_list).tolist()))
622
+ return result_list
623
+
624
+ def __call__(self, preds, label=None, *args, **kwargs):
625
+ if isinstance(preds, paddle.Tensor):
626
+ preds = preds.numpy()
627
+ preds_idx = preds.argmax(axis=2)
628
+ preds_prob = preds.max(axis=2)
629
+
630
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
631
+
632
+ if label is None:
633
+ return text
634
+ label = self.decode(label, is_remove_duplicate=False)
635
+ return text, label
636
+
637
+ def get_ignored_tokens(self):
638
+ return [self.padding_idx]
639
+
640
+
641
+ class DistillationSARLabelDecode(SARLabelDecode):
642
+ """
643
+ Convert
644
+ Convert between text-label and text-index
645
+ """
646
+
647
+ def __init__(
648
+ self,
649
+ character_dict_path=None,
650
+ use_space_char=False,
651
+ model_name=["student"],
652
+ key=None,
653
+ multi_head=False,
654
+ **kwargs
655
+ ):
656
+ super(DistillationSARLabelDecode, self).__init__(
657
+ character_dict_path, use_space_char
658
+ )
659
+ if not isinstance(model_name, list):
660
+ model_name = [model_name]
661
+ self.model_name = model_name
662
+
663
+ self.key = key
664
+ self.multi_head = multi_head
665
+
666
+ def __call__(self, preds, label=None, *args, **kwargs):
667
+ output = dict()
668
+ for name in self.model_name:
669
+ pred = preds[name]
670
+ if self.key is not None:
671
+ pred = pred[self.key]
672
+ if self.multi_head and isinstance(pred, dict):
673
+ pred = pred["sar"]
674
+ output[name] = super().__call__(pred, label=label, *args, **kwargs)
675
+ return output
676
+
677
+
678
+ class PRENLabelDecode(BaseRecLabelDecode):
679
+ """Convert between text-label and text-index"""
680
+
681
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
682
+ super(PRENLabelDecode, self).__init__(character_dict_path, use_space_char)
683
+
684
+ def add_special_char(self, dict_character):
685
+ padding_str = "<PAD>" # 0
686
+ end_str = "<EOS>" # 1
687
+ unknown_str = "<UNK>" # 2
688
+
689
+ dict_character = [padding_str, end_str, unknown_str] + dict_character
690
+ self.padding_idx = 0
691
+ self.end_idx = 1
692
+ self.unknown_idx = 2
693
+
694
+ return dict_character
695
+
696
+ def decode(self, text_index, text_prob=None):
697
+ """convert text-index into text-label."""
698
+ result_list = []
699
+ batch_size = len(text_index)
700
+
701
+ for batch_idx in range(batch_size):
702
+ char_list = []
703
+ conf_list = []
704
+ for idx in range(len(text_index[batch_idx])):
705
+ if text_index[batch_idx][idx] == self.end_idx:
706
+ break
707
+ if text_index[batch_idx][idx] in [self.padding_idx, self.unknown_idx]:
708
+ continue
709
+ char_list.append(self.character[int(text_index[batch_idx][idx])])
710
+ if text_prob is not None:
711
+ conf_list.append(text_prob[batch_idx][idx])
712
+ else:
713
+ conf_list.append(1)
714
+
715
+ text = "".join(char_list)
716
+ if len(text) > 0:
717
+ result_list.append((text, np.mean(conf_list).tolist()))
718
+ else:
719
+ # here confidence of empty recog result is 1
720
+ result_list.append(("", 1))
721
+ return result_list
722
+
723
+ def __call__(self, preds, label=None, *args, **kwargs):
724
+ preds = preds.numpy()
725
+ preds_idx = preds.argmax(axis=2)
726
+ preds_prob = preds.max(axis=2)
727
+ text = self.decode(preds_idx, preds_prob)
728
+ if label is None:
729
+ return text
730
+ label = self.decode(label)
731
+ return text, label
ocr/postprocess/sast_postprocess.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, division, print_function
2
+
3
+ import os
4
+ import sys
5
+
6
+ __dir__ = os.path.dirname(__file__)
7
+ sys.path.append(__dir__)
8
+ sys.path.append(os.path.join(__dir__, ".."))
9
+
10
+ import time
11
+
12
+ import cv2
13
+ import numpy as np
14
+ import paddle
15
+
16
+ from .locality_aware_nms import nms_locality
17
+
18
+
19
+ class SASTPostProcess(object):
20
+ """
21
+ The post process for SAST.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ score_thresh=0.5,
27
+ nms_thresh=0.2,
28
+ sample_pts_num=2,
29
+ shrink_ratio_of_width=0.3,
30
+ expand_scale=1.0,
31
+ tcl_map_thresh=0.5,
32
+ **kwargs
33
+ ):
34
+
35
+ self.score_thresh = score_thresh
36
+ self.nms_thresh = nms_thresh
37
+ self.sample_pts_num = sample_pts_num
38
+ self.shrink_ratio_of_width = shrink_ratio_of_width
39
+ self.expand_scale = expand_scale
40
+ self.tcl_map_thresh = tcl_map_thresh
41
+
42
+ # c++ la-nms is faster, but only support python 3.5
43
+ self.is_python35 = False
44
+ if sys.version_info.major == 3 and sys.version_info.minor == 5:
45
+ self.is_python35 = True
46
+
47
+ def point_pair2poly(self, point_pair_list):
48
+ """
49
+ Transfer vertical point_pairs into poly point in clockwise.
50
+ """
51
+ # constract poly
52
+ point_num = len(point_pair_list) * 2
53
+ point_list = [0] * point_num
54
+ for idx, point_pair in enumerate(point_pair_list):
55
+ point_list[idx] = point_pair[0]
56
+ point_list[point_num - 1 - idx] = point_pair[1]
57
+ return np.array(point_list).reshape(-1, 2)
58
+
59
+ def shrink_quad_along_width(self, quad, begin_width_ratio=0.0, end_width_ratio=1.0):
60
+ """
61
+ Generate shrink_quad_along_width.
62
+ """
63
+ ratio_pair = np.array(
64
+ [[begin_width_ratio], [end_width_ratio]], dtype=np.float32
65
+ )
66
+ p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
67
+ p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
68
+ return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
69
+
70
+ def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3):
71
+ """
72
+ expand poly along width.
73
+ """
74
+ point_num = poly.shape[0]
75
+ left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
76
+ left_ratio = (
77
+ -shrink_ratio_of_width
78
+ * np.linalg.norm(left_quad[0] - left_quad[3])
79
+ / (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
80
+ )
81
+ left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, 1.0)
82
+ right_quad = np.array(
83
+ [
84
+ poly[point_num // 2 - 2],
85
+ poly[point_num // 2 - 1],
86
+ poly[point_num // 2],
87
+ poly[point_num // 2 + 1],
88
+ ],
89
+ dtype=np.float32,
90
+ )
91
+ right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(
92
+ right_quad[0] - right_quad[3]
93
+ ) / (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
94
+ right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, right_ratio)
95
+ poly[0] = left_quad_expand[0]
96
+ poly[-1] = left_quad_expand[-1]
97
+ poly[point_num // 2 - 1] = right_quad_expand[1]
98
+ poly[point_num // 2] = right_quad_expand[2]
99
+ return poly
100
+
101
+ def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map):
102
+ """Restore quad."""
103
+ xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
104
+ xy_text = xy_text[:, ::-1] # (n, 2)
105
+
106
+ # Sort the text boxes via the y axis
107
+ xy_text = xy_text[np.argsort(xy_text[:, 1])]
108
+
109
+ scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
110
+ scores = scores[:, np.newaxis]
111
+
112
+ # Restore
113
+ point_num = int(tvo_map.shape[-1] / 2)
114
+ assert point_num == 4
115
+ tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :]
116
+ xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2)
117
+ quads = xy_text_tile - tvo_map
118
+
119
+ return scores, quads, xy_text
120
+
121
+ def quad_area(self, quad):
122
+ """
123
+ compute area of a quad.
124
+ """
125
+ edge = [
126
+ (quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
127
+ (quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
128
+ (quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
129
+ (quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1]),
130
+ ]
131
+ return np.sum(edge) / 2.0
132
+
133
+ def nms(self, dets):
134
+ if self.is_python35:
135
+ import lanms
136
+
137
+ dets = lanms.merge_quadrangle_n9(dets, self.nms_thresh)
138
+ else:
139
+ dets = nms_locality(dets, self.nms_thresh)
140
+ return dets
141
+
142
+ def cluster_by_quads_tco(self, tcl_map, tcl_map_thresh, quads, tco_map):
143
+ """
144
+ Cluster pixels in tcl_map based on quads.
145
+ """
146
+ instance_count = quads.shape[0] + 1 # contain background
147
+ instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32)
148
+ if instance_count == 1:
149
+ return instance_count, instance_label_map
150
+
151
+ # predict text center
152
+ xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
153
+ n = xy_text.shape[0]
154
+ xy_text = xy_text[:, ::-1] # (n, 2)
155
+ tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2)
156
+ pred_tc = xy_text - tco
157
+
158
+ # get gt text center
159
+ m = quads.shape[0]
160
+ gt_tc = np.mean(quads, axis=1) # (m, 2)
161
+
162
+ pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], (1, m, 1)) # (n, m, 2)
163
+ gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2)
164
+ dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m)
165
+ xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,)
166
+
167
+ instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign
168
+ return instance_count, instance_label_map
169
+
170
+ def estimate_sample_pts_num(self, quad, xy_text):
171
+ """
172
+ Estimate sample points number.
173
+ """
174
+ eh = (
175
+ np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])
176
+ ) / 2.0
177
+ ew = (
178
+ np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])
179
+ ) / 2.0
180
+
181
+ dense_sample_pts_num = max(2, int(ew))
182
+ dense_xy_center_line = xy_text[
183
+ np.linspace(
184
+ 0,
185
+ xy_text.shape[0] - 1,
186
+ dense_sample_pts_num,
187
+ endpoint=True,
188
+ dtype=np.float32,
189
+ ).astype(np.int32)
190
+ ]
191
+
192
+ dense_xy_center_line_diff = dense_xy_center_line[1:] - dense_xy_center_line[:-1]
193
+ estimate_arc_len = np.sum(np.linalg.norm(dense_xy_center_line_diff, axis=1))
194
+
195
+ sample_pts_num = max(2, int(estimate_arc_len / eh))
196
+ return sample_pts_num
197
+
198
+ def detect_sast(
199
+ self,
200
+ tcl_map,
201
+ tvo_map,
202
+ tbo_map,
203
+ tco_map,
204
+ ratio_w,
205
+ ratio_h,
206
+ src_w,
207
+ src_h,
208
+ shrink_ratio_of_width=0.3,
209
+ tcl_map_thresh=0.5,
210
+ offset_expand=1.0,
211
+ out_strid=4.0,
212
+ ):
213
+ """
214
+ first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys
215
+ """
216
+ # restore quad
217
+ scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, tvo_map)
218
+ dets = np.hstack((quads, scores)).astype(np.float32, copy=False)
219
+ dets = self.nms(dets)
220
+ if dets.shape[0] == 0:
221
+ return []
222
+ quads = dets[:, :-1].reshape(-1, 4, 2)
223
+
224
+ # Compute quad area
225
+ quad_areas = []
226
+ for quad in quads:
227
+ quad_areas.append(-self.quad_area(quad))
228
+
229
+ # instance segmentation
230
+ # instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8)
231
+ instance_count, instance_label_map = self.cluster_by_quads_tco(
232
+ tcl_map, tcl_map_thresh, quads, tco_map
233
+ )
234
+
235
+ # restore single poly with tcl instance.
236
+ poly_list = []
237
+ for instance_idx in range(1, instance_count):
238
+ xy_text = np.argwhere(instance_label_map == instance_idx)[:, ::-1]
239
+ quad = quads[instance_idx - 1]
240
+ q_area = quad_areas[instance_idx - 1]
241
+ if q_area < 5:
242
+ continue
243
+
244
+ #
245
+ len1 = float(np.linalg.norm(quad[0] - quad[1]))
246
+ len2 = float(np.linalg.norm(quad[1] - quad[2]))
247
+ min_len = min(len1, len2)
248
+ if min_len < 3:
249
+ continue
250
+
251
+ # filter small CC
252
+ if xy_text.shape[0] <= 0:
253
+ continue
254
+
255
+ # filter low confidence instance
256
+ xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
257
+ if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1:
258
+ # if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
259
+ continue
260
+
261
+ # sort xy_text
262
+ left_center_pt = np.array(
263
+ [[(quad[0, 0] + quad[-1, 0]) / 2.0, (quad[0, 1] + quad[-1, 1]) / 2.0]]
264
+ ) # (1, 2)
265
+ right_center_pt = np.array(
266
+ [[(quad[1, 0] + quad[2, 0]) / 2.0, (quad[1, 1] + quad[2, 1]) / 2.0]]
267
+ ) # (1, 2)
268
+ proj_unit_vec = (right_center_pt - left_center_pt) / (
269
+ np.linalg.norm(right_center_pt - left_center_pt) + 1e-6
270
+ )
271
+ proj_value = np.sum(xy_text * proj_unit_vec, axis=1)
272
+ xy_text = xy_text[np.argsort(proj_value)]
273
+
274
+ # Sample pts in tcl map
275
+ if self.sample_pts_num == 0:
276
+ sample_pts_num = self.estimate_sample_pts_num(quad, xy_text)
277
+ else:
278
+ sample_pts_num = self.sample_pts_num
279
+ xy_center_line = xy_text[
280
+ np.linspace(
281
+ 0,
282
+ xy_text.shape[0] - 1,
283
+ sample_pts_num,
284
+ endpoint=True,
285
+ dtype=np.float32,
286
+ ).astype(np.int32)
287
+ ]
288
+
289
+ point_pair_list = []
290
+ for x, y in xy_center_line:
291
+ # get corresponding offset
292
+ offset = tbo_map[y, x, :].reshape(2, 2)
293
+ if offset_expand != 1.0:
294
+ offset_length = np.linalg.norm(offset, axis=1, keepdims=True)
295
+ expand_length = np.clip(
296
+ offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0
297
+ )
298
+ offset_detal = offset / offset_length * expand_length
299
+ offset = offset + offset_detal
300
+ # original point
301
+ ori_yx = np.array([y, x], dtype=np.float32)
302
+ point_pair = (
303
+ (ori_yx + offset)[:, ::-1]
304
+ * out_strid
305
+ / np.array([ratio_w, ratio_h]).reshape(-1, 2)
306
+ )
307
+ point_pair_list.append(point_pair)
308
+
309
+ # ndarry: (x, 2), expand poly along width
310
+ detected_poly = self.point_pair2poly(point_pair_list)
311
+ detected_poly = self.expand_poly_along_width(
312
+ detected_poly, shrink_ratio_of_width
313
+ )
314
+ detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
315
+ detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
316
+ poly_list.append(detected_poly)
317
+
318
+ return poly_list
319
+
320
+ def __call__(self, outs_dict, shape_list):
321
+ score_list = outs_dict["f_score"]
322
+ border_list = outs_dict["f_border"]
323
+ tvo_list = outs_dict["f_tvo"]
324
+ tco_list = outs_dict["f_tco"]
325
+ if isinstance(score_list, paddle.Tensor):
326
+ score_list = score_list.numpy()
327
+ border_list = border_list.numpy()
328
+ tvo_list = tvo_list.numpy()
329
+ tco_list = tco_list.numpy()
330
+
331
+ img_num = len(shape_list)
332
+ poly_lists = []
333
+ for ino in range(img_num):
334
+ p_score = score_list[ino].transpose((1, 2, 0))
335
+ p_border = border_list[ino].transpose((1, 2, 0))
336
+ p_tvo = tvo_list[ino].transpose((1, 2, 0))
337
+ p_tco = tco_list[ino].transpose((1, 2, 0))
338
+ src_h, src_w, ratio_h, ratio_w = shape_list[ino]
339
+
340
+ poly_list = self.detect_sast(
341
+ p_score,
342
+ p_tvo,
343
+ p_border,
344
+ p_tco,
345
+ ratio_w,
346
+ ratio_h,
347
+ src_w,
348
+ src_h,
349
+ shrink_ratio_of_width=self.shrink_ratio_of_width,
350
+ tcl_map_thresh=self.tcl_map_thresh,
351
+ offset_expand=self.expand_scale,
352
+ )
353
+ poly_lists.append({"points": np.array(poly_list)})
354
+
355
+ return poly_lists
ocr/postprocess/vqa_token_re_layoutlm_postprocess.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class VQAReTokenLayoutLMPostProcess(object):
2
+ """Convert between text-label and text-index"""
3
+
4
+ def __init__(self, **kwargs):
5
+ super(VQAReTokenLayoutLMPostProcess, self).__init__()
6
+
7
+ def __call__(self, preds, label=None, *args, **kwargs):
8
+ if label is not None:
9
+ return self._metric(preds, label)
10
+ else:
11
+ return self._infer(preds, *args, **kwargs)
12
+
13
+ def _metric(self, preds, label):
14
+ return preds["pred_relations"], label[6], label[5]
15
+
16
+ def _infer(self, preds, *args, **kwargs):
17
+ ser_results = kwargs["ser_results"]
18
+ entity_idx_dict_batch = kwargs["entity_idx_dict_batch"]
19
+ pred_relations = preds["pred_relations"]
20
+
21
+ # merge relations and ocr info
22
+ results = []
23
+ for pred_relation, ser_result, entity_idx_dict in zip(
24
+ pred_relations, ser_results, entity_idx_dict_batch
25
+ ):
26
+ result = []
27
+ used_tail_id = []
28
+ for relation in pred_relation:
29
+ if relation["tail_id"] in used_tail_id:
30
+ continue
31
+ used_tail_id.append(relation["tail_id"])
32
+ ocr_info_head = ser_result[entity_idx_dict[relation["head_id"]]]
33
+ ocr_info_tail = ser_result[entity_idx_dict[relation["tail_id"]]]
34
+ result.append((ocr_info_head, ocr_info_tail))
35
+ results.append(result)
36
+ return results
ocr/postprocess/vqa_token_ser_layoutlm_postprocess.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import paddle
3
+
4
+
5
+ def load_vqa_bio_label_maps(label_map_path):
6
+ with open(label_map_path, "r", encoding="utf-8") as fin:
7
+ lines = fin.readlines()
8
+ lines = [line.strip() for line in lines]
9
+ if "O" not in lines:
10
+ lines.insert(0, "O")
11
+ labels = []
12
+ for line in lines:
13
+ if line == "O":
14
+ labels.append("O")
15
+ else:
16
+ labels.append("B-" + line)
17
+ labels.append("I-" + line)
18
+ label2id_map = {label: idx for idx, label in enumerate(labels)}
19
+ id2label_map = {idx: label for idx, label in enumerate(labels)}
20
+ return label2id_map, id2label_map
21
+
22
+
23
+ class VQASerTokenLayoutLMPostProcess(object):
24
+ """Convert between text-label and text-index"""
25
+
26
+ def __init__(self, class_path, **kwargs):
27
+ super(VQASerTokenLayoutLMPostProcess, self).__init__()
28
+ label2id_map, self.id2label_map = load_vqa_bio_label_maps(class_path)
29
+
30
+ self.label2id_map_for_draw = dict()
31
+ for key in label2id_map:
32
+ if key.startswith("I-"):
33
+ self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
34
+ else:
35
+ self.label2id_map_for_draw[key] = label2id_map[key]
36
+
37
+ self.id2label_map_for_show = dict()
38
+ for key in self.label2id_map_for_draw:
39
+ val = self.label2id_map_for_draw[key]
40
+ if key == "O":
41
+ self.id2label_map_for_show[val] = key
42
+ if key.startswith("B-") or key.startswith("I-"):
43
+ self.id2label_map_for_show[val] = key[2:]
44
+ else:
45
+ self.id2label_map_for_show[val] = key
46
+
47
+ def __call__(self, preds, batch=None, *args, **kwargs):
48
+ if isinstance(preds, paddle.Tensor):
49
+ preds = preds.numpy()
50
+
51
+ if batch is not None:
52
+ return self._metric(preds, batch[1])
53
+ else:
54
+ return self._infer(preds, **kwargs)
55
+
56
+ def _metric(self, preds, label):
57
+ pred_idxs = preds.argmax(axis=2)
58
+ decode_out_list = [[] for _ in range(pred_idxs.shape[0])]
59
+ label_decode_out_list = [[] for _ in range(pred_idxs.shape[0])]
60
+
61
+ for i in range(pred_idxs.shape[0]):
62
+ for j in range(pred_idxs.shape[1]):
63
+ if label[i, j] != -100:
64
+ label_decode_out_list[i].append(self.id2label_map[label[i, j]])
65
+ decode_out_list[i].append(self.id2label_map[pred_idxs[i, j]])
66
+ return decode_out_list, label_decode_out_list
67
+
68
+ def _infer(self, preds, attention_masks, segment_offset_ids, ocr_infos):
69
+ results = []
70
+
71
+ for pred, attention_mask, segment_offset_id, ocr_info in zip(
72
+ preds, attention_masks, segment_offset_ids, ocr_infos
73
+ ):
74
+ pred = np.argmax(pred, axis=1)
75
+ pred = [self.id2label_map[idx] for idx in pred]
76
+
77
+ for idx in range(len(segment_offset_id)):
78
+ if idx == 0:
79
+ start_id = 0
80
+ else:
81
+ start_id = segment_offset_id[idx - 1]
82
+
83
+ end_id = segment_offset_id[idx]
84
+
85
+ curr_pred = pred[start_id:end_id]
86
+ curr_pred = [self.label2id_map_for_draw[p] for p in curr_pred]
87
+
88
+ if len(curr_pred) <= 0:
89
+ pred_id = 0
90
+ else:
91
+ counts = np.bincount(curr_pred)
92
+ pred_id = np.argmax(counts)
93
+ ocr_info[idx]["pred_id"] = int(pred_id)
94
+ ocr_info[idx]["pred"] = self.id2label_map_for_show[int(pred_id)]
95
+ results.append(ocr_info)
96
+ return results
ocr/ppocr/__init__.py ADDED
File without changes
ocr/ppocr/data/__init__.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, division, print_function, unicode_literals
2
+
3
+ import os
4
+ import signal
5
+ import sys
6
+
7
+ __dir__ = os.path.dirname(os.path.abspath(__file__))
8
+ sys.path.append(os.path.abspath(os.path.join(__dir__, "../..")))
9
+
10
+ import copy
11
+
12
+ from paddle.io import BatchSampler, DataLoader, DistributedBatchSampler
13
+
14
+ from .imaug import create_operators, transform
15
+
16
+ __all__ = ["build_dataloader", "transform", "create_operators"]
17
+
18
+
19
+ def term_mp(sig_num, frame):
20
+ """kill all child processes"""
21
+ pid = os.getpid()
22
+ pgid = os.getpgid(os.getpid())
23
+ print("main proc {} exit, kill process group " "{}".format(pid, pgid))
24
+ os.killpg(pgid, signal.SIGKILL)
25
+
26
+
27
+ def build_dataloader(config, mode, device, logger, seed=None):
28
+ config = copy.deepcopy(config)
29
+
30
+ support_dict = ["SimpleDataSet", "LMDBDataSet", "PGDataSet", "PubTabDataSet"]
31
+ module_name = config[mode]["dataset"]["name"]
32
+ assert module_name in support_dict, Exception(
33
+ "DataSet only support {}".format(support_dict)
34
+ )
35
+ assert mode in ["Train", "Eval", "Test"], "Mode should be Train, Eval or Test."
36
+
37
+ dataset = eval(module_name)(config, mode, logger, seed)
38
+ loader_config = config[mode]["loader"]
39
+ batch_size = loader_config["batch_size_per_card"]
40
+ drop_last = loader_config["drop_last"]
41
+ shuffle = loader_config["shuffle"]
42
+ num_workers = loader_config["num_workers"]
43
+ if "use_shared_memory" in loader_config.keys():
44
+ use_shared_memory = loader_config["use_shared_memory"]
45
+ else:
46
+ use_shared_memory = True
47
+
48
+ if mode == "Train":
49
+ # Distribute data to multiple cards
50
+ batch_sampler = DistributedBatchSampler(
51
+ dataset=dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
52
+ )
53
+ else:
54
+ # Distribute data to single card
55
+ batch_sampler = BatchSampler(
56
+ dataset=dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
57
+ )
58
+
59
+ if "collate_fn" in loader_config:
60
+ from . import collate_fn
61
+
62
+ collate_fn = getattr(collate_fn, loader_config["collate_fn"])()
63
+ else:
64
+ collate_fn = None
65
+ data_loader = DataLoader(
66
+ dataset=dataset,
67
+ batch_sampler=batch_sampler,
68
+ places=device,
69
+ num_workers=num_workers,
70
+ return_list=True,
71
+ use_shared_memory=use_shared_memory,
72
+ collate_fn=collate_fn,
73
+ )
74
+
75
+ # support exit using ctrl+c
76
+ signal.signal(signal.SIGINT, term_mp)
77
+ signal.signal(signal.SIGTERM, term_mp)
78
+
79
+ return data_loader
ocr/ppocr/data/collate_fn.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numbers
2
+ from collections import defaultdict
3
+
4
+ import numpy as np
5
+ import paddle
6
+
7
+
8
+ class DictCollator(object):
9
+ """
10
+ data batch
11
+ """
12
+
13
+ def __call__(self, batch):
14
+ # todo:support batch operators
15
+ data_dict = defaultdict(list)
16
+ to_tensor_keys = []
17
+ for sample in batch:
18
+ for k, v in sample.items():
19
+ if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
20
+ if k not in to_tensor_keys:
21
+ to_tensor_keys.append(k)
22
+ data_dict[k].append(v)
23
+ for k in to_tensor_keys:
24
+ data_dict[k] = paddle.to_tensor(data_dict[k])
25
+ return data_dict
26
+
27
+
28
+ class ListCollator(object):
29
+ """
30
+ data batch
31
+ """
32
+
33
+ def __call__(self, batch):
34
+ # todo:support batch operators
35
+ data_dict = defaultdict(list)
36
+ to_tensor_idxs = []
37
+ for sample in batch:
38
+ for idx, v in enumerate(sample):
39
+ if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
40
+ if idx not in to_tensor_idxs:
41
+ to_tensor_idxs.append(idx)
42
+ data_dict[idx].append(v)
43
+ for idx in to_tensor_idxs:
44
+ data_dict[idx] = paddle.to_tensor(data_dict[idx])
45
+ return list(data_dict.values())
46
+
47
+
48
+ class SSLRotateCollate(object):
49
+ """
50
+ bach: [
51
+ [(4*3xH*W), (4,)]
52
+ [(4*3xH*W), (4,)]
53
+ ...
54
+ ]
55
+ """
56
+
57
+ def __call__(self, batch):
58
+ output = [np.concatenate(d, axis=0) for d in zip(*batch)]
59
+ return output
ocr/ppocr/data/imaug/ColorJitter.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from paddle.vision.transforms import ColorJitter as pp_ColorJitter
2
+
3
+ __all__ = ["ColorJitter"]
4
+
5
+
6
+ class ColorJitter(object):
7
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, **kwargs):
8
+ self.aug = pp_ColorJitter(brightness, contrast, saturation, hue)
9
+
10
+ def __call__(self, data):
11
+ image = data["image"]
12
+ image = self.aug(image)
13
+ data["image"] = image
14
+ return data
ocr/ppocr/data/imaug/__init__.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, division, print_function, unicode_literals
2
+
3
+ from .ColorJitter import ColorJitter
4
+ from .copy_paste import CopyPaste
5
+ from .east_process import *
6
+ from .fce_aug import *
7
+ from .fce_targets import FCENetTargets
8
+ from .gen_table_mask import *
9
+ from .iaa_augment import IaaAugment
10
+ from .label_ops import *
11
+ from .make_border_map import MakeBorderMap
12
+ from .make_pse_gt import MakePseGt
13
+ from .make_shrink_map import MakeShrinkMap
14
+ from .operators import *
15
+ from .pg_process import *
16
+ from .randaugment import RandAugment
17
+ from .random_crop_data import EastRandomCropData, RandomCropImgMask
18
+ from .rec_img_aug import (
19
+ ClsResizeImg,
20
+ NRTRRecResizeImg,
21
+ PRENResizeImg,
22
+ RecAug,
23
+ RecConAug,
24
+ RecResizeImg,
25
+ SARRecResizeImg,
26
+ SRNRecResizeImg,
27
+ )
28
+ from .sast_process import *
29
+ from .ssl_img_aug import SSLRotateResize
30
+ from .vqa import *
31
+
32
+
33
+ def transform(data, ops=None):
34
+ """transform"""
35
+ if ops is None:
36
+ ops = []
37
+ for op in ops:
38
+ data = op(data)
39
+ if data is None:
40
+ return None
41
+ return data
42
+
43
+
44
+ def create_operators(op_param_list, global_config=None):
45
+ """
46
+ create operators based on the config
47
+
48
+ Args:
49
+ params(list): a dict list, used to create some operators
50
+ """
51
+ assert isinstance(op_param_list, list), "operator config should be a list"
52
+ ops = []
53
+ for operator in op_param_list:
54
+ assert isinstance(operator, dict) and len(operator) == 1, "yaml format error"
55
+ op_name = list(operator)[0]
56
+ param = {} if operator[op_name] is None else operator[op_name]
57
+ if global_config is not None:
58
+ param.update(global_config)
59
+ op = eval(op_name)(**param)
60
+ ops.append(op)
61
+ return ops
ocr/ppocr/data/imaug/copy_paste.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import cv2
5
+ import numpy as np
6
+ from PIL import Image
7
+ from shapely.geometry import Polygon
8
+
9
+ from ppocr.data.imaug.iaa_augment import IaaAugment
10
+ from ppocr.data.imaug.random_crop_data import is_poly_outside_rect
11
+ from utility import get_rotate_crop_image
12
+
13
+
14
+ class CopyPaste(object):
15
+ def __init__(self, objects_paste_ratio=0.2, limit_paste=True, **kwargs):
16
+ self.ext_data_num = 1
17
+ self.objects_paste_ratio = objects_paste_ratio
18
+ self.limit_paste = limit_paste
19
+ augmenter_args = [{"type": "Resize", "args": {"size": [0.5, 3]}}]
20
+ self.aug = IaaAugment(augmenter_args)
21
+
22
+ def __call__(self, data):
23
+ point_num = data["polys"].shape[1]
24
+ src_img = data["image"]
25
+ src_polys = data["polys"].tolist()
26
+ src_texts = data["texts"]
27
+ src_ignores = data["ignore_tags"].tolist()
28
+ ext_data = data["ext_data"][0]
29
+ ext_image = ext_data["image"]
30
+ ext_polys = ext_data["polys"]
31
+ ext_texts = ext_data["texts"]
32
+ ext_ignores = ext_data["ignore_tags"]
33
+
34
+ indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]]
35
+ select_num = max(1, min(int(self.objects_paste_ratio * len(ext_polys)), 30))
36
+
37
+ random.shuffle(indexs)
38
+ select_idxs = indexs[:select_num]
39
+ select_polys = ext_polys[select_idxs]
40
+ select_ignores = ext_ignores[select_idxs]
41
+
42
+ src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
43
+ ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB)
44
+ src_img = Image.fromarray(src_img).convert("RGBA")
45
+ for idx, poly, tag in zip(select_idxs, select_polys, select_ignores):
46
+ box_img = get_rotate_crop_image(ext_image, poly)
47
+
48
+ src_img, box = self.paste_img(src_img, box_img, src_polys)
49
+ if box is not None:
50
+ box = box.tolist()
51
+ for _ in range(len(box), point_num):
52
+ box.append(box[-1])
53
+ src_polys.append(box)
54
+ src_texts.append(ext_texts[idx])
55
+ src_ignores.append(tag)
56
+ src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR)
57
+ h, w = src_img.shape[:2]
58
+ src_polys = np.array(src_polys)
59
+ src_polys[:, :, 0] = np.clip(src_polys[:, :, 0], 0, w)
60
+ src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h)
61
+ data["image"] = src_img
62
+ data["polys"] = src_polys
63
+ data["texts"] = src_texts
64
+ data["ignore_tags"] = np.array(src_ignores)
65
+ return data
66
+
67
+ def paste_img(self, src_img, box_img, src_polys):
68
+ box_img_pil = Image.fromarray(box_img).convert("RGBA")
69
+ src_w, src_h = src_img.size
70
+ box_w, box_h = box_img_pil.size
71
+
72
+ angle = np.random.randint(0, 360)
73
+ box = np.array([[[0, 0], [box_w, 0], [box_w, box_h], [0, box_h]]])
74
+ box = rotate_bbox(box_img, box, angle)[0]
75
+ box_img_pil = box_img_pil.rotate(angle, expand=1)
76
+ box_w, box_h = box_img_pil.width, box_img_pil.height
77
+ if src_w - box_w < 0 or src_h - box_h < 0:
78
+ return src_img, None
79
+
80
+ paste_x, paste_y = self.select_coord(
81
+ src_polys, box, src_w - box_w, src_h - box_h
82
+ )
83
+ if paste_x is None:
84
+ return src_img, None
85
+ box[:, 0] += paste_x
86
+ box[:, 1] += paste_y
87
+ r, g, b, A = box_img_pil.split()
88
+ src_img.paste(box_img_pil, (paste_x, paste_y), mask=A)
89
+
90
+ return src_img, box
91
+
92
+ def select_coord(self, src_polys, box, endx, endy):
93
+ if self.limit_paste:
94
+ xmin, ymin, xmax, ymax = (
95
+ box[:, 0].min(),
96
+ box[:, 1].min(),
97
+ box[:, 0].max(),
98
+ box[:, 1].max(),
99
+ )
100
+ for _ in range(50):
101
+ paste_x = random.randint(0, endx)
102
+ paste_y = random.randint(0, endy)
103
+ xmin1 = xmin + paste_x
104
+ xmax1 = xmax + paste_x
105
+ ymin1 = ymin + paste_y
106
+ ymax1 = ymax + paste_y
107
+
108
+ num_poly_in_rect = 0
109
+ for poly in src_polys:
110
+ if not is_poly_outside_rect(
111
+ poly, xmin1, ymin1, xmax1 - xmin1, ymax1 - ymin1
112
+ ):
113
+ num_poly_in_rect += 1
114
+ break
115
+ if num_poly_in_rect == 0:
116
+ return paste_x, paste_y
117
+ return None, None
118
+ else:
119
+ paste_x = random.randint(0, endx)
120
+ paste_y = random.randint(0, endy)
121
+ return paste_x, paste_y
122
+
123
+
124
+ def get_union(pD, pG):
125
+ return Polygon(pD).union(Polygon(pG)).area
126
+
127
+
128
+ def get_intersection_over_union(pD, pG):
129
+ return get_intersection(pD, pG) / get_union(pD, pG)
130
+
131
+
132
+ def get_intersection(pD, pG):
133
+ return Polygon(pD).intersection(Polygon(pG)).area
134
+
135
+
136
+ def rotate_bbox(img, text_polys, angle, scale=1):
137
+ """
138
+ from https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/augment.py
139
+ Args:
140
+ img: np.ndarray
141
+ text_polys: np.ndarray N*4*2
142
+ angle: int
143
+ scale: int
144
+
145
+ Returns:
146
+
147
+ """
148
+ w = img.shape[1]
149
+ h = img.shape[0]
150
+
151
+ rangle = np.deg2rad(angle)
152
+ nw = abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w)
153
+ nh = abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w)
154
+ rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale)
155
+ rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
156
+ rot_mat[0, 2] += rot_move[0]
157
+ rot_mat[1, 2] += rot_move[1]
158
+
159
+ # ---------------------- rotate box ----------------------
160
+ rot_text_polys = list()
161
+ for bbox in text_polys:
162
+ point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1]))
163
+ point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1]))
164
+ point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1]))
165
+ point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1]))
166
+ rot_text_polys.append([point1, point2, point3, point4])
167
+ return np.array(rot_text_polys, dtype=np.float32)
ocr/ppocr/data/imaug/east_process.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import cv2
4
+ import numpy as np
5
+
6
+ __all__ = ["EASTProcessTrain"]
7
+
8
+
9
+ class EASTProcessTrain(object):
10
+ def __init__(
11
+ self,
12
+ image_shape=[512, 512],
13
+ background_ratio=0.125,
14
+ min_crop_side_ratio=0.1,
15
+ min_text_size=10,
16
+ **kwargs
17
+ ):
18
+ self.input_size = image_shape[1]
19
+ self.random_scale = np.array([0.5, 1, 2.0, 3.0])
20
+ self.background_ratio = background_ratio
21
+ self.min_crop_side_ratio = min_crop_side_ratio
22
+ self.min_text_size = min_text_size
23
+
24
+ def preprocess(self, im):
25
+ input_size = self.input_size
26
+ im_shape = im.shape
27
+ im_size_min = np.min(im_shape[0:2])
28
+ im_size_max = np.max(im_shape[0:2])
29
+ im_scale = float(input_size) / float(im_size_max)
30
+ im = cv2.resize(im, None, None, fx=im_scale, fy=im_scale)
31
+ img_mean = [0.485, 0.456, 0.406]
32
+ img_std = [0.229, 0.224, 0.225]
33
+ # im = im[:, :, ::-1].astype(np.float32)
34
+ im = im / 255
35
+ im -= img_mean
36
+ im /= img_std
37
+ new_h, new_w, _ = im.shape
38
+ im_padded = np.zeros((input_size, input_size, 3), dtype=np.float32)
39
+ im_padded[:new_h, :new_w, :] = im
40
+ im_padded = im_padded.transpose((2, 0, 1))
41
+ im_padded = im_padded[np.newaxis, :]
42
+ return im_padded, im_scale
43
+
44
+ def rotate_im_poly(self, im, text_polys):
45
+ """
46
+ rotate image with 90 / 180 / 270 degre
47
+ """
48
+ im_w, im_h = im.shape[1], im.shape[0]
49
+ dst_im = im.copy()
50
+ dst_polys = []
51
+ rand_degree_ratio = np.random.rand()
52
+ rand_degree_cnt = 1
53
+ if 0.333 < rand_degree_ratio < 0.666:
54
+ rand_degree_cnt = 2
55
+ elif rand_degree_ratio > 0.666:
56
+ rand_degree_cnt = 3
57
+ for i in range(rand_degree_cnt):
58
+ dst_im = np.rot90(dst_im)
59
+ rot_degree = -90 * rand_degree_cnt
60
+ rot_angle = rot_degree * math.pi / 180.0
61
+ n_poly = text_polys.shape[0]
62
+ cx, cy = 0.5 * im_w, 0.5 * im_h
63
+ ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0]
64
+ for i in range(n_poly):
65
+ wordBB = text_polys[i]
66
+ poly = []
67
+ for j in range(4):
68
+ sx, sy = wordBB[j][0], wordBB[j][1]
69
+ dx = (
70
+ math.cos(rot_angle) * (sx - cx)
71
+ - math.sin(rot_angle) * (sy - cy)
72
+ + ncx
73
+ )
74
+ dy = (
75
+ math.sin(rot_angle) * (sx - cx)
76
+ + math.cos(rot_angle) * (sy - cy)
77
+ + ncy
78
+ )
79
+ poly.append([dx, dy])
80
+ dst_polys.append(poly)
81
+ dst_polys = np.array(dst_polys, dtype=np.float32)
82
+ return dst_im, dst_polys
83
+
84
+ def polygon_area(self, poly):
85
+ """
86
+ compute area of a polygon
87
+ :param poly:
88
+ :return:
89
+ """
90
+ edge = [
91
+ (poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
92
+ (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
93
+ (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
94
+ (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1]),
95
+ ]
96
+ return np.sum(edge) / 2.0
97
+
98
+ def check_and_validate_polys(self, polys, tags, img_height, img_width):
99
+ """
100
+ check so that the text poly is in the same direction,
101
+ and also filter some invalid polygons
102
+ :param polys:
103
+ :param tags:
104
+ :return:
105
+ """
106
+ h, w = img_height, img_width
107
+ if polys.shape[0] == 0:
108
+ return polys
109
+ polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
110
+ polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
111
+
112
+ validated_polys = []
113
+ validated_tags = []
114
+ for poly, tag in zip(polys, tags):
115
+ p_area = self.polygon_area(poly)
116
+ # invalid poly
117
+ if abs(p_area) < 1:
118
+ continue
119
+ if p_area > 0:
120
+ #'poly in wrong direction'
121
+ if not tag:
122
+ tag = True # reversed cases should be ignore
123
+ poly = poly[(0, 3, 2, 1), :]
124
+ validated_polys.append(poly)
125
+ validated_tags.append(tag)
126
+ return np.array(validated_polys), np.array(validated_tags)
127
+
128
+ def draw_img_polys(self, img, polys):
129
+ if len(img.shape) == 4:
130
+ img = np.squeeze(img, axis=0)
131
+ if img.shape[0] == 3:
132
+ img = img.transpose((1, 2, 0))
133
+ img[:, :, 2] += 123.68
134
+ img[:, :, 1] += 116.78
135
+ img[:, :, 0] += 103.94
136
+ cv2.imwrite("tmp.jpg", img)
137
+ img = cv2.imread("tmp.jpg")
138
+ for box in polys:
139
+ box = box.astype(np.int32).reshape((-1, 1, 2))
140
+ cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2)
141
+ import random
142
+
143
+ ino = random.randint(0, 100)
144
+ cv2.imwrite("tmp_%d.jpg" % ino, img)
145
+ return
146
+
147
+ def shrink_poly(self, poly, r):
148
+ """
149
+ fit a poly inside the origin poly, maybe bugs here...
150
+ used for generate the score map
151
+ :param poly: the text poly
152
+ :param r: r in the paper
153
+ :return: the shrinked poly
154
+ """
155
+ # shrink ratio
156
+ R = 0.3
157
+ # find the longer pair
158
+ dist0 = np.linalg.norm(poly[0] - poly[1])
159
+ dist1 = np.linalg.norm(poly[2] - poly[3])
160
+ dist2 = np.linalg.norm(poly[0] - poly[3])
161
+ dist3 = np.linalg.norm(poly[1] - poly[2])
162
+ if dist0 + dist1 > dist2 + dist3:
163
+ # first move (p0, p1), (p2, p3), then (p0, p3), (p1, p2)
164
+ ## p0, p1
165
+ theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0]))
166
+ poly[0][0] += R * r[0] * np.cos(theta)
167
+ poly[0][1] += R * r[0] * np.sin(theta)
168
+ poly[1][0] -= R * r[1] * np.cos(theta)
169
+ poly[1][1] -= R * r[1] * np.sin(theta)
170
+ ## p2, p3
171
+ theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0]))
172
+ poly[3][0] += R * r[3] * np.cos(theta)
173
+ poly[3][1] += R * r[3] * np.sin(theta)
174
+ poly[2][0] -= R * r[2] * np.cos(theta)
175
+ poly[2][1] -= R * r[2] * np.sin(theta)
176
+ ## p0, p3
177
+ theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1]))
178
+ poly[0][0] += R * r[0] * np.sin(theta)
179
+ poly[0][1] += R * r[0] * np.cos(theta)
180
+ poly[3][0] -= R * r[3] * np.sin(theta)
181
+ poly[3][1] -= R * r[3] * np.cos(theta)
182
+ ## p1, p2
183
+ theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1]))
184
+ poly[1][0] += R * r[1] * np.sin(theta)
185
+ poly[1][1] += R * r[1] * np.cos(theta)
186
+ poly[2][0] -= R * r[2] * np.sin(theta)
187
+ poly[2][1] -= R * r[2] * np.cos(theta)
188
+ else:
189
+ ## p0, p3
190
+ # print poly
191
+ theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1]))
192
+ poly[0][0] += R * r[0] * np.sin(theta)
193
+ poly[0][1] += R * r[0] * np.cos(theta)
194
+ poly[3][0] -= R * r[3] * np.sin(theta)
195
+ poly[3][1] -= R * r[3] * np.cos(theta)
196
+ ## p1, p2
197
+ theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1]))
198
+ poly[1][0] += R * r[1] * np.sin(theta)
199
+ poly[1][1] += R * r[1] * np.cos(theta)
200
+ poly[2][0] -= R * r[2] * np.sin(theta)
201
+ poly[2][1] -= R * r[2] * np.cos(theta)
202
+ ## p0, p1
203
+ theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0]))
204
+ poly[0][0] += R * r[0] * np.cos(theta)
205
+ poly[0][1] += R * r[0] * np.sin(theta)
206
+ poly[1][0] -= R * r[1] * np.cos(theta)
207
+ poly[1][1] -= R * r[1] * np.sin(theta)
208
+ ## p2, p3
209
+ theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0]))
210
+ poly[3][0] += R * r[3] * np.cos(theta)
211
+ poly[3][1] += R * r[3] * np.sin(theta)
212
+ poly[2][0] -= R * r[2] * np.cos(theta)
213
+ poly[2][1] -= R * r[2] * np.sin(theta)
214
+ return poly
215
+
216
+ def generate_quad(self, im_size, polys, tags):
217
+ """
218
+ Generate quadrangle.
219
+ """
220
+ h, w = im_size
221
+ poly_mask = np.zeros((h, w), dtype=np.uint8)
222
+ score_map = np.zeros((h, w), dtype=np.uint8)
223
+ # (x1, y1, ..., x4, y4, short_edge_norm)
224
+ geo_map = np.zeros((h, w, 9), dtype=np.float32)
225
+ # mask used during traning, to ignore some hard areas
226
+ training_mask = np.ones((h, w), dtype=np.uint8)
227
+ for poly_idx, poly_tag in enumerate(zip(polys, tags)):
228
+ poly = poly_tag[0]
229
+ tag = poly_tag[1]
230
+
231
+ r = [None, None, None, None]
232
+ for i in range(4):
233
+ dist1 = np.linalg.norm(poly[i] - poly[(i + 1) % 4])
234
+ dist2 = np.linalg.norm(poly[i] - poly[(i - 1) % 4])
235
+ r[i] = min(dist1, dist2)
236
+ # score map
237
+ shrinked_poly = self.shrink_poly(poly.copy(), r).astype(np.int32)[
238
+ np.newaxis, :, :
239
+ ]
240
+ cv2.fillPoly(score_map, shrinked_poly, 1)
241
+ cv2.fillPoly(poly_mask, shrinked_poly, poly_idx + 1)
242
+ # if the poly is too small, then ignore it during training
243
+ poly_h = min(
244
+ np.linalg.norm(poly[0] - poly[3]), np.linalg.norm(poly[1] - poly[2])
245
+ )
246
+ poly_w = min(
247
+ np.linalg.norm(poly[0] - poly[1]), np.linalg.norm(poly[2] - poly[3])
248
+ )
249
+ if min(poly_h, poly_w) < self.min_text_size:
250
+ cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0)
251
+
252
+ if tag:
253
+ cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0)
254
+
255
+ xy_in_poly = np.argwhere(poly_mask == (poly_idx + 1))
256
+ # geo map.
257
+ y_in_poly = xy_in_poly[:, 0]
258
+ x_in_poly = xy_in_poly[:, 1]
259
+ poly[:, 0] = np.minimum(np.maximum(poly[:, 0], 0), w)
260
+ poly[:, 1] = np.minimum(np.maximum(poly[:, 1], 0), h)
261
+ for pno in range(4):
262
+ geo_channel_beg = pno * 2
263
+ geo_map[y_in_poly, x_in_poly, geo_channel_beg] = (
264
+ x_in_poly - poly[pno, 0]
265
+ )
266
+ geo_map[y_in_poly, x_in_poly, geo_channel_beg + 1] = (
267
+ y_in_poly - poly[pno, 1]
268
+ )
269
+ geo_map[y_in_poly, x_in_poly, 8] = 1.0 / max(min(poly_h, poly_w), 1.0)
270
+ return score_map, geo_map, training_mask
271
+
272
+ def crop_area(self, im, polys, tags, crop_background=False, max_tries=50):
273
+ """
274
+ make random crop from the input image
275
+ :param im:
276
+ :param polys:
277
+ :param tags:
278
+ :param crop_background:
279
+ :param max_tries:
280
+ :return:
281
+ """
282
+ h, w, _ = im.shape
283
+ pad_h = h // 10
284
+ pad_w = w // 10
285
+ h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
286
+ w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
287
+ for poly in polys:
288
+ poly = np.round(poly, decimals=0).astype(np.int32)
289
+ minx = np.min(poly[:, 0])
290
+ maxx = np.max(poly[:, 0])
291
+ w_array[minx + pad_w : maxx + pad_w] = 1
292
+ miny = np.min(poly[:, 1])
293
+ maxy = np.max(poly[:, 1])
294
+ h_array[miny + pad_h : maxy + pad_h] = 1
295
+ # ensure the cropped area not across a text
296
+ h_axis = np.where(h_array == 0)[0]
297
+ w_axis = np.where(w_array == 0)[0]
298
+ if len(h_axis) == 0 or len(w_axis) == 0:
299
+ return im, polys, tags
300
+
301
+ for i in range(max_tries):
302
+ xx = np.random.choice(w_axis, size=2)
303
+ xmin = np.min(xx) - pad_w
304
+ xmax = np.max(xx) - pad_w
305
+ xmin = np.clip(xmin, 0, w - 1)
306
+ xmax = np.clip(xmax, 0, w - 1)
307
+ yy = np.random.choice(h_axis, size=2)
308
+ ymin = np.min(yy) - pad_h
309
+ ymax = np.max(yy) - pad_h
310
+ ymin = np.clip(ymin, 0, h - 1)
311
+ ymax = np.clip(ymax, 0, h - 1)
312
+ if (
313
+ xmax - xmin < self.min_crop_side_ratio * w
314
+ or ymax - ymin < self.min_crop_side_ratio * h
315
+ ):
316
+ # area too small
317
+ continue
318
+ if polys.shape[0] != 0:
319
+ poly_axis_in_area = (
320
+ (polys[:, :, 0] >= xmin)
321
+ & (polys[:, :, 0] <= xmax)
322
+ & (polys[:, :, 1] >= ymin)
323
+ & (polys[:, :, 1] <= ymax)
324
+ )
325
+ selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
326
+ else:
327
+ selected_polys = []
328
+
329
+ if len(selected_polys) == 0:
330
+ # no text in this area
331
+ if crop_background:
332
+ im = im[ymin : ymax + 1, xmin : xmax + 1, :]
333
+ polys = []
334
+ tags = []
335
+ return im, polys, tags
336
+ else:
337
+ continue
338
+
339
+ im = im[ymin : ymax + 1, xmin : xmax + 1, :]
340
+ polys = polys[selected_polys]
341
+ tags = tags[selected_polys]
342
+ polys[:, :, 0] -= xmin
343
+ polys[:, :, 1] -= ymin
344
+ return im, polys, tags
345
+ return im, polys, tags
346
+
347
+ def crop_background_infor(self, im, text_polys, text_tags):
348
+ im, text_polys, text_tags = self.crop_area(
349
+ im, text_polys, text_tags, crop_background=True
350
+ )
351
+
352
+ if len(text_polys) > 0:
353
+ return None
354
+ # pad and resize image
355
+ input_size = self.input_size
356
+ im, ratio = self.preprocess(im)
357
+ score_map = np.zeros((input_size, input_size), dtype=np.float32)
358
+ geo_map = np.zeros((input_size, input_size, 9), dtype=np.float32)
359
+ training_mask = np.ones((input_size, input_size), dtype=np.float32)
360
+ return im, score_map, geo_map, training_mask
361
+
362
+ def crop_foreground_infor(self, im, text_polys, text_tags):
363
+ im, text_polys, text_tags = self.crop_area(
364
+ im, text_polys, text_tags, crop_background=False
365
+ )
366
+
367
+ if text_polys.shape[0] == 0:
368
+ return None
369
+ # continue for all ignore case
370
+ if np.sum((text_tags * 1.0)) >= text_tags.size:
371
+ return None
372
+ # pad and resize image
373
+ input_size = self.input_size
374
+ im, ratio = self.preprocess(im)
375
+ text_polys[:, :, 0] *= ratio
376
+ text_polys[:, :, 1] *= ratio
377
+ _, _, new_h, new_w = im.shape
378
+ # print(im.shape)
379
+ # self.draw_img_polys(im, text_polys)
380
+ score_map, geo_map, training_mask = self.generate_quad(
381
+ (new_h, new_w), text_polys, text_tags
382
+ )
383
+ return im, score_map, geo_map, training_mask
384
+
385
+ def __call__(self, data):
386
+ im = data["image"]
387
+ text_polys = data["polys"]
388
+ text_tags = data["ignore_tags"]
389
+ if im is None:
390
+ return None
391
+ if text_polys.shape[0] == 0:
392
+ return None
393
+
394
+ # add rotate cases
395
+ if np.random.rand() < 0.5:
396
+ im, text_polys = self.rotate_im_poly(im, text_polys)
397
+ h, w, _ = im.shape
398
+ text_polys, text_tags = self.check_and_validate_polys(
399
+ text_polys, text_tags, h, w
400
+ )
401
+ if text_polys.shape[0] == 0:
402
+ return None
403
+
404
+ # random scale this image
405
+ rd_scale = np.random.choice(self.random_scale)
406
+ im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
407
+ text_polys *= rd_scale
408
+ if np.random.rand() < self.background_ratio:
409
+ outs = self.crop_background_infor(im, text_polys, text_tags)
410
+ else:
411
+ outs = self.crop_foreground_infor(im, text_polys, text_tags)
412
+
413
+ if outs is None:
414
+ return None
415
+ im, score_map, geo_map, training_mask = outs
416
+ score_map = score_map[np.newaxis, ::4, ::4].astype(np.float32)
417
+ geo_map = np.swapaxes(geo_map, 1, 2)
418
+ geo_map = np.swapaxes(geo_map, 1, 0)
419
+ geo_map = geo_map[:, ::4, ::4].astype(np.float32)
420
+ training_mask = training_mask[np.newaxis, ::4, ::4]
421
+ training_mask = training_mask.astype(np.float32)
422
+
423
+ data["image"] = im[0]
424
+ data["score_map"] = score_map
425
+ data["geo_map"] = geo_map
426
+ data["training_mask"] = training_mask
427
+ return data
ocr/ppocr/data/imaug/fce_aug.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+
4
+ import cv2
5
+ import numpy as np
6
+ from PIL import Image, ImageDraw
7
+ from shapely.geometry import Polygon
8
+
9
+ from postprocess.poly_nms import poly_intersection
10
+
11
+
12
+ class RandomScaling:
13
+ def __init__(self, size=800, scale=(3.0 / 4, 5.0 / 2), **kwargs):
14
+ """Random scale the image while keeping aspect.
15
+
16
+ Args:
17
+ size (int) : Base size before scaling.
18
+ scale (tuple(float)) : The range of scaling.
19
+ """
20
+ assert isinstance(size, int)
21
+ assert isinstance(scale, float) or isinstance(scale, tuple)
22
+ self.size = size
23
+ self.scale = scale if isinstance(scale, tuple) else (1 - scale, 1 + scale)
24
+
25
+ def __call__(self, data):
26
+ image = data["image"]
27
+ text_polys = data["polys"]
28
+ h, w, _ = image.shape
29
+
30
+ aspect_ratio = np.random.uniform(min(self.scale), max(self.scale))
31
+ scales = self.size * 1.0 / max(h, w) * aspect_ratio
32
+ scales = np.array([scales, scales])
33
+ out_size = (int(h * scales[1]), int(w * scales[0]))
34
+ image = cv2.resize(image, out_size[::-1])
35
+
36
+ data["image"] = image
37
+ text_polys[:, :, 0::2] = text_polys[:, :, 0::2] * scales[1]
38
+ text_polys[:, :, 1::2] = text_polys[:, :, 1::2] * scales[0]
39
+ data["polys"] = text_polys
40
+
41
+ return data
42
+
43
+
44
+ class RandomCropFlip:
45
+ def __init__(
46
+ self, pad_ratio=0.1, crop_ratio=0.5, iter_num=1, min_area_ratio=0.2, **kwargs
47
+ ):
48
+ """Random crop and flip a patch of the image.
49
+
50
+ Args:
51
+ crop_ratio (float): The ratio of cropping.
52
+ iter_num (int): Number of operations.
53
+ min_area_ratio (float): Minimal area ratio between cropped patch
54
+ and original image.
55
+ """
56
+ assert isinstance(crop_ratio, float)
57
+ assert isinstance(iter_num, int)
58
+ assert isinstance(min_area_ratio, float)
59
+
60
+ self.pad_ratio = pad_ratio
61
+ self.epsilon = 1e-2
62
+ self.crop_ratio = crop_ratio
63
+ self.iter_num = iter_num
64
+ self.min_area_ratio = min_area_ratio
65
+
66
+ def __call__(self, results):
67
+ for i in range(self.iter_num):
68
+ results = self.random_crop_flip(results)
69
+
70
+ return results
71
+
72
+ def random_crop_flip(self, results):
73
+ image = results["image"]
74
+ polygons = results["polys"]
75
+ ignore_tags = results["ignore_tags"]
76
+ if len(polygons) == 0:
77
+ return results
78
+
79
+ if np.random.random() >= self.crop_ratio:
80
+ return results
81
+
82
+ h, w, _ = image.shape
83
+ area = h * w
84
+ pad_h = int(h * self.pad_ratio)
85
+ pad_w = int(w * self.pad_ratio)
86
+ h_axis, w_axis = self.generate_crop_target(image, polygons, pad_h, pad_w)
87
+ if len(h_axis) == 0 or len(w_axis) == 0:
88
+ return results
89
+
90
+ attempt = 0
91
+ while attempt < 50:
92
+ attempt += 1
93
+ polys_keep = []
94
+ polys_new = []
95
+ ignore_tags_keep = []
96
+ ignore_tags_new = []
97
+ xx = np.random.choice(w_axis, size=2)
98
+ xmin = np.min(xx) - pad_w
99
+ xmax = np.max(xx) - pad_w
100
+ xmin = np.clip(xmin, 0, w - 1)
101
+ xmax = np.clip(xmax, 0, w - 1)
102
+ yy = np.random.choice(h_axis, size=2)
103
+ ymin = np.min(yy) - pad_h
104
+ ymax = np.max(yy) - pad_h
105
+ ymin = np.clip(ymin, 0, h - 1)
106
+ ymax = np.clip(ymax, 0, h - 1)
107
+ if (xmax - xmin) * (ymax - ymin) < area * self.min_area_ratio:
108
+ # area too small
109
+ continue
110
+
111
+ pts = np.stack(
112
+ [[xmin, xmax, xmax, xmin], [ymin, ymin, ymax, ymax]]
113
+ ).T.astype(np.int32)
114
+ pp = Polygon(pts)
115
+ fail_flag = False
116
+ for polygon, ignore_tag in zip(polygons, ignore_tags):
117
+ ppi = Polygon(polygon.reshape(-1, 2))
118
+ ppiou, _ = poly_intersection(ppi, pp, buffer=0)
119
+ if (
120
+ np.abs(ppiou - float(ppi.area)) > self.epsilon
121
+ and np.abs(ppiou) > self.epsilon
122
+ ):
123
+ fail_flag = True
124
+ break
125
+ elif np.abs(ppiou - float(ppi.area)) < self.epsilon:
126
+ polys_new.append(polygon)
127
+ ignore_tags_new.append(ignore_tag)
128
+ else:
129
+ polys_keep.append(polygon)
130
+ ignore_tags_keep.append(ignore_tag)
131
+
132
+ if fail_flag:
133
+ continue
134
+ else:
135
+ break
136
+
137
+ cropped = image[ymin:ymax, xmin:xmax, :]
138
+ select_type = np.random.randint(3)
139
+ if select_type == 0:
140
+ img = np.ascontiguousarray(cropped[:, ::-1])
141
+ elif select_type == 1:
142
+ img = np.ascontiguousarray(cropped[::-1, :])
143
+ else:
144
+ img = np.ascontiguousarray(cropped[::-1, ::-1])
145
+ image[ymin:ymax, xmin:xmax, :] = img
146
+ results["img"] = image
147
+
148
+ if len(polys_new) != 0:
149
+ height, width, _ = cropped.shape
150
+ if select_type == 0:
151
+ for idx, polygon in enumerate(polys_new):
152
+ poly = polygon.reshape(-1, 2)
153
+ poly[:, 0] = width - poly[:, 0] + 2 * xmin
154
+ polys_new[idx] = poly
155
+ elif select_type == 1:
156
+ for idx, polygon in enumerate(polys_new):
157
+ poly = polygon.reshape(-1, 2)
158
+ poly[:, 1] = height - poly[:, 1] + 2 * ymin
159
+ polys_new[idx] = poly
160
+ else:
161
+ for idx, polygon in enumerate(polys_new):
162
+ poly = polygon.reshape(-1, 2)
163
+ poly[:, 0] = width - poly[:, 0] + 2 * xmin
164
+ poly[:, 1] = height - poly[:, 1] + 2 * ymin
165
+ polys_new[idx] = poly
166
+ polygons = polys_keep + polys_new
167
+ ignore_tags = ignore_tags_keep + ignore_tags_new
168
+ results["polys"] = np.array(polygons)
169
+ results["ignore_tags"] = ignore_tags
170
+
171
+ return results
172
+
173
+ def generate_crop_target(self, image, all_polys, pad_h, pad_w):
174
+ """Generate crop target and make sure not to crop the polygon
175
+ instances.
176
+
177
+ Args:
178
+ image (ndarray): The image waited to be crop.
179
+ all_polys (list[list[ndarray]]): All polygons including ground
180
+ truth polygons and ground truth ignored polygons.
181
+ pad_h (int): Padding length of height.
182
+ pad_w (int): Padding length of width.
183
+ Returns:
184
+ h_axis (ndarray): Vertical cropping range.
185
+ w_axis (ndarray): Horizontal cropping range.
186
+ """
187
+ h, w, _ = image.shape
188
+ h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
189
+ w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
190
+
191
+ text_polys = []
192
+ for polygon in all_polys:
193
+ rect = cv2.minAreaRect(polygon.astype(np.int32).reshape(-1, 2))
194
+ box = cv2.boxPoints(rect)
195
+ box = np.int0(box)
196
+ text_polys.append([box[0], box[1], box[2], box[3]])
197
+
198
+ polys = np.array(text_polys, dtype=np.int32)
199
+ for poly in polys:
200
+ poly = np.round(poly, decimals=0).astype(np.int32)
201
+ minx = np.min(poly[:, 0])
202
+ maxx = np.max(poly[:, 0])
203
+ w_array[minx + pad_w : maxx + pad_w] = 1
204
+ miny = np.min(poly[:, 1])
205
+ maxy = np.max(poly[:, 1])
206
+ h_array[miny + pad_h : maxy + pad_h] = 1
207
+
208
+ h_axis = np.where(h_array == 0)[0]
209
+ w_axis = np.where(w_array == 0)[0]
210
+ return h_axis, w_axis
211
+
212
+
213
+ class RandomCropPolyInstances:
214
+ """Randomly crop images and make sure to contain at least one intact
215
+ instance."""
216
+
217
+ def __init__(self, crop_ratio=5.0 / 8.0, min_side_ratio=0.4, **kwargs):
218
+ super().__init__()
219
+ self.crop_ratio = crop_ratio
220
+ self.min_side_ratio = min_side_ratio
221
+
222
+ def sample_valid_start_end(self, valid_array, min_len, max_start, min_end):
223
+
224
+ assert isinstance(min_len, int)
225
+ assert len(valid_array) > min_len
226
+
227
+ start_array = valid_array.copy()
228
+ max_start = min(len(start_array) - min_len, max_start)
229
+ start_array[max_start:] = 0
230
+ start_array[0] = 1
231
+ diff_array = np.hstack([0, start_array]) - np.hstack([start_array, 0])
232
+ region_starts = np.where(diff_array < 0)[0]
233
+ region_ends = np.where(diff_array > 0)[0]
234
+ region_ind = np.random.randint(0, len(region_starts))
235
+ start = np.random.randint(region_starts[region_ind], region_ends[region_ind])
236
+
237
+ end_array = valid_array.copy()
238
+ min_end = max(start + min_len, min_end)
239
+ end_array[:min_end] = 0
240
+ end_array[-1] = 1
241
+ diff_array = np.hstack([0, end_array]) - np.hstack([end_array, 0])
242
+ region_starts = np.where(diff_array < 0)[0]
243
+ region_ends = np.where(diff_array > 0)[0]
244
+ region_ind = np.random.randint(0, len(region_starts))
245
+ end = np.random.randint(region_starts[region_ind], region_ends[region_ind])
246
+ return start, end
247
+
248
+ def sample_crop_box(self, img_size, results):
249
+ """Generate crop box and make sure not to crop the polygon instances.
250
+
251
+ Args:
252
+ img_size (tuple(int)): The image size (h, w).
253
+ results (dict): The results dict.
254
+ """
255
+
256
+ assert isinstance(img_size, tuple)
257
+ h, w = img_size[:2]
258
+
259
+ key_masks = results["polys"]
260
+
261
+ x_valid_array = np.ones(w, dtype=np.int32)
262
+ y_valid_array = np.ones(h, dtype=np.int32)
263
+
264
+ selected_mask = key_masks[np.random.randint(0, len(key_masks))]
265
+ selected_mask = selected_mask.reshape((-1, 2)).astype(np.int32)
266
+ max_x_start = max(np.min(selected_mask[:, 0]) - 2, 0)
267
+ min_x_end = min(np.max(selected_mask[:, 0]) + 3, w - 1)
268
+ max_y_start = max(np.min(selected_mask[:, 1]) - 2, 0)
269
+ min_y_end = min(np.max(selected_mask[:, 1]) + 3, h - 1)
270
+
271
+ for mask in key_masks:
272
+ mask = mask.reshape((-1, 2)).astype(np.int32)
273
+ clip_x = np.clip(mask[:, 0], 0, w - 1)
274
+ clip_y = np.clip(mask[:, 1], 0, h - 1)
275
+ min_x, max_x = np.min(clip_x), np.max(clip_x)
276
+ min_y, max_y = np.min(clip_y), np.max(clip_y)
277
+
278
+ x_valid_array[min_x - 2 : max_x + 3] = 0
279
+ y_valid_array[min_y - 2 : max_y + 3] = 0
280
+
281
+ min_w = int(w * self.min_side_ratio)
282
+ min_h = int(h * self.min_side_ratio)
283
+
284
+ x1, x2 = self.sample_valid_start_end(
285
+ x_valid_array, min_w, max_x_start, min_x_end
286
+ )
287
+ y1, y2 = self.sample_valid_start_end(
288
+ y_valid_array, min_h, max_y_start, min_y_end
289
+ )
290
+
291
+ return np.array([x1, y1, x2, y2])
292
+
293
+ def crop_img(self, img, bbox):
294
+ assert img.ndim == 3
295
+ h, w, _ = img.shape
296
+ assert 0 <= bbox[1] < bbox[3] <= h
297
+ assert 0 <= bbox[0] < bbox[2] <= w
298
+ return img[bbox[1] : bbox[3], bbox[0] : bbox[2]]
299
+
300
+ def __call__(self, results):
301
+ image = results["image"]
302
+ polygons = results["polys"]
303
+ ignore_tags = results["ignore_tags"]
304
+ if len(polygons) < 1:
305
+ return results
306
+
307
+ if np.random.random_sample() < self.crop_ratio:
308
+
309
+ crop_box = self.sample_crop_box(image.shape, results)
310
+ img = self.crop_img(image, crop_box)
311
+ results["image"] = img
312
+ # crop and filter masks
313
+ x1, y1, x2, y2 = crop_box
314
+ w = max(x2 - x1, 1)
315
+ h = max(y2 - y1, 1)
316
+ polygons[:, :, 0::2] = polygons[:, :, 0::2] - x1
317
+ polygons[:, :, 1::2] = polygons[:, :, 1::2] - y1
318
+
319
+ valid_masks_list = []
320
+ valid_tags_list = []
321
+ for ind, polygon in enumerate(polygons):
322
+ if (
323
+ (polygon[:, ::2] > -4).all()
324
+ and (polygon[:, ::2] < w + 4).all()
325
+ and (polygon[:, 1::2] > -4).all()
326
+ and (polygon[:, 1::2] < h + 4).all()
327
+ ):
328
+ polygon[:, ::2] = np.clip(polygon[:, ::2], 0, w)
329
+ polygon[:, 1::2] = np.clip(polygon[:, 1::2], 0, h)
330
+ valid_masks_list.append(polygon)
331
+ valid_tags_list.append(ignore_tags[ind])
332
+
333
+ results["polys"] = np.array(valid_masks_list)
334
+ results["ignore_tags"] = valid_tags_list
335
+
336
+ return results
337
+
338
+ def __repr__(self):
339
+ repr_str = self.__class__.__name__
340
+ return repr_str
341
+
342
+
343
+ class RandomRotatePolyInstances:
344
+ def __init__(
345
+ self,
346
+ rotate_ratio=0.5,
347
+ max_angle=10,
348
+ pad_with_fixed_color=False,
349
+ pad_value=(0, 0, 0),
350
+ **kwargs
351
+ ):
352
+ """Randomly rotate images and polygon masks.
353
+
354
+ Args:
355
+ rotate_ratio (float): The ratio of samples to operate rotation.
356
+ max_angle (int): The maximum rotation angle.
357
+ pad_with_fixed_color (bool): The flag for whether to pad rotated
358
+ image with fixed value. If set to False, the rotated image will
359
+ be padded onto cropped image.
360
+ pad_value (tuple(int)): The color value for padding rotated image.
361
+ """
362
+ self.rotate_ratio = rotate_ratio
363
+ self.max_angle = max_angle
364
+ self.pad_with_fixed_color = pad_with_fixed_color
365
+ self.pad_value = pad_value
366
+
367
+ def rotate(self, center, points, theta, center_shift=(0, 0)):
368
+ # rotate points.
369
+ (center_x, center_y) = center
370
+ center_y = -center_y
371
+ x, y = points[:, ::2], points[:, 1::2]
372
+ y = -y
373
+
374
+ theta = theta / 180 * math.pi
375
+ cos = math.cos(theta)
376
+ sin = math.sin(theta)
377
+
378
+ x = x - center_x
379
+ y = y - center_y
380
+
381
+ _x = center_x + x * cos - y * sin + center_shift[0]
382
+ _y = -(center_y + x * sin + y * cos) + center_shift[1]
383
+
384
+ points[:, ::2], points[:, 1::2] = _x, _y
385
+ return points
386
+
387
+ def cal_canvas_size(self, ori_size, degree):
388
+ assert isinstance(ori_size, tuple)
389
+ angle = degree * math.pi / 180.0
390
+ h, w = ori_size[:2]
391
+
392
+ cos = math.cos(angle)
393
+ sin = math.sin(angle)
394
+ canvas_h = int(w * math.fabs(sin) + h * math.fabs(cos))
395
+ canvas_w = int(w * math.fabs(cos) + h * math.fabs(sin))
396
+
397
+ canvas_size = (canvas_h, canvas_w)
398
+ return canvas_size
399
+
400
+ def sample_angle(self, max_angle):
401
+ angle = np.random.random_sample() * 2 * max_angle - max_angle
402
+ return angle
403
+
404
+ def rotate_img(self, img, angle, canvas_size):
405
+ h, w = img.shape[:2]
406
+ rotation_matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
407
+ rotation_matrix[0, 2] += int((canvas_size[1] - w) / 2)
408
+ rotation_matrix[1, 2] += int((canvas_size[0] - h) / 2)
409
+
410
+ if self.pad_with_fixed_color:
411
+ target_img = cv2.warpAffine(
412
+ img,
413
+ rotation_matrix,
414
+ (canvas_size[1], canvas_size[0]),
415
+ flags=cv2.INTER_NEAREST,
416
+ borderValue=self.pad_value,
417
+ )
418
+ else:
419
+ mask = np.zeros_like(img)
420
+ (h_ind, w_ind) = (
421
+ np.random.randint(0, h * 7 // 8),
422
+ np.random.randint(0, w * 7 // 8),
423
+ )
424
+ img_cut = img[h_ind : (h_ind + h // 9), w_ind : (w_ind + w // 9)]
425
+ img_cut = cv2.resize(img_cut, (canvas_size[1], canvas_size[0]))
426
+
427
+ mask = cv2.warpAffine(
428
+ mask,
429
+ rotation_matrix,
430
+ (canvas_size[1], canvas_size[0]),
431
+ borderValue=[1, 1, 1],
432
+ )
433
+ target_img = cv2.warpAffine(
434
+ img,
435
+ rotation_matrix,
436
+ (canvas_size[1], canvas_size[0]),
437
+ borderValue=[0, 0, 0],
438
+ )
439
+ target_img = target_img + img_cut * mask
440
+
441
+ return target_img
442
+
443
+ def __call__(self, results):
444
+ if np.random.random_sample() < self.rotate_ratio:
445
+ image = results["image"]
446
+ polygons = results["polys"]
447
+ h, w = image.shape[:2]
448
+
449
+ angle = self.sample_angle(self.max_angle)
450
+ canvas_size = self.cal_canvas_size((h, w), angle)
451
+ center_shift = (
452
+ int((canvas_size[1] - w) / 2),
453
+ int((canvas_size[0] - h) / 2),
454
+ )
455
+ image = self.rotate_img(image, angle, canvas_size)
456
+ results["image"] = image
457
+ # rotate polygons
458
+ rotated_masks = []
459
+ for mask in polygons:
460
+ rotated_mask = self.rotate((w / 2, h / 2), mask, angle, center_shift)
461
+ rotated_masks.append(rotated_mask)
462
+ results["polys"] = np.array(rotated_masks)
463
+
464
+ return results
465
+
466
+ def __repr__(self):
467
+ repr_str = self.__class__.__name__
468
+ return repr_str
469
+
470
+
471
+ class SquareResizePad:
472
+ def __init__(
473
+ self,
474
+ target_size,
475
+ pad_ratio=0.6,
476
+ pad_with_fixed_color=False,
477
+ pad_value=(0, 0, 0),
478
+ **kwargs
479
+ ):
480
+ """Resize or pad images to be square shape.
481
+
482
+ Args:
483
+ target_size (int): The target size of square shaped image.
484
+ pad_with_fixed_color (bool): The flag for whether to pad rotated
485
+ image with fixed value. If set to False, the rescales image will
486
+ be padded onto cropped image.
487
+ pad_value (tuple(int)): The color value for padding rotated image.
488
+ """
489
+ assert isinstance(target_size, int)
490
+ assert isinstance(pad_ratio, float)
491
+ assert isinstance(pad_with_fixed_color, bool)
492
+ assert isinstance(pad_value, tuple)
493
+
494
+ self.target_size = target_size
495
+ self.pad_ratio = pad_ratio
496
+ self.pad_with_fixed_color = pad_with_fixed_color
497
+ self.pad_value = pad_value
498
+
499
+ def resize_img(self, img, keep_ratio=True):
500
+ h, w, _ = img.shape
501
+ if keep_ratio:
502
+ t_h = self.target_size if h >= w else int(h * self.target_size / w)
503
+ t_w = self.target_size if h <= w else int(w * self.target_size / h)
504
+ else:
505
+ t_h = t_w = self.target_size
506
+ img = cv2.resize(img, (t_w, t_h))
507
+ return img, (t_h, t_w)
508
+
509
+ def square_pad(self, img):
510
+ h, w = img.shape[:2]
511
+ if h == w:
512
+ return img, (0, 0)
513
+ pad_size = max(h, w)
514
+ if self.pad_with_fixed_color:
515
+ expand_img = np.ones((pad_size, pad_size, 3), dtype=np.uint8)
516
+ expand_img[:] = self.pad_value
517
+ else:
518
+ (h_ind, w_ind) = (
519
+ np.random.randint(0, h * 7 // 8),
520
+ np.random.randint(0, w * 7 // 8),
521
+ )
522
+ img_cut = img[h_ind : (h_ind + h // 9), w_ind : (w_ind + w // 9)]
523
+ expand_img = cv2.resize(img_cut, (pad_size, pad_size))
524
+ if h > w:
525
+ y0, x0 = 0, (h - w) // 2
526
+ else:
527
+ y0, x0 = (w - h) // 2, 0
528
+ expand_img[y0 : y0 + h, x0 : x0 + w] = img
529
+ offset = (x0, y0)
530
+
531
+ return expand_img, offset
532
+
533
+ def square_pad_mask(self, points, offset):
534
+ x0, y0 = offset
535
+ pad_points = points.copy()
536
+ pad_points[::2] = pad_points[::2] + x0
537
+ pad_points[1::2] = pad_points[1::2] + y0
538
+ return pad_points
539
+
540
+ def __call__(self, results):
541
+ image = results["image"]
542
+ polygons = results["polys"]
543
+ h, w = image.shape[:2]
544
+
545
+ if np.random.random_sample() < self.pad_ratio:
546
+ image, out_size = self.resize_img(image, keep_ratio=True)
547
+ image, offset = self.square_pad(image)
548
+ else:
549
+ image, out_size = self.resize_img(image, keep_ratio=False)
550
+ offset = (0, 0)
551
+ results["image"] = image
552
+ try:
553
+ polygons[:, :, 0::2] = polygons[:, :, 0::2] * out_size[1] / w + offset[0]
554
+ polygons[:, :, 1::2] = polygons[:, :, 1::2] * out_size[0] / h + offset[1]
555
+ except:
556
+ pass
557
+ results["polys"] = polygons
558
+
559
+ return results
560
+
561
+ def __repr__(self):
562
+ repr_str = self.__class__.__name__
563
+ return repr_str
ocr/ppocr/data/imaug/fce_targets.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from numpy.fft import fft
4
+ from numpy.linalg import norm
5
+
6
+
7
+ def vector_slope(vec):
8
+ assert len(vec) == 2
9
+ return abs(vec[1] / (vec[0] + 1e-8))
10
+
11
+
12
+ class FCENetTargets:
13
+ """Generate the ground truth targets of FCENet: Fourier Contour Embedding
14
+ for Arbitrary-Shaped Text Detection.
15
+
16
+ [https://arxiv.org/abs/2104.10442]
17
+
18
+ Args:
19
+ fourier_degree (int): The maximum Fourier transform degree k.
20
+ resample_step (float): The step size for resampling the text center
21
+ line (TCL). It's better not to exceed half of the minimum width.
22
+ center_region_shrink_ratio (float): The shrink ratio of text center
23
+ region.
24
+ level_size_divisors (tuple(int)): The downsample ratio on each level.
25
+ level_proportion_range (tuple(tuple(int))): The range of text sizes
26
+ assigned to each level.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ fourier_degree=5,
32
+ resample_step=4.0,
33
+ center_region_shrink_ratio=0.3,
34
+ level_size_divisors=(8, 16, 32),
35
+ level_proportion_range=((0, 0.25), (0.2, 0.65), (0.55, 1.0)),
36
+ orientation_thr=2.0,
37
+ **kwargs
38
+ ):
39
+
40
+ super().__init__()
41
+ assert isinstance(level_size_divisors, tuple)
42
+ assert isinstance(level_proportion_range, tuple)
43
+ assert len(level_size_divisors) == len(level_proportion_range)
44
+ self.fourier_degree = fourier_degree
45
+ self.resample_step = resample_step
46
+ self.center_region_shrink_ratio = center_region_shrink_ratio
47
+ self.level_size_divisors = level_size_divisors
48
+ self.level_proportion_range = level_proportion_range
49
+
50
+ self.orientation_thr = orientation_thr
51
+
52
+ def vector_angle(self, vec1, vec2):
53
+ if vec1.ndim > 1:
54
+ unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8).reshape((-1, 1))
55
+ else:
56
+ unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8)
57
+ if vec2.ndim > 1:
58
+ unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8).reshape((-1, 1))
59
+ else:
60
+ unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8)
61
+ return np.arccos(np.clip(np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
62
+
63
+ def resample_line(self, line, n):
64
+ """Resample n points on a line.
65
+
66
+ Args:
67
+ line (ndarray): The points composing a line.
68
+ n (int): The resampled points number.
69
+
70
+ Returns:
71
+ resampled_line (ndarray): The points composing the resampled line.
72
+ """
73
+
74
+ assert line.ndim == 2
75
+ assert line.shape[0] >= 2
76
+ assert line.shape[1] == 2
77
+ assert isinstance(n, int)
78
+ assert n > 0
79
+
80
+ length_list = [norm(line[i + 1] - line[i]) for i in range(len(line) - 1)]
81
+ total_length = sum(length_list)
82
+ length_cumsum = np.cumsum([0.0] + length_list)
83
+ delta_length = total_length / (float(n) + 1e-8)
84
+
85
+ current_edge_ind = 0
86
+ resampled_line = [line[0]]
87
+
88
+ for i in range(1, n):
89
+ current_line_len = i * delta_length
90
+
91
+ while current_line_len >= length_cumsum[current_edge_ind + 1]:
92
+ current_edge_ind += 1
93
+ current_edge_end_shift = current_line_len - length_cumsum[current_edge_ind]
94
+ end_shift_ratio = current_edge_end_shift / length_list[current_edge_ind]
95
+ current_point = (
96
+ line[current_edge_ind]
97
+ + (line[current_edge_ind + 1] - line[current_edge_ind])
98
+ * end_shift_ratio
99
+ )
100
+ resampled_line.append(current_point)
101
+
102
+ resampled_line.append(line[-1])
103
+ resampled_line = np.array(resampled_line)
104
+
105
+ return resampled_line
106
+
107
+ def reorder_poly_edge(self, points):
108
+ """Get the respective points composing head edge, tail edge, top
109
+ sideline and bottom sideline.
110
+
111
+ Args:
112
+ points (ndarray): The points composing a text polygon.
113
+
114
+ Returns:
115
+ head_edge (ndarray): The two points composing the head edge of text
116
+ polygon.
117
+ tail_edge (ndarray): The two points composing the tail edge of text
118
+ polygon.
119
+ top_sideline (ndarray): The points composing top curved sideline of
120
+ text polygon.
121
+ bot_sideline (ndarray): The points composing bottom curved sideline
122
+ of text polygon.
123
+ """
124
+
125
+ assert points.ndim == 2
126
+ assert points.shape[0] >= 4
127
+ assert points.shape[1] == 2
128
+
129
+ head_inds, tail_inds = self.find_head_tail(points, self.orientation_thr)
130
+ head_edge, tail_edge = points[head_inds], points[tail_inds]
131
+
132
+ pad_points = np.vstack([points, points])
133
+ if tail_inds[1] < 1:
134
+ tail_inds[1] = len(points)
135
+ sideline1 = pad_points[head_inds[1] : tail_inds[1]]
136
+ sideline2 = pad_points[tail_inds[1] : (head_inds[1] + len(points))]
137
+ sideline_mean_shift = np.mean(sideline1, axis=0) - np.mean(sideline2, axis=0)
138
+
139
+ if sideline_mean_shift[1] > 0:
140
+ top_sideline, bot_sideline = sideline2, sideline1
141
+ else:
142
+ top_sideline, bot_sideline = sideline1, sideline2
143
+
144
+ return head_edge, tail_edge, top_sideline, bot_sideline
145
+
146
+ def find_head_tail(self, points, orientation_thr):
147
+ """Find the head edge and tail edge of a text polygon.
148
+
149
+ Args:
150
+ points (ndarray): The points composing a text polygon.
151
+ orientation_thr (float): The threshold for distinguishing between
152
+ head edge and tail edge among the horizontal and vertical edges
153
+ of a quadrangle.
154
+
155
+ Returns:
156
+ head_inds (list): The indexes of two points composing head edge.
157
+ tail_inds (list): The indexes of two points composing tail edge.
158
+ """
159
+
160
+ assert points.ndim == 2
161
+ assert points.shape[0] >= 4
162
+ assert points.shape[1] == 2
163
+ assert isinstance(orientation_thr, float)
164
+
165
+ if len(points) > 4:
166
+ pad_points = np.vstack([points, points[0]])
167
+ edge_vec = pad_points[1:] - pad_points[:-1]
168
+
169
+ theta_sum = []
170
+ adjacent_vec_theta = []
171
+ for i, edge_vec1 in enumerate(edge_vec):
172
+ adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
173
+ adjacent_edge_vec = edge_vec[adjacent_ind]
174
+ temp_theta_sum = np.sum(self.vector_angle(edge_vec1, adjacent_edge_vec))
175
+ temp_adjacent_theta = self.vector_angle(
176
+ adjacent_edge_vec[0], adjacent_edge_vec[1]
177
+ )
178
+ theta_sum.append(temp_theta_sum)
179
+ adjacent_vec_theta.append(temp_adjacent_theta)
180
+ theta_sum_score = np.array(theta_sum) / np.pi
181
+ adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
182
+ poly_center = np.mean(points, axis=0)
183
+ edge_dist = np.maximum(
184
+ norm(pad_points[1:] - poly_center, axis=-1),
185
+ norm(pad_points[:-1] - poly_center, axis=-1),
186
+ )
187
+ dist_score = edge_dist / np.max(edge_dist)
188
+ position_score = np.zeros(len(edge_vec))
189
+ score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
190
+ score += 0.35 * dist_score
191
+ if len(points) % 2 == 0:
192
+ position_score[(len(score) // 2 - 1)] += 1
193
+ position_score[-1] += 1
194
+ score += 0.1 * position_score
195
+ pad_score = np.concatenate([score, score])
196
+ score_matrix = np.zeros((len(score), len(score) - 3))
197
+ x = np.arange(len(score) - 3) / float(len(score) - 4)
198
+ gaussian = (
199
+ 1.0
200
+ / (np.sqrt(2.0 * np.pi) * 0.5)
201
+ * np.exp(-np.power((x - 0.5) / 0.5, 2.0) / 2)
202
+ )
203
+ gaussian = gaussian / np.max(gaussian)
204
+ for i in range(len(score)):
205
+ score_matrix[i, :] = (
206
+ score[i]
207
+ + pad_score[(i + 2) : (i + len(score) - 1)] * gaussian * 0.3
208
+ )
209
+
210
+ head_start, tail_increment = np.unravel_index(
211
+ score_matrix.argmax(), score_matrix.shape
212
+ )
213
+ tail_start = (head_start + tail_increment + 2) % len(points)
214
+ head_end = (head_start + 1) % len(points)
215
+ tail_end = (tail_start + 1) % len(points)
216
+
217
+ if head_end > tail_end:
218
+ head_start, tail_start = tail_start, head_start
219
+ head_end, tail_end = tail_end, head_end
220
+ head_inds = [head_start, head_end]
221
+ tail_inds = [tail_start, tail_end]
222
+ else:
223
+ if vector_slope(points[1] - points[0]) + vector_slope(
224
+ points[3] - points[2]
225
+ ) < vector_slope(points[2] - points[1]) + vector_slope(
226
+ points[0] - points[3]
227
+ ):
228
+ horizontal_edge_inds = [[0, 1], [2, 3]]
229
+ vertical_edge_inds = [[3, 0], [1, 2]]
230
+ else:
231
+ horizontal_edge_inds = [[3, 0], [1, 2]]
232
+ vertical_edge_inds = [[0, 1], [2, 3]]
233
+
234
+ vertical_len_sum = norm(
235
+ points[vertical_edge_inds[0][0]] - points[vertical_edge_inds[0][1]]
236
+ ) + norm(
237
+ points[vertical_edge_inds[1][0]] - points[vertical_edge_inds[1][1]]
238
+ )
239
+ horizontal_len_sum = norm(
240
+ points[horizontal_edge_inds[0][0]] - points[horizontal_edge_inds[0][1]]
241
+ ) + norm(
242
+ points[horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1][1]]
243
+ )
244
+
245
+ if vertical_len_sum > horizontal_len_sum * orientation_thr:
246
+ head_inds = horizontal_edge_inds[0]
247
+ tail_inds = horizontal_edge_inds[1]
248
+ else:
249
+ head_inds = vertical_edge_inds[0]
250
+ tail_inds = vertical_edge_inds[1]
251
+
252
+ return head_inds, tail_inds
253
+
254
+ def resample_sidelines(self, sideline1, sideline2, resample_step):
255
+ """Resample two sidelines to be of the same points number according to
256
+ step size.
257
+
258
+ Args:
259
+ sideline1 (ndarray): The points composing a sideline of a text
260
+ polygon.
261
+ sideline2 (ndarray): The points composing another sideline of a
262
+ text polygon.
263
+ resample_step (float): The resampled step size.
264
+
265
+ Returns:
266
+ resampled_line1 (ndarray): The resampled line 1.
267
+ resampled_line2 (ndarray): The resampled line 2.
268
+ """
269
+
270
+ assert sideline1.ndim == sideline2.ndim == 2
271
+ assert sideline1.shape[1] == sideline2.shape[1] == 2
272
+ assert sideline1.shape[0] >= 2
273
+ assert sideline2.shape[0] >= 2
274
+ assert isinstance(resample_step, float)
275
+
276
+ length1 = sum(
277
+ [norm(sideline1[i + 1] - sideline1[i]) for i in range(len(sideline1) - 1)]
278
+ )
279
+ length2 = sum(
280
+ [norm(sideline2[i + 1] - sideline2[i]) for i in range(len(sideline2) - 1)]
281
+ )
282
+
283
+ total_length = (length1 + length2) / 2
284
+ resample_point_num = max(int(float(total_length) / resample_step), 1)
285
+
286
+ resampled_line1 = self.resample_line(sideline1, resample_point_num)
287
+ resampled_line2 = self.resample_line(sideline2, resample_point_num)
288
+
289
+ return resampled_line1, resampled_line2
290
+
291
+ def generate_center_region_mask(self, img_size, text_polys):
292
+ """Generate text center region mask.
293
+
294
+ Args:
295
+ img_size (tuple): The image size of (height, width).
296
+ text_polys (list[list[ndarray]]): The list of text polygons.
297
+
298
+ Returns:
299
+ center_region_mask (ndarray): The text center region mask.
300
+ """
301
+
302
+ assert isinstance(img_size, tuple)
303
+ # assert check_argument.is_2dlist(text_polys)
304
+
305
+ h, w = img_size
306
+
307
+ center_region_mask = np.zeros((h, w), np.uint8)
308
+
309
+ center_region_boxes = []
310
+ for poly in text_polys:
311
+ # assert len(poly) == 1
312
+ polygon_points = poly.reshape(-1, 2)
313
+ _, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
314
+ resampled_top_line, resampled_bot_line = self.resample_sidelines(
315
+ top_line, bot_line, self.resample_step
316
+ )
317
+ resampled_bot_line = resampled_bot_line[::-1]
318
+ center_line = (resampled_top_line + resampled_bot_line) / 2
319
+
320
+ line_head_shrink_len = (
321
+ norm(resampled_top_line[0] - resampled_bot_line[0]) / 4.0
322
+ )
323
+ line_tail_shrink_len = (
324
+ norm(resampled_top_line[-1] - resampled_bot_line[-1]) / 4.0
325
+ )
326
+ head_shrink_num = int(line_head_shrink_len // self.resample_step)
327
+ tail_shrink_num = int(line_tail_shrink_len // self.resample_step)
328
+ if len(center_line) > head_shrink_num + tail_shrink_num + 2:
329
+ center_line = center_line[
330
+ head_shrink_num : len(center_line) - tail_shrink_num
331
+ ]
332
+ resampled_top_line = resampled_top_line[
333
+ head_shrink_num : len(resampled_top_line) - tail_shrink_num
334
+ ]
335
+ resampled_bot_line = resampled_bot_line[
336
+ head_shrink_num : len(resampled_bot_line) - tail_shrink_num
337
+ ]
338
+
339
+ for i in range(0, len(center_line) - 1):
340
+ tl = (
341
+ center_line[i]
342
+ + (resampled_top_line[i] - center_line[i])
343
+ * self.center_region_shrink_ratio
344
+ )
345
+ tr = (
346
+ center_line[i + 1]
347
+ + (resampled_top_line[i + 1] - center_line[i + 1])
348
+ * self.center_region_shrink_ratio
349
+ )
350
+ br = (
351
+ center_line[i + 1]
352
+ + (resampled_bot_line[i + 1] - center_line[i + 1])
353
+ * self.center_region_shrink_ratio
354
+ )
355
+ bl = (
356
+ center_line[i]
357
+ + (resampled_bot_line[i] - center_line[i])
358
+ * self.center_region_shrink_ratio
359
+ )
360
+ current_center_box = np.vstack([tl, tr, br, bl]).astype(np.int32)
361
+ center_region_boxes.append(current_center_box)
362
+
363
+ cv2.fillPoly(center_region_mask, center_region_boxes, 1)
364
+ return center_region_mask
365
+
366
+ def resample_polygon(self, polygon, n=400):
367
+ """Resample one polygon with n points on its boundary.
368
+
369
+ Args:
370
+ polygon (list[float]): The input polygon.
371
+ n (int): The number of resampled points.
372
+ Returns:
373
+ resampled_polygon (list[float]): The resampled polygon.
374
+ """
375
+ length = []
376
+
377
+ for i in range(len(polygon)):
378
+ p1 = polygon[i]
379
+ if i == len(polygon) - 1:
380
+ p2 = polygon[0]
381
+ else:
382
+ p2 = polygon[i + 1]
383
+ length.append(((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) ** 0.5)
384
+
385
+ total_length = sum(length)
386
+ n_on_each_line = (np.array(length) / (total_length + 1e-8)) * n
387
+ n_on_each_line = n_on_each_line.astype(np.int32)
388
+ new_polygon = []
389
+
390
+ for i in range(len(polygon)):
391
+ num = n_on_each_line[i]
392
+ p1 = polygon[i]
393
+ if i == len(polygon) - 1:
394
+ p2 = polygon[0]
395
+ else:
396
+ p2 = polygon[i + 1]
397
+
398
+ if num == 0:
399
+ continue
400
+
401
+ dxdy = (p2 - p1) / num
402
+ for j in range(num):
403
+ point = p1 + dxdy * j
404
+ new_polygon.append(point)
405
+
406
+ return np.array(new_polygon)
407
+
408
+ def normalize_polygon(self, polygon):
409
+ """Normalize one polygon so that its start point is at right most.
410
+
411
+ Args:
412
+ polygon (list[float]): The origin polygon.
413
+ Returns:
414
+ new_polygon (lost[float]): The polygon with start point at right.
415
+ """
416
+ temp_polygon = polygon - polygon.mean(axis=0)
417
+ x = np.abs(temp_polygon[:, 0])
418
+ y = temp_polygon[:, 1]
419
+ index_x = np.argsort(x)
420
+ index_y = np.argmin(y[index_x[:8]])
421
+ index = index_x[index_y]
422
+ new_polygon = np.concatenate([polygon[index:], polygon[:index]])
423
+ return new_polygon
424
+
425
+ def poly2fourier(self, polygon, fourier_degree):
426
+ """Perform Fourier transformation to generate Fourier coefficients ck
427
+ from polygon.
428
+
429
+ Args:
430
+ polygon (ndarray): An input polygon.
431
+ fourier_degree (int): The maximum Fourier degree K.
432
+ Returns:
433
+ c (ndarray(complex)): Fourier coefficients.
434
+ """
435
+ points = polygon[:, 0] + polygon[:, 1] * 1j
436
+ c_fft = fft(points) / len(points)
437
+ c = np.hstack((c_fft[-fourier_degree:], c_fft[: fourier_degree + 1]))
438
+ return c
439
+
440
+ def clockwise(self, c, fourier_degree):
441
+ """Make sure the polygon reconstructed from Fourier coefficients c in
442
+ the clockwise direction.
443
+
444
+ Args:
445
+ polygon (list[float]): The origin polygon.
446
+ Returns:
447
+ new_polygon (lost[float]): The polygon in clockwise point order.
448
+ """
449
+ if np.abs(c[fourier_degree + 1]) > np.abs(c[fourier_degree - 1]):
450
+ return c
451
+ elif np.abs(c[fourier_degree + 1]) < np.abs(c[fourier_degree - 1]):
452
+ return c[::-1]
453
+ else:
454
+ if np.abs(c[fourier_degree + 2]) > np.abs(c[fourier_degree - 2]):
455
+ return c
456
+ else:
457
+ return c[::-1]
458
+
459
+ def cal_fourier_signature(self, polygon, fourier_degree):
460
+ """Calculate Fourier signature from input polygon.
461
+
462
+ Args:
463
+ polygon (ndarray): The input polygon.
464
+ fourier_degree (int): The maximum Fourier degree K.
465
+ Returns:
466
+ fourier_signature (ndarray): An array shaped (2k+1, 2) containing
467
+ real part and image part of 2k+1 Fourier coefficients.
468
+ """
469
+ resampled_polygon = self.resample_polygon(polygon)
470
+ resampled_polygon = self.normalize_polygon(resampled_polygon)
471
+
472
+ fourier_coeff = self.poly2fourier(resampled_polygon, fourier_degree)
473
+ fourier_coeff = self.clockwise(fourier_coeff, fourier_degree)
474
+
475
+ real_part = np.real(fourier_coeff).reshape((-1, 1))
476
+ image_part = np.imag(fourier_coeff).reshape((-1, 1))
477
+ fourier_signature = np.hstack([real_part, image_part])
478
+
479
+ return fourier_signature
480
+
481
+ def generate_fourier_maps(self, img_size, text_polys):
482
+ """Generate Fourier coefficient maps.
483
+
484
+ Args:
485
+ img_size (tuple): The image size of (height, width).
486
+ text_polys (list[list[ndarray]]): The list of text polygons.
487
+
488
+ Returns:
489
+ fourier_real_map (ndarray): The Fourier coefficient real part maps.
490
+ fourier_image_map (ndarray): The Fourier coefficient image part
491
+ maps.
492
+ """
493
+
494
+ assert isinstance(img_size, tuple)
495
+
496
+ h, w = img_size
497
+ k = self.fourier_degree
498
+ real_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32)
499
+ imag_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32)
500
+
501
+ for poly in text_polys:
502
+ mask = np.zeros((h, w), dtype=np.uint8)
503
+ polygon = np.array(poly).reshape((1, -1, 2))
504
+ cv2.fillPoly(mask, polygon.astype(np.int32), 1)
505
+ fourier_coeff = self.cal_fourier_signature(polygon[0], k)
506
+ for i in range(-k, k + 1):
507
+ if i != 0:
508
+ real_map[i + k, :, :] = (
509
+ mask * fourier_coeff[i + k, 0]
510
+ + (1 - mask) * real_map[i + k, :, :]
511
+ )
512
+ imag_map[i + k, :, :] = (
513
+ mask * fourier_coeff[i + k, 1]
514
+ + (1 - mask) * imag_map[i + k, :, :]
515
+ )
516
+ else:
517
+ yx = np.argwhere(mask > 0.5)
518
+ k_ind = np.ones((len(yx)), dtype=np.int64) * k
519
+ y, x = yx[:, 0], yx[:, 1]
520
+ real_map[k_ind, y, x] = fourier_coeff[k, 0] - x
521
+ imag_map[k_ind, y, x] = fourier_coeff[k, 1] - y
522
+
523
+ return real_map, imag_map
524
+
525
+ def generate_text_region_mask(self, img_size, text_polys):
526
+ """Generate text center region mask and geometry attribute maps.
527
+
528
+ Args:
529
+ img_size (tuple): The image size (height, width).
530
+ text_polys (list[list[ndarray]]): The list of text polygons.
531
+
532
+ Returns:
533
+ text_region_mask (ndarray): The text region mask.
534
+ """
535
+
536
+ assert isinstance(img_size, tuple)
537
+
538
+ h, w = img_size
539
+ text_region_mask = np.zeros((h, w), dtype=np.uint8)
540
+
541
+ for poly in text_polys:
542
+ polygon = np.array(poly, dtype=np.int32).reshape((1, -1, 2))
543
+ cv2.fillPoly(text_region_mask, polygon, 1)
544
+
545
+ return text_region_mask
546
+
547
+ def generate_effective_mask(self, mask_size: tuple, polygons_ignore):
548
+ """Generate effective mask by setting the ineffective regions to 0 and
549
+ effective regions to 1.
550
+
551
+ Args:
552
+ mask_size (tuple): The mask size.
553
+ polygons_ignore (list[[ndarray]]: The list of ignored text
554
+ polygons.
555
+
556
+ Returns:
557
+ mask (ndarray): The effective mask of (height, width).
558
+ """
559
+
560
+ mask = np.ones(mask_size, dtype=np.uint8)
561
+
562
+ for poly in polygons_ignore:
563
+ instance = poly.reshape(-1, 2).astype(np.int32).reshape(1, -1, 2)
564
+ cv2.fillPoly(mask, instance, 0)
565
+
566
+ return mask
567
+
568
+ def generate_level_targets(self, img_size, text_polys, ignore_polys):
569
+ """Generate ground truth target on each level.
570
+
571
+ Args:
572
+ img_size (list[int]): Shape of input image.
573
+ text_polys (list[list[ndarray]]): A list of ground truth polygons.
574
+ ignore_polys (list[list[ndarray]]): A list of ignored polygons.
575
+ Returns:
576
+ level_maps (list(ndarray)): A list of ground target on each level.
577
+ """
578
+ h, w = img_size
579
+ lv_size_divs = self.level_size_divisors
580
+ lv_proportion_range = self.level_proportion_range
581
+ lv_text_polys = [[] for i in range(len(lv_size_divs))]
582
+ lv_ignore_polys = [[] for i in range(len(lv_size_divs))]
583
+ level_maps = []
584
+ for poly in text_polys:
585
+ polygon = np.array(poly, dtype=np.int).reshape((1, -1, 2))
586
+ _, _, box_w, box_h = cv2.boundingRect(polygon)
587
+ proportion = max(box_h, box_w) / (h + 1e-8)
588
+
589
+ for ind, proportion_range in enumerate(lv_proportion_range):
590
+ if proportion_range[0] < proportion < proportion_range[1]:
591
+ lv_text_polys[ind].append(poly / lv_size_divs[ind])
592
+
593
+ for ignore_poly in ignore_polys:
594
+ polygon = np.array(ignore_poly, dtype=np.int).reshape((1, -1, 2))
595
+ _, _, box_w, box_h = cv2.boundingRect(polygon)
596
+ proportion = max(box_h, box_w) / (h + 1e-8)
597
+
598
+ for ind, proportion_range in enumerate(lv_proportion_range):
599
+ if proportion_range[0] < proportion < proportion_range[1]:
600
+ lv_ignore_polys[ind].append(ignore_poly / lv_size_divs[ind])
601
+
602
+ for ind, size_divisor in enumerate(lv_size_divs):
603
+ current_level_maps = []
604
+ level_img_size = (h // size_divisor, w // size_divisor)
605
+
606
+ text_region = self.generate_text_region_mask(
607
+ level_img_size, lv_text_polys[ind]
608
+ )[None]
609
+ current_level_maps.append(text_region)
610
+
611
+ center_region = self.generate_center_region_mask(
612
+ level_img_size, lv_text_polys[ind]
613
+ )[None]
614
+ current_level_maps.append(center_region)
615
+
616
+ effective_mask = self.generate_effective_mask(
617
+ level_img_size, lv_ignore_polys[ind]
618
+ )[None]
619
+ current_level_maps.append(effective_mask)
620
+
621
+ fourier_real_map, fourier_image_maps = self.generate_fourier_maps(
622
+ level_img_size, lv_text_polys[ind]
623
+ )
624
+ current_level_maps.append(fourier_real_map)
625
+ current_level_maps.append(fourier_image_maps)
626
+
627
+ level_maps.append(np.concatenate(current_level_maps))
628
+
629
+ return level_maps
630
+
631
+ def generate_targets(self, results):
632
+ """Generate the ground truth targets for FCENet.
633
+
634
+ Args:
635
+ results (dict): The input result dictionary.
636
+
637
+ Returns:
638
+ results (dict): The output result dictionary.
639
+ """
640
+
641
+ assert isinstance(results, dict)
642
+ image = results["image"]
643
+ polygons = results["polys"]
644
+ ignore_tags = results["ignore_tags"]
645
+ h, w, _ = image.shape
646
+
647
+ polygon_masks = []
648
+ polygon_masks_ignore = []
649
+ for tag, polygon in zip(ignore_tags, polygons):
650
+ if tag is True:
651
+ polygon_masks_ignore.append(polygon)
652
+ else:
653
+ polygon_masks.append(polygon)
654
+
655
+ level_maps = self.generate_level_targets(
656
+ (h, w), polygon_masks, polygon_masks_ignore
657
+ )
658
+
659
+ mapping = {
660
+ "p3_maps": level_maps[0],
661
+ "p4_maps": level_maps[1],
662
+ "p5_maps": level_maps[2],
663
+ }
664
+ for key, value in mapping.items():
665
+ results[key] = value
666
+
667
+ return results
668
+
669
+ def __call__(self, results):
670
+ results = self.generate_targets(results)
671
+ return results
ocr/ppocr/data/imaug/gen_table_mask.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, division, print_function, unicode_literals
2
+
3
+ import cv2
4
+ import numpy as np
5
+
6
+
7
+ class GenTableMask(object):
8
+ """gen table mask"""
9
+
10
+ def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs):
11
+ self.shrink_h_max = 5
12
+ self.shrink_w_max = 5
13
+ self.mask_type = mask_type
14
+
15
+ def projection(self, erosion, h, w, spilt_threshold=0):
16
+ # 水平投影
17
+ projection_map = np.ones_like(erosion)
18
+ project_val_array = [0 for _ in range(0, h)]
19
+
20
+ for j in range(0, h):
21
+ for i in range(0, w):
22
+ if erosion[j, i] == 255:
23
+ project_val_array[j] += 1
24
+ # 根据数组,获取切割点
25
+ start_idx = 0 # 记录进入字符区的索引
26
+ end_idx = 0 # 记录进入空白区域的索引
27
+ in_text = False # 是否遍历到了字符区内
28
+ box_list = []
29
+ for i in range(len(project_val_array)):
30
+ if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
31
+ in_text = True
32
+ start_idx = i
33
+ elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
34
+ end_idx = i
35
+ in_text = False
36
+ if end_idx - start_idx <= 2:
37
+ continue
38
+ box_list.append((start_idx, end_idx + 1))
39
+
40
+ if in_text:
41
+ box_list.append((start_idx, h - 1))
42
+ # 绘制投影直方图
43
+ for j in range(0, h):
44
+ for i in range(0, project_val_array[j]):
45
+ projection_map[j, i] = 0
46
+ return box_list, projection_map
47
+
48
+ def projection_cx(self, box_img):
49
+ box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
50
+ h, w = box_gray_img.shape
51
+ # 灰度图片进行二值化处理
52
+ ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV)
53
+ # 纵向腐蚀
54
+ if h < w:
55
+ kernel = np.ones((2, 1), np.uint8)
56
+ erode = cv2.erode(thresh1, kernel, iterations=1)
57
+ else:
58
+ erode = thresh1
59
+ # 水平膨胀
60
+ kernel = np.ones((1, 5), np.uint8)
61
+ erosion = cv2.dilate(erode, kernel, iterations=1)
62
+ # 水平投影
63
+ projection_map = np.ones_like(erosion)
64
+ project_val_array = [0 for _ in range(0, h)]
65
+
66
+ for j in range(0, h):
67
+ for i in range(0, w):
68
+ if erosion[j, i] == 255:
69
+ project_val_array[j] += 1
70
+ # 根据数组,获取切割点
71
+ start_idx = 0 # 记录进入字符区的索引
72
+ end_idx = 0 # 记录进入空白区域的索引
73
+ in_text = False # 是否遍历到了字符区内
74
+ box_list = []
75
+ spilt_threshold = 0
76
+ for i in range(len(project_val_array)):
77
+ if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
78
+ in_text = True
79
+ start_idx = i
80
+ elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
81
+ end_idx = i
82
+ in_text = False
83
+ if end_idx - start_idx <= 2:
84
+ continue
85
+ box_list.append((start_idx, end_idx + 1))
86
+
87
+ if in_text:
88
+ box_list.append((start_idx, h - 1))
89
+ # 绘制投影直方图
90
+ for j in range(0, h):
91
+ for i in range(0, project_val_array[j]):
92
+ projection_map[j, i] = 0
93
+ split_bbox_list = []
94
+ if len(box_list) > 1:
95
+ for i, (h_start, h_end) in enumerate(box_list):
96
+ if i == 0:
97
+ h_start = 0
98
+ if i == len(box_list):
99
+ h_end = h
100
+ word_img = erosion[h_start : h_end + 1, :]
101
+ word_h, word_w = word_img.shape
102
+ w_split_list, w_projection_map = self.projection(
103
+ word_img.T, word_w, word_h
104
+ )
105
+ w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
106
+ if h_start > 0:
107
+ h_start -= 1
108
+ h_end += 1
109
+ word_img = box_img[h_start : h_end + 1 :, w_start : w_end + 1, :]
110
+ split_bbox_list.append([w_start, h_start, w_end, h_end])
111
+ else:
112
+ split_bbox_list.append([0, 0, w, h])
113
+ return split_bbox_list
114
+
115
+ def shrink_bbox(self, bbox):
116
+ left, top, right, bottom = bbox
117
+ sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max)
118
+ sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max)
119
+ left_new = left + sh_w
120
+ right_new = right - sh_w
121
+ top_new = top + sh_h
122
+ bottom_new = bottom - sh_h
123
+ if left_new >= right_new:
124
+ left_new = left
125
+ right_new = right
126
+ if top_new >= bottom_new:
127
+ top_new = top
128
+ bottom_new = bottom
129
+ return [left_new, top_new, right_new, bottom_new]
130
+
131
+ def __call__(self, data):
132
+ img = data["image"]
133
+ cells = data["cells"]
134
+ height, width = img.shape[0:2]
135
+ if self.mask_type == 1:
136
+ mask_img = np.zeros((height, width), dtype=np.float32)
137
+ else:
138
+ mask_img = np.zeros((height, width, 3), dtype=np.float32)
139
+ cell_num = len(cells)
140
+ for cno in range(cell_num):
141
+ if "bbox" in cells[cno]:
142
+ bbox = cells[cno]["bbox"]
143
+ left, top, right, bottom = bbox
144
+ box_img = img[top:bottom, left:right, :].copy()
145
+ split_bbox_list = self.projection_cx(box_img)
146
+ for sno in range(len(split_bbox_list)):
147
+ split_bbox_list[sno][0] += left
148
+ split_bbox_list[sno][1] += top
149
+ split_bbox_list[sno][2] += left
150
+ split_bbox_list[sno][3] += top
151
+
152
+ for sno in range(len(split_bbox_list)):
153
+ left, top, right, bottom = split_bbox_list[sno]
154
+ left, top, right, bottom = self.shrink_bbox(
155
+ [left, top, right, bottom]
156
+ )
157
+ if self.mask_type == 1:
158
+ mask_img[top:bottom, left:right] = 1.0
159
+ data["mask_img"] = mask_img
160
+ else:
161
+ mask_img[top:bottom, left:right, :] = (255, 255, 255)
162
+ data["image"] = mask_img
163
+ return data
164
+
165
+
166
+ class ResizeTableImage(object):
167
+ def __init__(self, max_len, **kwargs):
168
+ super(ResizeTableImage, self).__init__()
169
+ self.max_len = max_len
170
+
171
+ def get_img_bbox(self, cells):
172
+ bbox_list = []
173
+ if len(cells) == 0:
174
+ return bbox_list
175
+ cell_num = len(cells)
176
+ for cno in range(cell_num):
177
+ if "bbox" in cells[cno]:
178
+ bbox = cells[cno]["bbox"]
179
+ bbox_list.append(bbox)
180
+ return bbox_list
181
+
182
+ def resize_img_table(self, img, bbox_list, max_len):
183
+ height, width = img.shape[0:2]
184
+ ratio = max_len / (max(height, width) * 1.0)
185
+ resize_h = int(height * ratio)
186
+ resize_w = int(width * ratio)
187
+ img_new = cv2.resize(img, (resize_w, resize_h))
188
+ bbox_list_new = []
189
+ for bno in range(len(bbox_list)):
190
+ left, top, right, bottom = bbox_list[bno].copy()
191
+ left = int(left * ratio)
192
+ top = int(top * ratio)
193
+ right = int(right * ratio)
194
+ bottom = int(bottom * ratio)
195
+ bbox_list_new.append([left, top, right, bottom])
196
+ return img_new, bbox_list_new
197
+
198
+ def __call__(self, data):
199
+ img = data["image"]
200
+ if "cells" not in data:
201
+ cells = []
202
+ else:
203
+ cells = data["cells"]
204
+ bbox_list = self.get_img_bbox(cells)
205
+ img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len)
206
+ data["image"] = img_new
207
+ cell_num = len(cells)
208
+ bno = 0
209
+ for cno in range(cell_num):
210
+ if "bbox" in data["cells"][cno]:
211
+ data["cells"][cno]["bbox"] = bbox_list_new[bno]
212
+ bno += 1
213
+ data["max_len"] = self.max_len
214
+ return data
215
+
216
+
217
+ class PaddingTableImage(object):
218
+ def __init__(self, **kwargs):
219
+ super(PaddingTableImage, self).__init__()
220
+
221
+ def __call__(self, data):
222
+ img = data["image"]
223
+ max_len = data["max_len"]
224
+ padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32)
225
+ height, width = img.shape[0:2]
226
+ padding_img[0:height, 0:width, :] = img.copy()
227
+ data["image"] = padding_img
228
+ return data
ocr/ppocr/data/imaug/iaa_augment.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, division, print_function, unicode_literals
2
+
3
+ import imgaug
4
+ import imgaug.augmenters as iaa
5
+ import numpy as np
6
+
7
+
8
+ class AugmenterBuilder(object):
9
+ def __init__(self):
10
+ pass
11
+
12
+ def build(self, args, root=True):
13
+ if args is None or len(args) == 0:
14
+ return None
15
+ elif isinstance(args, list):
16
+ if root:
17
+ sequence = [self.build(value, root=False) for value in args]
18
+ return iaa.Sequential(sequence)
19
+ else:
20
+ return getattr(iaa, args[0])(
21
+ *[self.to_tuple_if_list(a) for a in args[1:]]
22
+ )
23
+ elif isinstance(args, dict):
24
+ cls = getattr(iaa, args["type"])
25
+ return cls(**{k: self.to_tuple_if_list(v) for k, v in args["args"].items()})
26
+ else:
27
+ raise RuntimeError("unknown augmenter arg: " + str(args))
28
+
29
+ def to_tuple_if_list(self, obj):
30
+ if isinstance(obj, list):
31
+ return tuple(obj)
32
+ return obj
33
+
34
+
35
+ class IaaAugment:
36
+ def __init__(self, augmenter_args=None, **kwargs):
37
+ if augmenter_args is None:
38
+ augmenter_args = [
39
+ {"type": "Fliplr", "args": {"p": 0.5}},
40
+ {"type": "Affine", "args": {"rotate": [-10, 10]}},
41
+ {"type": "Resize", "args": {"size": [0.5, 3]}},
42
+ ]
43
+ self.augmenter = AugmenterBuilder().build(augmenter_args)
44
+
45
+ def __call__(self, data):
46
+ image = data["image"]
47
+ shape = image.shape
48
+
49
+ if self.augmenter:
50
+ aug = self.augmenter.to_deterministic()
51
+ data["image"] = aug.augment_image(image)
52
+ data = self.may_augment_annotation(aug, data, shape)
53
+ return data
54
+
55
+ def may_augment_annotation(self, aug, data, shape):
56
+ if aug is None:
57
+ return data
58
+
59
+ line_polys = []
60
+ for poly in data["polys"]:
61
+ new_poly = self.may_augment_poly(aug, shape, poly)
62
+ line_polys.append(new_poly)
63
+ data["polys"] = np.array(line_polys)
64
+ return data
65
+
66
+ def may_augment_poly(self, aug, img_shape, poly):
67
+ keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly]
68
+ keypoints = aug.augment_keypoints(
69
+ [imgaug.KeypointsOnImage(keypoints, shape=img_shape)]
70
+ )[0].keypoints
71
+ poly = [(p.x, p.y) for p in keypoints]
72
+ return poly
ocr/ppocr/data/imaug/label_ops.py ADDED
@@ -0,0 +1,1046 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, division, print_function, unicode_literals
2
+
3
+ import copy
4
+ import json
5
+
6
+ import numpy as np
7
+ from shapely.geometry import LineString, Point, Polygon
8
+
9
+
10
+ class ClsLabelEncode(object):
11
+ def __init__(self, label_list, **kwargs):
12
+ self.label_list = label_list
13
+
14
+ def __call__(self, data):
15
+ label = data["label"]
16
+ if label not in self.label_list:
17
+ return None
18
+ label = self.label_list.index(label)
19
+ data["label"] = label
20
+ return data
21
+
22
+
23
+ class DetLabelEncode(object):
24
+ def __init__(self, **kwargs):
25
+ pass
26
+
27
+ def __call__(self, data):
28
+ label = data["label"]
29
+ label = json.loads(label)
30
+ nBox = len(label)
31
+ boxes, txts, txt_tags = [], [], []
32
+ for bno in range(0, nBox):
33
+ box = label[bno]["points"]
34
+ txt = label[bno]["transcription"]
35
+ boxes.append(box)
36
+ txts.append(txt)
37
+ if txt in ["*", "###"]:
38
+ txt_tags.append(True)
39
+ else:
40
+ txt_tags.append(False)
41
+ if len(boxes) == 0:
42
+ return None
43
+ boxes = self.expand_points_num(boxes)
44
+ boxes = np.array(boxes, dtype=np.float32)
45
+ txt_tags = np.array(txt_tags, dtype=np.bool)
46
+
47
+ data["polys"] = boxes
48
+ data["texts"] = txts
49
+ data["ignore_tags"] = txt_tags
50
+ return data
51
+
52
+ def order_points_clockwise(self, pts):
53
+ rect = np.zeros((4, 2), dtype="float32")
54
+ s = pts.sum(axis=1)
55
+ rect[0] = pts[np.argmin(s)]
56
+ rect[2] = pts[np.argmax(s)]
57
+ diff = np.diff(pts, axis=1)
58
+ rect[1] = pts[np.argmin(diff)]
59
+ rect[3] = pts[np.argmax(diff)]
60
+ return rect
61
+
62
+ def expand_points_num(self, boxes):
63
+ max_points_num = 0
64
+ for box in boxes:
65
+ if len(box) > max_points_num:
66
+ max_points_num = len(box)
67
+ ex_boxes = []
68
+ for box in boxes:
69
+ ex_box = box + [box[-1]] * (max_points_num - len(box))
70
+ ex_boxes.append(ex_box)
71
+ return ex_boxes
72
+
73
+
74
+ class BaseRecLabelEncode(object):
75
+ """Convert between text-label and text-index"""
76
+
77
+ def __init__(self, max_text_length, character_dict_path=None, use_space_char=False):
78
+
79
+ self.max_text_len = max_text_length
80
+ self.beg_str = "sos"
81
+ self.end_str = "eos"
82
+ self.lower = False
83
+
84
+ if character_dict_path is None:
85
+ self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
86
+ dict_character = list(self.character_str)
87
+ self.lower = True
88
+ else:
89
+ self.character_str = []
90
+ with open(character_dict_path, "rb") as fin:
91
+ lines = fin.readlines()
92
+ for line in lines:
93
+ line = line.decode("utf-8").strip("\n").strip("\r\n")
94
+ self.character_str.append(line)
95
+ if use_space_char:
96
+ self.character_str.append(" ")
97
+ dict_character = list(self.character_str)
98
+ dict_character = self.add_special_char(dict_character)
99
+ self.dict = {}
100
+ for i, char in enumerate(dict_character):
101
+ self.dict[char] = i
102
+ self.character = dict_character
103
+
104
+ def add_special_char(self, dict_character):
105
+ return dict_character
106
+
107
+ def encode(self, text):
108
+ """convert text-label into text-index.
109
+ input:
110
+ text: text labels of each image. [batch_size]
111
+
112
+ output:
113
+ text: concatenated text index for CTCLoss.
114
+ [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
115
+ length: length of each text. [batch_size]
116
+ """
117
+ if len(text) == 0 or len(text) > self.max_text_len:
118
+ return None
119
+ if self.lower:
120
+ text = text.lower()
121
+ text_list = []
122
+ for char in text:
123
+ if char not in self.dict:
124
+ continue
125
+ text_list.append(self.dict[char])
126
+ if len(text_list) == 0:
127
+ return None
128
+ return text_list
129
+
130
+
131
+ class NRTRLabelEncode(BaseRecLabelEncode):
132
+ """Convert between text-label and text-index"""
133
+
134
+ def __init__(
135
+ self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
136
+ ):
137
+
138
+ super(NRTRLabelEncode, self).__init__(
139
+ max_text_length, character_dict_path, use_space_char
140
+ )
141
+
142
+ def __call__(self, data):
143
+ text = data["label"]
144
+ text = self.encode(text)
145
+ if text is None:
146
+ return None
147
+ if len(text) >= self.max_text_len - 1:
148
+ return None
149
+ data["length"] = np.array(len(text))
150
+ text.insert(0, 2)
151
+ text.append(3)
152
+ text = text + [0] * (self.max_text_len - len(text))
153
+ data["label"] = np.array(text)
154
+ return data
155
+
156
+ def add_special_char(self, dict_character):
157
+ dict_character = ["blank", "<unk>", "<s>", "</s>"] + dict_character
158
+ return dict_character
159
+
160
+
161
+ class CTCLabelEncode(BaseRecLabelEncode):
162
+ """Convert between text-label and text-index"""
163
+
164
+ def __init__(
165
+ self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
166
+ ):
167
+ super(CTCLabelEncode, self).__init__(
168
+ max_text_length, character_dict_path, use_space_char
169
+ )
170
+
171
+ def __call__(self, data):
172
+ text = data["label"]
173
+ text = self.encode(text)
174
+ if text is None:
175
+ return None
176
+ data["length"] = np.array(len(text))
177
+ text = text + [0] * (self.max_text_len - len(text))
178
+ data["label"] = np.array(text)
179
+
180
+ label = [0] * len(self.character)
181
+ for x in text:
182
+ label[x] += 1
183
+ data["label_ace"] = np.array(label)
184
+ return data
185
+
186
+ def add_special_char(self, dict_character):
187
+ dict_character = ["blank"] + dict_character
188
+ return dict_character
189
+
190
+
191
+ class E2ELabelEncodeTest(BaseRecLabelEncode):
192
+ def __init__(
193
+ self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
194
+ ):
195
+ super(E2ELabelEncodeTest, self).__init__(
196
+ max_text_length, character_dict_path, use_space_char
197
+ )
198
+
199
+ def __call__(self, data):
200
+ import json
201
+
202
+ padnum = len(self.dict)
203
+ label = data["label"]
204
+ label = json.loads(label)
205
+ nBox = len(label)
206
+ boxes, txts, txt_tags = [], [], []
207
+ for bno in range(0, nBox):
208
+ box = label[bno]["points"]
209
+ txt = label[bno]["transcription"]
210
+ boxes.append(box)
211
+ txts.append(txt)
212
+ if txt in ["*", "###"]:
213
+ txt_tags.append(True)
214
+ else:
215
+ txt_tags.append(False)
216
+ boxes = np.array(boxes, dtype=np.float32)
217
+ txt_tags = np.array(txt_tags, dtype=np.bool)
218
+ data["polys"] = boxes
219
+ data["ignore_tags"] = txt_tags
220
+ temp_texts = []
221
+ for text in txts:
222
+ text = text.lower()
223
+ text = self.encode(text)
224
+ if text is None:
225
+ return None
226
+ text = text + [padnum] * (self.max_text_len - len(text)) # use 36 to pad
227
+ temp_texts.append(text)
228
+ data["texts"] = np.array(temp_texts)
229
+ return data
230
+
231
+
232
+ class E2ELabelEncodeTrain(object):
233
+ def __init__(self, **kwargs):
234
+ pass
235
+
236
+ def __call__(self, data):
237
+ import json
238
+
239
+ label = data["label"]
240
+ label = json.loads(label)
241
+ nBox = len(label)
242
+ boxes, txts, txt_tags = [], [], []
243
+ for bno in range(0, nBox):
244
+ box = label[bno]["points"]
245
+ txt = label[bno]["transcription"]
246
+ boxes.append(box)
247
+ txts.append(txt)
248
+ if txt in ["*", "###"]:
249
+ txt_tags.append(True)
250
+ else:
251
+ txt_tags.append(False)
252
+ boxes = np.array(boxes, dtype=np.float32)
253
+ txt_tags = np.array(txt_tags, dtype=np.bool)
254
+
255
+ data["polys"] = boxes
256
+ data["texts"] = txts
257
+ data["ignore_tags"] = txt_tags
258
+ return data
259
+
260
+
261
+ class KieLabelEncode(object):
262
+ def __init__(self, character_dict_path, norm=10, directed=False, **kwargs):
263
+ super(KieLabelEncode, self).__init__()
264
+ self.dict = dict({"": 0})
265
+ with open(character_dict_path, "r", encoding="utf-8") as fr:
266
+ idx = 1
267
+ for line in fr:
268
+ char = line.strip()
269
+ self.dict[char] = idx
270
+ idx += 1
271
+ self.norm = norm
272
+ self.directed = directed
273
+
274
+ def compute_relation(self, boxes):
275
+ """Compute relation between every two boxes."""
276
+ x1s, y1s = boxes[:, 0:1], boxes[:, 1:2]
277
+ x2s, y2s = boxes[:, 4:5], boxes[:, 5:6]
278
+ ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1)
279
+ dxs = (x1s[:, 0][None] - x1s) / self.norm
280
+ dys = (y1s[:, 0][None] - y1s) / self.norm
281
+ xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs
282
+ whs = ws / hs + np.zeros_like(xhhs)
283
+ relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1)
284
+ bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32)
285
+ return relations, bboxes
286
+
287
+ def pad_text_indices(self, text_inds):
288
+ """Pad text index to same length."""
289
+ max_len = 300
290
+ recoder_len = max([len(text_ind) for text_ind in text_inds])
291
+ padded_text_inds = -np.ones((len(text_inds), max_len), np.int32)
292
+ for idx, text_ind in enumerate(text_inds):
293
+ padded_text_inds[idx, : len(text_ind)] = np.array(text_ind)
294
+ return padded_text_inds, recoder_len
295
+
296
+ def list_to_numpy(self, ann_infos):
297
+ """Convert bboxes, relations, texts and labels to ndarray."""
298
+ boxes, text_inds = ann_infos["points"], ann_infos["text_inds"]
299
+ boxes = np.array(boxes, np.int32)
300
+ relations, bboxes = self.compute_relation(boxes)
301
+
302
+ labels = ann_infos.get("labels", None)
303
+ if labels is not None:
304
+ labels = np.array(labels, np.int32)
305
+ edges = ann_infos.get("edges", None)
306
+ if edges is not None:
307
+ labels = labels[:, None]
308
+ edges = np.array(edges)
309
+ edges = (edges[:, None] == edges[None, :]).astype(np.int32)
310
+ if self.directed:
311
+ edges = (edges & labels == 1).astype(np.int32)
312
+ np.fill_diagonal(edges, -1)
313
+ labels = np.concatenate([labels, edges], -1)
314
+ padded_text_inds, recoder_len = self.pad_text_indices(text_inds)
315
+ max_num = 300
316
+ temp_bboxes = np.zeros([max_num, 4])
317
+ h, _ = bboxes.shape
318
+ temp_bboxes[:h, :] = bboxes
319
+
320
+ temp_relations = np.zeros([max_num, max_num, 5])
321
+ temp_relations[:h, :h, :] = relations
322
+
323
+ temp_padded_text_inds = np.zeros([max_num, max_num])
324
+ temp_padded_text_inds[:h, :] = padded_text_inds
325
+
326
+ temp_labels = np.zeros([max_num, max_num])
327
+ temp_labels[:h, : h + 1] = labels
328
+
329
+ tag = np.array([h, recoder_len])
330
+ return dict(
331
+ image=ann_infos["image"],
332
+ points=temp_bboxes,
333
+ relations=temp_relations,
334
+ texts=temp_padded_text_inds,
335
+ labels=temp_labels,
336
+ tag=tag,
337
+ )
338
+
339
+ def convert_canonical(self, points_x, points_y):
340
+
341
+ assert len(points_x) == 4
342
+ assert len(points_y) == 4
343
+
344
+ points = [Point(points_x[i], points_y[i]) for i in range(4)]
345
+
346
+ polygon = Polygon([(p.x, p.y) for p in points])
347
+ min_x, min_y, _, _ = polygon.bounds
348
+ points_to_lefttop = [
349
+ LineString([points[i], Point(min_x, min_y)]) for i in range(4)
350
+ ]
351
+ distances = np.array([line.length for line in points_to_lefttop])
352
+ sort_dist_idx = np.argsort(distances)
353
+ lefttop_idx = sort_dist_idx[0]
354
+
355
+ if lefttop_idx == 0:
356
+ point_orders = [0, 1, 2, 3]
357
+ elif lefttop_idx == 1:
358
+ point_orders = [1, 2, 3, 0]
359
+ elif lefttop_idx == 2:
360
+ point_orders = [2, 3, 0, 1]
361
+ else:
362
+ point_orders = [3, 0, 1, 2]
363
+
364
+ sorted_points_x = [points_x[i] for i in point_orders]
365
+ sorted_points_y = [points_y[j] for j in point_orders]
366
+
367
+ return sorted_points_x, sorted_points_y
368
+
369
+ def sort_vertex(self, points_x, points_y):
370
+
371
+ assert len(points_x) == 4
372
+ assert len(points_y) == 4
373
+
374
+ x = np.array(points_x)
375
+ y = np.array(points_y)
376
+ center_x = np.sum(x) * 0.25
377
+ center_y = np.sum(y) * 0.25
378
+
379
+ x_arr = np.array(x - center_x)
380
+ y_arr = np.array(y - center_y)
381
+
382
+ angle = np.arctan2(y_arr, x_arr) * 180.0 / np.pi
383
+ sort_idx = np.argsort(angle)
384
+
385
+ sorted_points_x, sorted_points_y = [], []
386
+ for i in range(4):
387
+ sorted_points_x.append(points_x[sort_idx[i]])
388
+ sorted_points_y.append(points_y[sort_idx[i]])
389
+
390
+ return self.convert_canonical(sorted_points_x, sorted_points_y)
391
+
392
+ def __call__(self, data):
393
+ import json
394
+
395
+ label = data["label"]
396
+ annotations = json.loads(label)
397
+ boxes, texts, text_inds, labels, edges = [], [], [], [], []
398
+ for ann in annotations:
399
+ box = ann["points"]
400
+ x_list = [box[i][0] for i in range(4)]
401
+ y_list = [box[i][1] for i in range(4)]
402
+ sorted_x_list, sorted_y_list = self.sort_vertex(x_list, y_list)
403
+ sorted_box = []
404
+ for x, y in zip(sorted_x_list, sorted_y_list):
405
+ sorted_box.append(x)
406
+ sorted_box.append(y)
407
+ boxes.append(sorted_box)
408
+ text = ann["transcription"]
409
+ texts.append(ann["transcription"])
410
+ text_ind = [self.dict[c] for c in text if c in self.dict]
411
+ text_inds.append(text_ind)
412
+ if "label" in ann.keys():
413
+ labels.append(ann["label"])
414
+ elif "key_cls" in ann.keys():
415
+ labels.append(ann["key_cls"])
416
+ else:
417
+ raise ValueError(
418
+ "Cannot found 'key_cls' in ann.keys(), please check your training annotation."
419
+ )
420
+ edges.append(ann.get("edge", 0))
421
+ ann_infos = dict(
422
+ image=data["image"],
423
+ points=boxes,
424
+ texts=texts,
425
+ text_inds=text_inds,
426
+ edges=edges,
427
+ labels=labels,
428
+ )
429
+
430
+ return self.list_to_numpy(ann_infos)
431
+
432
+
433
+ class AttnLabelEncode(BaseRecLabelEncode):
434
+ """Convert between text-label and text-index"""
435
+
436
+ def __init__(
437
+ self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
438
+ ):
439
+ super(AttnLabelEncode, self).__init__(
440
+ max_text_length, character_dict_path, use_space_char
441
+ )
442
+
443
+ def add_special_char(self, dict_character):
444
+ self.beg_str = "sos"
445
+ self.end_str = "eos"
446
+ dict_character = [self.beg_str] + dict_character + [self.end_str]
447
+ return dict_character
448
+
449
+ def __call__(self, data):
450
+ text = data["label"]
451
+ text = self.encode(text)
452
+ if text is None:
453
+ return None
454
+ if len(text) >= self.max_text_len:
455
+ return None
456
+ data["length"] = np.array(len(text))
457
+ text = (
458
+ [0]
459
+ + text
460
+ + [len(self.character) - 1]
461
+ + [0] * (self.max_text_len - len(text) - 2)
462
+ )
463
+ data["label"] = np.array(text)
464
+ return data
465
+
466
+ def get_ignored_tokens(self):
467
+ beg_idx = self.get_beg_end_flag_idx("beg")
468
+ end_idx = self.get_beg_end_flag_idx("end")
469
+ return [beg_idx, end_idx]
470
+
471
+ def get_beg_end_flag_idx(self, beg_or_end):
472
+ if beg_or_end == "beg":
473
+ idx = np.array(self.dict[self.beg_str])
474
+ elif beg_or_end == "end":
475
+ idx = np.array(self.dict[self.end_str])
476
+ else:
477
+ assert False, "Unsupport type %s in get_beg_end_flag_idx" % beg_or_end
478
+ return idx
479
+
480
+
481
+ class SEEDLabelEncode(BaseRecLabelEncode):
482
+ """Convert between text-label and text-index"""
483
+
484
+ def __init__(
485
+ self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
486
+ ):
487
+ super(SEEDLabelEncode, self).__init__(
488
+ max_text_length, character_dict_path, use_space_char
489
+ )
490
+
491
+ def add_special_char(self, dict_character):
492
+ self.padding = "padding"
493
+ self.end_str = "eos"
494
+ self.unknown = "unknown"
495
+ dict_character = dict_character + [self.end_str, self.padding, self.unknown]
496
+ return dict_character
497
+
498
+ def __call__(self, data):
499
+ text = data["label"]
500
+ text = self.encode(text)
501
+ if text is None:
502
+ return None
503
+ if len(text) >= self.max_text_len:
504
+ return None
505
+ data["length"] = np.array(len(text)) + 1 # conclude eos
506
+ text = (
507
+ text
508
+ + [len(self.character) - 3]
509
+ + [len(self.character) - 2] * (self.max_text_len - len(text) - 1)
510
+ )
511
+ data["label"] = np.array(text)
512
+ return data
513
+
514
+
515
+ class SRNLabelEncode(BaseRecLabelEncode):
516
+ """Convert between text-label and text-index"""
517
+
518
+ def __init__(
519
+ self,
520
+ max_text_length=25,
521
+ character_dict_path=None,
522
+ use_space_char=False,
523
+ **kwargs
524
+ ):
525
+ super(SRNLabelEncode, self).__init__(
526
+ max_text_length, character_dict_path, use_space_char
527
+ )
528
+
529
+ def add_special_char(self, dict_character):
530
+ dict_character = dict_character + [self.beg_str, self.end_str]
531
+ return dict_character
532
+
533
+ def __call__(self, data):
534
+ text = data["label"]
535
+ text = self.encode(text)
536
+ char_num = len(self.character)
537
+ if text is None:
538
+ return None
539
+ if len(text) > self.max_text_len:
540
+ return None
541
+ data["length"] = np.array(len(text))
542
+ text = text + [char_num - 1] * (self.max_text_len - len(text))
543
+ data["label"] = np.array(text)
544
+ return data
545
+
546
+ def get_ignored_tokens(self):
547
+ beg_idx = self.get_beg_end_flag_idx("beg")
548
+ end_idx = self.get_beg_end_flag_idx("end")
549
+ return [beg_idx, end_idx]
550
+
551
+ def get_beg_end_flag_idx(self, beg_or_end):
552
+ if beg_or_end == "beg":
553
+ idx = np.array(self.dict[self.beg_str])
554
+ elif beg_or_end == "end":
555
+ idx = np.array(self.dict[self.end_str])
556
+ else:
557
+ assert False, "Unsupport type %s in get_beg_end_flag_idx" % beg_or_end
558
+ return idx
559
+
560
+
561
+ class TableLabelEncode(object):
562
+ """Convert between text-label and text-index"""
563
+
564
+ def __init__(
565
+ self,
566
+ max_text_length,
567
+ max_elem_length,
568
+ max_cell_num,
569
+ character_dict_path,
570
+ span_weight=1.0,
571
+ **kwargs
572
+ ):
573
+ self.max_text_length = max_text_length
574
+ self.max_elem_length = max_elem_length
575
+ self.max_cell_num = max_cell_num
576
+ list_character, list_elem = self.load_char_elem_dict(character_dict_path)
577
+ list_character = self.add_special_char(list_character)
578
+ list_elem = self.add_special_char(list_elem)
579
+ self.dict_character = {}
580
+ for i, char in enumerate(list_character):
581
+ self.dict_character[char] = i
582
+ self.dict_elem = {}
583
+ for i, elem in enumerate(list_elem):
584
+ self.dict_elem[elem] = i
585
+ self.span_weight = span_weight
586
+
587
+ def load_char_elem_dict(self, character_dict_path):
588
+ list_character = []
589
+ list_elem = []
590
+ with open(character_dict_path, "rb") as fin:
591
+ lines = fin.readlines()
592
+ substr = lines[0].decode("utf-8").strip("\r\n").split("\t")
593
+ character_num = int(substr[0])
594
+ elem_num = int(substr[1])
595
+ for cno in range(1, 1 + character_num):
596
+ character = lines[cno].decode("utf-8").strip("\r\n")
597
+ list_character.append(character)
598
+ for eno in range(1 + character_num, 1 + character_num + elem_num):
599
+ elem = lines[eno].decode("utf-8").strip("\r\n")
600
+ list_elem.append(elem)
601
+ return list_character, list_elem
602
+
603
+ def add_special_char(self, list_character):
604
+ self.beg_str = "sos"
605
+ self.end_str = "eos"
606
+ list_character = [self.beg_str] + list_character + [self.end_str]
607
+ return list_character
608
+
609
+ def get_span_idx_list(self):
610
+ span_idx_list = []
611
+ for elem in self.dict_elem:
612
+ if "span" in elem:
613
+ span_idx_list.append(self.dict_elem[elem])
614
+ return span_idx_list
615
+
616
+ def __call__(self, data):
617
+ cells = data["cells"]
618
+ structure = data["structure"]["tokens"]
619
+ structure = self.encode(structure, "elem")
620
+ if structure is None:
621
+ return None
622
+ elem_num = len(structure)
623
+ structure = [0] + structure + [len(self.dict_elem) - 1]
624
+ structure = structure + [0] * (self.max_elem_length + 2 - len(structure))
625
+ structure = np.array(structure)
626
+ data["structure"] = structure
627
+ elem_char_idx1 = self.dict_elem["<td>"]
628
+ elem_char_idx2 = self.dict_elem["<td"]
629
+ span_idx_list = self.get_span_idx_list()
630
+ td_idx_list = np.logical_or(
631
+ structure == elem_char_idx1, structure == elem_char_idx2
632
+ )
633
+ td_idx_list = np.where(td_idx_list)[0]
634
+
635
+ structure_mask = np.ones((self.max_elem_length + 2, 1), dtype=np.float32)
636
+ bbox_list = np.zeros((self.max_elem_length + 2, 4), dtype=np.float32)
637
+ bbox_list_mask = np.zeros((self.max_elem_length + 2, 1), dtype=np.float32)
638
+ img_height, img_width, img_ch = data["image"].shape
639
+ if len(span_idx_list) > 0:
640
+ span_weight = len(td_idx_list) * 1.0 / len(span_idx_list)
641
+ span_weight = min(max(span_weight, 1.0), self.span_weight)
642
+ for cno in range(len(cells)):
643
+ if "bbox" in cells[cno]:
644
+ bbox = cells[cno]["bbox"].copy()
645
+ bbox[0] = bbox[0] * 1.0 / img_width
646
+ bbox[1] = bbox[1] * 1.0 / img_height
647
+ bbox[2] = bbox[2] * 1.0 / img_width
648
+ bbox[3] = bbox[3] * 1.0 / img_height
649
+ td_idx = td_idx_list[cno]
650
+ bbox_list[td_idx] = bbox
651
+ bbox_list_mask[td_idx] = 1.0
652
+ cand_span_idx = td_idx + 1
653
+ if cand_span_idx < (self.max_elem_length + 2):
654
+ if structure[cand_span_idx] in span_idx_list:
655
+ structure_mask[cand_span_idx] = span_weight
656
+
657
+ data["bbox_list"] = bbox_list
658
+ data["bbox_list_mask"] = bbox_list_mask
659
+ data["structure_mask"] = structure_mask
660
+ char_beg_idx = self.get_beg_end_flag_idx("beg", "char")
661
+ char_end_idx = self.get_beg_end_flag_idx("end", "char")
662
+ elem_beg_idx = self.get_beg_end_flag_idx("beg", "elem")
663
+ elem_end_idx = self.get_beg_end_flag_idx("end", "elem")
664
+ data["sp_tokens"] = np.array(
665
+ [
666
+ char_beg_idx,
667
+ char_end_idx,
668
+ elem_beg_idx,
669
+ elem_end_idx,
670
+ elem_char_idx1,
671
+ elem_char_idx2,
672
+ self.max_text_length,
673
+ self.max_elem_length,
674
+ self.max_cell_num,
675
+ elem_num,
676
+ ]
677
+ )
678
+ return data
679
+
680
+ def encode(self, text, char_or_elem):
681
+ """convert text-label into text-index."""
682
+ if char_or_elem == "char":
683
+ max_len = self.max_text_length
684
+ current_dict = self.dict_character
685
+ else:
686
+ max_len = self.max_elem_length
687
+ current_dict = self.dict_elem
688
+ if len(text) > max_len:
689
+ return None
690
+ if len(text) == 0:
691
+ if char_or_elem == "char":
692
+ return [self.dict_character["space"]]
693
+ else:
694
+ return None
695
+ text_list = []
696
+ for char in text:
697
+ if char not in current_dict:
698
+ return None
699
+ text_list.append(current_dict[char])
700
+ if len(text_list) == 0:
701
+ if char_or_elem == "char":
702
+ return [self.dict_character["space"]]
703
+ else:
704
+ return None
705
+ return text_list
706
+
707
+ def get_ignored_tokens(self, char_or_elem):
708
+ beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
709
+ end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
710
+ return [beg_idx, end_idx]
711
+
712
+ def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
713
+ if char_or_elem == "char":
714
+ if beg_or_end == "beg":
715
+ idx = np.array(self.dict_character[self.beg_str])
716
+ elif beg_or_end == "end":
717
+ idx = np.array(self.dict_character[self.end_str])
718
+ else:
719
+ assert False, (
720
+ "Unsupport type %s in get_beg_end_flag_idx of char" % beg_or_end
721
+ )
722
+ elif char_or_elem == "elem":
723
+ if beg_or_end == "beg":
724
+ idx = np.array(self.dict_elem[self.beg_str])
725
+ elif beg_or_end == "end":
726
+ idx = np.array(self.dict_elem[self.end_str])
727
+ else:
728
+ assert False, (
729
+ "Unsupport type %s in get_beg_end_flag_idx of elem" % beg_or_end
730
+ )
731
+ else:
732
+ assert False, "Unsupport type %s in char_or_elem" % char_or_elem
733
+ return idx
734
+
735
+
736
+ class SARLabelEncode(BaseRecLabelEncode):
737
+ """Convert between text-label and text-index"""
738
+
739
+ def __init__(
740
+ self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
741
+ ):
742
+ super(SARLabelEncode, self).__init__(
743
+ max_text_length, character_dict_path, use_space_char
744
+ )
745
+
746
+ def add_special_char(self, dict_character):
747
+ beg_end_str = "<BOS/EOS>"
748
+ unknown_str = "<UKN>"
749
+ padding_str = "<PAD>"
750
+ dict_character = dict_character + [unknown_str]
751
+ self.unknown_idx = len(dict_character) - 1
752
+ dict_character = dict_character + [beg_end_str]
753
+ self.start_idx = len(dict_character) - 1
754
+ self.end_idx = len(dict_character) - 1
755
+ dict_character = dict_character + [padding_str]
756
+ self.padding_idx = len(dict_character) - 1
757
+
758
+ return dict_character
759
+
760
+ def __call__(self, data):
761
+ text = data["label"]
762
+ text = self.encode(text)
763
+ if text is None:
764
+ return None
765
+ if len(text) >= self.max_text_len - 1:
766
+ return None
767
+ data["length"] = np.array(len(text))
768
+ target = [self.start_idx] + text + [self.end_idx]
769
+ padded_text = [self.padding_idx for _ in range(self.max_text_len)]
770
+
771
+ padded_text[: len(target)] = target
772
+ data["label"] = np.array(padded_text)
773
+ return data
774
+
775
+ def get_ignored_tokens(self):
776
+ return [self.padding_idx]
777
+
778
+
779
+ class PRENLabelEncode(BaseRecLabelEncode):
780
+ def __init__(
781
+ self, max_text_length, character_dict_path, use_space_char=False, **kwargs
782
+ ):
783
+ super(PRENLabelEncode, self).__init__(
784
+ max_text_length, character_dict_path, use_space_char
785
+ )
786
+
787
+ def add_special_char(self, dict_character):
788
+ padding_str = "<PAD>" # 0
789
+ end_str = "<EOS>" # 1
790
+ unknown_str = "<UNK>" # 2
791
+
792
+ dict_character = [padding_str, end_str, unknown_str] + dict_character
793
+ self.padding_idx = 0
794
+ self.end_idx = 1
795
+ self.unknown_idx = 2
796
+
797
+ return dict_character
798
+
799
+ def encode(self, text):
800
+ if len(text) == 0 or len(text) >= self.max_text_len:
801
+ return None
802
+ if self.lower:
803
+ text = text.lower()
804
+ text_list = []
805
+ for char in text:
806
+ if char not in self.dict:
807
+ text_list.append(self.unknown_idx)
808
+ else:
809
+ text_list.append(self.dict[char])
810
+ text_list.append(self.end_idx)
811
+ if len(text_list) < self.max_text_len:
812
+ text_list += [self.padding_idx] * (self.max_text_len - len(text_list))
813
+ return text_list
814
+
815
+ def __call__(self, data):
816
+ text = data["label"]
817
+ encoded_text = self.encode(text)
818
+ if encoded_text is None:
819
+ return None
820
+ data["label"] = np.array(encoded_text)
821
+ return data
822
+
823
+
824
+ class VQATokenLabelEncode(object):
825
+ """
826
+ Label encode for NLP VQA methods
827
+ """
828
+
829
+ def __init__(
830
+ self,
831
+ class_path,
832
+ contains_re=False,
833
+ add_special_ids=False,
834
+ algorithm="LayoutXLM",
835
+ infer_mode=False,
836
+ ocr_engine=None,
837
+ **kwargs
838
+ ):
839
+ super(VQATokenLabelEncode, self).__init__()
840
+ from paddlenlp.transformers import (
841
+ LayoutLMTokenizer,
842
+ LayoutLMv2Tokenizer,
843
+ LayoutXLMTokenizer,
844
+ )
845
+
846
+ from ppocr.utils.utility import load_vqa_bio_label_maps
847
+
848
+ tokenizer_dict = {
849
+ "LayoutXLM": {
850
+ "class": LayoutXLMTokenizer,
851
+ "pretrained_model": "layoutxlm-base-uncased",
852
+ },
853
+ "LayoutLM": {
854
+ "class": LayoutLMTokenizer,
855
+ "pretrained_model": "layoutlm-base-uncased",
856
+ },
857
+ "LayoutLMv2": {
858
+ "class": LayoutLMv2Tokenizer,
859
+ "pretrained_model": "layoutlmv2-base-uncased",
860
+ },
861
+ }
862
+ self.contains_re = contains_re
863
+ tokenizer_config = tokenizer_dict[algorithm]
864
+ self.tokenizer = tokenizer_config["class"].from_pretrained(
865
+ tokenizer_config["pretrained_model"]
866
+ )
867
+ self.label2id_map, id2label_map = load_vqa_bio_label_maps(class_path)
868
+ self.add_special_ids = add_special_ids
869
+ self.infer_mode = infer_mode
870
+ self.ocr_engine = ocr_engine
871
+
872
+ def __call__(self, data):
873
+ # load bbox and label info
874
+ ocr_info = self._load_ocr_info(data)
875
+
876
+ height, width, _ = data["image"].shape
877
+
878
+ words_list = []
879
+ bbox_list = []
880
+ input_ids_list = []
881
+ token_type_ids_list = []
882
+ segment_offset_id = []
883
+ gt_label_list = []
884
+
885
+ entities = []
886
+
887
+ # for re
888
+ train_re = self.contains_re and not self.infer_mode
889
+ if train_re:
890
+ relations = []
891
+ id2label = {}
892
+ entity_id_to_index_map = {}
893
+ empty_entity = set()
894
+
895
+ data["ocr_info"] = copy.deepcopy(ocr_info)
896
+
897
+ for info in ocr_info:
898
+ if train_re:
899
+ # for re
900
+ if len(info["text"]) == 0:
901
+ empty_entity.add(info["id"])
902
+ continue
903
+ id2label[info["id"]] = info["label"]
904
+ relations.extend([tuple(sorted(l)) for l in info["linking"]])
905
+ # smooth_box
906
+ bbox = self._smooth_box(info["bbox"], height, width)
907
+
908
+ text = info["text"]
909
+ encode_res = self.tokenizer.encode(
910
+ text, pad_to_max_seq_len=False, return_attention_mask=True
911
+ )
912
+
913
+ if not self.add_special_ids:
914
+ # TODO: use tok.all_special_ids to remove
915
+ encode_res["input_ids"] = encode_res["input_ids"][1:-1]
916
+ encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1]
917
+ encode_res["attention_mask"] = encode_res["attention_mask"][1:-1]
918
+ # parse label
919
+ if not self.infer_mode:
920
+ label = info["label"]
921
+ gt_label = self._parse_label(label, encode_res)
922
+
923
+ # construct entities for re
924
+ if train_re:
925
+ if gt_label[0] != self.label2id_map["O"]:
926
+ entity_id_to_index_map[info["id"]] = len(entities)
927
+ label = label.upper()
928
+ entities.append(
929
+ {
930
+ "start": len(input_ids_list),
931
+ "end": len(input_ids_list) + len(encode_res["input_ids"]),
932
+ "label": label.upper(),
933
+ }
934
+ )
935
+ else:
936
+ entities.append(
937
+ {
938
+ "start": len(input_ids_list),
939
+ "end": len(input_ids_list) + len(encode_res["input_ids"]),
940
+ "label": "O",
941
+ }
942
+ )
943
+ input_ids_list.extend(encode_res["input_ids"])
944
+ token_type_ids_list.extend(encode_res["token_type_ids"])
945
+ bbox_list.extend([bbox] * len(encode_res["input_ids"]))
946
+ words_list.append(text)
947
+ segment_offset_id.append(len(input_ids_list))
948
+ if not self.infer_mode:
949
+ gt_label_list.extend(gt_label)
950
+
951
+ data["input_ids"] = input_ids_list
952
+ data["token_type_ids"] = token_type_ids_list
953
+ data["bbox"] = bbox_list
954
+ data["attention_mask"] = [1] * len(input_ids_list)
955
+ data["labels"] = gt_label_list
956
+ data["segment_offset_id"] = segment_offset_id
957
+ data["tokenizer_params"] = dict(
958
+ padding_side=self.tokenizer.padding_side,
959
+ pad_token_type_id=self.tokenizer.pad_token_type_id,
960
+ pad_token_id=self.tokenizer.pad_token_id,
961
+ )
962
+ data["entities"] = entities
963
+
964
+ if train_re:
965
+ data["relations"] = relations
966
+ data["id2label"] = id2label
967
+ data["empty_entity"] = empty_entity
968
+ data["entity_id_to_index_map"] = entity_id_to_index_map
969
+ return data
970
+
971
+ def _load_ocr_info(self, data):
972
+ def trans_poly_to_bbox(poly):
973
+ x1 = np.min([p[0] for p in poly])
974
+ x2 = np.max([p[0] for p in poly])
975
+ y1 = np.min([p[1] for p in poly])
976
+ y2 = np.max([p[1] for p in poly])
977
+ return [x1, y1, x2, y2]
978
+
979
+ if self.infer_mode:
980
+ ocr_result = self.ocr_engine.ocr(data["image"], cls=False)
981
+ ocr_info = []
982
+ for res in ocr_result:
983
+ ocr_info.append(
984
+ {
985
+ "text": res[1][0],
986
+ "bbox": trans_poly_to_bbox(res[0]),
987
+ "poly": res[0],
988
+ }
989
+ )
990
+ return ocr_info
991
+ else:
992
+ info = data["label"]
993
+ # read text info
994
+ info_dict = json.loads(info)
995
+ return info_dict["ocr_info"]
996
+
997
+ def _smooth_box(self, bbox, height, width):
998
+ bbox[0] = int(bbox[0] * 1000.0 / width)
999
+ bbox[2] = int(bbox[2] * 1000.0 / width)
1000
+ bbox[1] = int(bbox[1] * 1000.0 / height)
1001
+ bbox[3] = int(bbox[3] * 1000.0 / height)
1002
+ return bbox
1003
+
1004
+ def _parse_label(self, label, encode_res):
1005
+ gt_label = []
1006
+ if label.lower() == "other":
1007
+ gt_label.extend([0] * len(encode_res["input_ids"]))
1008
+ else:
1009
+ gt_label.append(self.label2id_map[("b-" + label).upper()])
1010
+ gt_label.extend(
1011
+ [self.label2id_map[("i-" + label).upper()]]
1012
+ * (len(encode_res["input_ids"]) - 1)
1013
+ )
1014
+ return gt_label
1015
+
1016
+
1017
+ class MultiLabelEncode(BaseRecLabelEncode):
1018
+ def __init__(
1019
+ self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
1020
+ ):
1021
+ super(MultiLabelEncode, self).__init__(
1022
+ max_text_length, character_dict_path, use_space_char
1023
+ )
1024
+
1025
+ self.ctc_encode = CTCLabelEncode(
1026
+ max_text_length, character_dict_path, use_space_char, **kwargs
1027
+ )
1028
+ self.sar_encode = SARLabelEncode(
1029
+ max_text_length, character_dict_path, use_space_char, **kwargs
1030
+ )
1031
+
1032
+ def __call__(self, data):
1033
+
1034
+ data_ctc = copy.deepcopy(data)
1035
+ data_sar = copy.deepcopy(data)
1036
+ data_out = dict()
1037
+ data_out["img_path"] = data.get("img_path", None)
1038
+ data_out["image"] = data["image"]
1039
+ ctc = self.ctc_encode.__call__(data_ctc)
1040
+ sar = self.sar_encode.__call__(data_sar)
1041
+ if ctc is None or sar is None:
1042
+ return None
1043
+ data_out["label_ctc"] = ctc["label"]
1044
+ data_out["label_sar"] = sar["label"]
1045
+ data_out["length"] = ctc["length"]
1046
+ return data_out
ocr/ppocr/data/imaug/make_border_map.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, division, print_function, unicode_literals
2
+
3
+ import cv2
4
+ import numpy as np
5
+
6
+ np.seterr(divide="ignore", invalid="ignore")
7
+ import warnings
8
+
9
+ import pyclipper
10
+ from shapely.geometry import Polygon
11
+
12
+ warnings.simplefilter("ignore")
13
+
14
+ __all__ = ["MakeBorderMap"]
15
+
16
+
17
+ class MakeBorderMap(object):
18
+ def __init__(self, shrink_ratio=0.4, thresh_min=0.3, thresh_max=0.7, **kwargs):
19
+ self.shrink_ratio = shrink_ratio
20
+ self.thresh_min = thresh_min
21
+ self.thresh_max = thresh_max
22
+
23
+ def __call__(self, data):
24
+
25
+ img = data["image"]
26
+ text_polys = data["polys"]
27
+ ignore_tags = data["ignore_tags"]
28
+
29
+ canvas = np.zeros(img.shape[:2], dtype=np.float32)
30
+ mask = np.zeros(img.shape[:2], dtype=np.float32)
31
+
32
+ for i in range(len(text_polys)):
33
+ if ignore_tags[i]:
34
+ continue
35
+ self.draw_border_map(text_polys[i], canvas, mask=mask)
36
+ canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min
37
+
38
+ data["threshold_map"] = canvas
39
+ data["threshold_mask"] = mask
40
+ return data
41
+
42
+ def draw_border_map(self, polygon, canvas, mask):
43
+ polygon = np.array(polygon)
44
+ assert polygon.ndim == 2
45
+ assert polygon.shape[1] == 2
46
+
47
+ polygon_shape = Polygon(polygon)
48
+ if polygon_shape.area <= 0:
49
+ return
50
+ distance = (
51
+ polygon_shape.area
52
+ * (1 - np.power(self.shrink_ratio, 2))
53
+ / polygon_shape.length
54
+ )
55
+ subject = [tuple(l) for l in polygon]
56
+ padding = pyclipper.PyclipperOffset()
57
+ padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
58
+
59
+ padded_polygon = np.array(padding.Execute(distance)[0])
60
+ cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
61
+
62
+ xmin = padded_polygon[:, 0].min()
63
+ xmax = padded_polygon[:, 0].max()
64
+ ymin = padded_polygon[:, 1].min()
65
+ ymax = padded_polygon[:, 1].max()
66
+ width = xmax - xmin + 1
67
+ height = ymax - ymin + 1
68
+
69
+ polygon[:, 0] = polygon[:, 0] - xmin
70
+ polygon[:, 1] = polygon[:, 1] - ymin
71
+
72
+ xs = np.broadcast_to(
73
+ np.linspace(0, width - 1, num=width).reshape(1, width), (height, width)
74
+ )
75
+ ys = np.broadcast_to(
76
+ np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width)
77
+ )
78
+
79
+ distance_map = np.zeros((polygon.shape[0], height, width), dtype=np.float32)
80
+ for i in range(polygon.shape[0]):
81
+ j = (i + 1) % polygon.shape[0]
82
+ absolute_distance = self._distance(xs, ys, polygon[i], polygon[j])
83
+ distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
84
+ distance_map = distance_map.min(axis=0)
85
+
86
+ xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
87
+ xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
88
+ ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
89
+ ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
90
+ canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1] = np.fmax(
91
+ 1
92
+ - distance_map[
93
+ ymin_valid - ymin : ymax_valid - ymax + height,
94
+ xmin_valid - xmin : xmax_valid - xmax + width,
95
+ ],
96
+ canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1],
97
+ )
98
+
99
+ def _distance(self, xs, ys, point_1, point_2):
100
+ """
101
+ compute the distance from point to a line
102
+ ys: coordinates in the first axis
103
+ xs: coordinates in the second axis
104
+ point_1, point_2: (x, y), the end of the line
105
+ """
106
+ height, width = xs.shape[:2]
107
+ square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[1])
108
+ square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[1])
109
+ square_distance = np.square(point_1[0] - point_2[0]) + np.square(
110
+ point_1[1] - point_2[1]
111
+ )
112
+
113
+ cosin = (square_distance - square_distance_1 - square_distance_2) / (
114
+ 2 * np.sqrt(square_distance_1 * square_distance_2)
115
+ )
116
+ square_sin = 1 - np.square(cosin)
117
+ square_sin = np.nan_to_num(square_sin)
118
+ result = np.sqrt(
119
+ square_distance_1 * square_distance_2 * square_sin / square_distance
120
+ )
121
+
122
+ result[cosin < 0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[
123
+ cosin < 0
124
+ ]
125
+ # self.extend_line(point_1, point_2, result)
126
+ return result
127
+
128
+ def extend_line(self, point_1, point_2, result, shrink_ratio):
129
+ ex_point_1 = (
130
+ int(round(point_1[0] + (point_1[0] - point_2[0]) * (1 + shrink_ratio))),
131
+ int(round(point_1[1] + (point_1[1] - point_2[1]) * (1 + shrink_ratio))),
132
+ )
133
+ cv2.line(
134
+ result,
135
+ tuple(ex_point_1),
136
+ tuple(point_1),
137
+ 4096.0,
138
+ 1,
139
+ lineType=cv2.LINE_AA,
140
+ shift=0,
141
+ )
142
+ ex_point_2 = (
143
+ int(round(point_2[0] + (point_2[0] - point_1[0]) * (1 + shrink_ratio))),
144
+ int(round(point_2[1] + (point_2[1] - point_1[1]) * (1 + shrink_ratio))),
145
+ )
146
+ cv2.line(
147
+ result,
148
+ tuple(ex_point_2),
149
+ tuple(point_2),
150
+ 4096.0,
151
+ 1,
152
+ lineType=cv2.LINE_AA,
153
+ shift=0,
154
+ )
155
+ return ex_point_1, ex_point_2
ocr/ppocr/data/imaug/make_pse_gt.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, division, print_function, unicode_literals
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import pyclipper
6
+ from shapely.geometry import Polygon
7
+
8
+ __all__ = ["MakePseGt"]
9
+
10
+
11
+ class MakePseGt(object):
12
+ def __init__(self, kernel_num=7, size=640, min_shrink_ratio=0.4, **kwargs):
13
+ self.kernel_num = kernel_num
14
+ self.min_shrink_ratio = min_shrink_ratio
15
+ self.size = size
16
+
17
+ def __call__(self, data):
18
+
19
+ image = data["image"]
20
+ text_polys = data["polys"]
21
+ ignore_tags = data["ignore_tags"]
22
+
23
+ h, w, _ = image.shape
24
+ short_edge = min(h, w)
25
+ if short_edge < self.size:
26
+ # keep short_size >= self.size
27
+ scale = self.size / short_edge
28
+ image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
29
+ text_polys *= scale
30
+
31
+ gt_kernels = []
32
+ for i in range(1, self.kernel_num + 1):
33
+ # s1->sn, from big to small
34
+ rate = 1.0 - (1.0 - self.min_shrink_ratio) / (self.kernel_num - 1) * i
35
+ text_kernel, ignore_tags = self.generate_kernel(
36
+ image.shape[0:2], rate, text_polys, ignore_tags
37
+ )
38
+ gt_kernels.append(text_kernel)
39
+
40
+ training_mask = np.ones(image.shape[0:2], dtype="uint8")
41
+ for i in range(text_polys.shape[0]):
42
+ if ignore_tags[i]:
43
+ cv2.fillPoly(
44
+ training_mask, text_polys[i].astype(np.int32)[np.newaxis, :, :], 0
45
+ )
46
+
47
+ gt_kernels = np.array(gt_kernels)
48
+ gt_kernels[gt_kernels > 0] = 1
49
+
50
+ data["image"] = image
51
+ data["polys"] = text_polys
52
+ data["gt_kernels"] = gt_kernels[0:]
53
+ data["gt_text"] = gt_kernels[0]
54
+ data["mask"] = training_mask.astype("float32")
55
+ return data
56
+
57
+ def generate_kernel(self, img_size, shrink_ratio, text_polys, ignore_tags=None):
58
+ """
59
+ Refer to part of the code:
60
+ https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/textdet_targets/base_textdet_targets.py
61
+ """
62
+
63
+ h, w = img_size
64
+ text_kernel = np.zeros((h, w), dtype=np.float32)
65
+ for i, poly in enumerate(text_polys):
66
+ polygon = Polygon(poly)
67
+ distance = (
68
+ polygon.area
69
+ * (1 - shrink_ratio * shrink_ratio)
70
+ / (polygon.length + 1e-6)
71
+ )
72
+ subject = [tuple(l) for l in poly]
73
+ pco = pyclipper.PyclipperOffset()
74
+ pco.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
75
+ shrinked = np.array(pco.Execute(-distance))
76
+
77
+ if len(shrinked) == 0 or shrinked.size == 0:
78
+ if ignore_tags is not None:
79
+ ignore_tags[i] = True
80
+ continue
81
+ try:
82
+ shrinked = np.array(shrinked[0]).reshape(-1, 2)
83
+ except:
84
+ if ignore_tags is not None:
85
+ ignore_tags[i] = True
86
+ continue
87
+ cv2.fillPoly(text_kernel, [shrinked.astype(np.int32)], i + 1)
88
+ return text_kernel, ignore_tags
ocr/ppocr/data/imaug/make_shrink_map.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, division, print_function, unicode_literals
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import pyclipper
6
+ from shapely.geometry import Polygon
7
+
8
+ __all__ = ["MakeShrinkMap"]
9
+
10
+
11
+ class MakeShrinkMap(object):
12
+ r"""
13
+ Making binary mask from detection data with ICDAR format.
14
+ Typically following the process of class `MakeICDARData`.
15
+ """
16
+
17
+ def __init__(self, min_text_size=8, shrink_ratio=0.4, **kwargs):
18
+ self.min_text_size = min_text_size
19
+ self.shrink_ratio = shrink_ratio
20
+
21
+ def __call__(self, data):
22
+ image = data["image"]
23
+ text_polys = data["polys"]
24
+ ignore_tags = data["ignore_tags"]
25
+
26
+ h, w = image.shape[:2]
27
+ text_polys, ignore_tags = self.validate_polygons(text_polys, ignore_tags, h, w)
28
+ gt = np.zeros((h, w), dtype=np.float32)
29
+ mask = np.ones((h, w), dtype=np.float32)
30
+ for i in range(len(text_polys)):
31
+ polygon = text_polys[i]
32
+ height = max(polygon[:, 1]) - min(polygon[:, 1])
33
+ width = max(polygon[:, 0]) - min(polygon[:, 0])
34
+ if ignore_tags[i] or min(height, width) < self.min_text_size:
35
+ cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
36
+ ignore_tags[i] = True
37
+ else:
38
+ polygon_shape = Polygon(polygon)
39
+ subject = [tuple(l) for l in polygon]
40
+ padding = pyclipper.PyclipperOffset()
41
+ padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
42
+ shrinked = []
43
+
44
+ # Increase the shrink ratio every time we get multiple polygon returned back
45
+ possible_ratios = np.arange(self.shrink_ratio, 1, self.shrink_ratio)
46
+ np.append(possible_ratios, 1)
47
+ # print(possible_ratios)
48
+ for ratio in possible_ratios:
49
+ # print(f"Change shrink ratio to {ratio}")
50
+ distance = (
51
+ polygon_shape.area
52
+ * (1 - np.power(ratio, 2))
53
+ / polygon_shape.length
54
+ )
55
+ shrinked = padding.Execute(-distance)
56
+ if len(shrinked) == 1:
57
+ break
58
+
59
+ if shrinked == []:
60
+ cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
61
+ ignore_tags[i] = True
62
+ continue
63
+
64
+ for each_shirnk in shrinked:
65
+ shirnk = np.array(each_shirnk).reshape(-1, 2)
66
+ cv2.fillPoly(gt, [shirnk.astype(np.int32)], 1)
67
+
68
+ data["shrink_map"] = gt
69
+ data["shrink_mask"] = mask
70
+ return data
71
+
72
+ def validate_polygons(self, polygons, ignore_tags, h, w):
73
+ """
74
+ polygons (numpy.array, required): of shape (num_instances, num_points, 2)
75
+ """
76
+ if len(polygons) == 0:
77
+ return polygons, ignore_tags
78
+ assert len(polygons) == len(ignore_tags)
79
+ for polygon in polygons:
80
+ polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1)
81
+ polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1)
82
+
83
+ for i in range(len(polygons)):
84
+ area = self.polygon_area(polygons[i])
85
+ if abs(area) < 1:
86
+ ignore_tags[i] = True
87
+ if area > 0:
88
+ polygons[i] = polygons[i][::-1, :]
89
+ return polygons, ignore_tags
90
+
91
+ def polygon_area(self, polygon):
92
+ """
93
+ compute polygon area
94
+ """
95
+ area = 0
96
+ q = polygon[-1]
97
+ for p in polygon:
98
+ area += p[0] * q[1] - p[1] * q[0]
99
+ q = p
100
+ return area / 2.0
ocr/ppocr/data/imaug/operators.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, division, print_function, unicode_literals
2
+
3
+ import math
4
+ import sys
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import six
9
+
10
+
11
+ class DecodeImage(object):
12
+ """decode image"""
13
+
14
+ def __init__(
15
+ self, img_mode="RGB", channel_first=False, ignore_orientation=False, **kwargs
16
+ ):
17
+ self.img_mode = img_mode
18
+ self.channel_first = channel_first
19
+ self.ignore_orientation = ignore_orientation
20
+
21
+ def __call__(self, data):
22
+ img = data["image"]
23
+ if six.PY2:
24
+ assert (
25
+ type(img) is str and len(img) > 0
26
+ ), "invalid input 'img' in DecodeImage"
27
+ else:
28
+ assert (
29
+ type(img) is bytes and len(img) > 0
30
+ ), "invalid input 'img' in DecodeImage"
31
+ img = np.frombuffer(img, dtype="uint8")
32
+ if self.ignore_orientation:
33
+ img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION | cv2.IMREAD_COLOR)
34
+ else:
35
+ img = cv2.imdecode(img, 1)
36
+ if img is None:
37
+ return None
38
+ if self.img_mode == "GRAY":
39
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
40
+ elif self.img_mode == "RGB":
41
+ assert img.shape[2] == 3, "invalid shape of image[%s]" % (img.shape)
42
+ img = img[:, :, ::-1]
43
+
44
+ if self.channel_first:
45
+ img = img.transpose((2, 0, 1))
46
+
47
+ data["image"] = img
48
+ return data
49
+
50
+
51
+ class NRTRDecodeImage(object):
52
+ """decode image"""
53
+
54
+ def __init__(self, img_mode="RGB", channel_first=False, **kwargs):
55
+ self.img_mode = img_mode
56
+ self.channel_first = channel_first
57
+
58
+ def __call__(self, data):
59
+ img = data["image"]
60
+ if six.PY2:
61
+ assert (
62
+ type(img) is str and len(img) > 0
63
+ ), "invalid input 'img' in DecodeImage"
64
+ else:
65
+ assert (
66
+ type(img) is bytes and len(img) > 0
67
+ ), "invalid input 'img' in DecodeImage"
68
+ img = np.frombuffer(img, dtype="uint8")
69
+
70
+ img = cv2.imdecode(img, 1)
71
+
72
+ if img is None:
73
+ return None
74
+ if self.img_mode == "GRAY":
75
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
76
+ elif self.img_mode == "RGB":
77
+ assert img.shape[2] == 3, "invalid shape of image[%s]" % (img.shape)
78
+ img = img[:, :, ::-1]
79
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
80
+ if self.channel_first:
81
+ img = img.transpose((2, 0, 1))
82
+ data["image"] = img
83
+ return data
84
+
85
+
86
+ class NormalizeImage(object):
87
+ """normalize image such as substract mean, divide std"""
88
+
89
+ def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
90
+ if isinstance(scale, str):
91
+ scale = eval(scale)
92
+ self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
93
+ mean = mean if mean is not None else [0.485, 0.456, 0.406]
94
+ std = std if std is not None else [0.229, 0.224, 0.225]
95
+
96
+ shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
97
+ self.mean = np.array(mean).reshape(shape).astype("float32")
98
+ self.std = np.array(std).reshape(shape).astype("float32")
99
+
100
+ def __call__(self, data):
101
+ img = data["image"]
102
+ from PIL import Image
103
+
104
+ if isinstance(img, Image.Image):
105
+ img = np.array(img)
106
+ assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
107
+ data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std
108
+ return data
109
+
110
+
111
+ class ToCHWImage(object):
112
+ """convert hwc image to chw image"""
113
+
114
+ def __init__(self, **kwargs):
115
+ pass
116
+
117
+ def __call__(self, data):
118
+ img = data["image"]
119
+ from PIL import Image
120
+
121
+ if isinstance(img, Image.Image):
122
+ img = np.array(img)
123
+ data["image"] = img.transpose((2, 0, 1))
124
+ return data
125
+
126
+
127
+ class Fasttext(object):
128
+ def __init__(self, path="None", **kwargs):
129
+ import fasttext
130
+
131
+ self.fast_model = fasttext.load_model(path)
132
+
133
+ def __call__(self, data):
134
+ label = data["label"]
135
+ fast_label = self.fast_model[label]
136
+ data["fast_label"] = fast_label
137
+ return data
138
+
139
+
140
+ class KeepKeys(object):
141
+ def __init__(self, keep_keys, **kwargs):
142
+ self.keep_keys = keep_keys
143
+
144
+ def __call__(self, data):
145
+ data_list = []
146
+ for key in self.keep_keys:
147
+ data_list.append(data[key])
148
+ return data_list
149
+
150
+
151
+ class Pad(object):
152
+ def __init__(self, size=None, size_div=32, **kwargs):
153
+ if size is not None and not isinstance(size, (int, list, tuple)):
154
+ raise TypeError(
155
+ "Type of target_size is invalid. Now is {}".format(type(size))
156
+ )
157
+ if isinstance(size, int):
158
+ size = [size, size]
159
+ self.size = size
160
+ self.size_div = size_div
161
+
162
+ def __call__(self, data):
163
+
164
+ img = data["image"]
165
+ img_h, img_w = img.shape[0], img.shape[1]
166
+ if self.size:
167
+ resize_h2, resize_w2 = self.size
168
+ assert (
169
+ img_h < resize_h2 and img_w < resize_w2
170
+ ), "(h, w) of target size should be greater than (img_h, img_w)"
171
+ else:
172
+ resize_h2 = max(
173
+ int(math.ceil(img.shape[0] / self.size_div) * self.size_div),
174
+ self.size_div,
175
+ )
176
+ resize_w2 = max(
177
+ int(math.ceil(img.shape[1] / self.size_div) * self.size_div),
178
+ self.size_div,
179
+ )
180
+ img = cv2.copyMakeBorder(
181
+ img,
182
+ 0,
183
+ resize_h2 - img_h,
184
+ 0,
185
+ resize_w2 - img_w,
186
+ cv2.BORDER_CONSTANT,
187
+ value=0,
188
+ )
189
+ data["image"] = img
190
+ return data
191
+
192
+
193
+ class Resize(object):
194
+ def __init__(self, size=(640, 640), **kwargs):
195
+ self.size = size
196
+
197
+ def resize_image(self, img):
198
+ resize_h, resize_w = self.size
199
+ ori_h, ori_w = img.shape[:2] # (h, w, c)
200
+ ratio_h = float(resize_h) / ori_h
201
+ ratio_w = float(resize_w) / ori_w
202
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
203
+ return img, [ratio_h, ratio_w]
204
+
205
+ def __call__(self, data):
206
+ img = data["image"]
207
+ if "polys" in data:
208
+ text_polys = data["polys"]
209
+
210
+ img_resize, [ratio_h, ratio_w] = self.resize_image(img)
211
+ if "polys" in data:
212
+ new_boxes = []
213
+ for box in text_polys:
214
+ new_box = []
215
+ for cord in box:
216
+ new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
217
+ new_boxes.append(new_box)
218
+ data["polys"] = np.array(new_boxes, dtype=np.float32)
219
+ data["image"] = img_resize
220
+ return data
221
+
222
+
223
+ class DetResizeForTest(object):
224
+ def __init__(self, **kwargs):
225
+ super(DetResizeForTest, self).__init__()
226
+ self.resize_type = 0
227
+ if "image_shape" in kwargs:
228
+ self.image_shape = kwargs["image_shape"]
229
+ self.resize_type = 1
230
+ elif "limit_side_len" in kwargs:
231
+ self.limit_side_len = kwargs["limit_side_len"]
232
+ self.limit_type = kwargs.get("limit_type", "min")
233
+ elif "resize_long" in kwargs:
234
+ self.resize_type = 2
235
+ self.resize_long = kwargs.get("resize_long", 960)
236
+ else:
237
+ self.limit_side_len = 736
238
+ self.limit_type = "min"
239
+
240
+ def __call__(self, data):
241
+ img = data["image"]
242
+ src_h, src_w, _ = img.shape
243
+
244
+ if self.resize_type == 0:
245
+ # img, shape = self.resize_image_type0(img)
246
+ img, [ratio_h, ratio_w] = self.resize_image_type0(img)
247
+ elif self.resize_type == 2:
248
+ img, [ratio_h, ratio_w] = self.resize_image_type2(img)
249
+ else:
250
+ # img, shape = self.resize_image_type1(img)
251
+ img, [ratio_h, ratio_w] = self.resize_image_type1(img)
252
+ data["image"] = img
253
+ data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w])
254
+ return data
255
+
256
+ def resize_image_type1(self, img):
257
+ resize_h, resize_w = self.image_shape
258
+ ori_h, ori_w = img.shape[:2] # (h, w, c)
259
+ ratio_h = float(resize_h) / ori_h
260
+ ratio_w = float(resize_w) / ori_w
261
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
262
+ # return img, np.array([ori_h, ori_w])
263
+ return img, [ratio_h, ratio_w]
264
+
265
+ def resize_image_type0(self, img):
266
+ """
267
+ resize image to a size multiple of 32 which is required by the network
268
+ args:
269
+ img(array): array with shape [h, w, c]
270
+ return(tuple):
271
+ img, (ratio_h, ratio_w)
272
+ """
273
+ limit_side_len = self.limit_side_len
274
+ h, w, c = img.shape
275
+
276
+ # limit the max side
277
+ if self.limit_type == "max":
278
+ if max(h, w) > limit_side_len:
279
+ if h > w:
280
+ ratio = float(limit_side_len) / h
281
+ else:
282
+ ratio = float(limit_side_len) / w
283
+ else:
284
+ ratio = 1.0
285
+ elif self.limit_type == "min":
286
+ if min(h, w) < limit_side_len:
287
+ if h < w:
288
+ ratio = float(limit_side_len) / h
289
+ else:
290
+ ratio = float(limit_side_len) / w
291
+ else:
292
+ ratio = 1.0
293
+ elif self.limit_type == "resize_long":
294
+ ratio = float(limit_side_len) / max(h, w)
295
+ else:
296
+ raise Exception("not support limit type, image ")
297
+ resize_h = int(h * ratio)
298
+ resize_w = int(w * ratio)
299
+
300
+ resize_h = max(int(round(resize_h / 32) * 32), 32)
301
+ resize_w = max(int(round(resize_w / 32) * 32), 32)
302
+
303
+ try:
304
+ if int(resize_w) <= 0 or int(resize_h) <= 0:
305
+ return None, (None, None)
306
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
307
+ except:
308
+ print(img.shape, resize_w, resize_h)
309
+ sys.exit(0)
310
+ ratio_h = resize_h / float(h)
311
+ ratio_w = resize_w / float(w)
312
+ return img, [ratio_h, ratio_w]
313
+
314
+ def resize_image_type2(self, img):
315
+ h, w, _ = img.shape
316
+
317
+ resize_w = w
318
+ resize_h = h
319
+
320
+ if resize_h > resize_w:
321
+ ratio = float(self.resize_long) / resize_h
322
+ else:
323
+ ratio = float(self.resize_long) / resize_w
324
+
325
+ resize_h = int(resize_h * ratio)
326
+ resize_w = int(resize_w * ratio)
327
+
328
+ max_stride = 128
329
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
330
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
331
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
332
+ ratio_h = resize_h / float(h)
333
+ ratio_w = resize_w / float(w)
334
+
335
+ return img, [ratio_h, ratio_w]
336
+
337
+
338
+ class E2EResizeForTest(object):
339
+ def __init__(self, **kwargs):
340
+ super(E2EResizeForTest, self).__init__()
341
+ self.max_side_len = kwargs["max_side_len"]
342
+ self.valid_set = kwargs["valid_set"]
343
+
344
+ def __call__(self, data):
345
+ img = data["image"]
346
+ src_h, src_w, _ = img.shape
347
+ if self.valid_set == "totaltext":
348
+ im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
349
+ img, max_side_len=self.max_side_len
350
+ )
351
+ else:
352
+ im_resized, (ratio_h, ratio_w) = self.resize_image(
353
+ img, max_side_len=self.max_side_len
354
+ )
355
+ data["image"] = im_resized
356
+ data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w])
357
+ return data
358
+
359
+ def resize_image_for_totaltext(self, im, max_side_len=512):
360
+
361
+ h, w, _ = im.shape
362
+ resize_w = w
363
+ resize_h = h
364
+ ratio = 1.25
365
+ if h * ratio > max_side_len:
366
+ ratio = float(max_side_len) / resize_h
367
+ resize_h = int(resize_h * ratio)
368
+ resize_w = int(resize_w * ratio)
369
+
370
+ max_stride = 128
371
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
372
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
373
+ im = cv2.resize(im, (int(resize_w), int(resize_h)))
374
+ ratio_h = resize_h / float(h)
375
+ ratio_w = resize_w / float(w)
376
+ return im, (ratio_h, ratio_w)
377
+
378
+ def resize_image(self, im, max_side_len=512):
379
+ """
380
+ resize image to a size multiple of max_stride which is required by the network
381
+ :param im: the resized image
382
+ :param max_side_len: limit of max image size to avoid out of memory in gpu
383
+ :return: the resized image and the resize ratio
384
+ """
385
+ h, w, _ = im.shape
386
+
387
+ resize_w = w
388
+ resize_h = h
389
+
390
+ # Fix the longer side
391
+ if resize_h > resize_w:
392
+ ratio = float(max_side_len) / resize_h
393
+ else:
394
+ ratio = float(max_side_len) / resize_w
395
+
396
+ resize_h = int(resize_h * ratio)
397
+ resize_w = int(resize_w * ratio)
398
+
399
+ max_stride = 128
400
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
401
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
402
+ im = cv2.resize(im, (int(resize_w), int(resize_h)))
403
+ ratio_h = resize_h / float(h)
404
+ ratio_w = resize_w / float(w)
405
+
406
+ return im, (ratio_h, ratio_w)
407
+
408
+
409
+ class KieResize(object):
410
+ def __init__(self, **kwargs):
411
+ super(KieResize, self).__init__()
412
+ self.max_side, self.min_side = kwargs["img_scale"][0], kwargs["img_scale"][1]
413
+
414
+ def __call__(self, data):
415
+ img = data["image"]
416
+ points = data["points"]
417
+ src_h, src_w, _ = img.shape
418
+ (
419
+ im_resized,
420
+ scale_factor,
421
+ [ratio_h, ratio_w],
422
+ [new_h, new_w],
423
+ ) = self.resize_image(img)
424
+ resize_points = self.resize_boxes(img, points, scale_factor)
425
+ data["ori_image"] = img
426
+ data["ori_boxes"] = points
427
+ data["points"] = resize_points
428
+ data["image"] = im_resized
429
+ data["shape"] = np.array([new_h, new_w])
430
+ return data
431
+
432
+ def resize_image(self, img):
433
+ norm_img = np.zeros([1024, 1024, 3], dtype="float32")
434
+ scale = [512, 1024]
435
+ h, w = img.shape[:2]
436
+ max_long_edge = max(scale)
437
+ max_short_edge = min(scale)
438
+ scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w))
439
+ resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(
440
+ h * float(scale_factor) + 0.5
441
+ )
442
+ max_stride = 32
443
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
444
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
445
+ im = cv2.resize(img, (resize_w, resize_h))
446
+ new_h, new_w = im.shape[:2]
447
+ w_scale = new_w / w
448
+ h_scale = new_h / h
449
+ scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
450
+ norm_img[:new_h, :new_w, :] = im
451
+ return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w]
452
+
453
+ def resize_boxes(self, im, points, scale_factor):
454
+ points = points * scale_factor
455
+ img_shape = im.shape[:2]
456
+ points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
457
+ points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
458
+ return points