shlomihod commited on
Commit
9680652
·
1 Parent(s): 566249d

make the defaults less model/dataset specific

Browse files
Files changed (1) hide show
  1. app.py +19 -22
app.py CHANGED
@@ -17,9 +17,9 @@ LOGGER = logging.getLogger(__name__)
17
 
18
  TITLE = "Prompter"
19
 
20
- HF_MODEL = st.secrets.get("hf_model")
21
 
22
- HF_DATASET = st.secrets.get("hf_dataset")
23
 
24
  DATASET_SPLIT_SEED = 42
25
  TRAIN_SIZE = 20
@@ -72,7 +72,7 @@ GENERATION_CONFIG_PARAMS = {
72
  },
73
  "stop_sequences": {
74
  "NAME": "Stop Sequences",
75
- "DEFAULT": [r"\nUser:", r"<|endoftext|>"],
76
  "SAMPLING": False,
77
  },
78
  }
@@ -81,10 +81,6 @@ GENERATION_CONFIG_DEFAULTS = {
81
  key: value["DEFAULT"] for key, value in GENERATION_CONFIG_PARAMS.items()
82
  }
83
 
84
- STARTER_PROMPT = """{content}
85
-
86
- The sentiment of the text is"""
87
-
88
 
89
  def strip_newline_space(text):
90
  return text.strip("\n").strip()
@@ -195,17 +191,19 @@ def complete(prompt, generation_config, details=True):
195
  def infer(prompt_template, inputs, generation_config=None):
196
  prompt = prompt_template.format(**inputs)
197
  output, length = complete(prompt, generation_config)
198
- return output, length
199
 
200
 
201
  def infer_multi(prompt_template, inputs_df, generation_config=None, progress=None):
202
  props = (i / len(inputs_df) for i in range(1, len(inputs_df) + 1))
203
 
204
  def infer_with_progress(inputs):
205
- output, length = infer(prompt_template, inputs.to_dict(), generation_config)
 
 
206
  if progress is not None:
207
  progress.progress(next(props))
208
- return output, length
209
 
210
  return zip(*inputs_df.apply(infer_with_progress, axis=1))
211
 
@@ -242,7 +240,7 @@ def canonize_label(output, annotation_labels, search_row):
242
  return UNKNOWN_LABEL
243
 
244
 
245
- def measure(dataset, outputs, lengths, search_row):
246
  inferences = [
247
  canonize_label(output, st.session_state.labels, search_row)
248
  for output in outputs
@@ -262,9 +260,6 @@ def measure(dataset, outputs, lengths, search_row):
262
  "output": outputs,
263
  }
264
  | dataset[st.session_state.input_columns].to_dict("list")
265
- | {
266
- "length": lengths,
267
- }
268
  )
269
 
270
  acc = accuracy_score(evaluation_df["annotation"], evaluation_df["inference"])
@@ -294,14 +289,18 @@ def run_evaluation(
294
  prompt_template, dataset, search_row, generation_config=None, progress=None
295
  ):
296
  inputs_df = dataset[st.session_state.input_columns]
297
- outputs, lengths = infer_multi(
298
  prompt_template,
299
  inputs_df,
300
  generation_config,
301
  progress,
302
  )
303
 
304
- metrics = measure(dataset, outputs, lengths, search_row)
 
 
 
 
305
  return metrics
306
 
307
 
@@ -311,7 +310,7 @@ def combine_labels(labels):
311
 
312
  if "client" not in st.session_state:
313
  st.session_state["client"] = InferenceClient(
314
- token=st.secrets.get("hf_token"), model=HF_MODEL
315
  )
316
 
317
  if "processing_tokenizer" not in st.session_state:
@@ -430,7 +429,7 @@ with st.sidebar:
430
  )
431
 
432
  st.session_state["client"] = InferenceClient(
433
- token=st.secrets.get("hf_token"), model=model
434
  )
435
 
436
  st.session_state["generation_config"] = generation_config
@@ -461,9 +460,7 @@ tab1, tab2, tab3 = st.tabs(["Evaluation", "Training Dataset", "Playground"])
461
 
462
  with tab1:
463
  with st.form("prompt_form"):
464
- prompt_template = st.text_area(
465
- "Prompt Template", STARTER_PROMPT, height=PROMPT_TEXT_HEIGHT
466
- )
467
 
468
  st.write(f"Labels: {combine_labels(st.session_state.labels)}")
469
  st.write(f"Inputs: {combine_labels(st.session_state.input_columns)}")
@@ -489,7 +486,7 @@ with tab1:
489
  )
490
 
491
  if not is_valid_prompt_template:
492
- st.error("Prompt template must contain some or all inputs fields.")
493
  st.stop()
494
 
495
  inference_progress = st.progress(0, "Executing inference")
 
17
 
18
  TITLE = "Prompter"
19
 
20
+ HF_MODEL = st.secrets.get("hf_model", "")
21
 
22
+ HF_DATASET = st.secrets.get("hf_dataset", "")
23
 
24
  DATASET_SPLIT_SEED = 42
25
  TRAIN_SIZE = 20
 
72
  },
73
  "stop_sequences": {
74
  "NAME": "Stop Sequences",
75
+ "DEFAULT": [r"\nUser:", r"<|endoftext|>", r"\n### Human:", r"\n### User:"],
76
  "SAMPLING": False,
77
  },
78
  }
 
81
  key: value["DEFAULT"] for key, value in GENERATION_CONFIG_PARAMS.items()
82
  }
83
 
 
 
 
 
84
 
85
  def strip_newline_space(text):
86
  return text.strip("\n").strip()
 
191
  def infer(prompt_template, inputs, generation_config=None):
192
  prompt = prompt_template.format(**inputs)
193
  output, length = complete(prompt, generation_config)
194
+ return output, prompt, length
195
 
196
 
197
  def infer_multi(prompt_template, inputs_df, generation_config=None, progress=None):
198
  props = (i / len(inputs_df) for i in range(1, len(inputs_df) + 1))
199
 
200
  def infer_with_progress(inputs):
201
+ output, prompt, length = infer(
202
+ prompt_template, inputs.to_dict(), generation_config
203
+ )
204
  if progress is not None:
205
  progress.progress(next(props))
206
+ return output, prompt, length
207
 
208
  return zip(*inputs_df.apply(infer_with_progress, axis=1))
209
 
 
240
  return UNKNOWN_LABEL
241
 
242
 
243
+ def measure(dataset, outputs, search_row):
244
  inferences = [
245
  canonize_label(output, st.session_state.labels, search_row)
246
  for output in outputs
 
260
  "output": outputs,
261
  }
262
  | dataset[st.session_state.input_columns].to_dict("list")
 
 
 
263
  )
264
 
265
  acc = accuracy_score(evaluation_df["annotation"], evaluation_df["inference"])
 
289
  prompt_template, dataset, search_row, generation_config=None, progress=None
290
  ):
291
  inputs_df = dataset[st.session_state.input_columns]
292
+ outputs, prompts, lengths = infer_multi(
293
  prompt_template,
294
  inputs_df,
295
  generation_config,
296
  progress,
297
  )
298
 
299
+ metrics = measure(dataset, outputs, search_row)
300
+
301
+ metrics["hit_miss"]["prompt"] = prompts
302
+ metrics["hit_miss"]["length"] = lengths
303
+
304
  return metrics
305
 
306
 
 
310
 
311
  if "client" not in st.session_state:
312
  st.session_state["client"] = InferenceClient(
313
+ token=st.secrets.get("hf_token", None), model=HF_MODEL
314
  )
315
 
316
  if "processing_tokenizer" not in st.session_state:
 
429
  )
430
 
431
  st.session_state["client"] = InferenceClient(
432
+ token=st.secrets.get("hf_token", None), model=model
433
  )
434
 
435
  st.session_state["generation_config"] = generation_config
 
460
 
461
  with tab1:
462
  with st.form("prompt_form"):
463
+ prompt_template = st.text_area("Prompt Template", height=PROMPT_TEXT_HEIGHT)
 
 
464
 
465
  st.write(f"Labels: {combine_labels(st.session_state.labels)}")
466
  st.write(f"Inputs: {combine_labels(st.session_state.input_columns)}")
 
486
  )
487
 
488
  if not is_valid_prompt_template:
489
+ st.error(f"The prompt template contains unrecognized fields.")
490
  st.stop()
491
 
492
  inference_progress = st.progress(0, "Executing inference")