vaivskku commited on
Commit
36bca9d
ยท
1 Parent(s): f7411a0
Files changed (1) hide show
  1. app.py +936 -0
app.py ADDED
@@ -0,0 +1,936 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoProcessor, Pix2StructForConditionalGeneration, T5Tokenizer, T5ForConditionalGeneration, Pix2StructProcessor
3
+ from PIL import Image
4
+ import torch
5
+ import warnings
6
+ import re
7
+ import json
8
+ import os
9
+ import numpy as np
10
+ import pandas as pd
11
+ from tqdm import tqdm
12
+ import argparse
13
+ from scipy import optimize
14
+ from typing import Optional
15
+ import dataclasses
16
+ import editdistance
17
+ import itertools
18
+ import sys
19
+ import time
20
+ import logging
21
+
22
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
23
+ logger = logging.getLogger()
24
+
25
+ warnings.filterwarnings('ignore')
26
+ MAX_PATCHES = 512
27
+ # Load the models and processor
28
+ #device = torch.device("cpu")
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+
31
+ # Paths to the models
32
+ ko_deplot_model_path = './model_epoch_1_210000.bin'
33
+ aihub_deplot_model_path='./deplot_k.pt'
34
+ t5_model_path = './ke_t5.pt'
35
+
36
+ # Load first model ko-deplot
37
+ processor1 = Pix2StructProcessor.from_pretrained('nuua/ko-deplot')
38
+ model1 = Pix2StructForConditionalGeneration.from_pretrained('nuua/ko-deplot')
39
+ model1.load_state_dict(torch.load(ko_deplot_model_path, map_location=device))
40
+ model1.to(device)
41
+
42
+ # Load second model aihub-deplot
43
+ processor2 = AutoProcessor.from_pretrained("ybelkada/pix2struct-base")
44
+ model2 = Pix2StructForConditionalGeneration.from_pretrained("ybelkada/pix2struct-base")
45
+ model2.load_state_dict(torch.load(aihub_deplot_model_path, map_location=device))
46
+
47
+
48
+ tokenizer = T5Tokenizer.from_pretrained("KETI-AIR/ke-t5-base")
49
+ t5_model = T5ForConditionalGeneration.from_pretrained("KETI-AIR/ke-t5-base")
50
+ t5_model.load_state_dict(torch.load(t5_model_path, map_location=device))
51
+
52
+ model2.to(device)
53
+ t5_model.to(device)
54
+
55
+
56
+ #ko-deplot ์ถ”๋ก ํ•จ์ˆ˜
57
+ # Function to format output
58
+ def format_output(prediction):
59
+ return prediction.replace('<0x0A>', '\n')
60
+
61
+ # First model prediction ko-deplot
62
+ def predict_model1(image):
63
+ images = [image]
64
+ inputs = processor1(images=images, text="What is the title of the chart", return_tensors="pt", padding=True)
65
+ inputs = {k: v.to(device) for k, v in inputs.items()} # Move to GPU
66
+
67
+ model1.eval()
68
+ with torch.no_grad():
69
+ predictions = model1.generate(**inputs, max_new_tokens=4096)
70
+ outputs = [processor1.decode(pred, skip_special_tokens=True) for pred in predictions]
71
+
72
+ formatted_output = format_output(outputs[0])
73
+ return formatted_output
74
+
75
+
76
+ def replace_unk(text):
77
+ # 1. '์ œ๋ชฉ:', '์œ ํ˜•:' ๊ธ€์ž ์•ž์— ์žˆ๋Š” <unk>๋Š” \n๋กœ ๋ฐ”๊ฟˆ
78
+ text = re.sub(r'<unk>(?=์ œ๋ชฉ:|์œ ํ˜•:)', '\n', text)
79
+ # 2. '์„ธ๋กœ ' ๋˜๋Š” '๊ฐ€๋กœ '์™€ '๋Œ€ํ˜•' ์‚ฌ์ด์— ์žˆ๋Š” <unk>๋ฅผ ""๋กœ ๋ฐ”๊ฟˆ
80
+ text = re.sub(r'(?<=์„ธ๋กœ |๊ฐ€๋กœ )<unk>(?=๋Œ€ํ˜•)', '', text)
81
+ # 3. ์ˆซ์ž์™€ ํ…์ŠคํŠธ ์‚ฌ์ด์— ์žˆ๋Š” <unk>๋ฅผ \n๋กœ ๋ฐ”๊ฟˆ
82
+ text = re.sub(r'(\d)<unk>([^\d])', r'\1\n\2', text)
83
+ # 4. %, ์›, ๊ฑด, ๋ช… ๋’ค์— ๋‚˜์˜ค๋Š” <unk>๋ฅผ \n๋กœ ๋ฐ”๊ฟˆ
84
+ text = re.sub(r'(?<=[%์›๊ฑด๋ช…\)])<unk>', '\n', text)
85
+ # 5. ์ˆซ์ž์™€ ์ˆซ์ž ์‚ฌ์ด์— ์žˆ๋Š” <unk>๋ฅผ \n๋กœ ๋ฐ”๊ฟˆ
86
+ text = re.sub(r'(\d)<unk>(\d)', r'\1\n\2', text)
87
+ # 6. 'ํ˜•'์ด๋ผ๋Š” ๊ธ€์ž์™€ ' |' ์‚ฌ์ด์— ์žˆ๋Š” <unk>๋ฅผ \n๋กœ ๋ฐ”๊ฟˆ
88
+ text = re.sub(r'ํ˜•<unk>(?= \|)', 'ํ˜•\n', text)
89
+ # 7. ๋‚˜๋จธ์ง€ <unk>๋ฅผ ๋ชจ๋‘ ""๋กœ ๋ฐ”๊ฟˆ
90
+ text = text.replace('<unk>', '')
91
+ return text
92
+
93
+ # Second model prediction aihub_deplot
94
+ def predict_model2(image):
95
+ image = image.convert("RGB")
96
+ inputs = processor2(images=image, return_tensors="pt", max_patches=MAX_PATCHES).to(device)
97
+
98
+ flattened_patches = inputs.flattened_patches.to(device)
99
+ attention_mask = inputs.attention_mask.to(device)
100
+
101
+ model2.eval()
102
+ t5_model.eval()
103
+ with torch.no_grad():
104
+ deplot_generated_ids = model2.generate(flattened_patches=flattened_patches, attention_mask=attention_mask, max_length=1000)
105
+ generated_datatable = processor2.batch_decode(deplot_generated_ids, skip_special_tokens=False)[0]
106
+ generated_datatable = generated_datatable.replace("<pad>", "<unk>").replace("</s>", "<unk>")
107
+ refined_table = replace_unk(generated_datatable)
108
+ return refined_table
109
+
110
+ #function for converting aihub dataset labeling json file to ko-deplot data table
111
+ def process_json_file(input_file):
112
+ with open(input_file, 'r', encoding='utf-8') as file:
113
+ data = json.load(file)
114
+
115
+ # ํ•„์š”ํ•œ ๋ฐ์ดํ„ฐ ์ถ”์ถœ
116
+ chart_type = data['metadata']['chart_sub']
117
+ title = data['annotations'][0]['title']
118
+ x_axis = data['annotations'][0]['axis_label']['x_axis']
119
+ y_axis = data['annotations'][0]['axis_label']['y_axis']
120
+ legend = data['annotations'][0]['legend']
121
+ data_labels = data['annotations'][0]['data_label']
122
+ is_legend = data['annotations'][0]['is_legend']
123
+
124
+ # ์›ํ•˜๋Š” ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜
125
+ formatted_string = f"TITLE | {title} <0x0A> "
126
+ if '๊ฐ€๋กœ' in chart_type:
127
+ if is_legend:
128
+ # ๊ฐ€๋กœ ์ฐจํŠธ ์ฒ˜๋ฆฌ
129
+ formatted_string += " | ".join(legend) + " <0x0A> "
130
+ for i in range(len(y_axis)):
131
+ row = [y_axis[i]]
132
+ for j in range(len(legend)):
133
+ if i < len(data_labels[j]):
134
+ row.append(str(data_labels[j][i])) # ๋ฐ์ดํ„ฐ ๊ฐ’์„ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜
135
+ else:
136
+ row.append("") # ๋ฐ์ดํ„ฐ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ๋นˆ ๋ฌธ์ž์—ด ์ถ”๊ฐ€
137
+ formatted_string += " | ".join(row) + " <0x0A> "
138
+ else:
139
+ # is_legend๊ฐ€ False์ธ ๊ฒฝ์šฐ
140
+ for i in range(len(y_axis)):
141
+ row = [y_axis[i], str(data_labels[0][i])]
142
+ formatted_string += " | ".join(row) + " <0x0A> "
143
+ elif chart_type == "์›ํ˜•":
144
+ # ์›ํ˜• ์ฐจํŠธ ์ฒ˜๋ฆฌ
145
+ if legend:
146
+ used_labels = legend
147
+ else:
148
+ used_labels = x_axis
149
+
150
+ formatted_string += " | ".join(used_labels) + " <0x0A> "
151
+ row = [data_labels[0][i] for i in range(len(used_labels))]
152
+ formatted_string += " | ".join(row) + " <0x0A> "
153
+ elif chart_type == "ํ˜ผํ•ฉํ˜•":
154
+ # ํ˜ผํ•ฉํ˜• ์ฐจํŠธ ์ฒ˜๋ฆฌ
155
+ all_legends = [ann['legend'][0] for ann in data['annotations']]
156
+ formatted_string += " | ".join(all_legends) + " <0x0A> "
157
+
158
+ combined_data = []
159
+ for i in range(len(x_axis)):
160
+ row = [x_axis[i]]
161
+ for ann in data['annotations']:
162
+ if i < len(ann['data_label'][0]):
163
+ row.append(str(ann['data_label'][0][i])) # ๋ฐ์ดํ„ฐ ๊ฐ’์„ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜
164
+ else:
165
+ row.append("") # ๋ฐ์ดํ„ฐ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ๋นˆ ๋ฌธ์ž์—ด ์ถ”๊ฐ€
166
+ combined_data.append(" | ".join(row))
167
+
168
+ formatted_string += " <0x0A> ".join(combined_data) + " <0x0A> "
169
+ else:
170
+ # ๊ธฐํƒ€ ์ฐจํŠธ ์ฒ˜๋ฆฌ
171
+ if is_legend:
172
+ formatted_string += " | ".join(legend) + " <0x0A> "
173
+ for i in range(len(x_axis)):
174
+ row = [x_axis[i]]
175
+ for j in range(len(legend)):
176
+ if i < len(data_labels[j]):
177
+ row.append(str(data_labels[j][i])) # ๋ฐ์ดํ„ฐ ๊ฐ’์„ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜
178
+ else:
179
+ row.append("") # ๋ฐ์ดํ„ฐ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ๋นˆ ๋ฌธ์ž์—ด ์ถ”๊ฐ€
180
+ formatted_string += " | ".join(row) + " <0x0A> "
181
+ else:
182
+ for i in range(len(x_axis)):
183
+ if i < len(data_labels[0]):
184
+ formatted_string += f"{x_axis[i]} | {str(data_labels[0][i])} <0x0A> "
185
+ else:
186
+ formatted_string += f"{x_axis[i]} | <0x0A> " # ๋ฐ์ดํ„ฐ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ๋นˆ ๋ฌธ์ž์—ด ์ถ”๊ฐ€
187
+
188
+ # ๋งˆ์ง€๋ง‰ "<0x0A> " ์ œ๊ฑฐ
189
+ formatted_string = formatted_string[:-8]
190
+ return format_output(formatted_string)
191
+
192
+ def chart_data(data):
193
+ datatable = []
194
+ num = len(data)
195
+ for n in range(num):
196
+ title = data[n]['title'] if data[n]['is_title'] else ''
197
+ legend = data[n]['legend'] if data[n]['is_legend'] else ''
198
+ datalabel = data[n]['data_label'] if data[n]['is_datalabel'] else [0]
199
+ unit = data[n]['unit'] if data[n]['is_unit'] else ''
200
+ base = data[n]['base'] if data[n]['is_base'] else ''
201
+ x_axis_title = data[n]['axis_title']['x_axis']
202
+ y_axis_title = data[n]['axis_title']['y_axis']
203
+ x_axis = data[n]['axis_label']['x_axis'] if data[n]['is_axis_label_x_axis'] else [0]
204
+ y_axis = data[n]['axis_label']['y_axis'] if data[n]['is_axis_label_y_axis'] else [0]
205
+
206
+ if len(legend) > 1:
207
+ datalabel = np.array(datalabel).transpose().tolist()
208
+
209
+ datatable.append([title, legend, datalabel, unit, base, x_axis_title, y_axis_title, x_axis, y_axis])
210
+
211
+ return datatable
212
+
213
+ def datatable(data, chart_type):
214
+ data_table = ''
215
+ num = len(data)
216
+
217
+ if len(data) == 2:
218
+ temp = []
219
+ temp.append(f"๋Œ€์ƒ: {data[0][4]}")
220
+ temp.append(f"์ œ๋ชฉ: {data[0][0]}")
221
+ temp.append(f"์œ ํ˜•: {' '.join(chart_type[0:2])}")
222
+ temp.append(f"{data[0][5]} | {data[0][1][0]}({data[0][3]}) | {data[1][1][0]}({data[1][3]})")
223
+
224
+ x_axis = data[0][7]
225
+ for idx, x in enumerate(x_axis):
226
+ temp.append(f"{x} | {data[0][2][0][idx]} | {data[1][2][0][idx]}")
227
+
228
+ data_table = '\n'.join(temp)
229
+ else:
230
+ for n in range(num):
231
+ temp = []
232
+
233
+ title, legend, datalabel, unit, base, x_axis_title, y_axis_title, x_axis, y_axis = data[n]
234
+ legend = [element + f"({unit})" for element in legend]
235
+
236
+ if len(legend) > 1:
237
+ temp.append(f"๋Œ€์ƒ: {base}")
238
+ temp.append(f"์ œ๋ชฉ: {title}")
239
+ temp.append(f"์œ ํ˜•: {' '.join(chart_type[0:2])}")
240
+ temp.append(f"{x_axis_title} | {' | '.join(legend)}")
241
+
242
+ if chart_type[2] == "์›ํ˜•":
243
+ datalabel = sum(datalabel, [])
244
+ temp.append(f"{' | '.join([str(d) for d in datalabel])}")
245
+ data_table = '\n'.join(temp)
246
+ else:
247
+ axis = y_axis if chart_type[2] == "๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•" else x_axis
248
+ for idx, (x, d) in enumerate(zip(axis, datalabel)):
249
+ temp_d = [str(e) for e in d]
250
+ temp_d = " | ".join(temp_d)
251
+ row = f"{x} | {temp_d}"
252
+ temp.append(row)
253
+ data_table = '\n'.join(temp)
254
+ else:
255
+ temp.append(f"๋Œ€์ƒ: {base}")
256
+ temp.append(f"์ œ๋ชฉ: {title}")
257
+ temp.append(f"์œ ํ˜•: {' '.join(chart_type[0:2])}")
258
+ temp.append(f"{x_axis_title} | {unit}")
259
+ axis = y_axis if chart_type[2] == "๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•" else x_axis
260
+ datalabel = datalabel[0]
261
+
262
+ for idx, x in enumerate(axis):
263
+ row = f"{x} | {str(datalabel[idx])}"
264
+ temp.append(row)
265
+ data_table = '\n'.join(temp)
266
+
267
+ return data_table
268
+
269
+ #function for converting aihub dataset labeling json file to aihub-deplot data table
270
+ def process_json_file2(input_file):
271
+ with open(input_file, 'r', encoding='utf-8') as file:
272
+ data = json.load(file)
273
+ # ํ•„์š”ํ•œ ๋ฐ์ดํ„ฐ ์ถ”์ถœ
274
+ chart_multi = data['metadata']['chart_multi']
275
+ chart_main = data['metadata']['chart_main']
276
+ chart_sub = data['metadata']['chart_sub']
277
+ chart_type = [chart_multi, chart_sub, chart_main]
278
+ chart_annotations = data['annotations']
279
+
280
+ charData = chart_data(chart_annotations)
281
+ dataTable = datatable(charData, chart_type)
282
+ return dataTable
283
+
284
+ # RMS
285
+ def _to_float(text): # ๋‹จ์œ„ ๋–ผ๊ณ  ์ˆซ์ž๋งŒ..?
286
+ try:
287
+ if text.endswith("%"):
288
+ # Convert percentages to floats.
289
+ return float(text.rstrip("%")) / 100.0
290
+ else:
291
+ return float(text)
292
+ except ValueError:
293
+ return None
294
+
295
+
296
+ def _get_relative_distance(
297
+ target, prediction, theta = 1.0
298
+ ):
299
+ """Returns min(1, |target-prediction|/|target|)."""
300
+ if not target:
301
+ return int(not prediction)
302
+ distance = min(abs((target - prediction) / target), 1)
303
+ return distance if distance < theta else 1
304
+
305
+ def anls_metric(target: str, prediction: str, theta: float = 0.5):
306
+ edit_distance = editdistance.eval(target, prediction)
307
+ normalize_ld = edit_distance / max(len(target), len(prediction))
308
+ return 1 - normalize_ld if normalize_ld < theta else 0
309
+
310
+ def _permute(values, indexes):
311
+ return tuple(values[i] if i < len(values) else "" for i in indexes)
312
+
313
+
314
+ @dataclasses.dataclass(frozen=True)
315
+ class Table:
316
+ """Helper class for the content of a markdown table."""
317
+
318
+ base: Optional[str] = None
319
+ title: Optional[str] = None
320
+ chartType: Optional[str] = None
321
+ headers: tuple[str, Ellipsis] = dataclasses.field(default_factory=tuple)
322
+ rows: tuple[tuple[str, Ellipsis], Ellipsis] = dataclasses.field(default_factory=tuple)
323
+
324
+ def permuted(self, indexes):
325
+ """Builds a version of the table changing the column order."""
326
+ return Table(
327
+ base=self.base,
328
+ title=self.title,
329
+ chartType=self.chartType,
330
+ headers=_permute(self.headers, indexes),
331
+ rows=tuple(_permute(row, indexes) for row in self.rows),
332
+ )
333
+
334
+ def aligned(
335
+ self, headers, text_theta = 0.5
336
+ ):
337
+ """Builds a column permutation with headers in the most correct order."""
338
+ if len(headers) != len(self.headers):
339
+ raise ValueError(f"Header length {headers} must match {self.headers}.")
340
+ distance = []
341
+ for h2 in self.headers:
342
+ distance.append(
343
+ [
344
+ 1 - anls_metric(h1, h2, text_theta)
345
+ for h1 in headers
346
+ ]
347
+ )
348
+ cost_matrix = np.array(distance)
349
+ row_ind, col_ind = optimize.linear_sum_assignment(cost_matrix)
350
+ permutation = [idx for _, idx in sorted(zip(col_ind, row_ind))]
351
+ score = (1 - cost_matrix)[permutation[1:], range(1, len(row_ind))].prod()
352
+ return self.permuted(permutation), score
353
+
354
+ def _parse_table(text, transposed = False): # ํ‘œ ์ œ๋ชฉ, ์—ด ์ด๋ฆ„, ํ–‰ ์ฐพ๊ธฐ
355
+ """Builds a table from a markdown representation."""
356
+ lines = text.lower().splitlines()
357
+ if not lines:
358
+ return Table()
359
+
360
+ if lines[0].startswith("๋Œ€์ƒ: "):
361
+ base = lines[0][len("๋Œ€์ƒ: ") :].strip()
362
+ offset = 1 #
363
+ else:
364
+ base = None
365
+ offset = 0
366
+ if lines[1].startswith("์ œ๋ชฉ: "):
367
+ title = lines[1][len("์ œ๋ชฉ: ") :].strip()
368
+ offset = 2 #
369
+ else:
370
+ title = None
371
+ offset = 1
372
+ if lines[2].startswith("์œ ํ˜•: "):
373
+ chartType = lines[2][len("์œ ํ˜•: ") :].strip()
374
+ offset = 3 #
375
+ else:
376
+ chartType = None
377
+
378
+ if len(lines) < offset + 1:
379
+ return Table(base=base, title=title, chartType=chartType)
380
+
381
+ rows = []
382
+ for line in lines[offset:]:
383
+ rows.append(tuple(v.strip() for v in line.split(" | ")))
384
+ if transposed:
385
+ rows = [tuple(row) for row in itertools.zip_longest(*rows, fillvalue="")]
386
+ return Table(base=base, title=title, chartType=chartType, headers=rows[0], rows=tuple(rows[1:]))
387
+
388
+ def _get_table_datapoints(table):
389
+ datapoints = {}
390
+ if table.base is not None:
391
+ datapoints["๋Œ€์ƒ"] = table.base
392
+ if table.title is not None:
393
+ datapoints["์ œ๋ชฉ"] = table.title
394
+ if table.chartType is not None:
395
+ datapoints["์œ ํ˜•"] = table.chartType
396
+ if not table.rows or len(table.headers) <= 1:
397
+ return datapoints
398
+ for row in table.rows:
399
+ for header, cell in zip(table.headers[1:], row[1:]):
400
+ #print(f"{row[0]} {header} >> {cell}")
401
+ datapoints[f"{row[0]} {header}"] = cell #
402
+ return datapoints
403
+
404
+ def _get_datapoint_metric( #
405
+ target,
406
+ prediction,
407
+ text_theta=0.5,
408
+ number_theta=0.1,
409
+ ):
410
+ """Computes a metric that scores how similar two datapoint pairs are."""
411
+ key_metric = anls_metric(
412
+ target[0], prediction[0], text_theta
413
+ )
414
+ pred_float = _to_float(prediction[1]) # ์ˆซ์ž์ธ์ง€ ํ™•์ธ
415
+ target_float = _to_float(target[1])
416
+ if pred_float is not None and target_float:
417
+ return key_metric * (
418
+ 1 - _get_relative_distance(target_float, pred_float, number_theta) # ์ˆซ์ž๋ฉด ์ƒ๋Œ€์  ๊ฑฐ๋ฆฌ๊ฐ’ ๊ณ„์‚ฐ
419
+ )
420
+ elif target[1] == prediction[1]:
421
+ return key_metric
422
+ else:
423
+ return key_metric * anls_metric(
424
+ target[1], prediction[1], text_theta
425
+ )
426
+
427
+ def _table_datapoints_precision_recall_f1( # ์ฐ ๊ณ„์‚ฐ
428
+ target_table,
429
+ prediction_table,
430
+ text_theta = 0.5,
431
+ number_theta = 0.1,
432
+ ):
433
+ """Calculates matching similarity between two tables as dicts."""
434
+ target_datapoints = list(_get_table_datapoints(target_table).items())
435
+ prediction_datapoints = list(_get_table_datapoints(prediction_table).items())
436
+ if not target_datapoints and not prediction_datapoints:
437
+ return 1, 1, 1
438
+ if not target_datapoints:
439
+ return 0, 1, 0
440
+ if not prediction_datapoints:
441
+ return 1, 0, 0
442
+ distance = []
443
+ for t, _ in target_datapoints:
444
+ distance.append(
445
+ [
446
+ 1 - anls_metric(t, p, text_theta)
447
+ for p, _ in prediction_datapoints
448
+ ]
449
+ )
450
+ cost_matrix = np.array(distance)
451
+ row_ind, col_ind = optimize.linear_sum_assignment(cost_matrix)
452
+ score = 0
453
+ for r, c in zip(row_ind, col_ind):
454
+ score += _get_datapoint_metric(
455
+ target_datapoints[r], prediction_datapoints[c], text_theta, number_theta
456
+ )
457
+ if score == 0:
458
+ return 0, 0, 0
459
+ precision = score / len(prediction_datapoints)
460
+ recall = score / len(target_datapoints)
461
+ return precision, recall, 2 * precision * recall / (precision + recall)
462
+
463
+ def table_datapoints_precision_recall_per_point( # ๊ฐ๊ฐ ๊ณ„์‚ฐ...
464
+ targets,
465
+ predictions,
466
+ text_theta = 0.5,
467
+ number_theta = 0.1,
468
+ ):
469
+ """Computes precisin recall and F1 metrics given two flattened tables.
470
+
471
+ Parses each string into a dictionary of keys and values using row and column
472
+ headers. Then we match keys between the two dicts as long as their relative
473
+ levenshtein distance is below a threshold. Values are also compared with
474
+ ANLS if strings or relative distance if they are numeric.
475
+
476
+ Args:
477
+ targets: list of list of strings.
478
+ predictions: list of strings.
479
+ text_theta: relative edit distance above this is set to the maximum of 1.
480
+ number_theta: relative error rate above this is set to the maximum of 1.
481
+
482
+ Returns:
483
+ Dictionary with per-point precision, recall and F1
484
+ """
485
+ assert len(targets) == len(predictions)
486
+ per_point_scores = {"precision": [], "recall": [], "f1": []}
487
+ for pred, target in zip(predictions, targets):
488
+ all_metrics = []
489
+ for transposed in [True, False]:
490
+ pred_table = _parse_table(pred, transposed=transposed)
491
+ target_table = _parse_table(target, transposed=transposed)
492
+
493
+ all_metrics.extend([_table_datapoints_precision_recall_f1(target_table, pred_table, text_theta, number_theta)])
494
+
495
+ p, r, f = max(all_metrics, key=lambda x: x[-1])
496
+ per_point_scores["precision"].append(p)
497
+ per_point_scores["recall"].append(r)
498
+ per_point_scores["f1"].append(f)
499
+ return per_point_scores
500
+
501
+ def table_datapoints_precision_recall( # deplot ์„ฑ๋Šฅ์ง€ํ‘œ
502
+ targets,
503
+ predictions,
504
+ text_theta = 0.5,
505
+ number_theta = 0.1,
506
+ ):
507
+ """Aggregated version of table_datapoints_precision_recall_per_point().
508
+
509
+ Same as table_datapoints_precision_recall_per_point() but returning aggregated
510
+ scores instead of per-point scores.
511
+
512
+ Args:
513
+ targets: list of list of strings.
514
+ predictions: list of strings.
515
+ text_theta: relative edit distance above this is set to the maximum of 1.
516
+ number_theta: relative error rate above this is set to the maximum of 1.
517
+
518
+ Returns:
519
+ Dictionary with aggregated precision, recall and F1
520
+ """
521
+ score_dict = table_datapoints_precision_recall_per_point(
522
+ targets, predictions, text_theta, number_theta
523
+ )
524
+ return {
525
+ "table_datapoints_precision": (
526
+ sum(score_dict["precision"]) / len(targets)
527
+ ),
528
+ "table_datapoints_recall": (
529
+ sum(score_dict["recall"]) / len(targets)
530
+ ),
531
+ "table_datapoints_f1": sum(score_dict["f1"]) / len(targets),
532
+ }
533
+
534
+ def evaluate_rms(generated_table,label_table):
535
+ predictions=[generated_table]
536
+ targets=[label_table]
537
+ RMS = table_datapoints_precision_recall(targets, predictions)
538
+ return RMS
539
+
540
+ def is_float(s):
541
+ try:
542
+ float(s)
543
+ return True
544
+ except ValueError:
545
+ return False
546
+
547
+ def ko_deplot_convert_to_dataframe(table_str):
548
+ lines = table_str.strip().split("\n")
549
+ title=lines[0].split(" | ")[1]
550
+ if(len(lines[1].split(" | "))==len(lines[2].split(" | "))):
551
+ headers=["0","1"]
552
+ if(is_float(lines[1].split(" | ")[1]) or lines[1].split(" | ")[0]==""):
553
+ data=[line.split(" | ") for line in lines[1:]]
554
+ df=pd.DataFrame(data,columns=headers)
555
+ return df
556
+ else:
557
+ category=lines[1].split(" | ")
558
+ value=lines[2].split(" | ")
559
+ df=pd.DataFrame({"๋ฒ”๋ก€":category,"๊ฐ’":value})
560
+ return df
561
+ else:
562
+ headers=[]
563
+ data=[]
564
+ for i in range(len(lines[2].split(" | "))):
565
+ headers.append(f"{i}")
566
+ line1=lines[1].split(" | ")
567
+ line1.insert(0," ")
568
+ data.append(line1)
569
+ for line in lines[2:]:
570
+ data.append(line.split(" | "))
571
+ df = pd.DataFrame(data, columns=headers)
572
+ return df
573
+
574
+ def aihub_deplot_convert_to_dataframe(table_str):
575
+ lines = table_str.strip().split("\n")
576
+ headers = []
577
+ if(len(lines[3].split(" | "))>len(lines[4].split(" | "))):
578
+ category=lines[3].split(" | ")
579
+ del category[0]
580
+ value=lines[4].split(" | ")
581
+ df=pd.DataFrame({"๋ฒ”๋ก€":category,"๊ฐ’":value})
582
+ return df
583
+ else:
584
+ for i in range(len(lines[3].split(" | "))):
585
+ headers.append(f"{i}")
586
+ data = [line.split(" | ") for line in lines[3:]]
587
+ df = pd.DataFrame(data, columns=headers)
588
+ return df
589
+
590
+ class Highlighter:
591
+ def __init__(self):
592
+ self.row = 0
593
+ self.col = 0
594
+
595
+ def compare_and_highlight(self, pred_table_elem, target_table, pred_table_row, props=''):
596
+ if self.row >= pred_table_row:
597
+ self.col += 1
598
+ self.row = 0
599
+ if pred_table_elem != target_table.iloc[self.row, self.col]:
600
+ self.row += 1
601
+ return props
602
+ else:
603
+ self.row += 1
604
+ return None
605
+
606
+ # 1. ๋ฐ์ดํ„ฐ ๋กœ๋“œ
607
+ aihub_deplot_result_df = pd.read_csv('./aihub_deplot_result.csv')
608
+ ko_deplot_result= './ko_deplot_result.json'
609
+
610
+ # 2. ์ฒดํฌํ•ด์•ผ ํ•˜๋Š” ์ด๋ฏธ์ง€ ํŒŒ์ผ ๋กœ๋“œ
611
+ def load_image_checklist(file):
612
+ with open(file, 'r') as f:
613
+ #image_names = [f'"{line.strip()}"' for line in f]
614
+ image_names = f.read().splitlines()
615
+ return image_names
616
+
617
+ # 3. ํ˜„์žฌ ์ธ๋ฑ์Šค๋ฅผ ์ถ”์ ํ•˜๊ธฐ ์œ„ํ•œ ๋ณ€์ˆ˜
618
+ current_index = 0
619
+ image_names = []
620
+ def show_image(current_idx):
621
+ image_name=image_names[current_idx]
622
+ image_path = f"./images/{image_name}.jpg"
623
+ if not os.path.exists(image_path):
624
+ raise FileNotFoundError(f"Image file not found: {image_path}")
625
+ return Image.open(image_path)
626
+
627
+ # 4. ๋ฒ„ํŠผ ํด๋ฆญ ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ
628
+ def non_real_time_check(file):
629
+ highlighter1 = Highlighter()
630
+ highlighter2 = Highlighter()
631
+ #global image_names, current_index
632
+ #image_names = load_image_checklist(file)
633
+ #current_index = 0
634
+ #image=show_image(current_index)
635
+ file_name =image_names[current_index].replace("Source","Label")
636
+
637
+ json_path="./ko_deplot_labeling_data.json"
638
+ with open(json_path, 'r', encoding='utf-8') as file:
639
+ json_data = json.load(file)
640
+ for key, value in json_data.items():
641
+ if key == file_name:
642
+ ko_deplot_labeling_str=value.get("txt").replace("<0x0A>","\n")
643
+ ko_deplot_label_title=ko_deplot_labeling_str.split(" \n ")[0].replace("TITLE | ","์ œ๋ชฉ:")
644
+ break
645
+
646
+ ko_deplot_rms_path="./ko_deplot_rms.txt"
647
+
648
+ with open(ko_deplot_rms_path,'r',encoding='utf-8') as file:
649
+ lines=file.readlines()
650
+ flag=0
651
+ for line in lines:
652
+ parts=line.strip().split(", ")
653
+ if(len(parts)==2 and parts[0]==image_names[current_index]):
654
+ ko_deplot_rms=parts[1]
655
+ flag=1
656
+ break
657
+ if(flag==0):
658
+ ko_deplot_rms="none"
659
+ ko_deplot_generated_title,ko_deplot_generated_table=ko_deplot_display_results(current_index)
660
+ aihub_deplot_generated_table,aihub_deplot_label_table,aihub_deplot_generated_title,aihub_deplot_label_title=aihub_deplot_display_results(current_index)
661
+ #ko_deplot_RMS=evaluate_rms(ko_deplot_generated_table,ko_deplot_labeling_str)
662
+ aihub_deplot_RMS=evaluate_rms(aihub_deplot_generated_table,aihub_deplot_label_table)
663
+
664
+
665
+ if flag == 1:
666
+ value = [round(float(ko_deplot_rms), 1)]
667
+ else:
668
+ value = [0]
669
+
670
+ ko_deplot_score_table = pd.DataFrame({
671
+ 'category': ['f1'],
672
+ 'value': value
673
+ })
674
+
675
+ aihub_deplot_score_table=pd.DataFrame({
676
+ 'category': ['precision', 'recall', 'f1'],
677
+ 'value': [
678
+ round(aihub_deplot_RMS['table_datapoints_precision'],1),
679
+ round(aihub_deplot_RMS['table_datapoints_recall'],1),
680
+ round(aihub_deplot_RMS['table_datapoints_f1'],1)
681
+ ]
682
+ })
683
+ ko_deplot_generated_df=ko_deplot_convert_to_dataframe(ko_deplot_generated_table)
684
+ aihub_deplot_generated_df=aihub_deplot_convert_to_dataframe(aihub_deplot_generated_table)
685
+ ko_deplot_labeling_df=ko_deplot_convert_to_dataframe(ko_deplot_labeling_str)
686
+ aihub_deplot_labeling_df=aihub_deplot_convert_to_dataframe(aihub_deplot_label_table)
687
+
688
+ ko_deplot_generated_df_row=ko_deplot_generated_df.shape[0]
689
+ aihub_deplot_generated_df_row=aihub_deplot_generated_df.shape[0]
690
+
691
+
692
+ styled_ko_deplot_table=ko_deplot_generated_df.style.applymap(highlighter1.compare_and_highlight,target_table=ko_deplot_labeling_df,pred_table_row=ko_deplot_generated_df_row,props='color:red')
693
+
694
+
695
+ styled_aihub_deplot_table=aihub_deplot_generated_df.style.applymap(highlighter2.compare_and_highlight,target_table=aihub_deplot_labeling_df,pred_table_row=aihub_deplot_generated_df_row,props='color:red')
696
+
697
+ #return ko_deplot_convert_to_dataframe(ko_deplot_generated_table), aihub_deplot_convert_to_dataframe(aihub_deplot_generated_table), aihub_deplot_convert_to_dataframe(label_table), ko_deplot_score_table, aihub_deplot_score_table
698
+ return gr.DataFrame(styled_ko_deplot_table,label=ko_deplot_generated_title+"(ko deplot ์ถ”๋ก  ๊ฒฐ๊ณผ)"),gr.DataFrame(styled_aihub_deplot_table,label=aihub_deplot_generated_title+"(aihub deplot ์ถ”๋ก  ๊ฒฐ๊ณผ)"),gr.DataFrame(ko_deplot_labeling_df,label=ko_deplot_label_title+"(ko deplot ์ •๋‹ต ํ…Œ์ด๋ธ”)"), gr.DataFrame(aihub_deplot_labeling_df,label=aihub_deplot_label_title+"(aihub deplot ์ •๋‹ต ํ…Œ์ด๋ธ”)"),ko_deplot_score_table, aihub_deplot_score_table
699
+
700
+ def ko_deplot_display_results(index):
701
+ filename=image_names[index]+".jpg"
702
+ with open(ko_deplot_result, 'r', encoding='utf-8') as f:
703
+ data = json.load(f)
704
+ for entry in data:
705
+ if entry['filename'].endswith(filename):
706
+ #return entry['table']
707
+ parts=entry['table'].split(" \n ",1)
708
+ return parts[0].replace("TITLE | ","์ œ๋ชฉ:"),entry['table']
709
+
710
+ def aihub_deplot_display_results(index):
711
+ if index < 0 or index >= len(image_names):
712
+ return "Index out of range", None, None
713
+ image_name = image_names[index]
714
+ image_row = aihub_deplot_result_df[aihub_deplot_result_df['data_id'] == image_name]
715
+ if not image_row.empty:
716
+ generated_table = image_row['generated_table'].values[0]
717
+ generated_title=generated_table.split("\n")[1]
718
+ label_table = image_row['label_table'].values[0]
719
+ label_title=label_table.split("\n")[1]
720
+ return generated_table, label_table, generated_title, label_title
721
+ else:
722
+ return "No results found for the image", None, None
723
+
724
+ def previous_image():
725
+ global current_index
726
+ if current_index>0:
727
+ current_index-=1
728
+ image=show_image(current_index)
729
+ return image, image_names[current_index],gr.update(interactive=current_index>0), gr.update(interactive=current_index<len(image_names)-1)
730
+
731
+ def next_image():
732
+ global current_index
733
+ if current_index<len(image_names)-1:
734
+ current_index+=1
735
+ image=show_image(current_index)
736
+ return image, image_names[current_index],gr.update(interactive=current_index>0), gr.update(interactive=current_index<len(image_names)-1)
737
+
738
+ def real_time_check(image_file):
739
+ highlighter1 = Highlighter()
740
+ highlighter2 = Highlighter()
741
+ image = Image.open(image_file)
742
+ result_model1 = predict_model1(image)
743
+ ko_deplot_generated_title=result_model1.split("\n")[0].split(" | ")[1]
744
+ ko_deplot_table=ko_deplot_convert_to_dataframe(result_model1)
745
+
746
+ result_model2 = predict_model2(image)
747
+ aihub_deplot_generated_title=result_model2.split("\n")[1].split(":")[1]
748
+ aihub_deplot_table=aihub_deplot_convert_to_dataframe(result_model2)
749
+ image_base_name = os.path.basename(image_file.name).replace("Source","Label")
750
+ file_name, _ = os.path.splitext(image_base_name)
751
+ aihub_labeling_data_json="./labeling_data/"+file_name+".json"
752
+ #aihub_labeling_data_json="./labeling_data/line_graph.json"
753
+ ko_deplot_labeling_str=process_json_file(aihub_labeling_data_json)
754
+ ko_deplot_label_title=ko_deplot_labeling_str.split("\n")[0].split(" | ")[1]
755
+ ko_deplot_label_table=ko_deplot_convert_to_dataframe(ko_deplot_labeling_str)
756
+
757
+ aihub_deplot_labeling_str=process_json_file2(aihub_labeling_data_json)
758
+ aihub_deplot_label_title=aihub_deplot_labeling_str.split("\n")[1].split(":")[1]
759
+ aihub_deplot_label_table=aihub_deplot_convert_to_dataframe(aihub_deplot_labeling_str)
760
+
761
+ ko_deplot_RMS=evaluate_rms(result_model1,ko_deplot_labeling_str)
762
+ aihub_deplot_RMS=evaluate_rms(result_model2,aihub_deplot_labeling_str)
763
+
764
+ ko_deplot_score_table=pd.DataFrame({
765
+ 'category': ['precision', 'recall', 'f1'],
766
+ 'value': [
767
+ round(ko_deplot_RMS['table_datapoints_precision'],1),
768
+ round(ko_deplot_RMS['table_datapoints_recall'],1),
769
+ round(ko_deplot_RMS['table_datapoints_f1'],1)
770
+ ]
771
+ })
772
+ aihub_deplot_score_table=pd.DataFrame({
773
+ 'category': ['precision', 'recall', 'f1'],
774
+ 'value': [
775
+ round(aihub_deplot_RMS['table_datapoints_precision'],1),
776
+ round(aihub_deplot_RMS['table_datapoints_recall'],1),
777
+ round(aihub_deplot_RMS['table_datapoints_f1'],1)
778
+ ]
779
+ })
780
+
781
+ ko_deplot_generated_df_row=ko_deplot_table.shape[0]
782
+ aihub_deplot_generated_df_row=aihub_deplot_table.shape[0]
783
+ styled_ko_deplot_table=ko_deplot_table.style.applymap(highlighter1.compare_and_highlight,target_table=ko_deplot_label_table,pred_table_row=ko_deplot_generated_df_row,props='color:red')
784
+ styled_aihub_deplot_table=aihub_deplot_table.style.applymap(highlighter2.compare_and_highlight,target_table=aihub_deplot_label_table,pred_table_row=aihub_deplot_generated_df_row,props='color:red')
785
+
786
+ return gr.DataFrame(styled_ko_deplot_table,label=ko_deplot_generated_title+"(kodeplot ์ถ”๋ก ๊ฒฐ๊ณผ)") , gr.DataFrame(styled_aihub_deplot_table,label=aihub_deplot_generated_title+"(aihub deplot ์ถ”๋ก  ๊ฒฐ๊ณผ)"),gr.DataFrame(ko_deplot_label_table,label=ko_deplot_label_title+"(kodeplot ์ •๋‹ต ํ…Œ์ด๋ธ”)"),gr.DataFrame(aihub_deplot_label_table,label=aihub_deplot_label_title+"(aihub deplot ์ •๋‹ต ํ…Œ์ด๋ธ”)"),ko_deplot_score_table, aihub_deplot_score_table
787
+ #return ko_deplot_table,aihub_deplot_table,aihub_deplot_label_table,ko_deplot_score_table,aihub_deplot_score_table
788
+ def inference(mode,image_uploader,file_uploader):
789
+ if(mode=="์ด๋ฏธ์ง€ ์—…๋กœ๋“œ"):
790
+ ko_deplot_table, aihub_deplot_table, ko_deplot_label_table,aihub_deplot_label_table,ko_deplot_score_table, aihub_deplot_score_table = real_time_check(image_uploader)
791
+ return ko_deplot_table, aihub_deplot_table, ko_deplot_label_table, aihub_deplot_label_table,ko_deplot_score_table, aihub_deplot_score_table
792
+ else:
793
+ styled_ko_deplot_table, styled_aihub_deplot_table, ko_deplot_label_table, aihub_deplot_label_table,ko_deplot_score_table, aihub_deplot_score_table =non_real_time_check(file_uploader)
794
+ return styled_ko_deplot_table, styled_aihub_deplot_table, ko_deplot_label_table,aihub_deplot_label_table,ko_deplot_score_table, aihub_deplot_score_table
795
+
796
+ def interface_selector(selector):
797
+ if selector == "์ด๋ฏธ์ง€ ์—…๋กœ๋“œ":
798
+ return gr.update(visible=True),gr.update(visible=False),gr.State("image_upload"),gr.update(visible=False),gr.update(visible=False)
799
+ elif selector == "ํŒŒ์ผ ์—…๋กœ๋“œ":
800
+ return gr.update(visible=False),gr.update(visible=True),gr.State("file_upload"), gr.update(visible=True),gr.update(visible=True)
801
+
802
+ def file_selector(selector):
803
+ if selector == "low score ์ฐจํŠธ":
804
+ return gr.File("./bottom_20_percent_images.txt")
805
+ elif selector == "high score ์ฐจํŠธ":
806
+ return gr.File("./top_20_percent_images.txt")
807
+
808
+ def update_results(model_type):
809
+ if "ko_deplot" == model_type:
810
+ return gr.update(visible=True),gr.update(visible=True),gr.update(visible=False),gr.update(visible=False),gr.update(visible=True),gr.update(visible=False)
811
+ elif "aihub_deplot" == model_type:
812
+ return gr.update(visible=False),gr.update(visible=False),gr.update(visible=True),gr.update(visible=True),gr.update(visible=False),gr.update(visible=True)
813
+ else:
814
+ return gr.update(visible=True), gr.update(visible=True),gr.update(visible=True),gr.update(visible=True),gr.update(visible=True),gr.update(visible=True)
815
+
816
+ def display_image(image_file):
817
+ image=Image.open(image_file)
818
+ return image, os.path.basename(image_file)
819
+
820
+ def display_image_in_file(image_checklist):
821
+ global image_names, current_index
822
+ image_names = load_image_checklist(image_checklist)
823
+ image=show_image(current_index)
824
+ return image,image_names[current_index]
825
+
826
+ def update_file_based_on_chart_type(chart_type, all_file_path):
827
+ with open(all_file_path, 'r', encoding='utf-8') as file:
828
+ lines = file.readlines()
829
+ filtered_lines=[]
830
+ if chart_type == "์ „์ฒด":
831
+ filtered_lines = lines
832
+ elif chart_type == "์ผ๋ฐ˜ ๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•":
833
+ filtered_lines = [line for line in lines if "_horizontal bar_standard" in line]
834
+ elif chart_type=="๋ˆ„์  ๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•":
835
+ filtered_lines = [line for line in lines if "_horizontal bar_accumulation" in line]
836
+ elif chart_type=="100% ๊ธฐ์ค€ ๋ˆ„์  ๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•":
837
+ filtered_lines = [line for line in lines if "_horizontal bar_100per accumulation" in line]
838
+ elif chart_type=="์ผ๋ฐ˜ ์„ธ๋กœ ๋ง‰๋Œ€ํ˜•":
839
+ filtered_lines = [line for line in lines if "_vertical bar_standard" in line]
840
+ elif chart_type=="๋ˆ„์  ์„ธ๋กœ ๋ง‰๋Œ€ํ˜•":
841
+ filtered_lines = [line for line in lines if "_vertical bar_accumulation" in line]
842
+ elif chart_type=="100% ๊ธฐ์ค€ ๋ˆ„์  ์„ธ๋กœ ๋ง‰๋Œ€ํ˜•":
843
+ filtered_lines = [line for line in lines if "_vertical bar_100per accumulation" in line]
844
+ elif chart_type=="์„ ํ˜•":
845
+ filtered_lines = [line for line in lines if "_line_standard" in line]
846
+ elif chart_type=="์›ํ˜•":
847
+ filtered_lines = [line for line in lines if "_pie_standard" in line]
848
+ elif chart_type=="๊ธฐํƒ€ ๋ฐฉ์‚ฌ๏ฟฝ๏ฟฝ":
849
+ filtered_lines = [line for line in lines if "_etc_radial" in line]
850
+ elif chart_type=="๊ธฐํƒ€ ํ˜ผํ•ฉํ˜•":
851
+ filtered_lines = [line for line in lines if "_etc_mix" in line]
852
+ # ์ƒˆ๋กœ์šด ํŒŒ์ผ์— ๊ธฐ๋ก
853
+ new_file_path = "./filtered_chart_images.txt"
854
+ with open(new_file_path, 'w', encoding='utf-8') as file:
855
+ file.writelines(filtered_lines)
856
+
857
+ return new_file_path
858
+
859
+ def handle_chart_type_change(chart_type,all_file_path):
860
+ new_file_path = update_file_based_on_chart_type(chart_type, all_file_path)
861
+ global image_names, current_index
862
+ image_names = load_image_checklist(new_file_path)
863
+ current_index=0
864
+ image=show_image(current_index)
865
+ return image,image_names[current_index]
866
+
867
+ with gr.Blocks() as iface:
868
+ mode=gr.State("image_upload")
869
+ with gr.Row():
870
+ with gr.Column():
871
+ #mode_label=gr.Text("์ด๋ฏธ์ง€ ์—…๋กœ๋“œ๊ฐ€ ์„ ํƒ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
872
+ upload_option = gr.Radio(choices=["์ด๋ฏธ์ง€ ์—…๋กœ๋“œ", "ํŒŒ์ผ ์—…๋กœ๋“œ"], value="์ด๋ฏธ์ง€ ์—…๋กœ๋“œ", label="์—…๋กœ๋“œ ์˜ต์…˜")
873
+ #with gr.Row():
874
+ #image_button = gr.Button("์ด๋ฏธ์ง€ ์—…๋กœ๋“œ")
875
+ #file_button = gr.Button("ํŒŒ์ผ ์—…๋กœ๋“œ")
876
+
877
+ # ์ด๋ฏธ์ง€์™€ ํŒŒ์ผ ์—…๋กœ๋“œ ์ปดํฌ๋„ŒํŠธ (์ดˆ๊ธฐ์—๋Š” ์ˆจ๊น€ ์ƒํƒœ)
878
+ # global image_uploader,file_uploader
879
+ image_uploader= gr.File(file_count="single",file_types=["image"],visible=True)
880
+ file_uploader= gr.File(file_count="single", file_types=[".txt"], visible=False)
881
+ file_upload_option=gr.Radio(choices=["low score ์ฐจํŠธ","high score ์ฐจํŠธ"],label="ํŒŒ์ผ ์—…๋กœ๋“œ ์˜ต์…˜",visible=False)
882
+ chart_type = gr.Dropdown(["์ผ๋ฐ˜ ๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•","๋ˆ„์  ๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•","100% ๊ธฐ์ค€ ๋ˆ„์  ๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•", "์ผ๋ฐ˜ ์„ธ๋กœ ๋ง‰๋Œ€ํ˜•","๋ˆ„์  ์„ธ๋กœ ๋ง‰๋Œ€ํ˜•","100% ๊ธฐ์ค€ ๋ˆ„์  ์„ธ๋กœ ๋ง‰๋Œ€ํ˜•","์„ ํ˜•", "์›ํ˜•", "๊ธฐํƒ€ ๋ฐฉ์‚ฌํ˜•", "๊ธฐํƒ€ ํ˜ผํ•ฉํ˜•", "์ „์ฒด"], label="Chart Type", value="all")
883
+ model_type=gr.Dropdown(["ko_deplot","aihub_deplot","all"],label="model")
884
+ image_displayer=gr.Image(visible=True)
885
+ with gr.Row():
886
+ pre_button=gr.Button("์ด์ „",interactive="False")
887
+ next_button=gr.Button("๋‹ค์Œ")
888
+ image_name=gr.Text("์ด๋ฏธ์ง€ ์ด๋ฆ„",visible=False)
889
+ #image_button.click(interface_selector, inputs=gr.State("์ด๋ฏธ์ง€ ์—…๋กœ๋“œ"), outputs=[image_uploader,file_uploader,mode,mode_label,image_name])
890
+ #file_button.click(interface_selector, inputs=gr.State("ํŒŒ์ผ ์—…๋กœ๋“œ"), outputs=[image_uploader, file_uploader,mode,mode_label,image_name])
891
+ inference_button=gr.Button("์ถ”๋ก ")
892
+ with gr.Column():
893
+ ko_deplot_generated_table=gr.DataFrame(visible=False,label="ko-deplot ์ถ”๋ก  ๊ฒฐ๊ณผ")
894
+ aihub_deplot_generated_table=gr.DataFrame(visible=False,label="aihub-deplot ์ถ”๋ก  ๊ฒฐ๊ณผ")
895
+ with gr.Column():
896
+ ko_deplot_label_table=gr.DataFrame(visible=False,label="ko-deplot ์ •๋‹ตํ…Œ์ด๋ธ”")
897
+ aihub_deplot_label_table=gr.DataFrame(visible=False,label="aihub-deplot ์ •๋‹ตํ…Œ์ด๋ธ”")
898
+ with gr.Column():
899
+ ko_deplot_score_table=gr.DataFrame(visible=False,label="ko_deplot ์ ์ˆ˜")
900
+ aihub_deplot_score_table=gr.DataFrame(visible=False,label="aihub_deplot ์ ์ˆ˜")
901
+ model_type.change(
902
+ update_results,
903
+ inputs=[model_type],
904
+ outputs=[ko_deplot_generated_table,ko_deplot_score_table,aihub_deplot_generated_table,aihub_deplot_score_table,ko_deplot_label_table,aihub_deplot_label_table]
905
+ )
906
+
907
+ upload_option.change(
908
+ interface_selector,
909
+ inputs=[upload_option],
910
+ outputs=[image_uploader, file_uploader, mode, image_name,file_upload_option]
911
+ )
912
+
913
+ file_upload_option.change(
914
+ file_selector,
915
+ inputs=[file_upload_option],
916
+ outputs=[file_uploader]
917
+ )
918
+
919
+ chart_type.change(handle_chart_type_change, inputs=[chart_type,file_uploader],outputs=[image_displayer,image_name])
920
+ image_uploader.upload(display_image,inputs=[image_uploader],outputs=[image_displayer,image_name])
921
+ file_uploader.change(display_image_in_file,inputs=[file_uploader],outputs=[image_displayer,image_name])
922
+ pre_button.click(previous_image, outputs=[image_displayer,image_name,pre_button,next_button])
923
+ next_button.click(next_image, outputs=[image_displayer,image_name,pre_button,next_button])
924
+ inference_button.click(inference,inputs=[upload_option,image_uploader,file_uploader],outputs=[ko_deplot_generated_table, aihub_deplot_generated_table, ko_deplot_label_table, aihub_deplot_label_table,ko_deplot_score_table, aihub_deplot_score_table])
925
+
926
+
927
+ if __name__ == "__main__":
928
+ print("Launching Gradio interface...")
929
+ sys.stdout.flush() # stdout ๋ฒ„ํผ๋ฅผ ๋น„์›๋‹ˆ๋‹ค.
930
+ iface.launch(share=True)
931
+ time.sleep(2) # Gradio URL์ด ์ถœ๋ ฅ๋  ๋•Œ๊นŒ์ง€ ์ž ์‹œ ๊ธฐ๋‹ค๋ฆฝ๋‹ˆ๋‹ค.
932
+ sys.stdout.flush() # ๋‹ค์‹œ stdout ๋ฒ„ํผ๋ฅผ ๋น„์›๋‹ˆ๋‹ค.
933
+ # Gradio๊ฐ€ ์ œ๊ณตํ•˜๋Š” URLs์„ ํŒŒ์ผ์— ๊ธฐ๋กํ•ฉ๋‹ˆ๋‹ค.
934
+ with open("gradio_url.log", "w") as f:
935
+ print(iface.local_url, file=f)
936
+ print(iface.share_url, file=f)