vaivskku commited on
Commit
b3befe4
ยท
1 Parent(s): a4c5668
Files changed (1) hide show
  1. app.py +1063 -0
app.py ADDED
@@ -0,0 +1,1063 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoProcessor, Pix2StructForConditionalGeneration, T5Tokenizer, T5ForConditionalGeneration, Pix2StructProcessor, BartConfig,ViTConfig,VisionEncoderDecoderConfig, DonutProcessor, VisionEncoderDecoderModel, AutoTokenizer, AutoModel
3
+ from PIL import Image
4
+ import torch
5
+ import warnings
6
+ import re
7
+ import json
8
+ import os
9
+ import numpy as np
10
+ import pandas as pd
11
+ from tqdm import tqdm
12
+ import argparse
13
+ from scipy import optimize
14
+ from typing import Optional
15
+ import dataclasses
16
+ import editdistance
17
+ import itertools
18
+ import sys
19
+ import time
20
+ import logging
21
+
22
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
23
+ logger = logging.getLogger()
24
+
25
+ warnings.filterwarnings('ignore')
26
+ MAX_PATCHES = 512
27
+ # Load the models and processor
28
+ #device = torch.device("cpu")
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+
31
+ # Paths to the models
32
+ ko_deplot_model_path = './deplot_model_ver_kor_24.7.25_refinetuning_epoch1.bin'
33
+ aihub_deplot_model_path='./deplot_k.pt'
34
+ t5_model_path = './ke_t5.pt'
35
+
36
+ # Load first model ko-deplot
37
+ processor1 = Pix2StructProcessor.from_pretrained('nuua/ko-deplot')
38
+ model1 = Pix2StructForConditionalGeneration.from_pretrained('nuua/ko-deplot')
39
+ model1.load_state_dict(torch.load(ko_deplot_model_path, map_location=device))
40
+ model1.to(device)
41
+
42
+ # Load second model aihub-deplot
43
+ processor2 = AutoProcessor.from_pretrained("ybelkada/pix2struct-base")
44
+ model2 = Pix2StructForConditionalGeneration.from_pretrained("ybelkada/pix2struct-base")
45
+ model2.load_state_dict(torch.load(aihub_deplot_model_path, map_location=device))
46
+
47
+
48
+ tokenizer = T5Tokenizer.from_pretrained("KETI-AIR/ke-t5-base")
49
+ t5_model = T5ForConditionalGeneration.from_pretrained("KETI-AIR/ke-t5-base")
50
+ t5_model.load_state_dict(torch.load(t5_model_path, map_location=device))
51
+
52
+ model2.to(device)
53
+ t5_model.to(device)
54
+
55
+ #Load third model unichart
56
+ unichart_model_path = "./unichart"
57
+ model3 = VisionEncoderDecoderModel.from_pretrained(unichart_model_path)
58
+ processor3 = DonutProcessor.from_pretrained(unichart_model_path)
59
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
+ model3.to(device)
61
+
62
+ #ko-deplot ์ถ”๋ก ํ•จ์ˆ˜
63
+ # Function to format output
64
+ def format_output(prediction):
65
+ return prediction.replace('<0x0A>', '\n')
66
+
67
+ # First model prediction ko-deplot
68
+ def predict_model1(image):
69
+ images = [image]
70
+ inputs = processor1(images=images, text="What is the title of the chart", return_tensors="pt", padding=True)
71
+ inputs = {k: v.to(device) for k, v in inputs.items()} # Move to GPU
72
+
73
+ model1.eval()
74
+ with torch.no_grad():
75
+ predictions = model1.generate(**inputs, max_new_tokens=4096)
76
+ outputs = [processor1.decode(pred, skip_special_tokens=True) for pred in predictions]
77
+
78
+ formatted_output = format_output(outputs[0])
79
+ return formatted_output
80
+
81
+
82
+ def replace_unk(text):
83
+ # 1. '์ œ๋ชฉ:', '์œ ํ˜•:' ๊ธ€์ž ์•ž์— ์žˆ๋Š” <unk>๋Š” \n๋กœ ๋ฐ”๊ฟˆ
84
+ text = re.sub(r'<unk>(?=์ œ๋ชฉ:|์œ ํ˜•:)', '\n', text)
85
+ # 2. '์„ธ๋กœ ' ๋˜๋Š” '๊ฐ€๋กœ '์™€ '๋Œ€ํ˜•' ์‚ฌ์ด์— ์žˆ๋Š” <unk>๋ฅผ ""๋กœ ๋ฐ”๊ฟˆ
86
+ text = re.sub(r'(?<=์„ธ๋กœ |๊ฐ€๋กœ )<unk>(?=๋Œ€ํ˜•)', '', text)
87
+ # 3. ์ˆซ์ž์™€ ํ…์ŠคํŠธ ์‚ฌ์ด์— ์žˆ๋Š” <unk>๋ฅผ \n๋กœ ๋ฐ”๊ฟˆ
88
+ text = re.sub(r'(\d)<unk>([^\d])', r'\1\n\2', text)
89
+ # 4. %, ์›, ๊ฑด, ๋ช… ๋’ค์— ๋‚˜์˜ค๋Š” <unk>๋ฅผ \n๋กœ ๋ฐ”๊ฟˆ
90
+ text = re.sub(r'(?<=[%์›๊ฑด๋ช…\)])<unk>', '\n', text)
91
+ # 5. ์ˆซ์ž์™€ ์ˆซ์ž ์‚ฌ์ด์— ์žˆ๋Š” <unk>๋ฅผ \n๋กœ ๋ฐ”๊ฟˆ
92
+ text = re.sub(r'(\d)<unk>(\d)', r'\1\n\2', text)
93
+ # 6. 'ํ˜•'์ด๋ผ๋Š” ๊ธ€์ž์™€ ' |' ์‚ฌ์ด์— ์žˆ๋Š” <unk>๋ฅผ \n๋กœ ๋ฐ”๊ฟˆ
94
+ text = re.sub(r'ํ˜•<unk>(?= \|)', 'ํ˜•\n', text)
95
+ # 7. ๋‚˜๋จธ์ง€ <unk>๋ฅผ ๋ชจ๋‘ ""๋กœ ๋ฐ”๊ฟˆ
96
+ text = text.replace('<unk>', '')
97
+ return text
98
+
99
+ # Second model prediction aihub_deplot
100
+ def predict_model2(image):
101
+ image = image.convert("RGB")
102
+ inputs = processor2(images=image, return_tensors="pt", max_patches=MAX_PATCHES).to(device)
103
+
104
+ flattened_patches = inputs.flattened_patches.to(device)
105
+ attention_mask = inputs.attention_mask.to(device)
106
+
107
+ model2.eval()
108
+ t5_model.eval()
109
+ with torch.no_grad():
110
+ deplot_generated_ids = model2.generate(flattened_patches=flattened_patches, attention_mask=attention_mask, max_length=1000)
111
+ generated_datatable = processor2.batch_decode(deplot_generated_ids, skip_special_tokens=False)[0]
112
+ generated_datatable = generated_datatable.replace("<pad>", "<unk>").replace("</s>", "<unk>")
113
+ refined_table = replace_unk(generated_datatable)
114
+ return refined_table
115
+
116
+ def predict_model3(image):
117
+ image=image.convert("RGB")
118
+ input_prompt = "<extract_data_table> <s_answer>"
119
+ decoder_input_ids = processor3.tokenizer(input_prompt, add_special_tokens=False, return_tensors="pt").input_ids
120
+ pixel_values = processor3(image, return_tensors="pt").pixel_values
121
+ outputs = model3.generate(
122
+ pixel_values.to(device),
123
+ decoder_input_ids=decoder_input_ids.to(device),
124
+ max_length=model3.decoder.config.max_position_embeddings,
125
+ early_stopping=True,
126
+ pad_token_id=processor3.tokenizer.pad_token_id,
127
+ eos_token_id=processor3.tokenizer.eos_token_id,
128
+ use_cache=True,
129
+ num_beams=4,
130
+ bad_words_ids=[[processor3.tokenizer.unk_token_id]],
131
+ return_dict_in_generate=True,
132
+ )
133
+ sequence = processor3.batch_decode(outputs.sequences)[0]
134
+ sequence = sequence.replace(processor3.tokenizer.eos_token, "").replace(processor3.tokenizer.pad_token, "")
135
+ sequence = sequence.split("<s_answer>")[-1].strip()
136
+
137
+ return sequence
138
+ #function for converting aihub dataset labeling json file to ko-deplot data table
139
+ def process_json_file(input_file):
140
+ with open(input_file, 'r', encoding='utf-8') as file:
141
+ data = json.load(file)
142
+
143
+ # ํ•„์š”ํ•œ ๋ฐ์ดํ„ฐ ์ถ”์ถœ
144
+ chart_type = data['metadata']['chart_sub']
145
+ title = data['annotations'][0]['title']
146
+ x_axis = data['annotations'][0]['axis_label']['x_axis']
147
+ y_axis = data['annotations'][0]['axis_label']['y_axis']
148
+ legend = data['annotations'][0]['legend']
149
+ data_labels = data['annotations'][0]['data_label']
150
+ is_legend = data['annotations'][0]['is_legend']
151
+
152
+ # ์›ํ•˜๋Š” ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜
153
+ formatted_string = f"TITLE | {title} <0x0A> "
154
+ if '๊ฐ€๋กœ' in chart_type:
155
+ if is_legend:
156
+ # ๊ฐ€๋กœ ์ฐจํŠธ ์ฒ˜๋ฆฌ
157
+ formatted_string += " | ".join(legend) + " <0x0A> "
158
+ for i in range(len(y_axis)):
159
+ row = [y_axis[i]]
160
+ for j in range(len(legend)):
161
+ if i < len(data_labels[j]):
162
+ row.append(str(data_labels[j][i])) # ๋ฐ์ดํ„ฐ ๊ฐ’์„ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜
163
+ else:
164
+ row.append("") # ๋ฐ์ดํ„ฐ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ๋นˆ ๋ฌธ์ž์—ด ์ถ”๊ฐ€
165
+ formatted_string += " | ".join(row) + " <0x0A> "
166
+ else:
167
+ # is_legend๊ฐ€ False์ธ ๊ฒฝ์šฐ
168
+ for i in range(len(y_axis)):
169
+ row = [y_axis[i], str(data_labels[0][i])]
170
+ formatted_string += " | ".join(row) + " <0x0A> "
171
+ elif chart_type == "์›ํ˜•":
172
+ # ์›ํ˜• ์ฐจํŠธ ์ฒ˜๋ฆฌ
173
+ if legend:
174
+ used_labels = legend
175
+ else:
176
+ used_labels = x_axis
177
+
178
+ formatted_string += " | ".join(used_labels) + " <0x0A> "
179
+ row = [data_labels[0][i] for i in range(len(used_labels))]
180
+ formatted_string += " | ".join(row) + " <0x0A> "
181
+ elif chart_type == "ํ˜ผํ•ฉํ˜•":
182
+ # ํ˜ผํ•ฉํ˜• ์ฐจํŠธ ์ฒ˜๋ฆฌ
183
+ all_legends = [ann['legend'][0] for ann in data['annotations']]
184
+ formatted_string += " | ".join(all_legends) + " <0x0A> "
185
+
186
+ combined_data = []
187
+ for i in range(len(x_axis)):
188
+ row = [x_axis[i]]
189
+ for ann in data['annotations']:
190
+ if i < len(ann['data_label'][0]):
191
+ row.append(str(ann['data_label'][0][i])) # ๋ฐ์ดํ„ฐ ๊ฐ’์„ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜
192
+ else:
193
+ row.append("") # ๋ฐ์ดํ„ฐ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ๋นˆ ๋ฌธ์ž์—ด ์ถ”๊ฐ€
194
+ combined_data.append(" | ".join(row))
195
+
196
+ formatted_string += " <0x0A> ".join(combined_data) + " <0x0A> "
197
+ else:
198
+ # ๊ธฐํƒ€ ์ฐจํŠธ ์ฒ˜๋ฆฌ
199
+ if is_legend:
200
+ formatted_string += " | ".join(legend) + " <0x0A> "
201
+ for i in range(len(x_axis)):
202
+ row = [x_axis[i]]
203
+ for j in range(len(legend)):
204
+ if i < len(data_labels[j]):
205
+ row.append(str(data_labels[j][i])) # ๋ฐ์ดํ„ฐ ๊ฐ’์„ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜
206
+ else:
207
+ row.append("") # ๋ฐ์ดํ„ฐ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ๋นˆ ๋ฌธ์ž์—ด ์ถ”๊ฐ€
208
+ formatted_string += " | ".join(row) + " <0x0A> "
209
+ else:
210
+ for i in range(len(x_axis)):
211
+ if i < len(data_labels[0]):
212
+ formatted_string += f"{x_axis[i]} | {str(data_labels[0][i])} <0x0A> "
213
+ else:
214
+ formatted_string += f"{x_axis[i]} | <0x0A> " # ๋ฐ์ดํ„ฐ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ๋นˆ ๋ฌธ์ž์—ด ์ถ”๊ฐ€
215
+
216
+ # ๋งˆ์ง€๋ง‰ "<0x0A> " ์ œ๊ฑฐ
217
+ formatted_string = formatted_string[:-8]
218
+ return format_output(formatted_string)
219
+
220
+ def chart_data(data):
221
+ datatable = []
222
+ num = len(data)
223
+ for n in range(num):
224
+ title = data[n]['title'] if data[n]['is_title'] else ''
225
+ legend = data[n]['legend'] if data[n]['is_legend'] else ''
226
+ datalabel = data[n]['data_label'] if data[n]['is_datalabel'] else [0]
227
+ unit = data[n]['unit'] if data[n]['is_unit'] else ''
228
+ base = data[n]['base'] if data[n]['is_base'] else ''
229
+ x_axis_title = data[n]['axis_title']['x_axis']
230
+ y_axis_title = data[n]['axis_title']['y_axis']
231
+ x_axis = data[n]['axis_label']['x_axis'] if data[n]['is_axis_label_x_axis'] else [0]
232
+ y_axis = data[n]['axis_label']['y_axis'] if data[n]['is_axis_label_y_axis'] else [0]
233
+
234
+ if len(legend) > 1:
235
+ datalabel = np.array(datalabel).transpose().tolist()
236
+
237
+ datatable.append([title, legend, datalabel, unit, base, x_axis_title, y_axis_title, x_axis, y_axis])
238
+
239
+ return datatable
240
+
241
+ def datatable(data, chart_type):
242
+ data_table = ''
243
+ num = len(data)
244
+
245
+ if len(data) == 2:
246
+ temp = []
247
+ temp.append(f"๋Œ€์ƒ: {data[0][4]}")
248
+ temp.append(f"์ œ๋ชฉ: {data[0][0]}")
249
+ temp.append(f"์œ ํ˜•: {' '.join(chart_type[0:2])}")
250
+ temp.append(f"{data[0][5]} | {data[0][1][0]}({data[0][3]}) | {data[1][1][0]}({data[1][3]})")
251
+
252
+ x_axis = data[0][7]
253
+ for idx, x in enumerate(x_axis):
254
+ temp.append(f"{x} | {data[0][2][0][idx]} | {data[1][2][0][idx]}")
255
+
256
+ data_table = '\n'.join(temp)
257
+ else:
258
+ for n in range(num):
259
+ temp = []
260
+
261
+ title, legend, datalabel, unit, base, x_axis_title, y_axis_title, x_axis, y_axis = data[n]
262
+ legend = [element + f"({unit})" for element in legend]
263
+
264
+ if len(legend) > 1:
265
+ temp.append(f"๋Œ€์ƒ: {base}")
266
+ temp.append(f"์ œ๋ชฉ: {title}")
267
+ temp.append(f"์œ ํ˜•: {' '.join(chart_type[0:2])}")
268
+ temp.append(f"{x_axis_title} | {' | '.join(legend)}")
269
+
270
+ if chart_type[2] == "์›ํ˜•":
271
+ datalabel = sum(datalabel, [])
272
+ temp.append(f"{' | '.join([str(d) for d in datalabel])}")
273
+ data_table = '\n'.join(temp)
274
+ else:
275
+ axis = y_axis if chart_type[2] == "๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•" else x_axis
276
+ for idx, (x, d) in enumerate(zip(axis, datalabel)):
277
+ temp_d = [str(e) for e in d]
278
+ temp_d = " | ".join(temp_d)
279
+ row = f"{x} | {temp_d}"
280
+ temp.append(row)
281
+ data_table = '\n'.join(temp)
282
+ else:
283
+ temp.append(f"๋Œ€์ƒ: {base}")
284
+ temp.append(f"์ œ๋ชฉ: {title}")
285
+ temp.append(f"์œ ํ˜•: {' '.join(chart_type[0:2])}")
286
+ temp.append(f"{x_axis_title} | {unit}")
287
+ axis = y_axis if chart_type[2] == "๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•" else x_axis
288
+ datalabel = datalabel[0]
289
+
290
+ for idx, x in enumerate(axis):
291
+ row = f"{x} | {str(datalabel[idx])}"
292
+ temp.append(row)
293
+ data_table = '\n'.join(temp)
294
+
295
+ return data_table
296
+
297
+ #function for converting aihub dataset labeling json file to aihub-deplot data table
298
+ def process_json_file2(input_file):
299
+ with open(input_file, 'r', encoding='utf-8') as file:
300
+ data = json.load(file)
301
+ # ํ•„์š”ํ•œ ๋ฐ์ดํ„ฐ ์ถ”์ถœ
302
+ chart_multi = data['metadata']['chart_multi']
303
+ chart_main = data['metadata']['chart_main']
304
+ chart_sub = data['metadata']['chart_sub']
305
+ chart_type = [chart_multi, chart_sub, chart_main]
306
+ chart_annotations = data['annotations']
307
+
308
+ charData = chart_data(chart_annotations)
309
+ dataTable = datatable(charData, chart_type)
310
+ return dataTable
311
+
312
+ # RMS
313
+ def _to_float(text): # ๋‹จ์œ„ ๋–ผ๊ณ  ์ˆซ์ž๋งŒ..?
314
+ try:
315
+ if text.endswith("%"):
316
+ # Convert percentages to floats.
317
+ return float(text.rstrip("%")) / 100.0
318
+ else:
319
+ return float(text)
320
+ except ValueError:
321
+ return None
322
+
323
+
324
+ def _get_relative_distance(
325
+ target, prediction, theta = 1.0
326
+ ):
327
+ """Returns min(1, |target-prediction|/|target|)."""
328
+ if not target:
329
+ return int(not prediction)
330
+ distance = min(abs((target - prediction) / target), 1)
331
+ return distance if distance < theta else 1
332
+
333
+ def anls_metric(target: str, prediction: str, theta: float = 0.5):
334
+ edit_distance = editdistance.eval(target, prediction)
335
+ normalize_ld = edit_distance / max(len(target), len(prediction))
336
+ return 1 - normalize_ld if normalize_ld < theta else 0
337
+
338
+ def _permute(values, indexes):
339
+ return tuple(values[i] if i < len(values) else "" for i in indexes)
340
+
341
+
342
+ @dataclasses.dataclass(frozen=True)
343
+ class Table:
344
+ """Helper class for the content of a markdown table."""
345
+
346
+ base: Optional[str] = None
347
+ title: Optional[str] = None
348
+ chartType: Optional[str] = None
349
+ headers: tuple[str, Ellipsis] = dataclasses.field(default_factory=tuple)
350
+ rows: tuple[tuple[str, Ellipsis], Ellipsis] = dataclasses.field(default_factory=tuple)
351
+
352
+ def permuted(self, indexes):
353
+ """Builds a version of the table changing the column order."""
354
+ return Table(
355
+ base=self.base,
356
+ title=self.title,
357
+ chartType=self.chartType,
358
+ headers=_permute(self.headers, indexes),
359
+ rows=tuple(_permute(row, indexes) for row in self.rows),
360
+ )
361
+
362
+ def aligned(
363
+ self, headers, text_theta = 0.5
364
+ ):
365
+ """Builds a column permutation with headers in the most correct order."""
366
+ if len(headers) != len(self.headers):
367
+ raise ValueError(f"Header length {headers} must match {self.headers}.")
368
+ distance = []
369
+ for h2 in self.headers:
370
+ distance.append(
371
+ [
372
+ 1 - anls_metric(h1, h2, text_theta)
373
+ for h1 in headers
374
+ ]
375
+ )
376
+ cost_matrix = np.array(distance)
377
+ row_ind, col_ind = optimize.linear_sum_assignment(cost_matrix)
378
+ permutation = [idx for _, idx in sorted(zip(col_ind, row_ind))]
379
+ score = (1 - cost_matrix)[permutation[1:], range(1, len(row_ind))].prod()
380
+ return self.permuted(permutation), score
381
+
382
+ def _parse_table(text, transposed = False): # ํ‘œ ์ œ๋ชฉ, ์—ด ์ด๋ฆ„, ํ–‰ ์ฐพ๊ธฐ
383
+ """Builds a table from a markdown representation."""
384
+ lines = text.lower().splitlines()
385
+ if not lines:
386
+ return Table()
387
+
388
+ if lines[0].startswith("๋Œ€์ƒ: "):
389
+ base = lines[0][len("๋Œ€์ƒ: ") :].strip()
390
+ offset = 1 #
391
+ else:
392
+ base = None
393
+ offset = 0
394
+ if lines[1].startswith("์ œ๋ชฉ: "):
395
+ title = lines[1][len("์ œ๋ชฉ: ") :].strip()
396
+ offset = 2 #
397
+ else:
398
+ title = None
399
+ offset = 1
400
+ if lines[2].startswith("์œ ํ˜•: "):
401
+ chartType = lines[2][len("์œ ํ˜•: ") :].strip()
402
+ offset = 3 #
403
+ else:
404
+ chartType = None
405
+
406
+ if len(lines) < offset + 1:
407
+ return Table(base=base, title=title, chartType=chartType)
408
+
409
+ rows = []
410
+ for line in lines[offset:]:
411
+ rows.append(tuple(v.strip() for v in line.split(" | ")))
412
+ if transposed:
413
+ rows = [tuple(row) for row in itertools.zip_longest(*rows, fillvalue="")]
414
+ return Table(base=base, title=title, chartType=chartType, headers=rows[0], rows=tuple(rows[1:]))
415
+
416
+ def _get_table_datapoints(table):
417
+ datapoints = {}
418
+ if table.base is not None:
419
+ datapoints["๋Œ€์ƒ"] = table.base
420
+ if table.title is not None:
421
+ datapoints["์ œ๋ชฉ"] = table.title
422
+ if table.chartType is not None:
423
+ datapoints["์œ ํ˜•"] = table.chartType
424
+ if not table.rows or len(table.headers) <= 1:
425
+ return datapoints
426
+ for row in table.rows:
427
+ for header, cell in zip(table.headers[1:], row[1:]):
428
+ #print(f"{row[0]} {header} >> {cell}")
429
+ datapoints[f"{row[0]} {header}"] = cell #
430
+ return datapoints
431
+
432
+ def _get_datapoint_metric( #
433
+ target,
434
+ prediction,
435
+ text_theta=0.5,
436
+ number_theta=0.1,
437
+ ):
438
+ """Computes a metric that scores how similar two datapoint pairs are."""
439
+ key_metric = anls_metric(
440
+ target[0], prediction[0], text_theta
441
+ )
442
+ pred_float = _to_float(prediction[1]) # ์ˆซ์ž์ธ์ง€ ํ™•์ธ
443
+ target_float = _to_float(target[1])
444
+ if pred_float is not None and target_float:
445
+ return key_metric * (
446
+ 1 - _get_relative_distance(target_float, pred_float, number_theta) # ์ˆซ์ž๋ฉด ์ƒ๋Œ€์  ๊ฑฐ๋ฆฌ๊ฐ’ ๊ณ„์‚ฐ
447
+ )
448
+ elif target[1] == prediction[1]:
449
+ return key_metric
450
+ else:
451
+ return key_metric * anls_metric(
452
+ target[1], prediction[1], text_theta
453
+ )
454
+
455
+ def _table_datapoints_precision_recall_f1( # ์ฐ ๊ณ„์‚ฐ
456
+ target_table,
457
+ prediction_table,
458
+ text_theta = 0.5,
459
+ number_theta = 0.1,
460
+ ):
461
+ """Calculates matching similarity between two tables as dicts."""
462
+ target_datapoints = list(_get_table_datapoints(target_table).items())
463
+ prediction_datapoints = list(_get_table_datapoints(prediction_table).items())
464
+ if not target_datapoints and not prediction_datapoints:
465
+ return 1, 1, 1
466
+ if not target_datapoints:
467
+ return 0, 1, 0
468
+ if not prediction_datapoints:
469
+ return 1, 0, 0
470
+ distance = []
471
+ for t, _ in target_datapoints:
472
+ distance.append(
473
+ [
474
+ 1 - anls_metric(t, p, text_theta)
475
+ for p, _ in prediction_datapoints
476
+ ]
477
+ )
478
+ cost_matrix = np.array(distance)
479
+ row_ind, col_ind = optimize.linear_sum_assignment(cost_matrix)
480
+ score = 0
481
+ for r, c in zip(row_ind, col_ind):
482
+ score += _get_datapoint_metric(
483
+ target_datapoints[r], prediction_datapoints[c], text_theta, number_theta
484
+ )
485
+ if score == 0:
486
+ return 0, 0, 0
487
+ precision = score / len(prediction_datapoints)
488
+ recall = score / len(target_datapoints)
489
+ return precision, recall, 2 * precision * recall / (precision + recall)
490
+
491
+ def table_datapoints_precision_recall_per_point( # ๊ฐ๊ฐ ๊ณ„์‚ฐ...
492
+ targets,
493
+ predictions,
494
+ text_theta = 0.5,
495
+ number_theta = 0.1,
496
+ ):
497
+ """Computes precisin recall and F1 metrics given two flattened tables.
498
+
499
+ Parses each string into a dictionary of keys and values using row and column
500
+ headers. Then we match keys between the two dicts as long as their relative
501
+ levenshtein distance is below a threshold. Values are also compared with
502
+ ANLS if strings or relative distance if they are numeric.
503
+
504
+ Args:
505
+ targets: list of list of strings.
506
+ predictions: list of strings.
507
+ text_theta: relative edit distance above this is set to the maximum of 1.
508
+ number_theta: relative error rate above this is set to the maximum of 1.
509
+
510
+ Returns:
511
+ Dictionary with per-point precision, recall and F1
512
+ """
513
+ assert len(targets) == len(predictions)
514
+ per_point_scores = {"precision": [], "recall": [], "f1": []}
515
+ for pred, target in zip(predictions, targets):
516
+ all_metrics = []
517
+ for transposed in [True, False]:
518
+ pred_table = _parse_table(pred, transposed=transposed)
519
+ target_table = _parse_table(target, transposed=transposed)
520
+
521
+ all_metrics.extend([_table_datapoints_precision_recall_f1(target_table, pred_table, text_theta, number_theta)])
522
+
523
+ p, r, f = max(all_metrics, key=lambda x: x[-1])
524
+ per_point_scores["precision"].append(p)
525
+ per_point_scores["recall"].append(r)
526
+ per_point_scores["f1"].append(f)
527
+ return per_point_scores
528
+
529
+ def table_datapoints_precision_recall( # deplot ์„ฑ๋Šฅ์ง€ํ‘œ
530
+ targets,
531
+ predictions,
532
+ text_theta = 0.5,
533
+ number_theta = 0.1,
534
+ ):
535
+ """Aggregated version of table_datapoints_precision_recall_per_point().
536
+
537
+ Same as table_datapoints_precision_recall_per_point() but returning aggregated
538
+ scores instead of per-point scores.
539
+
540
+ Args:
541
+ targets: list of list of strings.
542
+ predictions: list of strings.
543
+ text_theta: relative edit distance above this is set to the maximum of 1.
544
+ number_theta: relative error rate above this is set to the maximum of 1.
545
+
546
+ Returns:
547
+ Dictionary with aggregated precision, recall and F1
548
+ """
549
+ score_dict = table_datapoints_precision_recall_per_point(
550
+ targets, predictions, text_theta, number_theta
551
+ )
552
+ return {
553
+ "table_datapoints_precision": (
554
+ sum(score_dict["precision"]) / len(targets)
555
+ ),
556
+ "table_datapoints_recall": (
557
+ sum(score_dict["recall"]) / len(targets)
558
+ ),
559
+ "table_datapoints_f1": sum(score_dict["f1"]) / len(targets),
560
+ }
561
+
562
+ def evaluate_rms(generated_table,label_table):
563
+ predictions=[generated_table]
564
+ targets=[label_table]
565
+ RMS = table_datapoints_precision_recall(targets, predictions)
566
+ return RMS
567
+
568
+ def ko_deplot_convert_to_dataframe(generated_table_str):
569
+ lines = generated_table_str.strip().split(" \n")
570
+ headers=[]
571
+ data=[]
572
+ for i in range(len(lines[1].split(" | "))):
573
+ headers.append(f"{i}")
574
+ for line in lines[1:len(lines)-1]:
575
+ data.append(line.split("| "))
576
+ df = pd.DataFrame(data, columns=headers)
577
+ return df
578
+
579
+ def ko_deplot_convert_to_dataframe2(label_table_str):
580
+ lines = label_table_str.strip().split(" \n")
581
+ headers=[]
582
+ data=[]
583
+ for i in range(len(lines[1].split(" | "))):
584
+ headers.append(f"{i}")
585
+ for line in lines[1:]:
586
+ data.append(line.split("| "))
587
+ df = pd.DataFrame(data, columns=headers)
588
+ return df
589
+
590
+ def aihub_deplot_convert_to_dataframe(table_str):
591
+ lines = table_str.strip().split("\n")
592
+ headers = []
593
+ if(len(lines[3].split(" | "))>len(lines[4].split(" | "))):
594
+ category=lines[3].split(" | ")
595
+ del category[0]
596
+ value=lines[4].split(" | ")
597
+ df=pd.DataFrame({"๋ฒ”๋ก€":category,"๊ฐ’":value})
598
+ return df
599
+ else:
600
+ for i in range(len(lines[3].split(" | "))):
601
+ headers.append(f"{i}")
602
+ data = [line.split(" | ") for line in lines[3:]]
603
+ df = pd.DataFrame(data, columns=headers)
604
+ return df
605
+ def unichart_convert_to_dataframe(table_str):
606
+ lines=table_str.split(" & ")
607
+ headers=[]
608
+ data=[]
609
+ del lines[0]
610
+ for i in range(len(lines[1].split(" | "))):
611
+ headers.append(f"{i}")
612
+ if lines[0]=="value":
613
+ for line in lines[1:]:
614
+ data.append(line.split(" | "))
615
+ else:
616
+ category=lines[0].split(" | ")
617
+ category.insert(0," ")
618
+ data.append(category)
619
+ for line in lines[1:]:
620
+ data.append(line.split(" | "))
621
+ df=pd.DataFrame(data,columns=headers)
622
+ return df
623
+
624
+ class Highlighter:
625
+ def __init__(self):
626
+ self.row = 0
627
+ self.col = 0
628
+
629
+ def compare_and_highlight(self, pred_table_elem, target_table, pred_table_row, props=''):
630
+ if self.row >= pred_table_row:
631
+ self.col += 1
632
+ self.row = 0
633
+ if pred_table_elem != target_table.iloc[self.row, self.col]:
634
+ self.row += 1
635
+ return props
636
+ else:
637
+ self.row += 1
638
+ return None
639
+
640
+ # 1. ๋ฐ์ดํ„ฐ ๋กœ๋“œ
641
+ aihub_deplot_result_df = pd.read_csv('./aihub_deplot_result.csv')
642
+ ko_deplot_result= './ko-deplot-base-pred-epoch1-refinetuning.json'
643
+ unichart_result='./unichart_results.json'
644
+
645
+ # 2. ์ฒดํฌํ•ด์•ผ ํ•˜๋Š” ์ด๋ฏธ์ง€ ํŒŒ์ผ ๋กœ๋“œ
646
+ def load_image_checklist(file):
647
+ with open(file, 'r') as f:
648
+ #image_names = [f'"{line.strip()}"' for line in f]
649
+ image_names = f.read().splitlines()
650
+ return image_names
651
+
652
+ # 3. ํ˜„์žฌ ์ธ๋ฑ์Šค๋ฅผ ์ถ”์ ํ•˜๊ธฐ ์œ„ํ•œ ๋ณ€์ˆ˜
653
+ current_index = 0
654
+ image_names = []
655
+ def show_image(current_idx):
656
+ image_name=image_names[current_idx]
657
+ image_path = f"./images/{image_name}.jpg"
658
+ if not os.path.exists(image_path):
659
+ raise FileNotFoundError(f"Image file not found: {image_path}")
660
+ return Image.open(image_path)
661
+
662
+ # 4. ๋ฒ„ํŠผ ํด๋ฆญ ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ
663
+ def non_real_time_check(file):
664
+ highlighter1 = Highlighter()
665
+ highlighter2 = Highlighter()
666
+ highlighter3 = Highlighter()
667
+ #global image_names, current_index
668
+ #image_names = load_image_checklist(file)
669
+ #current_index = 0
670
+ #image=show_image(current_index)
671
+ file_name =image_names[current_index].replace("Source","Label")
672
+
673
+ json_path="./ko_deplot_labeling_data.json"
674
+ with open(json_path, 'r', encoding='utf-8') as file:
675
+ json_data = json.load(file)
676
+ for key, value in json_data.items():
677
+ if key == file_name:
678
+ ko_deplot_labeling_str=value.get("txt").replace("<0x0A>","\n")
679
+ ko_deplot_label_title=ko_deplot_labeling_str.split(" \n ")[0].replace("TITLE | ","์ œ๋ชฉ:")
680
+ break
681
+
682
+ ko_deplot_rms_path="./ko_deplot_rms.txt"
683
+ unichart_rms_path="./unichart_rms.txt"
684
+
685
+ json_path="./unichart_labeling_data.json"
686
+ with open(json_path, 'r', encoding='utf-8') as file:
687
+ json_data = json.load(file)
688
+ for entry in json_data:
689
+ if entry["imgname"]==image_names[current_index]+".jpg":
690
+ unichart_labeling_str=entry["label"]
691
+ unichart_label_title=entry["label"].split(" & ")[0].split(" | ")[1]
692
+
693
+ with open(ko_deplot_rms_path,'r',encoding='utf-8') as file:
694
+ lines=file.readlines()
695
+ flag=0
696
+ for line in lines:
697
+ parts=line.strip().split(", ")
698
+ if(len(parts)==2 and parts[0]==image_names[current_index]):
699
+ ko_deplot_rms=parts[1]
700
+ flag=1
701
+ break
702
+ if(flag==0):
703
+ ko_deplot_rms="none"
704
+
705
+ with open(unichart_rms_path,'r',encoding='utf-8') as file:
706
+ lines=file.readlines()
707
+ flag=0
708
+ for line in lines:
709
+ parts=line.strip().split(": ")
710
+ if(len(parts)==2 and parts[0]==image_names[current_index]+".jpg"):
711
+ unichart_rms=parts[1]
712
+ flag=1
713
+ break
714
+ if(flag==0):
715
+ unichart_rms="none"
716
+
717
+
718
+
719
+ ko_deplot_generated_title,ko_deplot_generated_table=ko_deplot_display_results(current_index)
720
+ aihub_deplot_generated_table,aihub_deplot_label_table,aihub_deplot_generated_title,aihub_deplot_label_title=aihub_deplot_display_results(current_index)
721
+ unichart_generated_table,unichart_generated_title=unichart_display_results(current_index)
722
+ #ko_deplot_RMS=evaluate_rms(ko_deplot_generated_table,ko_deplot_labeling_str)
723
+ aihub_deplot_RMS=evaluate_rms(aihub_deplot_generated_table,aihub_deplot_label_table)
724
+
725
+
726
+ if flag == 1:
727
+ value = [round(float(ko_deplot_rms), 1)]
728
+ else:
729
+ value = [0]
730
+
731
+ ko_deplot_score_table = pd.DataFrame({
732
+ 'category': ['f1'],
733
+ 'value': value
734
+ })
735
+
736
+ value=[round(float(unichart_rms)/100,1)]
737
+ unichart_score_table=pd.DataFrame({
738
+ 'category':['f1'],
739
+ 'value':value
740
+ })
741
+ aihub_deplot_score_table=pd.DataFrame({
742
+ 'category': ['precision', 'recall', 'f1'],
743
+ 'value': [
744
+ round(aihub_deplot_RMS['table_datapoints_precision'],1),
745
+ round(aihub_deplot_RMS['table_datapoints_recall'],1),
746
+ round(aihub_deplot_RMS['table_datapoints_f1'],1)
747
+ ]
748
+ })
749
+
750
+ ko_deplot_generated_df=ko_deplot_convert_to_dataframe(ko_deplot_generated_table)
751
+ aihub_deplot_generated_df=aihub_deplot_convert_to_dataframe(aihub_deplot_generated_table)
752
+ unichart_generated_df=unichart_convert_to_dataframe(unichart_generated_table)
753
+ ko_deplot_labeling_df=ko_deplot_convert_to_dataframe2(ko_deplot_labeling_str)
754
+ aihub_deplot_labeling_df=aihub_deplot_convert_to_dataframe(aihub_deplot_label_table)
755
+ unichart_labeling_df=unichart_convert_to_dataframe(unichart_labeling_str)
756
+
757
+ ko_deplot_generated_df_row=ko_deplot_generated_df.shape[0]
758
+ aihub_deplot_generated_df_row=aihub_deplot_generated_df.shape[0]
759
+ unichart_generated_df_row=unichart_generated_df.shape[0]
760
+
761
+
762
+ styled_ko_deplot_table=ko_deplot_generated_df.style.applymap(highlighter1.compare_and_highlight,target_table=ko_deplot_labeling_df,pred_table_row=ko_deplot_generated_df_row,props='color:red')
763
+
764
+
765
+ styled_aihub_deplot_table=aihub_deplot_generated_df.style.applymap(highlighter2.compare_and_highlight,target_table=aihub_deplot_labeling_df,pred_table_row=aihub_deplot_generated_df_row,props='color:red')
766
+
767
+
768
+ styled_unichart_table=unichart_generated_df.style.applymap(highlighter3.compare_and_highlight,target_table=unichart_labeling_df,pred_table_row=unichart_generated_df_row,props='color:red')
769
+
770
+ #return ko_deplot_convert_to_dataframe(ko_deplot_generated_table), aihub_deplot_convert_to_dataframe(aihub_deplot_generated_table), aihub_deplot_convert_to_dataframe(label_table), ko_deplot_score_table, aihub_deplot_score_table
771
+ return gr.DataFrame(styled_ko_deplot_table,label=ko_deplot_generated_title+"(ko deplot ์ถ”๋ก  ๊ฒฐ๊ณผ)"),gr.DataFrame(styled_aihub_deplot_table,label=aihub_deplot_generated_title+"(aihub deplot ์ถ”๋ก  ๊ฒฐ๊ณผ)"),gr.DataFrame(styled_unichart_table,label="์ œ๋ชฉ:"+unichart_generated_title+"(unichart ์ถ”๋ก  ๊ฒฐ๊ณผ)"),gr.DataFrame(ko_deplot_labeling_df,label=ko_deplot_label_title+"(ko deplot ์ •๋‹ต ํ…Œ์ด๋ธ”)"), gr.DataFrame(aihub_deplot_labeling_df,label=aihub_deplot_label_title+"(aihub deplot ์ •๋‹ต ํ…Œ์ด๋ธ”)"),gr.DataFrame(unichart_labeling_df,label="์ œ๋ชฉ:"+unichart_label_title+"(unichart ์ •๋‹ต ํ…Œ์ด๋ธ”)"),ko_deplot_score_table, aihub_deplot_score_table,unichart_score_table
772
+
773
+
774
+ def ko_deplot_display_results(index):
775
+ filename=image_names[index]+".jpg"
776
+ with open(ko_deplot_result, 'r', encoding='utf-8') as f:
777
+ data = json.load(f)
778
+ for entry in data:
779
+ if entry['filename'].endswith(filename):
780
+ #return entry['table']
781
+ parts=entry['table'].split("\n",1)
782
+ return parts[0].replace("TITLE | ","์ œ๋ชฉ:"),entry['table']
783
+
784
+ def aihub_deplot_display_results(index):
785
+ if index < 0 or index >= len(image_names):
786
+ return "Index out of range", None, None
787
+ image_name = image_names[index]
788
+ image_row = aihub_deplot_result_df[aihub_deplot_result_df['data_id'] == image_name]
789
+ if not image_row.empty:
790
+ generated_table = image_row['generated_table'].values[0]
791
+ generated_title=generated_table.split("\n")[1]
792
+ label_table = image_row['label_table'].values[0]
793
+ label_title=label_table.split("\n")[1]
794
+ return generated_table, label_table, generated_title, label_title
795
+ else:
796
+ return "No results found for the image", None, None
797
+ def unichart_display_results(index):
798
+ image_name=image_names[index]
799
+ with open(unichart_result,'r',encoding='utf-8') as f:
800
+ data=json.load(f)
801
+ for entry in data:
802
+ if entry['imgname']==image_name+".jpg":
803
+ return entry['label'],entry['label'].split(" & ")[0].split(" | ")[1]
804
+
805
+ def previous_image():
806
+ global current_index
807
+ if current_index>0:
808
+ current_index-=1
809
+ image=show_image(current_index)
810
+ return image, image_names[current_index],gr.update(interactive=current_index>0), gr.update(interactive=current_index<len(image_names)-1)
811
+
812
+ def next_image():
813
+ global current_index
814
+ if current_index<len(image_names)-1:
815
+ current_index+=1
816
+ image=show_image(current_index)
817
+ return image, image_names[current_index],gr.update(interactive=current_index>0), gr.update(interactive=current_index<len(image_names)-1)
818
+
819
+ def real_time_check(image_file):
820
+ highlighter1 = Highlighter()
821
+ highlighter2 = Highlighter()
822
+ highlighter3=Highlighter()
823
+ image = Image.open(image_file)
824
+ result_model1 = predict_model1(image)
825
+ parts=result_model1.split("\n")
826
+ del parts[-1]
827
+ result_model1="\n".join(parts)
828
+ ko_deplot_generated_title=result_model1.split("\n")[0].split(" | ")[1]
829
+ ko_deplot_table=ko_deplot_convert_to_dataframe2(result_model1)
830
+
831
+ result_model2 = predict_model2(image)
832
+ aihub_deplot_generated_title=result_model2.split("\n")[1].split(":")[1]
833
+ aihub_deplot_table=aihub_deplot_convert_to_dataframe(result_model2)
834
+ image_base_name = os.path.basename(image_file.name).replace("Source","Label")
835
+ file_name, _ = os.path.splitext(image_base_name)
836
+
837
+ result_model3=predict_model3(image)
838
+ unichart_table=unichart_convert_to_dataframe(result_model3)
839
+ unichart_generated_title=result_model3.split(" & ")[0].split(" | ")[1]
840
+
841
+ #aihub_labeling_data_json="./labeling_data/"+file_name+".json"
842
+
843
+ json_path="./ko_deplot_labeling_data.json"
844
+ with open(json_path, 'r', encoding='utf-8') as file:
845
+ json_data = json.load(file)
846
+ for key, value in json_data.items():
847
+ if key == file_name:
848
+ ko_deplot_labeling_str=value.get("txt").replace("<0x0A>","\n")
849
+ ko_deplot_label_title=ko_deplot_labeling_str.split(" \n ")[0].split(" | ")[1]
850
+ break
851
+
852
+ ko_deplot_label_table=ko_deplot_convert_to_dataframe2(ko_deplot_labeling_str)
853
+
854
+ #aihub_deplot_labeling_str=process_json_file2(aihub_labeling_data_json)
855
+ #aihub_deplot_label_title=aihub_deplot_labeling_str.split("\n")[1].split(":")[1]
856
+
857
+ image_row = aihub_deplot_result_df[aihub_deplot_result_df['data_id'] == file_name.replace("Label","Source")]
858
+ label_table=""
859
+ label_title=""
860
+ if not image_row.empty:
861
+ label_table = image_row['label_table'].values[0]
862
+ label_title=label_table.split("\n")[1]
863
+
864
+ aihub_deplot_label_table=aihub_deplot_convert_to_dataframe(label_table)
865
+
866
+ json_path="./unichart_labeling_data.json"
867
+ with open(json_path, 'r', encoding='utf-8') as file:
868
+ json_data = json.load(file)
869
+ for entry in json_data:
870
+ if entry["imgname"]==os.path.basename(image_file.name):
871
+ unichart_labeling_str=entry["label"]
872
+ unichart_label_title=entry["label"].split(" & ")[0].split(" | ")[1]
873
+ unichart_label_table=unichart_convert_to_dataframe(unichart_labeling_str)
874
+
875
+ ko_deplot_RMS=evaluate_rms(result_model1,ko_deplot_labeling_str)
876
+ aihub_deplot_RMS=evaluate_rms(result_model2,label_table)
877
+ unichart_RMS=evaluate_rms(result_model3.replace("Characteristic","Title").replace("&","\n"),unichart_labeling_str.replace("Characteristic","Title").replace("&","\n"))
878
+ ko_deplot_score_table=pd.DataFrame({
879
+ 'category': ['precision', 'recall', 'f1'],
880
+ 'value': [
881
+ round(ko_deplot_RMS['table_datapoints_precision'],1),
882
+ round(ko_deplot_RMS['table_datapoints_recall'],1),
883
+ round(ko_deplot_RMS['table_datapoints_f1'],1)
884
+ ]
885
+ })
886
+ aihub_deplot_score_table=pd.DataFrame({
887
+ 'category': ['precision', 'recall', 'f1'],
888
+ 'value': [
889
+ round(aihub_deplot_RMS['table_datapoints_precision'],1),
890
+ round(aihub_deplot_RMS['table_datapoints_recall'],1),
891
+ round(aihub_deplot_RMS['table_datapoints_f1'],1)
892
+ ]
893
+ })
894
+
895
+ unichart_score_table=pd.DataFrame({
896
+ 'category': ['precision', 'recall', 'f1'],
897
+ 'value': [
898
+ round(unichart_RMS['table_datapoints_precision'],1),
899
+ round(unichart_RMS['table_datapoints_recall'],1),
900
+ round(unichart_RMS['table_datapoints_f1'],1)
901
+ ]
902
+ })
903
+
904
+ ko_deplot_generated_df_row=ko_deplot_table.shape[0]
905
+ aihub_deplot_generated_df_row=aihub_deplot_table.shape[0]
906
+ unichart_generated_df_row=unichart_table.shape[0]
907
+ styled_ko_deplot_table=ko_deplot_table.style.applymap(highlighter1.compare_and_highlight,target_table=ko_deplot_label_table,pred_table_row=ko_deplot_generated_df_row,props='color:red')
908
+ styled_aihub_deplot_table=aihub_deplot_table.style.applymap(highlighter2.compare_and_highlight,target_table=aihub_deplot_label_table,pred_table_row=aihub_deplot_generated_df_row,props='color:red')
909
+ styled_unichart_table=unichart_table.style.applymap(highlighter3.compare_and_highlight,target_table=unichart_label_table,pred_table_row=unichart_generated_df_row,props='color:red')
910
+ return gr.DataFrame(styled_ko_deplot_table,label=ko_deplot_generated_title+"(kodeplot ์ถ”๋ก  ๊ฒฐ๊ณผ)") , gr.DataFrame(styled_aihub_deplot_table,label=aihub_deplot_generated_title+"(aihub deplot ์ถ”๋ก  ๊ฒฐ๊ณผ)"),gr.DataFrame(styled_unichart_table,label=unichart_generated_title+"(unichart ์ถ”๋ก  ๊ฒฐ๊ณผ)"),gr.DataFrame(ko_deplot_label_table,label=ko_deplot_label_title+"(kodeplot ์ •๋‹ต ํ…Œ์ด๋ธ”)"),gr.DataFrame(aihub_deplot_label_table,label=label_title+"(aihub deplot ์ •๋‹ต ํ…Œ์ด๋ธ”)"),gr.DataFrame(unichart_label_table,label=unichart_label_title+"(unichart ์ •๋‹ต ํ…Œ์ด๋ธ”)"),ko_deplot_score_table, aihub_deplot_score_table,unichart_score_table
911
+ #return ko_deplot_table,aihub_deplot_table,aihub_deplot_label_table,ko_deplot_score_table,aihub_deplot_score_table
912
+ def inference(mode,image_uploader,file_uploader):
913
+ if(mode=="์ด๋ฏธ์ง€ ์—…๋กœ๋“œ"):
914
+ ko_deplot_table, aihub_deplot_table, unichart_table, ko_deplot_label_table,aihub_deplot_label_table,unichart_label_table,ko_deplot_score_table, aihub_deplot_score_table,unichart_score_table= real_time_check(image_uploader)
915
+ return ko_deplot_table, aihub_deplot_table, unichart_table,ko_deplot_label_table, aihub_deplot_label_table,unichart_label_table,ko_deplot_score_table, aihub_deplot_score_table,unichart_score_table
916
+ else:
917
+ styled_ko_deplot_table,styled_aihub_deplot_table,styled_unichart_table,ko_deplot_label_table,aihub_deplot_label_table,unichart_label_table,ko_deplot_score_table,aihub_deplot_score_table, unichart_score_table=non_real_time_check(file_uploader)
918
+ return styled_ko_deplot_table, styled_aihub_deplot_table, styled_unichart_table,ko_deplot_label_table,aihub_deplot_label_table,unichart_label_table,ko_deplot_score_table, aihub_deplot_score_table, unichart_score_table
919
+ def interface_selector(selector):
920
+ if selector == "์ด๋ฏธ์ง€ ์—…๋กœ๋“œ":
921
+ return gr.update(visible=True),gr.update(visible=False),gr.State("image_upload"),gr.update(visible=False),gr.update(visible=False)
922
+ elif selector == "ํŒŒ์ผ ์—…๋กœ๋“œ":
923
+ return gr.update(visible=False),gr.update(visible=True),gr.State("file_upload"), gr.update(visible=True),gr.update(visible=True)
924
+
925
+ def file_selector(selector):
926
+ if selector == "low score ์ฐจํŠธ":
927
+ return gr.File("./new_bottom_20_percent_images.txt")
928
+ elif selector == "high score ์ฐจํŠธ":
929
+ return gr.File("./new_top_20_percent_images.txt")
930
+
931
+ def update_results(model_type):
932
+ if "ko_deplot" == model_type:
933
+ return gr.update(visible=True),gr.update(visible=True),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=True),gr.update(visible=False),gr.update(visible=False)
934
+ elif "aihub_deplot" == model_type:
935
+ return gr.update(visible=False),gr.update(visible=False),gr.update(visible=True),gr.update(visible=True),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=True),gr.update(visible=False)
936
+ elif "unichart"==model_type:
937
+ return gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=True),gr.update(visible=True),gr.update(visible=False),gr.update(visible=False),gr.update(visible=True)
938
+ else:
939
+ return gr.update(visible=True), gr.update(visible=True),gr.update(visible=True),gr.update(visible=True),gr.update(visible=True),gr.update(visible=True),gr.update(visible=True),gr.update(visible=True),gr.update(visible=True)
940
+
941
+ def display_image(image_file):
942
+ image=Image.open(image_file)
943
+ return image, os.path.basename(image_file)
944
+
945
+ def display_image_in_file(image_checklist):
946
+ global image_names, current_index
947
+ image_names = load_image_checklist(image_checklist)
948
+ image=show_image(current_index)
949
+ return image,image_names[current_index]
950
+
951
+ def update_file_based_on_chart_type(chart_type, all_file_path):
952
+ with open(all_file_path, 'r', encoding='utf-8') as file:
953
+ lines = file.readlines()
954
+ filtered_lines=[]
955
+ if chart_type == "์ „์ฒด":
956
+ filtered_lines = lines
957
+ elif chart_type == "์ผ๋ฐ˜ ๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•":
958
+ filtered_lines = [line for line in lines if "_horizontal bar_standard" in line]
959
+ elif chart_type=="๋ˆ„์  ๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•":
960
+ filtered_lines = [line for line in lines if "_horizontal bar_accumulation" in line]
961
+ elif chart_type=="100% ๊ธฐ์ค€ ๋ˆ„์  ๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•":
962
+ filtered_lines = [line for line in lines if "_horizontal bar_100per accumulation" in line]
963
+ elif chart_type=="์ผ๋ฐ˜ ์„ธ๋กœ ๋ง‰๋Œ€ํ˜•":
964
+ filtered_lines = [line for line in lines if "_vertical bar_standard" in line]
965
+ elif chart_type=="๋ˆ„์  ์„ธ๋กœ ๋ง‰๋Œ€ํ˜•":
966
+ filtered_lines = [line for line in lines if "_vertical bar_accumulation" in line]
967
+ elif chart_type=="100% ๊ธฐ์ค€ ๋ˆ„์  ์„ธ๋กœ ๋ง‰๋Œ€ํ˜•":
968
+ filtered_lines = [line for line in lines if "_vertical bar_100per accumulation" in line]
969
+ elif chart_type=="์„ ํ˜•":
970
+ filtered_lines = [line for line in lines if "_line_standard" in line]
971
+ elif chart_type=="์›ํ˜•":
972
+ filtered_lines = [line for line in lines if "_pie_standard" in line]
973
+ elif chart_type=="๊ธฐํƒ€ ๋ฐฉ์‚ฌํ˜•":
974
+ filtered_lines = [line for line in lines if "_etc_radial" in line]
975
+ elif chart_type=="๊ธฐํƒ€ ํ˜ผํ•ฉํ˜•":
976
+ filtered_lines = [line for line in lines if "_etc_mix" in line]
977
+ # ์ƒˆ๋กœ์šด ํŒŒ์ผ์— ๊ธฐ๋ก
978
+ new_file_path = "./filtered_chart_images.txt"
979
+ with open(new_file_path, 'w', encoding='utf-8') as file:
980
+ file.writelines(filtered_lines)
981
+
982
+ return new_file_path
983
+
984
+ def handle_chart_type_change(chart_type,all_file_path):
985
+ new_file_path = update_file_based_on_chart_type(chart_type, all_file_path)
986
+ global image_names, current_index
987
+ image_names = load_image_checklist(new_file_path)
988
+ current_index=0
989
+ image=show_image(current_index)
990
+ return image,image_names[current_index]
991
+
992
+ with gr.Blocks() as iface:
993
+ mode=gr.State("image_upload")
994
+ with gr.Row():
995
+ with gr.Column():
996
+ #mode_label=gr.Text("์ด๋ฏธ์ง€ ์—…๋กœ๋“œ๊ฐ€ ์„ ํƒ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
997
+ upload_option = gr.Radio(choices=["์ด๋ฏธ์ง€ ์—…๋กœ๋“œ", "ํŒŒ์ผ ์—…๋กœ๋“œ"], value="์ด๋ฏธ์ง€ ์—…๋กœ๋“œ", label="์—…๋กœ๋“œ ์˜ต์…˜")
998
+ #with gr.Row():
999
+ #image_button = gr.Button("์ด๋ฏธ์ง€ ์—…๋กœ๋“œ")
1000
+ #file_button = gr.Button("ํŒŒ์ผ ์—…๋กœ๋“œ")
1001
+
1002
+ # ์ด๋ฏธ์ง€์™€ ํŒŒ์ผ ์—…๋กœ๋“œ ์ปดํฌ๋„ŒํŠธ (์ดˆ๊ธฐ์—๋Š” ์ˆจ๊น€ ์ƒํƒœ)
1003
+ # global image_uploader,file_uploader
1004
+ image_uploader= gr.File(file_count="single",file_types=["image"],visible=True)
1005
+ file_uploader= gr.File(file_count="single", file_types=[".txt"], visible=False)
1006
+ file_upload_option=gr.Radio(choices=["low score ์ฐจํŠธ","high score ์ฐจํŠธ"],label="ํŒŒ์ผ ์—…๋กœ๋“œ ์˜ต์…˜",visible=False)
1007
+ chart_type = gr.Dropdown(["์ผ๋ฐ˜ ๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•","๋ˆ„์  ๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•","100% ๊ธฐ์ค€ ๋ˆ„์  ๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•", "์ผ๋ฐ˜ ์„ธ๋กœ ๋ง‰๋Œ€ํ˜•","๋ˆ„์  ์„ธ๋กœ ๋ง‰๋Œ€ํ˜•","100% ๊ธฐ์ค€ ๋ˆ„์  ์„ธ๋กœ ๋ง‰๋Œ€ํ˜•","์„ ํ˜•", "์›ํ˜•", "๊ธฐํƒ€ ๋ฐฉ์‚ฌํ˜•", "๊ธฐํƒ€ ํ˜ผํ•ฉํ˜•", "์ „์ฒด"], label="Chart Type", value="all")
1008
+ model_type=gr.Dropdown(["ko_deplot","aihub_deplot","unichart","all"],label="model")
1009
+ image_displayer=gr.Image(visible=True)
1010
+ with gr.Row():
1011
+ pre_button=gr.Button("์ด์ „",interactive="False")
1012
+ next_button=gr.Button("๋‹ค์Œ")
1013
+ image_name=gr.Text("์ด๋ฏธ์ง€ ์ด๋ฆ„",visible=False)
1014
+ #image_button.click(interface_selector, inputs=gr.State("์ด๋ฏธ์ง€ ์—…๋กœ๋“œ"), outputs=[image_uploader,file_uploader,mode,mode_label,image_name])
1015
+ #file_button.click(interface_selector, inputs=gr.State("ํŒŒ์ผ ์—…๋กœ๋“œ"), outputs=[image_uploader, file_uploader,mode,mode_label,image_name])
1016
+ inference_button=gr.Button("์ถ”๋ก ")
1017
+ with gr.Column():
1018
+ ko_deplot_generated_table=gr.DataFrame(visible=False,label="ko-deplot ์ถ”๋ก  ๊ฒฐ๊ณผ")
1019
+ aihub_deplot_generated_table=gr.DataFrame(visible=False,label="aihub-deplot ์ถ”๋ก  ๊ฒฐ๊ณผ")
1020
+ unichart_generated_table=gr.DataFrame(visible=False,label="unichart ์ถ”๋ก  ๊ฒฐ๊ณผ")
1021
+ with gr.Column():
1022
+ ko_deplot_label_table=gr.DataFrame(visible=False,label="ko-deplot ์ •๋‹ตํ…Œ์ด๋ธ”")
1023
+ aihub_deplot_label_table=gr.DataFrame(visible=False,label="aihub-deplot ์ •๋‹ตํ…Œ์ด๋ธ”")
1024
+ unichart_label_table=gr.DataFrame(visible=False,label="unichart ์ •๋‹ตํ…Œ์ด๋ธ”")
1025
+ with gr.Column():
1026
+ ko_deplot_score_table=gr.DataFrame(visible=False,label="ko_deplot ์ ์ˆ˜")
1027
+ aihub_deplot_score_table=gr.DataFrame(visible=False,label="aihub_deplot ์ ์ˆ˜")
1028
+ unichart_score_table=gr.DataFrame(visible=False,label="unichart ์ ์ˆ˜")
1029
+ model_type.change(
1030
+ update_results,
1031
+ inputs=[model_type],
1032
+ outputs=[ko_deplot_generated_table,ko_deplot_score_table,aihub_deplot_generated_table,aihub_deplot_score_table,unichart_generated_table,unichart_score_table,ko_deplot_label_table,aihub_deplot_label_table,unichart_label_table]
1033
+ )
1034
+
1035
+ upload_option.change(
1036
+ interface_selector,
1037
+ inputs=[upload_option],
1038
+ outputs=[image_uploader, file_uploader, mode, image_name,file_upload_option]
1039
+ )
1040
+
1041
+ file_upload_option.change(
1042
+ file_selector,
1043
+ inputs=[file_upload_option],
1044
+ outputs=[file_uploader]
1045
+ )
1046
+
1047
+ chart_type.change(handle_chart_type_change, inputs=[chart_type,file_uploader],outputs=[image_displayer,image_name])
1048
+ image_uploader.upload(display_image,inputs=[image_uploader],outputs=[image_displayer,image_name])
1049
+ file_uploader.change(display_image_in_file,inputs=[file_uploader],outputs=[image_displayer,image_name])
1050
+ pre_button.click(previous_image, outputs=[image_displayer,image_name,pre_button,next_button])
1051
+ next_button.click(next_image, outputs=[image_displayer,image_name,pre_button,next_button])
1052
+ inference_button.click(inference,inputs=[upload_option,image_uploader,file_uploader],outputs=[ko_deplot_generated_table, aihub_deplot_generated_table, unichart_generated_table, ko_deplot_label_table, aihub_deplot_label_table, unichart_label_table, ko_deplot_score_table, aihub_deplot_score_table,unichart_score_table])
1053
+
1054
+ if __name__ == "__main__":
1055
+ print("Launching Gradio interface...")
1056
+ sys.stdout.flush() # stdout ๋ฒ„ํผ๋ฅผ ๋น„์›๋‹ˆ๋‹ค.
1057
+ iface.launch(share=True)
1058
+ time.sleep(2) # Gradio URL์ด ์ถœ๋ ฅ๋  ๋•Œ๊นŒ์ง€ ์ž ์‹œ ๊ธฐ๋‹ค๋ฆฝ๋‹ˆ๋‹ค.
1059
+ sys.stdout.flush() # ๋‹ค์‹œ stdout ๋ฒ„ํผ๋ฅผ ๋น„์›๋‹ˆ๋‹ค.
1060
+ # Gradio๊ฐ€ ์ œ๊ณตํ•˜๋Š” URLs์„ ํŒŒ์ผ์— ๊ธฐ๋กํ•ฉ๋‹ˆ๋‹ค.
1061
+ with open("gradio_url.log", "w") as f:
1062
+ print(iface.local_url, file=f)
1063
+ print(iface.share_url, file=f)