shlomihod commited on
Commit
220a2a2
·
1 Parent(s): 4e5327b

add support to any dataset

Browse files
Files changed (1) hide show
  1. app.py +43 -21
app.py CHANGED
@@ -21,8 +21,11 @@ HF_MODEL = st.secrets.get("hf_model")
21
 
22
  HF_DATASET = st.secrets.get("hf_dataset")
23
 
24
- DATASET_SHUFFLE_SEED = 42
25
- NUM_SAMPLES = 25
 
 
 
26
  PROMPT_TEXT_HEIGHT = 300
27
 
28
  UNKNOWN_LABEL = "Unknown"
@@ -91,8 +94,8 @@ def normalize(text):
91
  return strip_newline_space(text).lower().capitalize()
92
 
93
 
94
- def prepare_datasets():
95
- ds = load_dataset(HF_DATASET)
96
 
97
  label_columns = [
98
  (name, info)
@@ -105,13 +108,14 @@ def prepare_datasets():
105
  label_dict = dict(enumerate(labels))
106
  input_columns = [name for name in ds["train"].features if name != label_column]
107
 
108
- def load(split):
109
- ds_split = ds[split]
 
110
 
111
- if split == "train":
112
- ds_split = ds_split.shuffle(seed=DATASET_SHUFFLE_SEED).select(
113
- range(NUM_SAMPLES)
114
- )
115
 
116
  df = ds_split.to_pandas()
117
 
@@ -120,10 +124,9 @@ def prepare_datasets():
120
 
121
  df[label_column] = df[label_column].replace(label_dict)
122
 
123
- return df
124
 
125
- # (load(split) for split in ("train",))
126
- return (load("train"), input_columns, label_column, labels)
127
 
128
 
129
  def complete(prompt, generation_config, details=False):
@@ -204,7 +207,7 @@ def infer_multi(prompt_template, inputs_df, generation_config=None, progress=Non
204
  def preprocess_output_line(text):
205
  return [
206
  normalized_token
207
- for token in st.session_state.tokenizer(text)
208
  if (normalized_token := normalize(str(token)))
209
  ]
210
 
@@ -300,18 +303,19 @@ if "client" not in st.session_state:
300
  token=st.secrets.get("hf_token"), model=HF_MODEL
301
  )
302
 
303
- if "tokenizer" not in st.session_state:
304
- st.session_state["tokenizer"] = English().tokenizer
305
 
306
  if "train_dataset" not in st.session_state or "test_dataset" not in st.session_state:
307
  (
308
- st.session_state["train_dataset"],
309
- # st.session_state["test_dataset"],
310
  st.session_state["input_columns"],
311
  st.session_state["label_column"],
312
  st.session_state["labels"],
313
- ) = prepare_datasets()
314
- st.session_state["test_dataset"] = st.session_state["train_dataset"]
 
 
315
 
316
  if "generation_config" not in st.session_state:
317
  st.session_state["generation_config"] = GENERATION_CONFIG_DEFAULTS
@@ -322,7 +326,9 @@ st.title(TITLE)
322
 
323
  with st.sidebar:
324
  with st.form("model_form"):
325
- model = st.text_input("Model", HF_MODEL)
 
 
326
 
327
  # Defautlt values from:
328
  # https://huggingface.co/docs/transformers/v4.30.0/main_classes/text_generation
@@ -365,6 +371,10 @@ with st.sidebar:
365
  submitted = st.form_submit_button("Set")
366
 
367
  if submitted:
 
 
 
 
368
  if not model:
369
  st.error("Model must be specified.")
370
  st.stop()
@@ -411,8 +421,20 @@ with st.sidebar:
411
  st.session_state["client"] = InferenceClient(
412
  token=st.secrets.get("hf_token"), model=model
413
  )
 
414
  st.session_state["generation_config"] = generation_config
415
 
 
 
 
 
 
 
 
 
 
 
 
416
  LOGGER.warning(f"FORM {model=}")
417
  LOGGER.warning(f"FORM {generation_config=}")
418
 
 
21
 
22
  HF_DATASET = st.secrets.get("hf_dataset")
23
 
24
+ DATASET_SPLIT_SEED = 42
25
+ TRAIN_SIZE = 20
26
+ TEST_SIZE = 50
27
+ SPLITS = ["train", "test"]
28
+
29
  PROMPT_TEXT_HEIGHT = 300
30
 
31
  UNKNOWN_LABEL = "Unknown"
 
94
  return strip_newline_space(text).lower().capitalize()
95
 
96
 
97
+ def prepare_datasets(dataset_name):
98
+ ds = load_dataset(dataset_name)
99
 
100
  label_columns = [
101
  (name, info)
 
108
  label_dict = dict(enumerate(labels))
109
  input_columns = [name for name in ds["train"].features if name != label_column]
110
 
111
+ ds = ds["train"].train_test_split(
112
+ train_size=TRAIN_SIZE, test_size=TEST_SIZE, seed=DATASET_SPLIT_SEED
113
+ )
114
 
115
+ dfs = {}
116
+
117
+ for split in ["train", "test"]:
118
+ ds_split = ds[split]
119
 
120
  df = ds_split.to_pandas()
121
 
 
124
 
125
  df[label_column] = df[label_column].replace(label_dict)
126
 
127
+ dfs[split] = df
128
 
129
+ return dfs, input_columns, label_column, labels
 
130
 
131
 
132
  def complete(prompt, generation_config, details=False):
 
207
  def preprocess_output_line(text):
208
  return [
209
  normalized_token
210
+ for token in st.session_state.processing_tokenizer(text)
211
  if (normalized_token := normalize(str(token)))
212
  ]
213
 
 
303
  token=st.secrets.get("hf_token"), model=HF_MODEL
304
  )
305
 
306
+ if "processing_tokenizer" not in st.session_state:
307
+ st.session_state["processing_tokenizer"] = English().tokenizer
308
 
309
  if "train_dataset" not in st.session_state or "test_dataset" not in st.session_state:
310
  (
311
+ splits_df,
 
312
  st.session_state["input_columns"],
313
  st.session_state["label_column"],
314
  st.session_state["labels"],
315
+ ) = prepare_datasets(HF_DATASET)
316
+
317
+ for split in splits_df:
318
+ st.session_state[f"{split}_dataset"] = splits_df[split]
319
 
320
  if "generation_config" not in st.session_state:
321
  st.session_state["generation_config"] = GENERATION_CONFIG_DEFAULTS
 
326
 
327
  with st.sidebar:
328
  with st.form("model_form"):
329
+ dataset = st.text_input("Dataset", HF_DATASET).strip()
330
+
331
+ model = st.text_input("Model", HF_MODEL).strip()
332
 
333
  # Defautlt values from:
334
  # https://huggingface.co/docs/transformers/v4.30.0/main_classes/text_generation
 
371
  submitted = st.form_submit_button("Set")
372
 
373
  if submitted:
374
+ if not dataset:
375
+ st.error("Dataset must be specified.")
376
+ st.stop()
377
+
378
  if not model:
379
  st.error("Model must be specified.")
380
  st.stop()
 
421
  st.session_state["client"] = InferenceClient(
422
  token=st.secrets.get("hf_token"), model=model
423
  )
424
+
425
  st.session_state["generation_config"] = generation_config
426
 
427
+ (
428
+ splits_df,
429
+ st.session_state["input_columns"],
430
+ st.session_state["label_column"],
431
+ st.session_state["labels"],
432
+ ) = prepare_datasets(dataset)
433
+
434
+ for split in splits_df:
435
+ st.session_state[f"{split}_dataset"] = splits_df[split]
436
+
437
+ LOGGER.warning(f"FORM {dataset=}")
438
  LOGGER.warning(f"FORM {model=}")
439
  LOGGER.warning(f"FORM {generation_config=}")
440