sdiazlor HF staff commited on
Commit
28b1761
1 Parent(s): 5c28c1d

fix: minor bug and feat:use seuqence(classlabel) for multilabel

Browse files
src/distilabel_dataset_generator/apps/base.py CHANGED
@@ -5,7 +5,7 @@ from typing import Any, Callable, List, Tuple, Union
5
  import argilla as rg
6
  import gradio as gr
7
  import pandas as pd
8
- from datasets import ClassLabel, Dataset, Features, Value
9
  from distilabel.distiset import Distiset
10
  from gradio import OAuthToken
11
  from huggingface_hub import HfApi, upload_file
@@ -421,30 +421,23 @@ def push_dataset_to_hub(
421
  progress(0.1, desc="Setting up dataset")
422
  repo_id = _check_push_to_hub(org_name, repo_name)
423
 
424
- if task == TEXTCAT_TASK and num_labels == 1:
425
- labels = [label.lower().strip() for label in labels]
426
- dataframe["label"] = dataframe["label"].apply(
427
- lambda x: x if x in labels else None
428
- )
429
- distiset = Distiset(
430
- {
431
- "default": Dataset.from_pandas(
432
- dataframe,
433
- features=Features(
434
- {
435
- "text": Value("string"),
436
- "label": ClassLabel(names=labels),
437
- }
438
- ),
439
- ),
440
- }
441
- )
442
  else:
443
- distiset = Distiset(
444
- {
445
- "default": Dataset.from_pandas(dataframe),
446
- }
447
- )
448
  progress(0.2, desc="Pushing dataset to hub")
449
  distiset.push_to_hub(
450
  repo_id=repo_id,
 
5
  import argilla as rg
6
  import gradio as gr
7
  import pandas as pd
8
+ from datasets import ClassLabel, Dataset, Features, Sequence, Value
9
  from distilabel.distiset import Distiset
10
  from gradio import OAuthToken
11
  from huggingface_hub import HfApi, upload_file
 
421
  progress(0.1, desc="Setting up dataset")
422
  repo_id = _check_push_to_hub(org_name, repo_name)
423
 
424
+ if task == TEXTCAT_TASK:
425
+ if num_labels == 1:
426
+ features = Features(
427
+ {"text": Value("string"), "label": ClassLabel(names=labels)}
428
+ )
429
+ else:
430
+ features = Features({
431
+ "text": Value("string"),
432
+ "labels": Sequence(feature=ClassLabel(names=labels))
433
+ })
434
+ distiset = Distiset({
435
+ "default": Dataset.from_pandas(dataframe, features=features)
436
+ })
 
 
 
 
 
437
  else:
438
+ distiset = Distiset({
439
+ "default": Dataset.from_pandas(dataframe)
440
+ })
 
 
441
  progress(0.2, desc="Pushing dataset to hub")
442
  distiset.push_to_hub(
443
  repo_id=repo_id,
src/distilabel_dataset_generator/apps/textcat.py CHANGED
@@ -37,6 +37,7 @@ from src.distilabel_dataset_generator.pipelines.textcat import (
37
  get_prompt_generator,
38
  get_textcat_generator,
39
  )
 
40
 
41
  TASK = "text_classification"
42
 
@@ -52,6 +53,7 @@ def push_dataset_to_hub(
52
  num_labels: int = 1,
53
  ):
54
  original_dataframe = dataframe.copy(deep=True)
 
55
  try:
56
  push_to_hub_base(
57
  dataframe,
@@ -82,7 +84,7 @@ def push_dataset_to_argilla(
82
  progress(0.1, desc="Setting up user and workspace")
83
  client = get_argilla_client()
84
  hf_user = HfApi().whoami(token=oauth_token.token)["name"]
85
- labels = [label.lower().strip() for label in labels]
86
  settings = rg.Settings(
87
  fields=[
88
  rg.TextField(
@@ -205,6 +207,7 @@ def generate_dataset(
205
  progress=gr.Progress(),
206
  ) -> pd.DataFrame:
207
  progress(0.0, desc="(1/2) Generating text classification data")
 
208
  textcat_generator = get_textcat_generator(
209
  difficulty=difficulty, clarity=clarity, is_sample=is_sample
210
  )
@@ -247,8 +250,8 @@ def generate_dataset(
247
  desc="(1/2) Labeling text classification data",
248
  )
249
  batch = textcat_results[n_processed : n_processed + batch_size]
250
- labels = list(labeller_generator.process(inputs=batch))
251
- labeller_results.extend(labels[0])
252
  n_processed += batch_size
253
  progress(
254
  1,
@@ -268,8 +271,24 @@ def generate_dataset(
268
  distiset_results.append(record)
269
 
270
  dataframe = pd.DataFrame(distiset_results)
271
- if num_labels == 1:
272
- dataframe = dataframe.rename(columns={"labels": "label"})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  progress(1.0, desc="Dataset generation completed")
274
  return dataframe
275
 
@@ -339,7 +358,7 @@ with app:
339
  )
340
  clarity = gr.Dropdown(
341
  choices=[
342
- ("Clear", "CLEAR"),
343
  (
344
  "Understandable",
345
  "understandable with some effort",
 
37
  get_prompt_generator,
38
  get_textcat_generator,
39
  )
40
+ from src.distilabel_dataset_generator.utils import get_preprocess_labels
41
 
42
  TASK = "text_classification"
43
 
 
53
  num_labels: int = 1,
54
  ):
55
  original_dataframe = dataframe.copy(deep=True)
56
+ labels = get_preprocess_labels(labels)
57
  try:
58
  push_to_hub_base(
59
  dataframe,
 
84
  progress(0.1, desc="Setting up user and workspace")
85
  client = get_argilla_client()
86
  hf_user = HfApi().whoami(token=oauth_token.token)["name"]
87
+ labels = get_preprocess_labels(labels)
88
  settings = rg.Settings(
89
  fields=[
90
  rg.TextField(
 
207
  progress=gr.Progress(),
208
  ) -> pd.DataFrame:
209
  progress(0.0, desc="(1/2) Generating text classification data")
210
+ labels = get_preprocess_labels(labels)
211
  textcat_generator = get_textcat_generator(
212
  difficulty=difficulty, clarity=clarity, is_sample=is_sample
213
  )
 
250
  desc="(1/2) Labeling text classification data",
251
  )
252
  batch = textcat_results[n_processed : n_processed + batch_size]
253
+ labels_batch = list(labeller_generator.process(inputs=batch))
254
+ labeller_results.extend(labels_batch[0])
255
  n_processed += batch_size
256
  progress(
257
  1,
 
271
  distiset_results.append(record)
272
 
273
  dataframe = pd.DataFrame(distiset_results)
274
+ if not is_sample:
275
+ if num_labels == 1:
276
+ dataframe = dataframe.rename(columns={"labels": "label"})
277
+ dataframe["label"] = dataframe["label"].apply(
278
+ lambda x: x.lower().strip() if x.lower().strip() in labels else None
279
+ )
280
+ else:
281
+ dataframe["labels"] = dataframe["labels"].apply(
282
+ lambda x: (
283
+ [
284
+ label.lower().strip()
285
+ for label in x
286
+ if label.lower().strip() in labels
287
+ ]
288
+ if isinstance(x, list)
289
+ else None
290
+ )
291
+ )
292
  progress(1.0, desc="Dataset generation completed")
293
  return dataframe
294
 
 
358
  )
359
  clarity = gr.Dropdown(
360
  choices=[
361
+ ("Clear", "clear"),
362
  (
363
  "Understandable",
364
  "understandable with some effort",
src/distilabel_dataset_generator/pipelines/textcat.py CHANGED
@@ -7,11 +7,11 @@ from distilabel.steps.tasks import (
7
  TextClassification,
8
  TextGeneration,
9
  )
10
-
11
  from src.distilabel_dataset_generator.pipelines.base import (
12
  MODEL,
13
  _get_next_api_key,
14
  )
 
15
 
16
  PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
17
 
@@ -84,7 +84,7 @@ def generate_pipeline_code(
84
  num_labels: int = 1,
85
  num_rows: int = 10,
86
  ) -> str:
87
- labels = [label.lower().strip() for label in labels or []]
88
  base_code = f"""
89
  # Requirements: `pip install distilabel[hf-inference-endpoints]`
90
  import os
@@ -159,6 +159,7 @@ with Pipeline(name="textcat") as pipeline:
159
  default_label="unknown"
160
  )
161
 
 
162
  task_generator >> textcat_generation >> keep_columns >> textcat_labeller
163
 
164
  if __name__ == "__main__":
@@ -186,7 +187,6 @@ def get_textcat_generator(difficulty, clarity, is_sample):
186
 
187
 
188
  def get_labeller_generator(system_prompt, labels, num_labels, is_sample):
189
- labels = [label.lower().strip() for label in labels]
190
  labeller_generator = TextClassification(
191
  llm=InferenceEndpointsLLM(
192
  model_id=MODEL,
 
7
  TextClassification,
8
  TextGeneration,
9
  )
 
10
  from src.distilabel_dataset_generator.pipelines.base import (
11
  MODEL,
12
  _get_next_api_key,
13
  )
14
+ from src.distilabel_dataset_generator.utils import get_preprocess_labels
15
 
16
  PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
17
 
 
84
  num_labels: int = 1,
85
  num_rows: int = 10,
86
  ) -> str:
87
+ labels = get_preprocess_labels(labels)
88
  base_code = f"""
89
  # Requirements: `pip install distilabel[hf-inference-endpoints]`
90
  import os
 
159
  default_label="unknown"
160
  )
161
 
162
+ # Connect steps in the pipeline
163
  task_generator >> textcat_generation >> keep_columns >> textcat_labeller
164
 
165
  if __name__ == "__main__":
 
187
 
188
 
189
  def get_labeller_generator(system_prompt, labels, num_labels, is_sample):
 
190
  labeller_generator = TextClassification(
191
  llm=InferenceEndpointsLLM(
192
  model_id=MODEL,
src/distilabel_dataset_generator/utils.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- from typing import Union, Optional
3
 
4
  import argilla as rg
5
  import gradio as gr
@@ -122,3 +122,6 @@ def get_argilla_client() -> Union[rg.Argilla, None]:
122
  )
123
  except Exception:
124
  return None
 
 
 
 
1
  import os
2
+ from typing import Union, List, Optional
3
 
4
  import argilla as rg
5
  import gradio as gr
 
122
  )
123
  except Exception:
124
  return None
125
+
126
+ def get_preprocess_labels(labels: Optional[List[str]]) -> List[str]:
127
+ return [label.lower().strip() for label in labels] if labels else []