IBYDMT / app_lib /user_input.py
jacopoteneggi's picture
Update
8e05eba verified
import json
import os
import streamlit as st
from PIL import Image
from streamlit_image_select import image_select
import app_lib.defaults as defaults
from app_lib.utils import SUPPORTED_DATASETS, SUPPORTED_MODELS
IMAGE_DIR = os.path.join("assets", "images")
IMAGE_NAMES = list(sorted(filter(lambda x: x.endswith(".jpg"), os.listdir(IMAGE_DIR))))
IMAGE_PATHS = list(map(lambda x: os.path.join(IMAGE_DIR, x), IMAGE_NAMES))
IMAGE_PRESETS = json.load(open("assets/image_presets.json"))
def _validate_class_name(class_name):
if class_name is None:
return (False, "Class name cannot be empty.")
if class_name.strip() == "":
return (False, "Class name cannot be empty.")
return (True, None)
def _validate_concepts(concepts):
if len(concepts) < 3:
return (False, "You must provide at least 3 concepts")
if len(concepts) > 10:
return (False, "Maximum 10 concepts allowed")
return (True, None)
def _get_significance_level():
default = defaults.SIGNIFICANCE_LEVEL_VALUE
step = defaults.SIGNIFICANCE_LEVEL_STEP
return st.slider(
"Significance level",
help=f"The level of significance of the tests. Defaults to {default:.2F}.",
min_value=step,
max_value=1.0,
value=default,
step=step,
disabled=st.session_state.disabled,
)
def _get_tau_max():
default = defaults.TAU_MAX_VALUE
step = defaults.TAU_MAX_STEP
return int(
st.slider(
"Length of test",
help=f"The maximum number of steps for each test. Defaults to {default}.",
min_value=step,
max_value=1000,
step=step,
value=default,
disabled=st.session_state.disabled,
)
)
def _get_number_of_tests():
default = defaults.R_VALUE
step = defaults.R_STEP
return int(
st.slider(
"Number of tests per concept",
help=(
"The number of tests to average for each concept. "
f"Defaults to {default}."
),
min_value=step,
max_value=100,
step=step,
value=default,
disabled=st.session_state.disabled,
)
)
def _get_cardinality(concepts, concepts_ready):
default = defaults.CARDINALITY_VALUE
step = defaults.CARDINALITY_STEP
return st.slider(
"Size of conditioning set",
help=(
"The number of concepts to condition model predictions on. "
"Defaults to {default}."
),
min_value=1,
max_value=max(2, len(concepts) - 1),
value=default,
step=step,
disabled=st.session_state.disabled or not concepts_ready,
)
def _get_dataset_name():
options = SUPPORTED_DATASETS
default_idx = options.index(defaults.DATASET_NAME)
return st.selectbox(
"Dataset",
options=options,
index=default_idx,
help=(
"Name of the dataset to use to train sampler."
f"Defaults to {SUPPORTED_DATASETS[default_idx]}."
),
disabled=st.session_state.disabled,
)
def get_model_name():
options = list(SUPPORTED_MODELS.keys())
default_idx = options.index(defaults.MODEL_NAME)
return st.selectbox(
"Model to test",
options=options,
index=default_idx,
help=(
"Name of the vision-language model to test the predictions of."
f"Defaults to {options[default_idx]}"
),
disabled=st.session_state.disabled,
)
def get_image():
with st.sidebar:
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
if uploaded_file is not None:
return (None, Image.open(uploaded_file))
else:
DEFAULT = IMAGE_NAMES.index("bowl_ace.jpg")
image_idx = image_select(
label="or select one",
images=IMAGE_PATHS,
index=DEFAULT,
return_value="index",
)
image_name, image_path = IMAGE_NAMES[image_idx], IMAGE_PATHS[image_idx]
return (image_name, Image.open(image_path))
def get_class_name(image_name=None):
default = (
IMAGE_PRESETS[image_name.split(".")[0]]["class_name"] if image_name else ""
)
class_name = st.text_input(
"Class to predict",
help="Name of the class to build the zero-shot CLIP classifier with.",
value=default,
disabled=st.session_state.disabled,
placeholder="Type class name here",
)
class_ready, class_error = _validate_class_name(class_name)
return class_name, class_ready, class_error
def get_concepts(image_name=None):
default = (
"\n".join(IMAGE_PRESETS[image_name.split(".")[0]]["concepts"])
if image_name
else ""
)
concepts = st.text_area(
"Concepts to test",
help=(
"List of concepts to test the predictions of the model with. "
"Write one concept per line. Maximum 10 concepts allowed."
),
height=180,
value=default,
disabled=st.session_state.disabled,
placeholder="Type one concept\nper line",
)
concepts = concepts.split("\n")
concepts = [concept.strip() for concept in concepts]
concepts = [concept for concept in concepts if concept != ""]
concepts = list(set(concepts))
concepts_ready, concepts_error = _validate_concepts(concepts)
return concepts, concepts_ready, concepts_error
def get_advanced_settings(concepts, concepts_ready):
with st.expander("Advanced settings"):
dataset_name = _get_dataset_name()
significance_level = _get_significance_level()
tau_max = _get_tau_max()
r = _get_number_of_tests()
cardinality = _get_cardinality(concepts, concepts_ready)
st.divider()
return significance_level, tau_max, r, cardinality, dataset_name