shlomihod commited on
Commit
644b61f
·
1 Parent(s): 984d233

add decoding seed

Browse files
Files changed (1) hide show
  1. app.py +29 -11
app.py CHANGED
@@ -27,7 +27,7 @@ HF_MODEL = st.secrets.get("hf_model", "")
27
 
28
  HF_DATASET = st.secrets.get("hf_dataset", "")
29
 
30
- DATASET_SPLIT_SEED = 42
31
  TRAIN_SIZE = 10
32
  TEST_SIZE = 25
33
  SPLITS = ["train", "test"]
@@ -96,7 +96,7 @@ def normalize(text):
96
  return strip_newline_space(text).lower().capitalize()
97
 
98
 
99
- def prepare_datasets(dataset_name):
100
  try:
101
  ds = load_dataset(dataset_name)
102
  except FileNotFoundError as e:
@@ -138,7 +138,7 @@ def prepare_datasets(dataset_name):
138
  input_columns.append(lowered_input_column)
139
 
140
  ds = ds["train"].train_test_split(
141
- train_size=TRAIN_SIZE, test_size=TEST_SIZE, seed=DATASET_SPLIT_SEED
142
  )
143
 
144
  dfs = {}
@@ -350,6 +350,9 @@ def combine_labels(labels):
350
  return "|".join(f"``{label}``" for label in labels)
351
 
352
 
 
 
 
353
  if "client" not in st.session_state:
354
  st.session_state["client"] = InferenceClient(
355
  token=st.secrets.get("hf_token", None), model=HF_MODEL
@@ -365,7 +368,7 @@ if "train_dataset" not in st.session_state:
365
  st.session_state["input_columns"],
366
  st.session_state["label_column"],
367
  st.session_state["labels"],
368
- ) = prepare_datasets(HF_DATASET)
369
 
370
  for split in splits_df:
371
  st.session_state[f"{split}_dataset"] = splits_df[split]
@@ -373,6 +376,7 @@ if "train_dataset" not in st.session_state:
373
  if "generation_config" not in st.session_state:
374
  st.session_state["generation_config"] = GENERATION_CONFIG_DEFAULTS
375
 
 
376
  st.set_page_config(page_title=TITLE, initial_sidebar_state="collapsed")
377
 
378
  st.title(TITLE)
@@ -419,7 +423,11 @@ with st.sidebar:
419
  if not stop_sequences:
420
  stop_sequences = None
421
 
422
- seed = st.text_input("Seed").strip()
 
 
 
 
423
 
424
  submitted = st.form_submit_button("Set")
425
 
@@ -432,10 +440,10 @@ with st.sidebar:
432
  st.error("Model must be specified.")
433
  st.stop()
434
 
435
- if not seed:
436
- seed = None
437
  elif seed.isnumeric():
438
- seed = int(seed)
439
  else:
440
  st.error("Seed must be numeric or empty.")
441
  st.stop()
@@ -461,16 +469,26 @@ with st.sidebar:
461
  )
462
  st.stop()
463
 
464
- if seed is not None and not do_sample:
465
  st.error(
466
- "Sampling must be enabled to use a seed. Otherwise, the seed field should be empty."
467
  )
468
  st.stop()
469
 
 
 
 
 
 
 
 
 
470
  generation_config = generation_config_sliders | dict(
471
- do_sample=do_sample, stop_sequences=stop_sequences, seed=seed
472
  )
473
 
 
 
474
  st.session_state["client"] = InferenceClient(
475
  token=st.secrets.get("hf_token", None), model=model
476
  )
 
27
 
28
  HF_DATASET = st.secrets.get("hf_dataset", "")
29
 
30
+ DATASET_SPLIT_SEED_DEFAULT = 42
31
  TRAIN_SIZE = 10
32
  TEST_SIZE = 25
33
  SPLITS = ["train", "test"]
 
96
  return strip_newline_space(text).lower().capitalize()
97
 
98
 
99
+ def prepare_datasets(dataset_name, dataset_split_seed=None):
100
  try:
101
  ds = load_dataset(dataset_name)
102
  except FileNotFoundError as e:
 
138
  input_columns.append(lowered_input_column)
139
 
140
  ds = ds["train"].train_test_split(
141
+ train_size=TRAIN_SIZE, test_size=TEST_SIZE, seed=dataset_split_seed
142
  )
143
 
144
  dfs = {}
 
350
  return "|".join(f"``{label}``" for label in labels)
351
 
352
 
353
+ if "dataset_split_seed" not in st.session_state:
354
+ st.session_state["dataset_split_seed"] = DATASET_SPLIT_SEED_DEFAULT
355
+
356
  if "client" not in st.session_state:
357
  st.session_state["client"] = InferenceClient(
358
  token=st.secrets.get("hf_token", None), model=HF_MODEL
 
368
  st.session_state["input_columns"],
369
  st.session_state["label_column"],
370
  st.session_state["labels"],
371
+ ) = prepare_datasets(HF_DATASET, st.session_state["dataset_split_seed"])
372
 
373
  for split in splits_df:
374
  st.session_state[f"{split}_dataset"] = splits_df[split]
 
376
  if "generation_config" not in st.session_state:
377
  st.session_state["generation_config"] = GENERATION_CONFIG_DEFAULTS
378
 
379
+
380
  st.set_page_config(page_title=TITLE, initial_sidebar_state="collapsed")
381
 
382
  st.title(TITLE)
 
423
  if not stop_sequences:
424
  stop_sequences = None
425
 
426
+ decoding_seed = st.text_input("Decoding Seed").strip()
427
+
428
+ dataset_split_seed = st.text_input(
429
+ "Dataset Split Seed", st.session_state["dataset_split_seed"]
430
+ ).strip()
431
 
432
  submitted = st.form_submit_button("Set")
433
 
 
440
  st.error("Model must be specified.")
441
  st.stop()
442
 
443
+ if not decoding_seed:
444
+ decoding_seed = None
445
  elif seed.isnumeric():
446
+ decoding_seed = int(seed)
447
  else:
448
  st.error("Seed must be numeric or empty.")
449
  st.stop()
 
469
  )
470
  st.stop()
471
 
472
+ if decoding_seed is not None and not do_sample:
473
  st.error(
474
+ "Sampling must be enabled to use a decoding seed. Otherwise, the seed field should be empty."
475
  )
476
  st.stop()
477
 
478
+ if not dataset_split_seed:
479
+ dataset_split_seed = None
480
+ elif dataset_split_seed.isnumeric():
481
+ dataset_split_seed = int(dataset_split_seed)
482
+ else:
483
+ st.error("Dataset split seed must be numeric or empty.")
484
+ st.stop()
485
+
486
  generation_config = generation_config_sliders | dict(
487
+ do_sample=do_sample, stop_sequences=stop_sequences, seed=decoding_seed
488
  )
489
 
490
+ st.session_state["dataset_split_seed"] = dataset_split_seed
491
+
492
  st.session_state["client"] = InferenceClient(
493
  token=st.secrets.get("hf_token", None), model=model
494
  )