import copy |
from PIL import Image |
from io import StringIO |
import streamlit as st |
import pandas as pd |
import numpy as np |
import re |
import time |
import os |
from matplotlib.font_manager import FontProperties |
from matplotlib.patches import Rectangle |
from matplotlib import pyplot as plt |
import plotly.graph_objects as go |
import plotly.express as px |
import numpy as np |
import pandas as pd |
import pathlib as pl |
import json |
import logging |
import zipfile |
from stqdm import stqdm |
import jellyfish as jf |
import lovely_tensors |
import shutil |
import eyekit_measures as ekm |
import zipfile |
import utils as ut |
os.environ["MPLCONFIGDIR"] = os.getcwd() + "/configs/" |
st.set_page_config("Correction", page_icon=":eye:", layout="wide") |
DEFAULT_PLOT_FONT = "DejaVu Sans Mono" |
EXAMPLES_FOLDER = "./testfiles/" |
EXAMPLES_ASC_ZIP_FILENAME = "asc_files.zip" |
OSF_DOWNLAOD_LINK = "https://osf.io/download/us97f/" |
lovely_tensors.monkey_patch() |
def make_folders(gradio_temp_folder, gradio_temp_unzipped_folder, gradio_plots): |
return ut.make_folders(gradio_temp_folder, gradio_temp_unzipped_folder, gradio_plots) |
TEMP_FOLDER = st.session_state["TEMP_FOLDER"] = ut.TEMP_FOLDER |
gradio_temp_unzipped_folder = st.session_state["gradio_temp_unzipped_folder"] = pl.Path("unzipped") |
PLOTS_FOLDER = st.session_state["PLOTS_FOLDER"] = pl.Path("plots") |
TEMP_FIGURE_STIMULUS_PATH = PLOTS_FOLDER.joinpath("temp_matplotlib_plot_stimulus.png") |
make_folders(TEMP_FOLDER, gradio_temp_unzipped_folder, PLOTS_FOLDER) |
@st.cache_data |
def get_classic_cfg(fname): |
return ut.get_classic_cfg(fname) |
classic_algos_cfg = st.session_state["classic_algos_cfg"] = get_classic_cfg("algo_cfgs_all.json") |
DIST_MODELS_FOLDER = st.session_state["DIST_MODELS_FOLDER"] = pl.Path("models") |
COLORS = st.session_state["COLORS"] = px.colors.qualitative.Alphabet |
ALGO_CHOICES = st.session_state["ALGO_CHOICES"] = [ |
"warp", |
"regress", |
"compare", |
"attach", |
"segment", |
"split", |
"stretch", |
"chain", |
"slice", |
"cluster", |
"merge", |
"Wisdom_of_Crowds", |
"DIST", |
"DIST-Ensemble", |
"Wisdom_of_Crowds_with_DIST", |
"Wisdom_of_Crowds_with_DIST_Ensemble", |
] |
st.session_state["colnames_custom_csv_fix"] = { |
"x_col_name_fix": "x", |
"y_col_name_fix": "y", |
"x_col_name_fix_stim": "char_x_center", |
"x_start_col_name_fix_stim": "char_xmin", |
"x_end_col_name_fix_stim": "char_xmax", |
"y_col_name_fix_stim": "char_y_center", |
"y_start_col_name_fix_stim": "char_ymin", |
"y_end_col_name_fix_stim": "char_ymax", |
"char_col_name_fix_stim": "char", |
"trial_id_col_name_fix": "trial_id", |
"trial_id_col_name_stim": "trial_id", |
"subject_col_name_fix": "subid", |
"subject_col_name_stim": "subid", |
"line_num_col_name_stim": "assigned_line", |
"time_start_col_name_fix": "start", |
"time_stop_col_name_fix": "stop", |
} |
if "results" not in st.session_state: |
st.session_state["results"] = {} |
@st.cache_resource |
def load_model(model_file, cfg): |
return ut.load_model(model_file, cfg) |
@st.cache_resource |
def find_and_load_model(model_date="20240104-223349"): |
return ut.find_and_load_model(model_date) |
def create_logger(name, level="DEBUG", file=None): |
logger = logging.getLogger(name) |
logger.propagate = False |
logger.setLevel(level) |
if sum([isinstance(handler, logging.StreamHandler) for handler in logger.handlers]) == 0: |
ch = logging.StreamHandler() |
ch.setFormatter( |
logging.Formatter( |
"%(asctime)s.%(msecs)03d-%(name)s-p%(process)s-{%(pathname)s:%(lineno)d}-%(levelname)s >>> %(message)s", |
"%m-%d %H:%M:%S", |
) |
) |
logger.addHandler(ch) |
if file is not None: |
if sum([isinstance(handler, logging.FileHandler) for handler in logger.handlers]) == 0: |
ch = logging.FileHandler(file, "w") |
ch.setFormatter( |
logging.Formatter( |
"%(asctime)s.%(msecs)03d-%(name)s-p%(process)s-{%(pathname)s:%(lineno)d}-%(levelname)s >>> %(message)s", |
"%m-%d %H:%M:%S", |
) |
) |
logger.addHandler(ch) |
logger.debug("Logger added") |
return logger |
if "logger" not in st.session_state: |
st.session_state["logger"] = create_logger(name="app", level="DEBUG", file="log_for_app.log") |
@st.cache_data |
EXAMPLE_ASC_FILES = download_example_ascs( |
) |
def asc_to_trial_ids(asc_file, close_gap_between_words=True): |
return ut.asc_to_trial_ids(asc_file, close_gap_between_words) |
@st.cache_data |
def get_trials_list(asc_file=None, close_gap_between_words=True): |
return ut.get_trials_list(asc_file, close_gap_between_words) |
@st.cache_data |
def prep_data_for_dist(model_cfg, dffix, trial=None): |
return ut.prep_data_for_dist(model_cfg, dffix, trial) |
def save_trial_to_json(trial, savename): |
return ut.save_trial_to_json(trial, savename) |
def export_csv(dffix, trial): |
return ut.export_csv(dffix, trial) |
@st.cache_data |
def get_DIST_preds(dffix, trial): |
return ut.get_DIST_preds(dffix, trial) |
@st.cache_data |
def get_EDIST_preds_with_model_check(dffix, trial, ensemble_model_avg=None): |
return ut.get_EDIST_preds_with_model_check(dffix, trial, ensemble_model_avg) |
def get_all_classic_preds(dffix, trial): |
return ut.get_all_classic_preds(dffix, trial) |
def apply_woc(dffix, trial, corrections, algo_choice): |
return ut.apply_woc(dffix, trial, corrections, algo_choice) |
@st.cache_data |
def correct_df( |
dffix, |
algo_choice, |
trial=None, |
for_multi=False, |
ensemble_model_avg=None, |
): |
return ut.correct_df( |
dffix, |
algo_choice, |
trial, |
for_multi, |
ensemble_model_avg, |
) |
@st.cache_data |
def get_font_and_font_size_from_trial(trial): |
return ut.get_font_and_font_size_from_trial(trial) |
@st.cache_data |
def add_default_font_and_character_props_to_state(trial): |
return ut.add_default_font_and_character_props_to_state(trial) |
@st.cache_data |
def get_plot_props(trial, available_fonts): |
return ut.get_plot_props(trial, available_fonts) |
def process_trial_choice(trial_id, algo_choice): |
if isinstance(trial_id, dict): |
trial_id = trial_id["value"] |
trials_by_ids = st.session_state["trials_by_ids"] |
trial = trials_by_ids[trial_id] |
if "chars_list" in trial: |
( |
y_diff, |
x_txt_start, |
y_txt_start, |
font_face, |
_, |
line_height, |
) = add_default_font_and_character_props_to_state(trial) |
font_size = ut.set_font_from_chars_list(trial) |
st.session_state["y_diff_for_eyekit"] = y_diff |
st.session_state["x_txt_start_for_eyekit"] = x_txt_start |
st.session_state["y_txt_start_for_eyekit"] = y_txt_start |
st.session_state["font_face_for_eyekit"] = font_face |
st.session_state["font_size_for_eyekit"] = font_size |
st.session_state["line_height_for_eyekit"] = line_height |
if "dffix" in trial: |
dffix = trial["dffix"] |
else: |
asc_file = pl.Path(st.session_state["asc_file"].name) |
trial["plot_file"] = str(PLOTS_FOLDER.joinpath(f"{asc_file.stem}_{trial_id}_2ndInput_chars_channel_sep.png")) |
trial["fname"] = str(asc_file.name).split(".")[0] |
df, dffix, trial = ut.trial_to_dfs(trial, st.session_state["lines"], use_synctime=True) |
st.session_state["logger"].info(f"dffix.columns after trial_to_dfs {dffix.columns}") |
font, font_size, dpi, screen_res = ut.get_plot_props(trial, AVAILABLE_FONTS) |
st.session_state["trial"] = trial |
if "chars_list" in trial: |
chars_df = pd.DataFrame(trial["chars_list"]) |
trial["chars_df"] = chars_df.to_dict() |
trial["y_char_unique"] = list(chars_df.char_y_center.sort_values().unique()) |
if algo_choice is not None and ("chars_list" in trial or "words_list" in trial): |
dffix, _ = correct_df(dffix, algo_choice, trial) |
else: |
st.warning("🚨 Stimulus information needed for fixation correction 🚨") |
return dffix, trial, dpi, screen_res, font, font_size |
@st.cache_data |
def process_trial_choice_single_csv(trial, algo_choice, file=None): |
return ut.process_trial_choice_single_csv(trial, algo_choice, file=file) |
def quick_dffix_save(dffix, savename): |
dffix.to_csv(savename) |
st.session_state["logger"].info(f"Saved processed data as {savename}") |
def save_trial_to_json(trial, savename): |
if "dffix" in trial: |
trial.pop("dffix") |
with open(savename, "w", encoding="utf-8") as f: |
json.dump(trial, f, ensure_ascii=False, indent=4, cls=ut.NumpyEncoder) |
@st.cache_data |
def process_trial(trial, asc_file_stem, lines, algo_choice, for_multi=False): |
trial_id = trial["trial_id"] |
trial["plot_file"] = str(PLOTS_FOLDER.joinpath(f"{asc_file_stem}_{trial_id}_2ndInput_chars_channel_sep.png")) |
trial["fname"] = str(asc_file_stem) |
font, font_size, dpi, screen_res = ut.get_plot_props(trial, AVAILABLE_FONTS) |
trial["font"] = font |
trial["font_size"] = font_size |
trial["dpi"] = dpi |
trial["screen_res"] = screen_res |
df, dffix, trial = ut.trial_to_dfs(trial, lines, use_synctime=True) |
if dffix.empty: |
return pd.DataFrame(), trial |
chars_df = pd.DataFrame(trial["chars_list"]) |
trial["y_char_unique"] = list(chars_df.char_y_center.sort_values().unique()) |
trial["chars_df"] = chars_df.to_dict() |
trial["y_char_unique"] = list(chars_df.char_y_center.sort_values().unique()) |
if algo_choice is not None: |
dffix = correct_df(dffix, algo_choice, trial, for_multi) |
return dffix, trial |
def add_text_to_ax( |
chars_list, |
ax, |
font_to_use="DejaVu Sans Mono", |
fontsize=21, |
prefix="char", |
plot_boxes=True, |
plot_text=True, |
box_annotations=None, |
): |
return ut.add_text_to_ax( |
chars_list, |
ax, |
font_to_use=font_to_use, |
fontsize=fontsize, |
prefix=prefix, |
plot_boxes=plot_boxes, |
plot_text=plot_text, |
box_annotations=box_annotations, |
) |
@st.cache_data |
def matplotlib_plot_df( |
dffix, |
trial, |
algo_choice, |
stimulus_prefix="word", |
desired_dpi=300, |
fix_to_plot=[], |
stim_info_to_plot=["Words", "Word boxes"], |
box_annotations=None, |
): |
return ut.matplotlib_plot_df( |
dffix, |
trial, |
algo_choice, |
stimulus_prefix=stimulus_prefix, |
desired_dpi=desired_dpi, |
fix_to_plot=fix_to_plot, |
stim_info_to_plot=stim_info_to_plot, |
box_annotations=box_annotations, |
) |
def sigmoid(x): |
return 1 / (1 + np.exp(-1 * x)) |
@st.cache_data |
def plotly_plot_with_image( |
dffix, |
trial, |
algo_choice, |
to_plot_list=["Uncorrected Fixations", "Words", "corrected fixations", "Word boxes"], |
scale_factor=0.5, |
): |
return ut.plotly_plot_with_image( |
dffix, |
trial, |
algo_choice, |
to_plot_list=to_plot_list, |
scale_factor=scale_factor, |
) |
@st.cache_data |
def plot_y_corr(dffix, algo_choice): |
return ut.plot_y_corr(dffix, algo_choice) |
def plotly_df( |
dffix=None, trial=None, algo_choice=None, to_plot_list=["fixations", "characters", "corrected fixations"], title="" |
): |
if dffix is None: |
dffix = st.session_state["dffix"] |
if algo_choice is None: |
algo_choice = st.session_state["algo_choice"] |
st.session_state["logger"].info(f"Plotting {to_plot_list}") |
num_datapoints = dffix.index |
if trial is None: |
if title in st.session_state["results"]: |
chars_df = pd.DataFrame(st.session_state["results"][title]["trial"]["chars_list"]) |
else: |
chars_df = pd.DataFrame(st.session_state["trial"]["chars_df"]) |
else: |
chars_df = pd.DataFrame(trial["chars_list"]) if "chars_list" in trial else None |
if chars_df is not None: |
font_face, font_size = get_font_and_font_size_from_trial(trial) |
font_size = font_size * 0.65 |
xmin = chars_df.char_x_center.min() |
xmax = chars_df.char_x_center.max() |
ymin = chars_df.char_y_center.min() |
ymax = chars_df.char_y_center.max() |
else: |
st.warning("No character or word information available to plot") |
xmin = dffix.x.min() |
xmax = dffix.x.max() |
ymin = dffix.y.min() |
ymax = dffix.y.max() |
layout = dict( |
plot_bgcolor="white", |
autosize=True, |
margin=dict(t=1, l=10, r=10, b=1), |
xaxis=dict( |
title="x-coordinate", |
linecolor="black", |
range=[xmin - 100, xmax + 100], |
showgrid=False, |
mirror="all", |
showline=True, |
), |
yaxis=dict( |
title="y-coordinate", |
range=[ymax + 100, ymin - 100], |
linecolor="black", |
showgrid=False, |
mirror="all", |
showline=True, |
), |
legend=dict(orientation="h", yanchor="bottom", y=1.05, xanchor="right", x=0.8), |
) |
fig = go.Figure(layout=layout) |
if "Uncorrected Fixations" in to_plot_list: |
duration_scaled = dffix.duration - dffix.duration.min() |
duration = ((duration_scaled + 0.1) / duration_scaled.median()) * 5 |
fig.add_trace( |
go.Scatter( |
x=dffix.x, |
y=dffix.y, |
mode="markers+lines+text", |
name="Raw fixations", |
marker=dict( |
symbol="arrow", |
size=duration.values, |
angleref="previous", |
), |
line_width=1.2, |
text=num_datapoints, |
textposition="middle right", |
textfont=dict( |
family="sans serif", |
size=9, |
), |
hoverinfo="text+x+y", |
opacity=0.6, |
) |
) |
if "Corrected Fixations" in to_plot_list: |
if isinstance(algo_choice, list): |
algo_choices = algo_choice |
repeats = range(len(algo_choice)) |
else: |
algo_choices = [algo_choice] |
repeats = range(1) |
for algoIdx in repeats: |
algo_choice = algo_choices[algoIdx] |
if f"y_{algo_choice}" in dffix.columns: |
fig.add_trace( |
go.Scatter( |
x=dffix.x, |
y=dffix.loc[:, f"y_{algo_choice}"], |
mode="markers", |
name=f"{algo_choice} corrected", |
marker_color=st.session_state["COLORS"][algoIdx], |
marker_size=5, |
hoverinfo="text+x+y", |
opacity=0.75, |
) |
) |
if "Characters" in to_plot_list and chars_df is not None: |
fig.add_trace( |
go.Scatter( |
x=chars_df.char_x_center, |
y=chars_df.char_y_center, |
mode="markers+text", |
name="", |
showlegend=False, |
text=chars_df.char, |
textposition="middle center", |
marker=dict(color="black", size=0.1), |
textfont=dict(family=font_face, size=font_size, color="Black"), |
) |
) |
if "Character boxes (slow to plot)" in to_plot_list and chars_df is not None: |
num = 0 |
for k, row in stqdm(chars_df.iterrows(), "Adding boxes"): |
fig.add_shape( |
type="rect", |
x0=row.char_xmin, |
y0=row.char_ymin, |
x1=row.char_xmax, |
y1=row.char_ymax, |
line=dict(color=st.session_state["COLORS"][-1], width=1), |
) |
num += 1 |
return fig |
def save_to_zips(folder, pattern, savename): |
if os.path.exists(TEMP_FOLDER.joinpath(savename)): |
mode = "a" |
else: |
mode = "w" |
for idx, f in enumerate(folder.glob(pattern)): |
with zipfile.ZipFile(TEMP_FOLDER.joinpath(savename), mode=mode) as archive: |
archive.write(f) |
st.session_state["logger"].info(f"Written {f} to zip {TEMP_FOLDER.joinpath(savename)}") |
if idx == 1: |
mode = "a" |
st.session_state["logger"].info("Done zipping") |
def process_multiple_asc(asc_files): |
algo_choice = st.session_state["algo_choice_multi"] |
if algo_choice is not None and "DIST" in algo_choice: |
model, model_cfg = find_and_load_model(model_date=st.session_state["DIST_MODEL_DATE_WITH_NORM"]) |
model = st.session_state["single_DIST_model"] |
model_cfg = st.session_state["single_DIST_model_cfg"] |
st.session_state["logger"].info(f"process_multiple_asc loaded model") |
else: |
model, model_cfg = None, None |
zipfiles_with_results = [] |
st.session_state["logger"].info(f"found asc_files {asc_files}") |
for asc_file in stqdm(asc_files, desc="Processing asc files"): |
st.session_state["logger"].info(f"processing asc_file {asc_file}") |
asc_file_stem = pl.Path(asc_file.name).stem |
trials_by_ids, lines = asc_to_trial_ids(asc_file) |
for trial_id, trial in stqdm(trials_by_ids.items(), desc=f"\nProcessing trials in {asc_file_stem}"): |
dffix, trial = process_trial( |
trial, |
asc_file_stem, |
lines, |
algo_choice, |
True, |
) |
st.session_state["logger"].debug(f"dffix.columns after process trial {dffix.columns}") |
if dffix.empty: |
st.session_state["logger"].warning(f"Dataframe for {trial_id} is empty, skipping") |
continue |
st.session_state["results"][f"{asc_file_stem}_{trial_id}"] = { |
"trial": trial, |
"dffix": dffix, |
} |
st.session_state["logger"].debug(f"Added {asc_file_stem}_{trial_id} to st.session_state") |
quick_dffix_save(dffix, TEMP_FOLDER.joinpath(f"{asc_file_stem}_{trial_id}.csv")) |
save_trial_to_json(trial, TEMP_FOLDER.joinpath(f"{asc_file_stem}_{trial_id}.json")) |
ut.plot_fixations_and_text( |
dffix, |
trial, |
save=True, |
savelocation=TEMP_FOLDER.joinpath(f"{asc_file_stem}_{trial_id}.png"), |
algo_choice=algo_choice, |
turn_axis_on=False, |
) |
if os.path.exists(TEMP_FOLDER.joinpath(f"{asc_file_stem}.zip")): |
os.remove(TEMP_FOLDER.joinpath(f"{asc_file_stem}.zip")) |
save_to_zips(TEMP_FOLDER, f"{asc_file_stem}*.csv", f"{asc_file_stem}.zip") |
save_to_zips(TEMP_FOLDER, f"{asc_file_stem}*.json", f"{asc_file_stem}.zip") |
save_to_zips(TEMP_FOLDER, f"{asc_file_stem}*.png", f"{asc_file_stem}.zip") |
zipfiles_with_results += [str(x) for x in TEMP_FOLDER.glob(f"{asc_file_stem}*.zip")] |
results_keys = list(st.session_state["results"].keys()) |
st.session_state["logger"].debug(f"results_keys are {results_keys}") |
st.session_state["trial_choices_multi"] = results_keys |
st.session_state["zipfiles_with_results"] = zipfiles_with_results |
return (zipfiles_with_results, results_keys) |
@st.cache_data |
def get_trials_and_lines_from_asc_files(asc_files): |
list_of_trial_lists = [] |
list_of_lines = [] |
total_num_trials = 0 |
asc_files_to_do = [] |
for filename_full in asc_files: |
if hasattr(filename_full, "name") and not isinstance(filename_full, pl.Path): |
file = filename_full.name |
st.session_state["logger"].info(f"Filename is {file}, filename_full is {filename_full}") |
else: |
file = filename_full |
if not isinstance(file, str): |
file_stem = pl.Path(file.name).stem |
else: |
file_stem = pl.Path(file).stem |
savefolder = gradio_temp_unzipped_folder.joinpath(file_stem) |
st.session_state["logger"].info(f"Operating on file {file}") |
if ".zip" in file: |
with zipfile.ZipFile(filename_full, "r") as z: |
z.extractall(str(savefolder)) |
elif ".tar" in file: |
shutil.unpack_archive(file, savefolder, "tar") |
elif ".asc" in file: |
asc_files_to_do.append(filename_full) |
else: |
st.session_state["logger"].warning(f"Unsopported file format found in files") |
newfiles = [str(x) for x in savefolder.glob(f"*.asc")] |
asc_files_to_do += newfiles |
st.session_state["logger"].info(f"asc_files_to_do is {asc_files_to_do}") |
for asc_file in asc_files_to_do: |
trials_by_ids, lines = asc_to_trial_ids(asc_file) |
total_num_trials += len(trials_by_ids) |
list_of_trial_lists.append(trials_by_ids) |
list_of_lines.append(lines) |
st.session_state["list_of_trial_lists"] = list_of_trial_lists |
st.session_state["list_of_lines"] = list_of_lines |
process_multiple_asc(st.session_state["multi_asc_filelist"]) |
def process_trial_choice_and_update_df_multi(): |
trial_id = st.session_state["trial_id_multi"] |
dffix = st.session_state["results"][trial_id]["dffix"] |
if "start_time" in dffix.columns: |
dffix = dffix.drop(axis=1, labels=["start_time", "end_time"]) |
st.session_state["dffix_multi"] = dffix |
st.session_state["trial_multi"] = st.session_state["results"][trial_id]["trial"] |
@st.cache_data |
def convert_df(df): |
return df.to_csv(index=False).encode("utf-8") |
def make_trial_from_stimulus_df( |
stim_plot_df, |
filename, |
trial_id, |
): |
chars_list = [] |
words_list = [] |
word_start_idx = 0 |
for idx, row in stim_plot_df.reset_index().iterrows(): |
char_dict = dict( |
char_xmin=row[st.session_state["x_start_col_name_fix_stim"]], |
char_xmax=row[st.session_state["x_end_col_name_fix_stim"]], |
char_ymin=row[st.session_state["y_start_col_name_fix_stim"]], |
char_ymax=row[st.session_state["y_end_col_name_fix_stim"]], |
char_x_center=row[st.session_state["x_col_name_fix_stim"]], |
char_y_center=row[st.session_state["y_col_name_fix_stim"]], |
char=row[st.session_state["char_col_name_fix_stim"]], |
assigned_line=int(row[st.session_state["line_num_col_name_stim"]]), |
) |
chars_list.append(char_dict) |
if len(chars_list) > 1 and ( |
char_dict["char"] == " " |
or (len(chars_list) > 2 and (chars_list[-1]["char_xmin"] < chars_list[-2]["char_xmin"])) |
): |
word_dict = dict( |
word_xmin=chars_list[word_start_idx]["char_xmin"], |
word_xmax=chars_list[-2]["char_xmax"], |
word_ymin=chars_list[word_start_idx]["char_ymin"], |
word_ymax=chars_list[word_start_idx]["char_ymax"], |
word_x_center=(chars_list[-2]["char_xmax"] - chars_list[word_start_idx]["char_xmin"]) / 2 |
+ chars_list[word_start_idx]["char_xmin"], |
word_y_center=(chars_list[word_start_idx]["char_ymax"] - chars_list[word_start_idx]["char_ymin"]) / 2 |
+ chars_list[word_start_idx]["char_ymin"], |
word="".join([chars_list[idx]["char"] for idx in range(word_start_idx, len(chars_list) - 1)]), |
) |
if char_dict["char"] != " ": |
word_start_idx = idx |
else: |
word_start_idx = idx + 1 |
words_list.append(word_dict) |
line_heights = [x["char_ymax"] - x["char_ymin"] for x in chars_list] |
line_xcoords_all = [x["char_x_center"] for x in chars_list] |
line_xcoords_no_pad = np.unique(line_xcoords_all) |
line_ycoords_all = [x["char_y_center"] for x in chars_list] |
line_ycoords_no_pad = np.unique(line_ycoords_all) |
trial = dict( |
filename=filename, |
y_midline=[float(x) for x in list(stim_plot_df[st.session_state["y_col_name_fix_stim"]].unique())], |
num_char_lines=len(stim_plot_df[st.session_state["y_col_name_fix_stim"]].unique()), |
y_diff=[ |
float(x) for x in list(np.unique(np.diff(stim_plot_df[st.session_state["y_start_col_name_fix_stim"]]))) |
], |
trial_id=trial_id, |
chars_list=chars_list, |
words_list=words_list, |
trial_is="paragraph", |
text="".join([x["char"] for x in chars_list]), |
) |
trial["x_char_unique"] = [float(x) for x in list(line_xcoords_no_pad)] |
trial["y_char_unique"] = list(map(float, list(line_ycoords_no_pad))) |
x_diff, y_diff = ut.calc_xdiff_ydiff( |
line_xcoords_no_pad, line_ycoords_no_pad, line_heights, allow_multiple_values=False |
) |
trial["x_diff"] = float(x_diff) |
trial["y_diff"] = float(y_diff) |
trial["num_char_lines"] = len(line_ycoords_no_pad) |
trial["line_heights"] = list(map(float, line_heights)) |
trial["chars_list"] = chars_list |
return trial |
@st.cache_data |
def get_fixations_file_trials_list(fixations_df, stimulus): |
if isinstance(stimulus, pd.DataFrame): |
stimulus[st.session_state["line_num_col_name_stim"]] -= stimulus[ |
st.session_state["line_num_col_name_stim"] |
].min() |
stimulus.rename( |
{ |
st.session_state["x_col_name_fix_stim"]: "char_x_center", |
st.session_state["x_start_col_name_fix_stim"]: "char_xmin", |
st.session_state["x_end_col_name_fix_stim"]: "char_xmax", |
st.session_state["y_col_name_fix_stim"]: "char_y_center", |
st.session_state["y_start_col_name_fix_stim"]: "char_ymin", |
st.session_state["y_end_col_name_fix_stim"]: "char_ymax", |
st.session_state["char_col_name_fix_stim"]: "char", |
st.session_state["trial_id_col_name_stim"]: "trial_id", |
}, |
axis=1, |
inplace=True, |
) |
fixations_df.rename( |
mapper={ |
st.session_state["x_col_name_fix"]: "x", |
st.session_state["y_col_name_fix"]: "y", |
st.session_state["time_start_col_name_fix"]: "corrected_start_time", |
st.session_state["time_stop_col_name_fix"]: "corrected_end_time", |
st.session_state["trial_id_col_name_fix"]: "trial_id", |
}, |
axis=1, |
inplace=True, |
) |
fixations_df["duration"] = fixations_df.corrected_end_time - fixations_df.corrected_start_time |
if "trial_id" in stimulus: |
fixations_df["trial_id"] = stimulus["trial_id"] |
if "trial_id" in fixations_df: |
if st.session_state["has_multiple_subject"]: |
fixations_df["trial_id"] = [ |
f"{id}_{num}" |
for id, num in zip( |
fixations_df[st.session_state["subject_col_name_fix"]], |
fixations_df[st.session_state["trial_id_col_name_fix"]], |
) |
] |
trial_keys = list(fixations_df[st.session_state["trial_id_col_name_fix"]].unique()) |
st.session_state["logger"].info(f"Found keys {trial_keys} for {st.session_state['single_csv_file'].name}") |
else: |
st.session_state["logger"].warning(f"trial id column not found assigning trial id trial_0.") |
st.warning(f"trial id column not found assigning trial id trial_0.") |
fixations_df["trial_id"] = "trial_0" |
st.session_state["fixations_df"] = fixations_df |
trials_by_ids = {} |
for trial_id, subdf in fixations_df.groupby("trial_id"): |
if isinstance(stimulus, pd.DataFrame): |
stim_df = stimulus[stimulus.trial_id == trial_id] |
stim_df = stim_df.dropna(axis=0, how="any") |
subdf = subdf.dropna(axis=0, how="any") |
subdf = subdf.reset_index(drop=True) |
stim_df = stim_df.reset_index(drop=True) |
assert not stim_df.empty, "stimulus df is empty" |
trial = make_trial_from_stimulus_df( |
stim_df, |
st.session_state["single_csv_file_stim"].name, |
trial_id, |
) |
else: |
trial = stimulus |
trial["dffix"] = subdf |
trial["fname"] = f"{trial_id}" |
trial["plot_file"] = str( |
st.session_state["PLOTS_FOLDER"].joinpath(f"{trial_id}_2ndInput_chars_channel_sep.png") |
) |
trials_by_ids[trial_id] = trial |
return trials_by_ids, trial_keys |
def try_reading_csv(file): |
stringio = StringIO(file.getvalue().decode("utf-8")) |
colname_mapping = {} |
try: |
df = pd.read_csv(stringio) |
st.session_state["logger"].info(f"\n{df.head()}") |
col_list = df.columns |
assert len(col_list) > 1 |
return df |
except Exception as e: |
st.session_state["logger"].warn(e) |
try: |
df = pd.read_csv(StringIO(file.getvalue().decode("utf-8")), delimiter="\t") |
col_list = df.columns |
assert len(col_list) > 1 |
return df |
except Exception as e: |
st.session_state["logger"].warn(e) |
return None |
@st.cache_data |
def guess_col_names_fix(file=None): |
if file is None: |
file = st.session_state["single_csv_file"] |
if file is None: |
return None |
first_line = next(iter(StringIO(file.getvalue().decode("utf-8")))) |
res = re.findall(r"[^()0-9-]+", first_line) |
for delim in [",", "\t", ";"]: |
first_line = first_line.split(delim) |
if len(first_line) > 2: |
break |
else: |
first_line = first_line[0] |
scores_lists = {} |
for k, v in st.session_state["colnames_custom_csv_fix"].items(): |
scores_lists[v] = [] |
for word in first_line: |
scores_lists[v].append(jf.levenshtein_distance(v, word)) |
scores_df = pd.DataFrame(scores_lists) |
scores_df.idxmin(axis=0) |
df = try_reading_csv(file) |
if df.shape[1] > 1: |
return df |
else: |
return None |
@st.cache_data |
def guess_col_names_stim(file=None): |
if file is None: |
file = st.session_state["single_csv_file_stim"] |
if file is None: |
return None |
if ".json" in file.name: |
json_string = file.getvalue().decode("utf-8") |
trial = json.loads(json_string) |
return trial |
else: |
df = try_reading_csv(file) |
if df.shape[1] > 1: |
return df |
else: |
return None |
@st.cache_resource |
def set_up_models(dist_models_folder): |
return ut.set_up_models(dist_models_folder) |
@st.cache_data |
def get_eyekit_measures(_txt, _seq, get_char_measures=False): |
return ekm.get_eyekit_measures(_txt, _seq, get_char_measures=get_char_measures) |
@st.cache_data |
def get_all_measures(trial, dffix, prefix, use_corrected_fixations=True, correction_algo="warp"): |
return ut.get_all_measures(trial, dffix, prefix, use_corrected_fixations=use_corrected_fixations, correction_algo=correction_algo) |
assert "ALGO_CHOICES" in st.session_state, f"st.session_state not initialized\n{list(st.session_state.keys())}" |
set_up_models_out = set_up_models(DIST_MODELS_FOLDER) |
st.session_state.update(set_up_models_out) |
st.title("Fixation data vertical alignment") |
st.header("👀 Read asc file or files and plot fixations 👀") |
st.markdown("[Contact Us](mailto:[email protected])") |
st.markdown("[Read about DIST model](https://arxiv.org/abs/2311.06095)") |
single_file_tab, multi_file_tab = st.tabs(["Single File 📁", "Multiple Files 📁 📁"]) |
single_file_tab_asc_tab, single_file_tab_csv_tab = single_file_tab.tabs([".asc files", "custom files"]) |
single_file_tab_asc_tab.subheader( |
"Upload an .asc file and select a trial. Then select a correction algorithm and plot/download the results" |
) |
def change_which_file_is_used_and_clear_results(): |
if "dffix" in st.session_state: |
del st.session_state["dffix"] |
if "trial" in st.session_state: |
del st.session_state["trial"] |
if st.session_state["single_file_tab_asc_tab_example_use_example_or_uploaded_file_choice"] == "Example File": |
st.session_state["single_asc_file_asc"] = st.session_state["single_file_tab_asc_tab_example_file_choice"] |
else: |
st.session_state["single_asc_file_asc"] = st.session_state["single_asc_uploaded_file"] |
with single_file_tab_asc_tab.form("single_file_tab_asc_tab_load_example_form"): |
single_asc_file_asc_uploaded = st.file_uploader( |
"Select .asc File", accept_multiple_files=False, key="single_asc_uploaded_file", type=["asc"] |
) |
close_gap_between_words_single_asc = st.checkbox( |
label="Should spaces between words be included in word bounding box?", |
value=False, |
key="close_gap_between_words_single_asc", |
) |
if os.path.isfile(EXAMPLE_ASC_FILES[0]): |
example_file_choice = st.selectbox( |
"Select example file", options=EXAMPLE_ASC_FILES, key="single_file_tab_asc_tab_example_file_choice" |
) |
use_example_or_uploaded_file_choice = st.radio( |
"Should the uploaded file be used or the selected example file?", |
index=1, |
options=["Uploaded File", "Example File"], |
key="single_file_tab_asc_tab_example_use_example_or_uploaded_file_choice", |
) |
upload_file_button = st.form_submit_button( |
label="Load selected data.", on_click=change_which_file_is_used_and_clear_results |
) |
if "single_asc_file_asc" in st.session_state and st.session_state["single_asc_file_asc"] is not None: |
trial_choices_single_asc, trials_by_ids, lines, asc_file = get_trials_list( |
st.session_state["single_asc_file_asc"], close_gap_between_words=close_gap_between_words_single_asc |
) |
st.session_state["trials_by_ids"] = trials_by_ids |
st.session_state["trial_choices"] = trial_choices_single_asc |
st.session_state["lines"] = lines |
st.session_state["asc_file"] = asc_file |
if trial_choices_single_asc: |
with single_file_tab_asc_tab.form(key="single_file_tab_asc_tab_trial_select_form"): |
col_a1, col_a2 = st.columns((1, 2)) |
with col_a1: |
trial_choice = st.selectbox( |
"Which trial should be corrected?", |
trial_choices_single_asc, |
key="trial_id", |
index=0, |
) |
with col_a2: |
st.multiselect( |
"Choose correction algorithm", |
key="algo_choice", |
default=[ALGO_CHOICES[0]], |
) |
process_trial_btn = st.form_submit_button("Load and correct trial") |
if process_trial_btn: |
single_file_tab_asc_tab.write(f'You selected: {st.session_state["trial_id"]}') |
dffix, trial, dpi, screen_res, font, font_size = process_trial_choice( |
trial_choice, st.session_state["algo_choice"] |
) |
st.session_state["dffix"] = dffix |
st.session_state["trial"] = trial |
st.session_state["dpi"] = dpi |
st.session_state["screen_res"] = screen_res |
st.session_state["font"] = font |
st.session_state["font_size"] = font_size |
export_csv(dffix, trial) |
if "dffix" in st.session_state and "trial" in st.session_state: |
df_expander_single = single_file_tab_asc_tab.expander("Show Dataframe", False) |
plot_expander_single = single_file_tab_asc_tab.expander("Show Plots", True) |
df_expander_single.dataframe(st.session_state["dffix"]) |
csv = convert_df(st.session_state["dffix"]) |
df_expander_single.download_button( |
"Download fixation dataframe", |
csv, |
f'{st.session_state["trial_id"]}.csv', |
"text/csv", |
key="download-csv-single", |
) |
plotting_checkboxes_single = plot_expander_single.multiselect( |
"Select what gets plotted", |
["Uncorrected Fixations", "Corrected Fixations", "Words", "Word boxes"], |
key="plotting_checkboxes_single", |
default=["Uncorrected Fixations", "Corrected Fixations", "Words", "Word boxes"], |
) |
scale_factor_single_asc = plot_expander_single.number_input( |
label="Scale factor for stimulus image", min_value=0.01, max_value=3.0, value=0.5, step=0.1 |
) |
plot_expander_single.plotly_chart( |
plotly_plot_with_image( |
st.session_state["dffix"], |
st.session_state["trial"], |
to_plot_list=plotting_checkboxes_single, |
algo_choice=st.session_state["algo_choice"], |
scale_factor=scale_factor_single_asc, |
), |
use_container_width=False, |
) |
plot_expander_single.plotly_chart( |
plot_y_corr(st.session_state["dffix"], st.session_state["algo_choice"]), use_container_width=True |
) |
if "chars_list" in st.session_state["trial"]: |
analysis_expander_single_asc = single_file_tab_asc_tab.expander("Show Analysis results", True) |
use_corrected_fixations_tickbox = analysis_expander_single_asc.checkbox( |
"Use corrected", |
True, |
"use_corrected_fixations_tickbox", |
help="Whether to use the corrected or uncorrected fixations for the analysis.", |
) |
eyekit_tab, own_analysis_tab = analysis_expander_single_asc.tabs( |
["Analysis using eyekit", "Analysis without eyekit"] |
) |
with eyekit_tab: |
st.markdown("Analysis powered by [eyekit](https://jwcarr.github.io/eyekit/)") |
st.markdown( |
"Please adjust parameters below to align fixations with stimulus using the sliders.Eyekit analysis is based on this alignment." |
) |
a_c1, a_c2, a_c3, a_c4, a_c5, a_c6 = st.columns(6) |
if "Consolas" in AVAILABLE_FONTS: |
font_index = AVAILABLE_FONTS.index("Consolas") |
elif "Courier New" in AVAILABLE_FONTS: |
font_index = AVAILABLE_FONTS.index("Courier New") |
elif "DejaVu Sans Mono" in AVAILABLE_FONTS: |
font_index = AVAILABLE_FONTS.index("DejaVu Sans Mono") |
else: |
font_index = 0 |
font_face = a_c1.selectbox( |
label="Select Font", |
index=font_index, |
key="font_face_for_eyekit_single_asc", |
) |
algo_choice_single_asc_eyekit = a_c1.selectbox( |
"Algorithm", st.session_state["algo_choice"], index=0, key="algo_choice_single_asc_eyekit" |
) |
sliders_on_tickbox = a_c6.checkbox( |
"Sliders", True, "single_asc_eyekit_sliders_checkbox", help="Turns sliders on and off" |
) |
if "font_size_for_eyekit" not in st.session_state: |
( |
y_diff, |
x_txt_start, |
y_txt_start, |
_, |
_, |
line_height, |
) = add_default_font_and_character_props_to_state(st.session_state["trial"]) |
font_size = ut.set_font_from_chars_list(st.session_state["trial"]) |
st.session_state["y_diff_for_eyekit"] = y_diff |
st.session_state["x_txt_start_for_eyekit"] = x_txt_start |
st.session_state["y_txt_start_for_eyekit"] = y_txt_start |
st.session_state["font_face_for_eyekit"] = font_face |
st.session_state["font_size_for_eyekit"] = font_size |
st.session_state["line_height_for_eyekit"] = line_height |
if sliders_on_tickbox: |
font_size = a_c2.select_slider( |
"Font Size", |
np.arange(5, 36, 0.25), |
st.session_state["font_size_for_eyekit"], |
key="font_size_for_eyekit_single_asc", |
) |
x_txt_start = a_c3.select_slider( |
"x", |
np.arange(300, 601, 1), |
round(st.session_state["x_txt_start_for_eyekit"]), |
key="x_txt_start_for_eyekit_single_asc", |
help="x coordinate of first character", |
) |
y_txt_start = a_c4.select_slider( |
"y", |
np.arange(100, 501, 1), |
round(st.session_state["y_txt_start_for_eyekit"]), |
key="y_txt_start_for_eyekit_single_asc", |
help="y coordinate of first character", |
) |
line_height = a_c5.select_slider( |
"Line height", |
np.arange(0, 151, 1), |
round(st.session_state["line_height_for_eyekit"]), |
key="line_height_for_eyekit_single_asc", |
) |
else: |
font_size = a_c2.number_input( |
"Font Size", |
None, |
None, |
round(st.session_state["font_size_for_eyekit"], ndigits=0), |
key="font_size_for_eyekit_single_asc", |
) |
x_txt_start = a_c3.number_input( |
"x", |
None, |
None, |
round(st.session_state["x_txt_start_for_eyekit"]), |
key="x_txt_start_for_eyekit_single_asc", |
help="x coordinate of first character", |
) |
y_txt_start = a_c4.number_input( |
"y", |
None, |
None, |
round(st.session_state["y_txt_start_for_eyekit"]), |
key="y_txt_start_for_eyekit_single_asc", |
help="y coordinate of first character", |
) |
line_height = a_c5.number_input( |
"Line height", |
None, |
None, |
round(st.session_state["line_height_for_eyekit"]), |
key="line_height_for_eyekit_single_asc", |
) |
fixation_sequence, textblock, screen_size = ekm.get_fix_seq_and_text_block( |
st.session_state["dffix"], |
st.session_state["trial"], |
x_txt_start=st.session_state["x_txt_start_for_eyekit_single_asc"], |
y_txt_start=st.session_state["y_txt_start_for_eyekit_single_asc"], |
font_face=st.session_state["font_face_for_eyekit_single_asc"], |
font_size=st.session_state["font_size_for_eyekit_single_asc"], |
line_height=line_height, |
use_corrected_fixations=st.session_state["use_corrected_fixations_tickbox"], |
correction_algo=st.session_state["algo_choice_single_asc_eyekit"], |
) |
eyekitplot_img = ekm.eyekit_plot(textblock, fixation_sequence, screen_size) |
st.image(eyekitplot_img, "Fixations and stimulus as used for anaylsis") |
with open( |
f'results/fixation_sequence_eyekit_{st.session_state["trial"]["trial_id"]}.json', "r" |
) as f: |
fixation_sequence_json = json.load(f) |
fixation_sequence_json_str = json.dumps(fixation_sequence_json) |
st.download_button( |
"Download fixations in eyekits format", |
fixation_sequence_json_str, |
f'fixation_sequence_eyekit_{st.session_state["trial"]["trial_id"]}.json', |
"json", |
key="download_eyekit_fix_json_single_asc", |
) |
with open(f'results/textblock_eyekit_{st.session_state["trial"]["trial_id"]}.json', "r") as f: |
textblock_json = json.load(f) |
textblock_json_str = json.dumps(textblock_json) |
st.download_button( |
"Download stimulus in eyekits format", |
textblock_json_str, |
f'textblock_eyekit_{st.session_state["trial"]["trial_id"]}.json', |
"json", |
key="download_eyekit_text_json_single_asc", |
) |
word_measures_df, character_measures_df = get_eyekit_measures( |
textblock, fixation_sequence, get_char_measures=False |
) |
st.dataframe(word_measures_df, use_container_width=True, hide_index=True) |
word_measures_df_csv = convert_df(word_measures_df) |
word_measures_df_download_btn = st.download_button( |
"Download word measures data", |
word_measures_df_csv, |
f'{st.session_state["trial"]["trial_id"]}_word_measures_df.csv', |
"text/csv", |
key="word_measures_df_download_btn", |
) |
measure_words = st.selectbox( |
"Select measure to visualize", list(ekm.MEASURES_DICT.keys()), key="measure_words" |
) |
st.image(ekm.plot_with_measure(textblock, fixation_sequence, screen_size, measure_words)) |
with own_analysis_tab: |
st.markdown( |
"This analysis method does not require manual alignment and works when the automated stimulus coordinates are correct." |
) |
own_word_measures = get_all_measures( |
st.session_state["trial"], |
st.session_state["dffix"], |
prefix="word", |
use_corrected_fixations=st.session_state["use_corrected_fixations_tickbox"], |
correction_algo=st.session_state["algo_choice_single_asc_eyekit"], |
) |
st.dataframe(own_word_measures, use_container_width=True, hide_index=True) |
own_word_measures_csv = convert_df(own_word_measures) |
word_measures_df_download_btn = st.download_button( |
"Download word measures data", |
own_word_measures_csv, |
f'{st.session_state["trial"]["trial_id"]}_own_word_measures_df.csv', |
"text/csv", |
key="own_word_measures_df_download_btn", |
) |
fix_to_plot = ( |
["Corrected Fixations"] |
if st.session_state["use_corrected_fixations_tickbox"] |
else ["Uncorrected Fixations"] |
) |
own_word_measures_fig, desired_width_in_pixels, desired_height_in_pixels = matplotlib_plot_df( |
st.session_state["dffix"], |
st.session_state["trial"], |
st.session_state["algo_choice"], |
stimulus_prefix="word", |
box_annotations=own_word_measures[measure_words], |
fix_to_plot=fix_to_plot, |
) |
st.pyplot(own_word_measures_fig) |
else: |
single_file_tab_asc_tab.warning("🚨 Stimulus information needed for analysis 🚨") |
single_file_tab_csv_tab.markdown( |
"#### Upload one .csv file for the fixations and one .json or .csv file for the stimulus information and select a trial. Then select a correction algorithm and plot/download the results" |
) |
with single_file_tab_csv_tab.expander("Upload and preview data", expanded=True): |
csv_upl_col1, csv_upl_col2 = st.columns(2) |
single_csv_file = csv_upl_col1.file_uploader( |
"Select .csv file containing the fixation data", |
accept_multiple_files=False, |
key="single_csv_file", |
type={"csv", "txt", "dat"}, |
) |
single_csv_stim_file = csv_upl_col2.file_uploader( |
"Select .csv or .json file containing the stimulus data", |
accept_multiple_files=False, |
key="single_csv_file_stim", |
type={"json", "csv", "txt", "dat"}, |
) |
if single_csv_file: |
st.session_state["dffix_single_csv"] = guess_col_names_fix(single_csv_file) |
if st.session_state["dffix_single_csv"] is not None: |
csv_upl_col1.dataframe( |
st.session_state["dffix_single_csv"], use_container_width=True, hide_index=True, height=200 |
) |
if single_csv_stim_file: |
st.session_state["stimdf_single_csv"] = guess_col_names_stim(single_csv_stim_file) |
if ".json" in single_csv_stim_file.name: |
st.session_state["colnames_stim"] = st.session_state["stimdf_single_csv"].keys() |
else: |
st.session_state["colnames_stim"] = st.session_state["stimdf_single_csv"].columns |
if st.session_state["stimdf_single_csv"] is not None: |
if ".json" in single_csv_stim_file.name: |
csv_upl_col2.json(st.session_state["stimdf_single_csv"]) |
else: |
csv_upl_col2.dataframe( |
st.session_state["stimdf_single_csv"], use_container_width=True, hide_index=True, height=200 |
) |
if single_csv_file and single_csv_stim_file: |
with single_file_tab_csv_tab.expander("Column names for csv files", expanded=True): |
with st.form("Column names for csv files"): |
st.markdown("### Please set column/key names for csv/json files") |
st.markdown("#### Fixation file column names:") |
c1, c2, c3 = st.columns(3) |
x_col_name_fix = c1.text_input("x coordinate", key="x_col_name_fix", value="x") |
y_col_name_fix = c2.text_input("y coordinate", key="y_col_name_fix", value="y") |
subject_col_name_fix = c1.text_input("subject id", key="subject_col_name_fix", value="sub_id") |
trial_id_col_name_fix = c3.text_input("trial id", key="trial_id_col_name_fix", value="trial_id") |
time_start_col_name_fix = c2.text_input( |
"fixation start time", key="time_start_col_name_fix", value="corrected_start_time" |
) |
time_stop_col_name_fix = c3.text_input( |
"fixation end time", key="time_stop_col_name_fix", value="corrected_end_time" |
) |
st.markdown("#### Stimulus file column/key names:") |
c1, c2, c3 = st.columns(3) |
x_col_name_fix_stim = c1.text_input("x coordinate", key="x_col_name_fix_stim", value="char_x_center") |
y_col_name_fix_stim = c2.text_input("y coordinate", key="y_col_name_fix_stim", value="char_y_center") |
x_start_col_name_fix_stim = c3.text_input( |
"x min of interest areas", key="x_start_col_name_fix_stim", value="char_xmin" |
) |
x_end_col_name_fix_stim = c1.text_input( |
"x max of interest areas", key="x_end_col_name_fix_stim", value="char_xmax" |
) |
y_start_col_name_fix_stim = c2.text_input( |
"y min of interest areas", key="y_start_col_name_fix_stim", value="char_ymin" |
) |
y_end_col_name_fix_stim = c3.text_input( |
"x max of interest areas", key="y_end_col_name_fix_stim", value="char_ymax" |
) |
char_col_name_fix_stim = c1.text_input( |
"content of interest area", key="char_col_name_fix_stim", value="char" |
) |
line_num_col_name_stim = c3.text_input( |
"line number for interest areas", key="line_num_col_name_stim", value="assigned_line" |
) |
subject_col_name_stim = c1.text_input("subject id", key="subject_col_name_stim", value="sub_id") |
trial_id_col_name_stim = c2.text_input("trial id", key="trial_id_col_name_stim", value="trial_id") |
has_multiple_subject = c2.checkbox("multiple subject in file", key="has_multiple_subject") |
form_submitted = st.form_submit_button("Confirm column/key names") |
if single_csv_file and single_csv_stim_file: |
process_custom_csvs_button = single_file_tab_csv_tab.button( |
"Load data from files", |
) |
if process_custom_csvs_button or "trial_keys_single_csv" in st.session_state: |
trials_by_ids, trial_keys = get_fixations_file_trials_list( |
st.session_state["dffix_single_csv"], st.session_state["stimdf_single_csv"] |
) |
st.session_state["trials_by_ids_single_csv"] = trials_by_ids |
st.session_state["trial_keys_single_csv"] = trial_keys |
with single_file_tab_csv_tab.form(key="trial_selection_algo_selection_form_single_csv"): |
col_a1, col_a2 = st.columns((1, 2)) |
with col_a1: |
trial_choice = st.selectbox( |
"Which trial should be corrected?", |
st.session_state["trial_keys_single_csv"], |
key="trial_id_selected_custom_csv", |
index=0, |
) |
with col_a2: |
algo_choice_single_csv = st.multiselect( |
"Choose correction algorithm", |
key="algo_choice_single_csv", |
default=[ALGO_CHOICES[0]], |
) |
process_trial_btn = st.form_submit_button("Correct trial") |
if "trial_id_selected_custom_csv" in st.session_state and "algo_choice_single_csv" in st.session_state: |
trial = st.session_state["trials_by_ids_single_csv"][trial_choice] |
dffix, trial, dpi, screen_res, font, font_size = process_trial_choice_single_csv( |
trial, algo_choice_single_csv |
) |
st.session_state["trial_single_csv"] = trial |
csv = convert_df(dffix) |
single_file_tab_csv_tab.download_button( |
"Download corrected fixation data", |
csv, |
f'{trial["trial_id"]}.csv', |
"text/csv", |
key="download-csv-custom-csv", |
) |
with single_file_tab_csv_tab.expander("Show corrected fixation data", expanded=True): |
st.dataframe(dffix, use_container_width=True, hide_index=True, height=200) |
with single_file_tab_csv_tab.expander("Show fixation plots", expanded=True): |
plotting_checkboxes_single_single_csv = st.multiselect( |
"Select what gets plotted", |
["Uncorrected Fixations", "Corrected Fixations", "Words", "Word boxes"], |
key="plotting_checkboxes_single_single_csv", |
default=["Uncorrected Fixations", "Corrected Fixations", "Words", "Word boxes"], |
) |
st.plotly_chart( |
plotly_plot_with_image( |
dffix, |
trial, |
to_plot_list=plotting_checkboxes_single_single_csv, |
algo_choice=algo_choice_single_csv, |
), |
use_container_width=True, |
) |
st.plotly_chart(plot_y_corr(dffix, algo_choice_single_csv), use_container_width=True) |
multi_file_tab.subheader("Upload multiple .asc files. Then select a correction algorithm and download the results.") |
with multi_file_tab.form("Upload files to be processed and select algorithm"): |
multifile_col, multi_algo_col = st.columns((1, 1)) |
with multifile_col: |
st.file_uploader( |
"Upload .asc Files", accept_multiple_files=True, key="multi_asc_filelist", type=["asc", "tar", "zip"] |
) |
with multi_algo_col: |
st.multiselect( |
"Choose correction algorithms", |
key="algo_choice_multi", |
default=[ALGO_CHOICES[0]], |
) |
process_trial_btn_multi = st.form_submit_button("Load and correct asc files") |
if process_trial_btn_multi: |
get_trials_and_lines_from_asc_files(st.session_state["multi_asc_filelist"]) |
if "zipfiles_with_results" in st.session_state: |
multi_res_col1, multi_res_col2 = multi_file_tab.columns(2) |
chosen_zip = multi_res_col1.selectbox("Choose results to download", st.session_state["zipfiles_with_results"]) |
st.session_state["logger"].info(f"Download button for {chosen_zip}") |
st.session_state["logger"].info(st.session_state["zipfiles_with_results"]) |
zipnamestem = pl.Path(chosen_zip).stem |
with open(chosen_zip, "rb") as f: |
multi_res_col2.download_button(f"Download {zipnamestem}", f, file_name=f"results_{zipnamestem}.zip") |
if "trial_choices_multi" in st.session_state: |
multi_plotting_options_col1, multi_plotting_options_col2 = multi_file_tab.columns(2) |
trial_choice_multi = multi_plotting_options_col1.selectbox( |
"Which trial should be plotted?", |
st.session_state["trial_choices_multi"], |
key="trial_id_multi", |
placeholder="Select trial to display and plot", |
on_change=process_trial_choice_and_update_df_multi, |
) |
plotting_checkboxes_multi = multi_plotting_options_col2.multiselect( |
"Select what gets plotted", |
["Uncorrected Fixations", "Corrected Fixations", "Words", "Word boxes"], |
default=["Uncorrected Fixations", "Corrected Fixations", "Words", "Word boxes"], |
) |
if trial_choice_multi and "dffix_multi" in st.session_state: |
df_expander_multi = multi_file_tab.expander("Show Dataframe", False) |
plot_expander_multi = multi_file_tab.expander("Show Plots", True) |
df_expander_multi.dataframe(st.session_state["dffix_multi"]) |
dffix_multi = st.session_state["dffix_multi"] |
trial_multi = st.session_state["trial_multi"] |
plot_expander_multi.plotly_chart( |
plotly_plot_with_image( |
dffix_multi, trial_multi, st.session_state["algo_choice_multi"], to_plot_list=plotting_checkboxes_multi |
), |
use_container_width=True, |
) |
plot_expander_multi.plotly_chart( |
plot_y_corr(dffix_multi, st.session_state["algo_choice_multi"]), use_container_width=True |
) |