Spaces:
Runtime error
Runtime error
shlomihod
commited on
Commit
·
220a2a2
1
Parent(s):
4e5327b
add support to any dataset
Browse files
app.py
CHANGED
@@ -21,8 +21,11 @@ HF_MODEL = st.secrets.get("hf_model")
|
|
21 |
|
22 |
HF_DATASET = st.secrets.get("hf_dataset")
|
23 |
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
26 |
PROMPT_TEXT_HEIGHT = 300
|
27 |
|
28 |
UNKNOWN_LABEL = "Unknown"
|
@@ -91,8 +94,8 @@ def normalize(text):
|
|
91 |
return strip_newline_space(text).lower().capitalize()
|
92 |
|
93 |
|
94 |
-
def prepare_datasets():
|
95 |
-
ds = load_dataset(
|
96 |
|
97 |
label_columns = [
|
98 |
(name, info)
|
@@ -105,13 +108,14 @@ def prepare_datasets():
|
|
105 |
label_dict = dict(enumerate(labels))
|
106 |
input_columns = [name for name in ds["train"].features if name != label_column]
|
107 |
|
108 |
-
|
109 |
-
|
|
|
110 |
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
|
116 |
df = ds_split.to_pandas()
|
117 |
|
@@ -120,10 +124,9 @@ def prepare_datasets():
|
|
120 |
|
121 |
df[label_column] = df[label_column].replace(label_dict)
|
122 |
|
123 |
-
|
124 |
|
125 |
-
|
126 |
-
return (load("train"), input_columns, label_column, labels)
|
127 |
|
128 |
|
129 |
def complete(prompt, generation_config, details=False):
|
@@ -204,7 +207,7 @@ def infer_multi(prompt_template, inputs_df, generation_config=None, progress=Non
|
|
204 |
def preprocess_output_line(text):
|
205 |
return [
|
206 |
normalized_token
|
207 |
-
for token in st.session_state.
|
208 |
if (normalized_token := normalize(str(token)))
|
209 |
]
|
210 |
|
@@ -300,18 +303,19 @@ if "client" not in st.session_state:
|
|
300 |
token=st.secrets.get("hf_token"), model=HF_MODEL
|
301 |
)
|
302 |
|
303 |
-
if "
|
304 |
-
st.session_state["
|
305 |
|
306 |
if "train_dataset" not in st.session_state or "test_dataset" not in st.session_state:
|
307 |
(
|
308 |
-
|
309 |
-
# st.session_state["test_dataset"],
|
310 |
st.session_state["input_columns"],
|
311 |
st.session_state["label_column"],
|
312 |
st.session_state["labels"],
|
313 |
-
) = prepare_datasets()
|
314 |
-
|
|
|
|
|
315 |
|
316 |
if "generation_config" not in st.session_state:
|
317 |
st.session_state["generation_config"] = GENERATION_CONFIG_DEFAULTS
|
@@ -322,7 +326,9 @@ st.title(TITLE)
|
|
322 |
|
323 |
with st.sidebar:
|
324 |
with st.form("model_form"):
|
325 |
-
|
|
|
|
|
326 |
|
327 |
# Defautlt values from:
|
328 |
# https://huggingface.co/docs/transformers/v4.30.0/main_classes/text_generation
|
@@ -365,6 +371,10 @@ with st.sidebar:
|
|
365 |
submitted = st.form_submit_button("Set")
|
366 |
|
367 |
if submitted:
|
|
|
|
|
|
|
|
|
368 |
if not model:
|
369 |
st.error("Model must be specified.")
|
370 |
st.stop()
|
@@ -411,8 +421,20 @@ with st.sidebar:
|
|
411 |
st.session_state["client"] = InferenceClient(
|
412 |
token=st.secrets.get("hf_token"), model=model
|
413 |
)
|
|
|
414 |
st.session_state["generation_config"] = generation_config
|
415 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
416 |
LOGGER.warning(f"FORM {model=}")
|
417 |
LOGGER.warning(f"FORM {generation_config=}")
|
418 |
|
|
|
21 |
|
22 |
HF_DATASET = st.secrets.get("hf_dataset")
|
23 |
|
24 |
+
DATASET_SPLIT_SEED = 42
|
25 |
+
TRAIN_SIZE = 20
|
26 |
+
TEST_SIZE = 50
|
27 |
+
SPLITS = ["train", "test"]
|
28 |
+
|
29 |
PROMPT_TEXT_HEIGHT = 300
|
30 |
|
31 |
UNKNOWN_LABEL = "Unknown"
|
|
|
94 |
return strip_newline_space(text).lower().capitalize()
|
95 |
|
96 |
|
97 |
+
def prepare_datasets(dataset_name):
|
98 |
+
ds = load_dataset(dataset_name)
|
99 |
|
100 |
label_columns = [
|
101 |
(name, info)
|
|
|
108 |
label_dict = dict(enumerate(labels))
|
109 |
input_columns = [name for name in ds["train"].features if name != label_column]
|
110 |
|
111 |
+
ds = ds["train"].train_test_split(
|
112 |
+
train_size=TRAIN_SIZE, test_size=TEST_SIZE, seed=DATASET_SPLIT_SEED
|
113 |
+
)
|
114 |
|
115 |
+
dfs = {}
|
116 |
+
|
117 |
+
for split in ["train", "test"]:
|
118 |
+
ds_split = ds[split]
|
119 |
|
120 |
df = ds_split.to_pandas()
|
121 |
|
|
|
124 |
|
125 |
df[label_column] = df[label_column].replace(label_dict)
|
126 |
|
127 |
+
dfs[split] = df
|
128 |
|
129 |
+
return dfs, input_columns, label_column, labels
|
|
|
130 |
|
131 |
|
132 |
def complete(prompt, generation_config, details=False):
|
|
|
207 |
def preprocess_output_line(text):
|
208 |
return [
|
209 |
normalized_token
|
210 |
+
for token in st.session_state.processing_tokenizer(text)
|
211 |
if (normalized_token := normalize(str(token)))
|
212 |
]
|
213 |
|
|
|
303 |
token=st.secrets.get("hf_token"), model=HF_MODEL
|
304 |
)
|
305 |
|
306 |
+
if "processing_tokenizer" not in st.session_state:
|
307 |
+
st.session_state["processing_tokenizer"] = English().tokenizer
|
308 |
|
309 |
if "train_dataset" not in st.session_state or "test_dataset" not in st.session_state:
|
310 |
(
|
311 |
+
splits_df,
|
|
|
312 |
st.session_state["input_columns"],
|
313 |
st.session_state["label_column"],
|
314 |
st.session_state["labels"],
|
315 |
+
) = prepare_datasets(HF_DATASET)
|
316 |
+
|
317 |
+
for split in splits_df:
|
318 |
+
st.session_state[f"{split}_dataset"] = splits_df[split]
|
319 |
|
320 |
if "generation_config" not in st.session_state:
|
321 |
st.session_state["generation_config"] = GENERATION_CONFIG_DEFAULTS
|
|
|
326 |
|
327 |
with st.sidebar:
|
328 |
with st.form("model_form"):
|
329 |
+
dataset = st.text_input("Dataset", HF_DATASET).strip()
|
330 |
+
|
331 |
+
model = st.text_input("Model", HF_MODEL).strip()
|
332 |
|
333 |
# Defautlt values from:
|
334 |
# https://huggingface.co/docs/transformers/v4.30.0/main_classes/text_generation
|
|
|
371 |
submitted = st.form_submit_button("Set")
|
372 |
|
373 |
if submitted:
|
374 |
+
if not dataset:
|
375 |
+
st.error("Dataset must be specified.")
|
376 |
+
st.stop()
|
377 |
+
|
378 |
if not model:
|
379 |
st.error("Model must be specified.")
|
380 |
st.stop()
|
|
|
421 |
st.session_state["client"] = InferenceClient(
|
422 |
token=st.secrets.get("hf_token"), model=model
|
423 |
)
|
424 |
+
|
425 |
st.session_state["generation_config"] = generation_config
|
426 |
|
427 |
+
(
|
428 |
+
splits_df,
|
429 |
+
st.session_state["input_columns"],
|
430 |
+
st.session_state["label_column"],
|
431 |
+
st.session_state["labels"],
|
432 |
+
) = prepare_datasets(dataset)
|
433 |
+
|
434 |
+
for split in splits_df:
|
435 |
+
st.session_state[f"{split}_dataset"] = splits_df[split]
|
436 |
+
|
437 |
+
LOGGER.warning(f"FORM {dataset=}")
|
438 |
LOGGER.warning(f"FORM {model=}")
|
439 |
LOGGER.warning(f"FORM {generation_config=}")
|
440 |
|