initial commit via hf
Browse files- .gitignore +2 -0
- algo_cfgs_all.json +51 -0
- analysis_funcs.py +355 -0
- app.py +1453 -0
- classic_correction_algos.py +546 -0
- eyekit_measures.py +178 -0
- loss_functions.py +179 -0
- models.py +897 -0
- models/BERT_20240104-223349_loop_normalize_by_line_height_and_width_True_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00430.ckpt +3 -0
- models/BERT_20240104-233803_loop_normalize_by_line_height_and_width_False_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00719.ckpt +3 -0
- models/BERT_20240107-152040_loop_restrict_sim_data_to_4000_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00515.ckpt +3 -0
- models/BERT_20240108-000344_loop_normalize_by_line_height_and_width_False_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00706.ckpt +3 -0
- models/BERT_20240108-011230_loop_normalize_by_line_height_and_width_True_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00560.ckpt +3 -0
- models/BERT_20240109-090419_loop_normalize_by_line_height_and_width_False_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00518.ckpt +3 -0
- models/BERT_20240122-183729_loop_normalize_by_line_height_and_width_True_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00523.ckpt +3 -0
- models/BERT_20240122-194041_loop_normalize_by_line_height_and_width_False_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00462.ckpt +3 -0
- models/BERT_fin_exp_20240104-223349.yaml +100 -0
- models/BERT_fin_exp_20240104-233803.yaml +100 -0
- models/BERT_fin_exp_20240107-152040.yaml +100 -0
- models/BERT_fin_exp_20240108-000344.yaml +100 -0
- models/BERT_fin_exp_20240108-011230.yaml +100 -0
- models/BERT_fin_exp_20240109-090419.yaml +100 -0
- models/BERT_fin_exp_20240122-183729.yaml +102 -0
- models/BERT_fin_exp_20240122-194041.yaml +102 -0
- requirements.txt +25 -0
- run_in_notebook.ipynb +0 -0
- utils.py +2016 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
.gitignore
|
algo_cfgs_all.json
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"compare": {
|
3 |
+
"x_thresh": 512,
|
4 |
+
"n_nearest_lines": 3
|
5 |
+
},
|
6 |
+
"attach": {},
|
7 |
+
"segment": {},
|
8 |
+
"split": {},
|
9 |
+
"stretch": {
|
10 |
+
"stretch_bounds": [
|
11 |
+
0.9,
|
12 |
+
1.1
|
13 |
+
],
|
14 |
+
"offset_bounds": [
|
15 |
+
-50,
|
16 |
+
50
|
17 |
+
]
|
18 |
+
},
|
19 |
+
"slice": {
|
20 |
+
"x_thresh": 192,
|
21 |
+
"y_thresh": 32,
|
22 |
+
"w_thresh": 32,
|
23 |
+
"n_thresh": 90
|
24 |
+
|
25 |
+
},
|
26 |
+
"warp": {},
|
27 |
+
"chain": {
|
28 |
+
"x_thresh": 192,
|
29 |
+
"y_thresh": 55
|
30 |
+
},
|
31 |
+
"regress": {
|
32 |
+
"slope_bounds": [
|
33 |
+
-0.1,
|
34 |
+
0.1
|
35 |
+
],
|
36 |
+
"offset_bounds": [
|
37 |
+
-50,
|
38 |
+
50
|
39 |
+
],
|
40 |
+
"std_bounds": [
|
41 |
+
1,
|
42 |
+
20
|
43 |
+
]
|
44 |
+
},
|
45 |
+
"cluster": {},
|
46 |
+
"merge": {
|
47 |
+
"y_thresh": 32,
|
48 |
+
"gradient_thresh": 0.1,
|
49 |
+
"error_thresh": 20
|
50 |
+
}
|
51 |
+
}
|
analysis_funcs.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Partially taken and adapted from: https://github.com/jwcarr/eyekit/blob/1db1913411327b108b87e097a00278b6e50d0751/eyekit/measure.py
|
3 |
+
Functions for calculating common reading measures, such as gaze duration or
|
4 |
+
initial landing position.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import pandas as pd
|
8 |
+
|
9 |
+
|
10 |
+
def fix_in_ia(fix_x, fix_y, ia_x_min, ia_x_max, ia_y_min, ia_y_max):
|
11 |
+
in_x = ia_x_min <= fix_x <= ia_x_max
|
12 |
+
in_y = ia_y_min <= fix_y <= ia_y_max
|
13 |
+
if in_x and in_y:
|
14 |
+
return True
|
15 |
+
else:
|
16 |
+
return False
|
17 |
+
|
18 |
+
|
19 |
+
def fix_in_ia_default(fixation, ia_row, prefix):
|
20 |
+
return fix_in_ia(
|
21 |
+
fixation.x,
|
22 |
+
fixation.y,
|
23 |
+
ia_row[f"{prefix}_xmin"],
|
24 |
+
ia_row[f"{prefix}_xmax"],
|
25 |
+
ia_row[f"{prefix}_ymin"],
|
26 |
+
ia_row[f"{prefix}_ymax"],
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
def number_of_fixations_own(trial, dffix, prefix="word"):
|
31 |
+
"""
|
32 |
+
Given an interest area and fixation sequence, return the number of
|
33 |
+
fixations on that interest area.
|
34 |
+
"""
|
35 |
+
ia_df = pd.DataFrame(trial[f"{prefix}s_list"])
|
36 |
+
counts = []
|
37 |
+
for cidx, ia_row in ia_df.iterrows():
|
38 |
+
count = 0
|
39 |
+
for idx, fixation in dffix.iterrows():
|
40 |
+
if fix_in_ia(
|
41 |
+
fixation.x,
|
42 |
+
fixation.y,
|
43 |
+
ia_row[f"{prefix}_xmin"],
|
44 |
+
ia_row[f"{prefix}_xmax"],
|
45 |
+
ia_row[f"{prefix}_ymin"],
|
46 |
+
ia_row[f"{prefix}_ymax"],
|
47 |
+
):
|
48 |
+
count += 1
|
49 |
+
counts.append(
|
50 |
+
{
|
51 |
+
f"{prefix}_index": cidx,
|
52 |
+
prefix: ia_row[f"{prefix}"],
|
53 |
+
"number_of_fixations": count,
|
54 |
+
}
|
55 |
+
)
|
56 |
+
return pd.DataFrame(counts)
|
57 |
+
|
58 |
+
|
59 |
+
def initial_fixation_duration_own(trial, dffix, prefix="word"):
|
60 |
+
"""
|
61 |
+
Given an interest area and fixation sequence, return the duration of the
|
62 |
+
initial fixation on that interest area for each word.
|
63 |
+
"""
|
64 |
+
ia_df = pd.DataFrame(trial[f"{prefix}s_list"])
|
65 |
+
durations = []
|
66 |
+
|
67 |
+
for cidx, ia_row in ia_df.iterrows():
|
68 |
+
initial_duration = 0
|
69 |
+
for idx, fixation in dffix.iterrows():
|
70 |
+
if fix_in_ia_default(fixation, ia_row, prefix):
|
71 |
+
initial_duration = fixation.duration
|
72 |
+
break # Exit the loop after finding the initial fixation for the word
|
73 |
+
durations.append(
|
74 |
+
{
|
75 |
+
f"{prefix}_index": cidx,
|
76 |
+
prefix: ia_row[f"{prefix}"],
|
77 |
+
"initial_fixation_duration": initial_duration,
|
78 |
+
}
|
79 |
+
)
|
80 |
+
|
81 |
+
return pd.DataFrame(durations)
|
82 |
+
|
83 |
+
|
84 |
+
def first_of_many_duration_own(trial, dffix, prefix="word"):
|
85 |
+
ia_df = pd.DataFrame(trial[f"{prefix}s_list"])
|
86 |
+
durations = []
|
87 |
+
for cidx, ia_row in ia_df.iterrows():
|
88 |
+
fixation_durations = []
|
89 |
+
for idx, fixation in dffix.iterrows():
|
90 |
+
if fix_in_ia_default(fixation, ia_row, prefix):
|
91 |
+
fixation_durations.append(fixation.duration)
|
92 |
+
if len(fixation_durations) > 1:
|
93 |
+
durations.append(
|
94 |
+
{
|
95 |
+
f"{prefix}_index": cidx,
|
96 |
+
prefix: ia_row[f"{prefix}"],
|
97 |
+
"first_of_many_duration": fixation_durations[0],
|
98 |
+
}
|
99 |
+
)
|
100 |
+
if durations:
|
101 |
+
return pd.DataFrame(durations)
|
102 |
+
else:
|
103 |
+
return pd.DataFrame()
|
104 |
+
|
105 |
+
|
106 |
+
def total_fixation_duration_own(trial, dffix, prefix="word"):
|
107 |
+
"""
|
108 |
+
Given an interest area and fixation sequence, return the sum duration of
|
109 |
+
all fixations on that interest area.
|
110 |
+
"""
|
111 |
+
ia_df = pd.DataFrame(trial[f"{prefix}s_list"])
|
112 |
+
durations = []
|
113 |
+
for cidx, ia_row in ia_df.iterrows():
|
114 |
+
total_duration = 0
|
115 |
+
for idx, fixation in dffix.iterrows():
|
116 |
+
if fix_in_ia_default(fixation, ia_row, prefix):
|
117 |
+
total_duration += fixation.duration
|
118 |
+
durations.append(
|
119 |
+
{
|
120 |
+
f"{prefix}_index": cidx,
|
121 |
+
prefix: ia_row[f"{prefix}"],
|
122 |
+
"total_fixation_duration": total_duration,
|
123 |
+
}
|
124 |
+
)
|
125 |
+
return pd.DataFrame(durations)
|
126 |
+
|
127 |
+
|
128 |
+
def gaze_duration_own(trial, dffix, prefix="word"):
|
129 |
+
"""
|
130 |
+
Given an interest area and fixation sequence, return the gaze duration on
|
131 |
+
that interest area. Gaze duration is the sum duration of all fixations
|
132 |
+
inside an interest area until the area is exited for the first time.
|
133 |
+
"""
|
134 |
+
ia_df = pd.DataFrame(trial[f"{prefix}s_list"])
|
135 |
+
durations = []
|
136 |
+
for cidx, ia_row in ia_df.iterrows():
|
137 |
+
duration = 0
|
138 |
+
in_ia = False
|
139 |
+
for idx, fixation in dffix.iterrows():
|
140 |
+
if fix_in_ia_default(fixation, ia_row, prefix):
|
141 |
+
duration += fixation.duration
|
142 |
+
in_ia = True
|
143 |
+
elif in_ia:
|
144 |
+
break
|
145 |
+
durations.append(
|
146 |
+
{
|
147 |
+
f"{prefix}_index": cidx,
|
148 |
+
prefix: ia_row[f"{prefix}"],
|
149 |
+
"gaze_duration": duration,
|
150 |
+
}
|
151 |
+
)
|
152 |
+
return pd.DataFrame(durations)
|
153 |
+
|
154 |
+
|
155 |
+
def go_past_duration_own(trial, dffix, prefix="word"):
|
156 |
+
"""
|
157 |
+
Given an interest area and fixation sequence, return the go-past time on
|
158 |
+
that interest area. Go-past time is the sum duration of all fixations from
|
159 |
+
when the interest area is first entered until when it is first exited to
|
160 |
+
the right, including any regressions to the left that occur during that
|
161 |
+
time period (and vice versa in the case of right-to-left text).
|
162 |
+
"""
|
163 |
+
ia_df = pd.DataFrame(trial[f"{prefix}s_list"])
|
164 |
+
results = []
|
165 |
+
|
166 |
+
for cidx, ia_row in ia_df.iterrows():
|
167 |
+
entered = False
|
168 |
+
go_past_time = 0
|
169 |
+
|
170 |
+
for idx, fixation in dffix.iterrows():
|
171 |
+
if fix_in_ia_default(fixation, ia_row, prefix):
|
172 |
+
if not entered:
|
173 |
+
entered = True
|
174 |
+
go_past_time += fixation.duration
|
175 |
+
elif entered:
|
176 |
+
if ia_row[f"{prefix}_xmax"] < fixation.x: # Interest area has been exited to the right
|
177 |
+
break
|
178 |
+
go_past_time += fixation.duration
|
179 |
+
|
180 |
+
results.append({f"{prefix}_index": cidx, prefix: ia_row[f"{prefix}"], "go_past_duration": go_past_time})
|
181 |
+
|
182 |
+
return pd.DataFrame(results)
|
183 |
+
|
184 |
+
|
185 |
+
def second_pass_duration_own(trial, dffix, prefix="word"):
|
186 |
+
"""
|
187 |
+
Given an interest area and fixation sequence, return the second pass
|
188 |
+
duration on that interest area for each word.
|
189 |
+
"""
|
190 |
+
ia_df = pd.DataFrame(trial[f"{prefix}s_list"])
|
191 |
+
durations = []
|
192 |
+
|
193 |
+
for cidx, ia_row in ia_df.iterrows():
|
194 |
+
current_pass = None
|
195 |
+
next_pass = 1
|
196 |
+
pass_duration = 0
|
197 |
+
for idx, fixation in dffix.iterrows():
|
198 |
+
if fix_in_ia_default(fixation, ia_row, prefix):
|
199 |
+
if current_pass is None: # first fixation in a new pass
|
200 |
+
current_pass = next_pass
|
201 |
+
if current_pass == 2:
|
202 |
+
pass_duration += fixation.duration
|
203 |
+
elif current_pass == 1: # first fixation to exit the first pass
|
204 |
+
current_pass = None
|
205 |
+
next_pass += 1
|
206 |
+
elif current_pass == 2: # first fixation to exit the second pass
|
207 |
+
break
|
208 |
+
durations.append(
|
209 |
+
{
|
210 |
+
f"{prefix}_index": cidx,
|
211 |
+
prefix: ia_row[f"{prefix}"],
|
212 |
+
"second_pass_duration": pass_duration,
|
213 |
+
}
|
214 |
+
)
|
215 |
+
|
216 |
+
return pd.DataFrame(durations)
|
217 |
+
|
218 |
+
|
219 |
+
def initial_landing_position_own(trial, dffix, prefix="word"):
|
220 |
+
"""
|
221 |
+
Given an interest area and fixation sequence, return the initial landing
|
222 |
+
position (expressed in character positions) on that interest area.
|
223 |
+
Counting is from 1. If the interest area represents right-to-left text,
|
224 |
+
the first character is the rightmost one. Returns `None` if no fixation
|
225 |
+
landed on the interest area.
|
226 |
+
"""
|
227 |
+
ia_df = pd.DataFrame(trial[f"{prefix}s_list"])
|
228 |
+
if prefix == "word":
|
229 |
+
chars_df = pd.DataFrame(trial[f"chars_list"])
|
230 |
+
else:
|
231 |
+
chars_df = None
|
232 |
+
results = []
|
233 |
+
for cidx, ia_row in ia_df.iterrows():
|
234 |
+
landing_position = None
|
235 |
+
for idx, fixation in dffix.iterrows():
|
236 |
+
if fix_in_ia_default(fixation, ia_row, prefix):
|
237 |
+
if prefix == "char":
|
238 |
+
landing_position = 1
|
239 |
+
else:
|
240 |
+
prefix_temp = "char"
|
241 |
+
matched_chars_df = chars_df.loc[
|
242 |
+
(chars_df.char_xmin >= ia_row[f"{prefix}_xmin"])
|
243 |
+
& (chars_df.char_xmax <= ia_row[f"{prefix}_xmax"])
|
244 |
+
& (chars_df.char_ymin >= ia_row[f"{prefix}_ymin"])
|
245 |
+
& (chars_df.char_ymax <= ia_row[f"{prefix}_ymax"]),
|
246 |
+
:,
|
247 |
+
] # need to find way to count correct letter number
|
248 |
+
for char_idx, (rowidx, char_row) in enumerate(matched_chars_df.iterrows()):
|
249 |
+
if fix_in_ia_default(fixation, char_row, prefix_temp):
|
250 |
+
landing_position = char_idx + 1 # starts at 1
|
251 |
+
break
|
252 |
+
break
|
253 |
+
results.append(
|
254 |
+
{
|
255 |
+
f"{prefix}_index": cidx,
|
256 |
+
prefix: ia_row[f"{prefix}"],
|
257 |
+
"initial_landing_position": landing_position,
|
258 |
+
}
|
259 |
+
)
|
260 |
+
return pd.DataFrame(results)
|
261 |
+
|
262 |
+
|
263 |
+
def initial_landing_distance_own(trial, dffix, prefix="word"):
|
264 |
+
"""
|
265 |
+
Given an interest area and fixation sequence, return the initial landing
|
266 |
+
distance on that interest area. The initial landing distance is the pixel
|
267 |
+
distance between the first fixation to land in an interest area and the
|
268 |
+
left edge of that interest area (or, in the case of right-to-left text,
|
269 |
+
the right edge). Technically, the distance is measured from the text onset
|
270 |
+
without including any padding. Returns `None` if no fixation landed on the
|
271 |
+
interest area.
|
272 |
+
"""
|
273 |
+
ia_df = pd.DataFrame(trial[f"{prefix}s_list"])
|
274 |
+
distances = []
|
275 |
+
for cidx, ia_row in ia_df.iterrows():
|
276 |
+
initial_distance = None
|
277 |
+
for idx, fixation in dffix.iterrows():
|
278 |
+
if fix_in_ia_default(fixation, ia_row, prefix):
|
279 |
+
distance = abs(ia_row[f"{prefix}_xmin"] - fixation.x)
|
280 |
+
if initial_distance is None:
|
281 |
+
initial_distance = distance
|
282 |
+
break
|
283 |
+
distances.append(
|
284 |
+
{
|
285 |
+
f"{prefix}_index": cidx,
|
286 |
+
prefix: ia_row[f"{prefix}"],
|
287 |
+
"initial_landing_distance": initial_distance,
|
288 |
+
}
|
289 |
+
)
|
290 |
+
return pd.DataFrame(distances)
|
291 |
+
|
292 |
+
|
293 |
+
def landing_distances_own(trial, dffix, prefix="word"):
|
294 |
+
"""
|
295 |
+
Given an interest area and fixation sequence, return a dataframe with
|
296 |
+
landing distances for each word in the interest area.
|
297 |
+
"""
|
298 |
+
ia_df = pd.DataFrame(trial[f"{prefix}s_list"])
|
299 |
+
distances = []
|
300 |
+
for cidx, ia_row in ia_df.iterrows():
|
301 |
+
landing_distances = []
|
302 |
+
for idx, fixation in dffix.iterrows():
|
303 |
+
if fix_in_ia_default(fixation, ia_row, prefix):
|
304 |
+
landing_distance = abs(ia_row[f"{prefix}_xmin"] - fixation.x)
|
305 |
+
landing_distances.append(round(landing_distance, ndigits=2))
|
306 |
+
distances.append({f"{prefix}_index": cidx, prefix: ia_row[f"{prefix}"], "landing_distances": landing_distances})
|
307 |
+
return pd.DataFrame(distances)
|
308 |
+
|
309 |
+
|
310 |
+
def number_of_regressions_in_own(trial, dffix, prefix="word"):
|
311 |
+
"""
|
312 |
+
Given an interest area and fixation sequence, return the number of
|
313 |
+
regressions back to that interest area after the interest area was read
|
314 |
+
for the first time. In other words, find the first fixation to exit the
|
315 |
+
interest area and then count how many times the reader returns to the
|
316 |
+
interest area from the right (or from the left in the case of
|
317 |
+
right-to-left text).
|
318 |
+
"""
|
319 |
+
ia_df = pd.DataFrame(trial[f"{prefix}s_list"])
|
320 |
+
counts = []
|
321 |
+
for cidx, ia_row in ia_df.iterrows():
|
322 |
+
entered_interest_area = False
|
323 |
+
first_exit_index = None
|
324 |
+
count = 0
|
325 |
+
prev_fixation = None
|
326 |
+
regression_counted = False
|
327 |
+
|
328 |
+
for fixidx, (rowidx, fixation) in enumerate(dffix.iterrows()):
|
329 |
+
if (
|
330 |
+
entered_interest_area
|
331 |
+
and first_exit_index is not None
|
332 |
+
and fix_in_ia_default(fixation, ia_row, prefix)
|
333 |
+
and not regression_counted
|
334 |
+
):
|
335 |
+
if prev_fixation.x > fixation.x:
|
336 |
+
count += 1
|
337 |
+
regression_counted = True
|
338 |
+
|
339 |
+
if fix_in_ia_default(fixation, ia_row, prefix):
|
340 |
+
entered_interest_area = True
|
341 |
+
elif entered_interest_area and first_exit_index is None:
|
342 |
+
first_exit_index = fixidx
|
343 |
+
else:
|
344 |
+
regression_counted = False
|
345 |
+
prev_fixation = fixation
|
346 |
+
|
347 |
+
counts.append(
|
348 |
+
{
|
349 |
+
f"{prefix}_index": cidx,
|
350 |
+
prefix: ia_row[f"{prefix}"],
|
351 |
+
"number_of_regressions_in": count,
|
352 |
+
}
|
353 |
+
)
|
354 |
+
|
355 |
+
return pd.DataFrame(counts)
|
app.py
ADDED
@@ -0,0 +1,1453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from PIL import Image
|
3 |
+
from io import StringIO
|
4 |
+
import streamlit as st
|
5 |
+
import pandas as pd
|
6 |
+
import numpy as np
|
7 |
+
import re
|
8 |
+
import time
|
9 |
+
import os
|
10 |
+
|
11 |
+
from matplotlib.font_manager import FontProperties
|
12 |
+
from matplotlib.patches import Rectangle
|
13 |
+
from matplotlib import pyplot as plt
|
14 |
+
import plotly.graph_objects as go
|
15 |
+
import plotly.express as px
|
16 |
+
import numpy as np
|
17 |
+
import pandas as pd
|
18 |
+
import pathlib as pl
|
19 |
+
import json
|
20 |
+
import logging
|
21 |
+
import zipfile
|
22 |
+
from stqdm import stqdm
|
23 |
+
import jellyfish as jf
|
24 |
+
import lovely_tensors
|
25 |
+
import shutil
|
26 |
+
import eyekit_measures as ekm
|
27 |
+
import zipfile
|
28 |
+
|
29 |
+
import utils as ut
|
30 |
+
|
31 |
+
os.environ["MPLCONFIGDIR"] = os.getcwd() + "/configs/"
|
32 |
+
|
33 |
+
st.set_page_config("Correction", page_icon=":eye:", layout="wide")
|
34 |
+
|
35 |
+
AVAILABLE_FONTS = st.session_state["AVAILABLE_FONTS"] = ut.AVAILABLE_FONTS
|
36 |
+
|
37 |
+
DEFAULT_PLOT_FONT = "DejaVu Sans Mono"
|
38 |
+
EXAMPLES_FOLDER = "./testfiles/"
|
39 |
+
EXAMPLES_ASC_ZIP_FILENAME = "asc_files.zip"
|
40 |
+
OSF_DOWNLAOD_LINK = "https://osf.io/download/us97f/"
|
41 |
+
EXAMPLES_FOLDER_PATH = pl.Path(EXAMPLES_FOLDER)
|
42 |
+
|
43 |
+
|
44 |
+
lovely_tensors.monkey_patch()
|
45 |
+
|
46 |
+
|
47 |
+
def make_folders(gradio_temp_folder, gradio_temp_unzipped_folder, gradio_plots):
|
48 |
+
return ut.make_folders(gradio_temp_folder, gradio_temp_unzipped_folder, gradio_plots)
|
49 |
+
|
50 |
+
|
51 |
+
TEMP_FOLDER = st.session_state["TEMP_FOLDER"] = ut.TEMP_FOLDER
|
52 |
+
gradio_temp_unzipped_folder = st.session_state["gradio_temp_unzipped_folder"] = pl.Path("unzipped")
|
53 |
+
|
54 |
+
PLOTS_FOLDER = st.session_state["PLOTS_FOLDER"] = pl.Path("plots")
|
55 |
+
TEMP_FIGURE_STIMULUS_PATH = PLOTS_FOLDER.joinpath("temp_matplotlib_plot_stimulus.png")
|
56 |
+
make_folders(TEMP_FOLDER, gradio_temp_unzipped_folder, PLOTS_FOLDER)
|
57 |
+
|
58 |
+
|
59 |
+
@st.cache_data
|
60 |
+
def get_classic_cfg(fname):
|
61 |
+
return ut.get_classic_cfg(fname)
|
62 |
+
|
63 |
+
|
64 |
+
classic_algos_cfg = st.session_state["classic_algos_cfg"] = get_classic_cfg("algo_cfgs_all.json")
|
65 |
+
|
66 |
+
DIST_MODELS_FOLDER = st.session_state["DIST_MODELS_FOLDER"] = pl.Path("models")
|
67 |
+
COLORS = st.session_state["COLORS"] = px.colors.qualitative.Alphabet
|
68 |
+
ALGO_CHOICES = st.session_state["ALGO_CHOICES"] = [
|
69 |
+
"warp",
|
70 |
+
"regress",
|
71 |
+
"compare",
|
72 |
+
"attach",
|
73 |
+
"segment",
|
74 |
+
"split",
|
75 |
+
"stretch",
|
76 |
+
"chain",
|
77 |
+
"slice",
|
78 |
+
"cluster",
|
79 |
+
"merge",
|
80 |
+
"Wisdom_of_Crowds",
|
81 |
+
"DIST",
|
82 |
+
"DIST-Ensemble",
|
83 |
+
"Wisdom_of_Crowds_with_DIST",
|
84 |
+
"Wisdom_of_Crowds_with_DIST_Ensemble",
|
85 |
+
]
|
86 |
+
|
87 |
+
|
88 |
+
st.session_state["colnames_custom_csv_fix"] = {
|
89 |
+
"x_col_name_fix": "x",
|
90 |
+
"y_col_name_fix": "y",
|
91 |
+
"x_col_name_fix_stim": "char_x_center",
|
92 |
+
"x_start_col_name_fix_stim": "char_xmin",
|
93 |
+
"x_end_col_name_fix_stim": "char_xmax",
|
94 |
+
"y_col_name_fix_stim": "char_y_center",
|
95 |
+
"y_start_col_name_fix_stim": "char_ymin",
|
96 |
+
"y_end_col_name_fix_stim": "char_ymax",
|
97 |
+
"char_col_name_fix_stim": "char",
|
98 |
+
"trial_id_col_name_fix": "trial_id",
|
99 |
+
"trial_id_col_name_stim": "trial_id",
|
100 |
+
"subject_col_name_fix": "subid",
|
101 |
+
"subject_col_name_stim": "subid",
|
102 |
+
"line_num_col_name_stim": "assigned_line",
|
103 |
+
"time_start_col_name_fix": "start",
|
104 |
+
"time_stop_col_name_fix": "stop",
|
105 |
+
}
|
106 |
+
|
107 |
+
if "results" not in st.session_state:
|
108 |
+
st.session_state["results"] = {}
|
109 |
+
|
110 |
+
|
111 |
+
@st.cache_resource
|
112 |
+
def load_model(model_file, cfg):
|
113 |
+
return ut.load_model(model_file, cfg)
|
114 |
+
|
115 |
+
|
116 |
+
@st.cache_resource
|
117 |
+
def find_and_load_model(model_date="20240104-223349"):
|
118 |
+
return ut.find_and_load_model(model_date)
|
119 |
+
|
120 |
+
|
121 |
+
def create_logger(name, level="DEBUG", file=None):
|
122 |
+
logger = logging.getLogger(name)
|
123 |
+
logger.propagate = False
|
124 |
+
logger.setLevel(level)
|
125 |
+
if sum([isinstance(handler, logging.StreamHandler) for handler in logger.handlers]) == 0:
|
126 |
+
ch = logging.StreamHandler()
|
127 |
+
ch.setFormatter(
|
128 |
+
logging.Formatter(
|
129 |
+
"%(asctime)s.%(msecs)03d-%(name)s-p%(process)s-{%(pathname)s:%(lineno)d}-%(levelname)s >>> %(message)s",
|
130 |
+
"%m-%d %H:%M:%S",
|
131 |
+
)
|
132 |
+
)
|
133 |
+
logger.addHandler(ch)
|
134 |
+
if file is not None:
|
135 |
+
if sum([isinstance(handler, logging.FileHandler) for handler in logger.handlers]) == 0:
|
136 |
+
ch = logging.FileHandler(file, "w")
|
137 |
+
ch.setFormatter(
|
138 |
+
logging.Formatter(
|
139 |
+
"%(asctime)s.%(msecs)03d-%(name)s-p%(process)s-{%(pathname)s:%(lineno)d}-%(levelname)s >>> %(message)s",
|
140 |
+
"%m-%d %H:%M:%S",
|
141 |
+
)
|
142 |
+
)
|
143 |
+
logger.addHandler(ch)
|
144 |
+
logger.debug("Logger added")
|
145 |
+
return logger
|
146 |
+
|
147 |
+
|
148 |
+
if "logger" not in st.session_state:
|
149 |
+
st.session_state["logger"] = create_logger(name="app", level="DEBUG", file="log_for_app.log")
|
150 |
+
|
151 |
+
|
152 |
+
@st.cache_data
|
153 |
+
def download_example_ascs(EXAMPLES_FOLDER, EXAMPLES_ASC_ZIP_FILENAME, OSF_DOWNLAOD_LINK, EXAMPLES_FOLDER_PATH):
|
154 |
+
return ut.download_example_ascs(EXAMPLES_FOLDER, EXAMPLES_ASC_ZIP_FILENAME, OSF_DOWNLAOD_LINK, EXAMPLES_FOLDER_PATH)
|
155 |
+
|
156 |
+
|
157 |
+
EXAMPLE_ASC_FILES = download_example_ascs(
|
158 |
+
EXAMPLES_FOLDER, EXAMPLES_ASC_ZIP_FILENAME, OSF_DOWNLAOD_LINK, EXAMPLES_FOLDER_PATH
|
159 |
+
)
|
160 |
+
|
161 |
+
|
162 |
+
def asc_to_trial_ids(asc_file, close_gap_between_words=True):
|
163 |
+
return ut.asc_to_trial_ids(asc_file, close_gap_between_words)
|
164 |
+
|
165 |
+
|
166 |
+
@st.cache_data
|
167 |
+
def get_trials_list(asc_file=None, close_gap_between_words=True):
|
168 |
+
return ut.get_trials_list(asc_file, close_gap_between_words)
|
169 |
+
|
170 |
+
|
171 |
+
@st.cache_data
|
172 |
+
def prep_data_for_dist(model_cfg, dffix, trial=None):
|
173 |
+
return ut.prep_data_for_dist(model_cfg, dffix, trial)
|
174 |
+
|
175 |
+
|
176 |
+
def save_trial_to_json(trial, savename):
|
177 |
+
return ut.save_trial_to_json(trial, savename)
|
178 |
+
|
179 |
+
|
180 |
+
def export_csv(dffix, trial):
|
181 |
+
return ut.export_csv(dffix, trial)
|
182 |
+
|
183 |
+
|
184 |
+
@st.cache_data
|
185 |
+
def get_DIST_preds(dffix, trial):
|
186 |
+
return ut.get_DIST_preds(dffix, trial)
|
187 |
+
|
188 |
+
|
189 |
+
@st.cache_data
|
190 |
+
def get_EDIST_preds_with_model_check(dffix, trial, ensemble_model_avg=None):
|
191 |
+
return ut.get_EDIST_preds_with_model_check(dffix, trial, ensemble_model_avg)
|
192 |
+
|
193 |
+
|
194 |
+
def get_all_classic_preds(dffix, trial):
|
195 |
+
return ut.get_all_classic_preds(dffix, trial)
|
196 |
+
|
197 |
+
|
198 |
+
def apply_woc(dffix, trial, corrections, algo_choice):
|
199 |
+
return ut.apply_woc(dffix, trial, corrections, algo_choice)
|
200 |
+
|
201 |
+
|
202 |
+
@st.cache_data
|
203 |
+
def correct_df(
|
204 |
+
dffix,
|
205 |
+
algo_choice,
|
206 |
+
trial=None,
|
207 |
+
for_multi=False,
|
208 |
+
ensemble_model_avg=None,
|
209 |
+
):
|
210 |
+
return ut.correct_df(
|
211 |
+
dffix,
|
212 |
+
algo_choice,
|
213 |
+
trial,
|
214 |
+
for_multi,
|
215 |
+
ensemble_model_avg,
|
216 |
+
)
|
217 |
+
|
218 |
+
|
219 |
+
@st.cache_data
|
220 |
+
def get_font_and_font_size_from_trial(trial):
|
221 |
+
return ut.get_font_and_font_size_from_trial(trial)
|
222 |
+
|
223 |
+
|
224 |
+
@st.cache_data
|
225 |
+
def add_default_font_and_character_props_to_state(trial):
|
226 |
+
return ut.add_default_font_and_character_props_to_state(trial)
|
227 |
+
|
228 |
+
|
229 |
+
@st.cache_data
|
230 |
+
def get_plot_props(trial, available_fonts):
|
231 |
+
return ut.get_plot_props(trial, available_fonts)
|
232 |
+
|
233 |
+
|
234 |
+
def process_trial_choice(trial_id, algo_choice):
|
235 |
+
if isinstance(trial_id, dict):
|
236 |
+
trial_id = trial_id["value"]
|
237 |
+
trials_by_ids = st.session_state["trials_by_ids"]
|
238 |
+
trial = trials_by_ids[trial_id]
|
239 |
+
if "chars_list" in trial:
|
240 |
+
(
|
241 |
+
y_diff,
|
242 |
+
x_txt_start,
|
243 |
+
y_txt_start,
|
244 |
+
font_face,
|
245 |
+
_,
|
246 |
+
line_height,
|
247 |
+
) = add_default_font_and_character_props_to_state(trial)
|
248 |
+
font_size = ut.set_font_from_chars_list(trial)
|
249 |
+
|
250 |
+
st.session_state["y_diff_for_eyekit"] = y_diff
|
251 |
+
st.session_state["x_txt_start_for_eyekit"] = x_txt_start
|
252 |
+
st.session_state["y_txt_start_for_eyekit"] = y_txt_start
|
253 |
+
st.session_state["font_face_for_eyekit"] = font_face
|
254 |
+
st.session_state["font_size_for_eyekit"] = font_size
|
255 |
+
st.session_state["line_height_for_eyekit"] = line_height
|
256 |
+
|
257 |
+
if "dffix" in trial:
|
258 |
+
dffix = trial["dffix"]
|
259 |
+
else:
|
260 |
+
asc_file = st.session_state["asc_file"]
|
261 |
+
trial["plot_file"] = str(PLOTS_FOLDER.joinpath(f"{asc_file.stem}_{trial_id}_2ndInput_chars_channel_sep.png"))
|
262 |
+
trial["fname"] = str(asc_file.name).split(".")[0]
|
263 |
+
df, dffix, trial = ut.trial_to_dfs(trial, st.session_state["lines"], use_synctime=True)
|
264 |
+
st.session_state["logger"].info(f"dffix.columns after trial_to_dfs {dffix.columns}")
|
265 |
+
|
266 |
+
font, font_size, dpi, screen_res = ut.get_plot_props(trial, AVAILABLE_FONTS)
|
267 |
+
st.session_state["trial"] = trial
|
268 |
+
if "chars_list" in trial:
|
269 |
+
chars_df = pd.DataFrame(trial["chars_list"])
|
270 |
+
trial["chars_df"] = chars_df.to_dict()
|
271 |
+
trial["y_char_unique"] = list(chars_df.char_y_center.sort_values().unique())
|
272 |
+
if algo_choice is not None and ("chars_list" in trial or "words_list" in trial):
|
273 |
+
dffix, _ = correct_df(dffix, algo_choice, trial)
|
274 |
+
else:
|
275 |
+
st.warning("🚨 Stimulus information needed for fixation correction 🚨")
|
276 |
+
|
277 |
+
return dffix, trial, dpi, screen_res, font, font_size
|
278 |
+
|
279 |
+
|
280 |
+
@st.cache_data
|
281 |
+
def process_trial_choice_single_csv(trial, algo_choice, file=None):
|
282 |
+
return ut.process_trial_choice_single_csv(trial, algo_choice, file=file)
|
283 |
+
|
284 |
+
|
285 |
+
def quick_dffix_save(dffix, savename):
|
286 |
+
dffix.to_csv(savename)
|
287 |
+
st.session_state["logger"].info(f"Saved processed data as {savename}")
|
288 |
+
|
289 |
+
|
290 |
+
def save_trial_to_json(trial, savename):
|
291 |
+
if "dffix" in trial:
|
292 |
+
trial.pop("dffix")
|
293 |
+
with open(savename, "w", encoding="utf-8") as f:
|
294 |
+
json.dump(trial, f, ensure_ascii=False, indent=4, cls=ut.NumpyEncoder)
|
295 |
+
|
296 |
+
|
297 |
+
@st.cache_data
|
298 |
+
def process_trial(trial, asc_file_stem, lines, algo_choice, for_multi=False):
|
299 |
+
trial_id = trial["trial_id"]
|
300 |
+
trial["plot_file"] = str(PLOTS_FOLDER.joinpath(f"{asc_file_stem}_{trial_id}_2ndInput_chars_channel_sep.png"))
|
301 |
+
trial["fname"] = str(asc_file_stem)
|
302 |
+
font, font_size, dpi, screen_res = ut.get_plot_props(trial, AVAILABLE_FONTS)
|
303 |
+
trial["font"] = font
|
304 |
+
trial["font_size"] = font_size
|
305 |
+
trial["dpi"] = dpi
|
306 |
+
trial["screen_res"] = screen_res
|
307 |
+
df, dffix, trial = ut.trial_to_dfs(trial, lines, use_synctime=True)
|
308 |
+
if dffix.empty:
|
309 |
+
return pd.DataFrame(), trial
|
310 |
+
|
311 |
+
chars_df = pd.DataFrame(trial["chars_list"])
|
312 |
+
trial["y_char_unique"] = list(chars_df.char_y_center.sort_values().unique())
|
313 |
+
|
314 |
+
trial["chars_df"] = chars_df.to_dict()
|
315 |
+
trial["y_char_unique"] = list(chars_df.char_y_center.sort_values().unique())
|
316 |
+
if algo_choice is not None:
|
317 |
+
dffix = correct_df(dffix, algo_choice, trial, for_multi)
|
318 |
+
|
319 |
+
return dffix, trial
|
320 |
+
|
321 |
+
|
322 |
+
def add_text_to_ax(
|
323 |
+
chars_list,
|
324 |
+
ax,
|
325 |
+
font_to_use="DejaVu Sans Mono",
|
326 |
+
fontsize=21,
|
327 |
+
prefix="char",
|
328 |
+
plot_boxes=True,
|
329 |
+
plot_text=True,
|
330 |
+
box_annotations=None,
|
331 |
+
):
|
332 |
+
return ut.add_text_to_ax(
|
333 |
+
chars_list,
|
334 |
+
ax,
|
335 |
+
font_to_use=font_to_use,
|
336 |
+
fontsize=fontsize,
|
337 |
+
prefix=prefix,
|
338 |
+
plot_boxes=plot_boxes,
|
339 |
+
plot_text=plot_text,
|
340 |
+
box_annotations=box_annotations,
|
341 |
+
)
|
342 |
+
|
343 |
+
|
344 |
+
@st.cache_data
|
345 |
+
def matplotlib_plot_df(
|
346 |
+
dffix,
|
347 |
+
trial,
|
348 |
+
algo_choice,
|
349 |
+
stimulus_prefix="word",
|
350 |
+
desired_dpi=300,
|
351 |
+
fix_to_plot=[],
|
352 |
+
stim_info_to_plot=["Words", "Word boxes"],
|
353 |
+
box_annotations=None,
|
354 |
+
):
|
355 |
+
return ut.matplotlib_plot_df(
|
356 |
+
dffix,
|
357 |
+
trial,
|
358 |
+
algo_choice,
|
359 |
+
stimulus_prefix=stimulus_prefix,
|
360 |
+
desired_dpi=desired_dpi,
|
361 |
+
fix_to_plot=fix_to_plot,
|
362 |
+
stim_info_to_plot=stim_info_to_plot,
|
363 |
+
box_annotations=box_annotations,
|
364 |
+
)
|
365 |
+
|
366 |
+
|
367 |
+
def sigmoid(x):
|
368 |
+
return 1 / (1 + np.exp(-1 * x))
|
369 |
+
|
370 |
+
|
371 |
+
@st.cache_data
|
372 |
+
def plotly_plot_with_image(
|
373 |
+
dffix,
|
374 |
+
trial,
|
375 |
+
algo_choice,
|
376 |
+
to_plot_list=["Uncorrected Fixations", "Words", "corrected fixations", "Word boxes"],
|
377 |
+
scale_factor=0.5,
|
378 |
+
):
|
379 |
+
return ut.plotly_plot_with_image(
|
380 |
+
dffix,
|
381 |
+
trial,
|
382 |
+
algo_choice,
|
383 |
+
to_plot_list=to_plot_list,
|
384 |
+
scale_factor=scale_factor,
|
385 |
+
)
|
386 |
+
|
387 |
+
|
388 |
+
@st.cache_data
|
389 |
+
def plot_y_corr(dffix, algo_choice):
|
390 |
+
return ut.plot_y_corr(dffix, algo_choice)
|
391 |
+
|
392 |
+
|
393 |
+
def plotly_df(
|
394 |
+
dffix=None, trial=None, algo_choice=None, to_plot_list=["fixations", "characters", "corrected fixations"], title=""
|
395 |
+
):
|
396 |
+
if dffix is None:
|
397 |
+
dffix = st.session_state["dffix"]
|
398 |
+
if algo_choice is None:
|
399 |
+
algo_choice = st.session_state["algo_choice"]
|
400 |
+
|
401 |
+
st.session_state["logger"].info(f"Plotting {to_plot_list}")
|
402 |
+
|
403 |
+
num_datapoints = dffix.index
|
404 |
+
if trial is None:
|
405 |
+
if title in st.session_state["results"]:
|
406 |
+
chars_df = pd.DataFrame(st.session_state["results"][title]["trial"]["chars_list"])
|
407 |
+
else:
|
408 |
+
chars_df = pd.DataFrame(st.session_state["trial"]["chars_df"])
|
409 |
+
else:
|
410 |
+
chars_df = pd.DataFrame(trial["chars_list"]) if "chars_list" in trial else None
|
411 |
+
if chars_df is not None:
|
412 |
+
font_face, font_size = get_font_and_font_size_from_trial(trial)
|
413 |
+
font_size = font_size * 0.65 # guess for scaling
|
414 |
+
xmin = chars_df.char_x_center.min()
|
415 |
+
xmax = chars_df.char_x_center.max()
|
416 |
+
ymin = chars_df.char_y_center.min()
|
417 |
+
ymax = chars_df.char_y_center.max()
|
418 |
+
else:
|
419 |
+
st.warning("No character or word information available to plot")
|
420 |
+
xmin = dffix.x.min()
|
421 |
+
xmax = dffix.x.max()
|
422 |
+
ymin = dffix.y.min()
|
423 |
+
ymax = dffix.y.max()
|
424 |
+
|
425 |
+
layout = dict(
|
426 |
+
plot_bgcolor="white",
|
427 |
+
autosize=True,
|
428 |
+
margin=dict(t=1, l=10, r=10, b=1),
|
429 |
+
xaxis=dict(
|
430 |
+
title="x-coordinate",
|
431 |
+
linecolor="black",
|
432 |
+
range=[xmin - 100, xmax + 100],
|
433 |
+
showgrid=False,
|
434 |
+
mirror="all",
|
435 |
+
showline=True,
|
436 |
+
),
|
437 |
+
yaxis=dict(
|
438 |
+
title="y-coordinate",
|
439 |
+
range=[ymax + 100, ymin - 100],
|
440 |
+
linecolor="black",
|
441 |
+
showgrid=False,
|
442 |
+
mirror="all",
|
443 |
+
showline=True,
|
444 |
+
),
|
445 |
+
legend=dict(orientation="h", yanchor="bottom", y=1.05, xanchor="right", x=0.8),
|
446 |
+
)
|
447 |
+
|
448 |
+
fig = go.Figure(layout=layout)
|
449 |
+
|
450 |
+
if "Uncorrected Fixations" in to_plot_list:
|
451 |
+
duration_scaled = dffix.duration - dffix.duration.min()
|
452 |
+
duration = ((duration_scaled + 0.1) / duration_scaled.median()) * 5
|
453 |
+
fig.add_trace(
|
454 |
+
go.Scatter(
|
455 |
+
x=dffix.x,
|
456 |
+
y=dffix.y,
|
457 |
+
mode="markers+lines+text",
|
458 |
+
name="Raw fixations",
|
459 |
+
marker=dict(
|
460 |
+
symbol="arrow",
|
461 |
+
size=duration.values,
|
462 |
+
angleref="previous",
|
463 |
+
),
|
464 |
+
line_width=1.2,
|
465 |
+
text=num_datapoints,
|
466 |
+
textposition="middle right",
|
467 |
+
textfont=dict(
|
468 |
+
family="sans serif",
|
469 |
+
size=9,
|
470 |
+
),
|
471 |
+
hoverinfo="text+x+y",
|
472 |
+
opacity=0.6,
|
473 |
+
)
|
474 |
+
)
|
475 |
+
if "Corrected Fixations" in to_plot_list:
|
476 |
+
if isinstance(algo_choice, list):
|
477 |
+
algo_choices = algo_choice
|
478 |
+
repeats = range(len(algo_choice))
|
479 |
+
else:
|
480 |
+
algo_choices = [algo_choice]
|
481 |
+
repeats = range(1)
|
482 |
+
for algoIdx in repeats:
|
483 |
+
algo_choice = algo_choices[algoIdx]
|
484 |
+
if f"y_{algo_choice}" in dffix.columns:
|
485 |
+
fig.add_trace(
|
486 |
+
go.Scatter(
|
487 |
+
x=dffix.x,
|
488 |
+
y=dffix.loc[:, f"y_{algo_choice}"],
|
489 |
+
mode="markers",
|
490 |
+
name=f"{algo_choice} corrected",
|
491 |
+
marker_color=st.session_state["COLORS"][algoIdx],
|
492 |
+
marker_size=5,
|
493 |
+
hoverinfo="text+x+y",
|
494 |
+
opacity=0.75,
|
495 |
+
)
|
496 |
+
)
|
497 |
+
if "Characters" in to_plot_list and chars_df is not None:
|
498 |
+
fig.add_trace(
|
499 |
+
go.Scatter(
|
500 |
+
x=chars_df.char_x_center,
|
501 |
+
y=chars_df.char_y_center,
|
502 |
+
mode="markers+text",
|
503 |
+
name="",
|
504 |
+
showlegend=False,
|
505 |
+
text=chars_df.char,
|
506 |
+
textposition="middle center",
|
507 |
+
marker=dict(color="black", size=0.1),
|
508 |
+
textfont=dict(family=font_face, size=font_size, color="Black"),
|
509 |
+
)
|
510 |
+
)
|
511 |
+
|
512 |
+
if "Character boxes (slow to plot)" in to_plot_list and chars_df is not None:
|
513 |
+
num = 0
|
514 |
+
for k, row in stqdm(chars_df.iterrows(), "Adding boxes"):
|
515 |
+
fig.add_shape(
|
516 |
+
type="rect",
|
517 |
+
x0=row.char_xmin,
|
518 |
+
y0=row.char_ymin,
|
519 |
+
x1=row.char_xmax,
|
520 |
+
y1=row.char_ymax,
|
521 |
+
line=dict(color=st.session_state["COLORS"][-1], width=1),
|
522 |
+
)
|
523 |
+
num += 1
|
524 |
+
return fig
|
525 |
+
|
526 |
+
|
527 |
+
def save_to_zips(folder, pattern, savename):
|
528 |
+
if os.path.exists(TEMP_FOLDER.joinpath(savename)):
|
529 |
+
mode = "a"
|
530 |
+
else:
|
531 |
+
mode = "w"
|
532 |
+
for idx, f in enumerate(folder.glob(pattern)):
|
533 |
+
with zipfile.ZipFile(TEMP_FOLDER.joinpath(savename), mode=mode) as archive:
|
534 |
+
archive.write(f)
|
535 |
+
st.session_state["logger"].info(f"Written {f} to zip {TEMP_FOLDER.joinpath(savename)}")
|
536 |
+
if idx == 1:
|
537 |
+
mode = "a"
|
538 |
+
st.session_state["logger"].info("Done zipping")
|
539 |
+
|
540 |
+
|
541 |
+
def process_multiple_asc(asc_files):
|
542 |
+
algo_choice = st.session_state["algo_choice_multi"]
|
543 |
+
if algo_choice is not None and "DIST" in algo_choice:
|
544 |
+
model, model_cfg = find_and_load_model(model_date=st.session_state["DIST_MODEL_DATE_WITH_NORM"])
|
545 |
+
model = st.session_state["single_DIST_model"]
|
546 |
+
model_cfg = st.session_state["single_DIST_model_cfg"]
|
547 |
+
st.session_state["logger"].info(f"process_multiple_asc loaded model")
|
548 |
+
else:
|
549 |
+
model, model_cfg = None, None
|
550 |
+
zipfiles_with_results = []
|
551 |
+
st.session_state["logger"].info(f"found asc_files {asc_files}")
|
552 |
+
|
553 |
+
for asc_file in stqdm(asc_files, desc="Processing asc files"):
|
554 |
+
st.session_state["logger"].info(f"processing asc_file {asc_file}")
|
555 |
+
asc_file_stem = pl.Path(asc_file.name).stem
|
556 |
+
trials_by_ids, lines = asc_to_trial_ids(asc_file)
|
557 |
+
for trial_id, trial in stqdm(trials_by_ids.items(), desc=f"\nProcessing trials in {asc_file_stem}"):
|
558 |
+
dffix, trial = process_trial(
|
559 |
+
trial,
|
560 |
+
asc_file_stem,
|
561 |
+
lines,
|
562 |
+
algo_choice,
|
563 |
+
True,
|
564 |
+
)
|
565 |
+
|
566 |
+
st.session_state["logger"].debug(f"dffix.columns after process trial {dffix.columns}")
|
567 |
+
if dffix.empty:
|
568 |
+
st.session_state["logger"].warning(f"Dataframe for {trial_id} is empty, skipping")
|
569 |
+
continue
|
570 |
+
st.session_state["results"][f"{asc_file_stem}_{trial_id}"] = {
|
571 |
+
"trial": trial,
|
572 |
+
"dffix": dffix,
|
573 |
+
}
|
574 |
+
st.session_state["logger"].debug(f"Added {asc_file_stem}_{trial_id} to st.session_state")
|
575 |
+
quick_dffix_save(dffix, TEMP_FOLDER.joinpath(f"{asc_file_stem}_{trial_id}.csv"))
|
576 |
+
save_trial_to_json(trial, TEMP_FOLDER.joinpath(f"{asc_file_stem}_{trial_id}.json"))
|
577 |
+
ut.plot_fixations_and_text(
|
578 |
+
dffix,
|
579 |
+
trial,
|
580 |
+
save=True,
|
581 |
+
savelocation=TEMP_FOLDER.joinpath(f"{asc_file_stem}_{trial_id}.png"),
|
582 |
+
algo_choice=algo_choice,
|
583 |
+
turn_axis_on=False,
|
584 |
+
)
|
585 |
+
if os.path.exists(TEMP_FOLDER.joinpath(f"{asc_file_stem}.zip")):
|
586 |
+
os.remove(TEMP_FOLDER.joinpath(f"{asc_file_stem}.zip"))
|
587 |
+
save_to_zips(TEMP_FOLDER, f"{asc_file_stem}*.csv", f"{asc_file_stem}.zip")
|
588 |
+
save_to_zips(TEMP_FOLDER, f"{asc_file_stem}*.json", f"{asc_file_stem}.zip")
|
589 |
+
save_to_zips(TEMP_FOLDER, f"{asc_file_stem}*.png", f"{asc_file_stem}.zip")
|
590 |
+
zipfiles_with_results += [str(x) for x in TEMP_FOLDER.glob(f"{asc_file_stem}*.zip")]
|
591 |
+
results_keys = list(st.session_state["results"].keys())
|
592 |
+
st.session_state["logger"].debug(f"results_keys are {results_keys}")
|
593 |
+
st.session_state["trial_choices_multi"] = results_keys
|
594 |
+
st.session_state["zipfiles_with_results"] = zipfiles_with_results
|
595 |
+
return (zipfiles_with_results, results_keys)
|
596 |
+
|
597 |
+
|
598 |
+
@st.cache_data
|
599 |
+
def get_trials_and_lines_from_asc_files(asc_files):
|
600 |
+
list_of_trial_lists = []
|
601 |
+
list_of_lines = []
|
602 |
+
total_num_trials = 0
|
603 |
+
|
604 |
+
asc_files_to_do = []
|
605 |
+
for filename_full in asc_files:
|
606 |
+
if hasattr(filename_full, "name") and not isinstance(filename_full, pl.Path):
|
607 |
+
file = filename_full.name
|
608 |
+
st.session_state["logger"].info(f"Filename is {file}, filename_full is {filename_full}")
|
609 |
+
else:
|
610 |
+
file = filename_full
|
611 |
+
if not isinstance(file, str):
|
612 |
+
file_stem = pl.Path(file.name).stem
|
613 |
+
else:
|
614 |
+
file_stem = pl.Path(file).stem
|
615 |
+
savefolder = gradio_temp_unzipped_folder.joinpath(file_stem)
|
616 |
+
st.session_state["logger"].info(f"Operating on file {file}")
|
617 |
+
if ".zip" in file:
|
618 |
+
with zipfile.ZipFile(filename_full, "r") as z:
|
619 |
+
z.extractall(str(savefolder))
|
620 |
+
elif ".tar" in file:
|
621 |
+
shutil.unpack_archive(file, savefolder, "tar")
|
622 |
+
elif ".asc" in file:
|
623 |
+
asc_files_to_do.append(filename_full)
|
624 |
+
else:
|
625 |
+
st.session_state["logger"].warning(f"Unsopported file format found in files")
|
626 |
+
newfiles = [str(x) for x in savefolder.glob(f"*.asc")]
|
627 |
+
asc_files_to_do += newfiles
|
628 |
+
st.session_state["logger"].info(f"asc_files_to_do is {asc_files_to_do}")
|
629 |
+
|
630 |
+
for asc_file in asc_files_to_do:
|
631 |
+
trials_by_ids, lines = asc_to_trial_ids(asc_file)
|
632 |
+
total_num_trials += len(trials_by_ids)
|
633 |
+
list_of_trial_lists.append(trials_by_ids)
|
634 |
+
list_of_lines.append(lines)
|
635 |
+
st.session_state["list_of_trial_lists"] = list_of_trial_lists
|
636 |
+
st.session_state["list_of_lines"] = list_of_lines
|
637 |
+
process_multiple_asc(st.session_state["multi_asc_filelist"])
|
638 |
+
|
639 |
+
|
640 |
+
def process_trial_choice_and_update_df_multi():
|
641 |
+
trial_id = st.session_state["trial_id_multi"]
|
642 |
+
dffix = st.session_state["results"][trial_id]["dffix"]
|
643 |
+
if "start_time" in dffix.columns:
|
644 |
+
dffix = dffix.drop(axis=1, labels=["start_time", "end_time"])
|
645 |
+
st.session_state["dffix_multi"] = dffix
|
646 |
+
st.session_state["trial_multi"] = st.session_state["results"][trial_id]["trial"]
|
647 |
+
|
648 |
+
|
649 |
+
@st.cache_data
|
650 |
+
def convert_df(df):
|
651 |
+
return df.to_csv(index=False).encode("utf-8")
|
652 |
+
|
653 |
+
|
654 |
+
def make_trial_from_stimulus_df(
|
655 |
+
stim_plot_df,
|
656 |
+
filename,
|
657 |
+
trial_id,
|
658 |
+
):
|
659 |
+
chars_list = []
|
660 |
+
words_list = []
|
661 |
+
word_start_idx = 0
|
662 |
+
for idx, row in stim_plot_df.reset_index().iterrows():
|
663 |
+
char_dict = dict(
|
664 |
+
char_xmin=row[st.session_state["x_start_col_name_fix_stim"]],
|
665 |
+
char_xmax=row[st.session_state["x_end_col_name_fix_stim"]],
|
666 |
+
char_ymin=row[st.session_state["y_start_col_name_fix_stim"]],
|
667 |
+
char_ymax=row[st.session_state["y_end_col_name_fix_stim"]],
|
668 |
+
char_x_center=row[st.session_state["x_col_name_fix_stim"]],
|
669 |
+
char_y_center=row[st.session_state["y_col_name_fix_stim"]],
|
670 |
+
char=row[st.session_state["char_col_name_fix_stim"]],
|
671 |
+
assigned_line=int(row[st.session_state["line_num_col_name_stim"]]),
|
672 |
+
)
|
673 |
+
chars_list.append(char_dict)
|
674 |
+
|
675 |
+
if len(chars_list) > 1 and (
|
676 |
+
char_dict["char"] == " "
|
677 |
+
or (len(chars_list) > 2 and (chars_list[-1]["char_xmin"] < chars_list[-2]["char_xmin"]))
|
678 |
+
):
|
679 |
+
word_dict = dict(
|
680 |
+
word_xmin=chars_list[word_start_idx]["char_xmin"],
|
681 |
+
word_xmax=chars_list[-2]["char_xmax"],
|
682 |
+
word_ymin=chars_list[word_start_idx]["char_ymin"],
|
683 |
+
word_ymax=chars_list[word_start_idx]["char_ymax"],
|
684 |
+
word_x_center=(chars_list[-2]["char_xmax"] - chars_list[word_start_idx]["char_xmin"]) / 2
|
685 |
+
+ chars_list[word_start_idx]["char_xmin"],
|
686 |
+
word_y_center=(chars_list[word_start_idx]["char_ymax"] - chars_list[word_start_idx]["char_ymin"]) / 2
|
687 |
+
+ chars_list[word_start_idx]["char_ymin"],
|
688 |
+
word="".join([chars_list[idx]["char"] for idx in range(word_start_idx, len(chars_list) - 1)]),
|
689 |
+
)
|
690 |
+
|
691 |
+
if char_dict["char"] != " ":
|
692 |
+
word_start_idx = idx
|
693 |
+
else:
|
694 |
+
word_start_idx = idx + 1
|
695 |
+
words_list.append(word_dict)
|
696 |
+
|
697 |
+
line_heights = [x["char_ymax"] - x["char_ymin"] for x in chars_list]
|
698 |
+
line_xcoords_all = [x["char_x_center"] for x in chars_list]
|
699 |
+
line_xcoords_no_pad = np.unique(line_xcoords_all)
|
700 |
+
|
701 |
+
line_ycoords_all = [x["char_y_center"] for x in chars_list]
|
702 |
+
line_ycoords_no_pad = np.unique(line_ycoords_all)
|
703 |
+
|
704 |
+
trial = dict(
|
705 |
+
filename=filename,
|
706 |
+
y_midline=[float(x) for x in list(stim_plot_df[st.session_state["y_col_name_fix_stim"]].unique())],
|
707 |
+
num_char_lines=len(stim_plot_df[st.session_state["y_col_name_fix_stim"]].unique()),
|
708 |
+
y_diff=[
|
709 |
+
float(x) for x in list(np.unique(np.diff(stim_plot_df[st.session_state["y_start_col_name_fix_stim"]])))
|
710 |
+
],
|
711 |
+
trial_id=trial_id,
|
712 |
+
chars_list=chars_list,
|
713 |
+
words_list=words_list,
|
714 |
+
trial_is="paragraph",
|
715 |
+
text="".join([x["char"] for x in chars_list]),
|
716 |
+
)
|
717 |
+
|
718 |
+
trial["x_char_unique"] = [float(x) for x in list(line_xcoords_no_pad)]
|
719 |
+
trial["y_char_unique"] = list(map(float, list(line_ycoords_no_pad)))
|
720 |
+
x_diff, y_diff = ut.calc_xdiff_ydiff(
|
721 |
+
line_xcoords_no_pad, line_ycoords_no_pad, line_heights, allow_multiple_values=False
|
722 |
+
)
|
723 |
+
trial["x_diff"] = float(x_diff)
|
724 |
+
trial["y_diff"] = float(y_diff)
|
725 |
+
trial["num_char_lines"] = len(line_ycoords_no_pad)
|
726 |
+
trial["line_heights"] = list(map(float, line_heights))
|
727 |
+
trial["chars_list"] = chars_list
|
728 |
+
|
729 |
+
return trial
|
730 |
+
|
731 |
+
|
732 |
+
@st.cache_data
|
733 |
+
def get_fixations_file_trials_list(fixations_df, stimulus):
|
734 |
+
if isinstance(stimulus, pd.DataFrame):
|
735 |
+
stimulus[st.session_state["line_num_col_name_stim"]] -= stimulus[
|
736 |
+
st.session_state["line_num_col_name_stim"]
|
737 |
+
].min()
|
738 |
+
stimulus.rename(
|
739 |
+
{
|
740 |
+
st.session_state["x_col_name_fix_stim"]: "char_x_center",
|
741 |
+
st.session_state["x_start_col_name_fix_stim"]: "char_xmin",
|
742 |
+
st.session_state["x_end_col_name_fix_stim"]: "char_xmax",
|
743 |
+
st.session_state["y_col_name_fix_stim"]: "char_y_center",
|
744 |
+
st.session_state["y_start_col_name_fix_stim"]: "char_ymin",
|
745 |
+
st.session_state["y_end_col_name_fix_stim"]: "char_ymax",
|
746 |
+
st.session_state["char_col_name_fix_stim"]: "char",
|
747 |
+
st.session_state["trial_id_col_name_stim"]: "trial_id",
|
748 |
+
},
|
749 |
+
axis=1,
|
750 |
+
inplace=True,
|
751 |
+
)
|
752 |
+
|
753 |
+
fixations_df.rename(
|
754 |
+
mapper={
|
755 |
+
st.session_state["x_col_name_fix"]: "x",
|
756 |
+
st.session_state["y_col_name_fix"]: "y",
|
757 |
+
st.session_state["time_start_col_name_fix"]: "corrected_start_time",
|
758 |
+
st.session_state["time_stop_col_name_fix"]: "corrected_end_time",
|
759 |
+
st.session_state["trial_id_col_name_fix"]: "trial_id",
|
760 |
+
},
|
761 |
+
axis=1,
|
762 |
+
inplace=True,
|
763 |
+
)
|
764 |
+
|
765 |
+
fixations_df["duration"] = fixations_df.corrected_end_time - fixations_df.corrected_start_time
|
766 |
+
if "trial_id" in stimulus:
|
767 |
+
fixations_df["trial_id"] = stimulus["trial_id"]
|
768 |
+
if "trial_id" in fixations_df:
|
769 |
+
if st.session_state["has_multiple_subject"]:
|
770 |
+
fixations_df["trial_id"] = [
|
771 |
+
f"{id}_{num}"
|
772 |
+
for id, num in zip(
|
773 |
+
fixations_df[st.session_state["subject_col_name_fix"]],
|
774 |
+
fixations_df[st.session_state["trial_id_col_name_fix"]],
|
775 |
+
)
|
776 |
+
]
|
777 |
+
trial_keys = list(fixations_df[st.session_state["trial_id_col_name_fix"]].unique())
|
778 |
+
st.session_state["logger"].info(f"Found keys {trial_keys} for {st.session_state['single_csv_file'].name}")
|
779 |
+
else:
|
780 |
+
st.session_state["logger"].warning(f"trial id column not found assigning trial id trial_0.")
|
781 |
+
st.warning(f"trial id column not found assigning trial id trial_0.")
|
782 |
+
fixations_df["trial_id"] = "trial_0"
|
783 |
+
st.session_state["fixations_df"] = fixations_df
|
784 |
+
trials_by_ids = {}
|
785 |
+
|
786 |
+
for trial_id, subdf in fixations_df.groupby("trial_id"):
|
787 |
+
if isinstance(stimulus, pd.DataFrame):
|
788 |
+
stim_df = stimulus[stimulus.trial_id == trial_id]
|
789 |
+
|
790 |
+
stim_df = stim_df.dropna(axis=0, how="any")
|
791 |
+
subdf = subdf.dropna(axis=0, how="any")
|
792 |
+
subdf = subdf.reset_index(drop=True)
|
793 |
+
stim_df = stim_df.reset_index(drop=True)
|
794 |
+
assert not stim_df.empty, "stimulus df is empty"
|
795 |
+
trial = make_trial_from_stimulus_df(
|
796 |
+
stim_df,
|
797 |
+
st.session_state["single_csv_file_stim"].name,
|
798 |
+
trial_id,
|
799 |
+
)
|
800 |
+
else:
|
801 |
+
trial = stimulus
|
802 |
+
trial["dffix"] = subdf
|
803 |
+
trial["fname"] = f"{trial_id}"
|
804 |
+
trial["plot_file"] = str(
|
805 |
+
st.session_state["PLOTS_FOLDER"].joinpath(f"{trial_id}_2ndInput_chars_channel_sep.png")
|
806 |
+
)
|
807 |
+
trials_by_ids[trial_id] = trial
|
808 |
+
|
809 |
+
return trials_by_ids, trial_keys
|
810 |
+
|
811 |
+
|
812 |
+
def try_reading_csv(file):
|
813 |
+
stringio = StringIO(file.getvalue().decode("utf-8"))
|
814 |
+
colname_mapping = {}
|
815 |
+
try:
|
816 |
+
df = pd.read_csv(stringio)
|
817 |
+
st.session_state["logger"].info(f"\n{df.head()}")
|
818 |
+
col_list = df.columns
|
819 |
+
assert len(col_list) > 1
|
820 |
+
return df
|
821 |
+
except Exception as e:
|
822 |
+
st.session_state["logger"].warn(e)
|
823 |
+
try:
|
824 |
+
df = pd.read_csv(StringIO(file.getvalue().decode("utf-8")), delimiter="\t")
|
825 |
+
col_list = df.columns
|
826 |
+
assert len(col_list) > 1
|
827 |
+
return df
|
828 |
+
except Exception as e:
|
829 |
+
st.session_state["logger"].warn(e)
|
830 |
+
return None
|
831 |
+
|
832 |
+
|
833 |
+
@st.cache_data
|
834 |
+
def guess_col_names_fix(file=None):
|
835 |
+
if file is None:
|
836 |
+
file = st.session_state["single_csv_file"]
|
837 |
+
if file is None:
|
838 |
+
return None
|
839 |
+
|
840 |
+
first_line = next(iter(StringIO(file.getvalue().decode("utf-8"))))
|
841 |
+
res = re.findall(r"[^()0-9-]+", first_line)
|
842 |
+
for delim in [",", "\t", ";"]:
|
843 |
+
first_line = first_line.split(delim)
|
844 |
+
if len(first_line) > 2:
|
845 |
+
break
|
846 |
+
else:
|
847 |
+
first_line = first_line[0]
|
848 |
+
scores_lists = {}
|
849 |
+
for k, v in st.session_state["colnames_custom_csv_fix"].items():
|
850 |
+
scores_lists[v] = []
|
851 |
+
for word in first_line:
|
852 |
+
scores_lists[v].append(jf.levenshtein_distance(v, word))
|
853 |
+
scores_df = pd.DataFrame(scores_lists)
|
854 |
+
scores_df.idxmin(axis=0)
|
855 |
+
df = try_reading_csv(file)
|
856 |
+
if df.shape[1] > 1:
|
857 |
+
return df
|
858 |
+
else:
|
859 |
+
return None
|
860 |
+
|
861 |
+
|
862 |
+
@st.cache_data
|
863 |
+
def guess_col_names_stim(file=None):
|
864 |
+
if file is None:
|
865 |
+
file = st.session_state["single_csv_file_stim"]
|
866 |
+
if file is None:
|
867 |
+
return None
|
868 |
+
if ".json" in file.name:
|
869 |
+
json_string = file.getvalue().decode("utf-8")
|
870 |
+
trial = json.loads(json_string)
|
871 |
+
return trial
|
872 |
+
else:
|
873 |
+
df = try_reading_csv(file)
|
874 |
+
|
875 |
+
if df.shape[1] > 1:
|
876 |
+
return df
|
877 |
+
else:
|
878 |
+
return None
|
879 |
+
|
880 |
+
|
881 |
+
@st.cache_resource
|
882 |
+
def set_up_models(dist_models_folder):
|
883 |
+
return ut.set_up_models(dist_models_folder)
|
884 |
+
|
885 |
+
@st.cache_data
|
886 |
+
def get_eyekit_measures(_txt, _seq, get_char_measures=False):
|
887 |
+
return ekm.get_eyekit_measures(_txt, _seq, get_char_measures=get_char_measures)
|
888 |
+
|
889 |
+
|
890 |
+
@st.cache_data
|
891 |
+
def get_all_measures(trial, dffix, prefix, use_corrected_fixations=True, correction_algo="warp"):
|
892 |
+
return ut.get_all_measures(trial, dffix, prefix, use_corrected_fixations=use_corrected_fixations, correction_algo=correction_algo)
|
893 |
+
|
894 |
+
|
895 |
+
assert "ALGO_CHOICES" in st.session_state, f"st.session_state not initialized\n{list(st.session_state.keys())}"
|
896 |
+
|
897 |
+
set_up_models_out = set_up_models(DIST_MODELS_FOLDER)
|
898 |
+
st.session_state.update(set_up_models_out)
|
899 |
+
|
900 |
+
|
901 |
+
st.title("Fixation data vertical alignment")
|
902 |
+
st.header("👀 Read asc file or files and plot fixations 👀")
|
903 |
+
st.markdown("[Contact Us](mailto:[email protected])")
|
904 |
+
st.markdown("[Read about DIST model](https://arxiv.org/abs/2311.06095)")
|
905 |
+
|
906 |
+
single_file_tab, multi_file_tab = st.tabs(["Single File 📁", "Multiple Files 📁 📁"])
|
907 |
+
|
908 |
+
single_file_tab_asc_tab, single_file_tab_csv_tab = single_file_tab.tabs([".asc files", "custom files"])
|
909 |
+
|
910 |
+
single_file_tab_asc_tab.subheader(
|
911 |
+
"Upload an .asc file and select a trial. Then select a correction algorithm and plot/download the results"
|
912 |
+
)
|
913 |
+
|
914 |
+
|
915 |
+
def change_which_file_is_used_and_clear_results():
|
916 |
+
if "dffix" in st.session_state:
|
917 |
+
del st.session_state["dffix"]
|
918 |
+
if "trial" in st.session_state:
|
919 |
+
del st.session_state["trial"]
|
920 |
+
if st.session_state["single_file_tab_asc_tab_example_use_example_or_uploaded_file_choice"] == "Example File":
|
921 |
+
st.session_state["single_asc_file_asc"] = st.session_state["single_file_tab_asc_tab_example_file_choice"]
|
922 |
+
else:
|
923 |
+
st.session_state["single_asc_file_asc"] = st.session_state["single_asc_uploaded_file"]
|
924 |
+
|
925 |
+
|
926 |
+
with single_file_tab_asc_tab.form("single_file_tab_asc_tab_load_example_form"):
|
927 |
+
single_asc_file_asc_uploaded = st.file_uploader(
|
928 |
+
"Select .asc File", accept_multiple_files=False, key="single_asc_uploaded_file", type=["asc"]
|
929 |
+
)
|
930 |
+
close_gap_between_words_single_asc = st.checkbox(
|
931 |
+
label="Should spaces between words be included in word bounding box?",
|
932 |
+
value=False,
|
933 |
+
key="close_gap_between_words_single_asc",
|
934 |
+
)
|
935 |
+
|
936 |
+
if os.path.isfile(EXAMPLE_ASC_FILES[0]):
|
937 |
+
example_file_choice = st.selectbox(
|
938 |
+
"Select example file", options=EXAMPLE_ASC_FILES, key="single_file_tab_asc_tab_example_file_choice"
|
939 |
+
)
|
940 |
+
use_example_or_uploaded_file_choice = st.radio(
|
941 |
+
"Should the uploaded file be used or the selected example file?",
|
942 |
+
index=1,
|
943 |
+
options=["Uploaded File", "Example File"],
|
944 |
+
key="single_file_tab_asc_tab_example_use_example_or_uploaded_file_choice",
|
945 |
+
)
|
946 |
+
|
947 |
+
upload_file_button = st.form_submit_button(
|
948 |
+
label="Load selected data.", on_click=change_which_file_is_used_and_clear_results
|
949 |
+
)
|
950 |
+
|
951 |
+
|
952 |
+
if "single_asc_file_asc" in st.session_state and st.session_state["single_asc_file_asc"] is not None:
|
953 |
+
trial_choices_single_asc, trials_by_ids, lines, asc_file = get_trials_list(
|
954 |
+
st.session_state["single_asc_file_asc"], close_gap_between_words=close_gap_between_words_single_asc
|
955 |
+
)
|
956 |
+
st.session_state["trials_by_ids"] = trials_by_ids
|
957 |
+
st.session_state["trial_choices"] = trial_choices_single_asc
|
958 |
+
st.session_state["lines"] = lines
|
959 |
+
st.session_state["asc_file"] = asc_file
|
960 |
+
if trial_choices_single_asc:
|
961 |
+
with single_file_tab_asc_tab.form(key="single_file_tab_asc_tab_trial_select_form"):
|
962 |
+
col_a1, col_a2 = st.columns((1, 2))
|
963 |
+
with col_a1:
|
964 |
+
trial_choice = st.selectbox(
|
965 |
+
"Which trial should be corrected?",
|
966 |
+
trial_choices_single_asc,
|
967 |
+
key="trial_id",
|
968 |
+
index=0,
|
969 |
+
)
|
970 |
+
with col_a2:
|
971 |
+
st.multiselect(
|
972 |
+
"Choose correction algorithm",
|
973 |
+
ALGO_CHOICES,
|
974 |
+
key="algo_choice",
|
975 |
+
default=[ALGO_CHOICES[0]],
|
976 |
+
)
|
977 |
+
process_trial_btn = st.form_submit_button("Load and correct trial")
|
978 |
+
|
979 |
+
if process_trial_btn:
|
980 |
+
single_file_tab_asc_tab.write(f'You selected: {st.session_state["trial_id"]}')
|
981 |
+
dffix, trial, dpi, screen_res, font, font_size = process_trial_choice(
|
982 |
+
trial_choice, st.session_state["algo_choice"]
|
983 |
+
)
|
984 |
+
|
985 |
+
st.session_state["dffix"] = dffix
|
986 |
+
st.session_state["trial"] = trial
|
987 |
+
st.session_state["dpi"] = dpi
|
988 |
+
st.session_state["screen_res"] = screen_res
|
989 |
+
st.session_state["font"] = font
|
990 |
+
st.session_state["font_size"] = font_size
|
991 |
+
|
992 |
+
export_csv(dffix, trial)
|
993 |
+
|
994 |
+
if "dffix" in st.session_state and "trial" in st.session_state:
|
995 |
+
df_expander_single = single_file_tab_asc_tab.expander("Show Dataframe", False)
|
996 |
+
plot_expander_single = single_file_tab_asc_tab.expander("Show Plots", True)
|
997 |
+
df_expander_single.dataframe(st.session_state["dffix"])
|
998 |
+
|
999 |
+
csv = convert_df(st.session_state["dffix"])
|
1000 |
+
|
1001 |
+
df_expander_single.download_button(
|
1002 |
+
"Download fixation dataframe",
|
1003 |
+
csv,
|
1004 |
+
f'{st.session_state["trial_id"]}.csv',
|
1005 |
+
"text/csv",
|
1006 |
+
key="download-csv-single",
|
1007 |
+
)
|
1008 |
+
|
1009 |
+
plotting_checkboxes_single = plot_expander_single.multiselect(
|
1010 |
+
"Select what gets plotted",
|
1011 |
+
["Uncorrected Fixations", "Corrected Fixations", "Words", "Word boxes"],
|
1012 |
+
key="plotting_checkboxes_single",
|
1013 |
+
default=["Uncorrected Fixations", "Corrected Fixations", "Words", "Word boxes"],
|
1014 |
+
)
|
1015 |
+
scale_factor_single_asc = plot_expander_single.number_input(
|
1016 |
+
label="Scale factor for stimulus image", min_value=0.01, max_value=3.0, value=0.5, step=0.1
|
1017 |
+
)
|
1018 |
+
plot_expander_single.plotly_chart(
|
1019 |
+
plotly_plot_with_image(
|
1020 |
+
st.session_state["dffix"],
|
1021 |
+
st.session_state["trial"],
|
1022 |
+
to_plot_list=plotting_checkboxes_single,
|
1023 |
+
algo_choice=st.session_state["algo_choice"],
|
1024 |
+
scale_factor=scale_factor_single_asc,
|
1025 |
+
),
|
1026 |
+
use_container_width=False,
|
1027 |
+
)
|
1028 |
+
plot_expander_single.plotly_chart(
|
1029 |
+
plot_y_corr(st.session_state["dffix"], st.session_state["algo_choice"]), use_container_width=True
|
1030 |
+
)
|
1031 |
+
|
1032 |
+
if "chars_list" in st.session_state["trial"]:
|
1033 |
+
analysis_expander_single_asc = single_file_tab_asc_tab.expander("Show Analysis results", True)
|
1034 |
+
use_corrected_fixations_tickbox = analysis_expander_single_asc.checkbox(
|
1035 |
+
"Use corrected",
|
1036 |
+
True,
|
1037 |
+
"use_corrected_fixations_tickbox",
|
1038 |
+
help="Whether to use the corrected or uncorrected fixations for the analysis.",
|
1039 |
+
)
|
1040 |
+
eyekit_tab, own_analysis_tab = analysis_expander_single_asc.tabs(
|
1041 |
+
["Analysis using eyekit", "Analysis without eyekit"]
|
1042 |
+
)
|
1043 |
+
with eyekit_tab:
|
1044 |
+
st.markdown("Analysis powered by [eyekit](https://jwcarr.github.io/eyekit/)")
|
1045 |
+
st.markdown(
|
1046 |
+
"Please adjust parameters below to align fixations with stimulus using the sliders.Eyekit analysis is based on this alignment."
|
1047 |
+
)
|
1048 |
+
a_c1, a_c2, a_c3, a_c4, a_c5, a_c6 = st.columns(6)
|
1049 |
+
if "Consolas" in AVAILABLE_FONTS:
|
1050 |
+
font_index = AVAILABLE_FONTS.index("Consolas")
|
1051 |
+
elif "Courier New" in AVAILABLE_FONTS:
|
1052 |
+
font_index = AVAILABLE_FONTS.index("Courier New")
|
1053 |
+
elif "DejaVu Sans Mono" in AVAILABLE_FONTS:
|
1054 |
+
font_index = AVAILABLE_FONTS.index("DejaVu Sans Mono")
|
1055 |
+
else:
|
1056 |
+
font_index = 0
|
1057 |
+
font_face = a_c1.selectbox(
|
1058 |
+
label="Select Font",
|
1059 |
+
options=AVAILABLE_FONTS,
|
1060 |
+
index=font_index,
|
1061 |
+
key="font_face_for_eyekit_single_asc",
|
1062 |
+
)
|
1063 |
+
algo_choice_single_asc_eyekit = a_c1.selectbox(
|
1064 |
+
"Algorithm", st.session_state["algo_choice"], index=0, key="algo_choice_single_asc_eyekit"
|
1065 |
+
)
|
1066 |
+
sliders_on_tickbox = a_c6.checkbox(
|
1067 |
+
"Sliders", True, "single_asc_eyekit_sliders_checkbox", help="Turns sliders on and off"
|
1068 |
+
)
|
1069 |
+
|
1070 |
+
if "font_size_for_eyekit" not in st.session_state:
|
1071 |
+
(
|
1072 |
+
y_diff,
|
1073 |
+
x_txt_start,
|
1074 |
+
y_txt_start,
|
1075 |
+
_,
|
1076 |
+
_,
|
1077 |
+
line_height,
|
1078 |
+
) = add_default_font_and_character_props_to_state(st.session_state["trial"])
|
1079 |
+
font_size = ut.set_font_from_chars_list(st.session_state["trial"])
|
1080 |
+
st.session_state["y_diff_for_eyekit"] = y_diff
|
1081 |
+
st.session_state["x_txt_start_for_eyekit"] = x_txt_start
|
1082 |
+
st.session_state["y_txt_start_for_eyekit"] = y_txt_start
|
1083 |
+
st.session_state["font_face_for_eyekit"] = font_face
|
1084 |
+
st.session_state["font_size_for_eyekit"] = font_size
|
1085 |
+
st.session_state["line_height_for_eyekit"] = line_height
|
1086 |
+
if sliders_on_tickbox:
|
1087 |
+
font_size = a_c2.select_slider(
|
1088 |
+
"Font Size",
|
1089 |
+
np.arange(5, 36, 0.25),
|
1090 |
+
st.session_state["font_size_for_eyekit"],
|
1091 |
+
key="font_size_for_eyekit_single_asc",
|
1092 |
+
)
|
1093 |
+
x_txt_start = a_c3.select_slider(
|
1094 |
+
"x",
|
1095 |
+
np.arange(300, 601, 1),
|
1096 |
+
round(st.session_state["x_txt_start_for_eyekit"]),
|
1097 |
+
key="x_txt_start_for_eyekit_single_asc",
|
1098 |
+
help="x coordinate of first character",
|
1099 |
+
)
|
1100 |
+
y_txt_start = a_c4.select_slider(
|
1101 |
+
"y",
|
1102 |
+
np.arange(100, 501, 1),
|
1103 |
+
round(st.session_state["y_txt_start_for_eyekit"]),
|
1104 |
+
key="y_txt_start_for_eyekit_single_asc",
|
1105 |
+
help="y coordinate of first character",
|
1106 |
+
)
|
1107 |
+
line_height = a_c5.select_slider(
|
1108 |
+
"Line height",
|
1109 |
+
np.arange(0, 151, 1),
|
1110 |
+
round(st.session_state["line_height_for_eyekit"]),
|
1111 |
+
key="line_height_for_eyekit_single_asc",
|
1112 |
+
)
|
1113 |
+
else:
|
1114 |
+
font_size = a_c2.number_input(
|
1115 |
+
"Font Size",
|
1116 |
+
None,
|
1117 |
+
None,
|
1118 |
+
round(st.session_state["font_size_for_eyekit"], ndigits=0),
|
1119 |
+
key="font_size_for_eyekit_single_asc",
|
1120 |
+
)
|
1121 |
+
x_txt_start = a_c3.number_input(
|
1122 |
+
"x",
|
1123 |
+
None,
|
1124 |
+
None,
|
1125 |
+
round(st.session_state["x_txt_start_for_eyekit"]),
|
1126 |
+
key="x_txt_start_for_eyekit_single_asc",
|
1127 |
+
help="x coordinate of first character",
|
1128 |
+
)
|
1129 |
+
y_txt_start = a_c4.number_input(
|
1130 |
+
"y",
|
1131 |
+
None,
|
1132 |
+
None,
|
1133 |
+
round(st.session_state["y_txt_start_for_eyekit"]),
|
1134 |
+
key="y_txt_start_for_eyekit_single_asc",
|
1135 |
+
help="y coordinate of first character",
|
1136 |
+
)
|
1137 |
+
line_height = a_c5.number_input(
|
1138 |
+
"Line height",
|
1139 |
+
None,
|
1140 |
+
None,
|
1141 |
+
round(st.session_state["line_height_for_eyekit"]),
|
1142 |
+
key="line_height_for_eyekit_single_asc",
|
1143 |
+
)
|
1144 |
+
|
1145 |
+
fixation_sequence, textblock, screen_size = ekm.get_fix_seq_and_text_block(
|
1146 |
+
st.session_state["dffix"],
|
1147 |
+
st.session_state["trial"],
|
1148 |
+
x_txt_start=st.session_state["x_txt_start_for_eyekit_single_asc"],
|
1149 |
+
y_txt_start=st.session_state["y_txt_start_for_eyekit_single_asc"],
|
1150 |
+
font_face=st.session_state["font_face_for_eyekit_single_asc"],
|
1151 |
+
font_size=st.session_state["font_size_for_eyekit_single_asc"],
|
1152 |
+
line_height=line_height,
|
1153 |
+
use_corrected_fixations=st.session_state["use_corrected_fixations_tickbox"],
|
1154 |
+
correction_algo=st.session_state["algo_choice_single_asc_eyekit"],
|
1155 |
+
)
|
1156 |
+
eyekitplot_img = ekm.eyekit_plot(textblock, fixation_sequence, screen_size)
|
1157 |
+
st.image(eyekitplot_img, "Fixations and stimulus as used for anaylsis")
|
1158 |
+
|
1159 |
+
with open(
|
1160 |
+
f'results/fixation_sequence_eyekit_{st.session_state["trial"]["trial_id"]}.json', "r"
|
1161 |
+
) as f:
|
1162 |
+
fixation_sequence_json = json.load(f)
|
1163 |
+
fixation_sequence_json_str = json.dumps(fixation_sequence_json)
|
1164 |
+
|
1165 |
+
st.download_button(
|
1166 |
+
"Download fixations in eyekits format",
|
1167 |
+
fixation_sequence_json_str,
|
1168 |
+
f'fixation_sequence_eyekit_{st.session_state["trial"]["trial_id"]}.json',
|
1169 |
+
"json",
|
1170 |
+
key="download_eyekit_fix_json_single_asc",
|
1171 |
+
)
|
1172 |
+
|
1173 |
+
with open(f'results/textblock_eyekit_{st.session_state["trial"]["trial_id"]}.json', "r") as f:
|
1174 |
+
textblock_json = json.load(f)
|
1175 |
+
textblock_json_str = json.dumps(textblock_json)
|
1176 |
+
|
1177 |
+
st.download_button(
|
1178 |
+
"Download stimulus in eyekits format",
|
1179 |
+
textblock_json_str,
|
1180 |
+
f'textblock_eyekit_{st.session_state["trial"]["trial_id"]}.json',
|
1181 |
+
"json",
|
1182 |
+
key="download_eyekit_text_json_single_asc",
|
1183 |
+
)
|
1184 |
+
|
1185 |
+
word_measures_df, character_measures_df = get_eyekit_measures(
|
1186 |
+
textblock, fixation_sequence, get_char_measures=False
|
1187 |
+
)
|
1188 |
+
|
1189 |
+
st.dataframe(word_measures_df, use_container_width=True, hide_index=True)
|
1190 |
+
word_measures_df_csv = convert_df(word_measures_df)
|
1191 |
+
|
1192 |
+
word_measures_df_download_btn = st.download_button(
|
1193 |
+
"Download word measures data",
|
1194 |
+
word_measures_df_csv,
|
1195 |
+
f'{st.session_state["trial"]["trial_id"]}_word_measures_df.csv',
|
1196 |
+
"text/csv",
|
1197 |
+
key="word_measures_df_download_btn",
|
1198 |
+
)
|
1199 |
+
measure_words = st.selectbox(
|
1200 |
+
"Select measure to visualize", list(ekm.MEASURES_DICT.keys()), key="measure_words"
|
1201 |
+
)
|
1202 |
+
st.image(ekm.plot_with_measure(textblock, fixation_sequence, screen_size, measure_words))
|
1203 |
+
with own_analysis_tab:
|
1204 |
+
st.markdown(
|
1205 |
+
"This analysis method does not require manual alignment and works when the automated stimulus coordinates are correct."
|
1206 |
+
)
|
1207 |
+
own_word_measures = get_all_measures(
|
1208 |
+
st.session_state["trial"],
|
1209 |
+
st.session_state["dffix"],
|
1210 |
+
prefix="word",
|
1211 |
+
use_corrected_fixations=st.session_state["use_corrected_fixations_tickbox"],
|
1212 |
+
correction_algo=st.session_state["algo_choice_single_asc_eyekit"],
|
1213 |
+
)
|
1214 |
+
st.dataframe(own_word_measures, use_container_width=True, hide_index=True)
|
1215 |
+
own_word_measures_csv = convert_df(own_word_measures)
|
1216 |
+
|
1217 |
+
word_measures_df_download_btn = st.download_button(
|
1218 |
+
"Download word measures data",
|
1219 |
+
own_word_measures_csv,
|
1220 |
+
f'{st.session_state["trial"]["trial_id"]}_own_word_measures_df.csv',
|
1221 |
+
"text/csv",
|
1222 |
+
key="own_word_measures_df_download_btn",
|
1223 |
+
)
|
1224 |
+
fix_to_plot = (
|
1225 |
+
["Corrected Fixations"]
|
1226 |
+
if st.session_state["use_corrected_fixations_tickbox"]
|
1227 |
+
else ["Uncorrected Fixations"]
|
1228 |
+
)
|
1229 |
+
own_word_measures_fig, desired_width_in_pixels, desired_height_in_pixels = matplotlib_plot_df(
|
1230 |
+
st.session_state["dffix"],
|
1231 |
+
st.session_state["trial"],
|
1232 |
+
st.session_state["algo_choice"],
|
1233 |
+
stimulus_prefix="word",
|
1234 |
+
box_annotations=own_word_measures[measure_words],
|
1235 |
+
fix_to_plot=fix_to_plot,
|
1236 |
+
)
|
1237 |
+
st.pyplot(own_word_measures_fig)
|
1238 |
+
else:
|
1239 |
+
single_file_tab_asc_tab.warning("🚨 Stimulus information needed for analysis 🚨")
|
1240 |
+
|
1241 |
+
single_file_tab_csv_tab.markdown(
|
1242 |
+
"#### Upload one .csv file for the fixations and one .json or .csv file for the stimulus information and select a trial. Then select a correction algorithm and plot/download the results"
|
1243 |
+
)
|
1244 |
+
|
1245 |
+
with single_file_tab_csv_tab.expander("Upload and preview data", expanded=True):
|
1246 |
+
csv_upl_col1, csv_upl_col2 = st.columns(2)
|
1247 |
+
single_csv_file = csv_upl_col1.file_uploader(
|
1248 |
+
"Select .csv file containing the fixation data",
|
1249 |
+
accept_multiple_files=False,
|
1250 |
+
key="single_csv_file",
|
1251 |
+
type={"csv", "txt", "dat"},
|
1252 |
+
)
|
1253 |
+
single_csv_stim_file = csv_upl_col2.file_uploader(
|
1254 |
+
"Select .csv or .json file containing the stimulus data",
|
1255 |
+
accept_multiple_files=False,
|
1256 |
+
key="single_csv_file_stim",
|
1257 |
+
type={"json", "csv", "txt", "dat"},
|
1258 |
+
)
|
1259 |
+
|
1260 |
+
if single_csv_file:
|
1261 |
+
st.session_state["dffix_single_csv"] = guess_col_names_fix(single_csv_file)
|
1262 |
+
if st.session_state["dffix_single_csv"] is not None:
|
1263 |
+
csv_upl_col1.dataframe(
|
1264 |
+
st.session_state["dffix_single_csv"], use_container_width=True, hide_index=True, height=200
|
1265 |
+
)
|
1266 |
+
if single_csv_stim_file:
|
1267 |
+
st.session_state["stimdf_single_csv"] = guess_col_names_stim(single_csv_stim_file)
|
1268 |
+
if ".json" in single_csv_stim_file.name:
|
1269 |
+
st.session_state["colnames_stim"] = st.session_state["stimdf_single_csv"].keys()
|
1270 |
+
else:
|
1271 |
+
st.session_state["colnames_stim"] = st.session_state["stimdf_single_csv"].columns
|
1272 |
+
if st.session_state["stimdf_single_csv"] is not None:
|
1273 |
+
if ".json" in single_csv_stim_file.name:
|
1274 |
+
csv_upl_col2.json(st.session_state["stimdf_single_csv"])
|
1275 |
+
else:
|
1276 |
+
csv_upl_col2.dataframe(
|
1277 |
+
st.session_state["stimdf_single_csv"], use_container_width=True, hide_index=True, height=200
|
1278 |
+
)
|
1279 |
+
|
1280 |
+
if single_csv_file and single_csv_stim_file:
|
1281 |
+
with single_file_tab_csv_tab.expander("Column names for csv files", expanded=True):
|
1282 |
+
with st.form("Column names for csv files"):
|
1283 |
+
st.markdown("### Please set column/key names for csv/json files")
|
1284 |
+
st.markdown("#### Fixation file column names:")
|
1285 |
+
c1, c2, c3 = st.columns(3)
|
1286 |
+
x_col_name_fix = c1.text_input("x coordinate", key="x_col_name_fix", value="x")
|
1287 |
+
y_col_name_fix = c2.text_input("y coordinate", key="y_col_name_fix", value="y")
|
1288 |
+
subject_col_name_fix = c1.text_input("subject id", key="subject_col_name_fix", value="sub_id")
|
1289 |
+
trial_id_col_name_fix = c3.text_input("trial id", key="trial_id_col_name_fix", value="trial_id")
|
1290 |
+
time_start_col_name_fix = c2.text_input(
|
1291 |
+
"fixation start time", key="time_start_col_name_fix", value="corrected_start_time"
|
1292 |
+
)
|
1293 |
+
time_stop_col_name_fix = c3.text_input(
|
1294 |
+
"fixation end time", key="time_stop_col_name_fix", value="corrected_end_time"
|
1295 |
+
)
|
1296 |
+
st.markdown("#### Stimulus file column/key names:")
|
1297 |
+
c1, c2, c3 = st.columns(3)
|
1298 |
+
x_col_name_fix_stim = c1.text_input("x coordinate", key="x_col_name_fix_stim", value="char_x_center")
|
1299 |
+
y_col_name_fix_stim = c2.text_input("y coordinate", key="y_col_name_fix_stim", value="char_y_center")
|
1300 |
+
x_start_col_name_fix_stim = c3.text_input(
|
1301 |
+
"x min of interest areas", key="x_start_col_name_fix_stim", value="char_xmin"
|
1302 |
+
)
|
1303 |
+
x_end_col_name_fix_stim = c1.text_input(
|
1304 |
+
"x max of interest areas", key="x_end_col_name_fix_stim", value="char_xmax"
|
1305 |
+
)
|
1306 |
+
y_start_col_name_fix_stim = c2.text_input(
|
1307 |
+
"y min of interest areas", key="y_start_col_name_fix_stim", value="char_ymin"
|
1308 |
+
)
|
1309 |
+
y_end_col_name_fix_stim = c3.text_input(
|
1310 |
+
"x max of interest areas", key="y_end_col_name_fix_stim", value="char_ymax"
|
1311 |
+
)
|
1312 |
+
char_col_name_fix_stim = c1.text_input(
|
1313 |
+
"content of interest area", key="char_col_name_fix_stim", value="char"
|
1314 |
+
)
|
1315 |
+
line_num_col_name_stim = c3.text_input(
|
1316 |
+
"line number for interest areas", key="line_num_col_name_stim", value="assigned_line"
|
1317 |
+
)
|
1318 |
+
subject_col_name_stim = c1.text_input("subject id", key="subject_col_name_stim", value="sub_id")
|
1319 |
+
trial_id_col_name_stim = c2.text_input("trial id", key="trial_id_col_name_stim", value="trial_id")
|
1320 |
+
has_multiple_subject = c2.checkbox("multiple subject in file", key="has_multiple_subject")
|
1321 |
+
form_submitted = st.form_submit_button("Confirm column/key names")
|
1322 |
+
|
1323 |
+
|
1324 |
+
if single_csv_file and single_csv_stim_file:
|
1325 |
+
process_custom_csvs_button = single_file_tab_csv_tab.button(
|
1326 |
+
"Load data from files",
|
1327 |
+
)
|
1328 |
+
if process_custom_csvs_button or "trial_keys_single_csv" in st.session_state:
|
1329 |
+
trials_by_ids, trial_keys = get_fixations_file_trials_list(
|
1330 |
+
st.session_state["dffix_single_csv"], st.session_state["stimdf_single_csv"]
|
1331 |
+
)
|
1332 |
+
|
1333 |
+
st.session_state["trials_by_ids_single_csv"] = trials_by_ids
|
1334 |
+
st.session_state["trial_keys_single_csv"] = trial_keys
|
1335 |
+
with single_file_tab_csv_tab.form(key="trial_selection_algo_selection_form_single_csv"):
|
1336 |
+
col_a1, col_a2 = st.columns((1, 2))
|
1337 |
+
with col_a1:
|
1338 |
+
trial_choice = st.selectbox(
|
1339 |
+
"Which trial should be corrected?",
|
1340 |
+
st.session_state["trial_keys_single_csv"],
|
1341 |
+
key="trial_id_selected_custom_csv",
|
1342 |
+
index=0,
|
1343 |
+
)
|
1344 |
+
with col_a2:
|
1345 |
+
algo_choice_single_csv = st.multiselect(
|
1346 |
+
"Choose correction algorithm",
|
1347 |
+
ALGO_CHOICES,
|
1348 |
+
key="algo_choice_single_csv",
|
1349 |
+
default=[ALGO_CHOICES[0]],
|
1350 |
+
)
|
1351 |
+
process_trial_btn = st.form_submit_button("Correct trial")
|
1352 |
+
if "trial_id_selected_custom_csv" in st.session_state and "algo_choice_single_csv" in st.session_state:
|
1353 |
+
trial = st.session_state["trials_by_ids_single_csv"][trial_choice]
|
1354 |
+
dffix, trial, dpi, screen_res, font, font_size = process_trial_choice_single_csv(
|
1355 |
+
trial, algo_choice_single_csv
|
1356 |
+
)
|
1357 |
+
st.session_state["trial_single_csv"] = trial
|
1358 |
+
csv = convert_df(dffix)
|
1359 |
+
|
1360 |
+
single_file_tab_csv_tab.download_button(
|
1361 |
+
"Download corrected fixation data",
|
1362 |
+
csv,
|
1363 |
+
f'{trial["trial_id"]}.csv',
|
1364 |
+
"text/csv",
|
1365 |
+
key="download-csv-custom-csv",
|
1366 |
+
)
|
1367 |
+
with single_file_tab_csv_tab.expander("Show corrected fixation data", expanded=True):
|
1368 |
+
st.dataframe(dffix, use_container_width=True, hide_index=True, height=200)
|
1369 |
+
with single_file_tab_csv_tab.expander("Show fixation plots", expanded=True):
|
1370 |
+
plotting_checkboxes_single_single_csv = st.multiselect(
|
1371 |
+
"Select what gets plotted",
|
1372 |
+
["Uncorrected Fixations", "Corrected Fixations", "Words", "Word boxes"],
|
1373 |
+
key="plotting_checkboxes_single_single_csv",
|
1374 |
+
default=["Uncorrected Fixations", "Corrected Fixations", "Words", "Word boxes"],
|
1375 |
+
)
|
1376 |
+
|
1377 |
+
st.plotly_chart(
|
1378 |
+
plotly_plot_with_image(
|
1379 |
+
dffix,
|
1380 |
+
trial,
|
1381 |
+
to_plot_list=plotting_checkboxes_single_single_csv,
|
1382 |
+
algo_choice=algo_choice_single_csv,
|
1383 |
+
),
|
1384 |
+
use_container_width=True,
|
1385 |
+
)
|
1386 |
+
st.plotly_chart(plot_y_corr(dffix, algo_choice_single_csv), use_container_width=True)
|
1387 |
+
|
1388 |
+
|
1389 |
+
multi_file_tab.subheader("Upload multiple .asc files. Then select a correction algorithm and download the results.")
|
1390 |
+
|
1391 |
+
with multi_file_tab.form("Upload files to be processed and select algorithm"):
|
1392 |
+
multifile_col, multi_algo_col = st.columns((1, 1))
|
1393 |
+
|
1394 |
+
with multifile_col:
|
1395 |
+
st.file_uploader(
|
1396 |
+
"Upload .asc Files", accept_multiple_files=True, key="multi_asc_filelist", type=["asc", "tar", "zip"]
|
1397 |
+
)
|
1398 |
+
with multi_algo_col:
|
1399 |
+
st.multiselect(
|
1400 |
+
"Choose correction algorithms",
|
1401 |
+
ALGO_CHOICES,
|
1402 |
+
key="algo_choice_multi",
|
1403 |
+
default=[ALGO_CHOICES[0]],
|
1404 |
+
)
|
1405 |
+
process_trial_btn_multi = st.form_submit_button("Load and correct asc files")
|
1406 |
+
if process_trial_btn_multi:
|
1407 |
+
get_trials_and_lines_from_asc_files(st.session_state["multi_asc_filelist"])
|
1408 |
+
|
1409 |
+
if "zipfiles_with_results" in st.session_state:
|
1410 |
+
multi_res_col1, multi_res_col2 = multi_file_tab.columns(2)
|
1411 |
+
|
1412 |
+
chosen_zip = multi_res_col1.selectbox("Choose results to download", st.session_state["zipfiles_with_results"])
|
1413 |
+
st.session_state["logger"].info(f"Download button for {chosen_zip}")
|
1414 |
+
st.session_state["logger"].info(st.session_state["zipfiles_with_results"])
|
1415 |
+
zipnamestem = pl.Path(chosen_zip).stem
|
1416 |
+
with open(chosen_zip, "rb") as f:
|
1417 |
+
multi_res_col2.download_button(f"Download {zipnamestem}", f, file_name=f"results_{zipnamestem}.zip")
|
1418 |
+
|
1419 |
+
|
1420 |
+
if "trial_choices_multi" in st.session_state:
|
1421 |
+
multi_plotting_options_col1, multi_plotting_options_col2 = multi_file_tab.columns(2)
|
1422 |
+
|
1423 |
+
trial_choice_multi = multi_plotting_options_col1.selectbox(
|
1424 |
+
"Which trial should be plotted?",
|
1425 |
+
st.session_state["trial_choices_multi"],
|
1426 |
+
key="trial_id_multi",
|
1427 |
+
placeholder="Select trial to display and plot",
|
1428 |
+
on_change=process_trial_choice_and_update_df_multi,
|
1429 |
+
)
|
1430 |
+
|
1431 |
+
plotting_checkboxes_multi = multi_plotting_options_col2.multiselect(
|
1432 |
+
"Select what gets plotted",
|
1433 |
+
["Uncorrected Fixations", "Corrected Fixations", "Words", "Word boxes"],
|
1434 |
+
default=["Uncorrected Fixations", "Corrected Fixations", "Words", "Word boxes"],
|
1435 |
+
)
|
1436 |
+
|
1437 |
+
if trial_choice_multi and "dffix_multi" in st.session_state:
|
1438 |
+
df_expander_multi = multi_file_tab.expander("Show Dataframe", False)
|
1439 |
+
plot_expander_multi = multi_file_tab.expander("Show Plots", True)
|
1440 |
+
|
1441 |
+
df_expander_multi.dataframe(st.session_state["dffix_multi"])
|
1442 |
+
dffix_multi = st.session_state["dffix_multi"]
|
1443 |
+
trial_multi = st.session_state["trial_multi"]
|
1444 |
+
|
1445 |
+
plot_expander_multi.plotly_chart(
|
1446 |
+
plotly_plot_with_image(
|
1447 |
+
dffix_multi, trial_multi, st.session_state["algo_choice_multi"], to_plot_list=plotting_checkboxes_multi
|
1448 |
+
),
|
1449 |
+
use_container_width=True,
|
1450 |
+
)
|
1451 |
+
plot_expander_multi.plotly_chart(
|
1452 |
+
plot_y_corr(dffix_multi, st.session_state["algo_choice_multi"]), use_container_width=True
|
1453 |
+
)
|
classic_correction_algos.py
ADDED
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Mostly adapted from https://github.com/jwcarr/eyekit/blob/350d055eecaa1581b03db5a847424825ffbb10f6/eyekit/_snap.py
|
3 |
+
"""
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from sklearn.cluster import KMeans
|
7 |
+
|
8 |
+
|
9 |
+
def apply_classic_algo(
|
10 |
+
dffix,
|
11 |
+
trial,
|
12 |
+
algo="slice",
|
13 |
+
algo_params=dict(x_thresh=192, y_thresh=32, w_thresh=32, n_thresh=90),
|
14 |
+
):
|
15 |
+
fixation_array = dffix.loc[:, ["x", "y"]].values
|
16 |
+
y_diff = trial["y_diff"]
|
17 |
+
if "y_char_unique" in trial:
|
18 |
+
midlines = trial["y_char_unique"]
|
19 |
+
else:
|
20 |
+
midlines = trial["y_midline"]
|
21 |
+
if len(midlines) == 1:
|
22 |
+
corrected_fix_y_vals = np.ones((fixation_array.shape[0])) * midlines[0]
|
23 |
+
elif fixation_array.shape[0] <= 2:
|
24 |
+
corrected_fix_y_vals = np.ones((fixation_array.shape[0])) * midlines[0]
|
25 |
+
|
26 |
+
else:
|
27 |
+
if algo == "slice":
|
28 |
+
corrected_fix_y_vals = slice(fixation_array, midlines, line_height=y_diff, **algo_params)
|
29 |
+
elif algo == "warp":
|
30 |
+
word_center_list = [(word["word_x_center"], word["word_y_center"]) for word in trial["words_list"]]
|
31 |
+
corrected_fix_y_vals = warp(fixation_array, word_center_list)
|
32 |
+
elif algo == "chain":
|
33 |
+
corrected_fix_y_vals = chain(fixation_array, midlines, **algo_params)
|
34 |
+
elif algo == "cluster":
|
35 |
+
corrected_fix_y_vals = cluster(fixation_array, midlines)
|
36 |
+
elif algo == "merge":
|
37 |
+
corrected_fix_y_vals = merge(fixation_array, midlines, **algo_params)
|
38 |
+
elif algo == "regress":
|
39 |
+
corrected_fix_y_vals = regress(fixation_array, midlines, **algo_params)
|
40 |
+
elif algo == "segment":
|
41 |
+
corrected_fix_y_vals = segment(fixation_array, midlines, **algo_params)
|
42 |
+
elif algo == "split":
|
43 |
+
corrected_fix_y_vals = split(fixation_array, midlines, **algo_params)
|
44 |
+
elif algo == "stretch":
|
45 |
+
corrected_fix_y_vals = stretch(fixation_array, midlines, **algo_params)
|
46 |
+
elif algo == "attach":
|
47 |
+
corrected_fix_y_vals = attach(fixation_array, midlines)
|
48 |
+
elif algo == "compare":
|
49 |
+
word_center_list = [(word["word_x_center"], word["word_y_center"]) for word in trial["words_list"]]
|
50 |
+
n_nearest_lines = min(algo_params["n_nearest_lines"], len(midlines) - 1)
|
51 |
+
algo_params["n_nearest_lines"] = n_nearest_lines
|
52 |
+
corrected_fix_y_vals = compare(fixation_array, np.array(word_center_list), **algo_params)
|
53 |
+
else:
|
54 |
+
raise NotImplementedError(f"{algo} not implemented")
|
55 |
+
|
56 |
+
corrected_line_nums = [trial["y_char_unique"].index(y) for y in corrected_fix_y_vals]
|
57 |
+
dffix[f"y_{algo}"] = corrected_fix_y_vals
|
58 |
+
dffix[f"line_num_{algo}"] = corrected_line_nums
|
59 |
+
return dffix
|
60 |
+
|
61 |
+
|
62 |
+
def slice(fixation_XY, midlines, line_height: float, x_thresh=192, y_thresh=32, w_thresh=32, n_thresh=90):
|
63 |
+
"""
|
64 |
+
Adapted from Eyekit(https://github.com/jwcarr/eyekit/blob/350d055eecaa1581b03db5a847424825ffbb10f6/eyekit/_snap.py)
|
65 |
+
implementation
|
66 |
+
|
67 |
+
Form a set of runs and then reduce the set to *m* by repeatedly merging
|
68 |
+
those that appear to be on the same line. Merged sequences are then
|
69 |
+
assigned to text lines in positional order. Default params:
|
70 |
+
`x_thresh=192`, `y_thresh=32`, `w_thresh=32`, `n_thresh=90`. Requires
|
71 |
+
NumPy. Original method by [Glandorf & Schroeder (2021)](https://doi.org/10.1016/j.procs.2021.09.069).
|
72 |
+
"""
|
73 |
+
fixation_XY = np.array(fixation_XY, dtype=float)
|
74 |
+
line_Y = np.array(midlines, dtype=float)
|
75 |
+
proto_lines, phantom_proto_lines = {}, {}
|
76 |
+
# 1. Segment runs
|
77 |
+
dist_X = abs(np.diff(fixation_XY[:, 0]))
|
78 |
+
dist_Y = abs(np.diff(fixation_XY[:, 1]))
|
79 |
+
end_run_indices = list(np.where(np.logical_or(dist_X > x_thresh, dist_Y > y_thresh))[0] + 1)
|
80 |
+
run_starts = [0] + end_run_indices
|
81 |
+
run_ends = end_run_indices + [len(fixation_XY)]
|
82 |
+
runs = [list(range(start, end)) for start, end in zip(run_starts, run_ends)]
|
83 |
+
# 2. Determine starting run
|
84 |
+
longest_run_i = np.argmax([fixation_XY[run[-1], 0] - fixation_XY[run[0], 0] for run in runs])
|
85 |
+
proto_lines[0] = runs.pop(longest_run_i)
|
86 |
+
# 3. Group runs into proto lines
|
87 |
+
while runs:
|
88 |
+
merger_on_this_iteration = False
|
89 |
+
for proto_line_i, direction in [(min(proto_lines), -1), (max(proto_lines), 1)]:
|
90 |
+
# Create new proto line above or below (depending on direction)
|
91 |
+
proto_lines[proto_line_i + direction] = []
|
92 |
+
# Get current proto line XY coordinates (if proto line is empty, get phanton coordinates)
|
93 |
+
if proto_lines[proto_line_i]:
|
94 |
+
proto_line_XY = fixation_XY[proto_lines[proto_line_i]]
|
95 |
+
else:
|
96 |
+
proto_line_XY = phantom_proto_lines[proto_line_i]
|
97 |
+
# Compute differences between current proto line and all runs
|
98 |
+
run_differences = np.zeros(len(runs))
|
99 |
+
for run_i, run in enumerate(runs):
|
100 |
+
y_diffs = [y - proto_line_XY[np.argmin(abs(proto_line_XY[:, 0] - x)), 1] for x, y in fixation_XY[run]]
|
101 |
+
run_differences[run_i] = np.mean(y_diffs)
|
102 |
+
# Find runs that can be merged into this proto line
|
103 |
+
merge_into_current = list(np.where(abs(run_differences) < w_thresh)[0])
|
104 |
+
# Find runs that can be merged into the adjacent proto line
|
105 |
+
merge_into_adjacent = list(
|
106 |
+
np.where(
|
107 |
+
np.logical_and(
|
108 |
+
run_differences * direction >= w_thresh,
|
109 |
+
run_differences * direction < n_thresh,
|
110 |
+
)
|
111 |
+
)[0]
|
112 |
+
)
|
113 |
+
# Perform mergers
|
114 |
+
for index in merge_into_current:
|
115 |
+
proto_lines[proto_line_i].extend(runs[index])
|
116 |
+
for index in merge_into_adjacent:
|
117 |
+
proto_lines[proto_line_i + direction].extend(runs[index])
|
118 |
+
# If no, mergers to the adjacent, create phantom line for the adjacent
|
119 |
+
if not merge_into_adjacent:
|
120 |
+
average_x, average_y = np.mean(proto_line_XY, axis=0)
|
121 |
+
adjacent_y = average_y + line_height * direction
|
122 |
+
phantom_proto_lines[proto_line_i + direction] = np.array([[average_x, adjacent_y]])
|
123 |
+
# Remove all runs that were merged on this iteration
|
124 |
+
for index in sorted(merge_into_current + merge_into_adjacent, reverse=True):
|
125 |
+
del runs[index]
|
126 |
+
merger_on_this_iteration = True
|
127 |
+
# If no mergers were made, break the while loop
|
128 |
+
if not merger_on_this_iteration:
|
129 |
+
break
|
130 |
+
# 4. Assign any leftover runs to the closest proto lines
|
131 |
+
for run in runs:
|
132 |
+
best_pl_distance = np.inf
|
133 |
+
best_pl_assignemnt = None
|
134 |
+
for proto_line_i in proto_lines:
|
135 |
+
if proto_lines[proto_line_i]:
|
136 |
+
proto_line_XY = fixation_XY[proto_lines[proto_line_i]]
|
137 |
+
else:
|
138 |
+
proto_line_XY = phantom_proto_lines[proto_line_i]
|
139 |
+
y_diffs = [y - proto_line_XY[np.argmin(abs(proto_line_XY[:, 0] - x)), 1] for x, y in fixation_XY[run]]
|
140 |
+
pl_distance = abs(np.mean(y_diffs))
|
141 |
+
if pl_distance < best_pl_distance:
|
142 |
+
best_pl_distance = pl_distance
|
143 |
+
best_pl_assignemnt = proto_line_i
|
144 |
+
proto_lines[best_pl_assignemnt].extend(run)
|
145 |
+
# 5. Prune proto lines
|
146 |
+
while len(proto_lines) > len(line_Y):
|
147 |
+
top, bot = min(proto_lines), max(proto_lines)
|
148 |
+
if len(proto_lines[top]) < len(proto_lines[bot]):
|
149 |
+
proto_lines[top + 1].extend(proto_lines[top])
|
150 |
+
del proto_lines[top]
|
151 |
+
else:
|
152 |
+
proto_lines[bot - 1].extend(proto_lines[bot])
|
153 |
+
del proto_lines[bot]
|
154 |
+
# 6. Map proto lines to text lines
|
155 |
+
for line_i, proto_line_i in enumerate(sorted(proto_lines)):
|
156 |
+
fixation_XY[proto_lines[proto_line_i], 1] = line_Y[line_i]
|
157 |
+
return fixation_XY[:, 1]
|
158 |
+
|
159 |
+
|
160 |
+
def attach(fixation_XY, line_Y):
|
161 |
+
n = len(fixation_XY)
|
162 |
+
for fixation_i in range(n):
|
163 |
+
line_i = np.argmin(abs(line_Y - fixation_XY[fixation_i, 1]))
|
164 |
+
fixation_XY[fixation_i, 1] = line_Y[line_i]
|
165 |
+
return fixation_XY[:, 1]
|
166 |
+
|
167 |
+
|
168 |
+
def chain(fixation_XY, midlines, x_thresh=192, y_thresh=32):
|
169 |
+
"""
|
170 |
+
Adapted from Eyekit(https://github.com/jwcarr/eyekit/blob/350d055eecaa1581b03db5a847424825ffbb10f6/eyekit/_snap.py)
|
171 |
+
implementation
|
172 |
+
Chain consecutive fixations that are sufficiently close to each other, and
|
173 |
+
then assign chains to their closest text lines. Default params:
|
174 |
+
`x_thresh=192`, `y_thresh=32`. Requires NumPy. Original method
|
175 |
+
implemented in [popEye](https://github.com/sascha2schroeder/popEye/).
|
176 |
+
"""
|
177 |
+
try:
|
178 |
+
import numpy as np
|
179 |
+
except ModuleNotFoundError as e:
|
180 |
+
e.msg = "The chain method requires NumPy."
|
181 |
+
raise
|
182 |
+
fixation_XY = np.array(fixation_XY)
|
183 |
+
line_Y = np.array(midlines)
|
184 |
+
dist_X = abs(np.diff(fixation_XY[:, 0]))
|
185 |
+
dist_Y = abs(np.diff(fixation_XY[:, 1]))
|
186 |
+
end_chain_indices = list(np.where(np.logical_or(dist_X > x_thresh, dist_Y > y_thresh))[0] + 1)
|
187 |
+
end_chain_indices.append(len(fixation_XY))
|
188 |
+
start_of_chain = 0
|
189 |
+
for end_of_chain in end_chain_indices:
|
190 |
+
mean_y = np.mean(fixation_XY[start_of_chain:end_of_chain, 1])
|
191 |
+
line_i = np.argmin(abs(line_Y - mean_y))
|
192 |
+
fixation_XY[start_of_chain:end_of_chain, 1] = line_Y[line_i]
|
193 |
+
start_of_chain = end_of_chain
|
194 |
+
return fixation_XY[:, 1]
|
195 |
+
|
196 |
+
|
197 |
+
def cluster(fixation_XY, line_Y):
|
198 |
+
m = len(line_Y)
|
199 |
+
fixation_Y = fixation_XY[:, 1].reshape(-1, 1)
|
200 |
+
clusters = KMeans(m, n_init=100, max_iter=300).fit_predict(fixation_Y)
|
201 |
+
centers = [fixation_Y[clusters == i].mean() for i in range(m)]
|
202 |
+
ordered_cluster_indices = np.argsort(centers)
|
203 |
+
for fixation_i, cluster_i in enumerate(clusters):
|
204 |
+
line_i = np.where(ordered_cluster_indices == cluster_i)[0][0]
|
205 |
+
fixation_XY[fixation_i, 1] = line_Y[line_i]
|
206 |
+
return fixation_XY[:, 1]
|
207 |
+
|
208 |
+
|
209 |
+
def compare(fixation_XY, word_XY, x_thresh=512, n_nearest_lines=3):
|
210 |
+
# COMPARE
|
211 |
+
#
|
212 |
+
# Lima Sanches, C., Kise, K., & Augereau, O. (2015). Eye gaze and text
|
213 |
+
# line matching for reading analysis. In Adjunct proceedings of the
|
214 |
+
# 2015 ACM International Joint Conference on Pervasive and
|
215 |
+
# Ubiquitous Computing and proceedings of the 2015 ACM International
|
216 |
+
# Symposium on Wearable Computers (pp. 1227–1233). Association for
|
217 |
+
# Computing Machinery.
|
218 |
+
#
|
219 |
+
# https://doi.org/10.1145/2800835.2807936
|
220 |
+
line_Y = np.unique(word_XY[:, 1])
|
221 |
+
n = len(fixation_XY)
|
222 |
+
diff_X = np.diff(fixation_XY[:, 0])
|
223 |
+
end_line_indices = list(np.where(diff_X < -x_thresh)[0] + 1)
|
224 |
+
end_line_indices.append(n)
|
225 |
+
start_of_line = 0
|
226 |
+
for end_of_line in end_line_indices:
|
227 |
+
gaze_line = fixation_XY[start_of_line:end_of_line]
|
228 |
+
mean_y = np.mean(gaze_line[:, 1])
|
229 |
+
lines_ordered_by_proximity = np.argsort(abs(line_Y - mean_y))
|
230 |
+
nearest_line_I = lines_ordered_by_proximity[:n_nearest_lines]
|
231 |
+
line_costs = np.zeros(n_nearest_lines)
|
232 |
+
for candidate_i in range(n_nearest_lines):
|
233 |
+
candidate_line_i = nearest_line_I[candidate_i]
|
234 |
+
text_line = word_XY[word_XY[:, 1] == line_Y[candidate_line_i]]
|
235 |
+
dtw_cost, dtw_path = dynamic_time_warping(gaze_line[:, 0:1], text_line[:, 0:1])
|
236 |
+
line_costs[candidate_i] = dtw_cost
|
237 |
+
line_i = nearest_line_I[np.argmin(line_costs)]
|
238 |
+
fixation_XY[start_of_line:end_of_line, 1] = line_Y[line_i]
|
239 |
+
start_of_line = end_of_line
|
240 |
+
return fixation_XY[:, 1]
|
241 |
+
|
242 |
+
|
243 |
+
def merge(fixation_XY, midlines, text_right_to_left=False, y_thresh=32, gradient_thresh=0.1, error_thresh=20):
|
244 |
+
"""
|
245 |
+
Form a set of progressive sequences and then reduce the set to *m* by
|
246 |
+
repeatedly merging those that appear to be on the same line. Merged
|
247 |
+
sequences are then assigned to text lines in positional order. Default
|
248 |
+
params: `y_thresh=32`, `gradient_thresh=0.1`, `error_thresh=20`. Requires
|
249 |
+
NumPy. Original method by [Špakov et al. (2019)](https://doi.org/10.3758/s13428-018-1120-x).
|
250 |
+
"""
|
251 |
+
try:
|
252 |
+
import numpy as np
|
253 |
+
except ModuleNotFoundError as e:
|
254 |
+
e.msg = "The merge method requires NumPy."
|
255 |
+
raise
|
256 |
+
fixation_XY = np.array(fixation_XY)
|
257 |
+
line_Y = np.array(midlines)
|
258 |
+
diff_X = np.diff(fixation_XY[:, 0])
|
259 |
+
dist_Y = abs(np.diff(fixation_XY[:, 1]))
|
260 |
+
if text_right_to_left:
|
261 |
+
sequence_boundaries = list(np.where(np.logical_or(diff_X > 0, dist_Y > y_thresh))[0] + 1)
|
262 |
+
else:
|
263 |
+
sequence_boundaries = list(np.where(np.logical_or(diff_X < 0, dist_Y > y_thresh))[0] + 1)
|
264 |
+
sequence_starts = [0] + sequence_boundaries
|
265 |
+
sequence_ends = sequence_boundaries + [len(fixation_XY)]
|
266 |
+
sequences = [list(range(start, end)) for start, end in zip(sequence_starts, sequence_ends)]
|
267 |
+
for min_i, min_j, remove_constraints in [
|
268 |
+
(3, 3, False), # Phase 1
|
269 |
+
(1, 3, False), # Phase 2
|
270 |
+
(1, 1, False), # Phase 3
|
271 |
+
(1, 1, True), # Phase 4
|
272 |
+
]:
|
273 |
+
while len(sequences) > len(line_Y):
|
274 |
+
best_merger = None
|
275 |
+
best_error = np.inf
|
276 |
+
for i in range(len(sequences) - 1):
|
277 |
+
if len(sequences[i]) < min_i:
|
278 |
+
continue # first sequence too short, skip to next i
|
279 |
+
for j in range(i + 1, len(sequences)):
|
280 |
+
if len(sequences[j]) < min_j:
|
281 |
+
continue # second sequence too short, skip to next j
|
282 |
+
candidate_XY = fixation_XY[sequences[i] + sequences[j]]
|
283 |
+
gradient, intercept = np.polyfit(candidate_XY[:, 0], candidate_XY[:, 1], 1)
|
284 |
+
residuals = candidate_XY[:, 1] - (gradient * candidate_XY[:, 0] + intercept)
|
285 |
+
error = np.sqrt(sum(residuals**2) / len(candidate_XY))
|
286 |
+
if remove_constraints or (abs(gradient) < gradient_thresh and error < error_thresh):
|
287 |
+
if error < best_error:
|
288 |
+
best_merger = (i, j)
|
289 |
+
best_error = error
|
290 |
+
if best_merger is None:
|
291 |
+
break # no possible mergers, break while and move to next phase
|
292 |
+
merge_i, merge_j = best_merger
|
293 |
+
merged_sequence = sequences[merge_i] + sequences[merge_j]
|
294 |
+
sequences.append(merged_sequence)
|
295 |
+
del sequences[merge_j], sequences[merge_i]
|
296 |
+
mean_Y = [fixation_XY[sequence, 1].mean() for sequence in sequences]
|
297 |
+
ordered_sequence_indices = np.argsort(mean_Y)
|
298 |
+
for line_i, sequence_i in enumerate(ordered_sequence_indices):
|
299 |
+
fixation_XY[sequences[sequence_i], 1] = line_Y[line_i]
|
300 |
+
return fixation_XY[:, 1]
|
301 |
+
|
302 |
+
|
303 |
+
def regress(
|
304 |
+
fixation_XY,
|
305 |
+
midlines,
|
306 |
+
slope_bounds=(-0.1, 0.1),
|
307 |
+
offset_bounds=(-50, 50),
|
308 |
+
std_bounds=(1, 20),
|
309 |
+
):
|
310 |
+
"""
|
311 |
+
Find *m* regression lines that best fit the fixations and group fixations
|
312 |
+
according to best fit regression lines, and then assign groups to text
|
313 |
+
lines in positional order. Default params: `slope_bounds=(-0.1, 0.1)`,
|
314 |
+
`offset_bounds=(-50, 50)`, `std_bounds=(1, 20)`. Requires SciPy.
|
315 |
+
Original method by [Cohen (2013)](https://doi.org/10.3758/s13428-012-0280-3).
|
316 |
+
"""
|
317 |
+
try:
|
318 |
+
import numpy as np
|
319 |
+
from scipy.optimize import minimize
|
320 |
+
from scipy.stats import norm
|
321 |
+
except ModuleNotFoundError as e:
|
322 |
+
e.msg = "The regress method requires SciPy."
|
323 |
+
raise
|
324 |
+
fixation_XY = np.array(fixation_XY)
|
325 |
+
line_Y = np.array(midlines)
|
326 |
+
density = np.zeros((len(fixation_XY), len(line_Y)))
|
327 |
+
|
328 |
+
def fit_lines(params):
|
329 |
+
k = slope_bounds[0] + (slope_bounds[1] - slope_bounds[0]) * norm.cdf(params[0])
|
330 |
+
o = offset_bounds[0] + (offset_bounds[1] - offset_bounds[0]) * norm.cdf(params[1])
|
331 |
+
s = std_bounds[0] + (std_bounds[1] - std_bounds[0]) * norm.cdf(params[2])
|
332 |
+
predicted_Y_from_slope = fixation_XY[:, 0] * k
|
333 |
+
line_Y_plus_offset = line_Y + o
|
334 |
+
for line_i in range(len(line_Y)):
|
335 |
+
fit_Y = predicted_Y_from_slope + line_Y_plus_offset[line_i]
|
336 |
+
density[:, line_i] = norm.logpdf(fixation_XY[:, 1], fit_Y, s)
|
337 |
+
return -sum(density.max(axis=1))
|
338 |
+
|
339 |
+
best_fit = minimize(fit_lines, [0, 0, 0], method="powell")
|
340 |
+
fit_lines(best_fit.x)
|
341 |
+
return line_Y[density.argmax(axis=1)]
|
342 |
+
|
343 |
+
|
344 |
+
def segment(fixation_XY, midlines, text_right_to_left=False):
|
345 |
+
"""
|
346 |
+
Segment fixation sequence into *m* subsequences based on *m*–1 most-likely
|
347 |
+
return sweeps, and then assign subsequences to text lines in chronological
|
348 |
+
order. Requires NumPy. Original method by
|
349 |
+
[Abdulin & Komogortsev (2015)](https://doi.org/10.1109/BTAS.2015.7358786).
|
350 |
+
"""
|
351 |
+
try:
|
352 |
+
import numpy as np
|
353 |
+
except ModuleNotFoundError as e:
|
354 |
+
e.msg = "The segment method requires NumPy."
|
355 |
+
raise
|
356 |
+
fixation_XY = np.array(fixation_XY)
|
357 |
+
line_Y = np.array(midlines)
|
358 |
+
diff_X = np.diff(fixation_XY[:, 0])
|
359 |
+
saccades_ordered_by_length = np.argsort(diff_X)
|
360 |
+
if text_right_to_left:
|
361 |
+
line_change_indices = saccades_ordered_by_length[-(len(line_Y) - 1) :]
|
362 |
+
else:
|
363 |
+
line_change_indices = saccades_ordered_by_length[: len(line_Y) - 1]
|
364 |
+
current_line_i = 0
|
365 |
+
for fixation_i in range(len(fixation_XY)):
|
366 |
+
fixation_XY[fixation_i, 1] = line_Y[current_line_i]
|
367 |
+
if fixation_i in line_change_indices:
|
368 |
+
current_line_i += 1
|
369 |
+
return fixation_XY[:, 1]
|
370 |
+
|
371 |
+
|
372 |
+
def split(fixation_XY, midlines, text_right_to_left=False):
|
373 |
+
"""
|
374 |
+
Split fixation sequence into subsequences based on best candidate return
|
375 |
+
sweeps, and then assign subsequences to closest text lines. Requires
|
376 |
+
SciPy. Original method by [Carr et al. (2022)](https://doi.org/10.3758/s13428-021-01554-0).
|
377 |
+
"""
|
378 |
+
try:
|
379 |
+
import numpy as np
|
380 |
+
from scipy.cluster.vq import kmeans2
|
381 |
+
except ModuleNotFoundError as e:
|
382 |
+
e.msg = "The split method requires SciPy."
|
383 |
+
raise
|
384 |
+
fixation_XY = np.array(fixation_XY)
|
385 |
+
line_Y = np.array(midlines)
|
386 |
+
diff_X = np.array(np.diff(fixation_XY[:, 0]), dtype=float).reshape(-1, 1)
|
387 |
+
centers, clusters = kmeans2(diff_X, 2, iter=100, minit="++", missing="raise")
|
388 |
+
if text_right_to_left:
|
389 |
+
sweep_marker = np.argmax(centers)
|
390 |
+
else:
|
391 |
+
sweep_marker = np.argmin(centers)
|
392 |
+
end_line_indices = list(np.where(clusters == sweep_marker)[0] + 1)
|
393 |
+
end_line_indices.append(len(fixation_XY))
|
394 |
+
start_of_line = 0
|
395 |
+
for end_of_line in end_line_indices:
|
396 |
+
mean_y = np.mean(fixation_XY[start_of_line:end_of_line, 1])
|
397 |
+
line_i = np.argmin(abs(line_Y - mean_y))
|
398 |
+
fixation_XY[start_of_line:end_of_line] = line_Y[line_i]
|
399 |
+
start_of_line = end_of_line
|
400 |
+
return fixation_XY[:, 1]
|
401 |
+
|
402 |
+
|
403 |
+
def stretch(fixation_XY, midlines, stretch_bounds=(0.9, 1.1), offset_bounds=(-50, 50)):
|
404 |
+
"""
|
405 |
+
Find a stretch factor and offset that results in a good alignment between
|
406 |
+
the fixations and lines of text, and then assign the transformed fixations
|
407 |
+
to the closest text lines. Default params: `stretch_bounds=(0.9, 1.1)`,
|
408 |
+
`offset_bounds=(-50, 50)`. Requires SciPy.
|
409 |
+
Original method by [Lohmeier (2015)](http://www.monochromata.de/master_thesis/ma1.3.pdf).
|
410 |
+
"""
|
411 |
+
try:
|
412 |
+
import numpy as np
|
413 |
+
from scipy.optimize import minimize
|
414 |
+
except ModuleNotFoundError as e:
|
415 |
+
e.msg = "The stretch method requires SciPy."
|
416 |
+
raise
|
417 |
+
fixation_Y = np.array(fixation_XY)[:, 1]
|
418 |
+
line_Y = np.array(midlines)
|
419 |
+
n = len(fixation_Y)
|
420 |
+
corrected_Y = np.zeros(n)
|
421 |
+
|
422 |
+
def fit_lines(params):
|
423 |
+
candidate_Y = fixation_Y * params[0] + params[1]
|
424 |
+
for fixation_i in range(n):
|
425 |
+
line_i = np.argmin(abs(line_Y - candidate_Y[fixation_i]))
|
426 |
+
corrected_Y[fixation_i] = line_Y[line_i]
|
427 |
+
return sum(abs(candidate_Y - corrected_Y))
|
428 |
+
|
429 |
+
best_fit = minimize(fit_lines, [1, 0], method="powell", bounds=[stretch_bounds, offset_bounds])
|
430 |
+
fit_lines(best_fit.x)
|
431 |
+
return corrected_Y
|
432 |
+
|
433 |
+
|
434 |
+
def warp(fixation_XY, word_center_list):
|
435 |
+
"""
|
436 |
+
Map fixations to word centers using [Dynamic Time
|
437 |
+
Warping](https://en.wikipedia.org/wiki/Dynamic_time_warping). This finds a
|
438 |
+
monotonically increasing mapping between fixations and words with the
|
439 |
+
shortest overall distance, effectively resulting in *m* subsequences.
|
440 |
+
Fixations are then assigned to the lines that their mapped words belong
|
441 |
+
to, effectively assigning subsequences to text lines in chronological
|
442 |
+
order. Requires NumPy.
|
443 |
+
Original method by [Carr et al. (2022)](https://doi.org/10.3758/s13428-021-01554-0).
|
444 |
+
"""
|
445 |
+
try:
|
446 |
+
import numpy as np
|
447 |
+
except ModuleNotFoundError as e:
|
448 |
+
e.msg = "The warp method requires NumPy."
|
449 |
+
raise
|
450 |
+
fixation_XY = np.array(fixation_XY)
|
451 |
+
word_XY = np.array([word_center for word_center in word_center_list])
|
452 |
+
n1 = len(fixation_XY)
|
453 |
+
n2 = len(word_XY)
|
454 |
+
cost = np.zeros((n1 + 1, n2 + 1))
|
455 |
+
cost[0, :] = np.inf
|
456 |
+
cost[:, 0] = np.inf
|
457 |
+
cost[0, 0] = 0
|
458 |
+
for fixation_i in range(n1):
|
459 |
+
for word_i in range(n2):
|
460 |
+
distance = np.sqrt(sum((fixation_XY[fixation_i] - word_XY[word_i]) ** 2))
|
461 |
+
cost[fixation_i + 1, word_i + 1] = distance + min(
|
462 |
+
cost[fixation_i, word_i + 1],
|
463 |
+
cost[fixation_i + 1, word_i],
|
464 |
+
cost[fixation_i, word_i],
|
465 |
+
)
|
466 |
+
cost = cost[1:, 1:]
|
467 |
+
warping_path = [[] for _ in range(n1)]
|
468 |
+
while fixation_i > 0 or word_i > 0:
|
469 |
+
warping_path[fixation_i].append(word_i)
|
470 |
+
possible_moves = [np.inf, np.inf, np.inf]
|
471 |
+
if fixation_i > 0 and word_i > 0:
|
472 |
+
possible_moves[0] = cost[fixation_i - 1, word_i - 1]
|
473 |
+
if fixation_i > 0:
|
474 |
+
possible_moves[1] = cost[fixation_i - 1, word_i]
|
475 |
+
if word_i > 0:
|
476 |
+
possible_moves[2] = cost[fixation_i, word_i - 1]
|
477 |
+
best_move = np.argmin(possible_moves)
|
478 |
+
if best_move == 0:
|
479 |
+
fixation_i -= 1
|
480 |
+
word_i -= 1
|
481 |
+
elif best_move == 1:
|
482 |
+
fixation_i -= 1
|
483 |
+
else:
|
484 |
+
word_i -= 1
|
485 |
+
warping_path[0].append(0)
|
486 |
+
for fixation_i, words_mapped_to_fixation_i in enumerate(warping_path):
|
487 |
+
candidate_Y = list(word_XY[words_mapped_to_fixation_i, 1])
|
488 |
+
fixation_XY[fixation_i, 1] = max(set(candidate_Y), key=candidate_Y.count)
|
489 |
+
return fixation_XY[:, 1]
|
490 |
+
|
491 |
+
|
492 |
+
def dynamic_time_warping(sequence1, sequence2):
|
493 |
+
n1 = len(sequence1)
|
494 |
+
n2 = len(sequence2)
|
495 |
+
dtw_cost = np.zeros((n1 + 1, n2 + 1))
|
496 |
+
dtw_cost[0, :] = np.inf
|
497 |
+
dtw_cost[:, 0] = np.inf
|
498 |
+
dtw_cost[0, 0] = 0
|
499 |
+
for i in range(n1):
|
500 |
+
for j in range(n2):
|
501 |
+
this_cost = np.sqrt(sum((sequence1[i] - sequence2[j]) ** 2))
|
502 |
+
dtw_cost[i + 1, j + 1] = this_cost + min(dtw_cost[i, j + 1], dtw_cost[i + 1, j], dtw_cost[i, j])
|
503 |
+
dtw_cost = dtw_cost[1:, 1:]
|
504 |
+
dtw_path = [[] for _ in range(n1)]
|
505 |
+
while i > 0 or j > 0:
|
506 |
+
dtw_path[i].append(j)
|
507 |
+
possible_moves = [np.inf, np.inf, np.inf]
|
508 |
+
if i > 0 and j > 0:
|
509 |
+
possible_moves[0] = dtw_cost[i - 1, j - 1]
|
510 |
+
if i > 0:
|
511 |
+
possible_moves[1] = dtw_cost[i - 1, j]
|
512 |
+
if j > 0:
|
513 |
+
possible_moves[2] = dtw_cost[i, j - 1]
|
514 |
+
best_move = np.argmin(possible_moves)
|
515 |
+
if best_move == 0:
|
516 |
+
i -= 1
|
517 |
+
j -= 1
|
518 |
+
elif best_move == 1:
|
519 |
+
i -= 1
|
520 |
+
else:
|
521 |
+
j -= 1
|
522 |
+
dtw_path[0].append(0)
|
523 |
+
return dtw_cost[-1, -1], dtw_path
|
524 |
+
|
525 |
+
|
526 |
+
def wisdom_of_the_crowd(assignments):
|
527 |
+
"""
|
528 |
+
For each fixation, choose the y-value with the most votes across multiple
|
529 |
+
algorithms. In the event of a tie, the left-most algorithm is given
|
530 |
+
priority.
|
531 |
+
"""
|
532 |
+
assignments = np.column_stack(assignments)
|
533 |
+
correction = []
|
534 |
+
for row in assignments:
|
535 |
+
candidates = list(row)
|
536 |
+
candidate_counts = {y: candidates.count(y) for y in set(candidates)}
|
537 |
+
best_count = max(candidate_counts.values())
|
538 |
+
best_candidates = [y for y, c in candidate_counts.items() if c == best_count]
|
539 |
+
if len(best_candidates) == 1:
|
540 |
+
correction.append(best_candidates[0])
|
541 |
+
else:
|
542 |
+
for y in row:
|
543 |
+
if y in best_candidates:
|
544 |
+
correction.append(y)
|
545 |
+
break
|
546 |
+
return correction
|
eyekit_measures.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import eyekit as ek
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
MEASURES_DICT = {
|
9 |
+
"number_of_fixations": [],
|
10 |
+
"initial_fixation_duration": [],
|
11 |
+
"first_of_many_duration": [],
|
12 |
+
"total_fixation_duration": [],
|
13 |
+
"gaze_duration": [],
|
14 |
+
"go_past_duration": [],
|
15 |
+
"second_pass_duration": [],
|
16 |
+
"initial_landing_position": [],
|
17 |
+
"initial_landing_distance": [],
|
18 |
+
"landing_distances": [],
|
19 |
+
"number_of_regressions_in": [],
|
20 |
+
}
|
21 |
+
|
22 |
+
|
23 |
+
def get_fix_seq_and_text_block(
|
24 |
+
dffix,
|
25 |
+
trial,
|
26 |
+
x_txt_start=None,
|
27 |
+
y_txt_start=None,
|
28 |
+
font_face="Courier New",
|
29 |
+
font_size=None,
|
30 |
+
line_height=None,
|
31 |
+
use_corrected_fixations=True,
|
32 |
+
correction_algo="warp",
|
33 |
+
):
|
34 |
+
if use_corrected_fixations and correction_algo is not None:
|
35 |
+
fixations_tuples = [
|
36 |
+
(
|
37 |
+
(x[1]["x"], x[1][f"y_{correction_algo}"], x[1]["corrected_start_time"], x[1]["corrected_end_time"])
|
38 |
+
if x[1]["corrected_start_time"] < x[1]["corrected_end_time"]
|
39 |
+
else (x[1]["x"], x[1]["y"], x[1]["corrected_start_time"], x[1]["corrected_end_time"] + 1)
|
40 |
+
)
|
41 |
+
for x in dffix.iterrows()
|
42 |
+
]
|
43 |
+
else:
|
44 |
+
fixations_tuples = [
|
45 |
+
(
|
46 |
+
(x[1]["x"], x[1]["y"], x[1]["corrected_start_time"], x[1]["corrected_end_time"])
|
47 |
+
if x[1]["corrected_start_time"] < x[1]["corrected_end_time"]
|
48 |
+
else (x[1]["x"], x[1]["y"], x[1]["corrected_start_time"], x[1]["corrected_end_time"] + 1)
|
49 |
+
)
|
50 |
+
for x in dffix.iterrows()
|
51 |
+
]
|
52 |
+
|
53 |
+
try:
|
54 |
+
fixation_sequence = ek.FixationSequence(fixations_tuples)
|
55 |
+
except Exception as e:
|
56 |
+
print(e)
|
57 |
+
print(f"Creating fixation failed for {trial['trial_id']} {trial['filename']}")
|
58 |
+
return dffix
|
59 |
+
|
60 |
+
if "display_coords" in trial:
|
61 |
+
display_coords = trial["display_coords"]
|
62 |
+
else:
|
63 |
+
display_coords = (0, 0, 1920, 1080)
|
64 |
+
screen_size = ((display_coords[2] - display_coords[0]), (display_coords[3] - display_coords[1]))
|
65 |
+
|
66 |
+
y_diffs = np.unique(trial["line_heights"])
|
67 |
+
if len(y_diffs) == 1:
|
68 |
+
y_diff = y_diffs[0]
|
69 |
+
else:
|
70 |
+
y_diff = np.min(y_diffs)
|
71 |
+
chars_list = trial["chars_list"]
|
72 |
+
max_line = int(chars_list[-1]["assigned_line"])
|
73 |
+
words_on_lines = {x: [] for x in range(int(max_line) + 1)}
|
74 |
+
[words_on_lines[x["assigned_line"]].append(x["char"]) for x in chars_list]
|
75 |
+
sentence_list = ["".join([s for s in v]) for idx, v in words_on_lines.items()]
|
76 |
+
|
77 |
+
if x_txt_start is None:
|
78 |
+
x_txt_start = float(chars_list[0]["char_xmin"])
|
79 |
+
if y_txt_start is None:
|
80 |
+
y_txt_start = float(chars_list[0]["char_ymax"])
|
81 |
+
|
82 |
+
if font_face is None and "font" in trial:
|
83 |
+
font_face = trial["font"]
|
84 |
+
elif font_face is None:
|
85 |
+
font_face = "DejaVu Sans Mono"
|
86 |
+
|
87 |
+
if font_size is None and "font_size" in trial:
|
88 |
+
font_size = trial["font_size"]
|
89 |
+
elif font_size is None:
|
90 |
+
font_size = float(y_diff * 0.333) # pixel to point conversion
|
91 |
+
if line_height is None:
|
92 |
+
line_height = float(y_diff)
|
93 |
+
textblock = ek.TextBlock(
|
94 |
+
sentence_list,
|
95 |
+
position=(float(x_txt_start), float(y_txt_start)),
|
96 |
+
font_face=font_face,
|
97 |
+
line_height=line_height,
|
98 |
+
font_size=font_size,
|
99 |
+
anchor="left",
|
100 |
+
align="left",
|
101 |
+
)
|
102 |
+
|
103 |
+
# eyekit_plot(textblock, fixation_sequence, screen_size)
|
104 |
+
ek.io.save(fixation_sequence, f'results/fixation_sequence_eyekit_{trial["trial_id"]}.json', compress=False)
|
105 |
+
ek.io.save(textblock, f'results/textblock_eyekit_{trial["trial_id"]}.json', compress=False)
|
106 |
+
|
107 |
+
return fixation_sequence, textblock, screen_size
|
108 |
+
|
109 |
+
|
110 |
+
def eyekit_plot(textblock, fixation_sequence, screen_size):
|
111 |
+
img = ek.vis.Image(*screen_size)
|
112 |
+
img.draw_text_block(textblock)
|
113 |
+
for word in textblock.words():
|
114 |
+
img.draw_rectangle(word, color="hotpink")
|
115 |
+
img.draw_fixation_sequence(fixation_sequence)
|
116 |
+
img.save("temp_eyekit_img.png", crop_margin=200)
|
117 |
+
img_png = Image.open("temp_eyekit_img.png")
|
118 |
+
return img_png
|
119 |
+
|
120 |
+
|
121 |
+
def plot_with_measure(textblock, fixation_sequence, screen_size, measure, use_characters=False):
|
122 |
+
|
123 |
+
eyekitplot_img = eyekit_plot(textblock, fixation_sequence, screen_size)
|
124 |
+
eyekitplot_img = ek.vis.Image(*screen_size)
|
125 |
+
eyekitplot_img.draw_text_block(textblock)
|
126 |
+
if use_characters:
|
127 |
+
measure_results = getattr(ek.measure, measure)(textblock.characters(), fixation_sequence)
|
128 |
+
enum = textblock.characters()
|
129 |
+
else:
|
130 |
+
measure_results = getattr(ek.measure, measure)(textblock.words(), fixation_sequence)
|
131 |
+
enum = textblock.words()
|
132 |
+
for word in enum:
|
133 |
+
eyekitplot_img.draw_rectangle(word, color="lightseagreen")
|
134 |
+
x = word.onset
|
135 |
+
y = word.y_br - 3
|
136 |
+
label = f"{measure_results[word.id]}"
|
137 |
+
eyekitplot_img.draw_annotation((x, y), label, color="lightseagreen", font_face="Arial bold", font_size=15)
|
138 |
+
eyekitplot_img.draw_fixation_sequence(fixation_sequence, color="gray")
|
139 |
+
eyekitplot_img.save("multiline_passage_piccol.png", crop_margin=100)
|
140 |
+
img_png = Image.open("multiline_passage_piccol.png")
|
141 |
+
return img_png
|
142 |
+
|
143 |
+
|
144 |
+
def get_eyekit_measures(_txt, _seq, get_char_measures=False):
|
145 |
+
measures = copy.deepcopy(MEASURES_DICT)
|
146 |
+
words = []
|
147 |
+
for w in _txt.words():
|
148 |
+
words.append(w.text)
|
149 |
+
for m in measures.keys():
|
150 |
+
measures[m].append(getattr(ek.measure, m)(w, _seq))
|
151 |
+
word_measures_df = pd.DataFrame(measures)
|
152 |
+
word_measures_df["word_number"] = np.arange(0, len(words))
|
153 |
+
word_measures_df["word"] = words
|
154 |
+
|
155 |
+
first_column = word_measures_df.pop("word")
|
156 |
+
word_measures_df.insert(0, "word", first_column)
|
157 |
+
first_column = word_measures_df.pop("word_number")
|
158 |
+
word_measures_df.insert(0, "word_number", first_column)
|
159 |
+
|
160 |
+
if get_char_measures:
|
161 |
+
measures = copy.deepcopy(MEASURES_DICT)
|
162 |
+
|
163 |
+
characters = []
|
164 |
+
for c in _txt.characters():
|
165 |
+
characters.append(c.text)
|
166 |
+
for m in measures.keys():
|
167 |
+
measures[m].append(getattr(ek.measure, m)(c, _seq))
|
168 |
+
character_measures_df = pd.DataFrame(measures)
|
169 |
+
character_measures_df["char_number"] = np.arange(0, len(characters))
|
170 |
+
character_measures_df["character"] = characters
|
171 |
+
|
172 |
+
first_column = character_measures_df.pop("character")
|
173 |
+
character_measures_df.insert(0, "character", first_column)
|
174 |
+
first_column = character_measures_df.pop("char_number")
|
175 |
+
character_measures_df.insert(0, "char_number", first_column)
|
176 |
+
else:
|
177 |
+
character_measures_df = None
|
178 |
+
return word_measures_df, character_measures_df
|
loss_functions.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch as t
|
2 |
+
|
3 |
+
|
4 |
+
def macro_soft_f1(real_vals, predictions, reduction):
|
5 |
+
"""from https://towardsdatascience.com/the-unknown-benefits-of-using-a-soft-f1-loss-in-classification-systems-753902c0105d"""
|
6 |
+
true_positive = (real_vals * predictions).sum(dim=0)
|
7 |
+
false_positive = (predictions * (1 - real_vals)).sum(dim=0)
|
8 |
+
false_negative = ((1 - predictions) * real_vals).sum(dim=0)
|
9 |
+
soft_f1 = 2 * true_positive / (2 * true_positive + false_negative + false_positive + 1e-16)
|
10 |
+
if reduction == "mean":
|
11 |
+
loss = t.mean(1 - soft_f1)
|
12 |
+
else:
|
13 |
+
loss = 1 - soft_f1
|
14 |
+
return loss
|
15 |
+
|
16 |
+
|
17 |
+
def coral_loss(logits, levels, importance_weights=None, reduction="mean"):
|
18 |
+
"""Computes the CORAL loss described in
|
19 |
+
Cao, Mirjalili, and Raschka (2020)
|
20 |
+
*Rank Consistent Ordinal Regression for Neural Networks
|
21 |
+
with Application to Age Estimation*
|
22 |
+
Pattern Recognition Letters, https://doi.org/10.1016/j.patrec.2020.11.008
|
23 |
+
Parameters
|
24 |
+
----------
|
25 |
+
logits : torch.tensor, shape(num_examples, num_classes-1)
|
26 |
+
Outputs of the CORAL layer.
|
27 |
+
levels : torch.tensor, shape(num_examples, num_classes-1)
|
28 |
+
True labels represented as extended binary vectors
|
29 |
+
(via `coral_pytorch.dataset.levels_from_labelbatch`).
|
30 |
+
importance_weights : torch.tensor, shape=(num_classes-1,) (default=None)
|
31 |
+
Optional weights for the different labels in levels.
|
32 |
+
A tensor of ones, i.e.,
|
33 |
+
`torch.ones(num_classes-1, dtype=torch.float32)`
|
34 |
+
will result in uniform weights that have the same effect as None.
|
35 |
+
reduction : str or None (default='mean')
|
36 |
+
If 'mean' or 'sum', returns the averaged or summed loss value across
|
37 |
+
all data points (rows) in logits. If None, returns a vector of
|
38 |
+
shape (num_examples,)
|
39 |
+
Returns
|
40 |
+
----------
|
41 |
+
loss : torch.tensor
|
42 |
+
A torch.tensor containing a single loss value (if `reduction='mean'` or '`sum'`)
|
43 |
+
or a loss value for each data record (if `reduction=None`).
|
44 |
+
Examples
|
45 |
+
----------
|
46 |
+
>>> import torch
|
47 |
+
>>> from coral_pytorch.losses import coral_loss
|
48 |
+
>>> levels = torch.tensor(
|
49 |
+
... [[1., 1., 0., 0.],
|
50 |
+
... [1., 0., 0., 0.],
|
51 |
+
... [1., 1., 1., 1.]])
|
52 |
+
>>> logits = torch.tensor(
|
53 |
+
... [[2.1, 1.8, -2.1, -1.8],
|
54 |
+
... [1.9, -1., -1.5, -1.3],
|
55 |
+
... [1.9, 1.8, 1.7, 1.6]])
|
56 |
+
>>> coral_loss(logits, levels)
|
57 |
+
tensor(0.6920)
|
58 |
+
https://github.com/Raschka-research-group/coral-pytorch/blob/c6ab93afd555a6eac708c95ae1feafa15f91c5aa/coral_pytorch/losses.py
|
59 |
+
"""
|
60 |
+
|
61 |
+
if not logits.shape == levels.shape:
|
62 |
+
raise ValueError(
|
63 |
+
"Please ensure that logits (%s) has the same shape as levels (%s). " % (logits.shape, levels.shape)
|
64 |
+
)
|
65 |
+
|
66 |
+
term1 = t.nn.functional.logsigmoid(logits) * levels + (t.nn.functional.logsigmoid(logits) - logits) * (1 - levels)
|
67 |
+
|
68 |
+
if importance_weights is not None:
|
69 |
+
term1 *= importance_weights
|
70 |
+
|
71 |
+
val = -t.sum(term1, dim=1)
|
72 |
+
|
73 |
+
if reduction == "mean":
|
74 |
+
loss = t.mean(val)
|
75 |
+
elif reduction == "sum":
|
76 |
+
loss = t.sum(val)
|
77 |
+
elif reduction is None:
|
78 |
+
loss = val
|
79 |
+
else:
|
80 |
+
s = 'Invalid value for `reduction`. Should be "mean", ' '"sum", or None. Got %s' % reduction
|
81 |
+
raise ValueError(s)
|
82 |
+
|
83 |
+
return loss
|
84 |
+
|
85 |
+
|
86 |
+
def corn_loss(logits, y_train, num_classes):
|
87 |
+
"""Computes the CORN loss described in our forthcoming
|
88 |
+
'Deep Neural Networks for Rank Consistent Ordinal
|
89 |
+
Regression based on Conditional Probabilities'
|
90 |
+
manuscript.
|
91 |
+
Parameters
|
92 |
+
----------
|
93 |
+
logits : torch.tensor, shape=(num_examples, num_classes-1)
|
94 |
+
Outputs of the CORN layer.
|
95 |
+
y_train : torch.tensor, shape=(num_examples)
|
96 |
+
Torch tensor containing the class labels.
|
97 |
+
num_classes : int
|
98 |
+
Number of unique class labels (class labels should start at 0).
|
99 |
+
Returns
|
100 |
+
----------
|
101 |
+
loss : torch.tensor
|
102 |
+
A torch.tensor containing a single loss value.
|
103 |
+
Examples
|
104 |
+
----------
|
105 |
+
>>> import torch
|
106 |
+
>>> from coral_pytorch.losses import corn_loss
|
107 |
+
>>> # Consider 8 training examples
|
108 |
+
>>> _ = torch.manual_seed(123)
|
109 |
+
>>> X_train = torch.rand(8, 99)
|
110 |
+
>>> y_train = torch.tensor([0, 1, 2, 2, 2, 3, 4, 4])
|
111 |
+
>>> NUM_CLASSES = 5
|
112 |
+
>>> #
|
113 |
+
>>> #
|
114 |
+
>>> # def __init__(self):
|
115 |
+
>>> corn_net = torch.nn.Linear(99, NUM_CLASSES-1)
|
116 |
+
>>> #
|
117 |
+
>>> #
|
118 |
+
>>> # def forward(self, X_train):
|
119 |
+
>>> logits = corn_net(X_train)
|
120 |
+
>>> logits.shape
|
121 |
+
torch.Size([8, 4])
|
122 |
+
>>> corn_loss(logits, y_train, NUM_CLASSES)
|
123 |
+
tensor(0.7127, grad_fn=<DivBackward0>)
|
124 |
+
https://github.com/Raschka-research-group/coral-pytorch/blob/c6ab93afd555a6eac708c95ae1feafa15f91c5aa/coral_pytorch/losses.py
|
125 |
+
"""
|
126 |
+
sets = []
|
127 |
+
for i in range(num_classes - 1):
|
128 |
+
label_mask = y_train > i - 1
|
129 |
+
label_tensor = (y_train[label_mask] > i).to(t.int64)
|
130 |
+
sets.append((label_mask, label_tensor))
|
131 |
+
|
132 |
+
num_examples = 0
|
133 |
+
losses = 0.0
|
134 |
+
for task_index, s in enumerate(sets):
|
135 |
+
train_examples = s[0]
|
136 |
+
train_labels = s[1]
|
137 |
+
|
138 |
+
if len(train_labels) < 1:
|
139 |
+
continue
|
140 |
+
|
141 |
+
num_examples += len(train_labels)
|
142 |
+
pred = logits[train_examples, task_index]
|
143 |
+
|
144 |
+
loss = -t.sum(
|
145 |
+
t.nn.functional.logsigmoid(pred) * train_labels
|
146 |
+
+ (t.nn.functional.logsigmoid(pred) - pred) * (1 - train_labels)
|
147 |
+
)
|
148 |
+
losses += loss
|
149 |
+
|
150 |
+
return losses / num_examples
|
151 |
+
|
152 |
+
|
153 |
+
def corn_label_from_logits(logits):
|
154 |
+
"""
|
155 |
+
Returns the predicted rank label from logits for a
|
156 |
+
network trained via the CORN loss.
|
157 |
+
Parameters
|
158 |
+
----------
|
159 |
+
logits : torch.tensor, shape=(n_examples, n_classes)
|
160 |
+
Torch tensor consisting of logits returned by the
|
161 |
+
neural net.
|
162 |
+
Returns
|
163 |
+
----------
|
164 |
+
labels : torch.tensor, shape=(n_examples)
|
165 |
+
Integer tensor containing the predicted rank (class) labels
|
166 |
+
Examples
|
167 |
+
----------
|
168 |
+
>>> # 2 training examples, 5 classes
|
169 |
+
>>> logits = torch.tensor([[14.152, -6.1942, 0.47710, 0.96850],
|
170 |
+
... [65.667, 0.303, 11.500, -4.524]])
|
171 |
+
>>> corn_label_from_logits(logits)
|
172 |
+
tensor([1, 3])
|
173 |
+
https://github.com/Raschka-research-group/coral-pytorch/blob/c6ab93afd555a6eac708c95ae1feafa15f91c5aa/coral_pytorch/dataset.py
|
174 |
+
"""
|
175 |
+
probas = t.sigmoid(logits)
|
176 |
+
probas = t.cumprod(probas, dim=1)
|
177 |
+
predict_levels = probas > 0.5
|
178 |
+
predicted_labels = t.sum(predict_levels, dim=1)
|
179 |
+
return predicted_labels
|
models.py
ADDED
@@ -0,0 +1,897 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import timm
|
2 |
+
import os
|
3 |
+
from typing import Any
|
4 |
+
from pytorch_lightning.utilities.types import LRSchedulerTypeUnion
|
5 |
+
import torch as t
|
6 |
+
from torch import nn
|
7 |
+
import numpy as np
|
8 |
+
import transformers
|
9 |
+
import pytorch_lightning as plight
|
10 |
+
import torchmetrics
|
11 |
+
import einops as eo
|
12 |
+
from loss_functions import coral_loss, corn_loss, corn_label_from_logits, macro_soft_f1
|
13 |
+
|
14 |
+
t.set_float32_matmul_precision("medium")
|
15 |
+
global_settings = dict(try_using_torch_compile=False)
|
16 |
+
|
17 |
+
|
18 |
+
class EnsembleModel(plight.LightningModule):
|
19 |
+
def __init__(self, models_without_norm_df, models_with_norm_df, learning_rate=0.0002, use_simple_average=False):
|
20 |
+
super().__init__()
|
21 |
+
self.models_without_norm = nn.ModuleList(list(models_without_norm_df))
|
22 |
+
self.models_with_norm = nn.ModuleList(list(models_with_norm_df))
|
23 |
+
self.learning_rate = learning_rate
|
24 |
+
self.use_simple_average = use_simple_average
|
25 |
+
|
26 |
+
if not self.use_simple_average:
|
27 |
+
self.combiner = nn.Linear(
|
28 |
+
self.models_with_norm[0].num_classes * (len(self.models_with_norm) + len(self.models_without_norm)),
|
29 |
+
self.models_with_norm[0].num_classes,
|
30 |
+
)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
x_unnormed, x_normed = x
|
34 |
+
if not self.use_simple_average:
|
35 |
+
out_unnormed = t.cat([model.model_step(x_unnormed, 0)[0] for model in self.models_without_norm], dim=-1)
|
36 |
+
out_normed = t.cat([model.model_step(x_normed, 0)[0] for model in self.models_with_norm], dim=-1)
|
37 |
+
out_avg = self.combiner(t.cat((out_unnormed, out_normed), dim=-1))
|
38 |
+
else:
|
39 |
+
out_unnormed = [model.model_step(x_unnormed, 0)[0] for model in self.models_without_norm]
|
40 |
+
out_normed = [model.model_step(x_normed, 0)[0] for model in self.models_with_norm]
|
41 |
+
|
42 |
+
out_avg = (t.stack(out_unnormed + out_normed, dim=-1) / 2).mean(-1)
|
43 |
+
return {"out_avg": out_avg, "out_unnormed": out_unnormed, "out_normed": out_normed}, x_unnormed[-1]
|
44 |
+
|
45 |
+
def training_step(self, batch, batch_idx):
|
46 |
+
out, y = self(batch)
|
47 |
+
loss = self.models_with_norm[0]._get_loss(out["out_avg"], y, batch[0])
|
48 |
+
self.log("train_loss", loss, on_epoch=True, on_step=True, sync_dist=True)
|
49 |
+
return loss
|
50 |
+
|
51 |
+
def validation_step(self, batch, batch_idx):
|
52 |
+
out, y = self(batch)
|
53 |
+
preds, y_onecold, ignore_index_val = self.models_with_norm[0]._get_preds_reals(out["out_avg"], y)
|
54 |
+
acc = torchmetrics.functional.accuracy(
|
55 |
+
preds,
|
56 |
+
y_onecold.to(t.long),
|
57 |
+
ignore_index=ignore_index_val,
|
58 |
+
num_classes=self.models_with_norm[0].num_classes,
|
59 |
+
task="multiclass",
|
60 |
+
)
|
61 |
+
self.log("acc", acc * 100, prog_bar=True, sync_dist=True)
|
62 |
+
loss = self.models_with_norm[0]._get_loss(out["out_avg"], y, batch[0])
|
63 |
+
self.log("val_loss", loss, prog_bar=True, sync_dist=True)
|
64 |
+
return loss
|
65 |
+
|
66 |
+
def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
|
67 |
+
out, y = self(batch)
|
68 |
+
preds, y_onecold, ignore_index_val = self.models_with_norm[0]._get_preds_reals(out["out_avg"], y)
|
69 |
+
return preds, out, y_onecold
|
70 |
+
|
71 |
+
def configure_optimizers(self):
|
72 |
+
return t.optim.Adam(self.parameters(), lr=self.learning_rate)
|
73 |
+
|
74 |
+
|
75 |
+
class TimmHeadReplace(nn.Module):
|
76 |
+
def __init__(self, pooling=None, in_channels=512, pooling_output_dimension=1, all_identity=False) -> None:
|
77 |
+
super().__init__()
|
78 |
+
|
79 |
+
if all_identity:
|
80 |
+
self.head = nn.Identity()
|
81 |
+
self.pooling = None
|
82 |
+
else:
|
83 |
+
self.pooling = pooling
|
84 |
+
if pooling is not None:
|
85 |
+
self.pooling_output_dimension = pooling_output_dimension
|
86 |
+
if self.pooling == "AdaptiveAvgPool2d":
|
87 |
+
self.pooling_layer = nn.AdaptiveAvgPool2d(pooling_output_dimension)
|
88 |
+
elif self.pooling == "AdaptiveMaxPool2d":
|
89 |
+
self.pooling_layer = nn.AdaptiveMaxPool2d(pooling_output_dimension)
|
90 |
+
self.head = nn.Flatten()
|
91 |
+
|
92 |
+
def forward(self, x, pre_logits=False):
|
93 |
+
if self.pooling is not None:
|
94 |
+
if self.pooling == "stack_avg_max_attn":
|
95 |
+
x = t.cat([layer(x) for layer in self.pooling_layer], dim=-1)
|
96 |
+
else:
|
97 |
+
x = self.pooling_layer(x)
|
98 |
+
return self.head(x)
|
99 |
+
|
100 |
+
|
101 |
+
class CVModel(nn.Module):
|
102 |
+
def __init__(
|
103 |
+
self,
|
104 |
+
modelname,
|
105 |
+
in_shape,
|
106 |
+
num_classes,
|
107 |
+
loss_func,
|
108 |
+
last_activation: str,
|
109 |
+
input_padding_val=10,
|
110 |
+
char_dims=2,
|
111 |
+
max_seq_length=1000,
|
112 |
+
) -> None:
|
113 |
+
super().__init__()
|
114 |
+
self.modelname = modelname
|
115 |
+
self.loss_func = loss_func
|
116 |
+
self.in_shape = in_shape
|
117 |
+
self.char_dims = char_dims
|
118 |
+
self.x_shape = in_shape
|
119 |
+
self.last_activation = last_activation
|
120 |
+
self.max_seq_length = max_seq_length
|
121 |
+
self.num_classes = num_classes
|
122 |
+
if self.loss_func == "OrdinalRegLoss":
|
123 |
+
self.out_shape = 1
|
124 |
+
else:
|
125 |
+
self.out_shape = num_classes
|
126 |
+
|
127 |
+
self.cv_model = timm.create_model(modelname, pretrained=True, num_classes=0)
|
128 |
+
self.cv_model.classifier = nn.Identity()
|
129 |
+
with t.inference_mode():
|
130 |
+
test_out = self.cv_model(t.ones(self.in_shape, dtype=t.float32))
|
131 |
+
self.cv_model_out_dim = test_out.shape[1]
|
132 |
+
self.cv_model.classifier = nn.Sequential(nn.Flatten(), nn.Linear(self.cv_model_out_dim, self.max_seq_length))
|
133 |
+
if self.out_shape == 1:
|
134 |
+
self.logit_norm = nn.Identity()
|
135 |
+
self.out_project = nn.Identity()
|
136 |
+
else:
|
137 |
+
self.logit_norm = nn.LayerNorm(self.max_seq_length)
|
138 |
+
self.out_project = nn.Linear(1, self.out_shape)
|
139 |
+
|
140 |
+
if last_activation == "Softmax":
|
141 |
+
self.final_activation = nn.Softmax(dim=-1)
|
142 |
+
elif last_activation == "Sigmoid":
|
143 |
+
self.final_activation = nn.Sigmoid()
|
144 |
+
elif last_activation == "LogSigmoid":
|
145 |
+
self.final_activation = nn.LogSigmoid()
|
146 |
+
elif last_activation == "Identity":
|
147 |
+
self.final_activation = nn.Identity()
|
148 |
+
else:
|
149 |
+
raise NotImplementedError(f"{last_activation} not implemented")
|
150 |
+
|
151 |
+
def forward(self, x):
|
152 |
+
if isinstance(x, list):
|
153 |
+
x = x[0]
|
154 |
+
x = self.cv_model(x)
|
155 |
+
x = self.cv_model.classifier(x).unsqueeze(-1)
|
156 |
+
x = self.out_project(x)
|
157 |
+
return self.final_activation(x)
|
158 |
+
|
159 |
+
|
160 |
+
class LitModel(plight.LightningModule):
|
161 |
+
def __init__(
|
162 |
+
self,
|
163 |
+
in_shape: tuple,
|
164 |
+
hidden_dim: int,
|
165 |
+
num_attention_heads: int,
|
166 |
+
num_layers: int,
|
167 |
+
loss_func: str,
|
168 |
+
learning_rate: float,
|
169 |
+
weight_decay: float,
|
170 |
+
cfg: dict,
|
171 |
+
use_lr_warmup: bool,
|
172 |
+
use_reduce_on_plateau: bool,
|
173 |
+
track_gradient_histogram=False,
|
174 |
+
register_forw_hook=False,
|
175 |
+
char_dims=2,
|
176 |
+
) -> None:
|
177 |
+
super().__init__()
|
178 |
+
if "only_use_2nd_input_stream" not in cfg:
|
179 |
+
cfg["only_use_2nd_input_stream"] = False
|
180 |
+
|
181 |
+
if "gamma_step_size" not in cfg:
|
182 |
+
cfg["gamma_step_size"] = 5
|
183 |
+
if "gamma_step_factor" not in cfg:
|
184 |
+
cfg["gamma_step_factor"] = 0.5
|
185 |
+
self.save_hyperparameters(
|
186 |
+
dict(
|
187 |
+
in_shape=in_shape,
|
188 |
+
hidden_dim=hidden_dim,
|
189 |
+
num_attention_heads=num_attention_heads,
|
190 |
+
num_layers=num_layers,
|
191 |
+
loss_func=loss_func,
|
192 |
+
learning_rate=learning_rate,
|
193 |
+
cfg=cfg,
|
194 |
+
x_shape=in_shape,
|
195 |
+
num_classes=cfg["num_classes"],
|
196 |
+
use_lr_warmup=use_lr_warmup,
|
197 |
+
num_warmup_steps=cfg["num_warmup_steps"],
|
198 |
+
use_reduce_on_plateau=use_reduce_on_plateau,
|
199 |
+
weight_decay=weight_decay,
|
200 |
+
track_gradient_histogram=track_gradient_histogram,
|
201 |
+
register_forw_hook=register_forw_hook,
|
202 |
+
char_dims=char_dims,
|
203 |
+
remove_timm_classifier_head_pooling=cfg["remove_timm_classifier_head_pooling"],
|
204 |
+
change_pooling_for_timm_head_to=cfg["change_pooling_for_timm_head_to"],
|
205 |
+
chars_conv_pooling_out_dim=cfg["chars_conv_pooling_out_dim"],
|
206 |
+
)
|
207 |
+
)
|
208 |
+
self.model_to_use = cfg["model_to_use"]
|
209 |
+
self.num_classes = cfg["num_classes"]
|
210 |
+
self.x_shape = in_shape
|
211 |
+
self.in_shape = in_shape
|
212 |
+
self.hidden_dim = hidden_dim
|
213 |
+
self.num_attention_heads = num_attention_heads
|
214 |
+
self.num_layers = num_layers
|
215 |
+
|
216 |
+
self.use_lr_warmup = use_lr_warmup
|
217 |
+
self.num_warmup_steps = cfg["num_warmup_steps"]
|
218 |
+
self.warmup_exponent = cfg["warmup_exponent"]
|
219 |
+
|
220 |
+
self.use_reduce_on_plateau = use_reduce_on_plateau
|
221 |
+
self.loss_func = loss_func
|
222 |
+
self.learning_rate = learning_rate
|
223 |
+
self.weight_decay = weight_decay
|
224 |
+
self.using_one_hot_targets = cfg["one_hot_y"]
|
225 |
+
self.track_gradient_histogram = track_gradient_histogram
|
226 |
+
self.register_forw_hook = register_forw_hook
|
227 |
+
if self.loss_func == "OrdinalRegLoss":
|
228 |
+
self.ord_reg_loss_max = cfg["ord_reg_loss_max"]
|
229 |
+
self.ord_reg_loss_min = cfg["ord_reg_loss_min"]
|
230 |
+
|
231 |
+
self.num_lin_layers = cfg["num_lin_layers"]
|
232 |
+
self.linear_activation = cfg["linear_activation"]
|
233 |
+
self.last_activation = cfg["last_activation"]
|
234 |
+
|
235 |
+
self.max_seq_length = cfg["manual_max_sequence_for_model"]
|
236 |
+
|
237 |
+
self.use_char_embed_info = cfg["use_embedded_char_pos_info"]
|
238 |
+
|
239 |
+
self.method_chars_into_model = cfg["method_chars_into_model"]
|
240 |
+
self.source_for_pretrained_cv_model = cfg["source_for_pretrained_cv_model"]
|
241 |
+
self.method_to_include_char_positions = cfg["method_to_include_char_positions"]
|
242 |
+
|
243 |
+
self.char_dims = char_dims
|
244 |
+
self.char_sequence_length = cfg["max_len_chars_list"] if self.use_char_embed_info else 0
|
245 |
+
|
246 |
+
self.chars_conv_lr_reduction_factor = cfg["chars_conv_lr_reduction_factor"]
|
247 |
+
if self.use_char_embed_info:
|
248 |
+
self.chars_bert_reduction_factor = cfg["chars_bert_reduction_factor"]
|
249 |
+
|
250 |
+
self.use_in_projection_bias = cfg["use_in_projection_bias"]
|
251 |
+
self.add_layer_norm_to_in_projection = cfg["add_layer_norm_to_in_projection"]
|
252 |
+
|
253 |
+
self.hidden_dropout_prob = cfg["hidden_dropout_prob"]
|
254 |
+
self.layer_norm_after_in_projection = cfg["layer_norm_after_in_projection"]
|
255 |
+
self.method_chars_into_model = cfg["method_chars_into_model"]
|
256 |
+
self.input_padding_val = cfg["input_padding_val"]
|
257 |
+
self.cv_char_modelname = cfg["cv_char_modelname"]
|
258 |
+
self.char_plot_shape = cfg["char_plot_shape"]
|
259 |
+
|
260 |
+
self.remove_timm_classifier_head_pooling = cfg["remove_timm_classifier_head_pooling"]
|
261 |
+
self.change_pooling_for_timm_head_to = cfg["change_pooling_for_timm_head_to"]
|
262 |
+
self.chars_conv_pooling_out_dim = cfg["chars_conv_pooling_out_dim"]
|
263 |
+
|
264 |
+
self.add_layer_norm_to_char_mlp = cfg["add_layer_norm_to_char_mlp"]
|
265 |
+
if "profile_torch_run" in cfg:
|
266 |
+
self.profile_torch_run = cfg["profile_torch_run"]
|
267 |
+
else:
|
268 |
+
self.profile_torch_run = False
|
269 |
+
if self.loss_func == "OrdinalRegLoss":
|
270 |
+
self.out_shape = 1
|
271 |
+
else:
|
272 |
+
self.out_shape = cfg["num_classes"]
|
273 |
+
|
274 |
+
if not self.hparams.cfg["only_use_2nd_input_stream"]:
|
275 |
+
if (
|
276 |
+
self.method_chars_into_model == "dense"
|
277 |
+
and self.use_char_embed_info
|
278 |
+
and self.method_to_include_char_positions == "concat"
|
279 |
+
):
|
280 |
+
self.project = nn.Linear(self.x_shape[-1], self.hidden_dim // 2, bias=self.use_in_projection_bias)
|
281 |
+
elif (
|
282 |
+
self.method_chars_into_model == "bert"
|
283 |
+
and self.use_char_embed_info
|
284 |
+
and self.method_to_include_char_positions == "concat"
|
285 |
+
):
|
286 |
+
self.hidden_dim_chars = self.hidden_dim // 2
|
287 |
+
self.project = nn.Linear(self.x_shape[-1], self.hidden_dim_chars, bias=self.use_in_projection_bias)
|
288 |
+
elif (
|
289 |
+
self.method_chars_into_model == "resnet"
|
290 |
+
and self.method_to_include_char_positions == "concat"
|
291 |
+
and self.use_char_embed_info
|
292 |
+
):
|
293 |
+
self.project = nn.Linear(self.x_shape[-1], self.hidden_dim // 2, bias=self.use_in_projection_bias)
|
294 |
+
elif self.model_to_use == "cv_only_model":
|
295 |
+
self.project = nn.Identity()
|
296 |
+
else:
|
297 |
+
self.project = nn.Linear(self.x_shape[-1], self.hidden_dim, bias=self.use_in_projection_bias)
|
298 |
+
if self.add_layer_norm_to_in_projection:
|
299 |
+
self.project = nn.Sequential(
|
300 |
+
nn.Linear(self.project.in_features, self.project.out_features, bias=self.use_in_projection_bias),
|
301 |
+
nn.LayerNorm(self.project.out_features),
|
302 |
+
)
|
303 |
+
|
304 |
+
if hasattr(self, "project") and "posix" in os.name and global_settings["try_using_torch_compile"]:
|
305 |
+
self.project = t.compile(self.project)
|
306 |
+
|
307 |
+
if self.use_char_embed_info:
|
308 |
+
self._create_char_model()
|
309 |
+
|
310 |
+
if self.layer_norm_after_in_projection:
|
311 |
+
if self.hparams.cfg["only_use_2nd_input_stream"]:
|
312 |
+
self.layer_norm_in = nn.LayerNorm(self.hidden_dim // 2)
|
313 |
+
else:
|
314 |
+
self.layer_norm_in = nn.LayerNorm(self.hidden_dim)
|
315 |
+
|
316 |
+
if "posix" in os.name and global_settings["try_using_torch_compile"]:
|
317 |
+
self.layer_norm_in = t.compile(self.layer_norm_in)
|
318 |
+
|
319 |
+
self._create_main_seq_model(cfg)
|
320 |
+
|
321 |
+
if register_forw_hook:
|
322 |
+
self.register_hooks()
|
323 |
+
if self.hparams.cfg["only_use_2nd_input_stream"]:
|
324 |
+
linear_in_dim = self.hidden_dim // 2
|
325 |
+
else:
|
326 |
+
linear_in_dim = self.hidden_dim
|
327 |
+
|
328 |
+
if self.num_lin_layers == 1:
|
329 |
+
self.linear = nn.Linear(linear_in_dim, self.out_shape)
|
330 |
+
else:
|
331 |
+
lin_layers = []
|
332 |
+
for _ in range(self.num_lin_layers - 1):
|
333 |
+
lin_layers.extend(
|
334 |
+
[
|
335 |
+
nn.Linear(linear_in_dim, linear_in_dim),
|
336 |
+
getattr(nn, self.linear_activation)(),
|
337 |
+
]
|
338 |
+
)
|
339 |
+
self.linear = nn.Sequential(*lin_layers, nn.Linear(linear_in_dim, self.out_shape))
|
340 |
+
|
341 |
+
if "posix" in os.name and global_settings["try_using_torch_compile"]:
|
342 |
+
self.linear = t.compile(self.linear)
|
343 |
+
|
344 |
+
if self.last_activation == "Softmax":
|
345 |
+
self.final_activation = nn.Softmax(dim=-1)
|
346 |
+
elif self.last_activation == "Sigmoid":
|
347 |
+
self.final_activation = nn.Sigmoid()
|
348 |
+
elif self.last_activation == "Identity":
|
349 |
+
self.final_activation = nn.Identity()
|
350 |
+
else:
|
351 |
+
raise NotImplementedError(f"{self.last_activation} not implemented")
|
352 |
+
|
353 |
+
if self.profile_torch_run:
|
354 |
+
self.profilerr = t.profiler.profile(
|
355 |
+
schedule=t.profiler.schedule(wait=1, warmup=10, active=10, repeat=1),
|
356 |
+
on_trace_ready=t.profiler.tensorboard_trace_handler("tblogs"),
|
357 |
+
with_stack=True,
|
358 |
+
record_shapes=True,
|
359 |
+
profile_memory=False,
|
360 |
+
)
|
361 |
+
|
362 |
+
def _create_main_seq_model(self, cfg):
|
363 |
+
if self.hparams.cfg["only_use_2nd_input_stream"]:
|
364 |
+
hidden_dim = self.hidden_dim // 2
|
365 |
+
else:
|
366 |
+
hidden_dim = self.hidden_dim
|
367 |
+
if self.model_to_use == "BERT":
|
368 |
+
self.bert_config = transformers.BertConfig(
|
369 |
+
vocab_size=self.x_shape[-1],
|
370 |
+
hidden_size=hidden_dim,
|
371 |
+
num_hidden_layers=self.num_layers,
|
372 |
+
intermediate_size=hidden_dim,
|
373 |
+
num_attention_heads=self.num_attention_heads,
|
374 |
+
max_position_embeddings=self.max_seq_length,
|
375 |
+
)
|
376 |
+
self.bert_model = transformers.BertModel(self.bert_config)
|
377 |
+
elif self.model_to_use == "cv_only_model":
|
378 |
+
self.bert_model = CVModel(
|
379 |
+
modelname=cfg["cv_modelname"],
|
380 |
+
in_shape=self.in_shape,
|
381 |
+
num_classes=cfg["num_classes"],
|
382 |
+
loss_func=cfg["loss_function"],
|
383 |
+
last_activation=cfg["last_activation"],
|
384 |
+
input_padding_val=cfg["input_padding_val"],
|
385 |
+
char_dims=self.char_dims,
|
386 |
+
max_seq_length=cfg["manual_max_sequence_for_model"],
|
387 |
+
)
|
388 |
+
else:
|
389 |
+
raise NotImplementedError(f"{self.model_to_use} not implemented")
|
390 |
+
if "posix" in os.name and global_settings["try_using_torch_compile"]:
|
391 |
+
self.bert_model = t.compile(self.bert_model)
|
392 |
+
return 0
|
393 |
+
|
394 |
+
def _create_char_model(self):
|
395 |
+
if self.method_chars_into_model == "dense":
|
396 |
+
self.chars_project_0 = nn.Linear(self.char_dims, 1, bias=self.use_in_projection_bias)
|
397 |
+
if "posix" in os.name and global_settings["try_using_torch_compile"]:
|
398 |
+
self.chars_project_0 = t.compile(self.chars_project_0)
|
399 |
+
if self.method_to_include_char_positions == "concat":
|
400 |
+
self.chars_project_1 = nn.Linear(
|
401 |
+
self.char_sequence_length, self.hidden_dim // 2, bias=self.use_in_projection_bias
|
402 |
+
)
|
403 |
+
else:
|
404 |
+
self.chars_project_1 = nn.Linear(
|
405 |
+
self.char_sequence_length, self.hidden_dim, bias=self.use_in_projection_bias
|
406 |
+
)
|
407 |
+
|
408 |
+
if "posix" in os.name and global_settings["try_using_torch_compile"]:
|
409 |
+
self.chars_project_1 = t.compile(self.chars_project_1)
|
410 |
+
elif not self.method_chars_into_model == "resnet":
|
411 |
+
self.chars_project = nn.Linear(self.char_dims, self.hidden_dim_chars, bias=self.use_in_projection_bias)
|
412 |
+
if "posix" in os.name and global_settings["try_using_torch_compile"]:
|
413 |
+
self.chars_project = t.compile(self.chars_project)
|
414 |
+
|
415 |
+
if self.method_chars_into_model == "bert":
|
416 |
+
if not hasattr(self, "hidden_dim_chars"):
|
417 |
+
if self.hidden_dim // self.chars_bert_reduction_factor > 1:
|
418 |
+
self.hidden_dim_chars = self.hidden_dim // self.chars_bert_reduction_factor
|
419 |
+
else:
|
420 |
+
self.hidden_dim_chars = self.hidden_dim
|
421 |
+
self.num_attention_heads_chars = self.hidden_dim_chars // (self.hidden_dim // self.num_attention_heads)
|
422 |
+
self.chars_bert_config = transformers.BertConfig(
|
423 |
+
vocab_size=self.x_shape[-1],
|
424 |
+
hidden_size=self.hidden_dim_chars,
|
425 |
+
num_hidden_layers=self.num_layers,
|
426 |
+
intermediate_size=self.hidden_dim_chars,
|
427 |
+
num_attention_heads=self.num_attention_heads_chars,
|
428 |
+
max_position_embeddings=self.char_sequence_length + 1,
|
429 |
+
num_labels=1,
|
430 |
+
)
|
431 |
+
self.chars_bert = transformers.BertForSequenceClassification(self.chars_bert_config)
|
432 |
+
|
433 |
+
if "posix" in os.name and global_settings["try_using_torch_compile"]:
|
434 |
+
self.chars_bert = t.compile(self.chars_bert)
|
435 |
+
self.chars_project_class_output = nn.Linear(1, self.hidden_dim_chars, bias=self.use_in_projection_bias)
|
436 |
+
if "posix" in os.name and global_settings["try_using_torch_compile"]:
|
437 |
+
self.chars_project_class_output = t.compile(self.chars_project_class_output)
|
438 |
+
elif self.method_chars_into_model == "resnet":
|
439 |
+
if self.source_for_pretrained_cv_model == "timm":
|
440 |
+
self.chars_conv = timm.create_model(
|
441 |
+
self.cv_char_modelname,
|
442 |
+
pretrained=True,
|
443 |
+
num_classes=0, # remove classifier nn.Linear
|
444 |
+
)
|
445 |
+
if self.remove_timm_classifier_head_pooling:
|
446 |
+
self.chars_conv.head = TimmHeadReplace(all_identity=True)
|
447 |
+
with t.inference_mode():
|
448 |
+
test_out = self.chars_conv(
|
449 |
+
t.ones((1, 3, self.char_plot_shape[0], self.char_plot_shape[1]), dtype=t.float32)
|
450 |
+
)
|
451 |
+
if test_out.ndim > 3:
|
452 |
+
self.chars_conv.head = TimmHeadReplace(
|
453 |
+
self.change_pooling_for_timm_head_to,
|
454 |
+
test_out.shape[1],
|
455 |
+
)
|
456 |
+
elif self.source_for_pretrained_cv_model == "huggingface":
|
457 |
+
self.chars_conv = transformers.AutoModelForImageClassification.from_pretrained(self.cv_char_modelname)
|
458 |
+
elif self.source_for_pretrained_cv_model == "torch_hub":
|
459 |
+
self.chars_conv = t.hub.load(*self.cv_char_modelname.split(","))
|
460 |
+
|
461 |
+
if hasattr(self.chars_conv, "classifier"):
|
462 |
+
self.chars_conv.classifier = nn.Identity()
|
463 |
+
elif hasattr(self.chars_conv, "cls_classifier"):
|
464 |
+
self.chars_conv.cls_classifier = nn.Identity()
|
465 |
+
elif hasattr(self.chars_conv, "fc"):
|
466 |
+
self.chars_conv.fc = nn.Identity()
|
467 |
+
|
468 |
+
if hasattr(self.chars_conv, "distillation_classifier"):
|
469 |
+
self.chars_conv.distillation_classifier = nn.Identity()
|
470 |
+
with t.inference_mode():
|
471 |
+
test_out = self.chars_conv(
|
472 |
+
t.ones((1, 3, self.char_plot_shape[0], self.char_plot_shape[1]), dtype=t.float32)
|
473 |
+
)
|
474 |
+
if hasattr(test_out, "last_hidden_state"):
|
475 |
+
self.chars_conv_out_dim = test_out.last_hidden_state.shape[1]
|
476 |
+
elif hasattr(test_out, "logits"):
|
477 |
+
self.chars_conv_out_dim = test_out.logits.shape[1]
|
478 |
+
elif isinstance(test_out, list):
|
479 |
+
self.chars_conv_out_dim = test_out[0].shape[1]
|
480 |
+
else:
|
481 |
+
self.chars_conv_out_dim = test_out.shape[1]
|
482 |
+
|
483 |
+
char_lin_layers = [nn.Flatten(), nn.Linear(self.chars_conv_out_dim, self.hidden_dim // 2)]
|
484 |
+
if self.add_layer_norm_to_char_mlp:
|
485 |
+
char_lin_layers.append(nn.LayerNorm(self.hidden_dim // 2))
|
486 |
+
self.chars_classifier = nn.Sequential(*char_lin_layers)
|
487 |
+
if hasattr(self.chars_conv, "distillation_classifier"):
|
488 |
+
self.chars_conv.distillation_classifier = nn.Sequential(
|
489 |
+
nn.Flatten(), nn.Linear(self.chars_conv_out_dim, self.hidden_dim // 2)
|
490 |
+
)
|
491 |
+
|
492 |
+
if "posix" in os.name and global_settings["try_using_torch_compile"]:
|
493 |
+
self.chars_classifier = t.compile(self.chars_classifier)
|
494 |
+
if "posix" in os.name and global_settings["try_using_torch_compile"]:
|
495 |
+
self.chars_conv = t.compile(self.chars_conv)
|
496 |
+
return 0
|
497 |
+
|
498 |
+
def register_hooks(self):
|
499 |
+
def add_to_tb(layer):
|
500 |
+
def hook(model, input, output):
|
501 |
+
if hasattr(output, "detach"):
|
502 |
+
for logger in self.loggers:
|
503 |
+
if hasattr(logger.experiment, "add_histogram"):
|
504 |
+
logger.experiment.add_histogram(
|
505 |
+
tag=f"{layer}_{str(list(output.shape))}",
|
506 |
+
values=output.detach(),
|
507 |
+
global_step=self.trainer.global_step,
|
508 |
+
)
|
509 |
+
|
510 |
+
return hook
|
511 |
+
|
512 |
+
for layer_id, layer in dict([*self.named_modules()]).items():
|
513 |
+
layer.register_forward_hook(add_to_tb(f"act_{layer_id}"))
|
514 |
+
|
515 |
+
def on_after_backward(self) -> None:
|
516 |
+
if self.track_gradient_histogram:
|
517 |
+
if self.trainer.global_step % 200 == 0:
|
518 |
+
for logger in self.loggers:
|
519 |
+
if hasattr(logger.experiment, "add_histogram"):
|
520 |
+
for layer_id, layer in dict([*self.named_modules()]).items():
|
521 |
+
parameters = layer.parameters()
|
522 |
+
for idx2, p in enumerate(parameters):
|
523 |
+
grad_val = p.grad
|
524 |
+
if grad_val is not None:
|
525 |
+
grad_name = f"grad_{idx2}_{layer_id}_{str(list(p.grad.shape))}"
|
526 |
+
logger.experiment.add_histogram(
|
527 |
+
tag=grad_name, values=grad_val, global_step=self.trainer.global_step
|
528 |
+
)
|
529 |
+
|
530 |
+
return super().on_after_backward()
|
531 |
+
|
532 |
+
def _fold_in_seq_dim(self, out, y):
|
533 |
+
batch_size, seq_len, num_classes = out.shape
|
534 |
+
out = eo.rearrange(out, "b s c -> (b s) c", s=seq_len)
|
535 |
+
if y is None:
|
536 |
+
return out, None
|
537 |
+
if len(y.shape) > 2:
|
538 |
+
y = eo.rearrange(y, "b s c -> (b s) c", s=seq_len)
|
539 |
+
else:
|
540 |
+
y = eo.rearrange(y, "b s -> (b s)", s=seq_len)
|
541 |
+
return out, y
|
542 |
+
|
543 |
+
def _get_loss(self, out, y, batch):
|
544 |
+
attention_mask = batch[-2]
|
545 |
+
if self.loss_func == "BCELoss":
|
546 |
+
if self.last_activation == "Identity":
|
547 |
+
loss = t.nn.functional.binary_cross_entropy_with_logits(out, y, reduction="none")
|
548 |
+
else:
|
549 |
+
loss = t.nn.functional.binary_cross_entropy(out, y, reduction="none")
|
550 |
+
|
551 |
+
replace_tensor = t.zeros(loss[1, 1, :].shape, device=loss.device, dtype=loss.dtype, requires_grad=False)
|
552 |
+
loss[~attention_mask.bool()] = replace_tensor
|
553 |
+
loss = loss.mean()
|
554 |
+
elif self.loss_func == "CrossEntropyLoss":
|
555 |
+
if len(out.shape) > 2:
|
556 |
+
out, y = self._fold_in_seq_dim(out, y)
|
557 |
+
loss = t.nn.functional.cross_entropy(out, y, reduction="mean", ignore_index=-100)
|
558 |
+
else:
|
559 |
+
loss = t.nn.functional.cross_entropy(out, y, reduction="mean", ignore_index=-100)
|
560 |
+
|
561 |
+
elif self.loss_func == "OrdinalRegLoss":
|
562 |
+
loss = t.nn.functional.mse_loss(out, y, reduction="none")
|
563 |
+
loss = loss[attention_mask.bool()].sum() * 10.0 / attention_mask.sum()
|
564 |
+
elif self.loss_func == "macro_soft_f1":
|
565 |
+
loss = macro_soft_f1(y, out, reduction="mean")
|
566 |
+
elif self.loss_func == "coral_loss":
|
567 |
+
loss = coral_loss(out, y)
|
568 |
+
elif self.loss_func == "corn_loss":
|
569 |
+
out, y = self._fold_in_seq_dim(out, y)
|
570 |
+
loss = corn_loss(out, y.squeeze(), self.out_shape)
|
571 |
+
else:
|
572 |
+
raise ValueError("Loss Function not reckognized")
|
573 |
+
return loss
|
574 |
+
|
575 |
+
def training_step(self, batch, batch_idx):
|
576 |
+
if self.profile_torch_run:
|
577 |
+
self.profilerr.step()
|
578 |
+
out, y = self.model_step(batch, batch_idx)
|
579 |
+
loss = self._get_loss(out, y, batch)
|
580 |
+
self.log("train_loss", loss, on_epoch=True, on_step=True, sync_dist=True)
|
581 |
+
return loss
|
582 |
+
|
583 |
+
def forward(*args):
|
584 |
+
return forward(args[0], args[1:])
|
585 |
+
|
586 |
+
def model_step(self, batch, batch_idx):
|
587 |
+
out = self.forward(batch)
|
588 |
+
return out, batch[-1]
|
589 |
+
|
590 |
+
def optimizer_step(
|
591 |
+
self,
|
592 |
+
epoch,
|
593 |
+
batch_idx,
|
594 |
+
optimizer,
|
595 |
+
optimizer_closure,
|
596 |
+
):
|
597 |
+
optimizer.step(closure=optimizer_closure)
|
598 |
+
|
599 |
+
if self.use_lr_warmup and self.hparams["cfg"]["lr_scheduling"] != "OneCycleLR":
|
600 |
+
if self.trainer.global_step < self.num_warmup_steps:
|
601 |
+
lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.num_warmup_steps) ** self.warmup_exponent
|
602 |
+
for pg in optimizer.param_groups:
|
603 |
+
pg["lr"] = lr_scale * self.hparams.learning_rate
|
604 |
+
if self.trainer.global_step % 10 == 0 or self.trainer.global_step == 0:
|
605 |
+
for idx, pg in enumerate(optimizer.param_groups):
|
606 |
+
self.log(f"lr_{idx}", pg["lr"], prog_bar=True, sync_dist=True)
|
607 |
+
|
608 |
+
def lr_scheduler_step(self, scheduler: LRSchedulerTypeUnion, metric: Any | None) -> None:
|
609 |
+
if self.use_lr_warmup and self.hparams["cfg"]["lr_scheduling"] != "OneCycleLR":
|
610 |
+
if self.trainer.global_step > self.num_warmup_steps:
|
611 |
+
if metric is None:
|
612 |
+
scheduler.step()
|
613 |
+
else:
|
614 |
+
scheduler.step(metric)
|
615 |
+
else:
|
616 |
+
if metric is None:
|
617 |
+
scheduler.step()
|
618 |
+
else:
|
619 |
+
scheduler.step(metric)
|
620 |
+
|
621 |
+
def _get_preds_reals(self, out, y):
|
622 |
+
if self.loss_func == "corn_loss":
|
623 |
+
seq_len = out.shape[1]
|
624 |
+
out, y = self._fold_in_seq_dim(out, y)
|
625 |
+
preds = corn_label_from_logits(out)
|
626 |
+
preds = eo.rearrange(preds, "(b s) -> b s", s=seq_len)
|
627 |
+
if y is not None:
|
628 |
+
y = eo.rearrange(y.squeeze(), "(b s) -> b s", s=seq_len)
|
629 |
+
|
630 |
+
elif self.loss_func == "OrdinalRegLoss":
|
631 |
+
preds = out * (self.ord_reg_loss_max - self.ord_reg_loss_min)
|
632 |
+
preds = (preds + self.ord_reg_loss_min).round().to(t.long)
|
633 |
+
|
634 |
+
else:
|
635 |
+
preds = t.argmax(out, dim=-1)
|
636 |
+
if y is None:
|
637 |
+
return preds, y, -100
|
638 |
+
else:
|
639 |
+
if self.using_one_hot_targets:
|
640 |
+
y_onecold = t.argmax(y, dim=-1)
|
641 |
+
ignore_index_val = 0
|
642 |
+
elif self.loss_func == "OrdinalRegLoss":
|
643 |
+
y_onecold = (y * self.num_classes).round().to(t.long)
|
644 |
+
|
645 |
+
y_onecold = y * (self.ord_reg_loss_max - self.ord_reg_loss_min)
|
646 |
+
y_onecold = (y_onecold + self.ord_reg_loss_min).round().to(t.long)
|
647 |
+
ignore_index_val = t.min(y_onecold).to(t.long)
|
648 |
+
else:
|
649 |
+
y_onecold = y
|
650 |
+
ignore_index_val = -100
|
651 |
+
|
652 |
+
if len(preds.shape) > len(y_onecold.shape):
|
653 |
+
preds = preds.squeeze()
|
654 |
+
return preds, y_onecold, ignore_index_val
|
655 |
+
|
656 |
+
def validation_step(self, batch, batch_idx):
|
657 |
+
out, y = self.model_step(batch, batch_idx)
|
658 |
+
preds, y_onecold, ignore_index_val = self._get_preds_reals(out, y)
|
659 |
+
|
660 |
+
if self.loss_func == "OrdinalRegLoss":
|
661 |
+
y_onecold = y_onecold.flatten()
|
662 |
+
preds = preds.flatten()[y_onecold != ignore_index_val]
|
663 |
+
y_onecold = y_onecold[y_onecold != ignore_index_val]
|
664 |
+
acc = (preds == y_onecold).sum() / len(y_onecold)
|
665 |
+
else:
|
666 |
+
acc = torchmetrics.functional.accuracy(
|
667 |
+
preds,
|
668 |
+
y_onecold.to(t.long),
|
669 |
+
ignore_index=ignore_index_val,
|
670 |
+
num_classes=self.num_classes,
|
671 |
+
task="multiclass",
|
672 |
+
)
|
673 |
+
self.log("acc", acc * 100, prog_bar=True, sync_dist=True)
|
674 |
+
loss = self._get_loss(out, y, batch)
|
675 |
+
self.log("val_loss", loss, prog_bar=True, sync_dist=True)
|
676 |
+
|
677 |
+
return loss
|
678 |
+
|
679 |
+
def predict_step(self, batch, batch_idx):
|
680 |
+
out, y = self.model_step(batch, batch_idx)
|
681 |
+
preds, y_onecold, ignore_index_val = self._get_preds_reals(out, y)
|
682 |
+
return preds, y_onecold
|
683 |
+
|
684 |
+
def configure_optimizers(self):
|
685 |
+
params = list(self.named_parameters())
|
686 |
+
|
687 |
+
def is_chars_conv(n):
|
688 |
+
if "chars_conv" not in n:
|
689 |
+
return False
|
690 |
+
if "chars_conv" in n and "classifier" in n:
|
691 |
+
return False
|
692 |
+
else:
|
693 |
+
return True
|
694 |
+
|
695 |
+
grouped_parameters = [
|
696 |
+
{
|
697 |
+
"params": [p for n, p in params if is_chars_conv(n)],
|
698 |
+
"lr": self.learning_rate / self.chars_conv_lr_reduction_factor,
|
699 |
+
"weight_decay": self.weight_decay,
|
700 |
+
},
|
701 |
+
{
|
702 |
+
"params": [p for n, p in params if not is_chars_conv(n)],
|
703 |
+
"lr": self.learning_rate,
|
704 |
+
"weight_decay": self.weight_decay,
|
705 |
+
},
|
706 |
+
]
|
707 |
+
opti = t.optim.AdamW(grouped_parameters, lr=self.learning_rate, weight_decay=self.weight_decay)
|
708 |
+
if self.use_reduce_on_plateau:
|
709 |
+
opti_dict = {
|
710 |
+
"optimizer": opti,
|
711 |
+
"lr_scheduler": {
|
712 |
+
"scheduler": t.optim.lr_scheduler.ReduceLROnPlateau(opti, mode="min", patience=2, factor=0.5),
|
713 |
+
"monitor": "val_loss",
|
714 |
+
"frequency": 1,
|
715 |
+
"interval": "epoch",
|
716 |
+
},
|
717 |
+
}
|
718 |
+
return opti_dict
|
719 |
+
else:
|
720 |
+
cfg = self.hparams["cfg"]
|
721 |
+
if cfg["use_reduce_on_plateau"]:
|
722 |
+
scheduler = None
|
723 |
+
elif cfg["lr_scheduling"] == "multistep":
|
724 |
+
scheduler = t.optim.lr_scheduler.MultiStepLR(
|
725 |
+
opti, milestones=cfg["multistep_milestones"], gamma=cfg["gamma_multistep"], verbose=False
|
726 |
+
)
|
727 |
+
interval = "step" if cfg["use_training_steps_for_end_and_lr_decay"] else "epoch"
|
728 |
+
elif cfg["lr_scheduling"] == "StepLR":
|
729 |
+
scheduler = t.optim.lr_scheduler.StepLR(
|
730 |
+
opti, step_size=cfg["gamma_step_size"], gamma=cfg["gamma_step_factor"]
|
731 |
+
)
|
732 |
+
interval = "step" if cfg["use_training_steps_for_end_and_lr_decay"] else "epoch"
|
733 |
+
elif cfg["lr_scheduling"] == "anneal":
|
734 |
+
scheduler = t.optim.lr_scheduler.CosineAnnealingLR(
|
735 |
+
opti, 250, eta_min=cfg["min_lr_anneal"], last_epoch=-1, verbose=False
|
736 |
+
)
|
737 |
+
interval = "step"
|
738 |
+
elif cfg["lr_scheduling"] == "ExponentialLR":
|
739 |
+
scheduler = t.optim.lr_scheduler.ExponentialLR(opti, gamma=cfg["lr_sched_exp_fac"])
|
740 |
+
interval = "step"
|
741 |
+
else:
|
742 |
+
scheduler = None
|
743 |
+
if scheduler is None:
|
744 |
+
return [opti]
|
745 |
+
else:
|
746 |
+
opti_dict = {
|
747 |
+
"optimizer": opti,
|
748 |
+
"lr_scheduler": {
|
749 |
+
"scheduler": scheduler,
|
750 |
+
"monitor": "global_step",
|
751 |
+
"frequency": 1,
|
752 |
+
"interval": interval,
|
753 |
+
},
|
754 |
+
}
|
755 |
+
return opti_dict
|
756 |
+
|
757 |
+
def on_fit_start(self) -> None:
|
758 |
+
if self.profile_torch_run:
|
759 |
+
self.profilerr.start()
|
760 |
+
return super().on_fit_start()
|
761 |
+
|
762 |
+
def on_fit_end(self) -> None:
|
763 |
+
if self.profile_torch_run:
|
764 |
+
self.profilerr.stop()
|
765 |
+
return super().on_fit_end()
|
766 |
+
|
767 |
+
|
768 |
+
def prep_model_input(self, batch):
|
769 |
+
if len(batch) == 1:
|
770 |
+
batch = batch[0]
|
771 |
+
if self.use_char_embed_info:
|
772 |
+
if len(batch) == 5:
|
773 |
+
x, chars_coords, ims, attention_mask, _ = batch
|
774 |
+
elif batch[1].ndim == 4:
|
775 |
+
x, ims, attention_mask, _ = batch
|
776 |
+
else:
|
777 |
+
x, chars_coords, attention_mask, _ = batch
|
778 |
+
padding_list = None
|
779 |
+
else:
|
780 |
+
if len(batch) > 3:
|
781 |
+
x = batch[0]
|
782 |
+
y = batch[-1]
|
783 |
+
attention_mask = batch[1]
|
784 |
+
else:
|
785 |
+
x, attention_mask, y = batch
|
786 |
+
|
787 |
+
if self.model_to_use != "cv_only_model" and not self.hparams.cfg["only_use_2nd_input_stream"]:
|
788 |
+
x_embedded = self.project(x)
|
789 |
+
else:
|
790 |
+
x_embedded = x
|
791 |
+
if self.use_char_embed_info:
|
792 |
+
if self.method_chars_into_model == "dense":
|
793 |
+
bool_mask = chars_coords == self.input_padding_val
|
794 |
+
bool_mask = bool_mask[:, :, 0]
|
795 |
+
chars_coords_projected = self.chars_project_0(chars_coords).squeeze(-1)
|
796 |
+
chars_coords_projected = chars_coords_projected * bool_mask
|
797 |
+
if self.chars_project_1.in_features == chars_coords_projected.shape[-1]:
|
798 |
+
chars_coords_projected = self.chars_project_1(chars_coords_projected)
|
799 |
+
else:
|
800 |
+
chars_coords_projected = chars_coords_projected.mean(dim=-1)
|
801 |
+
chars_coords_projected = chars_coords_projected.unsqueeze(1).repeat(1, x_embedded.shape[2])
|
802 |
+
elif self.method_chars_into_model == "bert":
|
803 |
+
chars_mask = chars_coords != self.input_padding_val
|
804 |
+
chars_mask = t.cat(
|
805 |
+
(
|
806 |
+
t.ones(chars_mask[:, :1, 0].shape, dtype=t.long, device=chars_coords.device),
|
807 |
+
chars_mask[:, :, 0].to(t.long),
|
808 |
+
),
|
809 |
+
dim=1,
|
810 |
+
)
|
811 |
+
chars_coords_projected = self.chars_project(chars_coords)
|
812 |
+
|
813 |
+
position_ids = t.arange(
|
814 |
+
0, chars_coords_projected.shape[1] + 1, dtype=t.long, device=chars_coords_projected.device
|
815 |
+
)
|
816 |
+
token_type_ids = t.zeros(
|
817 |
+
(chars_coords_projected.size()[0], chars_coords_projected.size()[1] + 1),
|
818 |
+
dtype=t.long,
|
819 |
+
device=chars_coords_projected.device,
|
820 |
+
) # +1 for CLS
|
821 |
+
chars_coords_projected = t.cat(
|
822 |
+
(t.ones_like(chars_coords_projected[:, :1, :]), chars_coords_projected), dim=1
|
823 |
+
) # to add CLS token
|
824 |
+
chars_coords_projected = self.chars_bert(
|
825 |
+
position_ids=position_ids,
|
826 |
+
inputs_embeds=chars_coords_projected,
|
827 |
+
token_type_ids=token_type_ids,
|
828 |
+
attention_mask=chars_mask,
|
829 |
+
)
|
830 |
+
if hasattr(chars_coords_projected, "last_hidden_state"):
|
831 |
+
chars_coords_projected = chars_coords_projected.last_hidden_state[:, 0, :]
|
832 |
+
elif hasattr(chars_coords_projected, "logits"):
|
833 |
+
chars_coords_projected = chars_coords_projected.logits
|
834 |
+
else:
|
835 |
+
chars_coords_projected = chars_coords_projected.hidden_states[-1][:, 0, :]
|
836 |
+
elif self.method_chars_into_model == "resnet":
|
837 |
+
chars_conv_out = self.chars_conv(ims)
|
838 |
+
if isinstance(chars_conv_out, list):
|
839 |
+
chars_conv_out = chars_conv_out[0]
|
840 |
+
if hasattr(chars_conv_out, "logits"):
|
841 |
+
chars_conv_out = chars_conv_out.logits
|
842 |
+
chars_coords_projected = self.chars_classifier(chars_conv_out)
|
843 |
+
|
844 |
+
chars_coords_projected = chars_coords_projected.unsqueeze(1).repeat(1, x_embedded.shape[1], 1)
|
845 |
+
if hasattr(self, "chars_project_class_output"):
|
846 |
+
chars_coords_projected = self.chars_project_class_output(chars_coords_projected)
|
847 |
+
|
848 |
+
if self.hparams.cfg["only_use_2nd_input_stream"]:
|
849 |
+
x_embedded = chars_coords_projected
|
850 |
+
elif self.method_to_include_char_positions == "concat":
|
851 |
+
x_embedded = t.cat((x_embedded, chars_coords_projected), dim=-1)
|
852 |
+
else:
|
853 |
+
x_embedded = x_embedded + chars_coords_projected
|
854 |
+
return x_embedded, attention_mask
|
855 |
+
|
856 |
+
|
857 |
+
def forward(self, batch):
|
858 |
+
prepped_input = prep_model_input(self, batch)
|
859 |
+
|
860 |
+
if len(batch) > 5:
|
861 |
+
x_embedded, padding_list, attention_mask, attention_mask_for_prediction = prepped_input
|
862 |
+
elif len(batch) > 2:
|
863 |
+
x_embedded, attention_mask = prepped_input
|
864 |
+
else:
|
865 |
+
x_embedded = prepped_input[0]
|
866 |
+
attention_mask = prepped_input[-1]
|
867 |
+
|
868 |
+
position_ids = t.arange(0, x_embedded.shape[1], dtype=t.long, device=x_embedded.device)
|
869 |
+
token_type_ids = t.zeros(x_embedded.size()[:-1], dtype=t.long, device=x_embedded.device)
|
870 |
+
|
871 |
+
if self.layer_norm_after_in_projection:
|
872 |
+
x_embedded = self.layer_norm_in(x_embedded)
|
873 |
+
|
874 |
+
if self.model_to_use == "LSTM":
|
875 |
+
bert_out = self.bert_model(x_embedded)
|
876 |
+
elif self.model_to_use in ["ProphetNet", "T5", "FunnelModel"]:
|
877 |
+
bert_out = self.bert_model(inputs_embeds=x_embedded, attention_mask=attention_mask)
|
878 |
+
elif self.model_to_use == "xBERT":
|
879 |
+
bert_out = self.bert_model(x_embedded, mask=attention_mask.to(bool))
|
880 |
+
elif self.model_to_use == "cv_only_model":
|
881 |
+
bert_out = self.bert_model(x_embedded)
|
882 |
+
else:
|
883 |
+
bert_out = self.bert_model(
|
884 |
+
position_ids=position_ids,
|
885 |
+
inputs_embeds=x_embedded,
|
886 |
+
token_type_ids=token_type_ids,
|
887 |
+
attention_mask=attention_mask,
|
888 |
+
)
|
889 |
+
if hasattr(bert_out, "last_hidden_state"):
|
890 |
+
last_hidden_state = bert_out.last_hidden_state
|
891 |
+
out = self.linear(last_hidden_state)
|
892 |
+
elif hasattr(bert_out, "logits"):
|
893 |
+
out = bert_out.logits
|
894 |
+
else:
|
895 |
+
out = bert_out
|
896 |
+
out = self.final_activation(out)
|
897 |
+
return out
|
models/BERT_20240104-223349_loop_normalize_by_line_height_and_width_True_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00430.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9c4ae65e81c722f3732563942ab40447a186869bebb1bbc8433a782805e73ac3
|
3 |
+
size 86691676
|
models/BERT_20240104-233803_loop_normalize_by_line_height_and_width_False_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00719.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a7588696e4afc4c8ffb0ff361d9566b7b360c61a3bb6fd6fcb484942b6d2568b
|
3 |
+
size 86692053
|
models/BERT_20240107-152040_loop_restrict_sim_data_to_4000_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00515.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:815b5500a1ae0a04bb55ae58c3896f07981757a2e1a2adf2cbc8a346551d88df
|
3 |
+
size 86686270
|
models/BERT_20240108-000344_loop_normalize_by_line_height_and_width_False_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00706.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f2e56e1e33da611622315995e0cdf4db5aad6a086420401ca3ee95393b8977ac
|
3 |
+
size 86692053
|
models/BERT_20240108-011230_loop_normalize_by_line_height_and_width_True_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00560.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4f060242cf0bc494d2908e0e99e9d411c9a9b131443cff91bb245229dad2f783
|
3 |
+
size 86691676
|
models/BERT_20240109-090419_loop_normalize_by_line_height_and_width_False_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00518.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bbf23ac7baa88a957e1782158bd7a32aedcfcb0527b203079191ac259ec146c5
|
3 |
+
size 86692053
|
models/BERT_20240122-183729_loop_normalize_by_line_height_and_width_True_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00523.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3fb7c8238752af51b64a23291080bb30edf9e090defcb2ec4015ddc8d543a9de
|
3 |
+
size 86691740
|
models/BERT_20240122-194041_loop_normalize_by_line_height_and_width_False_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00462.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:54fedcc5bdeda01bfae26bafcb7542c766807f1af9da7731aaa7ed38e93743d8
|
3 |
+
size 86692117
|
models/BERT_fin_exp_20240104-223349.yaml
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
add_layer_norm_to_char_mlp: true
|
2 |
+
add_layer_norm_to_in_projection: false
|
3 |
+
add_line_overlap_feature: true
|
4 |
+
add_normalised_values_as_features: false
|
5 |
+
change_pooling_for_timm_head_to: AdaptiveAvgPool2d
|
6 |
+
char_dims: 0
|
7 |
+
char_plot_shape:
|
8 |
+
- 224
|
9 |
+
- 224
|
10 |
+
chars_bert_reduction_factor: 4
|
11 |
+
chars_conv_lr_reduction_factor: 1
|
12 |
+
chars_conv_pooling_out_dim: 1
|
13 |
+
convert_posix: false
|
14 |
+
convert_winpath: false
|
15 |
+
cv_char_modelname: coatnet_nano_rw_224
|
16 |
+
cv_modelname: null
|
17 |
+
early_stopping_patience: 15
|
18 |
+
gamma_multistep: null
|
19 |
+
gamma_step_factor: 0.5
|
20 |
+
gamma_step_size: 3000
|
21 |
+
head_multiplication_factor: 64
|
22 |
+
hidden_dim_bert: 512
|
23 |
+
hidden_dropout_prob: 0.0
|
24 |
+
im_partial_string: fixations_chars_channel_sep
|
25 |
+
input_padding_val: 10
|
26 |
+
last_activation: Identity
|
27 |
+
layer_norm_after_in_projection: true
|
28 |
+
linear_activation: GELU
|
29 |
+
load_best_checkpoint_at_end: false
|
30 |
+
loss_function: corn_loss
|
31 |
+
lr: 0.0004
|
32 |
+
lr_initial: '0.0004'
|
33 |
+
lr_sched_exp_fac: null
|
34 |
+
lr_scheduling: StepLR
|
35 |
+
manual_max_sequence_for_model: 500
|
36 |
+
max_len_chars_list: 0
|
37 |
+
max_seq_length: 500
|
38 |
+
method_chars_into_model: resnet
|
39 |
+
method_to_include_char_positions: concat
|
40 |
+
min_lr_anneal: 1e-6
|
41 |
+
model_to_use: BERT
|
42 |
+
multistep_milestones: null
|
43 |
+
n_layers_BERT: 4
|
44 |
+
norm_by_char_averages: false
|
45 |
+
norm_by_line_width: false
|
46 |
+
norm_coords_by_letter_min_x_y: false
|
47 |
+
normalize_by_line_height_and_width: true
|
48 |
+
num_attention_heads: 8
|
49 |
+
num_classes: 16
|
50 |
+
num_lin_layers: 1
|
51 |
+
num_warmup_steps: 3000
|
52 |
+
one_hot_y: false
|
53 |
+
ord_reg_loss_max: 16
|
54 |
+
ord_reg_loss_min: -1
|
55 |
+
padding_at_end: true
|
56 |
+
plot_histogram: true
|
57 |
+
plot_learning_curves: true
|
58 |
+
precision: 16-mixed
|
59 |
+
prediction_only: false
|
60 |
+
pretrained_model_name_to_load: null
|
61 |
+
profile_torch_run: false
|
62 |
+
reload_model: false
|
63 |
+
reload_model_date: null
|
64 |
+
remove_eval_idx_from_train_idx: true
|
65 |
+
remove_timm_classifier_head_pooling: true
|
66 |
+
sample_cols:
|
67 |
+
- x
|
68 |
+
- y
|
69 |
+
sample_means:
|
70 |
+
- 0.7326
|
71 |
+
- 6.6381
|
72 |
+
- 2.4717
|
73 |
+
sample_std:
|
74 |
+
- 0.2778
|
75 |
+
- 1.882
|
76 |
+
- 1.8562
|
77 |
+
sample_std_unscaled:
|
78 |
+
- 285.193
|
79 |
+
- 131.1842
|
80 |
+
- 1.8562
|
81 |
+
save_weights_only: true
|
82 |
+
set_max_seq_len_manually: true
|
83 |
+
set_num_classes_manually: true
|
84 |
+
source_for_pretrained_cv_model: timm
|
85 |
+
target_padding_number: -100
|
86 |
+
track_activations_via_hook: false
|
87 |
+
track_gradient_histogram: false
|
88 |
+
use_char_bounding_boxes: true
|
89 |
+
use_early_stopping: false
|
90 |
+
use_embedded_char_pos_info: true
|
91 |
+
use_fixation_duration_information: false
|
92 |
+
use_in_projection_bias: false
|
93 |
+
use_lr_warmup: true
|
94 |
+
use_pupil_size_information: false
|
95 |
+
use_reduce_on_plateau: false
|
96 |
+
use_start_time_as_input_col: false
|
97 |
+
use_training_steps_for_end_and_lr_decay: true
|
98 |
+
use_words_coords: false
|
99 |
+
warmup_exponent: 1
|
100 |
+
weight_decay: 0.0
|
models/BERT_fin_exp_20240104-233803.yaml
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
add_layer_norm_to_char_mlp: true
|
2 |
+
add_layer_norm_to_in_projection: false
|
3 |
+
add_line_overlap_feature: true
|
4 |
+
add_normalised_values_as_features: false
|
5 |
+
change_pooling_for_timm_head_to: AdaptiveAvgPool2d
|
6 |
+
char_dims: 0
|
7 |
+
char_plot_shape:
|
8 |
+
- 224
|
9 |
+
- 224
|
10 |
+
chars_bert_reduction_factor: 4
|
11 |
+
chars_conv_lr_reduction_factor: 1
|
12 |
+
chars_conv_pooling_out_dim: 1
|
13 |
+
convert_posix: false
|
14 |
+
convert_winpath: false
|
15 |
+
cv_char_modelname: coatnet_nano_rw_224
|
16 |
+
cv_modelname: null
|
17 |
+
early_stopping_patience: 15
|
18 |
+
gamma_multistep: null
|
19 |
+
gamma_step_factor: 0.5
|
20 |
+
gamma_step_size: 3000
|
21 |
+
head_multiplication_factor: 64
|
22 |
+
hidden_dim_bert: 512
|
23 |
+
hidden_dropout_prob: 0.0
|
24 |
+
im_partial_string: fixations_chars_channel_sep
|
25 |
+
input_padding_val: 10
|
26 |
+
last_activation: Identity
|
27 |
+
layer_norm_after_in_projection: true
|
28 |
+
linear_activation: GELU
|
29 |
+
load_best_checkpoint_at_end: false
|
30 |
+
loss_function: corn_loss
|
31 |
+
lr: 0.0004
|
32 |
+
lr_initial: '0.0004'
|
33 |
+
lr_sched_exp_fac: null
|
34 |
+
lr_scheduling: StepLR
|
35 |
+
manual_max_sequence_for_model: 500
|
36 |
+
max_len_chars_list: 0
|
37 |
+
max_seq_length: 500
|
38 |
+
method_chars_into_model: resnet
|
39 |
+
method_to_include_char_positions: concat
|
40 |
+
min_lr_anneal: 1e-6
|
41 |
+
model_to_use: BERT
|
42 |
+
multistep_milestones: null
|
43 |
+
n_layers_BERT: 4
|
44 |
+
norm_by_char_averages: false
|
45 |
+
norm_by_line_width: false
|
46 |
+
norm_coords_by_letter_min_x_y: false
|
47 |
+
normalize_by_line_height_and_width: false
|
48 |
+
num_attention_heads: 8
|
49 |
+
num_classes: 16
|
50 |
+
num_lin_layers: 1
|
51 |
+
num_warmup_steps: 3000
|
52 |
+
one_hot_y: false
|
53 |
+
ord_reg_loss_max: 16
|
54 |
+
ord_reg_loss_min: -1
|
55 |
+
padding_at_end: true
|
56 |
+
plot_histogram: true
|
57 |
+
plot_learning_curves: true
|
58 |
+
precision: 16-mixed
|
59 |
+
prediction_only: false
|
60 |
+
pretrained_model_name_to_load: null
|
61 |
+
profile_torch_run: false
|
62 |
+
reload_model: false
|
63 |
+
reload_model_date: null
|
64 |
+
remove_eval_idx_from_train_idx: true
|
65 |
+
remove_timm_classifier_head_pooling: true
|
66 |
+
sample_cols:
|
67 |
+
- x
|
68 |
+
- y
|
69 |
+
sample_means:
|
70 |
+
- 710.6114
|
71 |
+
- 473.7518
|
72 |
+
- 2.4717
|
73 |
+
sample_std:
|
74 |
+
- 285.1937
|
75 |
+
- 131.1842
|
76 |
+
- 1.8562
|
77 |
+
sample_std_unscaled:
|
78 |
+
- 285.193
|
79 |
+
- 131.1842
|
80 |
+
- 1.8562
|
81 |
+
save_weights_only: true
|
82 |
+
set_max_seq_len_manually: true
|
83 |
+
set_num_classes_manually: true
|
84 |
+
source_for_pretrained_cv_model: timm
|
85 |
+
target_padding_number: -100
|
86 |
+
track_activations_via_hook: false
|
87 |
+
track_gradient_histogram: false
|
88 |
+
use_char_bounding_boxes: true
|
89 |
+
use_early_stopping: false
|
90 |
+
use_embedded_char_pos_info: true
|
91 |
+
use_fixation_duration_information: false
|
92 |
+
use_in_projection_bias: false
|
93 |
+
use_lr_warmup: true
|
94 |
+
use_pupil_size_information: false
|
95 |
+
use_reduce_on_plateau: false
|
96 |
+
use_start_time_as_input_col: false
|
97 |
+
use_training_steps_for_end_and_lr_decay: true
|
98 |
+
use_words_coords: false
|
99 |
+
warmup_exponent: 1
|
100 |
+
weight_decay: 0.0
|
models/BERT_fin_exp_20240107-152040.yaml
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
add_layer_norm_to_char_mlp: true
|
2 |
+
add_layer_norm_to_in_projection: false
|
3 |
+
add_line_overlap_feature: true
|
4 |
+
add_normalised_values_as_features: false
|
5 |
+
change_pooling_for_timm_head_to: AdaptiveAvgPool2d
|
6 |
+
char_dims: 0
|
7 |
+
char_plot_shape:
|
8 |
+
- 224
|
9 |
+
- 224
|
10 |
+
chars_bert_reduction_factor: 4
|
11 |
+
chars_conv_lr_reduction_factor: 1
|
12 |
+
chars_conv_pooling_out_dim: 1
|
13 |
+
convert_posix: false
|
14 |
+
convert_winpath: false
|
15 |
+
cv_char_modelname: coatnet_nano_rw_224
|
16 |
+
cv_modelname: null
|
17 |
+
early_stopping_patience: 15
|
18 |
+
gamma_multistep: null
|
19 |
+
gamma_step_factor: 0.5
|
20 |
+
gamma_step_size: 3000
|
21 |
+
head_multiplication_factor: 64
|
22 |
+
hidden_dim_bert: 512
|
23 |
+
hidden_dropout_prob: 0.0
|
24 |
+
im_partial_string: fixations_chars_channel_sep
|
25 |
+
input_padding_val: 10
|
26 |
+
last_activation: Identity
|
27 |
+
layer_norm_after_in_projection: true
|
28 |
+
linear_activation: GELU
|
29 |
+
load_best_checkpoint_at_end: false
|
30 |
+
loss_function: corn_loss
|
31 |
+
lr: 0.0004
|
32 |
+
lr_initial: '0.0004'
|
33 |
+
lr_sched_exp_fac: null
|
34 |
+
lr_scheduling: StepLR
|
35 |
+
manual_max_sequence_for_model: 500
|
36 |
+
max_len_chars_list: 0
|
37 |
+
max_seq_length: 500
|
38 |
+
method_chars_into_model: resnet
|
39 |
+
method_to_include_char_positions: concat
|
40 |
+
min_lr_anneal: 1e-6
|
41 |
+
model_to_use: BERT
|
42 |
+
multistep_milestones: null
|
43 |
+
n_layers_BERT: 4
|
44 |
+
norm_by_char_averages: false
|
45 |
+
norm_by_line_width: false
|
46 |
+
norm_coords_by_letter_min_x_y: true
|
47 |
+
normalize_by_line_height_and_width: true
|
48 |
+
num_attention_heads: 8
|
49 |
+
num_classes: 16
|
50 |
+
num_lin_layers: 1
|
51 |
+
num_warmup_steps: 3000
|
52 |
+
one_hot_y: false
|
53 |
+
ord_reg_loss_max: 16
|
54 |
+
ord_reg_loss_min: -1
|
55 |
+
padding_at_end: true
|
56 |
+
plot_histogram: true
|
57 |
+
plot_learning_curves: true
|
58 |
+
precision: 16-mixed
|
59 |
+
prediction_only: false
|
60 |
+
pretrained_model_name_to_load: null
|
61 |
+
profile_torch_run: false
|
62 |
+
reload_model: false
|
63 |
+
reload_model_date: null
|
64 |
+
remove_eval_idx_from_train_idx: true
|
65 |
+
remove_timm_classifier_head_pooling: true
|
66 |
+
sample_cols:
|
67 |
+
- x
|
68 |
+
- y
|
69 |
+
sample_means:
|
70 |
+
- 0.4423
|
71 |
+
- 3.1164
|
72 |
+
- 2.4717
|
73 |
+
sample_std:
|
74 |
+
- 0.2778
|
75 |
+
- 1.882
|
76 |
+
- 1.8562
|
77 |
+
sample_std_unscaled:
|
78 |
+
- 285.193
|
79 |
+
- 131.1842
|
80 |
+
- 1.8562
|
81 |
+
save_weights_only: true
|
82 |
+
set_max_seq_len_manually: true
|
83 |
+
set_num_classes_manually: true
|
84 |
+
source_for_pretrained_cv_model: timm
|
85 |
+
target_padding_number: -100
|
86 |
+
track_activations_via_hook: false
|
87 |
+
track_gradient_histogram: false
|
88 |
+
use_char_bounding_boxes: true
|
89 |
+
use_early_stopping: false
|
90 |
+
use_embedded_char_pos_info: true
|
91 |
+
use_fixation_duration_information: false
|
92 |
+
use_in_projection_bias: false
|
93 |
+
use_lr_warmup: true
|
94 |
+
use_pupil_size_information: false
|
95 |
+
use_reduce_on_plateau: false
|
96 |
+
use_start_time_as_input_col: false
|
97 |
+
use_training_steps_for_end_and_lr_decay: true
|
98 |
+
use_words_coords: false
|
99 |
+
warmup_exponent: 1
|
100 |
+
weight_decay: 0.0
|
models/BERT_fin_exp_20240108-000344.yaml
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
add_layer_norm_to_char_mlp: true
|
2 |
+
add_layer_norm_to_in_projection: false
|
3 |
+
add_line_overlap_feature: true
|
4 |
+
add_normalised_values_as_features: false
|
5 |
+
change_pooling_for_timm_head_to: AdaptiveAvgPool2d
|
6 |
+
char_dims: 0
|
7 |
+
char_plot_shape:
|
8 |
+
- 224
|
9 |
+
- 224
|
10 |
+
chars_bert_reduction_factor: 4
|
11 |
+
chars_conv_lr_reduction_factor: 1
|
12 |
+
chars_conv_pooling_out_dim: 1
|
13 |
+
convert_posix: false
|
14 |
+
convert_winpath: true
|
15 |
+
cv_char_modelname: coatnet_nano_rw_224
|
16 |
+
cv_modelname: null
|
17 |
+
early_stopping_patience: 15
|
18 |
+
gamma_multistep: null
|
19 |
+
gamma_step_factor: 0.5
|
20 |
+
gamma_step_size: 3000
|
21 |
+
head_multiplication_factor: 64
|
22 |
+
hidden_dim_bert: 512
|
23 |
+
hidden_dropout_prob: 0.0
|
24 |
+
im_partial_string: fixations_chars_channel_sep
|
25 |
+
input_padding_val: 10
|
26 |
+
last_activation: Identity
|
27 |
+
layer_norm_after_in_projection: true
|
28 |
+
linear_activation: GELU
|
29 |
+
load_best_checkpoint_at_end: false
|
30 |
+
loss_function: corn_loss
|
31 |
+
lr: 0.0004
|
32 |
+
lr_initial: '0.0004'
|
33 |
+
lr_sched_exp_fac: null
|
34 |
+
lr_scheduling: StepLR
|
35 |
+
manual_max_sequence_for_model: 500
|
36 |
+
max_len_chars_list: 0
|
37 |
+
max_seq_length: 500
|
38 |
+
method_chars_into_model: resnet
|
39 |
+
method_to_include_char_positions: concat
|
40 |
+
min_lr_anneal: 1e-6
|
41 |
+
model_to_use: BERT
|
42 |
+
multistep_milestones: null
|
43 |
+
n_layers_BERT: 4
|
44 |
+
norm_by_char_averages: false
|
45 |
+
norm_by_line_width: false
|
46 |
+
norm_coords_by_letter_min_x_y: true
|
47 |
+
normalize_by_line_height_and_width: false
|
48 |
+
num_attention_heads: 8
|
49 |
+
num_classes: 16
|
50 |
+
num_lin_layers: 1
|
51 |
+
num_warmup_steps: 3000
|
52 |
+
one_hot_y: false
|
53 |
+
ord_reg_loss_max: 16
|
54 |
+
ord_reg_loss_min: -1
|
55 |
+
padding_at_end: true
|
56 |
+
plot_histogram: true
|
57 |
+
plot_learning_curves: true
|
58 |
+
precision: 16-mixed
|
59 |
+
prediction_only: false
|
60 |
+
pretrained_model_name_to_load: null
|
61 |
+
profile_torch_run: false
|
62 |
+
reload_model: false
|
63 |
+
reload_model_date: null
|
64 |
+
remove_eval_idx_from_train_idx: true
|
65 |
+
remove_timm_classifier_head_pooling: true
|
66 |
+
sample_cols:
|
67 |
+
- x
|
68 |
+
- y
|
69 |
+
sample_means:
|
70 |
+
- 455.5905
|
71 |
+
- 218.0598
|
72 |
+
- 2.4717
|
73 |
+
sample_std:
|
74 |
+
- 285.1936
|
75 |
+
- 131.1842
|
76 |
+
- 1.8562
|
77 |
+
sample_std_unscaled:
|
78 |
+
- 285.1939
|
79 |
+
- 131.1844
|
80 |
+
- 1.8562
|
81 |
+
save_weights_only: true
|
82 |
+
set_max_seq_len_manually: true
|
83 |
+
set_num_classes_manually: true
|
84 |
+
source_for_pretrained_cv_model: timm
|
85 |
+
target_padding_number: -100
|
86 |
+
track_activations_via_hook: false
|
87 |
+
track_gradient_histogram: false
|
88 |
+
use_char_bounding_boxes: true
|
89 |
+
use_early_stopping: false
|
90 |
+
use_embedded_char_pos_info: true
|
91 |
+
use_fixation_duration_information: false
|
92 |
+
use_in_projection_bias: false
|
93 |
+
use_lr_warmup: true
|
94 |
+
use_pupil_size_information: false
|
95 |
+
use_reduce_on_plateau: false
|
96 |
+
use_start_time_as_input_col: false
|
97 |
+
use_training_steps_for_end_and_lr_decay: true
|
98 |
+
use_words_coords: false
|
99 |
+
warmup_exponent: 1
|
100 |
+
weight_decay: 0.0
|
models/BERT_fin_exp_20240108-011230.yaml
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
add_layer_norm_to_char_mlp: true
|
2 |
+
add_layer_norm_to_in_projection: false
|
3 |
+
add_line_overlap_feature: true
|
4 |
+
add_normalised_values_as_features: false
|
5 |
+
change_pooling_for_timm_head_to: AdaptiveAvgPool2d
|
6 |
+
char_dims: 0
|
7 |
+
char_plot_shape:
|
8 |
+
- 224
|
9 |
+
- 224
|
10 |
+
chars_bert_reduction_factor: 4
|
11 |
+
chars_conv_lr_reduction_factor: 1
|
12 |
+
chars_conv_pooling_out_dim: 1
|
13 |
+
convert_posix: false
|
14 |
+
convert_winpath: true
|
15 |
+
cv_char_modelname: coatnet_nano_rw_224
|
16 |
+
cv_modelname: null
|
17 |
+
early_stopping_patience: 15
|
18 |
+
gamma_multistep: null
|
19 |
+
gamma_step_factor: 0.5
|
20 |
+
gamma_step_size: 3000
|
21 |
+
head_multiplication_factor: 64
|
22 |
+
hidden_dim_bert: 512
|
23 |
+
hidden_dropout_prob: 0.0
|
24 |
+
im_partial_string: fixations_chars_channel_sep
|
25 |
+
input_padding_val: 10
|
26 |
+
last_activation: Identity
|
27 |
+
layer_norm_after_in_projection: true
|
28 |
+
linear_activation: GELU
|
29 |
+
load_best_checkpoint_at_end: false
|
30 |
+
loss_function: corn_loss
|
31 |
+
lr: 0.0004
|
32 |
+
lr_initial: '0.0004'
|
33 |
+
lr_sched_exp_fac: null
|
34 |
+
lr_scheduling: StepLR
|
35 |
+
manual_max_sequence_for_model: 500
|
36 |
+
max_len_chars_list: 0
|
37 |
+
max_seq_length: 500
|
38 |
+
method_chars_into_model: resnet
|
39 |
+
method_to_include_char_positions: concat
|
40 |
+
min_lr_anneal: 1e-6
|
41 |
+
model_to_use: BERT
|
42 |
+
multistep_milestones: null
|
43 |
+
n_layers_BERT: 4
|
44 |
+
norm_by_char_averages: false
|
45 |
+
norm_by_line_width: false
|
46 |
+
norm_coords_by_letter_min_x_y: true
|
47 |
+
normalize_by_line_height_and_width: true
|
48 |
+
num_attention_heads: 8
|
49 |
+
num_classes: 16
|
50 |
+
num_lin_layers: 1
|
51 |
+
num_warmup_steps: 3000
|
52 |
+
one_hot_y: false
|
53 |
+
ord_reg_loss_max: 16
|
54 |
+
ord_reg_loss_min: -1
|
55 |
+
padding_at_end: true
|
56 |
+
plot_histogram: true
|
57 |
+
plot_learning_curves: true
|
58 |
+
precision: 16-mixed
|
59 |
+
prediction_only: false
|
60 |
+
pretrained_model_name_to_load: null
|
61 |
+
profile_torch_run: false
|
62 |
+
reload_model: false
|
63 |
+
reload_model_date: null
|
64 |
+
remove_eval_idx_from_train_idx: true
|
65 |
+
remove_timm_classifier_head_pooling: true
|
66 |
+
sample_cols:
|
67 |
+
- x
|
68 |
+
- y
|
69 |
+
sample_means:
|
70 |
+
- 0.4423
|
71 |
+
- 3.1164
|
72 |
+
- 2.4717
|
73 |
+
sample_std:
|
74 |
+
- 0.2778
|
75 |
+
- 1.882
|
76 |
+
- 1.8562
|
77 |
+
sample_std_unscaled:
|
78 |
+
- 285.1939
|
79 |
+
- 131.1844
|
80 |
+
- 1.8562
|
81 |
+
save_weights_only: true
|
82 |
+
set_max_seq_len_manually: true
|
83 |
+
set_num_classes_manually: true
|
84 |
+
source_for_pretrained_cv_model: timm
|
85 |
+
target_padding_number: -100
|
86 |
+
track_activations_via_hook: false
|
87 |
+
track_gradient_histogram: false
|
88 |
+
use_char_bounding_boxes: true
|
89 |
+
use_early_stopping: false
|
90 |
+
use_embedded_char_pos_info: true
|
91 |
+
use_fixation_duration_information: false
|
92 |
+
use_in_projection_bias: false
|
93 |
+
use_lr_warmup: true
|
94 |
+
use_pupil_size_information: false
|
95 |
+
use_reduce_on_plateau: false
|
96 |
+
use_start_time_as_input_col: false
|
97 |
+
use_training_steps_for_end_and_lr_decay: true
|
98 |
+
use_words_coords: false
|
99 |
+
warmup_exponent: 1
|
100 |
+
weight_decay: 0.0
|
models/BERT_fin_exp_20240109-090419.yaml
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
add_layer_norm_to_char_mlp: true
|
2 |
+
add_layer_norm_to_in_projection: false
|
3 |
+
add_line_overlap_feature: true
|
4 |
+
add_normalised_values_as_features: false
|
5 |
+
change_pooling_for_timm_head_to: AdaptiveAvgPool2d
|
6 |
+
char_dims: 0
|
7 |
+
char_plot_shape:
|
8 |
+
- 224
|
9 |
+
- 224
|
10 |
+
chars_bert_reduction_factor: 4
|
11 |
+
chars_conv_lr_reduction_factor: 1
|
12 |
+
chars_conv_pooling_out_dim: 1
|
13 |
+
convert_posix: false
|
14 |
+
convert_winpath: true
|
15 |
+
cv_char_modelname: coatnet_nano_rw_224
|
16 |
+
cv_modelname: null
|
17 |
+
early_stopping_patience: 15
|
18 |
+
gamma_multistep: null
|
19 |
+
gamma_step_factor: 0.5
|
20 |
+
gamma_step_size: 3000
|
21 |
+
head_multiplication_factor: 64
|
22 |
+
hidden_dim_bert: 512
|
23 |
+
hidden_dropout_prob: 0.0
|
24 |
+
im_partial_string: fixations_chars_channel_sep
|
25 |
+
input_padding_val: 10
|
26 |
+
last_activation: Identity
|
27 |
+
layer_norm_after_in_projection: true
|
28 |
+
linear_activation: GELU
|
29 |
+
load_best_checkpoint_at_end: false
|
30 |
+
loss_function: corn_loss
|
31 |
+
lr: 0.0004
|
32 |
+
lr_initial: '0.0004'
|
33 |
+
lr_sched_exp_fac: null
|
34 |
+
lr_scheduling: StepLR
|
35 |
+
manual_max_sequence_for_model: 500
|
36 |
+
max_len_chars_list: 0
|
37 |
+
max_seq_length: 500
|
38 |
+
method_chars_into_model: resnet
|
39 |
+
method_to_include_char_positions: concat
|
40 |
+
min_lr_anneal: 1e-6
|
41 |
+
model_to_use: BERT
|
42 |
+
multistep_milestones: null
|
43 |
+
n_layers_BERT: 4
|
44 |
+
norm_by_char_averages: false
|
45 |
+
norm_by_line_width: false
|
46 |
+
norm_coords_by_letter_min_x_y: true
|
47 |
+
normalize_by_line_height_and_width: false
|
48 |
+
num_attention_heads: 8
|
49 |
+
num_classes: 16
|
50 |
+
num_lin_layers: 1
|
51 |
+
num_warmup_steps: 3000
|
52 |
+
one_hot_y: false
|
53 |
+
ord_reg_loss_max: 16
|
54 |
+
ord_reg_loss_min: -1
|
55 |
+
padding_at_end: true
|
56 |
+
plot_histogram: true
|
57 |
+
plot_learning_curves: true
|
58 |
+
precision: 16-mixed
|
59 |
+
prediction_only: false
|
60 |
+
pretrained_model_name_to_load: null
|
61 |
+
profile_torch_run: false
|
62 |
+
reload_model: false
|
63 |
+
reload_model_date: null
|
64 |
+
remove_eval_idx_from_train_idx: true
|
65 |
+
remove_timm_classifier_head_pooling: true
|
66 |
+
sample_cols:
|
67 |
+
- x
|
68 |
+
- y
|
69 |
+
sample_means:
|
70 |
+
- 455.708
|
71 |
+
- 217.8342
|
72 |
+
- 2.4706
|
73 |
+
sample_std:
|
74 |
+
- 285.2534
|
75 |
+
- 131.0263
|
76 |
+
- 1.8542
|
77 |
+
sample_std_unscaled:
|
78 |
+
- 285.2527
|
79 |
+
- 131.0262
|
80 |
+
- 1.8543
|
81 |
+
save_weights_only: true
|
82 |
+
set_max_seq_len_manually: true
|
83 |
+
set_num_classes_manually: true
|
84 |
+
source_for_pretrained_cv_model: timm
|
85 |
+
target_padding_number: -100
|
86 |
+
track_activations_via_hook: false
|
87 |
+
track_gradient_histogram: false
|
88 |
+
use_char_bounding_boxes: true
|
89 |
+
use_early_stopping: false
|
90 |
+
use_embedded_char_pos_info: true
|
91 |
+
use_fixation_duration_information: false
|
92 |
+
use_in_projection_bias: false
|
93 |
+
use_lr_warmup: true
|
94 |
+
use_pupil_size_information: false
|
95 |
+
use_reduce_on_plateau: false
|
96 |
+
use_start_time_as_input_col: false
|
97 |
+
use_training_steps_for_end_and_lr_decay: true
|
98 |
+
use_words_coords: false
|
99 |
+
warmup_exponent: 1
|
100 |
+
weight_decay: 0.0
|
models/BERT_fin_exp_20240122-183729.yaml
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
add_layer_norm_to_char_mlp: true
|
2 |
+
add_layer_norm_to_in_projection: false
|
3 |
+
add_line_overlap_feature: true
|
4 |
+
add_normalised_values_as_features: false
|
5 |
+
add_woc_feature: false
|
6 |
+
change_pooling_for_timm_head_to: AdaptiveAvgPool2d
|
7 |
+
char_dims: 0
|
8 |
+
char_plot_shape:
|
9 |
+
- 224
|
10 |
+
- 224
|
11 |
+
chars_bert_reduction_factor: 4
|
12 |
+
chars_conv_lr_reduction_factor: 1
|
13 |
+
chars_conv_pooling_out_dim: 1
|
14 |
+
convert_posix: false
|
15 |
+
convert_winpath: false
|
16 |
+
cv_char_modelname: coatnet_nano_rw_224
|
17 |
+
cv_modelname: null
|
18 |
+
early_stopping_patience: 15
|
19 |
+
gamma_multistep: null
|
20 |
+
gamma_step_factor: 0.5
|
21 |
+
gamma_step_size: 3000
|
22 |
+
head_multiplication_factor: 64
|
23 |
+
hidden_dim_bert: 512
|
24 |
+
hidden_dropout_prob: 0.0
|
25 |
+
im_partial_string: fixations_chars_channel_sep
|
26 |
+
input_padding_val: 10
|
27 |
+
last_activation: Identity
|
28 |
+
layer_norm_after_in_projection: true
|
29 |
+
linear_activation: GELU
|
30 |
+
load_best_checkpoint_at_end: false
|
31 |
+
loss_function: corn_loss
|
32 |
+
lr: 0.0004
|
33 |
+
lr_initial: '0.0004'
|
34 |
+
lr_sched_exp_fac: null
|
35 |
+
lr_scheduling: StepLR
|
36 |
+
manual_max_sequence_for_model: 500
|
37 |
+
max_len_chars_list: 0
|
38 |
+
max_seq_length: 500
|
39 |
+
method_chars_into_model: resnet
|
40 |
+
method_to_include_char_positions: concat
|
41 |
+
min_lr_anneal: 1e-6
|
42 |
+
model_to_use: BERT
|
43 |
+
multistep_milestones: null
|
44 |
+
n_layers_BERT: 4
|
45 |
+
norm_by_char_averages: false
|
46 |
+
norm_by_line_width: false
|
47 |
+
norm_coords_by_letter_min_x_y: true
|
48 |
+
normalize_by_line_height_and_width: true
|
49 |
+
num_attention_heads: 8
|
50 |
+
num_classes: 16
|
51 |
+
num_lin_layers: 1
|
52 |
+
num_warmup_steps: 3000
|
53 |
+
one_hot_y: false
|
54 |
+
only_use_2nd_input_stream: false
|
55 |
+
ord_reg_loss_max: 16
|
56 |
+
ord_reg_loss_min: -1
|
57 |
+
padding_at_end: true
|
58 |
+
plot_histogram: true
|
59 |
+
plot_learning_curves: true
|
60 |
+
precision: 16-mixed
|
61 |
+
prediction_only: false
|
62 |
+
pretrained_model_name_to_load: null
|
63 |
+
profile_torch_run: false
|
64 |
+
reload_model: false
|
65 |
+
reload_model_date: null
|
66 |
+
remove_eval_idx_from_train_idx: true
|
67 |
+
remove_timm_classifier_head_pooling: true
|
68 |
+
sample_cols:
|
69 |
+
- x
|
70 |
+
- y
|
71 |
+
sample_means:
|
72 |
+
- 0.4433
|
73 |
+
- 2.9599
|
74 |
+
- 2.3264
|
75 |
+
sample_std:
|
76 |
+
- 0.2782
|
77 |
+
- 1.7872
|
78 |
+
- 1.7619
|
79 |
+
sample_std_unscaled:
|
80 |
+
- 287.0107
|
81 |
+
- 124.4113
|
82 |
+
- 1.7619
|
83 |
+
save_weights_only: true
|
84 |
+
set_max_seq_len_manually: true
|
85 |
+
set_num_classes_manually: true
|
86 |
+
source_for_pretrained_cv_model: timm
|
87 |
+
target_padding_number: -100
|
88 |
+
track_activations_via_hook: false
|
89 |
+
track_gradient_histogram: false
|
90 |
+
use_char_bounding_boxes: true
|
91 |
+
use_early_stopping: false
|
92 |
+
use_embedded_char_pos_info: true
|
93 |
+
use_fixation_duration_information: false
|
94 |
+
use_in_projection_bias: false
|
95 |
+
use_lr_warmup: true
|
96 |
+
use_pupil_size_information: false
|
97 |
+
use_reduce_on_plateau: false
|
98 |
+
use_start_time_as_input_col: false
|
99 |
+
use_training_steps_for_end_and_lr_decay: true
|
100 |
+
use_words_coords: false
|
101 |
+
warmup_exponent: 1
|
102 |
+
weight_decay: 0.0
|
models/BERT_fin_exp_20240122-194041.yaml
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
add_layer_norm_to_char_mlp: true
|
2 |
+
add_layer_norm_to_in_projection: false
|
3 |
+
add_line_overlap_feature: true
|
4 |
+
add_normalised_values_as_features: false
|
5 |
+
add_woc_feature: false
|
6 |
+
change_pooling_for_timm_head_to: AdaptiveAvgPool2d
|
7 |
+
char_dims: 0
|
8 |
+
char_plot_shape:
|
9 |
+
- 224
|
10 |
+
- 224
|
11 |
+
chars_bert_reduction_factor: 4
|
12 |
+
chars_conv_lr_reduction_factor: 1
|
13 |
+
chars_conv_pooling_out_dim: 1
|
14 |
+
convert_posix: false
|
15 |
+
convert_winpath: false
|
16 |
+
cv_char_modelname: coatnet_nano_rw_224
|
17 |
+
cv_modelname: null
|
18 |
+
early_stopping_patience: 15
|
19 |
+
gamma_multistep: null
|
20 |
+
gamma_step_factor: 0.5
|
21 |
+
gamma_step_size: 3000
|
22 |
+
head_multiplication_factor: 64
|
23 |
+
hidden_dim_bert: 512
|
24 |
+
hidden_dropout_prob: 0.0
|
25 |
+
im_partial_string: fixations_chars_channel_sep
|
26 |
+
input_padding_val: 10
|
27 |
+
last_activation: Identity
|
28 |
+
layer_norm_after_in_projection: true
|
29 |
+
linear_activation: GELU
|
30 |
+
load_best_checkpoint_at_end: false
|
31 |
+
loss_function: corn_loss
|
32 |
+
lr: 0.0004
|
33 |
+
lr_initial: '0.0004'
|
34 |
+
lr_sched_exp_fac: null
|
35 |
+
lr_scheduling: StepLR
|
36 |
+
manual_max_sequence_for_model: 500
|
37 |
+
max_len_chars_list: 0
|
38 |
+
max_seq_length: 500
|
39 |
+
method_chars_into_model: resnet
|
40 |
+
method_to_include_char_positions: concat
|
41 |
+
min_lr_anneal: 1e-6
|
42 |
+
model_to_use: BERT
|
43 |
+
multistep_milestones: null
|
44 |
+
n_layers_BERT: 4
|
45 |
+
norm_by_char_averages: false
|
46 |
+
norm_by_line_width: false
|
47 |
+
norm_coords_by_letter_min_x_y: true
|
48 |
+
normalize_by_line_height_and_width: false
|
49 |
+
num_attention_heads: 8
|
50 |
+
num_classes: 16
|
51 |
+
num_lin_layers: 1
|
52 |
+
num_warmup_steps: 3000
|
53 |
+
one_hot_y: false
|
54 |
+
only_use_2nd_input_stream: false
|
55 |
+
ord_reg_loss_max: 16
|
56 |
+
ord_reg_loss_min: -1
|
57 |
+
padding_at_end: true
|
58 |
+
plot_histogram: true
|
59 |
+
plot_learning_curves: true
|
60 |
+
precision: 16-mixed
|
61 |
+
prediction_only: false
|
62 |
+
pretrained_model_name_to_load: null
|
63 |
+
profile_torch_run: false
|
64 |
+
reload_model: false
|
65 |
+
reload_model_date: null
|
66 |
+
remove_eval_idx_from_train_idx: true
|
67 |
+
remove_timm_classifier_head_pooling: true
|
68 |
+
sample_cols:
|
69 |
+
- x
|
70 |
+
- y
|
71 |
+
sample_means:
|
72 |
+
- 459.3367
|
73 |
+
- 206.88
|
74 |
+
- 2.3264
|
75 |
+
sample_std:
|
76 |
+
- 287.0111
|
77 |
+
- 124.4113
|
78 |
+
- 1.7619
|
79 |
+
sample_std_unscaled:
|
80 |
+
- 287.0107
|
81 |
+
- 124.4113
|
82 |
+
- 1.7619
|
83 |
+
save_weights_only: true
|
84 |
+
set_max_seq_len_manually: true
|
85 |
+
set_num_classes_manually: true
|
86 |
+
source_for_pretrained_cv_model: timm
|
87 |
+
target_padding_number: -100
|
88 |
+
track_activations_via_hook: false
|
89 |
+
track_gradient_histogram: false
|
90 |
+
use_char_bounding_boxes: true
|
91 |
+
use_early_stopping: false
|
92 |
+
use_embedded_char_pos_info: true
|
93 |
+
use_fixation_duration_information: false
|
94 |
+
use_in_projection_bias: false
|
95 |
+
use_lr_warmup: true
|
96 |
+
use_pupil_size_information: false
|
97 |
+
use_reduce_on_plateau: false
|
98 |
+
use_start_time_as_input_col: false
|
99 |
+
use_training_steps_for_end_and_lr_decay: true
|
100 |
+
use_words_coords: false
|
101 |
+
warmup_exponent: 1
|
102 |
+
weight_decay: 0.0
|
requirements.txt
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets
|
2 |
+
einops
|
3 |
+
matplotlib
|
4 |
+
numpy
|
5 |
+
pandas
|
6 |
+
PyYAML
|
7 |
+
seaborn
|
8 |
+
tqdm
|
9 |
+
transformers==4.30.2
|
10 |
+
tensorboard
|
11 |
+
torchmetrics
|
12 |
+
pytorch-lightning
|
13 |
+
scikit-learn
|
14 |
+
plotly
|
15 |
+
lovely-tensors
|
16 |
+
timm
|
17 |
+
openpyxl
|
18 |
+
torch==2.*
|
19 |
+
pydantic==1.10
|
20 |
+
streamlit
|
21 |
+
pycairo
|
22 |
+
eyekit
|
23 |
+
stqdm
|
24 |
+
jellyfish
|
25 |
+
icecream
|
run_in_notebook.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
utils.py
ADDED
@@ -0,0 +1,2016 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import zipfile
|
2 |
+
import os
|
3 |
+
import plotly.express as px
|
4 |
+
import plotly.graph_objects as go
|
5 |
+
from torch.utils.data.dataloader import DataLoader as dl
|
6 |
+
import yaml
|
7 |
+
from io import StringIO
|
8 |
+
import torch as t
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
from torch.utils.data import Dataset as torch_dset
|
12 |
+
from PIL import Image
|
13 |
+
import torchvision.transforms.functional as tvfunc
|
14 |
+
import json
|
15 |
+
from matplotlib import pyplot as plt
|
16 |
+
import matplotlib.patches as patches
|
17 |
+
from matplotlib.font_manager import FontProperties
|
18 |
+
import pathlib as pl
|
19 |
+
import matplotlib as mpl
|
20 |
+
import streamlit as st
|
21 |
+
from streamlit.runtime.uploaded_file_manager import UploadedFile
|
22 |
+
import einops as eo
|
23 |
+
import copy
|
24 |
+
|
25 |
+
# import stqdm
|
26 |
+
from tqdm.auto import tqdm
|
27 |
+
import time
|
28 |
+
import requests
|
29 |
+
|
30 |
+
from matplotlib.patches import Rectangle
|
31 |
+
from matplotlib import font_manager
|
32 |
+
from models import LitModel, EnsembleModel
|
33 |
+
from loss_functions import corn_label_from_logits
|
34 |
+
import classic_correction_algos as calgo
|
35 |
+
import analysis_funcs as anf
|
36 |
+
|
37 |
+
TEMP_FOLDER = pl.Path("results")
|
38 |
+
AVAILABLE_FONTS = [x.name for x in font_manager.fontManager.ttflist]
|
39 |
+
PLOTS_FOLDER = pl.Path("plots")
|
40 |
+
TEMP_FIGURE_STIMULUS_PATH = PLOTS_FOLDER / "temp_matplotlib_plot_stimulus.png"
|
41 |
+
all_fonts = [x.name for x in font_manager.fontManager.ttflist]
|
42 |
+
mpl.use("agg")
|
43 |
+
|
44 |
+
DIST_MODELS_FOLDER = pl.Path("models")
|
45 |
+
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
46 |
+
IMAGENET_STD = [0.229, 0.224, 0.225]
|
47 |
+
gradio_plots = pl.Path("plots")
|
48 |
+
|
49 |
+
event_strs = [
|
50 |
+
"EFIX",
|
51 |
+
"EFIX R",
|
52 |
+
"EFIX L",
|
53 |
+
"SSACC",
|
54 |
+
"ESACC",
|
55 |
+
"SFIX",
|
56 |
+
"MSG",
|
57 |
+
"SBLINK",
|
58 |
+
"EBLINK",
|
59 |
+
"BUTTON",
|
60 |
+
"INPUT",
|
61 |
+
"END",
|
62 |
+
"START",
|
63 |
+
"DISPLAY ON",
|
64 |
+
]
|
65 |
+
names_dict = {
|
66 |
+
"SSACC": {"Descr": "Start of Saccade", "Pattern": "SSACC <eye > <stime>"},
|
67 |
+
"ESACC": {
|
68 |
+
"Descr": "End of Saccade",
|
69 |
+
"Pattern": "ESACC <eye > <stime> <etime > <dur> <sxp > <syp> <exp > <eyp> <ampl > <pv >",
|
70 |
+
},
|
71 |
+
"SFIX": {"Descr": "Start of Fixation", "Pattern": "SFIX <eye > <stime>"},
|
72 |
+
"EFIX": {"Descr": "End of Fixation", "Pattern": "EFIX <eye > <stime> <etime > <dur> <axp > <ayp> <aps >"},
|
73 |
+
"SBLINK": {"Descr": "Start of Blink", "Pattern": "SBLINK <eye > <stime>"},
|
74 |
+
"EBLINK": {"Descr": "End of Blink", "Pattern": "EBLINK <eye > <stime> <etime > <dur>"},
|
75 |
+
"DISPLAY ON": {"Descr": "Actual start of Trial", "Pattern": "DISPLAY ON"},
|
76 |
+
}
|
77 |
+
metadata_strs = ["DISPLAY COORDS", "GAZE_COORDS", "FRAMERATE"]
|
78 |
+
|
79 |
+
ALGO_CHOICES = st.session_state["ALGO_CHOICES"] = [
|
80 |
+
"warp",
|
81 |
+
"regress",
|
82 |
+
"compare",
|
83 |
+
"attach",
|
84 |
+
"segment",
|
85 |
+
"split",
|
86 |
+
"stretch",
|
87 |
+
"chain",
|
88 |
+
"slice",
|
89 |
+
"cluster",
|
90 |
+
"merge",
|
91 |
+
"Wisdom_of_Crowds",
|
92 |
+
"DIST",
|
93 |
+
"DIST-Ensemble",
|
94 |
+
"Wisdom_of_Crowds_with_DIST",
|
95 |
+
"Wisdom_of_Crowds_with_DIST_Ensemble",
|
96 |
+
]
|
97 |
+
COLORS = px.colors.qualitative.Alphabet
|
98 |
+
|
99 |
+
|
100 |
+
class NumpyEncoder(json.JSONEncoder):
|
101 |
+
"From https://stackoverflow.com/questions/26646362/numpy-array-is-not-json-serializable"
|
102 |
+
|
103 |
+
def default(self, obj):
|
104 |
+
if isinstance(obj, np.ndarray):
|
105 |
+
return obj.tolist()
|
106 |
+
elif isinstance(obj, pl.Path) or isinstance(obj, UploadedFile):
|
107 |
+
return str(obj)
|
108 |
+
return json.JSONEncoder.default(self, obj)
|
109 |
+
|
110 |
+
|
111 |
+
class DSet(torch_dset):
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
in_sequence: t.Tensor,
|
115 |
+
chars_center_coords_padded: t.Tensor,
|
116 |
+
out_categories: t.Tensor,
|
117 |
+
trialslist: list,
|
118 |
+
padding_list: list = None,
|
119 |
+
padding_at_end: bool = False,
|
120 |
+
return_images_for_conv: bool = False,
|
121 |
+
im_partial_string: str = "fixations_chars_channel_sep",
|
122 |
+
input_im_shape=[224, 224],
|
123 |
+
) -> None:
|
124 |
+
super().__init__()
|
125 |
+
|
126 |
+
self.in_sequence = in_sequence
|
127 |
+
self.chars_center_coords_padded = chars_center_coords_padded
|
128 |
+
self.out_categories = out_categories
|
129 |
+
self.padding_list = padding_list
|
130 |
+
self.padding_at_end = padding_at_end
|
131 |
+
self.trialslist = trialslist
|
132 |
+
self.return_images_for_conv = return_images_for_conv
|
133 |
+
self.input_im_shape = input_im_shape
|
134 |
+
if return_images_for_conv:
|
135 |
+
self.im_partial_string = im_partial_string
|
136 |
+
self.plot_files = [
|
137 |
+
str(x["plot_file"]).replace("fixations_words", im_partial_string) for x in self.trialslist
|
138 |
+
]
|
139 |
+
|
140 |
+
def __getitem__(self, index):
|
141 |
+
|
142 |
+
if self.return_images_for_conv:
|
143 |
+
im = Image.open(self.plot_files[index])
|
144 |
+
if [im.size[1], im.size[0]] != self.input_im_shape:
|
145 |
+
im = tvfunc.resize(im, self.input_im_shape)
|
146 |
+
im = tvfunc.normalize(tvfunc.to_tensor(im), IMAGENET_MEAN, IMAGENET_STD)
|
147 |
+
if self.chars_center_coords_padded is not None:
|
148 |
+
if self.padding_list is not None:
|
149 |
+
attention_mask = t.ones(self.in_sequence[index].shape[:-1], dtype=t.long)
|
150 |
+
if self.padding_at_end:
|
151 |
+
if self.padding_list[index] > 0:
|
152 |
+
attention_mask[-self.padding_list[index] :] = 0
|
153 |
+
else:
|
154 |
+
attention_mask[: self.padding_list[index]] = 0
|
155 |
+
if self.return_images_for_conv:
|
156 |
+
return (
|
157 |
+
self.in_sequence[index],
|
158 |
+
self.chars_center_coords_padded[index],
|
159 |
+
im,
|
160 |
+
attention_mask,
|
161 |
+
self.out_categories[index],
|
162 |
+
)
|
163 |
+
return (
|
164 |
+
self.in_sequence[index],
|
165 |
+
self.chars_center_coords_padded[index],
|
166 |
+
attention_mask,
|
167 |
+
self.out_categories[index],
|
168 |
+
)
|
169 |
+
else:
|
170 |
+
if self.return_images_for_conv:
|
171 |
+
return (
|
172 |
+
self.in_sequence[index],
|
173 |
+
self.chars_center_coords_padded[index],
|
174 |
+
im,
|
175 |
+
self.out_categories[index],
|
176 |
+
)
|
177 |
+
else:
|
178 |
+
return (self.in_sequence[index], self.chars_center_coords_padded[index], self.out_categories[index])
|
179 |
+
|
180 |
+
if self.padding_list is not None:
|
181 |
+
attention_mask = t.ones(self.in_sequence[index].shape[:-1], dtype=t.long)
|
182 |
+
if self.padding_at_end:
|
183 |
+
if self.padding_list[index] > 0:
|
184 |
+
attention_mask[-self.padding_list[index] :] = 0
|
185 |
+
else:
|
186 |
+
attention_mask[: self.padding_list[index]] = 0
|
187 |
+
if self.return_images_for_conv:
|
188 |
+
return (self.in_sequence[index], im, attention_mask, self.out_categories[index])
|
189 |
+
else:
|
190 |
+
return (self.in_sequence[index], attention_mask, self.out_categories[index])
|
191 |
+
if self.return_images_for_conv:
|
192 |
+
return (self.in_sequence[index], im, self.out_categories[index])
|
193 |
+
else:
|
194 |
+
return (self.in_sequence[index], self.out_categories[index])
|
195 |
+
|
196 |
+
def __len__(self):
|
197 |
+
if isinstance(self.in_sequence, t.Tensor):
|
198 |
+
return self.in_sequence.shape[0]
|
199 |
+
else:
|
200 |
+
return len(self.in_sequence)
|
201 |
+
|
202 |
+
|
203 |
+
def download_url(url, target_filename):
|
204 |
+
r = requests.get(url)
|
205 |
+
open(target_filename, "wb").write(r.content)
|
206 |
+
return 0
|
207 |
+
|
208 |
+
|
209 |
+
def asc_to_trial_ids(asc_file, close_gap_between_words=True):
|
210 |
+
if "logger" in st.session_state:
|
211 |
+
st.session_state["logger"].debug("asc_to_trial_ids entered")
|
212 |
+
asc_encoding = ["ISO-8859-15", "UTF-8"][0]
|
213 |
+
trials_dict, lines = file_to_trials_and_lines(
|
214 |
+
asc_file, asc_encoding, close_gap_between_words=close_gap_between_words
|
215 |
+
)
|
216 |
+
|
217 |
+
trials_by_ids = {trials_dict[idx]["trial_id"]: trials_dict[idx] for idx in trials_dict["paragraph_trials"]}
|
218 |
+
if hasattr(asc_file, "name"):
|
219 |
+
if "logger" in st.session_state:
|
220 |
+
st.session_state["logger"].info(f"Found {len(trials_by_ids)} trials in {asc_file.name}.")
|
221 |
+
return trials_by_ids, lines
|
222 |
+
|
223 |
+
|
224 |
+
def get_trials_list(asc_file=None, close_gap_between_words=True):
|
225 |
+
if "logger" in st.session_state:
|
226 |
+
st.session_state["logger"].debug("get_trials_list entered")
|
227 |
+
|
228 |
+
if asc_file == None:
|
229 |
+
if "single_asc_file" in st.session_state.keys() and st.session_state["single_asc_file"] is not None:
|
230 |
+
asc_file = st.session_state["single_asc_file"]
|
231 |
+
else:
|
232 |
+
if "logger" in st.session_state:
|
233 |
+
st.session_state["logger"].warning("Asc file is None")
|
234 |
+
return None
|
235 |
+
|
236 |
+
if hasattr(asc_file, "name"):
|
237 |
+
if "logger" in st.session_state:
|
238 |
+
st.session_state["logger"].info(f"get_trials_list entered with asc_file {asc_file.name}")
|
239 |
+
|
240 |
+
trials_by_ids, lines = asc_to_trial_ids(asc_file, close_gap_between_words=close_gap_between_words)
|
241 |
+
trial_keys = list(trials_by_ids.keys())
|
242 |
+
|
243 |
+
return trial_keys, trials_by_ids, lines, asc_file
|
244 |
+
|
245 |
+
|
246 |
+
def save_trial_to_json(trial, savename):
|
247 |
+
if "dffix" in trial:
|
248 |
+
trial.pop("dffix")
|
249 |
+
with open(savename, "w", encoding="utf-8") as f:
|
250 |
+
json.dump(trial, f, ensure_ascii=False, indent=4, cls=NumpyEncoder)
|
251 |
+
|
252 |
+
|
253 |
+
def export_csv(dffix, trial):
|
254 |
+
if isinstance(dffix, dict):
|
255 |
+
dffix = dffix["value"]
|
256 |
+
trial_id = trial["trial_id"]
|
257 |
+
savename = TEMP_FOLDER.joinpath(pl.Path(trial["fname"]).stem)
|
258 |
+
trial_name = f"{savename}_{trial_id}_trial_info.json"
|
259 |
+
csv_name = f"{savename}_{trial_id}.csv"
|
260 |
+
dffix.to_csv(csv_name)
|
261 |
+
if "logger" in st.session_state:
|
262 |
+
st.session_state["logger"].info(f"Saved processed data as {csv_name}")
|
263 |
+
save_trial_to_json(trial, trial_name)
|
264 |
+
if "logger" in st.session_state:
|
265 |
+
st.session_state["logger"].info(f"Saved processed trial data as {trial_name}")
|
266 |
+
|
267 |
+
return csv_name, trial_name
|
268 |
+
|
269 |
+
|
270 |
+
def get_all_classic_preds(dffix, trial, classic_algos_cfg):
|
271 |
+
corrections = []
|
272 |
+
for algo, classic_params in copy.deepcopy(classic_algos_cfg).items():
|
273 |
+
dffix = calgo.apply_classic_algo(dffix, trial, algo, classic_params)
|
274 |
+
corrections.append(np.asarray(dffix.loc[:, f"y_{algo}"]))
|
275 |
+
return dffix, corrections
|
276 |
+
|
277 |
+
|
278 |
+
def apply_woc(dffix, trial, corrections, algo_choice):
|
279 |
+
|
280 |
+
corrected_Y = calgo.wisdom_of_the_crowd(corrections)
|
281 |
+
dffix.loc[:, f"y_{algo_choice}"] = corrected_Y
|
282 |
+
dffix[f"y_{algo_choice}_correction"] = (dffix.loc[:, f"y_{algo_choice}"] - dffix.loc[:, "y"]).round(1)
|
283 |
+
corrected_line_nums = [trial["y_char_unique"].index(y) for y in corrected_Y]
|
284 |
+
dffix.loc[:, f"line_num_y_{algo_choice}"] = corrected_line_nums
|
285 |
+
return dffix
|
286 |
+
|
287 |
+
|
288 |
+
def calc_xdiff_ydiff(line_xcoords_no_pad, line_ycoords_no_pad, line_heights, allow_multiple_values=False):
|
289 |
+
x_diffs = np.unique(np.diff(line_xcoords_no_pad))
|
290 |
+
if len(x_diffs) == 1:
|
291 |
+
x_diff = x_diffs[0]
|
292 |
+
elif not allow_multiple_values:
|
293 |
+
x_diff = np.min(x_diffs)
|
294 |
+
else:
|
295 |
+
x_diff = x_diffs
|
296 |
+
|
297 |
+
if np.unique(line_ycoords_no_pad).shape[0] == 1:
|
298 |
+
return x_diff, line_heights[0]
|
299 |
+
y_diffs = np.unique(np.diff(line_ycoords_no_pad))
|
300 |
+
if len(y_diffs) == 1:
|
301 |
+
y_diff = y_diffs[0]
|
302 |
+
elif len(y_diffs) == 0:
|
303 |
+
y_diff = 0
|
304 |
+
elif not allow_multiple_values:
|
305 |
+
y_diff = np.min(y_diffs)
|
306 |
+
else:
|
307 |
+
y_diff = y_diffs
|
308 |
+
return x_diff, y_diff
|
309 |
+
|
310 |
+
|
311 |
+
def add_words(trial, close_gap_between_words=True):
|
312 |
+
chars_list_reconstructed = []
|
313 |
+
words_list = []
|
314 |
+
word_start_idx = 0
|
315 |
+
chars_df = pd.DataFrame(trial["chars_list"])
|
316 |
+
chars_df["char_width"] = chars_df.char_xmax - chars_df.char_xmin
|
317 |
+
space_width = chars_df.loc[chars_df["char"] == " ", "char_width"].mean()
|
318 |
+
|
319 |
+
for idx, char_dict in enumerate(trial["chars_list"]):
|
320 |
+
on_line_num = char_dict["assigned_line"]
|
321 |
+
chars_list_reconstructed.append(char_dict)
|
322 |
+
if (
|
323 |
+
char_dict["char"] in [" ", ",", ";", ".", ":"]
|
324 |
+
or (
|
325 |
+
len(chars_list_reconstructed) > 2
|
326 |
+
and (chars_list_reconstructed[-1]["char_xmin"] < chars_list_reconstructed[-2]["char_xmin"])
|
327 |
+
)
|
328 |
+
or len(chars_list_reconstructed) == len(trial["chars_list"])
|
329 |
+
):
|
330 |
+
triggered = True
|
331 |
+
word_xmin = chars_list_reconstructed[word_start_idx]["char_xmin"]
|
332 |
+
word_xmax = chars_list_reconstructed[-2]["char_xmax"]
|
333 |
+
word_ymin = chars_list_reconstructed[word_start_idx]["char_ymin"]
|
334 |
+
word_ymax = chars_list_reconstructed[word_start_idx]["char_ymax"]
|
335 |
+
word_x_center = (word_xmax - word_xmin) / 2 + word_xmin
|
336 |
+
word_y_center = (word_ymax - word_ymin) / 2 + word_ymin
|
337 |
+
word = "".join(
|
338 |
+
[
|
339 |
+
chars_list_reconstructed[idx]["char"]
|
340 |
+
for idx in range(word_start_idx, len(chars_list_reconstructed) - 1)
|
341 |
+
]
|
342 |
+
)
|
343 |
+
assigned_line = chars_list_reconstructed[word_start_idx]["assigned_line"]
|
344 |
+
|
345 |
+
word_dict = dict(
|
346 |
+
word=word,
|
347 |
+
word_xmin=word_xmin,
|
348 |
+
word_xmax=word_xmax,
|
349 |
+
word_ymin=word_ymin,
|
350 |
+
word_ymax=word_ymax,
|
351 |
+
word_x_center=word_x_center,
|
352 |
+
word_y_center=word_y_center,
|
353 |
+
assigned_line=assigned_line,
|
354 |
+
)
|
355 |
+
if char_dict["char"] != " ":
|
356 |
+
word_start_idx = idx
|
357 |
+
else:
|
358 |
+
word_start_idx = idx + 1
|
359 |
+
words_list.append(word_dict)
|
360 |
+
else:
|
361 |
+
triggered = False
|
362 |
+
last_letter_in_word = word_dict["word"][-1]
|
363 |
+
last_letter_in_chars_list_reconstructed = char_dict["char"]
|
364 |
+
if last_letter_in_word != last_letter_in_chars_list_reconstructed:
|
365 |
+
word_dict = dict(
|
366 |
+
word=char_dict["char"],
|
367 |
+
word_xmin=char_dict["char_xmin"],
|
368 |
+
word_xmax=char_dict["char_xmax"],
|
369 |
+
word_ymin=char_dict["char_ymin"],
|
370 |
+
word_ymax=char_dict["char_ymax"],
|
371 |
+
word_x_center=char_dict["char_x_center"],
|
372 |
+
word_y_center=char_dict["char_y_center"],
|
373 |
+
assigned_line=assigned_line,
|
374 |
+
)
|
375 |
+
words_list.append(word_dict)
|
376 |
+
|
377 |
+
if close_gap_between_words:
|
378 |
+
for widx in range(1, len(words_list)):
|
379 |
+
if words_list[widx]["assigned_line"] == words_list[widx - 1]["assigned_line"]:
|
380 |
+
word_sep_half_width = (words_list[widx]["word_xmin"] - words_list[widx - 1]["word_xmax"]) / 2
|
381 |
+
words_list[widx - 1]["word_xmax"] = words_list[widx - 1]["word_xmax"] + word_sep_half_width
|
382 |
+
words_list[widx]["word_xmin"] = words_list[widx]["word_xmin"] - word_sep_half_width
|
383 |
+
|
384 |
+
return words_list
|
385 |
+
|
386 |
+
|
387 |
+
def asc_lines_to_trials_by_trail_id(
|
388 |
+
lines: list, paragraph_trials_only=False, fname: str = "", close_gap_between_words=True
|
389 |
+
) -> dict:
|
390 |
+
if hasattr(fname, "name"):
|
391 |
+
fname = fname.name
|
392 |
+
fps = -999
|
393 |
+
display_coords = -999
|
394 |
+
trials_dict = dict(paragraph_trials=[], paragraph_trial_IDs=[])
|
395 |
+
trial_idx = -1
|
396 |
+
removed_trial_ids = []
|
397 |
+
for idx, l in enumerate(lines):
|
398 |
+
parts = l.strip().split(" ")
|
399 |
+
if "TRIALID" in l:
|
400 |
+
trial_id = parts[-1]
|
401 |
+
trial_idx += 1
|
402 |
+
if trial_id[0] == "F":
|
403 |
+
trial_is = "question"
|
404 |
+
elif trial_id[0] == "P":
|
405 |
+
trial_is = "practice"
|
406 |
+
else:
|
407 |
+
trial_is = "paragraph"
|
408 |
+
trials_dict["paragraph_trials"].append(trial_idx)
|
409 |
+
trials_dict["paragraph_trial_IDs"].append(trial_id)
|
410 |
+
trials_dict[trial_idx] = dict(trial_id=trial_id, trial_id_idx=idx, trial_is=trial_is, filename=fname)
|
411 |
+
last_trial_skipped = False
|
412 |
+
|
413 |
+
elif "TRIAL_RESULT" in l or "stop_trial" in l:
|
414 |
+
trials_dict[trial_idx]["trial_result_idx"] = idx
|
415 |
+
trials_dict[trial_idx]["trial_result_timestamp"] = int(parts[0].split("\t")[1])
|
416 |
+
if len(parts) > 2:
|
417 |
+
trials_dict[trial_idx]["trial_result_number"] = int(parts[2])
|
418 |
+
elif "DISPLAY COORDS" in l and isinstance(display_coords, int):
|
419 |
+
display_coords = (float(parts[-4]), float(parts[-3]), float(parts[-2]), float(parts[-1]))
|
420 |
+
elif "GAZE_COORDS" in l and isinstance(display_coords, int):
|
421 |
+
display_coords = (float(parts[-4]), float(parts[-3]), float(parts[-2]), float(parts[-1]))
|
422 |
+
elif "FRAMERATE" in l:
|
423 |
+
l_idx = parts.index(metadata_strs[2])
|
424 |
+
fps = float(parts[l_idx + 1])
|
425 |
+
elif "TRIAL ABORTED" in l or "TRIAL REPEATED" in l:
|
426 |
+
if not last_trial_skipped:
|
427 |
+
if trial_is == "paragraph":
|
428 |
+
trials_dict["paragraph_trials"].remove(trial_idx)
|
429 |
+
trial_idx -= 1
|
430 |
+
removed_trial_ids.append(trial_id)
|
431 |
+
last_trial_skipped = True
|
432 |
+
|
433 |
+
if paragraph_trials_only:
|
434 |
+
trials_dict_temp = trials_dict.copy()
|
435 |
+
for k in trials_dict_temp.keys():
|
436 |
+
if k not in ["paragraph_trials"] + trials_dict_temp["paragraph_trials"]:
|
437 |
+
trials_dict.pop(k)
|
438 |
+
if len(trials_dict_temp["paragraph_trials"]):
|
439 |
+
trial_idx = trials_dict_temp["paragraph_trials"][-1]
|
440 |
+
else:
|
441 |
+
return trials_dict
|
442 |
+
trials_dict["display_coords"] = display_coords
|
443 |
+
trials_dict["fps"] = fps
|
444 |
+
trials_dict["max_trial_idx"] = trial_idx
|
445 |
+
enum = trials_dict["paragraph_trials"] if "paragraph_trials" in trials_dict.keys() else range(len(trials_dict))
|
446 |
+
for trial_idx in enum:
|
447 |
+
if trial_idx not in trials_dict.keys():
|
448 |
+
continue
|
449 |
+
chars_list = []
|
450 |
+
if "display_coords" not in trials_dict[trial_idx].keys():
|
451 |
+
trials_dict[trial_idx]["display_coords"] = trials_dict["display_coords"]
|
452 |
+
trial_start_idx = trials_dict[trial_idx]["trial_id_idx"]
|
453 |
+
trial_end_idx = trials_dict[trial_idx]["trial_result_idx"]
|
454 |
+
trial_lines = lines[trial_start_idx:trial_end_idx]
|
455 |
+
for idx, l in enumerate(trial_lines):
|
456 |
+
parts = l.strip().split(" ")
|
457 |
+
if "START" in l and " MSG" not in l:
|
458 |
+
trials_dict[trial_idx]["start_idx"] = trial_start_idx + idx + 7
|
459 |
+
trials_dict[trial_idx]["start_time"] = int(parts[0].split("\t")[1])
|
460 |
+
elif "END" in l and "ENDBUTTON" not in l and " MSG" not in l:
|
461 |
+
trials_dict[trial_idx]["end_idx"] = trial_start_idx + idx - 2
|
462 |
+
trials_dict[trial_idx]["end_time"] = int(parts[0].split("\t")[1])
|
463 |
+
elif "SYNCTIME" in l:
|
464 |
+
trials_dict[trial_idx]["synctime"] = trial_start_idx + idx
|
465 |
+
trials_dict[trial_idx]["synctime_time"] = int(parts[0].split("\t")[1])
|
466 |
+
elif "GAZE TARGET OFF" in l:
|
467 |
+
trials_dict[trial_idx]["gaze_targ_off_time"] = int(parts[0].split("\t")[1])
|
468 |
+
elif "GAZE TARGET ON" in l:
|
469 |
+
trials_dict[trial_idx]["gaze_targ_on_time"] = int(parts[0].split("\t")[1])
|
470 |
+
elif "DISPLAY_SENTENCE" in l: # some .asc files seem to use this
|
471 |
+
trials_dict[trial_idx]["gaze_targ_on_time"] = int(parts[0].split("\t")[1])
|
472 |
+
elif "REGION CHAR" in l:
|
473 |
+
rg_idx = parts.index("CHAR")
|
474 |
+
if len(parts[rg_idx:]) > 8:
|
475 |
+
char = " "
|
476 |
+
idx_correction = 1
|
477 |
+
elif len(parts[rg_idx:]) == 3:
|
478 |
+
char = " "
|
479 |
+
if "REGION CHAR" not in trial_lines[idx + 1]:
|
480 |
+
parts = trial_lines[idx + 1].strip().split(" ")
|
481 |
+
idx_correction = -rg_idx - 4
|
482 |
+
else:
|
483 |
+
char = parts[rg_idx + 3]
|
484 |
+
idx_correction = 0
|
485 |
+
try:
|
486 |
+
char_dict = {
|
487 |
+
"char": char,
|
488 |
+
"char_xmin": float(parts[rg_idx + 4 + idx_correction]),
|
489 |
+
"char_ymin": float(parts[rg_idx + 5 + idx_correction]),
|
490 |
+
"char_xmax": float(parts[rg_idx + 6 + idx_correction]),
|
491 |
+
"char_ymax": float(parts[rg_idx + 7 + idx_correction]),
|
492 |
+
}
|
493 |
+
char_dict["char_y_center"] = (char_dict["char_ymax"] - char_dict["char_ymin"]) / 2 + char_dict[
|
494 |
+
"char_ymin"
|
495 |
+
]
|
496 |
+
char_dict["char_x_center"] = (char_dict["char_xmax"] - char_dict["char_xmin"]) / 2 + char_dict[
|
497 |
+
"char_xmin"
|
498 |
+
]
|
499 |
+
chars_list.append(char_dict)
|
500 |
+
except Exception as e:
|
501 |
+
if "logger" in st.session_state:
|
502 |
+
st.session_state["logger"].warning(f"char_dict creation failed for parts {parts}")
|
503 |
+
if "logger" in st.session_state:
|
504 |
+
st.session_state["logger"].warning(e)
|
505 |
+
|
506 |
+
if "gaze_targ_on_time" in trials_dict[trial_idx]:
|
507 |
+
trials_dict[trial_idx]["trial_start_time"] = trials_dict[trial_idx]["gaze_targ_on_time"]
|
508 |
+
else:
|
509 |
+
trials_dict[trial_idx]["trial_start_time"] = trials_dict[trial_idx]["start_time"]
|
510 |
+
|
511 |
+
if len(chars_list) > 0:
|
512 |
+
line_ycoords = []
|
513 |
+
for idx in range(len(chars_list)):
|
514 |
+
chars_list[idx]["char_line_y"] = (
|
515 |
+
chars_list[idx]["char_ymax"] - chars_list[idx]["char_ymin"]
|
516 |
+
) / 2 + chars_list[idx]["char_ymin"]
|
517 |
+
if chars_list[idx]["char_line_y"] not in line_ycoords:
|
518 |
+
line_ycoords.append(chars_list[idx]["char_line_y"])
|
519 |
+
for idx in range(len(chars_list)):
|
520 |
+
chars_list[idx]["assigned_line"] = line_ycoords.index(chars_list[idx]["char_line_y"])
|
521 |
+
|
522 |
+
line_heights = [x["char_ymax"] - x["char_ymin"] for x in chars_list]
|
523 |
+
line_xcoords_all = [x["char_x_center"] for x in chars_list]
|
524 |
+
line_xcoords_no_pad = np.unique(line_xcoords_all)
|
525 |
+
|
526 |
+
line_ycoords_all = [x["char_y_center"] for x in chars_list]
|
527 |
+
line_ycoords_no_pad = np.unique(line_ycoords_all)
|
528 |
+
|
529 |
+
trials_dict[trial_idx]["x_char_unique"] = list(line_xcoords_no_pad)
|
530 |
+
trials_dict[trial_idx]["y_char_unique"] = list(line_ycoords_no_pad)
|
531 |
+
x_diff, y_diff = calc_xdiff_ydiff(
|
532 |
+
line_xcoords_no_pad, line_ycoords_no_pad, line_heights, allow_multiple_values=False
|
533 |
+
)
|
534 |
+
trials_dict[trial_idx]["x_diff"] = float(x_diff)
|
535 |
+
trials_dict[trial_idx]["y_diff"] = float(y_diff)
|
536 |
+
trials_dict[trial_idx]["num_char_lines"] = len(line_ycoords_no_pad)
|
537 |
+
trials_dict[trial_idx]["line_heights"] = line_heights
|
538 |
+
trials_dict[trial_idx]["chars_list"] = chars_list
|
539 |
+
|
540 |
+
words_list = add_words(trials_dict[trial_idx], close_gap_between_words=close_gap_between_words)
|
541 |
+
trials_dict[trial_idx]["words_list"] = words_list
|
542 |
+
|
543 |
+
return trials_dict
|
544 |
+
|
545 |
+
|
546 |
+
def file_to_trials_and_lines(uploaded_file, asc_encoding: str = "ISO-8859-15", close_gap_between_words=True):
|
547 |
+
if isinstance(uploaded_file, str) or isinstance(uploaded_file, pl.Path):
|
548 |
+
with open(uploaded_file, "r", encoding=asc_encoding) as f:
|
549 |
+
lines = f.readlines()
|
550 |
+
else:
|
551 |
+
stringio = StringIO(uploaded_file.getvalue().decode(asc_encoding))
|
552 |
+
loaded_str = stringio.read()
|
553 |
+
lines = loaded_str.split("\n")
|
554 |
+
trials_dict = asc_lines_to_trials_by_trail_id(
|
555 |
+
lines, True, uploaded_file, close_gap_between_words=close_gap_between_words
|
556 |
+
)
|
557 |
+
|
558 |
+
if "paragraph_trials" not in trials_dict.keys() and "trial_is" in trials_dict[0].keys():
|
559 |
+
paragraph_trials = []
|
560 |
+
for k in range(trials_dict["max_trial_idx"]):
|
561 |
+
if trials_dict[k]["trial_is"] == "paragraph":
|
562 |
+
paragraph_trials.append(k)
|
563 |
+
trials_dict["paragraph_trials"] = paragraph_trials
|
564 |
+
|
565 |
+
enum = (
|
566 |
+
trials_dict["paragraph_trials"]
|
567 |
+
if "paragraph_trials" in trials_dict.keys()
|
568 |
+
else range(trials_dict["max_trial_idx"])
|
569 |
+
)
|
570 |
+
for k in enum:
|
571 |
+
if "chars_list" in trials_dict[k].keys():
|
572 |
+
max_line = trials_dict[k]["chars_list"][-1]["assigned_line"]
|
573 |
+
words_on_lines = {x: [] for x in range(max_line + 1)}
|
574 |
+
[words_on_lines[x["assigned_line"]].append(x["char"]) for x in trials_dict[k]["chars_list"]]
|
575 |
+
sentence_list = ["".join([s for s in v]) for idx, v in words_on_lines.items()]
|
576 |
+
text = sentence_list[0] + "\n".join([x for x in sentence_list[1:]])
|
577 |
+
trials_dict[k]["sentence_list"] = sentence_list
|
578 |
+
trials_dict[k]["text"] = text
|
579 |
+
trials_dict[k]["max_line"] = max_line
|
580 |
+
|
581 |
+
return trials_dict, lines
|
582 |
+
|
583 |
+
|
584 |
+
def get_plot_props(trial, available_fonts):
|
585 |
+
if "font" in trial.keys():
|
586 |
+
font = trial["font"]
|
587 |
+
font_size = trial["font_size"]
|
588 |
+
if font not in available_fonts:
|
589 |
+
font = "DejaVu Sans Mono"
|
590 |
+
else:
|
591 |
+
font = "DejaVu Sans Mono"
|
592 |
+
font_size = 21
|
593 |
+
dpi = 100
|
594 |
+
if "display_coords" in trial.keys():
|
595 |
+
screen_res = (trial["display_coords"][2], trial["display_coords"][3])
|
596 |
+
else:
|
597 |
+
screen_res = (1920, 1080)
|
598 |
+
return font, font_size, dpi, screen_res
|
599 |
+
|
600 |
+
|
601 |
+
def trial_to_dfs(
|
602 |
+
trial: dict, lines: list, use_synctime: bool = False, save_lines_to_txt=False, cut_out_outer_fixations=False
|
603 |
+
):
|
604 |
+
"""trial should be dict of line numbers of trials.
|
605 |
+
lines should be list of lines from .asc file."""
|
606 |
+
|
607 |
+
if use_synctime and "synctime" in trial:
|
608 |
+
idx0, idxend = trial["synctime"] + 1, trial["trial_result_idx"]
|
609 |
+
else:
|
610 |
+
idx0, idxend = trial["start_idx"], trial["end_idx"]
|
611 |
+
|
612 |
+
line_dicts = []
|
613 |
+
fixations_dicts = []
|
614 |
+
blink_started = False
|
615 |
+
|
616 |
+
fixation_started = False
|
617 |
+
efix_count = 0
|
618 |
+
sfix_count = 0
|
619 |
+
sblink_count = 0
|
620 |
+
|
621 |
+
if save_lines_to_txt:
|
622 |
+
with open("Lines_plus500.txt", "w") as f:
|
623 |
+
f.writelines(lines[idx0 - 500 : idxend + 500])
|
624 |
+
eye_to_use = "R"
|
625 |
+
for l in lines[idx0 : idxend + 1]:
|
626 |
+
if "EFIX R" in l:
|
627 |
+
eye_to_use = "R"
|
628 |
+
break
|
629 |
+
elif "EFIX L" in l:
|
630 |
+
eye_to_use = "L"
|
631 |
+
break
|
632 |
+
for l in lines[idx0 : idxend + 1]:
|
633 |
+
parts = [x.strip() for x in l.split("\t")]
|
634 |
+
if f"EFIX {eye_to_use}" in l:
|
635 |
+
efix_count += 1
|
636 |
+
if fixation_started:
|
637 |
+
if parts[1] == "." and parts[2] == ".":
|
638 |
+
continue
|
639 |
+
fixations_dicts.append(
|
640 |
+
{
|
641 |
+
"start_time": float(parts[0].split()[-1].strip()),
|
642 |
+
"end_time": float(parts[1].strip()),
|
643 |
+
"duration": float(parts[2].strip()),
|
644 |
+
"x": float(parts[3].strip()),
|
645 |
+
"y": float(parts[4].strip()),
|
646 |
+
"pupil_size": float(parts[5].strip()),
|
647 |
+
}
|
648 |
+
)
|
649 |
+
if len(fixations_dicts) >= 2:
|
650 |
+
assert (
|
651 |
+
fixations_dicts[-1]["start_time"] > fixations_dicts[-2]["start_time"]
|
652 |
+
), "start times not in order"
|
653 |
+
fixation_started = False
|
654 |
+
|
655 |
+
elif f"SFIX {eye_to_use}" in l:
|
656 |
+
sfix_count += 1
|
657 |
+
fixation_started = True
|
658 |
+
elif f"SBLINK {eye_to_use}" in l:
|
659 |
+
sblink_count += 1
|
660 |
+
blink_started = True
|
661 |
+
if not blink_started and not any([True for x in event_strs if x in l]):
|
662 |
+
if len(parts) < 3 or (parts[1] == "." and parts[2] == "."):
|
663 |
+
continue
|
664 |
+
line_dicts.append(
|
665 |
+
{
|
666 |
+
"idx": float(parts[0].strip()),
|
667 |
+
"x": float(parts[1].strip()),
|
668 |
+
"y": float(parts[2].strip()),
|
669 |
+
"p": float(parts[3].strip()),
|
670 |
+
}
|
671 |
+
)
|
672 |
+
|
673 |
+
elif f"EBLINK {eye_to_use}" in l:
|
674 |
+
blink_started = False
|
675 |
+
|
676 |
+
df = pd.DataFrame(line_dicts)
|
677 |
+
dffix = pd.DataFrame(fixations_dicts)
|
678 |
+
if len(fixations_dicts) > 0:
|
679 |
+
dffix["corrected_start_time"] = dffix.start_time - trial["trial_start_time"]
|
680 |
+
dffix["corrected_end_time"] = dffix.end_time - trial["trial_start_time"]
|
681 |
+
dffix["fix_duration"] = dffix.corrected_end_time.values - dffix.corrected_start_time.values
|
682 |
+
assert all(np.diff(dffix["corrected_start_time"]) > 0), "start times not in order"
|
683 |
+
else:
|
684 |
+
df, pd.DataFrame(), trial
|
685 |
+
|
686 |
+
if cut_out_outer_fixations:
|
687 |
+
dffix = dffix[(dffix.x > -10) & (dffix.y > -10) & (dffix.x < 1050) & (dffix.y < 800)]
|
688 |
+
trial["efix_count"] = efix_count
|
689 |
+
trial["eye_to_use"] = eye_to_use
|
690 |
+
trial["sfix_count"] = sfix_count
|
691 |
+
trial["sblink_count"] = sblink_count
|
692 |
+
return df, dffix, trial
|
693 |
+
|
694 |
+
|
695 |
+
def get_save_path(fpath, fname_ending):
|
696 |
+
save_path = gradio_plots.joinpath(f"{fpath.stem}_{fname_ending}.png")
|
697 |
+
return save_path
|
698 |
+
|
699 |
+
|
700 |
+
def save_im_load_convert(fpath, fig, fname_ending, mode):
|
701 |
+
save_path = get_save_path(fpath, fname_ending)
|
702 |
+
fig.savefig(save_path)
|
703 |
+
im = Image.open(save_path).convert(mode)
|
704 |
+
im.save(save_path)
|
705 |
+
return im
|
706 |
+
|
707 |
+
|
708 |
+
def get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, dffix=None, prefix="word"):
|
709 |
+
fig = plt.figure(figsize=(screen_res[0] / dpi, screen_res[1] / dpi), dpi=dpi)
|
710 |
+
ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
|
711 |
+
ax.set_axis_off()
|
712 |
+
if dffix is not None:
|
713 |
+
ax.set_ylim((dffix.y.min(), dffix.y.max()))
|
714 |
+
ax.set_xlim((dffix.x.min(), dffix.x.max()))
|
715 |
+
else:
|
716 |
+
ax.set_ylim((words_df[f"{prefix}_y_center"].min() - y_margin, words_df[f"{prefix}_y_center"].max() + y_margin))
|
717 |
+
ax.set_xlim((words_df[f"{prefix}_x_center"].min() - x_margin, words_df[f"{prefix}_x_center"].max() + x_margin))
|
718 |
+
ax.invert_yaxis()
|
719 |
+
fig.add_axes(ax)
|
720 |
+
return fig, ax
|
721 |
+
|
722 |
+
|
723 |
+
def plot_text_boxes_fixations(
|
724 |
+
fpath,
|
725 |
+
dpi,
|
726 |
+
screen_res,
|
727 |
+
data_dir_sub,
|
728 |
+
set_font_size: bool,
|
729 |
+
font_size: int,
|
730 |
+
use_words: bool,
|
731 |
+
save_channel_repeats: bool,
|
732 |
+
save_combo_grey_and_rgb: bool,
|
733 |
+
dffix=None,
|
734 |
+
trial=None,
|
735 |
+
):
|
736 |
+
if isinstance(fpath, str):
|
737 |
+
fpath = pl.Path(fpath)
|
738 |
+
if use_words:
|
739 |
+
prefix = "word"
|
740 |
+
else:
|
741 |
+
prefix = "char"
|
742 |
+
if dffix is None:
|
743 |
+
dffix = pd.read_csv(fpath)
|
744 |
+
if trial is None:
|
745 |
+
json_fpath = str(fpath).replace("_fixations.csv", "_trial.json")
|
746 |
+
with open(json_fpath, "r") as f:
|
747 |
+
trial = json.load(f)
|
748 |
+
words_df = pd.DataFrame(trial[f"{prefix}s_list"])
|
749 |
+
x_right = words_df[f"{prefix}_xmin"]
|
750 |
+
x_left = words_df[f"{prefix}_xmax"]
|
751 |
+
y_top = words_df[f"{prefix}_ymax"]
|
752 |
+
y_bottom = words_df[f"{prefix}_ymin"]
|
753 |
+
|
754 |
+
if f"{prefix}_x_center" not in words_df.columns:
|
755 |
+
words_df[f"{prefix}_x_center"] = (words_df[f"{prefix}_xmax"] - words_df[f"{prefix}_xmin"]) / 2 + words_df[
|
756 |
+
f"{prefix}_xmin"
|
757 |
+
]
|
758 |
+
words_df[f"{prefix}_y_center"] = (words_df[f"{prefix}_ymax"] - words_df[f"{prefix}_ymin"]) / 2 + words_df[
|
759 |
+
f"{prefix}_ymin"
|
760 |
+
]
|
761 |
+
|
762 |
+
x_margin = words_df[f"{prefix}_x_center"].mean() / 8
|
763 |
+
y_margin = words_df[f"{prefix}_y_center"].mean() / 4
|
764 |
+
times = dffix.corrected_start_time - dffix.corrected_start_time.min()
|
765 |
+
times = times / times.max()
|
766 |
+
times = np.linspace(0.25, 1, len(times))
|
767 |
+
|
768 |
+
if set_font_size:
|
769 |
+
font = "monospace"
|
770 |
+
else:
|
771 |
+
font_size = trial["font_size"] * 27 // dpi
|
772 |
+
|
773 |
+
font_props = FontProperties(family=font, style="normal", size=font_size)
|
774 |
+
if save_combo_grey_and_rgb:
|
775 |
+
fig, ax = get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, prefix=prefix)
|
776 |
+
ax.scatter(dffix.x, dffix.y, alpha=times, facecolor="b")
|
777 |
+
for idx in range(len(x_left)):
|
778 |
+
xdiff = x_right[idx] - x_left[idx]
|
779 |
+
ydiff = y_top[idx] - y_bottom[idx]
|
780 |
+
rect = patches.Rectangle(
|
781 |
+
(x_left[idx] - 1, y_bottom[idx] - 1),
|
782 |
+
xdiff,
|
783 |
+
ydiff,
|
784 |
+
alpha=0.9,
|
785 |
+
linewidth=0.8,
|
786 |
+
edgecolor="r",
|
787 |
+
facecolor="none",
|
788 |
+
) # seems to need one pixel offset
|
789 |
+
ax.text(
|
790 |
+
words_df[f"{prefix}_x_center"][idx],
|
791 |
+
words_df[f"{prefix}_y_center"][idx],
|
792 |
+
words_df[prefix][idx],
|
793 |
+
horizontalalignment="center",
|
794 |
+
verticalalignment="center",
|
795 |
+
fontproperties=font_props,
|
796 |
+
color="g",
|
797 |
+
)
|
798 |
+
ax.add_patch(rect)
|
799 |
+
fname_ending = f"{prefix}s_combo_rgb"
|
800 |
+
words_combo_rgb_im = save_im_load_convert(fpath, fig, fname_ending, "RGB")
|
801 |
+
plt.close("all")
|
802 |
+
|
803 |
+
fig, ax = get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, prefix=prefix)
|
804 |
+
|
805 |
+
ax.scatter(dffix.x, dffix.y, facecolor="k", alpha=times)
|
806 |
+
for idx in range(len(x_left)):
|
807 |
+
xdiff = x_right[idx] - x_left[idx]
|
808 |
+
ydiff = y_top[idx] - y_bottom[idx]
|
809 |
+
rect = patches.Rectangle(
|
810 |
+
(x_left[idx] - 1, y_bottom[idx] - 1),
|
811 |
+
xdiff,
|
812 |
+
ydiff,
|
813 |
+
alpha=0.9,
|
814 |
+
linewidth=0.8,
|
815 |
+
edgecolor="k",
|
816 |
+
facecolor="none",
|
817 |
+
) # seems to need one pixel offset
|
818 |
+
ax.text(
|
819 |
+
words_df[f"{prefix}_x_center"][idx],
|
820 |
+
words_df[f"{prefix}_y_center"][idx],
|
821 |
+
words_df[prefix][idx],
|
822 |
+
horizontalalignment="center",
|
823 |
+
verticalalignment="center",
|
824 |
+
fontproperties=font_props,
|
825 |
+
)
|
826 |
+
ax.add_patch(rect)
|
827 |
+
fname_ending = f"{prefix}s_combo_grey"
|
828 |
+
words_combo_grey_im = save_im_load_convert(fpath, fig, fname_ending, "L")
|
829 |
+
plt.close("all")
|
830 |
+
|
831 |
+
fig, ax = get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, prefix=prefix)
|
832 |
+
|
833 |
+
ax.scatter(words_df[f"{prefix}_x_center"], words_df[f"{prefix}_y_center"], s=1, facecolor="k", alpha=0.01)
|
834 |
+
for idx in range(len(x_left)):
|
835 |
+
ax.text(
|
836 |
+
words_df[f"{prefix}_x_center"][idx],
|
837 |
+
words_df[f"{prefix}_y_center"][idx],
|
838 |
+
words_df[prefix][idx],
|
839 |
+
horizontalalignment="center",
|
840 |
+
verticalalignment="center",
|
841 |
+
fontproperties=font_props,
|
842 |
+
)
|
843 |
+
fname_ending = f"{prefix}s_grey"
|
844 |
+
words_grey_im = save_im_load_convert(fpath, fig, fname_ending, "L")
|
845 |
+
|
846 |
+
plt.close("all")
|
847 |
+
fig, ax = get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, prefix=prefix)
|
848 |
+
|
849 |
+
ax.scatter(words_df[f"{prefix}_x_center"], words_df[f"{prefix}_y_center"], s=1, facecolor="k", alpha=0.1)
|
850 |
+
for idx in range(len(x_left)):
|
851 |
+
xdiff = x_right[idx] - x_left[idx]
|
852 |
+
ydiff = y_top[idx] - y_bottom[idx]
|
853 |
+
rect = patches.Rectangle(
|
854 |
+
(x_left[idx] - 1, y_bottom[idx] - 1), xdiff, ydiff, alpha=0.9, linewidth=1, edgecolor="k", facecolor="grey"
|
855 |
+
) # seems to need one pixel offset
|
856 |
+
ax.add_patch(rect)
|
857 |
+
fname_ending = f"{prefix}_boxes_grey"
|
858 |
+
word_boxes_grey_im = save_im_load_convert(fpath, fig, fname_ending, "L")
|
859 |
+
|
860 |
+
plt.close("all")
|
861 |
+
|
862 |
+
fig, ax = get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, prefix=prefix)
|
863 |
+
|
864 |
+
ax.scatter(dffix.x, dffix.y, facecolor="k", alpha=times)
|
865 |
+
fname_ending = "fix_scatter_grey"
|
866 |
+
fix_scatter_grey_im = save_im_load_convert(fpath, fig, fname_ending, "L")
|
867 |
+
|
868 |
+
plt.close("all")
|
869 |
+
|
870 |
+
arr_combo = np.stack(
|
871 |
+
[
|
872 |
+
np.asarray(words_grey_im),
|
873 |
+
np.asarray(word_boxes_grey_im),
|
874 |
+
np.asarray(fix_scatter_grey_im),
|
875 |
+
],
|
876 |
+
axis=2,
|
877 |
+
)
|
878 |
+
|
879 |
+
im_combo = Image.fromarray(arr_combo)
|
880 |
+
fname_ending = f"{prefix}s_channel_sep"
|
881 |
+
|
882 |
+
save_path = get_save_path(fpath, fname_ending)
|
883 |
+
print(f"save_path for im combo is {save_path}")
|
884 |
+
im_combo.save(fpath)
|
885 |
+
|
886 |
+
if save_channel_repeats:
|
887 |
+
arr_combo = np.stack([np.asarray(words_grey_im)] * 3, axis=2)
|
888 |
+
im_combo = Image.fromarray(arr_combo)
|
889 |
+
fname_ending = f"{prefix}s_channel_repeat"
|
890 |
+
|
891 |
+
save_path = get_save_path(fpath, fname_ending)
|
892 |
+
im_combo.save(save_path)
|
893 |
+
|
894 |
+
arr_combo = np.stack([np.asarray(word_boxes_grey_im)] * 3, axis=2)
|
895 |
+
|
896 |
+
im_combo = Image.fromarray(arr_combo)
|
897 |
+
fname_ending = f"{prefix}boxes_channel_repeat"
|
898 |
+
|
899 |
+
save_path = get_save_path(fpath, fname_ending)
|
900 |
+
im_combo.save(save_path)
|
901 |
+
|
902 |
+
arr_combo = np.stack([np.asarray(fix_scatter_grey_im)] * 3, axis=2)
|
903 |
+
|
904 |
+
im_combo = Image.fromarray(arr_combo)
|
905 |
+
fname_ending = "fix_channel_repeat"
|
906 |
+
|
907 |
+
save_path = get_save_path(fpath, fname_ending)
|
908 |
+
im_combo.save(save_path)
|
909 |
+
|
910 |
+
|
911 |
+
def add_line_overlaps_to_sample(trial, sample):
|
912 |
+
char_df = pd.DataFrame(trial["chars_list"])
|
913 |
+
line_overlaps = []
|
914 |
+
for arr in sample:
|
915 |
+
y_val = arr[1]
|
916 |
+
line_overlap = t.tensor(-1, dtype=t.float32)
|
917 |
+
for idx, (x1, x2) in enumerate(zip(char_df.char_ymin.unique(), char_df.char_ymax.unique())):
|
918 |
+
if x1 <= y_val <= x2:
|
919 |
+
line_overlap = t.tensor(idx, dtype=t.float32)
|
920 |
+
break
|
921 |
+
line_overlaps.append(line_overlap)
|
922 |
+
line_olaps_tensor = t.stack(line_overlaps, dim=0)
|
923 |
+
sample = t.cat([sample, line_olaps_tensor.unsqueeze(1)], dim=1)
|
924 |
+
return sample
|
925 |
+
|
926 |
+
|
927 |
+
def norm_coords_by_letter_min_x_y(
|
928 |
+
sample_idx: int,
|
929 |
+
trialslist: list,
|
930 |
+
samplelist: list,
|
931 |
+
chars_center_coords_list: list = None,
|
932 |
+
):
|
933 |
+
chars_df = pd.DataFrame(trialslist[sample_idx]["chars_list"])
|
934 |
+
trialslist[sample_idx]["x_char_unique"] = chars_df.char_xmin.unique()
|
935 |
+
|
936 |
+
min_x_chars = chars_df.char_xmin.min()
|
937 |
+
min_y_chars = chars_df.char_ymin.min()
|
938 |
+
|
939 |
+
norm_vector_substract = t.zeros(
|
940 |
+
(1, samplelist[sample_idx].shape[1]), dtype=samplelist[sample_idx].dtype, device=samplelist[sample_idx].device
|
941 |
+
)
|
942 |
+
norm_vector_substract[0, 0] = norm_vector_substract[0, 0] + 1 * min_x_chars
|
943 |
+
norm_vector_substract[0, 1] = norm_vector_substract[0, 1] + 1 * min_y_chars
|
944 |
+
|
945 |
+
samplelist[sample_idx] = samplelist[sample_idx] - norm_vector_substract
|
946 |
+
|
947 |
+
if chars_center_coords_list is not None:
|
948 |
+
norm_vector_substract = norm_vector_substract.squeeze(0)[:2]
|
949 |
+
if chars_center_coords_list[sample_idx].shape[-1] == norm_vector_substract.shape[-1] * 2:
|
950 |
+
chars_center_coords_list[sample_idx][:, :2] -= norm_vector_substract
|
951 |
+
chars_center_coords_list[sample_idx][:, 2:] -= norm_vector_substract
|
952 |
+
else:
|
953 |
+
chars_center_coords_list[sample_idx] -= norm_vector_substract
|
954 |
+
return trialslist, samplelist, chars_center_coords_list
|
955 |
+
|
956 |
+
|
957 |
+
def norm_coords_by_letter_positions(
|
958 |
+
sample_idx: int,
|
959 |
+
trialslist: list,
|
960 |
+
samplelist: list,
|
961 |
+
meanlist: list = None,
|
962 |
+
stdlist: list = None,
|
963 |
+
return_mean_std_lists=False,
|
964 |
+
norm_by_char_averages=False,
|
965 |
+
chars_center_coords_list: list = None,
|
966 |
+
add_normalised_values_as_features=False,
|
967 |
+
):
|
968 |
+
chars_df = pd.DataFrame(trialslist[sample_idx]["chars_list"])
|
969 |
+
trialslist[sample_idx]["x_char_unique"] = chars_df.char_xmin.unique()
|
970 |
+
|
971 |
+
min_x_chars = chars_df.char_xmin.min()
|
972 |
+
max_x_chars = chars_df.char_xmax.max()
|
973 |
+
|
974 |
+
norm_vector_multi = t.ones(
|
975 |
+
(1, samplelist[sample_idx].shape[1]), dtype=samplelist[sample_idx].dtype, device=samplelist[sample_idx].device
|
976 |
+
)
|
977 |
+
if norm_by_char_averages:
|
978 |
+
chars_list = trialslist[sample_idx]["chars_list"]
|
979 |
+
char_widths = np.asarray([x["char_xmax"] - x["char_xmin"] for x in chars_list])
|
980 |
+
char_heights = np.asarray([x["char_ymax"] - x["char_ymin"] for x in chars_list])
|
981 |
+
char_widths_average = np.mean(char_widths[char_widths > 0])
|
982 |
+
char_heights_average = np.mean(char_heights[char_heights > 0])
|
983 |
+
|
984 |
+
norm_vector_multi[0, 0] = norm_vector_multi[0, 0] * char_widths_average
|
985 |
+
norm_vector_multi[0, 1] = norm_vector_multi[0, 1] * char_heights_average
|
986 |
+
|
987 |
+
else:
|
988 |
+
line_height = min(np.unique(trialslist[sample_idx]["line_heights"]))
|
989 |
+
line_width = max_x_chars - min_x_chars
|
990 |
+
norm_vector_multi[0, 0] = norm_vector_multi[0, 0] * line_width
|
991 |
+
norm_vector_multi[0, 1] = norm_vector_multi[0, 1] * line_height
|
992 |
+
assert ~t.any(t.isnan(norm_vector_multi)), "Nan found in char norming vector"
|
993 |
+
|
994 |
+
norm_vector_multi = norm_vector_multi.squeeze(0)
|
995 |
+
if add_normalised_values_as_features:
|
996 |
+
norm_vector_multi = norm_vector_multi[norm_vector_multi != 1]
|
997 |
+
normed_features = samplelist[sample_idx][:, : norm_vector_multi.shape[0]] / norm_vector_multi
|
998 |
+
samplelist[sample_idx] = t.cat([samplelist[sample_idx], normed_features], dim=1)
|
999 |
+
else:
|
1000 |
+
samplelist[sample_idx] = samplelist[sample_idx] / norm_vector_multi # in case time or pupil size is included
|
1001 |
+
if chars_center_coords_list is not None:
|
1002 |
+
norm_vector_multi = norm_vector_multi[:2]
|
1003 |
+
if chars_center_coords_list[sample_idx].shape[-1] == norm_vector_multi.shape[-1] * 2:
|
1004 |
+
chars_center_coords_list[sample_idx][:, :2] /= norm_vector_multi
|
1005 |
+
chars_center_coords_list[sample_idx][:, 2:] /= norm_vector_multi
|
1006 |
+
else:
|
1007 |
+
chars_center_coords_list[sample_idx] /= norm_vector_multi
|
1008 |
+
if return_mean_std_lists:
|
1009 |
+
mean_val = samplelist[sample_idx].mean(axis=0).cpu().numpy()
|
1010 |
+
meanlist.append(mean_val)
|
1011 |
+
std_val = samplelist[sample_idx].std(axis=0).cpu().numpy()
|
1012 |
+
stdlist.append(std_val)
|
1013 |
+
assert ~any(np.isnan(mean_val)), "Nan found in mean_val"
|
1014 |
+
assert ~any(np.isnan(mean_val)), "Nan found in std_val"
|
1015 |
+
|
1016 |
+
return trialslist, samplelist, meanlist, stdlist, chars_center_coords_list
|
1017 |
+
return trialslist, samplelist, chars_center_coords_list
|
1018 |
+
|
1019 |
+
|
1020 |
+
def remove_compile_from_model(model):
|
1021 |
+
if hasattr(model.project, "_orig_mod"):
|
1022 |
+
model.project = model.project._orig_mod
|
1023 |
+
model.chars_conv = model.chars_conv._orig_mod
|
1024 |
+
model.chars_classifier = model.chars_classifier._orig_mod
|
1025 |
+
model.layer_norm_in = model.layer_norm_in._orig_mod
|
1026 |
+
model.bert_model = model.bert_model._orig_mod
|
1027 |
+
model.linear = model.linear._orig_mod
|
1028 |
+
else:
|
1029 |
+
print(f"remove_compile_from_model not done since model.project {model.project} has no orig_mod")
|
1030 |
+
return model
|
1031 |
+
|
1032 |
+
|
1033 |
+
def remove_compile_from_dict(state_dict):
|
1034 |
+
for key in list(state_dict.keys()):
|
1035 |
+
newkey = key.replace("._orig_mod.", ".")
|
1036 |
+
state_dict[newkey] = state_dict.pop(key)
|
1037 |
+
return state_dict
|
1038 |
+
|
1039 |
+
|
1040 |
+
def add_text_to_ax(
|
1041 |
+
chars_list,
|
1042 |
+
ax,
|
1043 |
+
font_to_use="DejaVu Sans Mono",
|
1044 |
+
fontsize=21,
|
1045 |
+
prefix="char",
|
1046 |
+
plot_boxes=True,
|
1047 |
+
plot_text=True,
|
1048 |
+
box_annotations=None,
|
1049 |
+
):
|
1050 |
+
font_props = FontProperties(family=font_to_use, style="normal", size=fontsize)
|
1051 |
+
if not plot_boxes and not plot_text:
|
1052 |
+
return None
|
1053 |
+
if box_annotations is None:
|
1054 |
+
enum = chars_list
|
1055 |
+
else:
|
1056 |
+
enum = zip(chars_list, box_annotations)
|
1057 |
+
for v in enum:
|
1058 |
+
if box_annotations is not None:
|
1059 |
+
v, annot_text = v
|
1060 |
+
x0, y0 = v[f"{prefix}_xmin"], v[f"{prefix}_ymin"]
|
1061 |
+
xdiff, ydiff = v[f"{prefix}_xmax"] - v[f"{prefix}_xmin"], v[f"{prefix}_ymax"] - v[f"{prefix}_ymin"]
|
1062 |
+
if plot_text:
|
1063 |
+
ax.text(
|
1064 |
+
v[f"{prefix}_x_center"],
|
1065 |
+
v[f"{prefix}_y_center"],
|
1066 |
+
v[prefix],
|
1067 |
+
horizontalalignment="center",
|
1068 |
+
verticalalignment="center",
|
1069 |
+
fontproperties=font_props,
|
1070 |
+
)
|
1071 |
+
if plot_boxes:
|
1072 |
+
ax.add_patch(Rectangle((x0, y0), xdiff, ydiff, edgecolor="grey", facecolor="none", lw=0.8, alpha=0.4))
|
1073 |
+
if box_annotations is not None:
|
1074 |
+
ax.annotate(
|
1075 |
+
str(annot_text),
|
1076 |
+
(x0 + xdiff / 2, y0),
|
1077 |
+
horizontalalignment="center",
|
1078 |
+
verticalalignment="center",
|
1079 |
+
fontproperties=FontProperties(family=font_to_use, style="normal", size=fontsize / 1.5),
|
1080 |
+
)
|
1081 |
+
|
1082 |
+
|
1083 |
+
def plot_fixations_and_text(
|
1084 |
+
dffix: pd.DataFrame,
|
1085 |
+
trial: dict,
|
1086 |
+
plot_prefix="chars_",
|
1087 |
+
show=False,
|
1088 |
+
returnfig=False,
|
1089 |
+
save=False,
|
1090 |
+
savelocation="plot.png",
|
1091 |
+
font_to_use="DejaVu Sans Mono",
|
1092 |
+
fontsize=20,
|
1093 |
+
plot_classic=True,
|
1094 |
+
plot_boxes=True,
|
1095 |
+
plot_text=True,
|
1096 |
+
fig_size=(14, 8),
|
1097 |
+
dpi=300,
|
1098 |
+
turn_axis_on=True,
|
1099 |
+
algo_choice="slice",
|
1100 |
+
):
|
1101 |
+
fig, ax = plt.subplots(1, 1, figsize=fig_size, tight_layout=True, dpi=dpi)
|
1102 |
+
if f"{plot_prefix}list" in trial.keys():
|
1103 |
+
add_text_to_ax(
|
1104 |
+
trial[f"{plot_prefix}list"],
|
1105 |
+
ax,
|
1106 |
+
font_to_use,
|
1107 |
+
fontsize=fontsize,
|
1108 |
+
prefix=plot_prefix[:-2],
|
1109 |
+
plot_boxes=plot_boxes,
|
1110 |
+
plot_text=plot_text,
|
1111 |
+
)
|
1112 |
+
ax.plot(dffix.x, dffix.y, "kX", label="Raw Fixations", alpha=0.9)
|
1113 |
+
|
1114 |
+
if plot_classic and f"line_num_{algo_choice}" in dffix.columns:
|
1115 |
+
ax.scatter(
|
1116 |
+
dffix.x,
|
1117 |
+
dffix[f"y_{algo_choice}"],
|
1118 |
+
marker="*",
|
1119 |
+
color="tab:green",
|
1120 |
+
label=f"{algo_choice} Prediction",
|
1121 |
+
alpha=0.9,
|
1122 |
+
)
|
1123 |
+
for x_before, y_before, x_after, y_after in zip(
|
1124 |
+
dffix.x.values, dffix[f"y_{algo_choice}"].values, dffix.x, dffix.y
|
1125 |
+
):
|
1126 |
+
arr_delta_x = x_after - x_before
|
1127 |
+
arr_delta_y = y_after - y_before
|
1128 |
+
ax.arrow(x_before, y_before, arr_delta_x, arr_delta_y, color="tab:green", alpha=0.6)
|
1129 |
+
ax.set_ylabel("y (pixel)")
|
1130 |
+
ax.set_xlabel("x (pixel)")
|
1131 |
+
|
1132 |
+
ax.invert_yaxis()
|
1133 |
+
ax.legend(bbox_to_anchor=(1, 1), loc="upper left")
|
1134 |
+
if not turn_axis_on:
|
1135 |
+
ax.axis("off")
|
1136 |
+
if save:
|
1137 |
+
plt.savefig(savelocation, dpi=dpi)
|
1138 |
+
if show:
|
1139 |
+
plt.show()
|
1140 |
+
if returnfig:
|
1141 |
+
return fig
|
1142 |
+
else:
|
1143 |
+
plt.close()
|
1144 |
+
return None
|
1145 |
+
|
1146 |
+
|
1147 |
+
def make_folders(gradio_temp_folder, gradio_temp_unzipped_folder, gradio_plots):
|
1148 |
+
gradio_temp_folder.mkdir(exist_ok=True)
|
1149 |
+
gradio_temp_unzipped_folder.mkdir(exist_ok=True)
|
1150 |
+
gradio_plots.mkdir(exist_ok=True)
|
1151 |
+
return 0
|
1152 |
+
|
1153 |
+
|
1154 |
+
def get_classic_cfg(fname):
|
1155 |
+
with open(fname, "r") as f:
|
1156 |
+
jsonsstring = f.read()
|
1157 |
+
classic_algos_cfg = json.loads(jsonsstring)
|
1158 |
+
classic_algos_cfg["slice"] = classic_algos_cfg["slice"]
|
1159 |
+
classic_algos_cfg = classic_algos_cfg
|
1160 |
+
return classic_algos_cfg
|
1161 |
+
|
1162 |
+
|
1163 |
+
def find_and_load_model(model_date="20240104-223349"):
|
1164 |
+
model_cfg_file = list(DIST_MODELS_FOLDER.glob(f"*{model_date}*.yaml"))
|
1165 |
+
if len(model_cfg_file) == 0:
|
1166 |
+
if "logger" in st.session_state:
|
1167 |
+
st.session_state["logger"].warning(f"No model cfg yaml found for {model_date}")
|
1168 |
+
return None, None
|
1169 |
+
model_cfg_file = model_cfg_file[0]
|
1170 |
+
with open(model_cfg_file) as f:
|
1171 |
+
model_cfg = yaml.safe_load(f)
|
1172 |
+
|
1173 |
+
model_cfg["system_type"] = "linux"
|
1174 |
+
model_file = list(pl.Path("models").glob(f"*{model_date}*.ckpt"))[0]
|
1175 |
+
model = load_model(model_file, model_cfg)
|
1176 |
+
|
1177 |
+
return model, model_cfg
|
1178 |
+
|
1179 |
+
|
1180 |
+
def load_model(model_file, cfg):
|
1181 |
+
try:
|
1182 |
+
model_loaded = t.load(model_file, map_location="cpu")
|
1183 |
+
if "hyper_parameters" in model_loaded.keys():
|
1184 |
+
model_cfg_temp = model_loaded["hyper_parameters"]["cfg"]
|
1185 |
+
else:
|
1186 |
+
model_cfg_temp = cfg
|
1187 |
+
model_state_dict = model_loaded["state_dict"]
|
1188 |
+
except Exception as e:
|
1189 |
+
if "logger" in st.session_state:
|
1190 |
+
st.session_state["logger"].warning(e)
|
1191 |
+
if "logger" in st.session_state:
|
1192 |
+
st.session_state["logger"].warning(f"Failed to load {model_file}")
|
1193 |
+
return None
|
1194 |
+
model = LitModel(
|
1195 |
+
[1, 500, 3],
|
1196 |
+
model_cfg_temp["hidden_dim_bert"],
|
1197 |
+
model_cfg_temp["num_attention_heads"],
|
1198 |
+
model_cfg_temp["n_layers_BERT"],
|
1199 |
+
model_cfg_temp["loss_function"],
|
1200 |
+
1e-4,
|
1201 |
+
model_cfg_temp["weight_decay"],
|
1202 |
+
model_cfg_temp,
|
1203 |
+
model_cfg_temp["use_lr_warmup"],
|
1204 |
+
model_cfg_temp["use_reduce_on_plateau"],
|
1205 |
+
track_gradient_histogram=model_cfg_temp["track_gradient_histogram"],
|
1206 |
+
register_forw_hook=model_cfg_temp["track_activations_via_hook"],
|
1207 |
+
char_dims=model_cfg_temp["char_dims"],
|
1208 |
+
)
|
1209 |
+
model = remove_compile_from_model(model)
|
1210 |
+
model_state_dict = remove_compile_from_dict(model_state_dict)
|
1211 |
+
with t.no_grad():
|
1212 |
+
model.load_state_dict(model_state_dict, strict=False)
|
1213 |
+
model.eval()
|
1214 |
+
model.freeze()
|
1215 |
+
return model
|
1216 |
+
|
1217 |
+
|
1218 |
+
def set_up_models(dist_models_folder):
|
1219 |
+
out_dict = {}
|
1220 |
+
if "logger" in st.session_state:
|
1221 |
+
st.session_state["logger"].info("Loading Ensemble")
|
1222 |
+
dist_models_with_norm = list(dist_models_folder.glob("*normalize_by_line_height_and_width_True*.ckpt"))
|
1223 |
+
dist_models_without_norm = list(dist_models_folder.glob("*normalize_by_line_height_and_width_False*.ckpt"))
|
1224 |
+
DIST_MODEL_DATE_WITH_NORM = dist_models_with_norm[0].stem.split("_")[1]
|
1225 |
+
|
1226 |
+
models_without_norm_df = [find_and_load_model(m_file.stem.split("_")[1]) for m_file in dist_models_without_norm]
|
1227 |
+
models_with_norm_df = [find_and_load_model(m_file.stem.split("_")[1]) for m_file in dist_models_with_norm]
|
1228 |
+
|
1229 |
+
model_cfg_without_norm_df = [x[1] for x in models_without_norm_df if x[1] is not None][0]
|
1230 |
+
model_cfg_with_norm_df = [x[1] for x in models_with_norm_df if x[1] is not None][0]
|
1231 |
+
|
1232 |
+
models_without_norm_df = [x[0] for x in models_without_norm_df if x[0] is not None]
|
1233 |
+
models_with_norm_df = [x[0] for x in models_with_norm_df if x[0] is not None]
|
1234 |
+
|
1235 |
+
ensemble_model_avg = EnsembleModel(
|
1236 |
+
models_without_norm_df, models_with_norm_df, learning_rate=0.0058, use_simple_average=True
|
1237 |
+
)
|
1238 |
+
out_dict["ensemble_model_avg"] = ensemble_model_avg
|
1239 |
+
|
1240 |
+
out_dict["model_cfg_without_norm_df"] = model_cfg_without_norm_df
|
1241 |
+
out_dict["model_cfg_with_norm_df"] = model_cfg_with_norm_df
|
1242 |
+
|
1243 |
+
single_DIST_model, single_DIST_model_cfg = find_and_load_model(model_date=DIST_MODEL_DATE_WITH_NORM)
|
1244 |
+
out_dict["DIST_MODEL_DATE_WITH_NORM"] = DIST_MODEL_DATE_WITH_NORM
|
1245 |
+
out_dict["single_DIST_model"] = single_DIST_model
|
1246 |
+
out_dict["single_DIST_model_cfg"] = single_DIST_model_cfg
|
1247 |
+
return out_dict
|
1248 |
+
|
1249 |
+
|
1250 |
+
def prep_data_for_dist(model_cfg, dffix, trial=None):
|
1251 |
+
if "logger" in st.session_state:
|
1252 |
+
st.session_state["logger"].debug("prep_data_for_dist entered")
|
1253 |
+
if trial is None:
|
1254 |
+
trial = st.session_state["trial"]
|
1255 |
+
if isinstance(dffix, dict):
|
1256 |
+
dffix = dffix["value"]
|
1257 |
+
sample_tensor = t.tensor(dffix.loc[:, model_cfg["sample_cols"]].to_numpy(), dtype=t.float32)
|
1258 |
+
|
1259 |
+
if model_cfg["add_line_overlap_feature"]:
|
1260 |
+
sample_tensor = add_line_overlaps_to_sample(trial, sample_tensor)
|
1261 |
+
|
1262 |
+
has_nans = t.any(t.isnan(sample_tensor))
|
1263 |
+
assert not has_nans, "NaNs found in sample tensor"
|
1264 |
+
samplelist_eval = [sample_tensor]
|
1265 |
+
trialslist_eval = [trial]
|
1266 |
+
chars_center_coords_list_eval = None
|
1267 |
+
if model_cfg["norm_coords_by_letter_min_x_y"]:
|
1268 |
+
for sample_idx, _ in enumerate(samplelist_eval):
|
1269 |
+
trialslist_eval, samplelist_eval, chars_center_coords_list_eval = norm_coords_by_letter_min_x_y(
|
1270 |
+
sample_idx,
|
1271 |
+
trialslist_eval,
|
1272 |
+
samplelist_eval,
|
1273 |
+
chars_center_coords_list=chars_center_coords_list_eval,
|
1274 |
+
)
|
1275 |
+
|
1276 |
+
if model_cfg["normalize_by_line_height_and_width"]:
|
1277 |
+
meanlist_eval, stdlist_eval = [], []
|
1278 |
+
for sample_idx, _ in enumerate(samplelist_eval):
|
1279 |
+
(
|
1280 |
+
trialslist_eval,
|
1281 |
+
samplelist_eval,
|
1282 |
+
meanlist_eval,
|
1283 |
+
stdlist_eval,
|
1284 |
+
chars_center_coords_list_eval,
|
1285 |
+
) = norm_coords_by_letter_positions(
|
1286 |
+
sample_idx,
|
1287 |
+
trialslist_eval,
|
1288 |
+
samplelist_eval,
|
1289 |
+
meanlist_eval,
|
1290 |
+
stdlist_eval,
|
1291 |
+
return_mean_std_lists=True,
|
1292 |
+
norm_by_char_averages=model_cfg["norm_by_char_averages"],
|
1293 |
+
chars_center_coords_list=chars_center_coords_list_eval,
|
1294 |
+
add_normalised_values_as_features=model_cfg["add_normalised_values_as_features"],
|
1295 |
+
)
|
1296 |
+
sample_tensor = samplelist_eval[0]
|
1297 |
+
sample_means = t.tensor(model_cfg["sample_means"], dtype=t.float32)
|
1298 |
+
sample_std = t.tensor(model_cfg["sample_std"], dtype=t.float32)
|
1299 |
+
sample_tensor = (sample_tensor - sample_means) / sample_std
|
1300 |
+
sample_tensor = sample_tensor.unsqueeze(0)
|
1301 |
+
|
1302 |
+
if "logger" in st.session_state:
|
1303 |
+
st.session_state["logger"].info(f"Using path {trial['plot_file']} for plotting")
|
1304 |
+
plot_text_boxes_fixations(
|
1305 |
+
fpath=trial["plot_file"],
|
1306 |
+
dpi=250,
|
1307 |
+
screen_res=(1024, 768),
|
1308 |
+
data_dir_sub=None,
|
1309 |
+
set_font_size=True,
|
1310 |
+
font_size=4,
|
1311 |
+
use_words=False,
|
1312 |
+
save_channel_repeats=False,
|
1313 |
+
save_combo_grey_and_rgb=False,
|
1314 |
+
dffix=dffix,
|
1315 |
+
trial=trial,
|
1316 |
+
)
|
1317 |
+
|
1318 |
+
val_set = DSet(
|
1319 |
+
sample_tensor,
|
1320 |
+
None,
|
1321 |
+
t.zeros((1, sample_tensor.shape[1])),
|
1322 |
+
trialslist_eval,
|
1323 |
+
padding_list=[0],
|
1324 |
+
padding_at_end=model_cfg["padding_at_end"],
|
1325 |
+
return_images_for_conv=True,
|
1326 |
+
im_partial_string=model_cfg["im_partial_string"],
|
1327 |
+
input_im_shape=model_cfg["char_plot_shape"],
|
1328 |
+
)
|
1329 |
+
val_loader = dl(val_set, batch_size=1, shuffle=False, num_workers=0)
|
1330 |
+
return val_loader, val_set
|
1331 |
+
|
1332 |
+
|
1333 |
+
def fold_in_seq_dim(out, y=None):
|
1334 |
+
batch_size, seq_len, num_classes = out.shape
|
1335 |
+
|
1336 |
+
out = eo.rearrange(out, "b s c -> (b s) c", s=seq_len)
|
1337 |
+
if y is None:
|
1338 |
+
return out, None
|
1339 |
+
if len(y.shape) > 2:
|
1340 |
+
y = eo.rearrange(y, "b s c -> (b s) c", s=seq_len)
|
1341 |
+
else:
|
1342 |
+
y = eo.rearrange(y, "b s -> (b s)", s=seq_len)
|
1343 |
+
return out, y
|
1344 |
+
|
1345 |
+
|
1346 |
+
def logits_to_pred(out, y=None):
|
1347 |
+
seq_len = out.shape[1]
|
1348 |
+
out, y = fold_in_seq_dim(out, y)
|
1349 |
+
preds = corn_label_from_logits(out)
|
1350 |
+
preds = eo.rearrange(preds, "(b s) -> b s", s=seq_len)
|
1351 |
+
if y is not None:
|
1352 |
+
y = eo.rearrange(y.squeeze(), "(b s) -> b s", s=seq_len)
|
1353 |
+
y = y
|
1354 |
+
return preds, y
|
1355 |
+
|
1356 |
+
|
1357 |
+
def get_DIST_preds(dffix, trial, models_dict=None):
|
1358 |
+
algo_choice = "DIST"
|
1359 |
+
|
1360 |
+
if models_dict is None:
|
1361 |
+
if st.session_state["single_DIST_model"] is None or st.session_state["single_DIST_model_cfg"] is None:
|
1362 |
+
st.session_state["single_DIST_model"], st.session_state["single_DIST_model_cfg"] = find_and_load_model(
|
1363 |
+
model_date=st.session_state["DIST_MODEL_DATE_WITH_NORM"]
|
1364 |
+
)
|
1365 |
+
|
1366 |
+
if "logger" in st.session_state:
|
1367 |
+
st.session_state["logger"].info("Model is None, reiniting model")
|
1368 |
+
else:
|
1369 |
+
model = st.session_state["single_DIST_model"]
|
1370 |
+
loader, dset = prep_data_for_dist(st.session_state["single_DIST_model_cfg"], dffix, trial)
|
1371 |
+
else:
|
1372 |
+
model = models_dict["single_DIST_model"]
|
1373 |
+
loader, dset = prep_data_for_dist(models_dict["single_DIST_model_cfg"], dffix, trial)
|
1374 |
+
batch = next(iter(loader))
|
1375 |
+
|
1376 |
+
if "cpu" not in str(model.device):
|
1377 |
+
batch = [x.cuda() for x in batch]
|
1378 |
+
try:
|
1379 |
+
out = model(batch)
|
1380 |
+
preds, y = logits_to_pred(out, y=None)
|
1381 |
+
if "logger" in st.session_state:
|
1382 |
+
st.session_state["logger"].debug(
|
1383 |
+
f"y_char_unique are {trial['y_char_unique']} for trial {trial['trial_id']}"
|
1384 |
+
)
|
1385 |
+
if "logger" in st.session_state:
|
1386 |
+
st.session_state["logger"].debug(f"trial keys are {trial.keys()} for trial {trial['trial_id']}")
|
1387 |
+
if "logger" in st.session_state:
|
1388 |
+
st.session_state["logger"].debug(
|
1389 |
+
f"chars_list has len {len(trial['chars_list'])} for trial {trial['trial_id']}"
|
1390 |
+
)
|
1391 |
+
if "logger" in st.session_state:
|
1392 |
+
st.session_state["logger"].debug(f"y_char_unique {trial['y_char_unique']} for trial {trial['trial_id']}")
|
1393 |
+
if len(trial["y_char_unique"]) < 1:
|
1394 |
+
y_char_unique = pd.DataFrame(trial["chars_list"]).char_y_center.sort_values().unique()
|
1395 |
+
else:
|
1396 |
+
y_char_unique = trial["y_char_unique"]
|
1397 |
+
num_lines = trial["num_char_lines"] - 1
|
1398 |
+
preds = t.clamp(preds, 0, num_lines).squeeze().cpu().numpy()
|
1399 |
+
y_pred_DIST = [y_char_unique[idx] for idx in preds]
|
1400 |
+
|
1401 |
+
dffix[f"line_num_{algo_choice}"] = preds
|
1402 |
+
dffix[f"y_{algo_choice}"] = np.round(y_pred_DIST, decimals=1)
|
1403 |
+
dffix[f"y_{algo_choice}_correction"] = (dffix.loc[:, f"y_{algo_choice}"] - dffix.loc[:, "y"]).round(1)
|
1404 |
+
except Exception as e:
|
1405 |
+
if "logger" in st.session_state:
|
1406 |
+
st.session_state["logger"].warning(f"Exception on model(batch) for DIST \n{e}")
|
1407 |
+
return dffix
|
1408 |
+
|
1409 |
+
|
1410 |
+
def get_DIST_ensemble_preds(
|
1411 |
+
dffix,
|
1412 |
+
trial,
|
1413 |
+
model_cfg_without_norm_df,
|
1414 |
+
model_cfg_with_norm_df,
|
1415 |
+
ensemble_model_avg,
|
1416 |
+
):
|
1417 |
+
algo_choice = "DIST-Ensemble"
|
1418 |
+
loader_without_norm, dset_without_norm = prep_data_for_dist(model_cfg_without_norm_df, dffix, trial)
|
1419 |
+
loader_with_norm, dset_with_norm = prep_data_for_dist(model_cfg_with_norm_df, dffix, trial)
|
1420 |
+
batch_without_norm = next(iter(loader_without_norm))
|
1421 |
+
batch_with_norm = next(iter(loader_with_norm))
|
1422 |
+
out = ensemble_model_avg((batch_without_norm, batch_with_norm))
|
1423 |
+
preds, y = logits_to_pred(out[0]["out_avg"], y=None)
|
1424 |
+
if len(trial["y_char_unique"]) < 1:
|
1425 |
+
y_char_unique = pd.DataFrame(trial["chars_list"]).char_y_center.sort_values().unique()
|
1426 |
+
else:
|
1427 |
+
y_char_unique = trial["y_char_unique"]
|
1428 |
+
num_lines = trial["num_char_lines"] - 1
|
1429 |
+
preds = t.clamp(preds, 0, num_lines).squeeze().cpu().numpy()
|
1430 |
+
if "logger" in st.session_state:
|
1431 |
+
st.session_state["logger"].debug(f"preds are {preds} for trial {trial['trial_id']}")
|
1432 |
+
y_pred_DIST = [y_char_unique[idx] for idx in preds]
|
1433 |
+
|
1434 |
+
dffix[f"line_num_{algo_choice}"] = preds
|
1435 |
+
dffix[f"y_{algo_choice}"] = np.round(y_pred_DIST, decimals=1)
|
1436 |
+
dffix[f"y_{algo_choice}_correction"] = (dffix.loc[:, f"y_{algo_choice}"] - dffix.loc[:, "y"]).round(1)
|
1437 |
+
return dffix
|
1438 |
+
|
1439 |
+
|
1440 |
+
def get_EDIST_preds_with_model_check(dffix, trial, ensemble_model_avg=None, models_dict=None):
|
1441 |
+
|
1442 |
+
if models_dict is None:
|
1443 |
+
if ensemble_model_avg is None and "ensemble_model_avg" not in st.session_state:
|
1444 |
+
if "logger" in st.session_state:
|
1445 |
+
st.session_state["logger"].info("Ensemble Model is None, reiniting model")
|
1446 |
+
dist_models_with_norm = DIST_MODELS_FOLDER.glob("*normalize_by_line_height_and_width_True*.ckpt")
|
1447 |
+
dist_models_without_norm = DIST_MODELS_FOLDER.glob("*normalize_by_line_height_and_width_False*.ckpt")
|
1448 |
+
|
1449 |
+
models_without_norm_df = [
|
1450 |
+
find_and_load_model(m_file.stem.split("_")[1]) for m_file in dist_models_without_norm
|
1451 |
+
]
|
1452 |
+
models_with_norm_df = [find_and_load_model(m_file.stem.split("_")[1]) for m_file in dist_models_with_norm]
|
1453 |
+
|
1454 |
+
model_cfg_without_norm_df = [x[1] for x in models_without_norm_df if x[1] is not None][0]
|
1455 |
+
model_cfg_with_norm_df = [x[1] for x in models_with_norm_df if x[1] is not None][0]
|
1456 |
+
|
1457 |
+
models_without_norm_df = [x[0] for x in models_without_norm_df if x[0] is not None]
|
1458 |
+
models_with_norm_df = [x[0] for x in models_with_norm_df if x[0] is not None]
|
1459 |
+
|
1460 |
+
ensemble_model_avg = EnsembleModel(
|
1461 |
+
models_without_norm_df, models_with_norm_df, learning_rate=0.0, use_simple_average=True
|
1462 |
+
)
|
1463 |
+
st.session_state["ensemble_model_avg"] = ensemble_model_avg
|
1464 |
+
st.session_state["model_cfg_without_norm_df"] = model_cfg_without_norm_df
|
1465 |
+
st.session_state["model_cfg_with_norm_df"] = model_cfg_with_norm_df
|
1466 |
+
else:
|
1467 |
+
model_cfg_without_norm_df = st.session_state["model_cfg_without_norm_df"]
|
1468 |
+
model_cfg_with_norm_df = st.session_state["model_cfg_with_norm_df"]
|
1469 |
+
ensemble_model_avg = st.session_state["ensemble_model_avg"]
|
1470 |
+
dffix = get_DIST_ensemble_preds(
|
1471 |
+
dffix,
|
1472 |
+
trial,
|
1473 |
+
st.session_state["model_cfg_without_norm_df"],
|
1474 |
+
st.session_state["model_cfg_with_norm_df"],
|
1475 |
+
st.session_state["ensemble_model_avg"],
|
1476 |
+
)
|
1477 |
+
else:
|
1478 |
+
dffix = get_DIST_ensemble_preds(
|
1479 |
+
dffix,
|
1480 |
+
trial,
|
1481 |
+
models_dict["model_cfg_without_norm_df"],
|
1482 |
+
models_dict["model_cfg_with_norm_df"],
|
1483 |
+
models_dict["ensemble_model_avg"],
|
1484 |
+
)
|
1485 |
+
return dffix
|
1486 |
+
|
1487 |
+
|
1488 |
+
def correct_df(
|
1489 |
+
dffix,
|
1490 |
+
algo_choice,
|
1491 |
+
trial=None,
|
1492 |
+
for_multi=False,
|
1493 |
+
ensemble_model_avg=None,
|
1494 |
+
is_outside_of_streamlit=False,
|
1495 |
+
classic_algos_cfg=None,
|
1496 |
+
models_dict=None,
|
1497 |
+
):
|
1498 |
+
if is_outside_of_streamlit:
|
1499 |
+
stqdm = tqdm
|
1500 |
+
else:
|
1501 |
+
from stqdm import stqdm
|
1502 |
+
if classic_algos_cfg is None:
|
1503 |
+
classic_algos_cfg = st.session_state["classic_algos_cfg"]
|
1504 |
+
if trial is None and not for_multi:
|
1505 |
+
trial = st.session_state["trial"]
|
1506 |
+
if "logger" in st.session_state:
|
1507 |
+
st.session_state["logger"].info(f"Applying {algo_choice} to fixations for trial {trial['trial_id']}")
|
1508 |
+
|
1509 |
+
if isinstance(dffix, dict):
|
1510 |
+
dffix = dffix["value"]
|
1511 |
+
if "x" not in dffix.keys() or "x" not in dffix.keys():
|
1512 |
+
if "logger" in st.session_state:
|
1513 |
+
st.session_state["logger"].warning(f"x or y not in dffix")
|
1514 |
+
if "logger" in st.session_state:
|
1515 |
+
st.session_state["logger"].warning(dffix.columns)
|
1516 |
+
return dffix
|
1517 |
+
if isinstance(algo_choice, list):
|
1518 |
+
algo_choices = algo_choice
|
1519 |
+
repeats = range(len(algo_choice))
|
1520 |
+
else:
|
1521 |
+
algo_choices = [algo_choice]
|
1522 |
+
repeats = range(1)
|
1523 |
+
for algoIdx in stqdm(repeats, desc="Applying correction algorithms"):
|
1524 |
+
algo_choice = algo_choices[algoIdx]
|
1525 |
+
st_proc = time.process_time()
|
1526 |
+
st_wall = time.time()
|
1527 |
+
|
1528 |
+
if algo_choice == "DIST":
|
1529 |
+
dffix = get_DIST_preds(dffix, trial, models_dict=models_dict)
|
1530 |
+
|
1531 |
+
elif algo_choice == "DIST-Ensemble":
|
1532 |
+
dffix = get_EDIST_preds_with_model_check(dffix, trial, models_dict=models_dict)
|
1533 |
+
elif algo_choice == "Wisdom_of_Crowds_with_DIST":
|
1534 |
+
dffix, corrections = get_all_classic_preds(dffix, trial, classic_algos_cfg)
|
1535 |
+
dffix = get_DIST_preds(dffix, trial, models_dict=models_dict)
|
1536 |
+
for _ in range(3):
|
1537 |
+
corrections.append(np.asarray(dffix.loc[:, "y_DIST"]))
|
1538 |
+
dffix = apply_woc(dffix, trial, corrections, algo_choice)
|
1539 |
+
elif algo_choice == "Wisdom_of_Crowds_with_DIST_Ensemble":
|
1540 |
+
dffix, corrections = get_all_classic_preds(dffix, trial, classic_algos_cfg)
|
1541 |
+
dffix = get_EDIST_preds_with_model_check(dffix, trial, ensemble_model_avg, models_dict=models_dict)
|
1542 |
+
for _ in range(3):
|
1543 |
+
corrections.append(np.asarray(dffix.loc[:, "y_DIST-Ensemble"]))
|
1544 |
+
dffix = apply_woc(dffix, trial, corrections, algo_choice)
|
1545 |
+
elif algo_choice == "Wisdom_of_Crowds":
|
1546 |
+
dffix, corrections = get_all_classic_preds(dffix, trial, classic_algos_cfg)
|
1547 |
+
dffix = apply_woc(dffix, trial, corrections, algo_choice)
|
1548 |
+
|
1549 |
+
else:
|
1550 |
+
algo_cfg = classic_algos_cfg[algo_choice]
|
1551 |
+
dffix = calgo.apply_classic_algo(dffix, trial, algo_choice, algo_cfg)
|
1552 |
+
dffix[f"y_{algo_choice}_correction"] = (dffix.loc[:, f"y_{algo_choice}"] - dffix.loc[:, "y"]).round(1)
|
1553 |
+
|
1554 |
+
et_proc = time.process_time()
|
1555 |
+
time_proc = et_proc - st_proc
|
1556 |
+
et_wall = time.time()
|
1557 |
+
time_wall = et_wall - st_wall
|
1558 |
+
if "logger" in st.session_state:
|
1559 |
+
st.session_state["logger"].info(f"time_proc {algo_choice} {time_proc}")
|
1560 |
+
if "logger" in st.session_state:
|
1561 |
+
st.session_state["logger"].info(f"time_wall {algo_choice} {time_wall}")
|
1562 |
+
if for_multi:
|
1563 |
+
return dffix
|
1564 |
+
else:
|
1565 |
+
if "start_time" in dffix.columns:
|
1566 |
+
dffix = dffix.drop(axis=1, labels=["start_time", "end_time"])
|
1567 |
+
return dffix, export_csv(dffix, trial)
|
1568 |
+
|
1569 |
+
def set_font_from_chars_list(trial):
|
1570 |
+
|
1571 |
+
if "chars_list" in trial:
|
1572 |
+
chars_df = pd.DataFrame(trial["chars_list"])
|
1573 |
+
line_diffs = np.diff(chars_df.char_y_center.unique())
|
1574 |
+
y_diffs = np.unique(line_diffs)
|
1575 |
+
if len(y_diffs) == 1:
|
1576 |
+
y_diff = y_diffs[0]
|
1577 |
+
else:
|
1578 |
+
y_diff = np.min(y_diffs)
|
1579 |
+
y_diff = round(y_diff * 2) / 2
|
1580 |
+
|
1581 |
+
else:
|
1582 |
+
y_diff = 1 / 0.333 * 18
|
1583 |
+
font_size = y_diff * 0.333 # pixel to point conversion
|
1584 |
+
return round((font_size)*4,ndigits=0)/4
|
1585 |
+
|
1586 |
+
def get_font_and_font_size_from_trial(trial):
|
1587 |
+
font_face, font_size, dpi, screen_res = get_plot_props(trial, AVAILABLE_FONTS)
|
1588 |
+
|
1589 |
+
if font_size is None and "font_size" in trial:
|
1590 |
+
font_size = trial["font_size"]
|
1591 |
+
elif font_size is None:
|
1592 |
+
font_size = set_font_from_chars_list(trial)
|
1593 |
+
return font_face, font_size
|
1594 |
+
|
1595 |
+
|
1596 |
+
def sigmoid(x):
|
1597 |
+
return 1 / (1 + np.exp(-1 * x))
|
1598 |
+
|
1599 |
+
|
1600 |
+
def matplotlib_plot_df(
|
1601 |
+
dffix,
|
1602 |
+
trial,
|
1603 |
+
algo_choice,
|
1604 |
+
stimulus_prefix="word",
|
1605 |
+
desired_dpi=300,
|
1606 |
+
fix_to_plot=[],
|
1607 |
+
stim_info_to_plot=["Words", "Word boxes"],
|
1608 |
+
box_annotations=None,
|
1609 |
+
):
|
1610 |
+
chars_df = pd.DataFrame(trial["chars_list"]) if "chars_list" in trial else None
|
1611 |
+
|
1612 |
+
if chars_df is not None:
|
1613 |
+
font_face, font_size = get_font_and_font_size_from_trial(trial)
|
1614 |
+
font_size = font_size * 0.65
|
1615 |
+
else:
|
1616 |
+
st.warning("No character or word information available to plot")
|
1617 |
+
|
1618 |
+
if "display_coords" in trial:
|
1619 |
+
desired_width_in_pixels = trial["display_coords"][2] + 1
|
1620 |
+
desired_height_in_pixels = trial["display_coords"][3] + 1
|
1621 |
+
else:
|
1622 |
+
desired_width_in_pixels = 1920
|
1623 |
+
desired_height_in_pixels = 1080
|
1624 |
+
|
1625 |
+
figure_width = desired_width_in_pixels / desired_dpi
|
1626 |
+
figure_height = desired_height_in_pixels / desired_dpi
|
1627 |
+
|
1628 |
+
fig = plt.figure(figsize=(figure_width, figure_height), dpi=desired_dpi)
|
1629 |
+
ax = fig.add_subplot(1, 1, 1)
|
1630 |
+
fig.subplots_adjust(bottom=0)
|
1631 |
+
fig.subplots_adjust(top=1)
|
1632 |
+
fig.subplots_adjust(right=1)
|
1633 |
+
fig.subplots_adjust(left=0)
|
1634 |
+
if "font" in trial and trial["font"] in AVAILABLE_FONTS:
|
1635 |
+
font_to_use = trial["font"]
|
1636 |
+
else:
|
1637 |
+
font_to_use = "DejaVu Sans Mono"
|
1638 |
+
if "font_size" in trial:
|
1639 |
+
font_size = trial["font_size"]
|
1640 |
+
else:
|
1641 |
+
font_size = 20
|
1642 |
+
|
1643 |
+
if f"{stimulus_prefix}s_list" in trial:
|
1644 |
+
add_text_to_ax(
|
1645 |
+
trial[f"{stimulus_prefix}s_list"],
|
1646 |
+
ax,
|
1647 |
+
font_to_use,
|
1648 |
+
prefix=stimulus_prefix,
|
1649 |
+
fontsize=font_size / 3.89,
|
1650 |
+
plot_text=False,
|
1651 |
+
plot_boxes=True if "Word boxes" in stim_info_to_plot else False,
|
1652 |
+
box_annotations=box_annotations,
|
1653 |
+
)
|
1654 |
+
|
1655 |
+
if "chars_list" in trial:
|
1656 |
+
add_text_to_ax(
|
1657 |
+
trial["chars_list"],
|
1658 |
+
ax,
|
1659 |
+
font_to_use,
|
1660 |
+
prefix="char",
|
1661 |
+
fontsize=font_size / 3.89,
|
1662 |
+
plot_text=True if "Words" in stim_info_to_plot else False,
|
1663 |
+
plot_boxes=False,
|
1664 |
+
box_annotations=None,
|
1665 |
+
)
|
1666 |
+
|
1667 |
+
if "Uncorrected Fixations" in fix_to_plot:
|
1668 |
+
ax.plot(dffix.x, dffix.y, label="Raw fixations", color="blue", alpha=0.6, linewidth=0.6)
|
1669 |
+
|
1670 |
+
x0 = dffix.x.iloc[range(len(dffix.x) - 1)].values
|
1671 |
+
x1 = dffix.x.iloc[range(1, len(dffix.x))].values
|
1672 |
+
y0 = dffix.y.iloc[range(len(dffix.y) - 1)].values
|
1673 |
+
y1 = dffix.y.iloc[range(1, len(dffix.y))].values
|
1674 |
+
xpos = x0
|
1675 |
+
ypos = y0
|
1676 |
+
xdir = x1 - x0
|
1677 |
+
ydir = y1 - y0
|
1678 |
+
for X, Y, dX, dY in zip(xpos, ypos, xdir, ydir):
|
1679 |
+
ax.annotate(
|
1680 |
+
"",
|
1681 |
+
xytext=(X, Y),
|
1682 |
+
xy=(X + 0.001 * dX, Y + 0.001 * dY),
|
1683 |
+
arrowprops=dict(arrowstyle="fancy", color="blue"),
|
1684 |
+
size=8,
|
1685 |
+
alpha=0.3,
|
1686 |
+
)
|
1687 |
+
if "Corrected Fixations" in fix_to_plot:
|
1688 |
+
if isinstance(algo_choice, list):
|
1689 |
+
algo_choices = algo_choice
|
1690 |
+
repeats = range(len(algo_choice))
|
1691 |
+
else:
|
1692 |
+
algo_choices = [algo_choice]
|
1693 |
+
repeats = range(1)
|
1694 |
+
for algoIdx in repeats:
|
1695 |
+
algo_choice = algo_choices[algoIdx]
|
1696 |
+
if f"y_{algo_choice}" in dffix.columns:
|
1697 |
+
ax.plot(
|
1698 |
+
dffix.x,
|
1699 |
+
dffix.loc[:, f"y_{algo_choice}"],
|
1700 |
+
label="Raw fixations",
|
1701 |
+
color=COLORS[algoIdx],
|
1702 |
+
alpha=0.6,
|
1703 |
+
linewidth=0.6,
|
1704 |
+
)
|
1705 |
+
|
1706 |
+
x0 = dffix.x.iloc[range(len(dffix.x) - 1)].values
|
1707 |
+
x1 = dffix.x.iloc[range(1, len(dffix.x))].values
|
1708 |
+
y0 = dffix.loc[:, f"y_{algo_choice}"].iloc[range(len(dffix.loc[:, f"y_{algo_choice}"]) - 1)].values
|
1709 |
+
y1 = dffix.loc[:, f"y_{algo_choice}"].iloc[range(1, len(dffix.loc[:, f"y_{algo_choice}"]))].values
|
1710 |
+
xpos = x0
|
1711 |
+
ypos = y0
|
1712 |
+
xdir = x1 - x0
|
1713 |
+
ydir = y1 - y0
|
1714 |
+
for X, Y, dX, dY in zip(xpos, ypos, xdir, ydir):
|
1715 |
+
ax.annotate(
|
1716 |
+
"",
|
1717 |
+
xytext=(X, Y),
|
1718 |
+
xy=(X + 0.001 * dX, Y + 0.001 * dY),
|
1719 |
+
arrowprops=dict(arrowstyle="fancy", color=COLORS[algoIdx]),
|
1720 |
+
size=8,
|
1721 |
+
alpha=0.3,
|
1722 |
+
)
|
1723 |
+
|
1724 |
+
ax.set_xlim((0, desired_width_in_pixels))
|
1725 |
+
ax.set_ylim((0, desired_height_in_pixels))
|
1726 |
+
ax.invert_yaxis()
|
1727 |
+
|
1728 |
+
return fig, desired_width_in_pixels, desired_height_in_pixels
|
1729 |
+
|
1730 |
+
|
1731 |
+
def plotly_plot_with_image(
|
1732 |
+
dffix,
|
1733 |
+
trial,
|
1734 |
+
algo_choice,
|
1735 |
+
to_plot_list=["Uncorrected Fixations", "Words", "corrected fixations", "Word boxes"],
|
1736 |
+
scale_factor=0.5,
|
1737 |
+
):
|
1738 |
+
fig, img_width, img_height = matplotlib_plot_df(
|
1739 |
+
dffix, trial, algo_choice, desired_dpi=300, fix_to_plot=[], stim_info_to_plot=to_plot_list
|
1740 |
+
)
|
1741 |
+
fig.savefig(TEMP_FIGURE_STIMULUS_PATH)
|
1742 |
+
fig = go.Figure()
|
1743 |
+
fig.add_trace(
|
1744 |
+
go.Scatter(
|
1745 |
+
x=[0, img_width * scale_factor],
|
1746 |
+
y=[img_height * scale_factor, 0],
|
1747 |
+
mode="markers",
|
1748 |
+
marker_opacity=0,
|
1749 |
+
name="scale_helper",
|
1750 |
+
)
|
1751 |
+
)
|
1752 |
+
|
1753 |
+
fig.update_xaxes(visible=False, range=[0, img_width * scale_factor])
|
1754 |
+
|
1755 |
+
fig.update_yaxes(
|
1756 |
+
visible=False,
|
1757 |
+
range=[img_height * scale_factor, 0],
|
1758 |
+
scaleanchor="x",
|
1759 |
+
)
|
1760 |
+
if "Words" in to_plot_list or "Word boxes" in to_plot_list:
|
1761 |
+
imsource = Image.open(str(TEMP_FIGURE_STIMULUS_PATH))
|
1762 |
+
fig.add_layout_image(
|
1763 |
+
dict(
|
1764 |
+
x=0,
|
1765 |
+
sizex=img_width * scale_factor,
|
1766 |
+
y=0,
|
1767 |
+
sizey=img_height * scale_factor,
|
1768 |
+
xref="x",
|
1769 |
+
yref="y",
|
1770 |
+
opacity=1.0,
|
1771 |
+
layer="below",
|
1772 |
+
sizing="stretch",
|
1773 |
+
source=imsource,
|
1774 |
+
)
|
1775 |
+
)
|
1776 |
+
|
1777 |
+
if "Uncorrected Fixations" in to_plot_list:
|
1778 |
+
duration_scaled = dffix.duration - dffix.duration.min()
|
1779 |
+
duration_scaled = ((duration_scaled / duration_scaled.max()) - 0.5) * 3
|
1780 |
+
duration = sigmoid(duration_scaled) * 50 * scale_factor
|
1781 |
+
fig.add_trace(
|
1782 |
+
go.Scatter(
|
1783 |
+
x=dffix.x * scale_factor,
|
1784 |
+
y=dffix.y * scale_factor,
|
1785 |
+
mode="markers+lines+text",
|
1786 |
+
name="Raw fixations",
|
1787 |
+
marker=dict(
|
1788 |
+
color=COLORS[-1],
|
1789 |
+
symbol="arrow",
|
1790 |
+
size=duration.values,
|
1791 |
+
angleref="previous",
|
1792 |
+
line=dict(color="black", width=duration.values / 10),
|
1793 |
+
),
|
1794 |
+
line_width=2 * scale_factor,
|
1795 |
+
text=np.arange(len(dffix.x)),
|
1796 |
+
textposition="middle right",
|
1797 |
+
textfont=dict(
|
1798 |
+
family="sans serif",
|
1799 |
+
size=18 * scale_factor,
|
1800 |
+
),
|
1801 |
+
hoverinfo="text+x+y",
|
1802 |
+
opacity=0.9,
|
1803 |
+
)
|
1804 |
+
)
|
1805 |
+
|
1806 |
+
if "Corrected Fixations" in to_plot_list:
|
1807 |
+
if isinstance(algo_choice, list):
|
1808 |
+
algo_choices = algo_choice
|
1809 |
+
repeats = range(len(algo_choice))
|
1810 |
+
else:
|
1811 |
+
algo_choices = [algo_choice]
|
1812 |
+
repeats = range(1)
|
1813 |
+
for algoIdx in repeats:
|
1814 |
+
algo_choice = algo_choices[algoIdx]
|
1815 |
+
if f"y_{algo_choice}" in dffix.columns:
|
1816 |
+
fig.add_trace(
|
1817 |
+
go.Scatter(
|
1818 |
+
x=dffix.x * scale_factor,
|
1819 |
+
y=dffix.loc[:, f"y_{algo_choice}"] * scale_factor,
|
1820 |
+
mode="markers",
|
1821 |
+
name=f"{algo_choice} corrected",
|
1822 |
+
marker_color=COLORS[algoIdx],
|
1823 |
+
marker_size=10 * scale_factor,
|
1824 |
+
hoverinfo="text+x+y",
|
1825 |
+
opacity=0.75,
|
1826 |
+
)
|
1827 |
+
)
|
1828 |
+
|
1829 |
+
fig.update_layout(
|
1830 |
+
plot_bgcolor=None,
|
1831 |
+
width=img_width * scale_factor,
|
1832 |
+
height=img_height * scale_factor,
|
1833 |
+
margin={"l": 0, "r": 0, "t": 0, "b": 0},
|
1834 |
+
legend=dict(orientation="h", yanchor="bottom", y=1.05, xanchor="right", x=0.8),
|
1835 |
+
)
|
1836 |
+
|
1837 |
+
for trace in fig["data"]:
|
1838 |
+
if trace["name"] == "scale_helper":
|
1839 |
+
trace["showlegend"] = False
|
1840 |
+
return fig
|
1841 |
+
|
1842 |
+
|
1843 |
+
def plot_y_corr(dffix, algo_choice, margin=dict(t=40, l=10, r=10, b=1)):
|
1844 |
+
num_datapoints = len(dffix.x)
|
1845 |
+
|
1846 |
+
layout = dict(
|
1847 |
+
plot_bgcolor="white",
|
1848 |
+
autosize=True,
|
1849 |
+
margin=margin,
|
1850 |
+
xaxis=dict(
|
1851 |
+
title="Fixation Index",
|
1852 |
+
linecolor="black",
|
1853 |
+
range=[-1, num_datapoints + 1],
|
1854 |
+
showgrid=False,
|
1855 |
+
mirror="all",
|
1856 |
+
showline=True,
|
1857 |
+
),
|
1858 |
+
yaxis=dict(
|
1859 |
+
title="y correction",
|
1860 |
+
side="left",
|
1861 |
+
linecolor="black",
|
1862 |
+
showgrid=False,
|
1863 |
+
mirror="all",
|
1864 |
+
showline=True,
|
1865 |
+
),
|
1866 |
+
legend=dict(orientation="v", yanchor="middle", y=0.95, xanchor="left", x=1.05),
|
1867 |
+
)
|
1868 |
+
if isinstance(dffix, dict):
|
1869 |
+
dffix = dffix["value"]
|
1870 |
+
algo_string = algo_choice[0] if isinstance(algo_choice, list) else algo_choice
|
1871 |
+
if f"y_{algo_string}_correction" not in dffix.columns:
|
1872 |
+
st.session_state["logger"].warning("No correction column found in dataframe")
|
1873 |
+
return go.Figure(layout=layout)
|
1874 |
+
if isinstance(dffix, dict):
|
1875 |
+
dffix = dffix["value"]
|
1876 |
+
|
1877 |
+
fig = go.Figure(layout=layout)
|
1878 |
+
|
1879 |
+
if isinstance(algo_choice, list):
|
1880 |
+
algo_choices = algo_choice
|
1881 |
+
repeats = range(len(algo_choice))
|
1882 |
+
else:
|
1883 |
+
algo_choices = [algo_choice]
|
1884 |
+
repeats = range(1)
|
1885 |
+
for algoIdx in repeats:
|
1886 |
+
algo_choice = algo_choices[algoIdx]
|
1887 |
+
fig.add_trace(
|
1888 |
+
go.Scatter(
|
1889 |
+
x=np.arange(num_datapoints),
|
1890 |
+
y=dffix.loc[:, f"y_{algo_choice}_correction"],
|
1891 |
+
mode="markers",
|
1892 |
+
name=f"{algo_choice} y correction",
|
1893 |
+
marker_color=COLORS[algoIdx],
|
1894 |
+
marker_size=3,
|
1895 |
+
showlegend=True,
|
1896 |
+
)
|
1897 |
+
)
|
1898 |
+
fig.update_yaxes(zeroline=True, zerolinewidth=1, zerolinecolor="black")
|
1899 |
+
|
1900 |
+
return fig
|
1901 |
+
|
1902 |
+
|
1903 |
+
def download_example_ascs(EXAMPLES_FOLDER, EXAMPLES_ASC_ZIP_FILENAME, OSF_DOWNLAOD_LINK, EXAMPLES_FOLDER_PATH):
|
1904 |
+
if not os.path.isdir(EXAMPLES_FOLDER):
|
1905 |
+
os.mkdir(EXAMPLES_FOLDER)
|
1906 |
+
|
1907 |
+
if not os.path.exists(EXAMPLES_ASC_ZIP_FILENAME):
|
1908 |
+
download_url(OSF_DOWNLAOD_LINK, EXAMPLES_ASC_ZIP_FILENAME)
|
1909 |
+
# os.system(f'''wget -O {EXAMPLES_ASC_ZIP_FILENAME} -c --read-timeout=5 --tries=0 "{OSF_DOWNLAOD_LINK}"''')
|
1910 |
+
|
1911 |
+
if os.path.exists(EXAMPLES_ASC_ZIP_FILENAME):
|
1912 |
+
if EXAMPLES_FOLDER_PATH.exists():
|
1913 |
+
EXAMPLE_ASC_FILES = [x for x in EXAMPLES_FOLDER_PATH.glob("*.asc")]
|
1914 |
+
if len(EXAMPLE_ASC_FILES) != 4:
|
1915 |
+
try:
|
1916 |
+
with zipfile.ZipFile(EXAMPLES_ASC_ZIP_FILENAME, "r") as zip_ref:
|
1917 |
+
zip_ref.extractall(EXAMPLES_FOLDER)
|
1918 |
+
except Exception as e:
|
1919 |
+
st.session_state["logger"].warning(e)
|
1920 |
+
st.session_state["logger"].warning(f"Extracting {EXAMPLES_ASC_ZIP_FILENAME} failed")
|
1921 |
+
|
1922 |
+
EXAMPLE_ASC_FILES = [x for x in EXAMPLES_FOLDER_PATH.glob("*.asc")]
|
1923 |
+
return EXAMPLE_ASC_FILES
|
1924 |
+
|
1925 |
+
|
1926 |
+
def process_trial_choice_single_csv(trial, algo_choice, file=None):
|
1927 |
+
trial_id = trial["trial_id"]
|
1928 |
+
if "dffix" in trial:
|
1929 |
+
dffix = trial["dffix"]
|
1930 |
+
else:
|
1931 |
+
if file is None:
|
1932 |
+
file = st.session_state["single_csv_file"]
|
1933 |
+
trial["plot_file"] = str(PLOTS_FOLDER.joinpath(f"{file.name}_{trial_id}_2ndInput_chars_channel_sep.png"))
|
1934 |
+
trial["fname"] = str(file.name)
|
1935 |
+
dffix = trial["dffix"] = st.session_state["trials_by_ids_single_csv"][trial_id]["dffix"]
|
1936 |
+
|
1937 |
+
font, font_size, dpi, screen_res = get_plot_props(trial, AVAILABLE_FONTS)
|
1938 |
+
chars_df = pd.DataFrame(trial["chars_list"])
|
1939 |
+
trial["chars_df"] = chars_df.to_dict()
|
1940 |
+
trial["y_char_unique"] = list(chars_df.char_y_center.sort_values().unique())
|
1941 |
+
if algo_choice is not None:
|
1942 |
+
dffix, _ = correct_df(dffix, algo_choice, trial)
|
1943 |
+
return dffix, trial, dpi, screen_res, font, font_size
|
1944 |
+
|
1945 |
+
|
1946 |
+
def add_default_font_and_character_props_to_state(trial):
|
1947 |
+
chars_list = trial["chars_list"]
|
1948 |
+
chars_df = pd.DataFrame(trial["chars_list"])
|
1949 |
+
line_diffs = np.diff(chars_df.char_y_center.unique())
|
1950 |
+
y_diffs = np.unique(line_diffs)
|
1951 |
+
if len(y_diffs) == 1:
|
1952 |
+
y_diff = y_diffs[0]
|
1953 |
+
else:
|
1954 |
+
y_diff = np.min(y_diffs)
|
1955 |
+
y_diff = round(y_diff * 2) / 2
|
1956 |
+
x_txt_start = chars_list[0]["char_xmin"]
|
1957 |
+
y_txt_start = chars_list[0]["char_y_center"]
|
1958 |
+
|
1959 |
+
font_face, font_size = get_font_and_font_size_from_trial(trial)
|
1960 |
+
|
1961 |
+
line_height = y_diff
|
1962 |
+
return y_diff, x_txt_start, y_txt_start, font_face, font_size, line_height
|
1963 |
+
|
1964 |
+
def get_all_measures(trial, dffix, prefix, use_corrected_fixations=True, correction_algo="warp"):
|
1965 |
+
if use_corrected_fixations:
|
1966 |
+
dffix_copy = copy.deepcopy(dffix)
|
1967 |
+
dffix_copy["y"] = dffix_copy[f"y_{correction_algo}"]
|
1968 |
+
else:
|
1969 |
+
dffix_copy = dffix
|
1970 |
+
initial_landing_position_own_vals = anf.initial_landing_position_own(trial, dffix_copy, prefix).set_index(
|
1971 |
+
f"{prefix}_index"
|
1972 |
+
)
|
1973 |
+
second_pass_duration_own_vals = anf.second_pass_duration_own(trial, dffix_copy, prefix).set_index(f"{prefix}_index")
|
1974 |
+
number_of_fixations_own_vals = anf.number_of_fixations_own(trial, dffix_copy, prefix).set_index(f"{prefix}_index")
|
1975 |
+
initial_fixation_duration_own_vals = anf.initial_fixation_duration_own(trial, dffix_copy, prefix).set_index(
|
1976 |
+
f"{prefix}_index"
|
1977 |
+
)
|
1978 |
+
first_of_many_duration_own_vals = anf.first_of_many_duration_own(trial, dffix_copy, prefix).set_index(
|
1979 |
+
f"{prefix}_index"
|
1980 |
+
)
|
1981 |
+
total_fixation_duration_own_vals = anf.total_fixation_duration_own(trial, dffix_copy, prefix).set_index(
|
1982 |
+
f"{prefix}_index"
|
1983 |
+
)
|
1984 |
+
gaze_duration_own_vals = anf.gaze_duration_own(trial, dffix_copy, prefix).set_index(f"{prefix}_index")
|
1985 |
+
go_past_duration_own_vals = anf.go_past_duration_own(trial, dffix_copy, prefix).set_index(f"{prefix}_index")
|
1986 |
+
initial_landing_distance_own_vals = anf.initial_landing_distance_own(trial, dffix_copy, prefix).set_index(
|
1987 |
+
f"{prefix}_index"
|
1988 |
+
)
|
1989 |
+
landing_distances_own_vals = anf.landing_distances_own(trial, dffix_copy, prefix).set_index(f"{prefix}_index")
|
1990 |
+
number_of_regressions_in_own_vals = anf.number_of_regressions_in_own(trial, dffix_copy, prefix).set_index(
|
1991 |
+
f"{prefix}_index"
|
1992 |
+
)
|
1993 |
+
own_measure_df = pd.concat(
|
1994 |
+
[
|
1995 |
+
df.drop(prefix, axis=1)
|
1996 |
+
for df in [
|
1997 |
+
number_of_fixations_own_vals,
|
1998 |
+
initial_fixation_duration_own_vals,
|
1999 |
+
first_of_many_duration_own_vals,
|
2000 |
+
total_fixation_duration_own_vals,
|
2001 |
+
gaze_duration_own_vals,
|
2002 |
+
go_past_duration_own_vals,
|
2003 |
+
second_pass_duration_own_vals,
|
2004 |
+
initial_landing_position_own_vals,
|
2005 |
+
initial_landing_distance_own_vals,
|
2006 |
+
landing_distances_own_vals,
|
2007 |
+
number_of_regressions_in_own_vals,
|
2008 |
+
]
|
2009 |
+
],
|
2010 |
+
axis=1,
|
2011 |
+
)
|
2012 |
+
own_measure_df[prefix] = number_of_fixations_own_vals[prefix]
|
2013 |
+
first_column = own_measure_df.pop(prefix)
|
2014 |
+
own_measure_df.insert(0, prefix, first_column)
|
2015 |
+
own_measure_df.insert(0, f"{prefix}_num", np.arange((own_measure_df.shape[0])))
|
2016 |
+
return own_measure_df
|