jacopoteneggi commited on
Commit
7e207f0
1 Parent(s): 5ead791
Files changed (8) hide show
  1. app.py +5 -7
  2. app_lib/main.py +23 -24
  3. app_lib/test.py +106 -71
  4. app_lib/user_input.py +86 -12
  5. header.md +1 -3
  6. ibydmt/test.py +5 -2
  7. requirements.txt +2 -1
  8. style.css +15 -0
app.py CHANGED
@@ -2,11 +2,6 @@ import streamlit as st
2
 
3
  from app_lib.main import main
4
 
5
- with open("style.css", "r") as f:
6
- style = f.read()
7
- with open("header.md", "r") as f:
8
- header = f.read()
9
-
10
  if "sidebar_state" not in st.session_state:
11
  st.session_state.sidebar_state = "collapsed"
12
  if "disabled" not in st.session_state:
@@ -16,9 +11,12 @@ if "results" not in st.session_state:
16
 
17
  st.set_page_config(layout="wide", initial_sidebar_state=st.session_state.sidebar_state)
18
 
19
- st.session_state.sidebar_state = "collapsed"
20
- st.markdown(f"<style>{style}</style>", unsafe_allow_html=True)
 
 
21
 
 
22
  st.markdown(header)
23
 
24
  if __name__ == "__main__":
 
2
 
3
  from app_lib.main import main
4
 
 
 
 
 
 
5
  if "sidebar_state" not in st.session_state:
6
  st.session_state.sidebar_state = "collapsed"
7
  if "disabled" not in st.session_state:
 
11
 
12
  st.set_page_config(layout="wide", initial_sidebar_state=st.session_state.sidebar_state)
13
 
14
+ with open("style.css", "r") as f:
15
+ style = f.read()
16
+ with open("header.md", "r") as f:
17
+ header = f.read()
18
 
19
+ st.markdown(f"<style>{style}</style>", unsafe_allow_html=True)
20
  st.markdown(header)
21
 
22
  if __name__ == "__main__":
app_lib/main.py CHANGED
@@ -1,15 +1,14 @@
1
  import torch
2
  import streamlit as st
3
- import time
4
 
5
  from app_lib.user_input import (
6
- get_cardinality,
7
  get_class_name,
8
  get_concepts,
9
  get_image,
10
  get_model_name,
 
11
  )
12
- from app_lib.test import test
13
  from app_lib.viz import viz_results
14
 
15
 
@@ -20,6 +19,10 @@ def _disable():
20
  def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
21
  columns = st.columns([0.40, 0.60])
22
 
 
 
 
 
23
  with columns[0]:
24
  st.header("Choose Image and Concepts")
25
 
@@ -41,8 +44,6 @@ def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
41
  model_name = get_model_name()
42
  class_name, class_ready, class_error = get_class_name()
43
  concepts, concepts_ready, concepts_error = get_concepts()
44
- cardinality = int(len(concepts) / 2)
45
- # get_cardinality(concepts, concepts_ready)
46
 
47
  ready = class_ready and concepts_ready
48
 
@@ -55,6 +56,10 @@ def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
55
  st.error(error_message)
56
 
57
  with st.container():
 
 
 
 
58
  test_button = st.button(
59
  "Test Concepts",
60
  use_container_width=True,
@@ -62,25 +67,19 @@ def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
62
  disabled=st.session_state.disabled or not ready,
63
  )
64
 
65
- with st.popover("Advanced settings", disabled=st.session_state.disabled):
66
- st.markdown("Hello World 👋")
67
-
68
- with columns[1]:
69
- st.header("Results")
70
-
71
  if test_button:
72
  st.session_state.results = None
73
 
74
- _, centercol, _ = st.columns(3)
75
- with centercol:
76
- test(
77
- image,
78
- class_name,
79
- concepts,
80
- cardinality,
81
- "imagenette",
82
- model_name,
83
- device,
84
- )
85
-
86
- viz_results()
 
1
  import torch
2
  import streamlit as st
 
3
 
4
  from app_lib.user_input import (
 
5
  get_class_name,
6
  get_concepts,
7
  get_image,
8
  get_model_name,
9
+ get_advanced_settings,
10
  )
11
+ from app_lib.test import get_testing_config, test
12
  from app_lib.viz import viz_results
13
 
14
 
 
19
  def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
20
  columns = st.columns([0.40, 0.60])
21
 
22
+ with columns[1]:
23
+ st.header("Results")
24
+ viz_results()
25
+
26
  with columns[0]:
27
  st.header("Choose Image and Concepts")
28
 
 
44
  model_name = get_model_name()
45
  class_name, class_ready, class_error = get_class_name()
46
  concepts, concepts_ready, concepts_error = get_concepts()
 
 
47
 
48
  ready = class_ready and concepts_ready
49
 
 
56
  st.error(error_message)
57
 
58
  with st.container():
59
+ significance_level, tau_max, r, cardinality = get_advanced_settings(
60
+ concepts, concepts_ready
61
+ )
62
+
63
  test_button = st.button(
64
  "Test Concepts",
65
  use_container_width=True,
 
67
  disabled=st.session_state.disabled or not ready,
68
  )
69
 
 
 
 
 
 
 
70
  if test_button:
71
  st.session_state.results = None
72
 
73
+ testing_config = get_testing_config(
74
+ significance_level=significance_level, tau_max=tau_max, r=r
75
+ )
76
+ test(
77
+ testing_config,
78
+ image,
79
+ class_name,
80
+ concepts,
81
+ cardinality,
82
+ "imagenette",
83
+ model_name,
84
+ device,
85
+ )
app_lib/test.py CHANGED
@@ -4,7 +4,6 @@ import open_clip
4
  import h5py
5
  import streamlit as st
6
  import numpy as np
7
- import pandas as pd
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
 
10
  import ml_collections
@@ -16,16 +15,6 @@ from app_lib.ckde import cKDE
16
 
17
  rng = np.random.default_rng()
18
 
19
- testing_config = ml_collections.ConfigDict()
20
- testing_config.significance_level = 0.05
21
- testing_config.wealth = "ons"
22
- testing_config.bet = "tanh"
23
- testing_config.kernel = "rbf"
24
- testing_config.kernel_scale_method = "quantile"
25
- testing_config.kernel_scale = 0.5
26
- testing_config.tau_max = 200
27
- testing_config.r = 10
28
-
29
 
30
  def _get_open_clip_model(model_name, device):
31
  backbone = model_name.split(":")[-1]
@@ -45,19 +34,7 @@ def _get_clip_model(model_name, device):
45
  return model, preprocess, tokenizer
46
 
47
 
48
- def load_dataset(dataset_name, model_name):
49
- dataset_path = hf_hub_download(
50
- repo_id="jacopoteneggi/IBYDMT",
51
- filename=f"{dataset_name}_{model_name}_train.h5",
52
- repo_type="dataset",
53
- )
54
-
55
- with h5py.File(dataset_path, "r") as dataset:
56
- embedding = dataset["embedding"][:]
57
- return embedding
58
-
59
-
60
- def load_model(model_name, device):
61
  if "open_clip" in model_name:
62
  model, preprocess, tokenizer = _get_open_clip_model(model_name, device)
63
  elif "clip" in model_name:
@@ -67,7 +44,7 @@ def load_model(model_name, device):
67
 
68
  @torch.no_grad()
69
  @torch.cuda.amp.autocast()
70
- def encode_concepts(tokenizer, model, concepts, device):
71
  concepts_text = tokenizer(concepts).to(device)
72
 
73
  concept_features = model.encode_text(concepts_text)
@@ -77,7 +54,7 @@ def encode_concepts(tokenizer, model, concepts, device):
77
 
78
  @torch.no_grad()
79
  @torch.cuda.amp.autocast()
80
- def encode_image(model, preprocess, image, device):
81
  image = preprocess(image)
82
  image = image.unsqueeze(0)
83
  image = image.to(device)
@@ -89,7 +66,7 @@ def encode_image(model, preprocess, image, device):
89
 
90
  @torch.no_grad()
91
  @torch.cuda.amp.autocast()
92
- def encode_class_name(tokenizer, model, class_name, device):
93
  class_text = tokenizer([f"A photo of a {class_name}"]).to(device)
94
 
95
  class_features = model.encode_text(class_text)
@@ -97,12 +74,24 @@ def encode_class_name(tokenizer, model, class_name, device):
97
  return class_features.cpu().numpy()
98
 
99
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def _sample_random_subset(concept_idx, concepts, cardinality):
101
  sample_idx = list(set(range(len(concepts))) - {concept_idx})
102
  return rng.permutation(sample_idx)[:cardinality].tolist()
103
 
104
 
105
- def _test(z, concept_idx, concepts, cardinality, sampler, classifier):
106
  def cond_p(z, cond_idx, m):
107
  _, sample_h = sampler.sample(z, cond_idx, m=m)
108
  return sample_h
@@ -118,9 +107,16 @@ def _test(z, concept_idx, concepts, cardinality, sampler, classifier):
118
 
119
  tester = xSKIT(testing_config)
120
  rejected, tau = tester.test(
121
- z, concept_idx, subset_idx, cond_p, f, interrupt_on_rejection=False
 
 
 
 
 
 
122
  )
123
  wealth = tester.wealth._wealth
 
124
 
125
  rejected_hist.append(rejected)
126
  tau_hist.append(tau)
@@ -136,60 +132,99 @@ def _test(z, concept_idx, concepts, cardinality, sampler, classifier):
136
  }
137
 
138
 
139
- def test(image, class_name, concepts, cardinality, dataset_name, model_name, device):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  with st.spinner("Loading model"):
141
- model, preprocess, tokenizer = load_model(model_name, device)
142
 
143
  with st.spinner("Encoding concepts"):
144
- cbm = encode_concepts(tokenizer, model, concepts, device)
145
 
146
  with st.spinner("Encoding image"):
147
- h = encode_image(model, preprocess, image, device)
148
  z = h @ cbm.T
149
  z = z.squeeze()
150
 
151
- with st.spinner("Testing"):
152
- progress_bar = st.progress(0)
153
-
154
- embedding = load_dataset("imagenette", model_name)
155
- semantics = embedding @ cbm.T
156
- sampler = cKDE(embedding, semantics)
157
-
158
- classifier = encode_class_name(tokenizer, model, class_name, device)
159
-
160
- with ThreadPoolExecutor() as executor:
161
- futures = [
162
- executor.submit(
163
- _test, z, concept_idx, concepts, cardinality, sampler, classifier
164
- )
165
- for concept_idx in range(len(concepts))
166
- ]
167
-
168
- results = []
169
- for idx, future in enumerate(as_completed(futures)):
170
- results.append(future.result())
171
- progress_bar.progress((idx + 1) / len(concepts))
172
-
173
- rejected = np.empty((testing_config.r, len(concepts)))
174
- tau = np.empty((testing_config.r, len(concepts)))
175
- wealth = np.empty((testing_config.r, testing_config.tau_max, len(concepts)))
176
 
177
- for _results in results:
178
- concept_idx = concepts.index(_results["concept"])
 
179
 
180
- rejected[:, concept_idx] = np.array(_results["rejected"])
181
- tau[:, concept_idx] = np.array(_results["tau"])
182
- wealth[:, :, concept_idx] = np.array(_results["wealth"])
183
 
184
- tau /= testing_config.tau_max
 
 
 
185
 
186
- st.session_state.results = {
187
- "significance_level": testing_config.significance_level,
188
- "concepts": concepts,
189
- "rejected": rejected,
190
- "tau": tau,
191
- "wealth": wealth,
192
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
  st.session_state.disabled = False
195
  st.experimental_rerun()
 
4
  import h5py
5
  import streamlit as st
6
  import numpy as np
 
7
  from concurrent.futures import ThreadPoolExecutor, as_completed
8
 
9
  import ml_collections
 
15
 
16
  rng = np.random.default_rng()
17
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def _get_open_clip_model(model_name, device):
20
  backbone = model_name.split(":")[-1]
 
34
  return model, preprocess, tokenizer
35
 
36
 
37
+ def _load_model(model_name, device):
 
 
 
 
 
 
 
 
 
 
 
 
38
  if "open_clip" in model_name:
39
  model, preprocess, tokenizer = _get_open_clip_model(model_name, device)
40
  elif "clip" in model_name:
 
44
 
45
  @torch.no_grad()
46
  @torch.cuda.amp.autocast()
47
+ def _encode_concepts(tokenizer, model, concepts, device):
48
  concepts_text = tokenizer(concepts).to(device)
49
 
50
  concept_features = model.encode_text(concepts_text)
 
54
 
55
  @torch.no_grad()
56
  @torch.cuda.amp.autocast()
57
+ def _encode_image(model, preprocess, image, device):
58
  image = preprocess(image)
59
  image = image.unsqueeze(0)
60
  image = image.to(device)
 
66
 
67
  @torch.no_grad()
68
  @torch.cuda.amp.autocast()
69
+ def _encode_class_name(tokenizer, model, class_name, device):
70
  class_text = tokenizer([f"A photo of a {class_name}"]).to(device)
71
 
72
  class_features = model.encode_text(class_text)
 
74
  return class_features.cpu().numpy()
75
 
76
 
77
+ def _load_dataset(dataset_name, model_name):
78
+ dataset_path = hf_hub_download(
79
+ repo_id="jacopoteneggi/IBYDMT",
80
+ filename=f"{dataset_name}_{model_name}_train.h5",
81
+ repo_type="dataset",
82
+ )
83
+
84
+ with h5py.File(dataset_path, "r") as dataset:
85
+ embedding = dataset["embedding"][:]
86
+ return embedding
87
+
88
+
89
  def _sample_random_subset(concept_idx, concepts, cardinality):
90
  sample_idx = list(set(range(len(concepts))) - {concept_idx})
91
  return rng.permutation(sample_idx)[:cardinality].tolist()
92
 
93
 
94
+ def _test(testing_config, z, concept_idx, concepts, cardinality, sampler, classifier):
95
  def cond_p(z, cond_idx, m):
96
  _, sample_h = sampler.sample(z, cond_idx, m=m)
97
  return sample_h
 
107
 
108
  tester = xSKIT(testing_config)
109
  rejected, tau = tester.test(
110
+ z,
111
+ concept_idx,
112
+ subset_idx,
113
+ cond_p,
114
+ f,
115
+ interrupt_on="max_wealth",
116
+ max_wealth=100,
117
  )
118
  wealth = tester.wealth._wealth
119
+ wealth = wealth + [wealth[-1]] * (testing_config.tau_max - len(wealth))
120
 
121
  rejected_hist.append(rejected)
122
  tau_hist.append(tau)
 
132
  }
133
 
134
 
135
+ def get_testing_config(**kwargs):
136
+ testing_config = st.session_state.testing_config = ml_collections.ConfigDict()
137
+ testing_config.significance_level = kwargs.get("significance_level", 0.05)
138
+ testing_config.wealth = kwargs.get("wealth", "ons")
139
+ testing_config.bet = kwargs.get("bet", "tanh")
140
+ testing_config.kernel = kwargs.get("kernel", "rbf")
141
+ testing_config.kernel_scale_method = kwargs.get("kernel_scale_method", "quantile")
142
+ testing_config.kernel_scale = kwargs.get("kernel_scale", 0.5)
143
+ testing_config.tau_max = kwargs.get("tau_max", 200)
144
+ testing_config.r = kwargs.get("r", 10)
145
+ return testing_config
146
+
147
+
148
+ def test(
149
+ testing_config,
150
+ image,
151
+ class_name,
152
+ concepts,
153
+ cardinality,
154
+ dataset_name,
155
+ model_name,
156
+ device,
157
+ ):
158
  with st.spinner("Loading model"):
159
+ model, preprocess, tokenizer = _load_model(model_name, device)
160
 
161
  with st.spinner("Encoding concepts"):
162
+ cbm = _encode_concepts(tokenizer, model, concepts, device)
163
 
164
  with st.spinner("Encoding image"):
165
+ h = _encode_image(model, preprocess, image, device)
166
  z = h @ cbm.T
167
  z = z.squeeze()
168
 
169
+ progress_bar = st.progress(
170
+ 0,
171
+ text=f"Testing concepts (can take a few minutes) [0 / {len(concepts)} completed]",
172
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
+ embedding = _load_dataset(dataset_name, model_name)
175
+ semantics = embedding @ cbm.T
176
+ sampler = cKDE(embedding, semantics)
177
 
178
+ classifier = _encode_class_name(tokenizer, model, class_name, device)
 
 
179
 
180
+ progress_bar.progress(
181
+ 1 / (len(concepts) + 1),
182
+ text=f"Testing concepts (can take a few minutes) [0 / {len(concepts)} completed]",
183
+ )
184
 
185
+ with ThreadPoolExecutor() as executor:
186
+ futures = [
187
+ executor.submit(
188
+ _test,
189
+ testing_config,
190
+ z,
191
+ concept_idx,
192
+ concepts,
193
+ cardinality,
194
+ sampler,
195
+ classifier,
196
+ )
197
+ for concept_idx in range(len(concepts))
198
+ ]
199
+
200
+ results = []
201
+ for idx, future in enumerate(as_completed(futures)):
202
+ results.append(future.result())
203
+ progress_bar.progress(
204
+ (idx + 2) / (len(concepts) + 1),
205
+ text=f"Testing concepts (can take a few minutes) [{idx + 1} / {len(concepts)} completed]",
206
+ )
207
+
208
+ rejected = np.empty((testing_config.r, len(concepts)))
209
+ tau = np.empty((testing_config.r, len(concepts)))
210
+ wealth = np.empty((testing_config.r, testing_config.tau_max, len(concepts)))
211
+
212
+ for _results in results:
213
+ concept_idx = concepts.index(_results["concept"])
214
+
215
+ rejected[:, concept_idx] = np.array(_results["rejected"])
216
+ tau[:, concept_idx] = np.array(_results["tau"])
217
+ wealth[:, :, concept_idx] = np.array(_results["wealth"])
218
+
219
+ tau /= testing_config.tau_max
220
+
221
+ st.session_state.results = {
222
+ "significance_level": testing_config.significance_level,
223
+ "concepts": concepts,
224
+ "rejected": rejected,
225
+ "tau": tau,
226
+ "wealth": wealth,
227
+ }
228
 
229
  st.session_state.disabled = False
230
  st.experimental_rerun()
app_lib/user_input.py CHANGED
@@ -20,6 +20,82 @@ def _validate_concepts(concepts):
20
  return (False, "Maximum 10 concepts allowed")
21
  return (True, None)
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def get_model_name():
24
  return st.selectbox(
25
  "Model to test",
@@ -58,8 +134,8 @@ def get_class_name():
58
 
59
  def get_concepts():
60
  concepts = st.text_area(
61
- "Concepts to test (max 10)",
62
- help="List of concepts to test the predictions of the model with. Write one concept per line.",
63
  height=160,
64
  value="piano\ncute\nwhiskers\nmusic\nwild",
65
  disabled=st.session_state.disabled,
@@ -73,13 +149,11 @@ def get_concepts():
73
  return concepts, concepts_ready, concepts_error
74
 
75
 
76
- def get_cardinality(concepts, concepts_ready):
77
- return st.slider(
78
- "Size of conditioning set",
79
- help="The number of concepts to condition model predictions on.",
80
- min_value=1,
81
- max_value=max(2, len(concepts) - 1),
82
- value=2,
83
- step=1,
84
- disabled=st.session_state.disabled or not concepts_ready,
85
- )
 
20
  return (False, "Maximum 10 concepts allowed")
21
  return (True, None)
22
 
23
+
24
+ def _get_significance_level():
25
+ DEFAULT = 0.05
26
+ return st.slider(
27
+ "Significance level",
28
+ help=" ".join(
29
+ [
30
+ "The level of significance of the tests.",
31
+ f"Defaults to {DEFAULT:.2F}.",
32
+ ]
33
+ ),
34
+ min_value=0.01,
35
+ max_value=1.0,
36
+ value=DEFAULT,
37
+ step=0.01,
38
+ disabled=st.session_state.disabled,
39
+ )
40
+
41
+
42
+ def _get_tau_max():
43
+ DEFAULT = 200
44
+ return int(
45
+ st.slider(
46
+ "Duration of test",
47
+ help=" ".join(
48
+ [
49
+ "The maximum number of steps for each test.",
50
+ f"Defaults to {DEFAULT}.",
51
+ ]
52
+ ),
53
+ min_value=1,
54
+ max_value=1000,
55
+ step=1,
56
+ value=DEFAULT,
57
+ disabled=st.session_state.disabled,
58
+ )
59
+ )
60
+
61
+
62
+ def _get_number_of_tests():
63
+ DEFAULT = 20
64
+ return int(
65
+ st.slider(
66
+ "Number of tests per concept",
67
+ help=" ".join(
68
+ [
69
+ "The number of tests to average for each concept.",
70
+ f"Defaults to {DEFAULT}.",
71
+ ]
72
+ ),
73
+ min_value=1,
74
+ max_value=100,
75
+ step=1,
76
+ value=DEFAULT,
77
+ disabled=st.session_state.disabled,
78
+ )
79
+ )
80
+
81
+
82
+ def _get_cardinality(concepts, concepts_ready):
83
+ return st.slider(
84
+ "Size of conditioning set",
85
+ help=" ".join(
86
+ [
87
+ "The number of concepts to condition model predictions on.",
88
+ "Defaults to half of the number of concepts.",
89
+ ]
90
+ ),
91
+ min_value=1,
92
+ max_value=max(2, len(concepts) - 1),
93
+ value=int(len(concepts) / 2),
94
+ step=1,
95
+ disabled=st.session_state.disabled or not concepts_ready,
96
+ )
97
+
98
+
99
  def get_model_name():
100
  return st.selectbox(
101
  "Model to test",
 
134
 
135
  def get_concepts():
136
  concepts = st.text_area(
137
+ "Concepts to test",
138
+ help="List of concepts to test the predictions of the model with. Write one concept per line. Maximum 10 concepts allowed.",
139
  height=160,
140
  value="piano\ncute\nwhiskers\nmusic\nwild",
141
  disabled=st.session_state.disabled,
 
149
  return concepts, concepts_ready, concepts_error
150
 
151
 
152
+ def get_advanced_settings(concepts, concepts_ready):
153
+ with st.popover("Advanced settings", disabled=st.session_state.disabled):
154
+ significance_level = _get_significance_level()
155
+ tau_max = _get_tau_max()
156
+ r = _get_number_of_tests()
157
+ cardinality = _get_cardinality(concepts, concepts_ready)
158
+
159
+ return significance_level, tau_max, r, cardinality
 
 
header.md CHANGED
@@ -1,5 +1,3 @@
1
  # 🤔 I Bet You Did Not Mean That
2
 
3
- Official HF Space for the paper [*I Bet You Did Not Mean That: Testing Semantic Importance via Betting*](https://arxiv.org/pdf/2405.19146), by [Jacopo Teneggi](https://jacopoteneggi.github.io) and [Jeremias Sulam](https://sites.google.com/view/jsulam).
4
-
5
- ---
 
1
  # 🤔 I Bet You Did Not Mean That
2
 
3
+ Official 🤗 Space for the paper [*I Bet You Did Not Mean That: Testing Semantic Importance via Betting*](https://arxiv.org/pdf/2405.19146), by [Jacopo Teneggi](https://jacopoteneggi.github.io) and [Jeremias Sulam](https://sites.google.com/view/jsulam).
 
 
ibydmt/test.py CHANGED
@@ -141,7 +141,8 @@ class xSKIT(SequentialTester):
141
  C: list[int],
142
  cond_p: Callable[[Float[Array, "D"], list[int], int], Float[Array, "N D2"]],
143
  model: Callable[[Float[Array, "N D2"]], Float[Array, "N"]],
144
- interrupt_on_rejection: bool = True,
 
145
  ) -> Tuple[bool, int]:
146
  sample = functools.partial(self._sample, z, j, C, cond_p, model)
147
 
@@ -159,6 +160,8 @@ class xSKIT(SequentialTester):
159
 
160
  if self.wealth.rejected:
161
  tau = min(tau, t)
162
- if interrupt_on_rejection:
 
 
163
  break
164
  return (self.wealth.rejected, tau)
 
141
  C: list[int],
142
  cond_p: Callable[[Float[Array, "D"], list[int], int], Float[Array, "N D2"]],
143
  model: Callable[[Float[Array, "N D2"]], Float[Array, "N"]],
144
+ interrupt_on: str = "rejection",
145
+ max_wealth: float = None,
146
  ) -> Tuple[bool, int]:
147
  sample = functools.partial(self._sample, z, j, C, cond_p, model)
148
 
 
160
 
161
  if self.wealth.rejected:
162
  tau = min(tau, t)
163
+ if interrupt_on == "rejection":
164
+ break
165
+ if interrupt_on == "max_wealth" and self.wealth._w >= max_wealth:
166
  break
167
  return (self.wealth.rejected, tau)
requirements.txt CHANGED
@@ -4,4 +4,5 @@ open_clip_torch
4
  h5py
5
  ml_collections
6
  jaxtyping
7
- scikit-learn
 
 
4
  h5py
5
  ml_collections
6
  jaxtyping
7
+ scikit-learn
8
+ plotly
style.css CHANGED
@@ -33,7 +33,22 @@ h1 {
33
  }
34
  }
35
 
 
 
 
 
 
 
 
 
36
  button:hover>div:first-of-type>p {
37
  text-decoration: underline;
38
  }
 
 
 
 
 
 
 
39
  }
 
33
  }
34
  }
35
 
36
+ button:active {
37
+ background: white;
38
+ }
39
+
40
+ button:focus:not(:active) {
41
+ color: rgb(49, 51, 63);
42
+ }
43
+
44
  button:hover>div:first-of-type>p {
45
  text-decoration: underline;
46
  }
47
+ }
48
+
49
+ [data-testid="stSpinner"] {
50
+ >div {
51
+ display: flex;
52
+ justify-content: center;
53
+ }
54
  }