sdiazlor HF staff commited on
Commit
229dcf3
1 Parent(s): 288d796

Add working textcat version

Browse files
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ synthetic-data-generator
requirements.txt CHANGED
@@ -3,4 +3,5 @@ gradio[oauth]
3
  distilabel[hf-inference-endpoints,argilla]
4
  beautifulsoup4
5
  sentence-transformers
6
- model2vec
 
 
3
  distilabel[hf-inference-endpoints,argilla]
4
  beautifulsoup4
5
  sentence-transformers
6
+ model2vec
7
+ outlines
src/distilabel_dataset_generator/apps/base.py CHANGED
@@ -1,11 +1,12 @@
1
  import io
 
2
  import uuid
3
  from typing import Any, Callable, List, Optional, Tuple, Union
4
 
5
  import argilla as rg
6
  import gradio as gr
7
  import pandas as pd
8
- from datasets import Dataset
9
  from distilabel.distiset import Distiset
10
  from gradio import OAuthToken
11
  from huggingface_hub import HfApi, upload_file
@@ -16,6 +17,8 @@ from src.distilabel_dataset_generator.utils import (
16
  list_orgs,
17
  )
18
 
 
 
19
 
20
  def swap_visibilty(oauth_token: Optional[OAuthToken] = None):
21
  if oauth_token:
@@ -30,15 +33,21 @@ def get_main_ui(
30
  default_datasets: List[pd.DataFrame],
31
  fn_generate_system_prompt: Callable,
32
  fn_generate_dataset: Callable,
 
33
  ):
34
  def fn_generate_sample_dataset(system_prompt, progress=gr.Progress()):
35
  if system_prompt in default_system_prompts:
36
  index = default_system_prompts.index(system_prompt)
37
  if index < len(default_datasets):
38
  return default_datasets[index]
39
- result = fn_generate_dataset(
40
- system_prompt, num_turns=1, num_rows=1, progress=progress, is_sample=True
41
- )
 
 
 
 
 
42
  return result
43
 
44
  with gr.Blocks(
@@ -109,7 +118,7 @@ def get_main_ui(
109
  outputs=[sample_dataset],
110
  show_progress=True,
111
  )
112
-
113
  btn_generate_sample_dataset.click(
114
  fn=fn_generate_sample_dataset,
115
  inputs=[system_prompt],
@@ -306,14 +315,14 @@ def get_argilla_tab() -> Tuple[Any]:
306
  dataset_name = gr.Textbox(
307
  label="Dataset name",
308
  placeholder="dataset_name",
309
- value="my-distiset",
310
  )
311
  add_to_existing_dataset = gr.Checkbox(
312
  label="Allow adding records to existing dataset",
313
- info="When selected, you do need to ensure the number of turns in the conversation is the same as the number of turns in the existing dataset.",
314
  value=False,
315
  interactive=True,
316
- scale=0.5,
317
  )
318
 
319
  with gr.Row(variant="panel"):
@@ -354,7 +363,7 @@ def get_hf_tab() -> Tuple[Any]:
354
  label="Private dataset",
355
  value=True,
356
  interactive=True,
357
- scale=0.5,
358
  )
359
  with gr.Row(variant="panel"):
360
  btn_generate_full_dataset = gr.Button(
@@ -403,14 +412,33 @@ def push_dataset_to_hub(
403
  repo_name: str = None,
404
  oauth_token: Union[OAuthToken, None] = None,
405
  progress=gr.Progress(),
 
 
 
406
  ) -> pd.DataFrame:
407
  progress(0.1, desc="Setting up dataset")
408
  repo_id = _check_push_to_hub(org_name, repo_name)
409
- distiset = Distiset(
410
- {
411
- "default": Dataset.from_pandas(dataframe),
412
- }
413
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
  progress(0.2, desc="Pushing dataset to hub")
415
  distiset.push_to_hub(
416
  repo_id=repo_id,
@@ -444,6 +472,7 @@ def get_final_dataset_row(default_datasets) -> gr.Dataframe:
444
  label="Generated dataset",
445
  interactive=False,
446
  wrap=True,
 
447
  )
448
  return final_dataset
449
 
 
1
  import io
2
+ import re
3
  import uuid
4
  from typing import Any, Callable, List, Optional, Tuple, Union
5
 
6
  import argilla as rg
7
  import gradio as gr
8
  import pandas as pd
9
+ from datasets import Dataset, Features, ClassLabel, Value
10
  from distilabel.distiset import Distiset
11
  from gradio import OAuthToken
12
  from huggingface_hub import HfApi, upload_file
 
17
  list_orgs,
18
  )
19
 
20
+ TEXTCAT_TASK = "text_classification"
21
+ SFT_TASK = "supervised_finetuning"
22
 
23
  def swap_visibilty(oauth_token: Optional[OAuthToken] = None):
24
  if oauth_token:
 
33
  default_datasets: List[pd.DataFrame],
34
  fn_generate_system_prompt: Callable,
35
  fn_generate_dataset: Callable,
36
+ task: str,
37
  ):
38
  def fn_generate_sample_dataset(system_prompt, progress=gr.Progress()):
39
  if system_prompt in default_system_prompts:
40
  index = default_system_prompts.index(system_prompt)
41
  if index < len(default_datasets):
42
  return default_datasets[index]
43
+ if task == TEXTCAT_TASK:
44
+ result = fn_generate_dataset(
45
+ system_prompt, difficulty="mixed", clarity="mixed", labels=[], num_labels=1, num_rows=1, progress=progress, is_sample=True
46
+ )
47
+ else:
48
+ result = fn_generate_dataset(
49
+ system_prompt, num_turns=1, num_rows=1, progress=progress, is_sample=True
50
+ )
51
  return result
52
 
53
  with gr.Blocks(
 
118
  outputs=[sample_dataset],
119
  show_progress=True,
120
  )
121
+
122
  btn_generate_sample_dataset.click(
123
  fn=fn_generate_sample_dataset,
124
  inputs=[system_prompt],
 
315
  dataset_name = gr.Textbox(
316
  label="Dataset name",
317
  placeholder="dataset_name",
318
+ value=f"my-distiset-{uuid.uuid4()}", ######## CHANGE AFTER TESTING
319
  )
320
  add_to_existing_dataset = gr.Checkbox(
321
  label="Allow adding records to existing dataset",
322
+ info="When selected, you do need to ensure the dataset options are the same as in the existing dataset.",
323
  value=False,
324
  interactive=True,
325
+ scale=1,
326
  )
327
 
328
  with gr.Row(variant="panel"):
 
363
  label="Private dataset",
364
  value=True,
365
  interactive=True,
366
+ scale=1,
367
  )
368
  with gr.Row(variant="panel"):
369
  btn_generate_full_dataset = gr.Button(
 
412
  repo_name: str = None,
413
  oauth_token: Union[OAuthToken, None] = None,
414
  progress=gr.Progress(),
415
+ labels: List[str] = None,
416
+ num_labels: int = None,
417
+ task: str = TEXTCAT_TASK,
418
  ) -> pd.DataFrame:
419
  progress(0.1, desc="Setting up dataset")
420
  repo_id = _check_push_to_hub(org_name, repo_name)
421
+
422
+ if task == TEXTCAT_TASK and num_labels == 1:
423
+ distiset = Distiset(
424
+ {
425
+ "default": Dataset.from_pandas(
426
+ dataframe,
427
+ features=Features(
428
+ {
429
+ "text": Value("string"),
430
+ "label": ClassLabel(names=labels),
431
+ }
432
+ ),
433
+ ),
434
+ }
435
+ )
436
+ else:
437
+ distiset = Distiset(
438
+ {
439
+ "default": Dataset.from_pandas(dataframe),
440
+ }
441
+ )
442
  progress(0.2, desc="Pushing dataset to hub")
443
  distiset.push_to_hub(
444
  repo_id=repo_id,
 
472
  label="Generated dataset",
473
  interactive=False,
474
  wrap=True,
475
+ min_width=300,
476
  )
477
  return final_dataset
478
 
src/distilabel_dataset_generator/apps/sft.py CHANGED
@@ -39,6 +39,8 @@ from src.distilabel_dataset_generator.pipelines.sft import (
39
  get_response_generator,
40
  )
41
 
 
 
42
 
43
  def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
44
  def convert_to_list_of_dicts(messages: str) -> List[Dict[str, str]]:
@@ -65,7 +67,9 @@ def push_dataset_to_hub(
65
  ):
66
  original_dataframe = dataframe.copy(deep=True)
67
  dataframe = convert_dataframe_messages(dataframe)
68
- push_to_hub_base(dataframe, private, org_name, repo_name, oauth_token, progress)
 
 
69
  return original_dataframe
70
 
71
 
@@ -357,6 +361,7 @@ def generate_dataset(
357
  default_datasets=DEFAULT_DATASETS,
358
  fn_generate_system_prompt=generate_system_prompt,
359
  fn_generate_dataset=generate_dataset,
 
360
  )
361
 
362
  with app:
 
39
  get_response_generator,
40
  )
41
 
42
+ TASK = "supervised_fine_tuning"
43
+
44
 
45
  def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
46
  def convert_to_list_of_dicts(messages: str) -> List[Dict[str, str]]:
 
67
  ):
68
  original_dataframe = dataframe.copy(deep=True)
69
  dataframe = convert_dataframe_messages(dataframe)
70
+ push_to_hub_base(
71
+ dataframe, private, org_name, repo_name, oauth_token, progress, task=TASK
72
+ )
73
  return original_dataframe
74
 
75
 
 
361
  default_datasets=DEFAULT_DATASETS,
362
  fn_generate_system_prompt=generate_system_prompt,
363
  fn_generate_dataset=generate_dataset,
364
+ task=TASK,
365
  )
366
 
367
  with app:
src/distilabel_dataset_generator/apps/textcat.py CHANGED
@@ -1,43 +1,250 @@
1
- from typing import List
 
2
 
 
3
  import gradio as gr
4
  import pandas as pd
 
 
 
5
 
6
  from src.distilabel_dataset_generator.apps.base import (
 
7
  get_main_ui,
8
  get_pipeline_code_ui,
9
  hide_success_message,
10
- push_dataset_to_hub,
11
  push_pipeline_code_to_hub,
12
  show_success_message_argilla,
13
  show_success_message_hub,
14
  validate_argilla_user_workspace_dataset,
15
  )
 
 
 
 
 
 
 
 
 
 
16
  from src.distilabel_dataset_generator.pipelines.textcat import (
17
  DEFAULT_DATASET_DESCRIPTIONS,
18
  DEFAULT_DATASETS,
19
  DEFAULT_SYSTEM_PROMPTS,
 
20
  generate_pipeline_code,
 
 
 
21
  )
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- def push_dataset_to_argilla(dataset: pd.DataFrame, dataset_name: str) -> pd.DataFrame:
25
- return dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- def generate_system_prompt(dataset_description: str) -> str:
29
- return dataset_description
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
 
32
  def generate_dataset(
33
  system_prompt: str,
34
  difficulty: str,
35
  clarity: str,
36
- labels: List[str],
37
- num_labels: int,
38
- num_rows: int,
 
 
39
  ) -> pd.DataFrame:
40
- return pd.DataFrame({"prompt": [system_prompt], "completion": [system_prompt]})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
 
43
  (
@@ -69,8 +276,20 @@ def generate_dataset(
69
  default_datasets=DEFAULT_DATASETS,
70
  fn_generate_system_prompt=generate_system_prompt,
71
  fn_generate_dataset=generate_dataset,
 
72
  )
73
 
 
 
 
 
 
 
 
 
 
 
 
74
  with app:
75
  with main_ui:
76
  with custom_input_ui:
@@ -78,7 +297,7 @@ with app:
78
  choices=[
79
  ("High School", "high school"),
80
  ("College", "college"),
81
- ("PhD", "phd"),
82
  ("Mixed", "mixed"),
83
  ],
84
  value="mixed",
@@ -86,29 +305,38 @@ with app:
86
  )
87
  clarity = gr.Dropdown(
88
  choices=[
89
- ("Clear", "Clear"),
90
  (
91
  "Understandable",
92
  "understandable with some effort",
93
  ),
94
- ("Ambiguous", "Ambiguous"),
95
  ("Mixed", "mixed"),
96
  ],
97
  value="mixed",
98
  label="Clarity",
99
  )
100
- labels = gr.Dropdown(
101
- choices=[],
102
- allow_custom_value=True,
103
- interactive=True,
104
- label="Labels",
105
- multiselect=True,
106
- )
 
 
 
 
 
 
107
  num_labels = gr.Number(
108
  label="Number of labels", value=1, minimum=1, maximum=10
109
  )
110
  num_rows = gr.Number(
111
- label="Number of rows", value=10, minimum=1, maximum=500
 
 
 
112
  )
113
 
114
  pipeline_code = get_pipeline_code_ui(
@@ -123,6 +351,12 @@ with app:
123
  )
124
 
125
  # define app triggers
 
 
 
 
 
 
126
  gr.on(
127
  triggers=[
128
  btn_generate_full_dataset.click,
@@ -152,7 +386,7 @@ with app:
152
  show_progress=True,
153
  ).success(
154
  fn=push_dataset_to_argilla,
155
- inputs=[final_dataset, dataset_name],
156
  outputs=[final_dataset],
157
  show_progress=True,
158
  ).success(
@@ -171,7 +405,7 @@ with app:
171
  show_progress=True,
172
  ).then(
173
  fn=push_dataset_to_hub,
174
- inputs=[final_dataset, private, org_name, repo_name],
175
  outputs=[final_dataset],
176
  show_progress=True,
177
  ).then(
@@ -190,7 +424,7 @@ with app:
190
  outputs=[success_message],
191
  ).then(
192
  fn=push_dataset_to_hub,
193
- inputs=[final_dataset, private, org_name, repo_name],
194
  outputs=[final_dataset],
195
  show_progress=True,
196
  ).then(
@@ -214,7 +448,7 @@ with app:
214
  show_progress=True,
215
  ).success(
216
  fn=push_dataset_to_argilla,
217
- inputs=[final_dataset, dataset_name],
218
  outputs=[final_dataset],
219
  show_progress=True,
220
  ).success(
 
1
+ import re
2
+ from typing import Dict, List, Union
3
 
4
+ import argilla as rg
5
  import gradio as gr
6
  import pandas as pd
7
+ from datasets import Dataset
8
+ from distilabel.distiset import Distiset
9
+ from huggingface_hub import HfApi
10
 
11
  from src.distilabel_dataset_generator.apps.base import (
12
+ get_argilla_client,
13
  get_main_ui,
14
  get_pipeline_code_ui,
15
  hide_success_message,
 
16
  push_pipeline_code_to_hub,
17
  show_success_message_argilla,
18
  show_success_message_hub,
19
  validate_argilla_user_workspace_dataset,
20
  )
21
+ from src.distilabel_dataset_generator.apps.base import (
22
+ push_dataset_to_hub as push_to_hub_base,
23
+ )
24
+ from src.distilabel_dataset_generator.pipelines.base import (
25
+ DEFAULT_BATCH_SIZE,
26
+ )
27
+ from src.distilabel_dataset_generator.pipelines.embeddings import (
28
+ get_embeddings,
29
+ get_sentence_embedding_dimensions,
30
+ )
31
  from src.distilabel_dataset_generator.pipelines.textcat import (
32
  DEFAULT_DATASET_DESCRIPTIONS,
33
  DEFAULT_DATASETS,
34
  DEFAULT_SYSTEM_PROMPTS,
35
+ PROMPT_CREATION_PROMPT,
36
  generate_pipeline_code,
37
+ get_textcat_generator,
38
+ get_prompt_generator,
39
+ get_labeller_generator,
40
  )
41
 
42
+ TASK = "text_classification"
43
+
44
+ def push_dataset_to_hub(
45
+ dataframe: pd.DataFrame,
46
+ private: bool = True,
47
+ org_name: str = None,
48
+ repo_name: str = None,
49
+ oauth_token: Union[gr.OAuthToken, None] = None,
50
+ progress=gr.Progress(),
51
+ labels: List[str] = None,
52
+ num_labels: int = 1,
53
+ ):
54
+ original_dataframe = dataframe.copy(deep=True)
55
+ push_to_hub_base(
56
+ dataframe,
57
+ private,
58
+ org_name,
59
+ repo_name,
60
+ oauth_token,
61
+ progress,
62
+ labels,
63
+ num_labels,
64
+ task=TASK,
65
+ )
66
+ return original_dataframe
67
+
68
 
69
+ def push_dataset_to_argilla(
70
+ dataframe: pd.DataFrame,
71
+ dataset_name: str,
72
+ oauth_token: Union[gr.OAuthToken, None] = None,
73
+ progress=gr.Progress(),
74
+ num_labels: int = 1,
75
+ labels: List[str] = None,
76
+ ) -> pd.DataFrame:
77
+ original_dataframe = dataframe.copy(deep=True)
78
+ try:
79
+ progress(0.1, desc="Setting up user and workspace")
80
+ client = get_argilla_client()
81
+ hf_user = HfApi().whoami(token=oauth_token.token)["name"]
82
+ settings = rg.Settings(
83
+ fields=[
84
+ rg.TextField(
85
+ name="text",
86
+ description="The text classification data",
87
+ title="Text",
88
+ ),
89
+ ],
90
+ questions=[
91
+ (
92
+ rg.LabelQuestion(
93
+ name="label",
94
+ title="Label",
95
+ description="The label of the text",
96
+ labels=labels,
97
+ )
98
+ if num_labels == 1
99
+ else rg.MultiLabelQuestion(
100
+ name="labels",
101
+ title="Labels",
102
+ description="The labels of the conversation",
103
+ labels=labels,
104
+ )
105
+ ),
106
+ ],
107
+ metadata=[
108
+ rg.IntegerMetadataProperty(name="text_length", title="Text Length"),
109
+ ],
110
+ vectors=[
111
+ rg.VectorField(
112
+ name="text_embeddings",
113
+ dimensions=get_sentence_embedding_dimensions(),
114
+ )
115
+ ],
116
+ guidelines="Please review the text and provide or correct the label where needed.",
117
+ )
118
 
119
+ dataframe["text_length"] = dataframe["text"].apply(len)
120
+ dataframe["text_embeddings"] = get_embeddings(dataframe["text"])
121
+
122
+ progress(0.5, desc="Creating dataset")
123
+ rg_dataset = client.datasets(name=dataset_name, workspace=hf_user)
124
+ if rg_dataset is None:
125
+ rg_dataset = rg.Dataset(
126
+ name=dataset_name,
127
+ workspace=hf_user,
128
+ settings=settings,
129
+ client=client,
130
+ )
131
+ rg_dataset = rg_dataset.create()
132
+ progress(0.7, desc="Pushing dataset to Argilla")
133
+ hf_dataset = Dataset.from_pandas(dataframe)
134
+ rg_dataset.records.log(records=hf_dataset)
135
+ progress(1.0, desc="Dataset pushed to Argilla")
136
+ except Exception as e:
137
+ raise gr.Error(f"Error pushing dataset to Argilla: {e}")
138
+ return original_dataframe
139
 
140
+
141
+ def generate_system_prompt(dataset_description, progress=gr.Progress()):
142
+ progress(0.0, desc="Generating text classification task")
143
+ if dataset_description in DEFAULT_DATASET_DESCRIPTIONS:
144
+ index = DEFAULT_DATASET_DESCRIPTIONS.index(dataset_description)
145
+ if index < len(DEFAULT_SYSTEM_PROMPTS):
146
+ return DEFAULT_SYSTEM_PROMPTS[index]
147
+
148
+ progress(0.3, desc="Initializing text generation")
149
+ generate_description = get_prompt_generator()
150
+ progress(0.7, desc="Generating text classification task")
151
+ result = next(
152
+ generate_description.process(
153
+ [
154
+ {
155
+ "system_prompt": PROMPT_CREATION_PROMPT,
156
+ "instruction": dataset_description,
157
+ }
158
+ ]
159
+ )
160
+ )[0]["generation"]
161
+ progress(1.0, desc="Text classification task generated")
162
+ return result
163
 
164
 
165
  def generate_dataset(
166
  system_prompt: str,
167
  difficulty: str,
168
  clarity: str,
169
+ labels: List[str] = [],
170
+ num_labels: int = 2,
171
+ num_rows: int = 10,
172
+ is_sample: bool = False,
173
+ progress=gr.Progress(),
174
  ) -> pd.DataFrame:
175
+ progress(0.0, desc="(1/2) Generating text classification data")
176
+ textcat_generator = get_textcat_generator(difficulty, clarity, is_sample)
177
+ labeler_generator = get_labeller_generator(num_labels, labels, is_sample)
178
+ total_steps: int = num_rows * 2
179
+ batch_size = DEFAULT_BATCH_SIZE
180
+
181
+ # create text classification data
182
+ n_processed = 0
183
+ textcat_results = []
184
+ while n_processed < num_rows:
185
+ progress(
186
+ 0.5 * n_processed / num_rows,
187
+ total=total_steps,
188
+ desc="(1/2) Generating text classification data",
189
+ )
190
+ remaining_rows = num_rows - n_processed
191
+ batch_size = min(batch_size, remaining_rows)
192
+ inputs = [{"task": system_prompt} for _ in range(batch_size)]
193
+ batch = list(textcat_generator.process(inputs=inputs))
194
+ textcat_results.extend(batch[0])
195
+ n_processed += batch_size
196
+ for result in textcat_results:
197
+ result["text"] = result["input_text"]
198
+
199
+ # label text classification data
200
+ progress(0.5, desc="(1/2) Labeling text classification data")
201
+ if not is_sample:
202
+ n_processed = 0
203
+ labeler_results = []
204
+ while n_processed < num_rows:
205
+ progress(
206
+ 0.5 + 0.5 * n_processed / num_rows,
207
+ total=total_steps,
208
+ desc="(1/2) Generating text classification data",
209
+ )
210
+ batch = textcat_results[n_processed : n_processed + batch_size]
211
+ labels = list(labeler_generator.process(inputs=batch))
212
+ labeler_results.extend(labels[0])
213
+ n_processed += batch_size
214
+ progress(
215
+ 1,
216
+ total=total_steps,
217
+ desc="(2/2) Labeling text classification data",
218
+ )
219
+
220
+ # create final dataset
221
+ distiset_results = []
222
+ if is_sample:
223
+ for result in textcat_results:
224
+ record = {}
225
+ for relevant_keys in [
226
+ "text",
227
+ "label",
228
+ ]:
229
+ if relevant_keys in result:
230
+ record[relevant_keys] = result[relevant_keys]
231
+ distiset_results.append(record)
232
+ else:
233
+ for result in labeler_results:
234
+ record = {}
235
+ for relevant_keys in [
236
+ "text",
237
+ "labels",
238
+ ]:
239
+ if relevant_keys in result:
240
+ record[relevant_keys] = result[relevant_keys]
241
+ distiset_results.append(record)
242
+
243
+ dataframe = pd.DataFrame(distiset_results)
244
+ if num_labels == 1:
245
+ dataframe = dataframe.rename(columns={"labels": "label"})
246
+ progress(1.0, desc="Dataset generation completed")
247
+ return dataframe
248
 
249
 
250
  (
 
276
  default_datasets=DEFAULT_DATASETS,
277
  fn_generate_system_prompt=generate_system_prompt,
278
  fn_generate_dataset=generate_dataset,
279
+ task=TASK,
280
  )
281
 
282
+
283
+ def update_labels_based_on_checkbox(checked, system_prompt):
284
+ if checked:
285
+ pattern = r"'(\b\w+\b)'"
286
+ new_labels = re.findall(pattern, system_prompt)
287
+ gr.update(choices=new_labels)
288
+ return gr.update(value=new_labels)
289
+ else:
290
+ return gr.update(choices=[])
291
+
292
+
293
  with app:
294
  with main_ui:
295
  with custom_input_ui:
 
297
  choices=[
298
  ("High School", "high school"),
299
  ("College", "college"),
300
+ ("PhD", "PhD"),
301
  ("Mixed", "mixed"),
302
  ],
303
  value="mixed",
 
305
  )
306
  clarity = gr.Dropdown(
307
  choices=[
308
+ ("Clear", "CLEAR"),
309
  (
310
  "Understandable",
311
  "understandable with some effort",
312
  ),
313
+ ("Ambiguous", "ambiguous"),
314
  ("Mixed", "mixed"),
315
  ],
316
  value="mixed",
317
  label="Clarity",
318
  )
319
+ with gr.Row(variant="default"):
320
+ labels = gr.Dropdown(
321
+ choices=[],
322
+ allow_custom_value=True,
323
+ interactive=True,
324
+ label="Labels",
325
+ multiselect=True,
326
+ )
327
+ suggested_labels = gr.Checkbox(
328
+ label="Add suggested labels",
329
+ value=False,
330
+ interactive=True,
331
+ )
332
  num_labels = gr.Number(
333
  label="Number of labels", value=1, minimum=1, maximum=10
334
  )
335
  num_rows = gr.Number(
336
+ label="Number of rows",
337
+ value=1,
338
+ minimum=1,
339
+ maximum=500, ###### CHANGE AFTER TESTING
340
  )
341
 
342
  pipeline_code = get_pipeline_code_ui(
 
351
  )
352
 
353
  # define app triggers
354
+ suggested_labels.change(
355
+ update_labels_based_on_checkbox,
356
+ inputs=[suggested_labels, system_prompt],
357
+ outputs=labels,
358
+ )
359
+
360
  gr.on(
361
  triggers=[
362
  btn_generate_full_dataset.click,
 
386
  show_progress=True,
387
  ).success(
388
  fn=push_dataset_to_argilla,
389
+ inputs=[final_dataset, dataset_name, num_labels, labels],
390
  outputs=[final_dataset],
391
  show_progress=True,
392
  ).success(
 
405
  show_progress=True,
406
  ).then(
407
  fn=push_dataset_to_hub,
408
+ inputs=[final_dataset, private, org_name, repo_name, labels, num_labels],
409
  outputs=[final_dataset],
410
  show_progress=True,
411
  ).then(
 
424
  outputs=[success_message],
425
  ).then(
426
  fn=push_dataset_to_hub,
427
+ inputs=[final_dataset, private, org_name, repo_name, labels],
428
  outputs=[final_dataset],
429
  show_progress=True,
430
  ).then(
 
448
  show_progress=True,
449
  ).success(
450
  fn=push_dataset_to_argilla,
451
+ inputs=[final_dataset, dataset_name, num_labels, labels],
452
  outputs=[final_dataset],
453
  show_progress=True,
454
  ).success(
src/distilabel_dataset_generator/pipelines/textcat.py CHANGED
@@ -1,6 +1,42 @@
 
 
1
  from typing import List
 
 
2
 
3
- import pandas as pd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  DEFAULT_DATASET_DESCRIPTIONS = [
6
  "A dataset covering customer reviews for an e-commerce website.",
@@ -23,14 +59,14 @@ DEFAULT_DATASETS = [
23
  "Yesterday, the US stock market had a significant increase.",
24
  "New research suggests that the Earth is not a perfect sphere.",
25
  ],
26
- "label": [["economy", "politics"], ["science", "environment"]],
27
  }
28
  ),
29
  ]
30
 
31
  DEFAULT_SYSTEM_PROMPTS = [
32
- "Classify the following customer review as positive or negative.",
33
- "Classify the following news article into one or more categories.",
34
  ]
35
 
36
 
@@ -42,8 +78,118 @@ def generate_pipeline_code(
42
  num_labels: int,
43
  num_rows: int,
44
  ) -> str:
45
- return """
46
- from distilabel import Distilabel
 
 
 
 
 
47
 
48
- #### PIPELINE CODE HERE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
  from typing import List
4
+ from distilabel.llms import InferenceEndpointsLLM
5
+ from distilabel.steps.tasks import GenerateTextClassificationData, TextClassification, TextGeneration
6
 
7
+ from src.distilabel_dataset_generator.pipelines.base import (
8
+ MODEL,
9
+ _get_next_api_key,
10
+ )
11
+
12
+ PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
13
+
14
+ Your task is to write a prompt following the instruction of the user. Respond with the prompt and nothing else.
15
+
16
+ The prompt you write should follow the same style and structure as the following example prompts, clearly specifying the possible classification labels where applicable:
17
+
18
+ Classify the following customer review of a cinema as either 'positive' or 'negative'.
19
+
20
+ Classify the following news article into one or more of the following categories: 'politics', 'sports', 'technology', 'entertainment', 'health', 'business', 'environment', 'education', 'science', 'international'.
21
+
22
+ Determine the sentiment of the following social media post: 'ambiguous', 'sarcastic', 'informative', 'emotional'.
23
+
24
+ Identify the issue category for the following technical support ticket: 'billing', 'technical', 'account', 'shipping', 'returns', 'installation', 'subscription'.
25
+
26
+ Classify the following movie review into one of the following categories: 'critical', 'praise', 'disappointed', 'enthusiastic'.
27
+
28
+ Determine the level of customer satisfaction from the following customer service transcript: 'satisfied', 'dissatisfied', 'highly satisfied', 'somewhat dissatisfied', 'indifferent'.
29
+
30
+ Categorize the following product description into one of the following product types: 'smartphone', 'laptop', 'tablet', 'smartwatch', 'e-reader', 'headphones'.
31
+
32
+ Classify the following tweet as expressing either 'support' or 'opposition' to the political event discussed.
33
+
34
+ Classify the following restaurant review into one of the following categories: 'food quality', 'service', 'ambiance', or 'price'.
35
+
36
+ Classify the following blog post based on its primary fashion trend or style: 'casual', 'formal', 'streetwear', 'vintage' or 'sustainable fashion'.
37
+
38
+ User dataset description:
39
+ """
40
 
41
  DEFAULT_DATASET_DESCRIPTIONS = [
42
  "A dataset covering customer reviews for an e-commerce website.",
 
59
  "Yesterday, the US stock market had a significant increase.",
60
  "New research suggests that the Earth is not a perfect sphere.",
61
  ],
62
+ "labels": [["economy", "politics"], ["science", "environment"]],
63
  }
64
  ),
65
  ]
66
 
67
  DEFAULT_SYSTEM_PROMPTS = [
68
+ "Classify the following customer review as either 'positive' or 'negative'.",
69
+ "Classify the following news article into one of the following categories: 'politics', 'economy', 'environment', 'science', 'health'.",
70
  ]
71
 
72
 
 
78
  num_labels: int,
79
  num_rows: int,
80
  ) -> str:
81
+ base = f"""
82
+ # Requirements: `pip install distilabel[hf-inference-endpoints]`
83
+ import os
84
+ from distilabel.llms import InferenceEndpointsLLM
85
+ from distilabel.pipeline import Pipeline
86
+ from distilabel.steps import LoadDataFromDicts
87
+ from distilabel.steps.tasks import GenerateTextClassificationData
88
 
89
+ MODEL = "{MODEL}"
90
+ TEXTCAT_TASK = "{system_prompt}"
91
+ os.environ["HF_TOKEN"] = (
92
+ "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
93
+ )
94
+
95
+ with Pipeline(name="textcat") as pipeline:
96
+ textcat_generation = GenerateTextClassificationData(
97
+ llm=InferenceEndpointsLLM(
98
+ model_id=MODEL,
99
+ tokenizer_id=MODEL,
100
+ api_key=_get_next_api_key(),
101
+ generation_kwargs={{
102
+ "temperature": 0.8,
103
+ "max_new_tokens": 2048,
104
+ }},
105
+ ),
106
+ difficulty={None if difficulty == "mixed" else difficulty},
107
+ clarity={None if clarity == "mixed" else clarity},
108
+ num_generations={num_rows},
109
+ )
110
+ keep_columns = KeepColumns(
111
+ columns=["input_text", "model_name"],
112
+ )
113
  """
114
+ if num_labels > 1:
115
+ return base + """
116
+ textcat_generation >> keep_columns >> textcat_labeler
117
+
118
+ if __name__ == "__main__":
119
+ distiset = pipeline.run()
120
+ """
121
+
122
+ return f"""
123
+ textcat_labeler = TextClassification(
124
+ llm=InferenceEndpointsLLM(
125
+ model_id=MODEL,
126
+ tokenizer_id=MODEL,
127
+ api_key=_get_next_api_key(),
128
+ generation_kwargs={{
129
+ "temperature": 0.8,
130
+ "max_new_tokens": 2048,
131
+ }},
132
+ ),
133
+ n= {num_labels},
134
+ available_labels={labels},
135
+ )
136
+
137
+ textcat_generation >> keep_columns >> textcat_labeler
138
+
139
+ if __name__ == "__main__":
140
+ distiset = pipeline.run()
141
+
142
+ """
143
+
144
+ def get_textcat_generator(difficulty, clarity, is_sample):
145
+ textcat_generator = GenerateTextClassificationData(
146
+ llm=InferenceEndpointsLLM(
147
+ model_id=MODEL,
148
+ tokenizer_id=MODEL,
149
+ api_key=_get_next_api_key(),
150
+ generation_kwargs={
151
+ "temperature": 0.8,
152
+ "max_new_tokens": 256 if is_sample else 1024,
153
+ },
154
+ ),
155
+ difficulty=None if difficulty == "mixed" else difficulty,
156
+ clarity=None if clarity == "mixed" else clarity,
157
+ )
158
+ textcat_generator.load()
159
+ return textcat_generator
160
+
161
+
162
+ def get_labeller_generator(num_labels, labels, is_sample):
163
+ labeller_generator = TextClassification(
164
+ llm=InferenceEndpointsLLM(
165
+ model_id=MODEL,
166
+ tokenizer_id=MODEL,
167
+ api_key=_get_next_api_key(),
168
+ generation_kwargs={
169
+ "temperature": 0.8,
170
+ "max_new_tokens": 256 if is_sample else 1024,
171
+ },
172
+ ),
173
+ n= num_labels,
174
+ available_labels=labels,
175
+ )
176
+ labeller_generator.load()
177
+ return labeller_generator
178
+
179
+
180
+ def get_prompt_generator():
181
+ prompt_generator = TextGeneration(
182
+ llm=InferenceEndpointsLLM(
183
+ api_key=_get_next_api_key(),
184
+ model_id=MODEL,
185
+ tokenizer_id=MODEL,
186
+ generation_kwargs={
187
+ "temperature": 0.8,
188
+ "max_new_tokens": 2048,
189
+ "do_sample": True,
190
+ },
191
+ ),
192
+ use_system_prompt=True,
193
+ )
194
+ prompt_generator.load()
195
+ return prompt_generator