Spaces:
Runtime error
Runtime error
shlomihod
commited on
Commit
·
9680652
1
Parent(s):
566249d
make the defaults less model/dataset specific
Browse files
app.py
CHANGED
@@ -17,9 +17,9 @@ LOGGER = logging.getLogger(__name__)
|
|
17 |
|
18 |
TITLE = "Prompter"
|
19 |
|
20 |
-
HF_MODEL = st.secrets.get("hf_model")
|
21 |
|
22 |
-
HF_DATASET = st.secrets.get("hf_dataset")
|
23 |
|
24 |
DATASET_SPLIT_SEED = 42
|
25 |
TRAIN_SIZE = 20
|
@@ -72,7 +72,7 @@ GENERATION_CONFIG_PARAMS = {
|
|
72 |
},
|
73 |
"stop_sequences": {
|
74 |
"NAME": "Stop Sequences",
|
75 |
-
"DEFAULT": [r"\nUser:", r"<|endoftext|>"],
|
76 |
"SAMPLING": False,
|
77 |
},
|
78 |
}
|
@@ -81,10 +81,6 @@ GENERATION_CONFIG_DEFAULTS = {
|
|
81 |
key: value["DEFAULT"] for key, value in GENERATION_CONFIG_PARAMS.items()
|
82 |
}
|
83 |
|
84 |
-
STARTER_PROMPT = """{content}
|
85 |
-
|
86 |
-
The sentiment of the text is"""
|
87 |
-
|
88 |
|
89 |
def strip_newline_space(text):
|
90 |
return text.strip("\n").strip()
|
@@ -195,17 +191,19 @@ def complete(prompt, generation_config, details=True):
|
|
195 |
def infer(prompt_template, inputs, generation_config=None):
|
196 |
prompt = prompt_template.format(**inputs)
|
197 |
output, length = complete(prompt, generation_config)
|
198 |
-
return output, length
|
199 |
|
200 |
|
201 |
def infer_multi(prompt_template, inputs_df, generation_config=None, progress=None):
|
202 |
props = (i / len(inputs_df) for i in range(1, len(inputs_df) + 1))
|
203 |
|
204 |
def infer_with_progress(inputs):
|
205 |
-
output, length = infer(
|
|
|
|
|
206 |
if progress is not None:
|
207 |
progress.progress(next(props))
|
208 |
-
return output, length
|
209 |
|
210 |
return zip(*inputs_df.apply(infer_with_progress, axis=1))
|
211 |
|
@@ -242,7 +240,7 @@ def canonize_label(output, annotation_labels, search_row):
|
|
242 |
return UNKNOWN_LABEL
|
243 |
|
244 |
|
245 |
-
def measure(dataset, outputs,
|
246 |
inferences = [
|
247 |
canonize_label(output, st.session_state.labels, search_row)
|
248 |
for output in outputs
|
@@ -262,9 +260,6 @@ def measure(dataset, outputs, lengths, search_row):
|
|
262 |
"output": outputs,
|
263 |
}
|
264 |
| dataset[st.session_state.input_columns].to_dict("list")
|
265 |
-
| {
|
266 |
-
"length": lengths,
|
267 |
-
}
|
268 |
)
|
269 |
|
270 |
acc = accuracy_score(evaluation_df["annotation"], evaluation_df["inference"])
|
@@ -294,14 +289,18 @@ def run_evaluation(
|
|
294 |
prompt_template, dataset, search_row, generation_config=None, progress=None
|
295 |
):
|
296 |
inputs_df = dataset[st.session_state.input_columns]
|
297 |
-
outputs, lengths = infer_multi(
|
298 |
prompt_template,
|
299 |
inputs_df,
|
300 |
generation_config,
|
301 |
progress,
|
302 |
)
|
303 |
|
304 |
-
metrics = measure(dataset, outputs,
|
|
|
|
|
|
|
|
|
305 |
return metrics
|
306 |
|
307 |
|
@@ -311,7 +310,7 @@ def combine_labels(labels):
|
|
311 |
|
312 |
if "client" not in st.session_state:
|
313 |
st.session_state["client"] = InferenceClient(
|
314 |
-
token=st.secrets.get("hf_token"), model=HF_MODEL
|
315 |
)
|
316 |
|
317 |
if "processing_tokenizer" not in st.session_state:
|
@@ -430,7 +429,7 @@ with st.sidebar:
|
|
430 |
)
|
431 |
|
432 |
st.session_state["client"] = InferenceClient(
|
433 |
-
token=st.secrets.get("hf_token"), model=model
|
434 |
)
|
435 |
|
436 |
st.session_state["generation_config"] = generation_config
|
@@ -461,9 +460,7 @@ tab1, tab2, tab3 = st.tabs(["Evaluation", "Training Dataset", "Playground"])
|
|
461 |
|
462 |
with tab1:
|
463 |
with st.form("prompt_form"):
|
464 |
-
prompt_template = st.text_area(
|
465 |
-
"Prompt Template", STARTER_PROMPT, height=PROMPT_TEXT_HEIGHT
|
466 |
-
)
|
467 |
|
468 |
st.write(f"Labels: {combine_labels(st.session_state.labels)}")
|
469 |
st.write(f"Inputs: {combine_labels(st.session_state.input_columns)}")
|
@@ -489,7 +486,7 @@ with tab1:
|
|
489 |
)
|
490 |
|
491 |
if not is_valid_prompt_template:
|
492 |
-
st.error("
|
493 |
st.stop()
|
494 |
|
495 |
inference_progress = st.progress(0, "Executing inference")
|
|
|
17 |
|
18 |
TITLE = "Prompter"
|
19 |
|
20 |
+
HF_MODEL = st.secrets.get("hf_model", "")
|
21 |
|
22 |
+
HF_DATASET = st.secrets.get("hf_dataset", "")
|
23 |
|
24 |
DATASET_SPLIT_SEED = 42
|
25 |
TRAIN_SIZE = 20
|
|
|
72 |
},
|
73 |
"stop_sequences": {
|
74 |
"NAME": "Stop Sequences",
|
75 |
+
"DEFAULT": [r"\nUser:", r"<|endoftext|>", r"\n### Human:", r"\n### User:"],
|
76 |
"SAMPLING": False,
|
77 |
},
|
78 |
}
|
|
|
81 |
key: value["DEFAULT"] for key, value in GENERATION_CONFIG_PARAMS.items()
|
82 |
}
|
83 |
|
|
|
|
|
|
|
|
|
84 |
|
85 |
def strip_newline_space(text):
|
86 |
return text.strip("\n").strip()
|
|
|
191 |
def infer(prompt_template, inputs, generation_config=None):
|
192 |
prompt = prompt_template.format(**inputs)
|
193 |
output, length = complete(prompt, generation_config)
|
194 |
+
return output, prompt, length
|
195 |
|
196 |
|
197 |
def infer_multi(prompt_template, inputs_df, generation_config=None, progress=None):
|
198 |
props = (i / len(inputs_df) for i in range(1, len(inputs_df) + 1))
|
199 |
|
200 |
def infer_with_progress(inputs):
|
201 |
+
output, prompt, length = infer(
|
202 |
+
prompt_template, inputs.to_dict(), generation_config
|
203 |
+
)
|
204 |
if progress is not None:
|
205 |
progress.progress(next(props))
|
206 |
+
return output, prompt, length
|
207 |
|
208 |
return zip(*inputs_df.apply(infer_with_progress, axis=1))
|
209 |
|
|
|
240 |
return UNKNOWN_LABEL
|
241 |
|
242 |
|
243 |
+
def measure(dataset, outputs, search_row):
|
244 |
inferences = [
|
245 |
canonize_label(output, st.session_state.labels, search_row)
|
246 |
for output in outputs
|
|
|
260 |
"output": outputs,
|
261 |
}
|
262 |
| dataset[st.session_state.input_columns].to_dict("list")
|
|
|
|
|
|
|
263 |
)
|
264 |
|
265 |
acc = accuracy_score(evaluation_df["annotation"], evaluation_df["inference"])
|
|
|
289 |
prompt_template, dataset, search_row, generation_config=None, progress=None
|
290 |
):
|
291 |
inputs_df = dataset[st.session_state.input_columns]
|
292 |
+
outputs, prompts, lengths = infer_multi(
|
293 |
prompt_template,
|
294 |
inputs_df,
|
295 |
generation_config,
|
296 |
progress,
|
297 |
)
|
298 |
|
299 |
+
metrics = measure(dataset, outputs, search_row)
|
300 |
+
|
301 |
+
metrics["hit_miss"]["prompt"] = prompts
|
302 |
+
metrics["hit_miss"]["length"] = lengths
|
303 |
+
|
304 |
return metrics
|
305 |
|
306 |
|
|
|
310 |
|
311 |
if "client" not in st.session_state:
|
312 |
st.session_state["client"] = InferenceClient(
|
313 |
+
token=st.secrets.get("hf_token", None), model=HF_MODEL
|
314 |
)
|
315 |
|
316 |
if "processing_tokenizer" not in st.session_state:
|
|
|
429 |
)
|
430 |
|
431 |
st.session_state["client"] = InferenceClient(
|
432 |
+
token=st.secrets.get("hf_token", None), model=model
|
433 |
)
|
434 |
|
435 |
st.session_state["generation_config"] = generation_config
|
|
|
460 |
|
461 |
with tab1:
|
462 |
with st.form("prompt_form"):
|
463 |
+
prompt_template = st.text_area("Prompt Template", height=PROMPT_TEXT_HEIGHT)
|
|
|
|
|
464 |
|
465 |
st.write(f"Labels: {combine_labels(st.session_state.labels)}")
|
466 |
st.write(f"Inputs: {combine_labels(st.session_state.input_columns)}")
|
|
|
486 |
)
|
487 |
|
488 |
if not is_valid_prompt_template:
|
489 |
+
st.error(f"The prompt template contains unrecognized fields.")
|
490 |
st.stop()
|
491 |
|
492 |
inference_progress = st.progress(0, "Executing inference")
|