fix: minor bug and feat:use seuqence(classlabel) for multilabel
Browse files
src/distilabel_dataset_generator/apps/base.py
CHANGED
@@ -5,7 +5,7 @@ from typing import Any, Callable, List, Tuple, Union
|
|
5 |
import argilla as rg
|
6 |
import gradio as gr
|
7 |
import pandas as pd
|
8 |
-
from datasets import ClassLabel, Dataset, Features, Value
|
9 |
from distilabel.distiset import Distiset
|
10 |
from gradio import OAuthToken
|
11 |
from huggingface_hub import HfApi, upload_file
|
@@ -421,30 +421,23 @@ def push_dataset_to_hub(
|
|
421 |
progress(0.1, desc="Setting up dataset")
|
422 |
repo_id = _check_push_to_hub(org_name, repo_name)
|
423 |
|
424 |
-
if task == TEXTCAT_TASK
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
{
|
431 |
-
"
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
}
|
438 |
-
),
|
439 |
-
),
|
440 |
-
}
|
441 |
-
)
|
442 |
else:
|
443 |
-
distiset = Distiset(
|
444 |
-
|
445 |
-
|
446 |
-
}
|
447 |
-
)
|
448 |
progress(0.2, desc="Pushing dataset to hub")
|
449 |
distiset.push_to_hub(
|
450 |
repo_id=repo_id,
|
|
|
5 |
import argilla as rg
|
6 |
import gradio as gr
|
7 |
import pandas as pd
|
8 |
+
from datasets import ClassLabel, Dataset, Features, Sequence, Value
|
9 |
from distilabel.distiset import Distiset
|
10 |
from gradio import OAuthToken
|
11 |
from huggingface_hub import HfApi, upload_file
|
|
|
421 |
progress(0.1, desc="Setting up dataset")
|
422 |
repo_id = _check_push_to_hub(org_name, repo_name)
|
423 |
|
424 |
+
if task == TEXTCAT_TASK:
|
425 |
+
if num_labels == 1:
|
426 |
+
features = Features(
|
427 |
+
{"text": Value("string"), "label": ClassLabel(names=labels)}
|
428 |
+
)
|
429 |
+
else:
|
430 |
+
features = Features({
|
431 |
+
"text": Value("string"),
|
432 |
+
"labels": Sequence(feature=ClassLabel(names=labels))
|
433 |
+
})
|
434 |
+
distiset = Distiset({
|
435 |
+
"default": Dataset.from_pandas(dataframe, features=features)
|
436 |
+
})
|
|
|
|
|
|
|
|
|
|
|
437 |
else:
|
438 |
+
distiset = Distiset({
|
439 |
+
"default": Dataset.from_pandas(dataframe)
|
440 |
+
})
|
|
|
|
|
441 |
progress(0.2, desc="Pushing dataset to hub")
|
442 |
distiset.push_to_hub(
|
443 |
repo_id=repo_id,
|
src/distilabel_dataset_generator/apps/textcat.py
CHANGED
@@ -37,6 +37,7 @@ from src.distilabel_dataset_generator.pipelines.textcat import (
|
|
37 |
get_prompt_generator,
|
38 |
get_textcat_generator,
|
39 |
)
|
|
|
40 |
|
41 |
TASK = "text_classification"
|
42 |
|
@@ -52,6 +53,7 @@ def push_dataset_to_hub(
|
|
52 |
num_labels: int = 1,
|
53 |
):
|
54 |
original_dataframe = dataframe.copy(deep=True)
|
|
|
55 |
try:
|
56 |
push_to_hub_base(
|
57 |
dataframe,
|
@@ -82,7 +84,7 @@ def push_dataset_to_argilla(
|
|
82 |
progress(0.1, desc="Setting up user and workspace")
|
83 |
client = get_argilla_client()
|
84 |
hf_user = HfApi().whoami(token=oauth_token.token)["name"]
|
85 |
-
labels =
|
86 |
settings = rg.Settings(
|
87 |
fields=[
|
88 |
rg.TextField(
|
@@ -205,6 +207,7 @@ def generate_dataset(
|
|
205 |
progress=gr.Progress(),
|
206 |
) -> pd.DataFrame:
|
207 |
progress(0.0, desc="(1/2) Generating text classification data")
|
|
|
208 |
textcat_generator = get_textcat_generator(
|
209 |
difficulty=difficulty, clarity=clarity, is_sample=is_sample
|
210 |
)
|
@@ -247,8 +250,8 @@ def generate_dataset(
|
|
247 |
desc="(1/2) Labeling text classification data",
|
248 |
)
|
249 |
batch = textcat_results[n_processed : n_processed + batch_size]
|
250 |
-
|
251 |
-
labeller_results.extend(
|
252 |
n_processed += batch_size
|
253 |
progress(
|
254 |
1,
|
@@ -268,8 +271,24 @@ def generate_dataset(
|
|
268 |
distiset_results.append(record)
|
269 |
|
270 |
dataframe = pd.DataFrame(distiset_results)
|
271 |
-
if
|
272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
progress(1.0, desc="Dataset generation completed")
|
274 |
return dataframe
|
275 |
|
@@ -339,7 +358,7 @@ with app:
|
|
339 |
)
|
340 |
clarity = gr.Dropdown(
|
341 |
choices=[
|
342 |
-
("Clear", "
|
343 |
(
|
344 |
"Understandable",
|
345 |
"understandable with some effort",
|
|
|
37 |
get_prompt_generator,
|
38 |
get_textcat_generator,
|
39 |
)
|
40 |
+
from src.distilabel_dataset_generator.utils import get_preprocess_labels
|
41 |
|
42 |
TASK = "text_classification"
|
43 |
|
|
|
53 |
num_labels: int = 1,
|
54 |
):
|
55 |
original_dataframe = dataframe.copy(deep=True)
|
56 |
+
labels = get_preprocess_labels(labels)
|
57 |
try:
|
58 |
push_to_hub_base(
|
59 |
dataframe,
|
|
|
84 |
progress(0.1, desc="Setting up user and workspace")
|
85 |
client = get_argilla_client()
|
86 |
hf_user = HfApi().whoami(token=oauth_token.token)["name"]
|
87 |
+
labels = get_preprocess_labels(labels)
|
88 |
settings = rg.Settings(
|
89 |
fields=[
|
90 |
rg.TextField(
|
|
|
207 |
progress=gr.Progress(),
|
208 |
) -> pd.DataFrame:
|
209 |
progress(0.0, desc="(1/2) Generating text classification data")
|
210 |
+
labels = get_preprocess_labels(labels)
|
211 |
textcat_generator = get_textcat_generator(
|
212 |
difficulty=difficulty, clarity=clarity, is_sample=is_sample
|
213 |
)
|
|
|
250 |
desc="(1/2) Labeling text classification data",
|
251 |
)
|
252 |
batch = textcat_results[n_processed : n_processed + batch_size]
|
253 |
+
labels_batch = list(labeller_generator.process(inputs=batch))
|
254 |
+
labeller_results.extend(labels_batch[0])
|
255 |
n_processed += batch_size
|
256 |
progress(
|
257 |
1,
|
|
|
271 |
distiset_results.append(record)
|
272 |
|
273 |
dataframe = pd.DataFrame(distiset_results)
|
274 |
+
if not is_sample:
|
275 |
+
if num_labels == 1:
|
276 |
+
dataframe = dataframe.rename(columns={"labels": "label"})
|
277 |
+
dataframe["label"] = dataframe["label"].apply(
|
278 |
+
lambda x: x.lower().strip() if x.lower().strip() in labels else None
|
279 |
+
)
|
280 |
+
else:
|
281 |
+
dataframe["labels"] = dataframe["labels"].apply(
|
282 |
+
lambda x: (
|
283 |
+
[
|
284 |
+
label.lower().strip()
|
285 |
+
for label in x
|
286 |
+
if label.lower().strip() in labels
|
287 |
+
]
|
288 |
+
if isinstance(x, list)
|
289 |
+
else None
|
290 |
+
)
|
291 |
+
)
|
292 |
progress(1.0, desc="Dataset generation completed")
|
293 |
return dataframe
|
294 |
|
|
|
358 |
)
|
359 |
clarity = gr.Dropdown(
|
360 |
choices=[
|
361 |
+
("Clear", "clear"),
|
362 |
(
|
363 |
"Understandable",
|
364 |
"understandable with some effort",
|
src/distilabel_dataset_generator/pipelines/textcat.py
CHANGED
@@ -7,11 +7,11 @@ from distilabel.steps.tasks import (
|
|
7 |
TextClassification,
|
8 |
TextGeneration,
|
9 |
)
|
10 |
-
|
11 |
from src.distilabel_dataset_generator.pipelines.base import (
|
12 |
MODEL,
|
13 |
_get_next_api_key,
|
14 |
)
|
|
|
15 |
|
16 |
PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
|
17 |
|
@@ -84,7 +84,7 @@ def generate_pipeline_code(
|
|
84 |
num_labels: int = 1,
|
85 |
num_rows: int = 10,
|
86 |
) -> str:
|
87 |
-
labels =
|
88 |
base_code = f"""
|
89 |
# Requirements: `pip install distilabel[hf-inference-endpoints]`
|
90 |
import os
|
@@ -159,6 +159,7 @@ with Pipeline(name="textcat") as pipeline:
|
|
159 |
default_label="unknown"
|
160 |
)
|
161 |
|
|
|
162 |
task_generator >> textcat_generation >> keep_columns >> textcat_labeller
|
163 |
|
164 |
if __name__ == "__main__":
|
@@ -186,7 +187,6 @@ def get_textcat_generator(difficulty, clarity, is_sample):
|
|
186 |
|
187 |
|
188 |
def get_labeller_generator(system_prompt, labels, num_labels, is_sample):
|
189 |
-
labels = [label.lower().strip() for label in labels]
|
190 |
labeller_generator = TextClassification(
|
191 |
llm=InferenceEndpointsLLM(
|
192 |
model_id=MODEL,
|
|
|
7 |
TextClassification,
|
8 |
TextGeneration,
|
9 |
)
|
|
|
10 |
from src.distilabel_dataset_generator.pipelines.base import (
|
11 |
MODEL,
|
12 |
_get_next_api_key,
|
13 |
)
|
14 |
+
from src.distilabel_dataset_generator.utils import get_preprocess_labels
|
15 |
|
16 |
PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
|
17 |
|
|
|
84 |
num_labels: int = 1,
|
85 |
num_rows: int = 10,
|
86 |
) -> str:
|
87 |
+
labels = get_preprocess_labels(labels)
|
88 |
base_code = f"""
|
89 |
# Requirements: `pip install distilabel[hf-inference-endpoints]`
|
90 |
import os
|
|
|
159 |
default_label="unknown"
|
160 |
)
|
161 |
|
162 |
+
# Connect steps in the pipeline
|
163 |
task_generator >> textcat_generation >> keep_columns >> textcat_labeller
|
164 |
|
165 |
if __name__ == "__main__":
|
|
|
187 |
|
188 |
|
189 |
def get_labeller_generator(system_prompt, labels, num_labels, is_sample):
|
|
|
190 |
labeller_generator = TextClassification(
|
191 |
llm=InferenceEndpointsLLM(
|
192 |
model_id=MODEL,
|
src/distilabel_dataset_generator/utils.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import os
|
2 |
-
from typing import Union, Optional
|
3 |
|
4 |
import argilla as rg
|
5 |
import gradio as gr
|
@@ -122,3 +122,6 @@ def get_argilla_client() -> Union[rg.Argilla, None]:
|
|
122 |
)
|
123 |
except Exception:
|
124 |
return None
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
from typing import Union, List, Optional
|
3 |
|
4 |
import argilla as rg
|
5 |
import gradio as gr
|
|
|
122 |
)
|
123 |
except Exception:
|
124 |
return None
|
125 |
+
|
126 |
+
def get_preprocess_labels(labels: Optional[List[str]]) -> List[str]:
|
127 |
+
return [label.lower().strip() for label in labels] if labels else []
|