Spaces:
Runtime error
Runtime error
shlomihod
commited on
Commit
·
644b61f
1
Parent(s):
984d233
add decoding seed
Browse files
app.py
CHANGED
@@ -27,7 +27,7 @@ HF_MODEL = st.secrets.get("hf_model", "")
|
|
27 |
|
28 |
HF_DATASET = st.secrets.get("hf_dataset", "")
|
29 |
|
30 |
-
|
31 |
TRAIN_SIZE = 10
|
32 |
TEST_SIZE = 25
|
33 |
SPLITS = ["train", "test"]
|
@@ -96,7 +96,7 @@ def normalize(text):
|
|
96 |
return strip_newline_space(text).lower().capitalize()
|
97 |
|
98 |
|
99 |
-
def prepare_datasets(dataset_name):
|
100 |
try:
|
101 |
ds = load_dataset(dataset_name)
|
102 |
except FileNotFoundError as e:
|
@@ -138,7 +138,7 @@ def prepare_datasets(dataset_name):
|
|
138 |
input_columns.append(lowered_input_column)
|
139 |
|
140 |
ds = ds["train"].train_test_split(
|
141 |
-
train_size=TRAIN_SIZE, test_size=TEST_SIZE, seed=
|
142 |
)
|
143 |
|
144 |
dfs = {}
|
@@ -350,6 +350,9 @@ def combine_labels(labels):
|
|
350 |
return "|".join(f"``{label}``" for label in labels)
|
351 |
|
352 |
|
|
|
|
|
|
|
353 |
if "client" not in st.session_state:
|
354 |
st.session_state["client"] = InferenceClient(
|
355 |
token=st.secrets.get("hf_token", None), model=HF_MODEL
|
@@ -365,7 +368,7 @@ if "train_dataset" not in st.session_state:
|
|
365 |
st.session_state["input_columns"],
|
366 |
st.session_state["label_column"],
|
367 |
st.session_state["labels"],
|
368 |
-
) = prepare_datasets(HF_DATASET)
|
369 |
|
370 |
for split in splits_df:
|
371 |
st.session_state[f"{split}_dataset"] = splits_df[split]
|
@@ -373,6 +376,7 @@ if "train_dataset" not in st.session_state:
|
|
373 |
if "generation_config" not in st.session_state:
|
374 |
st.session_state["generation_config"] = GENERATION_CONFIG_DEFAULTS
|
375 |
|
|
|
376 |
st.set_page_config(page_title=TITLE, initial_sidebar_state="collapsed")
|
377 |
|
378 |
st.title(TITLE)
|
@@ -419,7 +423,11 @@ with st.sidebar:
|
|
419 |
if not stop_sequences:
|
420 |
stop_sequences = None
|
421 |
|
422 |
-
|
|
|
|
|
|
|
|
|
423 |
|
424 |
submitted = st.form_submit_button("Set")
|
425 |
|
@@ -432,10 +440,10 @@ with st.sidebar:
|
|
432 |
st.error("Model must be specified.")
|
433 |
st.stop()
|
434 |
|
435 |
-
if not
|
436 |
-
|
437 |
elif seed.isnumeric():
|
438 |
-
|
439 |
else:
|
440 |
st.error("Seed must be numeric or empty.")
|
441 |
st.stop()
|
@@ -461,16 +469,26 @@ with st.sidebar:
|
|
461 |
)
|
462 |
st.stop()
|
463 |
|
464 |
-
if
|
465 |
st.error(
|
466 |
-
"Sampling must be enabled to use a seed. Otherwise, the seed field should be empty."
|
467 |
)
|
468 |
st.stop()
|
469 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
470 |
generation_config = generation_config_sliders | dict(
|
471 |
-
do_sample=do_sample, stop_sequences=stop_sequences, seed=
|
472 |
)
|
473 |
|
|
|
|
|
474 |
st.session_state["client"] = InferenceClient(
|
475 |
token=st.secrets.get("hf_token", None), model=model
|
476 |
)
|
|
|
27 |
|
28 |
HF_DATASET = st.secrets.get("hf_dataset", "")
|
29 |
|
30 |
+
DATASET_SPLIT_SEED_DEFAULT = 42
|
31 |
TRAIN_SIZE = 10
|
32 |
TEST_SIZE = 25
|
33 |
SPLITS = ["train", "test"]
|
|
|
96 |
return strip_newline_space(text).lower().capitalize()
|
97 |
|
98 |
|
99 |
+
def prepare_datasets(dataset_name, dataset_split_seed=None):
|
100 |
try:
|
101 |
ds = load_dataset(dataset_name)
|
102 |
except FileNotFoundError as e:
|
|
|
138 |
input_columns.append(lowered_input_column)
|
139 |
|
140 |
ds = ds["train"].train_test_split(
|
141 |
+
train_size=TRAIN_SIZE, test_size=TEST_SIZE, seed=dataset_split_seed
|
142 |
)
|
143 |
|
144 |
dfs = {}
|
|
|
350 |
return "|".join(f"``{label}``" for label in labels)
|
351 |
|
352 |
|
353 |
+
if "dataset_split_seed" not in st.session_state:
|
354 |
+
st.session_state["dataset_split_seed"] = DATASET_SPLIT_SEED_DEFAULT
|
355 |
+
|
356 |
if "client" not in st.session_state:
|
357 |
st.session_state["client"] = InferenceClient(
|
358 |
token=st.secrets.get("hf_token", None), model=HF_MODEL
|
|
|
368 |
st.session_state["input_columns"],
|
369 |
st.session_state["label_column"],
|
370 |
st.session_state["labels"],
|
371 |
+
) = prepare_datasets(HF_DATASET, st.session_state["dataset_split_seed"])
|
372 |
|
373 |
for split in splits_df:
|
374 |
st.session_state[f"{split}_dataset"] = splits_df[split]
|
|
|
376 |
if "generation_config" not in st.session_state:
|
377 |
st.session_state["generation_config"] = GENERATION_CONFIG_DEFAULTS
|
378 |
|
379 |
+
|
380 |
st.set_page_config(page_title=TITLE, initial_sidebar_state="collapsed")
|
381 |
|
382 |
st.title(TITLE)
|
|
|
423 |
if not stop_sequences:
|
424 |
stop_sequences = None
|
425 |
|
426 |
+
decoding_seed = st.text_input("Decoding Seed").strip()
|
427 |
+
|
428 |
+
dataset_split_seed = st.text_input(
|
429 |
+
"Dataset Split Seed", st.session_state["dataset_split_seed"]
|
430 |
+
).strip()
|
431 |
|
432 |
submitted = st.form_submit_button("Set")
|
433 |
|
|
|
440 |
st.error("Model must be specified.")
|
441 |
st.stop()
|
442 |
|
443 |
+
if not decoding_seed:
|
444 |
+
decoding_seed = None
|
445 |
elif seed.isnumeric():
|
446 |
+
decoding_seed = int(seed)
|
447 |
else:
|
448 |
st.error("Seed must be numeric or empty.")
|
449 |
st.stop()
|
|
|
469 |
)
|
470 |
st.stop()
|
471 |
|
472 |
+
if decoding_seed is not None and not do_sample:
|
473 |
st.error(
|
474 |
+
"Sampling must be enabled to use a decoding seed. Otherwise, the seed field should be empty."
|
475 |
)
|
476 |
st.stop()
|
477 |
|
478 |
+
if not dataset_split_seed:
|
479 |
+
dataset_split_seed = None
|
480 |
+
elif dataset_split_seed.isnumeric():
|
481 |
+
dataset_split_seed = int(dataset_split_seed)
|
482 |
+
else:
|
483 |
+
st.error("Dataset split seed must be numeric or empty.")
|
484 |
+
st.stop()
|
485 |
+
|
486 |
generation_config = generation_config_sliders | dict(
|
487 |
+
do_sample=do_sample, stop_sequences=stop_sequences, seed=decoding_seed
|
488 |
)
|
489 |
|
490 |
+
st.session_state["dataset_split_seed"] = dataset_split_seed
|
491 |
+
|
492 |
st.session_state["client"] = InferenceClient(
|
493 |
token=st.secrets.get("hf_token", None), model=model
|
494 |
)
|