shlomihod commited on
Commit
4e5327b
·
1 Parent(s): 6c25427

identify labels and inputes for the dataset

Browse files
Files changed (1) hide show
  1. app.py +71 -49
app.py CHANGED
@@ -7,6 +7,7 @@ import numpy as np
7
  import pandas as pd
8
  import streamlit as st
9
  from datasets import load_dataset
 
10
  from huggingface_hub import InferenceClient
11
  from huggingface_hub.utils import HfHubHTTPError
12
  from sklearn.metrics import ConfusionMatrixDisplay, accuracy_score, confusion_matrix
@@ -18,15 +19,12 @@ TITLE = "Prompter"
18
 
19
  HF_MODEL = st.secrets.get("hf_model")
20
 
21
- HF_DATASET = "amazon_polarity"
22
 
23
  DATASET_SHUFFLE_SEED = 42
24
  NUM_SAMPLES = 25
25
  PROMPT_TEXT_HEIGHT = 300
26
 
27
- TEXT_COLUMN = "content"
28
- ANNOTATION_COLUMN = "label"
29
-
30
  UNKNOWN_LABEL = "Unknown"
31
 
32
  SEARCH_ROW_DICT = {"First": 0, "Last": -1}
@@ -80,7 +78,7 @@ GENERATION_CONFIG_DEFAULTS = {
80
  key: value["DEFAULT"] for key, value in GENERATION_CONFIG_PARAMS.items()
81
  }
82
 
83
- STARTER_PROMPT = """{text}
84
 
85
  The sentiment of the text is"""
86
 
@@ -94,26 +92,41 @@ def normalize(text):
94
 
95
 
96
  def prepare_datasets():
97
- label_dict = {0: normalize("negative"), 1: normalize("positive")}
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  def load(split):
100
- df = (
101
- load_dataset(HF_DATASET, split=split)
102
- .shuffle(seed=DATASET_SHUFFLE_SEED)
103
- .select(range(NUM_SAMPLES))
104
- .to_pandas()
105
- )
 
 
106
 
107
- df["content"] = df["content"].apply(strip_newline_space)
108
- df["label"].replace(label_dict, inplace=True)
109
- df.drop(columns=["title"], inplace=True)
 
110
 
111
  return df
112
 
113
- return (load(split) for split in ("train", "test"))
 
114
 
115
 
116
- def complete(prompt, generation_config):
117
  if generation_config is None:
118
  generation_config = {}
119
 
@@ -144,13 +157,13 @@ def complete(prompt, generation_config):
144
  else:
145
  assert generation_config["do_sample"]
146
 
147
- LOGGER.warning(f"API Call {generation_config=}")
148
  response = st.session_state.client.text_generation(
149
- prompt, stream=False, details=True, **generation_config
150
  )
151
  LOGGER.debug(response)
152
 
153
- output = response.generated_text
154
 
155
  # Remove stop sequences from the output
156
  # Inspired by
@@ -170,22 +183,22 @@ def complete(prompt, generation_config):
170
  return output
171
 
172
 
173
- def infer(prompt_template, text, generation_config=None):
174
- prompt = prompt_template.format(text=text)
175
  output = complete(prompt, generation_config)
176
  return output
177
 
178
 
179
- def infer_multi(prompt_template, text_series, generation_config=None, progress=None):
180
- props = (i / len(text_series) for i in range(1, len(text_series) + 1))
181
 
182
- def infer_with_progress(text):
183
- output = infer(prompt_template, text, generation_config)
184
  if progress is not None:
185
  progress.progress(next(props))
186
  return output
187
 
188
- return text_series.apply(infer_with_progress)
189
 
190
 
191
  def preprocess_output_line(text):
@@ -221,34 +234,33 @@ def canonize_label(output, annotation_labels, search_row):
221
 
222
 
223
  def measure(dataset, outputs, search_row):
224
- annotation_labels = sorted(dataset[ANNOTATION_COLUMN].unique())
225
-
226
  inferences = [
227
- canonize_label(output, annotation_labels, search_row) for output in outputs
 
228
  ]
229
 
230
- inference_labels = annotation_labels.copy() + [UNKNOWN_LABEL]
 
 
231
 
232
  evaluation_df = pd.DataFrame(
233
  {
234
  "hit/miss": np.where(
235
- dataset[ANNOTATION_COLUMN] == inferences, "hit", "miss"
236
  ),
237
- "annotation": dataset[ANNOTATION_COLUMN],
238
  "inference": inferences,
239
  "output": outputs,
240
- "text": dataset[TEXT_COLUMN],
241
  }
 
242
  )
243
 
244
- all_labels = sorted(set(annotation_labels + inference_labels))
245
-
246
  acc = accuracy_score(evaluation_df["annotation"], evaluation_df["inference"])
247
  cm = confusion_matrix(
248
- evaluation_df["annotation"], evaluation_df["inference"], labels=all_labels
249
  )
250
 
251
- cm_display = ConfusionMatrixDisplay(cm, display_labels=all_labels)
252
  cm_display.plot()
253
  cm_display.ax_.set_xlabel("inference Labels")
254
  cm_display.ax_.set_ylabel("Annotation Labels")
@@ -259,7 +271,7 @@ def measure(dataset, outputs, search_row):
259
  "confusion_matrix": cm,
260
  "confusion_matrix_display": cm_display.figure_,
261
  "hit_miss": evaluation_df,
262
- "annotation_labels": annotation_labels,
263
  "inference_labels": inference_labels,
264
  }
265
 
@@ -270,7 +282,10 @@ def run_evaluation(
270
  prompt_template, dataset, search_row, generation_config=None, progress=None
271
  ):
272
  outputs = infer_multi(
273
- prompt_template, dataset[TEXT_COLUMN], generation_config, progress
 
 
 
274
  )
275
  metrics = measure(dataset, outputs, search_row)
276
  return metrics
@@ -291,8 +306,12 @@ if "tokenizer" not in st.session_state:
291
  if "train_dataset" not in st.session_state or "test_dataset" not in st.session_state:
292
  (
293
  st.session_state["train_dataset"],
294
- st.session_state["test_dataset"],
 
 
 
295
  ) = prepare_datasets()
 
296
 
297
  if "generation_config" not in st.session_state:
298
  st.session_state["generation_config"] = GENERATION_CONFIG_DEFAULTS
@@ -368,7 +387,7 @@ with st.sidebar:
368
  value != GENERATION_CONFIG_DEFAULTS[name]
369
  for name, value in generation_confing_slider_sampling.items()
370
  )
371
- and not generation_config["do_sample"]
372
  ):
373
  sampling_slider_default_values_info = " | ".join(
374
  f"{name}={GENERATION_CONFIG_DEFAULTS[name]}"
@@ -379,14 +398,14 @@ with st.sidebar:
379
  )
380
  st.stop()
381
 
382
- if seed is not None and not generation_config["do_sample"]:
383
  st.error(
384
  "Sampling must be enabled to use a seed. Otherwise, the seed field should be empty."
385
  )
386
  st.stop()
387
 
388
  generation_config = generation_config_sliders | dict(
389
- stop_sequences=stop_sequences, seed=seed
390
  )
391
 
392
  st.session_state["client"] = InferenceClient(
@@ -406,6 +425,9 @@ with tab1:
406
  "Prompt Template", STARTER_PROMPT, height=PROMPT_TEXT_HEIGHT
407
  )
408
 
 
 
 
409
  col1, col2 = st.columns(2)
410
 
411
  with col1:
@@ -422,12 +444,12 @@ with tab1:
422
  st.stop()
423
 
424
  _, formats, *_ = zip(*string.Formatter().parse(prompt_template))
425
- is_valid_prompt_template = set(formats) == {"text"} or set(formats) == {
426
- "text",
427
- None,
428
- }
429
  if not is_valid_prompt_template:
430
- st.error("Prompt template must contain a single {text} field.")
431
  st.stop()
432
 
433
  inference_progress = st.progress(0, "Executing inference")
 
7
  import pandas as pd
8
  import streamlit as st
9
  from datasets import load_dataset
10
+ from datasets.tasks.text_classification import ClassLabel
11
  from huggingface_hub import InferenceClient
12
  from huggingface_hub.utils import HfHubHTTPError
13
  from sklearn.metrics import ConfusionMatrixDisplay, accuracy_score, confusion_matrix
 
19
 
20
  HF_MODEL = st.secrets.get("hf_model")
21
 
22
+ HF_DATASET = st.secrets.get("hf_dataset")
23
 
24
  DATASET_SHUFFLE_SEED = 42
25
  NUM_SAMPLES = 25
26
  PROMPT_TEXT_HEIGHT = 300
27
 
 
 
 
28
  UNKNOWN_LABEL = "Unknown"
29
 
30
  SEARCH_ROW_DICT = {"First": 0, "Last": -1}
 
78
  key: value["DEFAULT"] for key, value in GENERATION_CONFIG_PARAMS.items()
79
  }
80
 
81
+ STARTER_PROMPT = """{content}
82
 
83
  The sentiment of the text is"""
84
 
 
92
 
93
 
94
  def prepare_datasets():
95
+ ds = load_dataset(HF_DATASET)
96
+
97
+ label_columns = [
98
+ (name, info)
99
+ for name, info in ds["train"].features.items()
100
+ if isinstance(info, ClassLabel)
101
+ ]
102
+ assert len(label_columns) == 1
103
+ label_column, label_column_info = label_columns[0]
104
+ labels = [normalize(label) for label in label_column_info.names]
105
+ label_dict = dict(enumerate(labels))
106
+ input_columns = [name for name in ds["train"].features if name != label_column]
107
 
108
  def load(split):
109
+ ds_split = ds[split]
110
+
111
+ if split == "train":
112
+ ds_split = ds_split.shuffle(seed=DATASET_SHUFFLE_SEED).select(
113
+ range(NUM_SAMPLES)
114
+ )
115
+
116
+ df = ds_split.to_pandas()
117
 
118
+ for input_column in input_columns:
119
+ df[input_column] = df[input_column].apply(strip_newline_space)
120
+
121
+ df[label_column] = df[label_column].replace(label_dict)
122
 
123
  return df
124
 
125
+ # (load(split) for split in ("train",))
126
+ return (load("train"), input_columns, label_column, labels)
127
 
128
 
129
+ def complete(prompt, generation_config, details=False):
130
  if generation_config is None:
131
  generation_config = {}
132
 
 
157
  else:
158
  assert generation_config["do_sample"]
159
 
160
+ LOGGER.warning(f"API Call\n\n``{prompt}``\n\n{generation_config=}")
161
  response = st.session_state.client.text_generation(
162
+ prompt, stream=False, details=details, **generation_config
163
  )
164
  LOGGER.debug(response)
165
 
166
+ output = response.generated_text if details else response
167
 
168
  # Remove stop sequences from the output
169
  # Inspired by
 
183
  return output
184
 
185
 
186
+ def infer(prompt_template, inputs, generation_config=None):
187
+ prompt = prompt_template.format(**inputs)
188
  output = complete(prompt, generation_config)
189
  return output
190
 
191
 
192
+ def infer_multi(prompt_template, inputs_df, generation_config=None, progress=None):
193
+ props = (i / len(inputs_df) for i in range(1, len(inputs_df) + 1))
194
 
195
+ def infer_with_progress(inputs):
196
+ output = infer(prompt_template, inputs.to_dict(), generation_config)
197
  if progress is not None:
198
  progress.progress(next(props))
199
  return output
200
 
201
+ return inputs_df.apply(infer_with_progress, axis=1)
202
 
203
 
204
  def preprocess_output_line(text):
 
234
 
235
 
236
  def measure(dataset, outputs, search_row):
 
 
237
  inferences = [
238
+ canonize_label(output, st.session_state.labels, search_row)
239
+ for output in outputs
240
  ]
241
 
242
+ print(f"{inferences=}")
243
+ print(f"{st.session_state.labels=}")
244
+ inference_labels = st.session_state.labels + [UNKNOWN_LABEL]
245
 
246
  evaluation_df = pd.DataFrame(
247
  {
248
  "hit/miss": np.where(
249
+ dataset[st.session_state.label_column] == inferences, "hit", "miss"
250
  ),
251
+ "annotation": dataset[st.session_state.label_column],
252
  "inference": inferences,
253
  "output": outputs,
 
254
  }
255
+ | dataset[st.session_state.input_columns].to_dict("list")
256
  )
257
 
 
 
258
  acc = accuracy_score(evaluation_df["annotation"], evaluation_df["inference"])
259
  cm = confusion_matrix(
260
+ evaluation_df["annotation"], evaluation_df["inference"], labels=inference_labels
261
  )
262
 
263
+ cm_display = ConfusionMatrixDisplay(cm, display_labels=inference_labels)
264
  cm_display.plot()
265
  cm_display.ax_.set_xlabel("inference Labels")
266
  cm_display.ax_.set_ylabel("Annotation Labels")
 
271
  "confusion_matrix": cm,
272
  "confusion_matrix_display": cm_display.figure_,
273
  "hit_miss": evaluation_df,
274
+ "annotation_labels": st.session_state.labels,
275
  "inference_labels": inference_labels,
276
  }
277
 
 
282
  prompt_template, dataset, search_row, generation_config=None, progress=None
283
  ):
284
  outputs = infer_multi(
285
+ prompt_template,
286
+ dataset[st.session_state.input_columns],
287
+ generation_config,
288
+ progress,
289
  )
290
  metrics = measure(dataset, outputs, search_row)
291
  return metrics
 
306
  if "train_dataset" not in st.session_state or "test_dataset" not in st.session_state:
307
  (
308
  st.session_state["train_dataset"],
309
+ # st.session_state["test_dataset"],
310
+ st.session_state["input_columns"],
311
+ st.session_state["label_column"],
312
+ st.session_state["labels"],
313
  ) = prepare_datasets()
314
+ st.session_state["test_dataset"] = st.session_state["train_dataset"]
315
 
316
  if "generation_config" not in st.session_state:
317
  st.session_state["generation_config"] = GENERATION_CONFIG_DEFAULTS
 
387
  value != GENERATION_CONFIG_DEFAULTS[name]
388
  for name, value in generation_confing_slider_sampling.items()
389
  )
390
+ and not do_sample
391
  ):
392
  sampling_slider_default_values_info = " | ".join(
393
  f"{name}={GENERATION_CONFIG_DEFAULTS[name]}"
 
398
  )
399
  st.stop()
400
 
401
+ if seed is not None and not do_sample:
402
  st.error(
403
  "Sampling must be enabled to use a seed. Otherwise, the seed field should be empty."
404
  )
405
  st.stop()
406
 
407
  generation_config = generation_config_sliders | dict(
408
+ do_sample=do_sample, stop_sequences=stop_sequences, seed=seed
409
  )
410
 
411
  st.session_state["client"] = InferenceClient(
 
425
  "Prompt Template", STARTER_PROMPT, height=PROMPT_TEXT_HEIGHT
426
  )
427
 
428
+ st.write(f"Labels: {combine_labels(st.session_state.labels)}")
429
+ st.write(f"Inputs: {combine_labels(st.session_state.input_columns)}")
430
+
431
  col1, col2 = st.columns(2)
432
 
433
  with col1:
 
444
  st.stop()
445
 
446
  _, formats, *_ = zip(*string.Formatter().parse(prompt_template))
447
+ is_valid_prompt_template = set(formats).issubset(
448
+ {None} | set(st.session_state.input_columns)
449
+ )
450
+
451
  if not is_valid_prompt_template:
452
+ st.error("Prompt template must contain some or all inputs fields.")
453
  st.stop()
454
 
455
  inference_progress = st.progress(0, "Executing inference")