Spaces:
Runtime error
Runtime error
shlomihod
commited on
Commit
·
4e5327b
1
Parent(s):
6c25427
identify labels and inputes for the dataset
Browse files
app.py
CHANGED
@@ -7,6 +7,7 @@ import numpy as np
|
|
7 |
import pandas as pd
|
8 |
import streamlit as st
|
9 |
from datasets import load_dataset
|
|
|
10 |
from huggingface_hub import InferenceClient
|
11 |
from huggingface_hub.utils import HfHubHTTPError
|
12 |
from sklearn.metrics import ConfusionMatrixDisplay, accuracy_score, confusion_matrix
|
@@ -18,15 +19,12 @@ TITLE = "Prompter"
|
|
18 |
|
19 |
HF_MODEL = st.secrets.get("hf_model")
|
20 |
|
21 |
-
HF_DATASET = "
|
22 |
|
23 |
DATASET_SHUFFLE_SEED = 42
|
24 |
NUM_SAMPLES = 25
|
25 |
PROMPT_TEXT_HEIGHT = 300
|
26 |
|
27 |
-
TEXT_COLUMN = "content"
|
28 |
-
ANNOTATION_COLUMN = "label"
|
29 |
-
|
30 |
UNKNOWN_LABEL = "Unknown"
|
31 |
|
32 |
SEARCH_ROW_DICT = {"First": 0, "Last": -1}
|
@@ -80,7 +78,7 @@ GENERATION_CONFIG_DEFAULTS = {
|
|
80 |
key: value["DEFAULT"] for key, value in GENERATION_CONFIG_PARAMS.items()
|
81 |
}
|
82 |
|
83 |
-
STARTER_PROMPT = """{
|
84 |
|
85 |
The sentiment of the text is"""
|
86 |
|
@@ -94,26 +92,41 @@ def normalize(text):
|
|
94 |
|
95 |
|
96 |
def prepare_datasets():
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
def load(split):
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
.select(
|
104 |
-
|
105 |
-
|
|
|
|
|
106 |
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
110 |
|
111 |
return df
|
112 |
|
113 |
-
|
|
|
114 |
|
115 |
|
116 |
-
def complete(prompt, generation_config):
|
117 |
if generation_config is None:
|
118 |
generation_config = {}
|
119 |
|
@@ -144,13 +157,13 @@ def complete(prompt, generation_config):
|
|
144 |
else:
|
145 |
assert generation_config["do_sample"]
|
146 |
|
147 |
-
LOGGER.warning(f"API Call
|
148 |
response = st.session_state.client.text_generation(
|
149 |
-
prompt, stream=False, details=
|
150 |
)
|
151 |
LOGGER.debug(response)
|
152 |
|
153 |
-
output = response.generated_text
|
154 |
|
155 |
# Remove stop sequences from the output
|
156 |
# Inspired by
|
@@ -170,22 +183,22 @@ def complete(prompt, generation_config):
|
|
170 |
return output
|
171 |
|
172 |
|
173 |
-
def infer(prompt_template,
|
174 |
-
prompt = prompt_template.format(
|
175 |
output = complete(prompt, generation_config)
|
176 |
return output
|
177 |
|
178 |
|
179 |
-
def infer_multi(prompt_template,
|
180 |
-
props = (i / len(
|
181 |
|
182 |
-
def infer_with_progress(
|
183 |
-
output = infer(prompt_template,
|
184 |
if progress is not None:
|
185 |
progress.progress(next(props))
|
186 |
return output
|
187 |
|
188 |
-
return
|
189 |
|
190 |
|
191 |
def preprocess_output_line(text):
|
@@ -221,34 +234,33 @@ def canonize_label(output, annotation_labels, search_row):
|
|
221 |
|
222 |
|
223 |
def measure(dataset, outputs, search_row):
|
224 |
-
annotation_labels = sorted(dataset[ANNOTATION_COLUMN].unique())
|
225 |
-
|
226 |
inferences = [
|
227 |
-
canonize_label(output,
|
|
|
228 |
]
|
229 |
|
230 |
-
|
|
|
|
|
231 |
|
232 |
evaluation_df = pd.DataFrame(
|
233 |
{
|
234 |
"hit/miss": np.where(
|
235 |
-
dataset[
|
236 |
),
|
237 |
-
"annotation": dataset[
|
238 |
"inference": inferences,
|
239 |
"output": outputs,
|
240 |
-
"text": dataset[TEXT_COLUMN],
|
241 |
}
|
|
|
242 |
)
|
243 |
|
244 |
-
all_labels = sorted(set(annotation_labels + inference_labels))
|
245 |
-
|
246 |
acc = accuracy_score(evaluation_df["annotation"], evaluation_df["inference"])
|
247 |
cm = confusion_matrix(
|
248 |
-
evaluation_df["annotation"], evaluation_df["inference"], labels=
|
249 |
)
|
250 |
|
251 |
-
cm_display = ConfusionMatrixDisplay(cm, display_labels=
|
252 |
cm_display.plot()
|
253 |
cm_display.ax_.set_xlabel("inference Labels")
|
254 |
cm_display.ax_.set_ylabel("Annotation Labels")
|
@@ -259,7 +271,7 @@ def measure(dataset, outputs, search_row):
|
|
259 |
"confusion_matrix": cm,
|
260 |
"confusion_matrix_display": cm_display.figure_,
|
261 |
"hit_miss": evaluation_df,
|
262 |
-
"annotation_labels":
|
263 |
"inference_labels": inference_labels,
|
264 |
}
|
265 |
|
@@ -270,7 +282,10 @@ def run_evaluation(
|
|
270 |
prompt_template, dataset, search_row, generation_config=None, progress=None
|
271 |
):
|
272 |
outputs = infer_multi(
|
273 |
-
prompt_template,
|
|
|
|
|
|
|
274 |
)
|
275 |
metrics = measure(dataset, outputs, search_row)
|
276 |
return metrics
|
@@ -291,8 +306,12 @@ if "tokenizer" not in st.session_state:
|
|
291 |
if "train_dataset" not in st.session_state or "test_dataset" not in st.session_state:
|
292 |
(
|
293 |
st.session_state["train_dataset"],
|
294 |
-
st.session_state["test_dataset"],
|
|
|
|
|
|
|
295 |
) = prepare_datasets()
|
|
|
296 |
|
297 |
if "generation_config" not in st.session_state:
|
298 |
st.session_state["generation_config"] = GENERATION_CONFIG_DEFAULTS
|
@@ -368,7 +387,7 @@ with st.sidebar:
|
|
368 |
value != GENERATION_CONFIG_DEFAULTS[name]
|
369 |
for name, value in generation_confing_slider_sampling.items()
|
370 |
)
|
371 |
-
and not
|
372 |
):
|
373 |
sampling_slider_default_values_info = " | ".join(
|
374 |
f"{name}={GENERATION_CONFIG_DEFAULTS[name]}"
|
@@ -379,14 +398,14 @@ with st.sidebar:
|
|
379 |
)
|
380 |
st.stop()
|
381 |
|
382 |
-
if seed is not None and not
|
383 |
st.error(
|
384 |
"Sampling must be enabled to use a seed. Otherwise, the seed field should be empty."
|
385 |
)
|
386 |
st.stop()
|
387 |
|
388 |
generation_config = generation_config_sliders | dict(
|
389 |
-
stop_sequences=stop_sequences, seed=seed
|
390 |
)
|
391 |
|
392 |
st.session_state["client"] = InferenceClient(
|
@@ -406,6 +425,9 @@ with tab1:
|
|
406 |
"Prompt Template", STARTER_PROMPT, height=PROMPT_TEXT_HEIGHT
|
407 |
)
|
408 |
|
|
|
|
|
|
|
409 |
col1, col2 = st.columns(2)
|
410 |
|
411 |
with col1:
|
@@ -422,12 +444,12 @@ with tab1:
|
|
422 |
st.stop()
|
423 |
|
424 |
_, formats, *_ = zip(*string.Formatter().parse(prompt_template))
|
425 |
-
is_valid_prompt_template = set(formats)
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
if not is_valid_prompt_template:
|
430 |
-
st.error("Prompt template must contain
|
431 |
st.stop()
|
432 |
|
433 |
inference_progress = st.progress(0, "Executing inference")
|
|
|
7 |
import pandas as pd
|
8 |
import streamlit as st
|
9 |
from datasets import load_dataset
|
10 |
+
from datasets.tasks.text_classification import ClassLabel
|
11 |
from huggingface_hub import InferenceClient
|
12 |
from huggingface_hub.utils import HfHubHTTPError
|
13 |
from sklearn.metrics import ConfusionMatrixDisplay, accuracy_score, confusion_matrix
|
|
|
19 |
|
20 |
HF_MODEL = st.secrets.get("hf_model")
|
21 |
|
22 |
+
HF_DATASET = st.secrets.get("hf_dataset")
|
23 |
|
24 |
DATASET_SHUFFLE_SEED = 42
|
25 |
NUM_SAMPLES = 25
|
26 |
PROMPT_TEXT_HEIGHT = 300
|
27 |
|
|
|
|
|
|
|
28 |
UNKNOWN_LABEL = "Unknown"
|
29 |
|
30 |
SEARCH_ROW_DICT = {"First": 0, "Last": -1}
|
|
|
78 |
key: value["DEFAULT"] for key, value in GENERATION_CONFIG_PARAMS.items()
|
79 |
}
|
80 |
|
81 |
+
STARTER_PROMPT = """{content}
|
82 |
|
83 |
The sentiment of the text is"""
|
84 |
|
|
|
92 |
|
93 |
|
94 |
def prepare_datasets():
|
95 |
+
ds = load_dataset(HF_DATASET)
|
96 |
+
|
97 |
+
label_columns = [
|
98 |
+
(name, info)
|
99 |
+
for name, info in ds["train"].features.items()
|
100 |
+
if isinstance(info, ClassLabel)
|
101 |
+
]
|
102 |
+
assert len(label_columns) == 1
|
103 |
+
label_column, label_column_info = label_columns[0]
|
104 |
+
labels = [normalize(label) for label in label_column_info.names]
|
105 |
+
label_dict = dict(enumerate(labels))
|
106 |
+
input_columns = [name for name in ds["train"].features if name != label_column]
|
107 |
|
108 |
def load(split):
|
109 |
+
ds_split = ds[split]
|
110 |
+
|
111 |
+
if split == "train":
|
112 |
+
ds_split = ds_split.shuffle(seed=DATASET_SHUFFLE_SEED).select(
|
113 |
+
range(NUM_SAMPLES)
|
114 |
+
)
|
115 |
+
|
116 |
+
df = ds_split.to_pandas()
|
117 |
|
118 |
+
for input_column in input_columns:
|
119 |
+
df[input_column] = df[input_column].apply(strip_newline_space)
|
120 |
+
|
121 |
+
df[label_column] = df[label_column].replace(label_dict)
|
122 |
|
123 |
return df
|
124 |
|
125 |
+
# (load(split) for split in ("train",))
|
126 |
+
return (load("train"), input_columns, label_column, labels)
|
127 |
|
128 |
|
129 |
+
def complete(prompt, generation_config, details=False):
|
130 |
if generation_config is None:
|
131 |
generation_config = {}
|
132 |
|
|
|
157 |
else:
|
158 |
assert generation_config["do_sample"]
|
159 |
|
160 |
+
LOGGER.warning(f"API Call\n\n``{prompt}``\n\n{generation_config=}")
|
161 |
response = st.session_state.client.text_generation(
|
162 |
+
prompt, stream=False, details=details, **generation_config
|
163 |
)
|
164 |
LOGGER.debug(response)
|
165 |
|
166 |
+
output = response.generated_text if details else response
|
167 |
|
168 |
# Remove stop sequences from the output
|
169 |
# Inspired by
|
|
|
183 |
return output
|
184 |
|
185 |
|
186 |
+
def infer(prompt_template, inputs, generation_config=None):
|
187 |
+
prompt = prompt_template.format(**inputs)
|
188 |
output = complete(prompt, generation_config)
|
189 |
return output
|
190 |
|
191 |
|
192 |
+
def infer_multi(prompt_template, inputs_df, generation_config=None, progress=None):
|
193 |
+
props = (i / len(inputs_df) for i in range(1, len(inputs_df) + 1))
|
194 |
|
195 |
+
def infer_with_progress(inputs):
|
196 |
+
output = infer(prompt_template, inputs.to_dict(), generation_config)
|
197 |
if progress is not None:
|
198 |
progress.progress(next(props))
|
199 |
return output
|
200 |
|
201 |
+
return inputs_df.apply(infer_with_progress, axis=1)
|
202 |
|
203 |
|
204 |
def preprocess_output_line(text):
|
|
|
234 |
|
235 |
|
236 |
def measure(dataset, outputs, search_row):
|
|
|
|
|
237 |
inferences = [
|
238 |
+
canonize_label(output, st.session_state.labels, search_row)
|
239 |
+
for output in outputs
|
240 |
]
|
241 |
|
242 |
+
print(f"{inferences=}")
|
243 |
+
print(f"{st.session_state.labels=}")
|
244 |
+
inference_labels = st.session_state.labels + [UNKNOWN_LABEL]
|
245 |
|
246 |
evaluation_df = pd.DataFrame(
|
247 |
{
|
248 |
"hit/miss": np.where(
|
249 |
+
dataset[st.session_state.label_column] == inferences, "hit", "miss"
|
250 |
),
|
251 |
+
"annotation": dataset[st.session_state.label_column],
|
252 |
"inference": inferences,
|
253 |
"output": outputs,
|
|
|
254 |
}
|
255 |
+
| dataset[st.session_state.input_columns].to_dict("list")
|
256 |
)
|
257 |
|
|
|
|
|
258 |
acc = accuracy_score(evaluation_df["annotation"], evaluation_df["inference"])
|
259 |
cm = confusion_matrix(
|
260 |
+
evaluation_df["annotation"], evaluation_df["inference"], labels=inference_labels
|
261 |
)
|
262 |
|
263 |
+
cm_display = ConfusionMatrixDisplay(cm, display_labels=inference_labels)
|
264 |
cm_display.plot()
|
265 |
cm_display.ax_.set_xlabel("inference Labels")
|
266 |
cm_display.ax_.set_ylabel("Annotation Labels")
|
|
|
271 |
"confusion_matrix": cm,
|
272 |
"confusion_matrix_display": cm_display.figure_,
|
273 |
"hit_miss": evaluation_df,
|
274 |
+
"annotation_labels": st.session_state.labels,
|
275 |
"inference_labels": inference_labels,
|
276 |
}
|
277 |
|
|
|
282 |
prompt_template, dataset, search_row, generation_config=None, progress=None
|
283 |
):
|
284 |
outputs = infer_multi(
|
285 |
+
prompt_template,
|
286 |
+
dataset[st.session_state.input_columns],
|
287 |
+
generation_config,
|
288 |
+
progress,
|
289 |
)
|
290 |
metrics = measure(dataset, outputs, search_row)
|
291 |
return metrics
|
|
|
306 |
if "train_dataset" not in st.session_state or "test_dataset" not in st.session_state:
|
307 |
(
|
308 |
st.session_state["train_dataset"],
|
309 |
+
# st.session_state["test_dataset"],
|
310 |
+
st.session_state["input_columns"],
|
311 |
+
st.session_state["label_column"],
|
312 |
+
st.session_state["labels"],
|
313 |
) = prepare_datasets()
|
314 |
+
st.session_state["test_dataset"] = st.session_state["train_dataset"]
|
315 |
|
316 |
if "generation_config" not in st.session_state:
|
317 |
st.session_state["generation_config"] = GENERATION_CONFIG_DEFAULTS
|
|
|
387 |
value != GENERATION_CONFIG_DEFAULTS[name]
|
388 |
for name, value in generation_confing_slider_sampling.items()
|
389 |
)
|
390 |
+
and not do_sample
|
391 |
):
|
392 |
sampling_slider_default_values_info = " | ".join(
|
393 |
f"{name}={GENERATION_CONFIG_DEFAULTS[name]}"
|
|
|
398 |
)
|
399 |
st.stop()
|
400 |
|
401 |
+
if seed is not None and not do_sample:
|
402 |
st.error(
|
403 |
"Sampling must be enabled to use a seed. Otherwise, the seed field should be empty."
|
404 |
)
|
405 |
st.stop()
|
406 |
|
407 |
generation_config = generation_config_sliders | dict(
|
408 |
+
do_sample=do_sample, stop_sequences=stop_sequences, seed=seed
|
409 |
)
|
410 |
|
411 |
st.session_state["client"] = InferenceClient(
|
|
|
425 |
"Prompt Template", STARTER_PROMPT, height=PROMPT_TEXT_HEIGHT
|
426 |
)
|
427 |
|
428 |
+
st.write(f"Labels: {combine_labels(st.session_state.labels)}")
|
429 |
+
st.write(f"Inputs: {combine_labels(st.session_state.input_columns)}")
|
430 |
+
|
431 |
col1, col2 = st.columns(2)
|
432 |
|
433 |
with col1:
|
|
|
444 |
st.stop()
|
445 |
|
446 |
_, formats, *_ = zip(*string.Formatter().parse(prompt_template))
|
447 |
+
is_valid_prompt_template = set(formats).issubset(
|
448 |
+
{None} | set(st.session_state.input_columns)
|
449 |
+
)
|
450 |
+
|
451 |
if not is_valid_prompt_template:
|
452 |
+
st.error("Prompt template must contain some or all inputs fields.")
|
453 |
st.stop()
|
454 |
|
455 |
inference_progress = st.progress(0, "Executing inference")
|