sdiazlor HF staff davidberenstein1957 HF staff commited on
Commit
3c2fc33
1 Parent(s): abbcdf5

feat/text-classification (#11)

Browse files

- feat: Add basic layout textcat (f6a1e437026116667f981091a1d55f2d0266b1cf)
- Merge branch 'main' into pr/11 (2e2beb7bc7be95d2222fecb5af5158b9325ce533)
- refactor: re-usable gradio component (54d4d8d8a537f2114a7fa56b487591a8dee99e92)
- feat: Add support for textcat (adc79cea2eb743ffc85665fd188b040bc80983ba)
- feat: Add buttons to align with textcat and textcatgenerator arguments (288d796777464cd54105e1dcc8ebf28b9fdc09dd)
- Add working textcat version (229dcf3cb0731f80012197b5ebf6815b4261d948)
- feat: Address edge cases and improve textcat UI (6a8a817258c9851000ecb128041b4872be076b00)
- fix: remove typo when copying runnable pipeline (5c28c1d076376ec1b3910666cd6311303d5c500b)
- fix: minor bug and feat:use seuqence(classlabel) for multilabel (28b1761da8c831cd53fad907bbea2603ab29c4f7)


Co-authored-by: David Berenstein <[email protected]>

.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ synthetic-data-generator
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
 
3
  from src.distilabel_dataset_generator.apps.faq import app as faq_app
4
  from src.distilabel_dataset_generator.apps.sft import app as sft_app
 
5
 
6
  theme = gr.themes.Monochrome(
7
  spacing_size="md",
@@ -25,8 +26,8 @@ css = """
25
  """
26
 
27
  demo = gr.TabbedInterface(
28
- [sft_app, faq_app],
29
- ["Supervised Fine-Tuning", "FAQ"],
30
  css=css,
31
  title="""
32
  <style>
 
2
 
3
  from src.distilabel_dataset_generator.apps.faq import app as faq_app
4
  from src.distilabel_dataset_generator.apps.sft import app as sft_app
5
+ from src.distilabel_dataset_generator.apps.textcat import app as textcat_app
6
 
7
  theme = gr.themes.Monochrome(
8
  spacing_size="md",
 
26
  """
27
 
28
  demo = gr.TabbedInterface(
29
+ [textcat_app, sft_app, faq_app],
30
+ ["Text Classification", "Supervised Fine-Tuning", "FAQ"],
31
  css=css,
32
  title="""
33
  <style>
requirements.txt CHANGED
@@ -3,4 +3,5 @@ gradio[oauth]
3
  distilabel[hf-inference-endpoints,argilla]
4
  beautifulsoup4
5
  sentence-transformers
6
- model2vec
 
 
3
  distilabel[hf-inference-endpoints,argilla]
4
  beautifulsoup4
5
  sentence-transformers
6
+ model2vec
7
+ outlines
src/distilabel_dataset_generator/apps/base.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import uuid
3
+ from typing import Any, Callable, List, Tuple, Union
4
+
5
+ import argilla as rg
6
+ import gradio as gr
7
+ import pandas as pd
8
+ from datasets import ClassLabel, Dataset, Features, Sequence, Value
9
+ from distilabel.distiset import Distiset
10
+ from gradio import OAuthToken
11
+ from huggingface_hub import HfApi, upload_file
12
+
13
+ from src.distilabel_dataset_generator.utils import (
14
+ _LOGGED_OUT_CSS,
15
+ get_argilla_client,
16
+ list_orgs,
17
+ swap_visibilty,
18
+ get_login_button,
19
+ )
20
+
21
+ TEXTCAT_TASK = "text_classification"
22
+ SFT_TASK = "supervised_fine_tuning"
23
+
24
+
25
+ def get_main_ui(
26
+ default_dataset_descriptions: List[str],
27
+ default_system_prompts: List[str],
28
+ default_datasets: List[pd.DataFrame],
29
+ fn_generate_system_prompt: Callable,
30
+ fn_generate_dataset: Callable,
31
+ task: str,
32
+ ):
33
+ def fn_generate_sample_dataset(system_prompt, progress=gr.Progress()):
34
+ if system_prompt in default_system_prompts:
35
+ index = default_system_prompts.index(system_prompt)
36
+ if index < len(default_datasets):
37
+ return default_datasets[index]
38
+ if task == TEXTCAT_TASK:
39
+ result = fn_generate_dataset(
40
+ system_prompt=system_prompt,
41
+ difficulty="mixed",
42
+ clarity="mixed",
43
+ labels=[],
44
+ num_labels=1,
45
+ num_rows=1,
46
+ progress=progress,
47
+ is_sample=True,
48
+ )
49
+ else:
50
+ result = fn_generate_dataset(
51
+ system_prompt=system_prompt,
52
+ num_turns=1,
53
+ num_rows=1,
54
+ progress=progress,
55
+ is_sample=True,
56
+ )
57
+ return result
58
+
59
+ with gr.Blocks(
60
+ title="🧬 Synthetic Data Generator",
61
+ head="🧬 Synthetic Data Generator",
62
+ css=_LOGGED_OUT_CSS,
63
+ ) as app:
64
+ with gr.Row():
65
+ gr.Markdown(
66
+ "Want to run this locally or with other LLMs? Take a look at the FAQ tab. distilabel Synthetic Data Generator is free, we use the authentication token to push the dataset to the Hugging Face Hub and not for data generation."
67
+ )
68
+ with gr.Row():
69
+ gr.Column()
70
+ get_login_button()
71
+ gr.Column()
72
+
73
+ gr.Markdown("## Iterate on a sample dataset")
74
+ with gr.Column() as main_ui:
75
+ (
76
+ dataset_description,
77
+ examples,
78
+ btn_generate_system_prompt,
79
+ system_prompt,
80
+ sample_dataset,
81
+ btn_generate_sample_dataset,
82
+ ) = get_iterate_on_sample_dataset_ui(
83
+ default_dataset_descriptions=default_dataset_descriptions,
84
+ default_system_prompts=default_system_prompts,
85
+ default_datasets=default_datasets,
86
+ task=task,
87
+ )
88
+ gr.Markdown("## Generate full dataset")
89
+ gr.Markdown(
90
+ "Once you're satisfied with the sample, generate a larger dataset and push it to Argilla or the Hugging Face Hub."
91
+ )
92
+ with gr.Row(variant="panel") as custom_input_ui:
93
+ pass
94
+
95
+ (
96
+ dataset_name,
97
+ add_to_existing_dataset,
98
+ btn_generate_full_dataset_argilla,
99
+ btn_generate_and_push_to_argilla,
100
+ btn_push_to_argilla,
101
+ org_name,
102
+ repo_name,
103
+ private,
104
+ btn_generate_full_dataset,
105
+ btn_generate_and_push_to_hub,
106
+ btn_push_to_hub,
107
+ final_dataset,
108
+ success_message,
109
+ ) = get_push_to_ui(default_datasets)
110
+
111
+ sample_dataset.change(
112
+ fn=lambda x: x,
113
+ inputs=[sample_dataset],
114
+ outputs=[final_dataset],
115
+ )
116
+
117
+ btn_generate_system_prompt.click(
118
+ fn=fn_generate_system_prompt,
119
+ inputs=[dataset_description],
120
+ outputs=[system_prompt],
121
+ show_progress=True,
122
+ ).then(
123
+ fn=fn_generate_sample_dataset,
124
+ inputs=[system_prompt],
125
+ outputs=[sample_dataset],
126
+ show_progress=True,
127
+ )
128
+
129
+ btn_generate_sample_dataset.click(
130
+ fn=fn_generate_sample_dataset,
131
+ inputs=[system_prompt],
132
+ outputs=[sample_dataset],
133
+ show_progress=True,
134
+ )
135
+
136
+ app.load(fn=swap_visibilty, outputs=main_ui)
137
+ app.load(get_org_dropdown, outputs=[org_name])
138
+
139
+ return (
140
+ app,
141
+ main_ui,
142
+ custom_input_ui,
143
+ dataset_description,
144
+ examples,
145
+ btn_generate_system_prompt,
146
+ system_prompt,
147
+ sample_dataset,
148
+ btn_generate_sample_dataset,
149
+ dataset_name,
150
+ add_to_existing_dataset,
151
+ btn_generate_full_dataset_argilla,
152
+ btn_generate_and_push_to_argilla,
153
+ btn_push_to_argilla,
154
+ org_name,
155
+ repo_name,
156
+ private,
157
+ btn_generate_full_dataset,
158
+ btn_generate_and_push_to_hub,
159
+ btn_push_to_hub,
160
+ final_dataset,
161
+ success_message,
162
+ )
163
+
164
+
165
+ def validate_argilla_user_workspace_dataset(
166
+ dataset_name: str,
167
+ final_dataset: pd.DataFrame,
168
+ add_to_existing_dataset: bool,
169
+ oauth_token: Union[OAuthToken, None] = None,
170
+ progress=gr.Progress(),
171
+ ) -> str:
172
+ progress(0, desc="Validating dataset configuration")
173
+ hf_user = HfApi().whoami(token=oauth_token.token)["name"]
174
+ client = get_argilla_client()
175
+ if dataset_name is None or dataset_name == "":
176
+ raise gr.Error("Dataset name is required")
177
+ # Create user if it doesn't exist
178
+ rg_user = client.users(username=hf_user)
179
+ if rg_user is None:
180
+ rg_user = client.users.add(
181
+ rg.User(username=hf_user, role="admin", password=str(uuid.uuid4()))
182
+ )
183
+ # Create workspace if it doesn't exist
184
+ workspace = client.workspaces(name=hf_user)
185
+ if workspace is None:
186
+ workspace = client.workspaces.add(rg.Workspace(name=hf_user))
187
+ workspace.add_user(hf_user)
188
+ # Check if dataset exists
189
+ dataset = client.datasets(name=dataset_name, workspace=hf_user)
190
+ if dataset and not add_to_existing_dataset:
191
+ raise gr.Error(f"Dataset {dataset_name} already exists")
192
+ return final_dataset
193
+
194
+
195
+ def get_org_dropdown(oauth_token: OAuthToken = None):
196
+ orgs = list_orgs(oauth_token)
197
+ return gr.Dropdown(
198
+ label="Organization",
199
+ choices=orgs,
200
+ value=orgs[0] if orgs else None,
201
+ allow_custom_value=True,
202
+ )
203
+
204
+
205
+ def get_push_to_ui(default_datasets):
206
+ with gr.Column() as push_to_ui:
207
+ (
208
+ dataset_name,
209
+ add_to_existing_dataset,
210
+ btn_generate_full_dataset_argilla,
211
+ btn_generate_and_push_to_argilla,
212
+ btn_push_to_argilla,
213
+ ) = get_argilla_tab()
214
+ (
215
+ org_name,
216
+ repo_name,
217
+ private,
218
+ btn_generate_full_dataset,
219
+ btn_generate_and_push_to_hub,
220
+ btn_push_to_hub,
221
+ ) = get_hf_tab()
222
+ final_dataset = get_final_dataset_row(default_datasets)
223
+ success_message = get_success_message_row()
224
+ return (
225
+ dataset_name,
226
+ add_to_existing_dataset,
227
+ btn_generate_full_dataset_argilla,
228
+ btn_generate_and_push_to_argilla,
229
+ btn_push_to_argilla,
230
+ org_name,
231
+ repo_name,
232
+ private,
233
+ btn_generate_full_dataset,
234
+ btn_generate_and_push_to_hub,
235
+ btn_push_to_hub,
236
+ final_dataset,
237
+ success_message,
238
+ )
239
+
240
+
241
+ def get_iterate_on_sample_dataset_ui(
242
+ default_dataset_descriptions: List[str],
243
+ default_system_prompts: List[str],
244
+ default_datasets: List[pd.DataFrame],
245
+ task: str,
246
+ ):
247
+ with gr.Column():
248
+ dataset_description = gr.TextArea(
249
+ label="Give a precise description of your desired application. Check the examples for inspiration.",
250
+ value=default_dataset_descriptions[0],
251
+ lines=2,
252
+ )
253
+ examples = gr.Examples(
254
+ elem_id="system_prompt_examples",
255
+ examples=[[example] for example in default_dataset_descriptions],
256
+ inputs=[dataset_description],
257
+ )
258
+ with gr.Row():
259
+ gr.Column(scale=1)
260
+ btn_generate_system_prompt = gr.Button(
261
+ value="Generate system prompt and sample dataset"
262
+ )
263
+ gr.Column(scale=1)
264
+
265
+ system_prompt = gr.TextArea(
266
+ label="System prompt for dataset generation. You can tune it and regenerate the sample.",
267
+ value=default_system_prompts[0],
268
+ lines=2 if task == TEXTCAT_TASK else 5,
269
+ )
270
+
271
+ with gr.Row():
272
+ sample_dataset = gr.Dataframe(
273
+ value=default_datasets[0],
274
+ label="Sample dataset. Prompts and completions truncated to 256 tokens.",
275
+ interactive=False,
276
+ wrap=True,
277
+ )
278
+
279
+ with gr.Row():
280
+ gr.Column(scale=1)
281
+ btn_generate_sample_dataset = gr.Button(
282
+ value="Generate sample dataset",
283
+ )
284
+ gr.Column(scale=1)
285
+
286
+ return (
287
+ dataset_description,
288
+ examples,
289
+ btn_generate_system_prompt,
290
+ system_prompt,
291
+ sample_dataset,
292
+ btn_generate_sample_dataset,
293
+ )
294
+
295
+
296
+ def get_pipeline_code_ui(pipeline_code: str) -> gr.Code:
297
+ gr.Markdown("## Or run this pipeline locally with distilabel")
298
+ gr.Markdown(
299
+ "You can run this pipeline locally with distilabel. For more information, please refer to the [distilabel documentation](https://distilabel.argilla.io/) or go to the FAQ tab at the top of the page for more information."
300
+ )
301
+ with gr.Accordion(
302
+ "Run this pipeline using distilabel",
303
+ open=False,
304
+ ):
305
+ pipeline_code = gr.Code(
306
+ value=pipeline_code,
307
+ language="python",
308
+ label="Distilabel Pipeline Code",
309
+ )
310
+ return pipeline_code
311
+
312
+
313
+ def get_argilla_tab() -> Tuple[Any]:
314
+ with gr.Tab(label="Argilla"):
315
+ if get_argilla_client() is not None:
316
+ with gr.Row(variant="panel"):
317
+ dataset_name = gr.Textbox(
318
+ label="Dataset name",
319
+ placeholder="dataset_name",
320
+ value="my-distiset",
321
+ )
322
+ add_to_existing_dataset = gr.Checkbox(
323
+ label="Allow adding records to existing dataset",
324
+ info="When selected, you do need to ensure the dataset options are the same as in the existing dataset.",
325
+ value=False,
326
+ interactive=True,
327
+ scale=1,
328
+ )
329
+
330
+ with gr.Row(variant="panel"):
331
+ btn_generate_full_dataset_argilla = gr.Button(
332
+ value="Generate", variant="primary", scale=2
333
+ )
334
+ btn_generate_and_push_to_argilla = gr.Button(
335
+ value="Generate and Push to Argilla",
336
+ variant="primary",
337
+ scale=2,
338
+ )
339
+ btn_push_to_argilla = gr.Button(
340
+ value="Push to Argilla", variant="primary", scale=2
341
+ )
342
+ else:
343
+ gr.Markdown(
344
+ "Please add `ARGILLA_API_URL` and `ARGILLA_API_KEY` to use Argilla or export the dataset to the Hugging Face Hub."
345
+ )
346
+ return (
347
+ dataset_name,
348
+ add_to_existing_dataset,
349
+ btn_generate_full_dataset_argilla,
350
+ btn_generate_and_push_to_argilla,
351
+ btn_push_to_argilla,
352
+ )
353
+
354
+
355
+ def get_hf_tab() -> Tuple[Any]:
356
+ with gr.Tab("Hugging Face Hub"):
357
+ with gr.Row(variant="panel"):
358
+ org_name = get_org_dropdown()
359
+ repo_name = gr.Textbox(
360
+ label="Repo name",
361
+ placeholder="dataset_name",
362
+ value="my-distiset",
363
+ )
364
+ private = gr.Checkbox(
365
+ label="Private dataset",
366
+ value=True,
367
+ interactive=True,
368
+ scale=1,
369
+ )
370
+ with gr.Row(variant="panel"):
371
+ btn_generate_full_dataset = gr.Button(
372
+ value="Generate", variant="primary", scale=2
373
+ )
374
+ btn_generate_and_push_to_hub = gr.Button(
375
+ value="Generate and Push to Hub", variant="primary", scale=2
376
+ )
377
+ btn_push_to_hub = gr.Button(value="Push to Hub", variant="primary", scale=2)
378
+ return (
379
+ org_name,
380
+ repo_name,
381
+ private,
382
+ btn_generate_full_dataset,
383
+ btn_generate_and_push_to_hub,
384
+ btn_push_to_hub,
385
+ )
386
+
387
+
388
+ def push_pipeline_code_to_hub(
389
+ pipeline_code: str,
390
+ org_name: str,
391
+ repo_name: str,
392
+ oauth_token: Union[OAuthToken, None] = None,
393
+ progress=gr.Progress(),
394
+ ):
395
+ repo_id = _check_push_to_hub(org_name, repo_name)
396
+ progress(0.1, desc="Uploading pipeline code")
397
+ with io.BytesIO(pipeline_code.encode("utf-8")) as f:
398
+ upload_file(
399
+ path_or_fileobj=f,
400
+ path_in_repo="pipeline.py",
401
+ repo_id=repo_id,
402
+ repo_type="dataset",
403
+ token=oauth_token.token,
404
+ commit_message="Include pipeline script",
405
+ create_pr=False,
406
+ )
407
+ progress(1.0, desc="Pipeline code uploaded")
408
+
409
+
410
+ def push_dataset_to_hub(
411
+ dataframe: pd.DataFrame,
412
+ private: bool = True,
413
+ org_name: str = None,
414
+ repo_name: str = None,
415
+ oauth_token: Union[OAuthToken, None] = None,
416
+ progress=gr.Progress(),
417
+ labels: List[str] = None,
418
+ num_labels: int = None,
419
+ task: str = TEXTCAT_TASK,
420
+ ) -> pd.DataFrame:
421
+ progress(0.1, desc="Setting up dataset")
422
+ repo_id = _check_push_to_hub(org_name, repo_name)
423
+
424
+ if task == TEXTCAT_TASK:
425
+ if num_labels == 1:
426
+ features = Features(
427
+ {"text": Value("string"), "label": ClassLabel(names=labels)}
428
+ )
429
+ else:
430
+ features = Features({
431
+ "text": Value("string"),
432
+ "labels": Sequence(feature=ClassLabel(names=labels))
433
+ })
434
+ distiset = Distiset({
435
+ "default": Dataset.from_pandas(dataframe, features=features)
436
+ })
437
+ else:
438
+ distiset = Distiset({
439
+ "default": Dataset.from_pandas(dataframe)
440
+ })
441
+ progress(0.2, desc="Pushing dataset to hub")
442
+ distiset.push_to_hub(
443
+ repo_id=repo_id,
444
+ private=private,
445
+ include_script=False,
446
+ token=oauth_token.token,
447
+ create_pr=False,
448
+ )
449
+ progress(1.0, desc="Dataset pushed to hub")
450
+ return dataframe
451
+
452
+
453
+ def _check_push_to_hub(org_name, repo_name):
454
+ repo_id = (
455
+ f"{org_name}/{repo_name}"
456
+ if repo_name is not None and org_name is not None
457
+ else None
458
+ )
459
+ if repo_id is not None:
460
+ if not all([repo_id, org_name, repo_name]):
461
+ raise gr.Error(
462
+ "Please provide a `repo_name` and `org_name` to push the dataset to."
463
+ )
464
+ return repo_id
465
+
466
+
467
+ def get_final_dataset_row(default_datasets) -> gr.Dataframe:
468
+ with gr.Row():
469
+ final_dataset = gr.Dataframe(
470
+ value=default_datasets[0],
471
+ label="Generated dataset",
472
+ interactive=False,
473
+ wrap=True,
474
+ min_width=300,
475
+ )
476
+ return final_dataset
477
+
478
+
479
+ def get_success_message_row() -> gr.Markdown:
480
+ with gr.Row():
481
+ success_message = gr.Markdown(visible=False)
482
+ return success_message
483
+
484
+
485
+ def show_success_message_argilla() -> gr.Markdown:
486
+ client = get_argilla_client()
487
+ argilla_api_url = client.api_url
488
+ return gr.Markdown(
489
+ value=f"""
490
+ <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
491
+ <h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
492
+ <p style="margin-top: 0.5em;">
493
+ Your dataset is now available at:
494
+ <a href="{argilla_api_url}" target="_blank" style="color: #1565c0; text-decoration: none;">
495
+ {argilla_api_url}
496
+ </a>
497
+ <br>Unfamiliar with Argilla? Here are some docs to help you get started:
498
+ <br>• <a href="https://docs.argilla.io/latest/how_to_guides/annotate/" target="_blank">How to curate data in Argilla</a>
499
+ <br>• <a href="https://docs.argilla.io/latest/how_to_guides/import_export/" target="_blank">How to export data once you have reviewed the dataset</a>
500
+ </p>
501
+ </div>
502
+ """,
503
+ visible=True,
504
+ )
505
+
506
+
507
+ def show_success_message_hub(org_name, repo_name) -> gr.Markdown:
508
+ return gr.Markdown(
509
+ value=f"""
510
+ <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
511
+ <h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
512
+ <p style="margin-top: 0.5em;">
513
+ The generated dataset is in the right format for fine-tuning with TRL, AutoTrain or other frameworks.
514
+ Your dataset is now available at:
515
+ <a href="https://huggingface.co/datasets/{org_name}/{repo_name}" target="_blank" style="color: #1565c0; text-decoration: none;">
516
+ https://huggingface.co/datasets/{org_name}/{repo_name}
517
+ </a>
518
+ </p>
519
+ </div>
520
+ """,
521
+ visible=True,
522
+ )
523
+
524
+
525
+ def hide_success_message() -> gr.Markdown:
526
+ return gr.Markdown(visible=False)
src/distilabel_dataset_generator/apps/faq.py CHANGED
@@ -15,7 +15,7 @@ with gr.Blocks() as app:
15
  <p>This tool simplifies the process of creating custom datasets, enabling you to:</p>
16
  <ul>
17
  <li>Define the characteristics of your desired application</li>
18
- <li>Generate system prompts automatically</li>
19
  <li>Create sample datasets for quick iteration</li>
20
  <li>Produce full-scale datasets with customizable parameters</li>
21
  <li>Push your generated datasets directly to the Hugging Face Hub</li>
 
15
  <p>This tool simplifies the process of creating custom datasets, enabling you to:</p>
16
  <ul>
17
  <li>Define the characteristics of your desired application</li>
18
+ <li>Generate system prompts and tasks automatically</li>
19
  <li>Create sample datasets for quick iteration</li>
20
  <li>Produce full-scale datasets with customizable parameters</li>
21
  <li>Push your generated datasets directly to the Hugging Face Hub</li>
src/distilabel_dataset_generator/apps/sft.py CHANGED
@@ -1,6 +1,4 @@
1
  import ast
2
- import io
3
- import uuid
4
  from typing import Dict, List, Union
5
 
6
  import argilla as rg
@@ -8,17 +6,29 @@ import gradio as gr
8
  import pandas as pd
9
  from datasets import Dataset
10
  from distilabel.distiset import Distiset
11
- from distilabel.steps.tasks.text_generation import TextGeneration
12
- from gradio.oauth import OAuthToken
13
- from huggingface_hub import upload_file
14
- from huggingface_hub.hf_api import HfApi
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  from src.distilabel_dataset_generator.pipelines.embeddings import (
17
  get_embeddings,
18
  get_sentence_embedding_dimensions,
19
  )
20
  from src.distilabel_dataset_generator.pipelines.sft import (
21
- DEFAULT_BATCH_SIZE,
22
  DEFAULT_DATASET_DESCRIPTIONS,
23
  DEFAULT_DATASETS,
24
  DEFAULT_SYSTEM_PROMPTS,
@@ -28,222 +38,52 @@ from src.distilabel_dataset_generator.pipelines.sft import (
28
  get_prompt_generator,
29
  get_response_generator,
30
  )
31
- from src.distilabel_dataset_generator.utils import (
32
- get_argilla_client,
33
- get_login_button,
34
- get_org_dropdown,
35
- swap_visibilty,
36
- )
37
-
38
-
39
- def convert_to_list_of_dicts(messages: str) -> List[Dict[str, str]]:
40
- return ast.literal_eval(
41
- messages.replace("'user'}", "'user'},")
42
- .replace("'system'}", "'system'},")
43
- .replace("'assistant'}", "'assistant'},")
44
- )
45
 
 
46
 
47
- def generate_system_prompt(dataset_description, progress=gr.Progress()):
48
- progress(0.0, desc="Generating system prompt")
49
- if dataset_description in DEFAULT_DATASET_DESCRIPTIONS:
50
- index = DEFAULT_DATASET_DESCRIPTIONS.index(dataset_description)
51
- if index < len(DEFAULT_SYSTEM_PROMPTS):
52
- return DEFAULT_SYSTEM_PROMPTS[index]
53
 
54
- progress(0.3, desc="Initializing text generation")
55
- generate_description: TextGeneration = get_prompt_generator()
56
- progress(0.7, desc="Generating system prompt")
57
- result = next(
58
- generate_description.process(
59
- [
60
- {
61
- "system_prompt": PROMPT_CREATION_PROMPT,
62
- "instruction": dataset_description,
63
- }
64
- ]
65
  )
66
- )[0]["generation"]
67
- progress(1.0, desc="System prompt generated")
68
- return result
69
-
70
-
71
- def generate_sample_dataset(system_prompt, progress=gr.Progress()):
72
- if system_prompt in DEFAULT_SYSTEM_PROMPTS:
73
- index = DEFAULT_SYSTEM_PROMPTS.index(system_prompt)
74
- if index < len(DEFAULT_DATASETS):
75
- return DEFAULT_DATASETS[index]
76
- result = generate_dataset(
77
- system_prompt, num_turns=1, num_rows=1, progress=progress, is_sample=True
78
- )
79
- return result
80
-
81
-
82
- def _check_push_to_hub(org_name, repo_name):
83
- repo_id = (
84
- f"{org_name}/{repo_name}"
85
- if repo_name is not None and org_name is not None
86
- else None
87
- )
88
- if repo_id is not None:
89
- if not all([repo_id, org_name, repo_name]):
90
- raise gr.Error(
91
- "Please provide a `repo_name` and `org_name` to push the dataset to."
92
- )
93
- return repo_id
94
-
95
-
96
- def generate_dataset(
97
- system_prompt: str,
98
- num_turns: int = 1,
99
- num_rows: int = 5,
100
- is_sample: bool = False,
101
- progress=gr.Progress(),
102
- ) -> pd.DataFrame:
103
- progress(0.0, desc="(1/2) Generating instructions")
104
- magpie_generator = get_magpie_generator(
105
- num_turns, num_rows, system_prompt, is_sample
106
- )
107
- response_generator = get_response_generator(num_turns, system_prompt, is_sample)
108
- total_steps: int = num_rows * 2
109
- batch_size = DEFAULT_BATCH_SIZE
110
 
111
- # create instructions
112
- n_processed = 0
113
- magpie_results = []
114
- while n_processed < num_rows:
115
- progress(
116
- 0.5 * n_processed / num_rows,
117
- total=total_steps,
118
- desc="(1/2) Generating instructions",
119
  )
120
- remaining_rows = num_rows - n_processed
121
- batch_size = min(batch_size, remaining_rows)
122
- inputs = [{"system_prompt": system_prompt} for _ in range(batch_size)]
123
- batch = list(magpie_generator.process(inputs=inputs))
124
- magpie_results.extend(batch[0])
125
- n_processed += batch_size
126
- progress(0.5, desc="(1/2) Generating instructions")
127
-
128
- # generate responses
129
- n_processed = 0
130
- response_results = []
131
- if num_turns == 1:
132
- while n_processed < num_rows:
133
- progress(
134
- 0.5 + 0.5 * n_processed / num_rows,
135
- total=total_steps,
136
- desc="(2/2) Generating responses",
137
- )
138
- batch = magpie_results[n_processed : n_processed + batch_size]
139
- responses = list(response_generator.process(inputs=batch))
140
- response_results.extend(responses[0])
141
- n_processed += batch_size
142
- for result in response_results:
143
- result["prompt"] = result["instruction"]
144
- result["completion"] = result["generation"]
145
- result["system_prompt"] = system_prompt
146
- else:
147
- for result in magpie_results:
148
- result["conversation"].insert(
149
- 0, {"role": "system", "content": system_prompt}
150
- )
151
- result["messages"] = result["conversation"]
152
- while n_processed < num_rows:
153
- progress(
154
- 0.5 + 0.5 * n_processed / num_rows,
155
- total=total_steps,
156
- desc="(2/2) Generating responses",
157
- )
158
- batch = magpie_results[n_processed : n_processed + batch_size]
159
- responses = list(response_generator.process(inputs=batch))
160
- response_results.extend(responses[0])
161
- n_processed += batch_size
162
- for result in response_results:
163
- result["messages"].append(
164
- {"role": "assistant", "content": result["generation"]}
165
- )
166
- progress(
167
- 1,
168
- total=total_steps,
169
- desc="(2/2) Generating responses",
170
- )
171
-
172
- # create distiset
173
- distiset_results = []
174
- for result in response_results:
175
- record = {}
176
- for relevant_keys in [
177
- "messages",
178
- "prompt",
179
- "completion",
180
- "model_name",
181
- "system_prompt",
182
- ]:
183
- if relevant_keys in result:
184
- record[relevant_keys] = result[relevant_keys]
185
- distiset_results.append(record)
186
-
187
- distiset = Distiset(
188
- {
189
- "default": Dataset.from_list(distiset_results),
190
- }
191
- )
192
-
193
- # If not pushing to hub generate the dataset directly
194
- distiset = distiset["default"]
195
- if num_turns == 1:
196
- outputs = distiset.to_pandas()[["system_prompt", "prompt", "completion"]]
197
- else:
198
- outputs = distiset.to_pandas()[["messages"]]
199
- dataframe = pd.DataFrame(outputs)
200
- progress(1.0, desc="Dataset generation completed")
201
  return dataframe
202
 
203
 
204
- def push_to_hub(
205
  dataframe: pd.DataFrame,
206
  private: bool = True,
207
  org_name: str = None,
208
  repo_name: str = None,
209
- oauth_token: Union[OAuthToken, None] = None,
210
  progress=gr.Progress(),
211
- ) -> pd.DataFrame:
212
  original_dataframe = dataframe.copy(deep=True)
213
- if "messages" in dataframe.columns:
214
- dataframe["messages"] = dataframe["messages"].apply(
215
- lambda x: convert_to_list_of_dicts(x) if isinstance(x, str) else x
 
216
  )
217
- progress(0.1, desc="Setting up dataset")
218
- repo_id = _check_push_to_hub(org_name, repo_name)
219
- distiset = Distiset(
220
- {
221
- "default": Dataset.from_pandas(dataframe),
222
- }
223
- )
224
- progress(0.2, desc="Pushing dataset to hub")
225
- distiset.push_to_hub(
226
- repo_id=repo_id,
227
- private=private,
228
- include_script=False,
229
- token=oauth_token.token,
230
- create_pr=False,
231
- )
232
- progress(1.0, desc="Dataset pushed to hub")
233
  return original_dataframe
234
 
235
 
236
- def push_to_argilla(
237
  dataframe: pd.DataFrame,
238
  dataset_name: str,
239
- oauth_token: Union[OAuthToken, None] = None,
240
  progress=gr.Progress(),
241
  ) -> pd.DataFrame:
242
  original_dataframe = dataframe.copy(deep=True)
243
- if "messages" in dataframe.columns:
244
- dataframe["messages"] = dataframe["messages"].apply(
245
- lambda x: convert_to_list_of_dicts(x) if isinstance(x, str) else x
246
- )
247
  try:
248
  progress(0.1, desc="Setting up user and workspace")
249
  client = get_argilla_client()
@@ -363,294 +203,198 @@ def push_to_argilla(
363
  return original_dataframe
364
 
365
 
366
- def validate_argilla_dataset_name(
367
- dataset_name: str,
368
- final_dataset: pd.DataFrame,
369
- add_to_existing_dataset: bool,
370
- oauth_token: Union[OAuthToken, None] = None,
371
- progress=gr.Progress(),
372
- ) -> str:
373
- progress(0, desc="Validating dataset configuration")
374
- hf_user = HfApi().whoami(token=oauth_token.token)["name"]
375
- client = get_argilla_client()
376
- if dataset_name is None or dataset_name == "":
377
- raise gr.Error("Dataset name is required")
378
- # Create user if it doesn't exist
379
- rg_user = client.users(username=hf_user)
380
- if rg_user is None:
381
- rg_user = client.users.add(
382
- rg.User(username=hf_user, role="admin", password=str(uuid.uuid4()))
383
- )
384
- # Create workspace if it doesn't exist
385
- workspace = client.workspaces(name=hf_user)
386
- if workspace is None:
387
- workspace = client.workspaces.add(rg.Workspace(name=hf_user))
388
- workspace.add_user(rg_user)
389
- # Check if dataset exists
390
- dataset = client.datasets(name=dataset_name, workspace=hf_user)
391
- if dataset and not add_to_existing_dataset:
392
- raise gr.Error(f"Dataset {dataset_name} already exists")
393
- return final_dataset
394
-
395
-
396
- def upload_pipeline_code(
397
- pipeline_code,
398
- org_name,
399
- repo_name,
400
- oauth_token: Union[OAuthToken, None] = None,
401
- progress=gr.Progress(),
402
- ):
403
- repo_id = _check_push_to_hub(org_name, repo_name)
404
- progress(0.1, desc="Uploading pipeline code")
405
- with io.BytesIO(pipeline_code.encode("utf-8")) as f:
406
- upload_file(
407
- path_or_fileobj=f,
408
- path_in_repo="pipeline.py",
409
- repo_id=repo_id,
410
- repo_type="dataset",
411
- token=oauth_token.token,
412
- commit_message="Include pipeline script",
413
- create_pr=False,
414
- )
415
- progress(1.0, desc="Pipeline code uploaded")
416
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
 
418
- css = """
419
- .main_ui_logged_out{opacity: 0.3; pointer-events: none}
420
- """
421
 
422
- with gr.Blocks(
423
- title="🧬 Synthetic Data Generator",
424
- head="🧬 Synthetic Data Generator",
425
- css=css,
426
- ) as app:
427
- with gr.Row():
428
- gr.Markdown(
429
- "Want to run this locally or with other LLMs? Take a look at the FAQ tab. distilabel Synthetic Data Generator is free, we use the authentication token to push the dataset to the Hugging Face Hub and not for data generation."
430
- )
431
- with gr.Row():
432
- gr.Column()
433
- get_login_button()
434
- gr.Column()
435
-
436
- gr.Markdown("## Iterate on a sample dataset")
437
- with gr.Column() as main_ui:
438
- dataset_description = gr.TextArea(
439
- label="Give a precise description of the assistant or tool. Don't describe the dataset",
440
- value=DEFAULT_DATASET_DESCRIPTIONS[0],
441
- lines=2,
442
- )
443
- examples = gr.Examples(
444
- elem_id="system_prompt_examples",
445
- examples=[[example] for example in DEFAULT_DATASET_DESCRIPTIONS],
446
- inputs=[dataset_description],
447
- )
448
- with gr.Row():
449
- gr.Column(scale=1)
450
- btn_generate_system_prompt = gr.Button(
451
- value="Generate system prompt and sample dataset"
452
- )
453
- gr.Column(scale=1)
454
 
455
- system_prompt = gr.TextArea(
456
- label="System prompt for dataset generation. You can tune it and regenerate the sample",
457
- value=DEFAULT_SYSTEM_PROMPTS[0],
458
- lines=5,
 
 
 
 
459
  )
 
 
 
 
 
 
 
460
 
461
- with gr.Row():
462
- sample_dataset = gr.Dataframe(
463
- value=DEFAULT_DATASETS[0],
464
- label="Sample dataset. Prompts and completions truncated to 256 tokens.",
465
- interactive=False,
466
- wrap=True,
 
 
 
467
  )
468
-
469
- with gr.Row():
470
- gr.Column(scale=1)
471
- btn_generate_sample_dataset = gr.Button(
472
- value="Generate sample dataset",
 
 
 
 
 
 
 
473
  )
474
- gr.Column(scale=1)
475
-
476
- result = btn_generate_system_prompt.click(
477
- fn=generate_system_prompt,
478
- inputs=[dataset_description],
479
- outputs=[system_prompt],
480
- show_progress=True,
481
- ).then(
482
- fn=generate_sample_dataset,
483
- inputs=[system_prompt],
484
- outputs=[sample_dataset],
485
- show_progress=True,
486
- )
487
-
488
- btn_generate_sample_dataset.click(
489
- fn=generate_sample_dataset,
490
- inputs=[system_prompt],
491
- outputs=[sample_dataset],
492
- show_progress=True,
493
- )
494
-
495
- # Add a header for the full dataset generation section
496
- gr.Markdown("## Generate full dataset")
497
- gr.Markdown(
498
- "Once you're satisfied with the sample, generate a larger dataset and push it to Argilla or the Hugging Face Hub."
499
- )
500
-
501
- with gr.Column() as push_to_hub_ui:
502
- with gr.Row(variant="panel"):
503
- num_turns = gr.Number(
504
- value=1,
505
- label="Number of turns in the conversation",
506
- minimum=1,
507
- maximum=4,
508
- step=1,
509
- info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
510
- )
511
- num_rows = gr.Number(
512
- value=10,
513
- label="Number of rows in the dataset",
514
- minimum=1,
515
- maximum=500,
516
- info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
517
- )
518
-
519
- with gr.Tab(label="Argilla"):
520
- if get_argilla_client() is not None:
521
- with gr.Row(variant="panel"):
522
- dataset_name = gr.Textbox(
523
- label="Dataset name",
524
- placeholder="dataset_name",
525
- value="my-distiset",
526
- )
527
- add_to_existing_dataset = gr.Checkbox(
528
- label="Allow adding records to existing dataset",
529
- info="When selected, you do need to ensure the number of turns in the conversation is the same as the number of turns in the existing dataset.",
530
- value=False,
531
- interactive=True,
532
- scale=0.5,
533
- )
534
-
535
- with gr.Row(variant="panel"):
536
- btn_generate_full_dataset_copy = gr.Button(
537
- value="Generate", variant="primary", scale=2
538
- )
539
- btn_generate_and_push_to_argilla = gr.Button(
540
- value="Generate and Push to Argilla",
541
- variant="primary",
542
- scale=2,
543
- )
544
- btn_push_to_argilla = gr.Button(
545
- value="Push to Argilla", variant="primary", scale=2
546
- )
547
- else:
548
- gr.Markdown(
549
- "Please add `ARGILLA_API_URL` and `ARGILLA_API_KEY` to use Argilla or export the dataset to the Hugging Face Hub."
550
- )
551
- with gr.Tab("Hugging Face Hub"):
552
- with gr.Row(variant="panel"):
553
- org_name = get_org_dropdown()
554
- repo_name = gr.Textbox(
555
- label="Repo name",
556
- placeholder="dataset_name",
557
- value="my-distiset",
558
- )
559
- private = gr.Checkbox(
560
- label="Private dataset",
561
- value=True,
562
- interactive=True,
563
- scale=0.5,
564
- )
565
- with gr.Row(variant="panel"):
566
- btn_generate_full_dataset = gr.Button(
567
- value="Generate", variant="primary", scale=2
568
- )
569
- btn_generate_and_push_to_hub = gr.Button(
570
- value="Generate and Push to Hub", variant="primary", scale=2
571
- )
572
- btn_push_to_hub = gr.Button(
573
- value="Push to Hub", variant="primary", scale=2
574
- )
575
 
576
- with gr.Row():
577
- final_dataset = gr.Dataframe(
578
- value=DEFAULT_DATASETS[0],
579
- label="Generated dataset",
580
- interactive=False,
581
- wrap=True,
582
- )
 
 
 
 
 
 
 
583
 
584
- with gr.Row():
585
- success_message = gr.Markdown(visible=False)
 
 
 
586
 
587
- def show_success_message_argilla():
588
- client = get_argilla_client()
589
- argilla_api_url = client.api_url
590
- return gr.Markdown(
591
- value=f"""
592
- <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
593
- <h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
594
- <p style="margin-top: 0.5em;">
595
- Your dataset is now available at:
596
- <a href="{argilla_api_url}" target="_blank" style="color: #1565c0; text-decoration: none;">
597
- {argilla_api_url}
598
- </a>
599
- <br>Unfamiliar with Argilla? Here are some docs to help you get started:
600
- <br>• <a href="https://docs.argilla.io/latest/how_to_guides/annotate/" target="_blank">How to curate data in Argilla</a>
601
- <br>• <a href="https://docs.argilla.io/latest/how_to_guides/import_export/" target="_blank">How to export data once you have reviewed the dataset</a>
602
- </p>
603
- </div>
604
- """,
605
- visible=True,
606
- )
607
 
608
- def show_success_message_hub(org_name, repo_name):
609
- return gr.Markdown(
610
- value=f"""
611
- <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
612
- <h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
613
- <p style="margin-top: 0.5em;">
614
- The generated dataset is in the right format for fine-tuning with TRL, AutoTrain or other frameworks.
615
- Your dataset is now available at:
616
- <a href="https://huggingface.co/datasets/{org_name}/{repo_name}" target="_blank" style="color: #1565c0; text-decoration: none;">
617
- https://huggingface.co/datasets/{org_name}/{repo_name}
618
- </a>
619
- </p>
620
- </div>
621
- """,
622
- visible=True,
623
- )
624
 
625
- def hide_success_message():
626
- return gr.Markdown(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
627
 
628
- gr.Markdown("## Or run this pipeline locally with distilabel")
629
- gr.Markdown(
630
- "You can run this pipeline locally with distilabel. For more information, please refer to the [distilabel documentation](https://distilabel.argilla.io/) or go to the FAQ tab at the top of the page for more information."
631
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
632
 
633
- with gr.Accordion(
634
- "Run this pipeline using distilabel",
635
- open=False,
636
- ):
637
- pipeline_code = gr.Code(
638
- value=generate_pipeline_code(
639
- system_prompt.value, num_turns.value, num_rows.value
640
- ),
641
- language="python",
642
- label="Distilabel Pipeline Code",
643
  )
644
 
645
- sample_dataset.change(
646
- fn=lambda x: x,
647
- inputs=[sample_dataset],
648
- outputs=[final_dataset],
649
- )
650
  gr.on(
651
  triggers=[
652
  btn_generate_full_dataset.click,
653
- btn_generate_full_dataset_copy.click,
654
  ],
655
  fn=hide_success_message,
656
  outputs=[success_message],
@@ -662,7 +406,7 @@ with gr.Blocks(
662
  )
663
 
664
  btn_generate_and_push_to_argilla.click(
665
- fn=validate_argilla_dataset_name,
666
  inputs=[dataset_name, final_dataset, add_to_existing_dataset],
667
  outputs=[final_dataset],
668
  show_progress=True,
@@ -675,7 +419,7 @@ with gr.Blocks(
675
  outputs=[final_dataset],
676
  show_progress=True,
677
  ).success(
678
- fn=push_to_argilla,
679
  inputs=[final_dataset, dataset_name],
680
  outputs=[final_dataset],
681
  show_progress=True,
@@ -694,12 +438,12 @@ with gr.Blocks(
694
  outputs=[final_dataset],
695
  show_progress=True,
696
  ).then(
697
- fn=push_to_hub,
698
  inputs=[final_dataset, private, org_name, repo_name],
699
  outputs=[final_dataset],
700
  show_progress=True,
701
  ).then(
702
- fn=upload_pipeline_code,
703
  inputs=[pipeline_code, org_name, repo_name],
704
  outputs=[],
705
  show_progress=True,
@@ -713,12 +457,12 @@ with gr.Blocks(
713
  fn=hide_success_message,
714
  outputs=[success_message],
715
  ).then(
716
- fn=push_to_hub,
717
  inputs=[final_dataset, private, org_name, repo_name],
718
  outputs=[final_dataset],
719
  show_progress=True,
720
  ).then(
721
- fn=upload_pipeline_code,
722
  inputs=[pipeline_code, org_name, repo_name],
723
  outputs=[],
724
  show_progress=True,
@@ -732,12 +476,12 @@ with gr.Blocks(
732
  fn=hide_success_message,
733
  outputs=[success_message],
734
  ).success(
735
- fn=validate_argilla_dataset_name,
736
  inputs=[dataset_name, final_dataset, add_to_existing_dataset],
737
  outputs=[final_dataset],
738
  show_progress=True,
739
  ).success(
740
- fn=push_to_argilla,
741
  inputs=[final_dataset, dataset_name],
742
  outputs=[final_dataset],
743
  show_progress=True,
@@ -762,5 +506,3 @@ with gr.Blocks(
762
  inputs=[system_prompt, num_turns, num_rows],
763
  outputs=[pipeline_code],
764
  )
765
- app.load(get_org_dropdown, outputs=[org_name])
766
- app.load(fn=swap_visibilty, outputs=main_ui)
 
1
  import ast
 
 
2
  from typing import Dict, List, Union
3
 
4
  import argilla as rg
 
6
  import pandas as pd
7
  from datasets import Dataset
8
  from distilabel.distiset import Distiset
9
+ from huggingface_hub import HfApi
 
 
 
10
 
11
+ from src.distilabel_dataset_generator.apps.base import (
12
+ get_argilla_client,
13
+ get_main_ui,
14
+ get_pipeline_code_ui,
15
+ hide_success_message,
16
+ push_pipeline_code_to_hub,
17
+ show_success_message_argilla,
18
+ show_success_message_hub,
19
+ validate_argilla_user_workspace_dataset,
20
+ )
21
+ from src.distilabel_dataset_generator.apps.base import (
22
+ push_dataset_to_hub as push_to_hub_base,
23
+ )
24
+ from src.distilabel_dataset_generator.pipelines.base import (
25
+ DEFAULT_BATCH_SIZE,
26
+ )
27
  from src.distilabel_dataset_generator.pipelines.embeddings import (
28
  get_embeddings,
29
  get_sentence_embedding_dimensions,
30
  )
31
  from src.distilabel_dataset_generator.pipelines.sft import (
 
32
  DEFAULT_DATASET_DESCRIPTIONS,
33
  DEFAULT_DATASETS,
34
  DEFAULT_SYSTEM_PROMPTS,
 
38
  get_prompt_generator,
39
  get_response_generator,
40
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ TASK = "supervised_fine_tuning"
43
 
 
 
 
 
 
 
44
 
45
+ def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
46
+ def convert_to_list_of_dicts(messages: str) -> List[Dict[str, str]]:
47
+ return ast.literal_eval(
48
+ messages.replace("'user'}", "'user'},")
49
+ .replace("'system'}", "'system'},")
50
+ .replace("'assistant'}", "'assistant'},")
 
 
 
 
 
51
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ if "messages" in dataframe.columns:
54
+ dataframe["messages"] = dataframe["messages"].apply(
55
+ lambda x: convert_to_list_of_dicts(x) if isinstance(x, str) else x
 
 
 
 
 
56
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  return dataframe
58
 
59
 
60
+ def push_dataset_to_hub(
61
  dataframe: pd.DataFrame,
62
  private: bool = True,
63
  org_name: str = None,
64
  repo_name: str = None,
65
+ oauth_token: Union[gr.OAuthToken, None] = None,
66
  progress=gr.Progress(),
67
+ ):
68
  original_dataframe = dataframe.copy(deep=True)
69
+ dataframe = convert_dataframe_messages(dataframe)
70
+ try:
71
+ push_to_hub_base(
72
+ dataframe, private, org_name, repo_name, oauth_token, progress, task=TASK
73
  )
74
+ except Exception as e:
75
+ raise gr.Error(f"Error pushing dataset to the Hub: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  return original_dataframe
77
 
78
 
79
+ def push_dataset_to_argilla(
80
  dataframe: pd.DataFrame,
81
  dataset_name: str,
82
+ oauth_token: Union[gr.OAuthToken, None] = None,
83
  progress=gr.Progress(),
84
  ) -> pd.DataFrame:
85
  original_dataframe = dataframe.copy(deep=True)
86
+ dataframe = convert_dataframe_messages(dataframe)
 
 
 
87
  try:
88
  progress(0.1, desc="Setting up user and workspace")
89
  client = get_argilla_client()
 
203
  return original_dataframe
204
 
205
 
206
+ def generate_system_prompt(dataset_description, progress=gr.Progress()):
207
+ progress(0.0, desc="Generating system prompt")
208
+ if dataset_description in DEFAULT_DATASET_DESCRIPTIONS:
209
+ index = DEFAULT_DATASET_DESCRIPTIONS.index(dataset_description)
210
+ if index < len(DEFAULT_SYSTEM_PROMPTS):
211
+ return DEFAULT_SYSTEM_PROMPTS[index]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
+ progress(0.3, desc="Initializing text generation")
214
+ generate_description = get_prompt_generator()
215
+ progress(0.7, desc="Generating system prompt")
216
+ result = next(
217
+ generate_description.process(
218
+ [
219
+ {
220
+ "system_prompt": PROMPT_CREATION_PROMPT,
221
+ "instruction": dataset_description,
222
+ }
223
+ ]
224
+ )
225
+ )[0]["generation"]
226
+ progress(1.0, desc="System prompt generated")
227
+ return result
228
 
 
 
 
229
 
230
+ def generate_dataset(
231
+ system_prompt: str,
232
+ num_turns: int = 1,
233
+ num_rows: int = 5,
234
+ is_sample: bool = False,
235
+ progress=gr.Progress(),
236
+ ) -> pd.DataFrame:
237
+ progress(0.0, desc="(1/2) Generating instructions")
238
+ magpie_generator = get_magpie_generator(
239
+ num_turns, num_rows, system_prompt, is_sample
240
+ )
241
+ response_generator = get_response_generator(num_turns, system_prompt, is_sample)
242
+ total_steps: int = num_rows * 2
243
+ batch_size = DEFAULT_BATCH_SIZE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
+ # create instructions
246
+ n_processed = 0
247
+ magpie_results = []
248
+ while n_processed < num_rows:
249
+ progress(
250
+ 0.5 * n_processed / num_rows,
251
+ total=total_steps,
252
+ desc="(1/2) Generating instructions",
253
  )
254
+ remaining_rows = num_rows - n_processed
255
+ batch_size = min(batch_size, remaining_rows)
256
+ inputs = [{"system_prompt": system_prompt} for _ in range(batch_size)]
257
+ batch = list(magpie_generator.process(inputs=inputs))
258
+ magpie_results.extend(batch[0])
259
+ n_processed += batch_size
260
+ progress(0.5, desc="(1/2) Generating instructions")
261
 
262
+ # generate responses
263
+ n_processed = 0
264
+ response_results = []
265
+ if num_turns == 1:
266
+ while n_processed < num_rows:
267
+ progress(
268
+ 0.5 + 0.5 * n_processed / num_rows,
269
+ total=total_steps,
270
+ desc="(2/2) Generating responses",
271
  )
272
+ batch = magpie_results[n_processed : n_processed + batch_size]
273
+ responses = list(response_generator.process(inputs=batch))
274
+ response_results.extend(responses[0])
275
+ n_processed += batch_size
276
+ for result in response_results:
277
+ result["prompt"] = result["instruction"]
278
+ result["completion"] = result["generation"]
279
+ result["system_prompt"] = system_prompt
280
+ else:
281
+ for result in magpie_results:
282
+ result["conversation"].insert(
283
+ 0, {"role": "system", "content": system_prompt}
284
  )
285
+ result["messages"] = result["conversation"]
286
+ while n_processed < num_rows:
287
+ progress(
288
+ 0.5 + 0.5 * n_processed / num_rows,
289
+ total=total_steps,
290
+ desc="(2/2) Generating responses",
291
+ )
292
+ batch = magpie_results[n_processed : n_processed + batch_size]
293
+ responses = list(response_generator.process(inputs=batch))
294
+ response_results.extend(responses[0])
295
+ n_processed += batch_size
296
+ for result in response_results:
297
+ result["messages"].append(
298
+ {"role": "assistant", "content": result["generation"]}
299
+ )
300
+ progress(
301
+ 1,
302
+ total=total_steps,
303
+ desc="(2/2) Creating dataset",
304
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
+ # create distiset
307
+ distiset_results = []
308
+ for result in response_results:
309
+ record = {}
310
+ for relevant_keys in [
311
+ "messages",
312
+ "prompt",
313
+ "completion",
314
+ "model_name",
315
+ "system_prompt",
316
+ ]:
317
+ if relevant_keys in result:
318
+ record[relevant_keys] = result[relevant_keys]
319
+ distiset_results.append(record)
320
 
321
+ distiset = Distiset(
322
+ {
323
+ "default": Dataset.from_list(distiset_results),
324
+ }
325
+ )
326
 
327
+ # If not pushing to hub generate the dataset directly
328
+ distiset = distiset["default"]
329
+ if num_turns == 1:
330
+ outputs = distiset.to_pandas()[["system_prompt", "prompt", "completion"]]
331
+ else:
332
+ outputs = distiset.to_pandas()[["messages"]]
333
+ dataframe = pd.DataFrame(outputs)
334
+ progress(1.0, desc="Dataset generation completed")
335
+ return dataframe
 
 
 
 
 
 
 
 
 
 
 
336
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
+ (
339
+ app,
340
+ main_ui,
341
+ custom_input_ui,
342
+ dataset_description,
343
+ examples,
344
+ btn_generate_system_prompt,
345
+ system_prompt,
346
+ sample_dataset,
347
+ btn_generate_sample_dataset,
348
+ dataset_name,
349
+ add_to_existing_dataset,
350
+ btn_generate_full_dataset_argilla,
351
+ btn_generate_and_push_to_argilla,
352
+ btn_push_to_argilla,
353
+ org_name,
354
+ repo_name,
355
+ private,
356
+ btn_generate_full_dataset,
357
+ btn_generate_and_push_to_hub,
358
+ btn_push_to_hub,
359
+ final_dataset,
360
+ success_message,
361
+ ) = get_main_ui(
362
+ default_dataset_descriptions=DEFAULT_DATASET_DESCRIPTIONS,
363
+ default_system_prompts=DEFAULT_SYSTEM_PROMPTS,
364
+ default_datasets=DEFAULT_DATASETS,
365
+ fn_generate_system_prompt=generate_system_prompt,
366
+ fn_generate_dataset=generate_dataset,
367
+ task=TASK,
368
+ )
369
 
370
+ with app:
371
+ with main_ui:
372
+ with custom_input_ui:
373
+ num_turns = gr.Number(
374
+ value=1,
375
+ label="Number of turns in the conversation",
376
+ minimum=1,
377
+ maximum=4,
378
+ step=1,
379
+ info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
380
+ )
381
+ num_rows = gr.Number(
382
+ value=10,
383
+ label="Number of rows in the dataset",
384
+ minimum=1,
385
+ maximum=500,
386
+ info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
387
+ )
388
 
389
+ pipeline_code = get_pipeline_code_ui(
390
+ generate_pipeline_code(system_prompt.value, num_turns.value, num_rows.value)
 
 
 
 
 
 
 
 
391
  )
392
 
393
+ # define app triggers
 
 
 
 
394
  gr.on(
395
  triggers=[
396
  btn_generate_full_dataset.click,
397
+ btn_generate_full_dataset_argilla.click,
398
  ],
399
  fn=hide_success_message,
400
  outputs=[success_message],
 
406
  )
407
 
408
  btn_generate_and_push_to_argilla.click(
409
+ fn=validate_argilla_user_workspace_dataset,
410
  inputs=[dataset_name, final_dataset, add_to_existing_dataset],
411
  outputs=[final_dataset],
412
  show_progress=True,
 
419
  outputs=[final_dataset],
420
  show_progress=True,
421
  ).success(
422
+ fn=push_dataset_to_argilla,
423
  inputs=[final_dataset, dataset_name],
424
  outputs=[final_dataset],
425
  show_progress=True,
 
438
  outputs=[final_dataset],
439
  show_progress=True,
440
  ).then(
441
+ fn=push_dataset_to_hub,
442
  inputs=[final_dataset, private, org_name, repo_name],
443
  outputs=[final_dataset],
444
  show_progress=True,
445
  ).then(
446
+ fn=push_pipeline_code_to_hub,
447
  inputs=[pipeline_code, org_name, repo_name],
448
  outputs=[],
449
  show_progress=True,
 
457
  fn=hide_success_message,
458
  outputs=[success_message],
459
  ).then(
460
+ fn=push_dataset_to_hub,
461
  inputs=[final_dataset, private, org_name, repo_name],
462
  outputs=[final_dataset],
463
  show_progress=True,
464
  ).then(
465
+ fn=push_pipeline_code_to_hub,
466
  inputs=[pipeline_code, org_name, repo_name],
467
  outputs=[],
468
  show_progress=True,
 
476
  fn=hide_success_message,
477
  outputs=[success_message],
478
  ).success(
479
+ fn=validate_argilla_user_workspace_dataset,
480
  inputs=[dataset_name, final_dataset, add_to_existing_dataset],
481
  outputs=[final_dataset],
482
  show_progress=True,
483
  ).success(
484
+ fn=push_dataset_to_argilla,
485
  inputs=[final_dataset, dataset_name],
486
  outputs=[final_dataset],
487
  show_progress=True,
 
506
  inputs=[system_prompt, num_turns, num_rows],
507
  outputs=[pipeline_code],
508
  )
 
 
src/distilabel_dataset_generator/apps/textcat.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List, Union
3
+
4
+ import argilla as rg
5
+ import gradio as gr
6
+ import pandas as pd
7
+ from datasets import Dataset
8
+ from huggingface_hub import HfApi
9
+
10
+ from src.distilabel_dataset_generator.apps.base import (
11
+ get_argilla_client,
12
+ get_main_ui,
13
+ get_pipeline_code_ui,
14
+ hide_success_message,
15
+ push_pipeline_code_to_hub,
16
+ show_success_message_argilla,
17
+ show_success_message_hub,
18
+ validate_argilla_user_workspace_dataset,
19
+ )
20
+ from src.distilabel_dataset_generator.apps.base import (
21
+ push_dataset_to_hub as push_to_hub_base,
22
+ )
23
+ from src.distilabel_dataset_generator.pipelines.base import (
24
+ DEFAULT_BATCH_SIZE,
25
+ )
26
+ from src.distilabel_dataset_generator.pipelines.embeddings import (
27
+ get_embeddings,
28
+ get_sentence_embedding_dimensions,
29
+ )
30
+ from src.distilabel_dataset_generator.pipelines.textcat import (
31
+ DEFAULT_DATASET_DESCRIPTIONS,
32
+ DEFAULT_DATASETS,
33
+ DEFAULT_SYSTEM_PROMPTS,
34
+ PROMPT_CREATION_PROMPT,
35
+ generate_pipeline_code,
36
+ get_labeller_generator,
37
+ get_prompt_generator,
38
+ get_textcat_generator,
39
+ )
40
+ from src.distilabel_dataset_generator.utils import get_preprocess_labels
41
+
42
+ TASK = "text_classification"
43
+
44
+
45
+ def push_dataset_to_hub(
46
+ dataframe: pd.DataFrame,
47
+ private: bool = True,
48
+ org_name: str = None,
49
+ repo_name: str = None,
50
+ oauth_token: Union[gr.OAuthToken, None] = None,
51
+ progress=gr.Progress(),
52
+ labels: List[str] = None,
53
+ num_labels: int = 1,
54
+ ):
55
+ original_dataframe = dataframe.copy(deep=True)
56
+ labels = get_preprocess_labels(labels)
57
+ try:
58
+ push_to_hub_base(
59
+ dataframe,
60
+ private,
61
+ org_name,
62
+ repo_name,
63
+ oauth_token,
64
+ progress,
65
+ labels,
66
+ num_labels,
67
+ task=TASK,
68
+ )
69
+ except Exception as e:
70
+ raise gr.Error(f"Error pushing dataset to the Hub: {e}")
71
+ return original_dataframe
72
+
73
+
74
+ def push_dataset_to_argilla(
75
+ dataframe: pd.DataFrame,
76
+ dataset_name: str,
77
+ oauth_token: Union[gr.OAuthToken, None] = None,
78
+ progress=gr.Progress(),
79
+ num_labels: int = 1,
80
+ labels: List[str] = None,
81
+ ) -> pd.DataFrame:
82
+ original_dataframe = dataframe.copy(deep=True)
83
+ try:
84
+ progress(0.1, desc="Setting up user and workspace")
85
+ client = get_argilla_client()
86
+ hf_user = HfApi().whoami(token=oauth_token.token)["name"]
87
+ labels = get_preprocess_labels(labels)
88
+ settings = rg.Settings(
89
+ fields=[
90
+ rg.TextField(
91
+ name="text",
92
+ description="The text classification data",
93
+ title="Text",
94
+ ),
95
+ ],
96
+ questions=[
97
+ (
98
+ rg.LabelQuestion(
99
+ name="label",
100
+ title="Label",
101
+ description="The label of the text",
102
+ labels=labels,
103
+ )
104
+ if num_labels == 1
105
+ else rg.MultiLabelQuestion(
106
+ name="labels",
107
+ title="Labels",
108
+ description="The labels of the conversation",
109
+ labels=labels,
110
+ )
111
+ ),
112
+ ],
113
+ metadata=[
114
+ rg.IntegerMetadataProperty(name="text_length", title="Text Length"),
115
+ ],
116
+ vectors=[
117
+ rg.VectorField(
118
+ name="text_embeddings",
119
+ dimensions=get_sentence_embedding_dimensions(),
120
+ )
121
+ ],
122
+ guidelines="Please review the text and provide or correct the label where needed.",
123
+ )
124
+
125
+ dataframe["text_length"] = dataframe["text"].apply(len)
126
+ dataframe["text_embeddings"] = get_embeddings(dataframe["text"])
127
+
128
+ progress(0.5, desc="Creating dataset")
129
+ rg_dataset = client.datasets(name=dataset_name, workspace=hf_user)
130
+ if rg_dataset is None:
131
+ rg_dataset = rg.Dataset(
132
+ name=dataset_name,
133
+ workspace=hf_user,
134
+ settings=settings,
135
+ client=client,
136
+ )
137
+ rg_dataset = rg_dataset.create()
138
+ progress(0.7, desc="Pushing dataset to Argilla")
139
+ hf_dataset = Dataset.from_pandas(dataframe)
140
+ records = [
141
+ rg.Record(
142
+ fields={
143
+ "text": sample["text"],
144
+ },
145
+ metadata={"text_length": sample["text_length"]},
146
+ vectors={"text_embeddings": sample["text_embeddings"]},
147
+ suggestions=(
148
+ [
149
+ rg.Suggestion(
150
+ question_name="label" if num_labels == 1 else "labels",
151
+ value=(
152
+ sample["label"] if num_labels == 1 else sample["labels"]
153
+ ),
154
+ )
155
+ ]
156
+ if (
157
+ (num_labels == 1 and sample["label"] in labels)
158
+ or (
159
+ num_labels > 1
160
+ and all(label in labels for label in sample["labels"])
161
+ )
162
+ )
163
+ else []
164
+ ),
165
+ )
166
+ for sample in hf_dataset
167
+ ]
168
+ rg_dataset.records.log(records=records)
169
+ progress(1.0, desc="Dataset pushed to Argilla")
170
+ except Exception as e:
171
+ raise gr.Error(f"Error pushing dataset to Argilla: {e}")
172
+ return original_dataframe
173
+
174
+
175
+ def generate_system_prompt(dataset_description, progress=gr.Progress()):
176
+ progress(0.0, desc="Generating text classification task")
177
+ if dataset_description in DEFAULT_DATASET_DESCRIPTIONS:
178
+ index = DEFAULT_DATASET_DESCRIPTIONS.index(dataset_description)
179
+ if index < len(DEFAULT_SYSTEM_PROMPTS):
180
+ return DEFAULT_SYSTEM_PROMPTS[index]
181
+
182
+ progress(0.3, desc="Initializing text generation")
183
+ generate_description = get_prompt_generator()
184
+ progress(0.7, desc="Generating text classification task")
185
+ result = next(
186
+ generate_description.process(
187
+ [
188
+ {
189
+ "system_prompt": PROMPT_CREATION_PROMPT,
190
+ "instruction": dataset_description,
191
+ }
192
+ ]
193
+ )
194
+ )[0]["generation"]
195
+ progress(1.0, desc="Text classification task generated")
196
+ return result
197
+
198
+
199
+ def generate_dataset(
200
+ system_prompt: str,
201
+ difficulty: str,
202
+ clarity: str,
203
+ labels: List[str] = None,
204
+ num_labels: int = 1,
205
+ num_rows: int = 10,
206
+ is_sample: bool = False,
207
+ progress=gr.Progress(),
208
+ ) -> pd.DataFrame:
209
+ progress(0.0, desc="(1/2) Generating text classification data")
210
+ labels = get_preprocess_labels(labels)
211
+ textcat_generator = get_textcat_generator(
212
+ difficulty=difficulty, clarity=clarity, is_sample=is_sample
213
+ )
214
+ labeller_generator = get_labeller_generator(
215
+ system_prompt=system_prompt,
216
+ labels=labels,
217
+ num_labels=num_labels,
218
+ is_sample=is_sample,
219
+ )
220
+ total_steps: int = num_rows * 2
221
+ batch_size = DEFAULT_BATCH_SIZE
222
+
223
+ # create text classification data
224
+ n_processed = 0
225
+ textcat_results = []
226
+ while n_processed < num_rows:
227
+ progress(
228
+ 0.5 * n_processed / num_rows,
229
+ total=total_steps,
230
+ desc="(1/2) Generating text classification data",
231
+ )
232
+ remaining_rows = num_rows - n_processed
233
+ batch_size = min(batch_size, remaining_rows)
234
+ inputs = [{"task": system_prompt} for _ in range(batch_size)]
235
+ batch = list(textcat_generator.process(inputs=inputs))
236
+ textcat_results.extend(batch[0])
237
+ n_processed += batch_size
238
+ for result in textcat_results:
239
+ result["text"] = result["input_text"]
240
+
241
+ # label text classification data
242
+ progress(0.5, desc="(1/2) Generating text classification data")
243
+ if not is_sample:
244
+ n_processed = 0
245
+ labeller_results = []
246
+ while n_processed < num_rows:
247
+ progress(
248
+ 0.5 + 0.5 * n_processed / num_rows,
249
+ total=total_steps,
250
+ desc="(1/2) Labeling text classification data",
251
+ )
252
+ batch = textcat_results[n_processed : n_processed + batch_size]
253
+ labels_batch = list(labeller_generator.process(inputs=batch))
254
+ labeller_results.extend(labels_batch[0])
255
+ n_processed += batch_size
256
+ progress(
257
+ 1,
258
+ total=total_steps,
259
+ desc="(2/2) Creating dataset",
260
+ )
261
+
262
+ # create final dataset
263
+ distiset_results = []
264
+ source_results = textcat_results if is_sample else labeller_results
265
+ for result in source_results:
266
+ record = {
267
+ key: result[key]
268
+ for key in ["text", "label" if is_sample else "labels"]
269
+ if key in result
270
+ }
271
+ distiset_results.append(record)
272
+
273
+ dataframe = pd.DataFrame(distiset_results)
274
+ if not is_sample:
275
+ if num_labels == 1:
276
+ dataframe = dataframe.rename(columns={"labels": "label"})
277
+ dataframe["label"] = dataframe["label"].apply(
278
+ lambda x: x.lower().strip() if x.lower().strip() in labels else None
279
+ )
280
+ else:
281
+ dataframe["labels"] = dataframe["labels"].apply(
282
+ lambda x: (
283
+ [
284
+ label.lower().strip()
285
+ for label in x
286
+ if label.lower().strip() in labels
287
+ ]
288
+ if isinstance(x, list)
289
+ else None
290
+ )
291
+ )
292
+ progress(1.0, desc="Dataset generation completed")
293
+ return dataframe
294
+
295
+
296
+ def update_suggested_labels(system_prompt):
297
+ new_labels = re.findall(r"'(\b[\w-]+\b)'", system_prompt)
298
+ if not new_labels:
299
+ return gr.Warning(
300
+ "No labels found in the system prompt. Please add labels manually."
301
+ )
302
+ return gr.update(choices=new_labels, value=new_labels)
303
+
304
+
305
+ def validate_input_labels(labels):
306
+ if not labels or len(labels) < 2:
307
+ raise gr.Error(
308
+ f"Please select at least 2 labels to classify your text. You selected {len(labels) if labels else 0}."
309
+ )
310
+ return labels
311
+
312
+
313
+ (
314
+ app,
315
+ main_ui,
316
+ custom_input_ui,
317
+ dataset_description,
318
+ examples,
319
+ btn_generate_system_prompt,
320
+ system_prompt,
321
+ sample_dataset,
322
+ btn_generate_sample_dataset,
323
+ dataset_name,
324
+ add_to_existing_dataset,
325
+ btn_generate_full_dataset_argilla,
326
+ btn_generate_and_push_to_argilla,
327
+ btn_push_to_argilla,
328
+ org_name,
329
+ repo_name,
330
+ private,
331
+ btn_generate_full_dataset,
332
+ btn_generate_and_push_to_hub,
333
+ btn_push_to_hub,
334
+ final_dataset,
335
+ success_message,
336
+ ) = get_main_ui(
337
+ default_dataset_descriptions=DEFAULT_DATASET_DESCRIPTIONS,
338
+ default_system_prompts=DEFAULT_SYSTEM_PROMPTS,
339
+ default_datasets=DEFAULT_DATASETS,
340
+ fn_generate_system_prompt=generate_system_prompt,
341
+ fn_generate_dataset=generate_dataset,
342
+ task=TASK,
343
+ )
344
+
345
+ with app:
346
+ with main_ui:
347
+ with custom_input_ui:
348
+ difficulty = gr.Dropdown(
349
+ choices=[
350
+ ("High School", "high school"),
351
+ ("College", "college"),
352
+ ("PhD", "PhD"),
353
+ ("Mixed", "mixed"),
354
+ ],
355
+ value="mixed",
356
+ label="Difficulty",
357
+ info="The difficulty of the text to be generated.",
358
+ )
359
+ clarity = gr.Dropdown(
360
+ choices=[
361
+ ("Clear", "clear"),
362
+ (
363
+ "Understandable",
364
+ "understandable with some effort",
365
+ ),
366
+ ("Ambiguous", "ambiguous"),
367
+ ("Mixed", "mixed"),
368
+ ],
369
+ value="mixed",
370
+ label="Clarity",
371
+ info="The clarity of the text to be generated.",
372
+ )
373
+ with gr.Column():
374
+ labels = gr.Dropdown(
375
+ choices=[],
376
+ allow_custom_value=True,
377
+ interactive=True,
378
+ label="Labels",
379
+ multiselect=True,
380
+ info="Add the labels to classify the text.",
381
+ )
382
+ with gr.Blocks():
383
+ btn_suggested_labels = gr.Button(
384
+ value="Add suggested labels",
385
+ size="sm",
386
+ )
387
+ num_labels = gr.Number(
388
+ label="Number of labels",
389
+ value=1,
390
+ minimum=1,
391
+ maximum=10,
392
+ info="The number of labels to classify the text.",
393
+ )
394
+ num_rows = gr.Number(
395
+ label="Number of rows",
396
+ value=10,
397
+ minimum=1,
398
+ maximum=500,
399
+ info="More rows will take longer to generate.",
400
+ )
401
+
402
+ pipeline_code = get_pipeline_code_ui(
403
+ generate_pipeline_code(
404
+ system_prompt.value,
405
+ difficulty=difficulty.value,
406
+ clarity=clarity.value,
407
+ labels=labels.value,
408
+ num_labels=num_labels.value,
409
+ num_rows=num_rows.value,
410
+ )
411
+ )
412
+
413
+ # define app triggers
414
+ btn_suggested_labels.click(
415
+ fn=update_suggested_labels,
416
+ inputs=[system_prompt],
417
+ outputs=labels,
418
+ )
419
+
420
+ gr.on(
421
+ triggers=[
422
+ btn_generate_full_dataset.click,
423
+ btn_generate_full_dataset_argilla.click,
424
+ ],
425
+ fn=hide_success_message,
426
+ outputs=[success_message],
427
+ ).then(
428
+ fn=validate_input_labels,
429
+ inputs=[labels],
430
+ outputs=[labels],
431
+ ).success(
432
+ fn=generate_dataset,
433
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
434
+ outputs=[final_dataset],
435
+ show_progress=True,
436
+ )
437
+
438
+ btn_generate_and_push_to_argilla.click(
439
+ fn=validate_argilla_user_workspace_dataset,
440
+ inputs=[dataset_name, final_dataset, add_to_existing_dataset],
441
+ outputs=[final_dataset],
442
+ show_progress=True,
443
+ ).success(
444
+ fn=hide_success_message,
445
+ outputs=[success_message],
446
+ ).success(
447
+ fn=generate_dataset,
448
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
449
+ outputs=[final_dataset],
450
+ show_progress=True,
451
+ ).success(
452
+ fn=push_dataset_to_argilla,
453
+ inputs=[final_dataset, dataset_name, num_labels, labels],
454
+ outputs=[final_dataset],
455
+ show_progress=True,
456
+ ).success(
457
+ fn=show_success_message_argilla,
458
+ inputs=[],
459
+ outputs=[success_message],
460
+ )
461
+
462
+ btn_generate_and_push_to_hub.click(
463
+ fn=hide_success_message,
464
+ outputs=[success_message],
465
+ ).then(
466
+ fn=generate_dataset,
467
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
468
+ outputs=[final_dataset],
469
+ show_progress=True,
470
+ ).then(
471
+ fn=push_dataset_to_hub,
472
+ inputs=[final_dataset, private, org_name, repo_name, labels, num_labels],
473
+ outputs=[final_dataset],
474
+ show_progress=True,
475
+ ).then(
476
+ fn=push_pipeline_code_to_hub,
477
+ inputs=[pipeline_code, org_name, repo_name],
478
+ outputs=[],
479
+ show_progress=True,
480
+ ).success(
481
+ fn=show_success_message_hub,
482
+ inputs=[org_name, repo_name],
483
+ outputs=[success_message],
484
+ )
485
+
486
+ btn_push_to_hub.click(
487
+ fn=hide_success_message,
488
+ outputs=[success_message],
489
+ ).then(
490
+ fn=push_dataset_to_hub,
491
+ inputs=[final_dataset, private, org_name, repo_name, labels, num_labels],
492
+ outputs=[final_dataset],
493
+ show_progress=True,
494
+ ).then(
495
+ fn=push_pipeline_code_to_hub,
496
+ inputs=[pipeline_code, org_name, repo_name],
497
+ outputs=[],
498
+ show_progress=True,
499
+ ).success(
500
+ fn=show_success_message_hub,
501
+ inputs=[org_name, repo_name],
502
+ outputs=[success_message],
503
+ )
504
+
505
+ btn_push_to_argilla.click(
506
+ fn=hide_success_message,
507
+ outputs=[success_message],
508
+ ).success(
509
+ fn=validate_argilla_user_workspace_dataset,
510
+ inputs=[dataset_name, final_dataset, add_to_existing_dataset],
511
+ outputs=[final_dataset],
512
+ show_progress=True,
513
+ ).success(
514
+ fn=push_dataset_to_argilla,
515
+ inputs=[final_dataset, dataset_name, num_labels, labels],
516
+ outputs=[final_dataset],
517
+ show_progress=True,
518
+ ).success(
519
+ fn=show_success_message_argilla,
520
+ inputs=[],
521
+ outputs=[success_message],
522
+ )
523
+
524
+ system_prompt.change(
525
+ fn=generate_pipeline_code,
526
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
527
+ outputs=[pipeline_code],
528
+ )
529
+ difficulty.change(
530
+ fn=generate_pipeline_code,
531
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
532
+ outputs=[pipeline_code],
533
+ )
534
+ clarity.change(
535
+ fn=generate_pipeline_code,
536
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
537
+ outputs=[pipeline_code],
538
+ )
539
+ labels.change(
540
+ fn=generate_pipeline_code,
541
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
542
+ outputs=[pipeline_code],
543
+ )
544
+ num_labels.change(
545
+ fn=generate_pipeline_code,
546
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
547
+ outputs=[pipeline_code],
548
+ )
src/distilabel_dataset_generator/pipelines/base.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.distilabel_dataset_generator.utils import HF_TOKENS
2
+
3
+ DEFAULT_BATCH_SIZE = 5
4
+ TOKEN_INDEX = 0
5
+ MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
6
+
7
+
8
+ def _get_next_api_key():
9
+ global TOKEN_INDEX
10
+ api_key = HF_TOKENS[TOKEN_INDEX % len(HF_TOKENS)]
11
+ TOKEN_INDEX += 1
12
+ return api_key
src/distilabel_dataset_generator/pipelines/sft.py CHANGED
@@ -1,12 +1,11 @@
1
  import pandas as pd
2
- from datasets import Dataset
3
- from distilabel.distiset import Distiset
4
  from distilabel.llms import InferenceEndpointsLLM
5
- from distilabel.pipeline import Pipeline
6
- from distilabel.steps import KeepColumns
7
  from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
8
 
9
- from src.distilabel_dataset_generator.utils import HF_TOKENS
 
 
 
10
 
11
  INFORMATION_SEEKING_PROMPT = (
12
  "You are an AI assistant designed to provide accurate and concise information on a wide"
@@ -120,7 +119,6 @@ The prompt you write should follow the same style and structure as the following
120
  User dataset description:
121
  """
122
 
123
- MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
124
  DEFAULT_DATASET_DESCRIPTIONS = (
125
  "rude customer assistant for a phone company",
126
  "assistant that solves math puzzles using python",
@@ -157,8 +155,6 @@ _STOP_SEQUENCES = [
157
  "assistant",
158
  " \n\n",
159
  ]
160
- DEFAULT_BATCH_SIZE = 5
161
- TOKEN_INDEX = 0
162
 
163
 
164
  def _get_output_mappings(num_turns):
@@ -213,13 +209,6 @@ if __name__ == "__main__":
213
  return code
214
 
215
 
216
- def _get_next_api_key():
217
- global TOKEN_INDEX
218
- api_key = HF_TOKENS[TOKEN_INDEX % len(HF_TOKENS)]
219
- TOKEN_INDEX += 1
220
- return api_key
221
-
222
-
223
  def get_magpie_generator(num_turns, num_rows, system_prompt, is_sample):
224
  input_mappings = _get_output_mappings(num_turns)
225
  output_mappings = input_mappings.copy()
@@ -300,12 +289,9 @@ def get_response_generator(num_turns, system_prompt, is_sample):
300
 
301
 
302
  def get_prompt_generator():
303
- global TOKEN_INDEX
304
- api_key = HF_TOKENS[TOKEN_INDEX % len(HF_TOKENS)]
305
- TOKEN_INDEX += 1
306
  prompt_generator = TextGeneration(
307
  llm=InferenceEndpointsLLM(
308
- api_key=api_key,
309
  model_id=MODEL,
310
  tokenizer_id=MODEL,
311
  generation_kwargs={
@@ -318,95 +304,3 @@ def get_prompt_generator():
318
  )
319
  prompt_generator.load()
320
  return prompt_generator
321
-
322
-
323
- def get_pipeline(num_turns, num_rows, system_prompt, is_sample):
324
- input_mappings = _get_output_mappings(num_turns)
325
- output_mappings = input_mappings
326
-
327
- with Pipeline(name="sft") as pipeline:
328
- magpie = get_magpie_generator(num_turns, num_rows, system_prompt, is_sample)
329
- generate_response = get_response_generator(system_prompt, is_sample)
330
-
331
- keep_columns = KeepColumns(
332
- columns=list(output_mappings.values()) + ["model_name"],
333
- )
334
-
335
- magpie.connect(generate_response)
336
- generate_response.connect(keep_columns)
337
- return pipeline
338
-
339
-
340
- if __name__ == "__main__":
341
- prompt_generation_step = get_prompt_generator()
342
- system_prompt = next(
343
- prompt_generation_step.process(
344
- [
345
- {
346
- "system_prompt": PROMPT_CREATION_PROMPT,
347
- "instruction": DEFAULT_DATASET_DESCRIPTIONS[0],
348
- }
349
- ]
350
- )
351
- )[0]["generation"]
352
- num_rows = 2
353
- num_turns = 1
354
- magpie_generator = get_magpie_generator(num_turns, num_rows, system_prompt, False)
355
- response_generator = get_response_generator(num_turns, system_prompt, False)
356
- total_steps = num_rows * 2
357
- batch_size = 5 # Adjust this value as needed
358
-
359
- # create instructions
360
- magpie_results = []
361
- for i in range(0, num_rows, batch_size):
362
- batch = list(magpie_generator.process())[:batch_size]
363
- magpie_results.extend([item[0] for item in batch])
364
-
365
- # generate responses
366
- response_results = []
367
- if num_turns == 1:
368
- for i in range(0, len(magpie_results), batch_size):
369
- batch = magpie_results[i : i + batch_size]
370
- batch = [entry[0] for entry in batch]
371
- responses = list(response_generator.process(inputs=batch))
372
- response_results.extend(responses)
373
- for result in response_results:
374
- result[0]["prompt"] = result[0]["instruction"]
375
- result[0]["completion"] = result[0]["generation"]
376
- result[0]["system_prompt"] = system_prompt
377
- else:
378
- for result in magpie_results:
379
- result[0]["conversation"].insert(
380
- 0, {"role": "system", "content": system_prompt}
381
- )
382
- result[0]["messages"] = result[0]["conversation"]
383
- for i in range(0, len(magpie_results), batch_size):
384
- batch = magpie_results[i : i + batch_size]
385
- batch = [entry[0] for entry in batch]
386
- responses = list(response_generator.process(inputs=batch))
387
- response_results.extend(responses)
388
-
389
- for result in response_results:
390
- result[0]["messages"].append(
391
- {"role": "assistant", "content": result[0]["generation"]}
392
- )
393
-
394
- distiset_results = []
395
- for result in response_results[0]:
396
- record = {}
397
- for relevant_keys in [
398
- "messages",
399
- "prompt",
400
- "completion",
401
- "model_name",
402
- "system_prompt",
403
- ]:
404
- if relevant_keys in result:
405
- record[relevant_keys] = result[relevant_keys]
406
- distiset_results.append(record)
407
-
408
- distiset = Distiset(
409
- {
410
- "default": Dataset.from_list(distiset_results),
411
- }
412
- )
 
1
  import pandas as pd
 
 
2
  from distilabel.llms import InferenceEndpointsLLM
 
 
3
  from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
4
 
5
+ from src.distilabel_dataset_generator.pipelines.base import (
6
+ MODEL,
7
+ _get_next_api_key,
8
+ )
9
 
10
  INFORMATION_SEEKING_PROMPT = (
11
  "You are an AI assistant designed to provide accurate and concise information on a wide"
 
119
  User dataset description:
120
  """
121
 
 
122
  DEFAULT_DATASET_DESCRIPTIONS = (
123
  "rude customer assistant for a phone company",
124
  "assistant that solves math puzzles using python",
 
155
  "assistant",
156
  " \n\n",
157
  ]
 
 
158
 
159
 
160
  def _get_output_mappings(num_turns):
 
209
  return code
210
 
211
 
 
 
 
 
 
 
 
212
  def get_magpie_generator(num_turns, num_rows, system_prompt, is_sample):
213
  input_mappings = _get_output_mappings(num_turns)
214
  output_mappings = input_mappings.copy()
 
289
 
290
 
291
  def get_prompt_generator():
 
 
 
292
  prompt_generator = TextGeneration(
293
  llm=InferenceEndpointsLLM(
294
+ api_key=_get_next_api_key(),
295
  model_id=MODEL,
296
  tokenizer_id=MODEL,
297
  generation_kwargs={
 
304
  )
305
  prompt_generator.load()
306
  return prompt_generator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/distilabel_dataset_generator/pipelines/textcat.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import pandas as pd
4
+ from distilabel.llms import InferenceEndpointsLLM
5
+ from distilabel.steps.tasks import (
6
+ GenerateTextClassificationData,
7
+ TextClassification,
8
+ TextGeneration,
9
+ )
10
+ from src.distilabel_dataset_generator.pipelines.base import (
11
+ MODEL,
12
+ _get_next_api_key,
13
+ )
14
+ from src.distilabel_dataset_generator.utils import get_preprocess_labels
15
+
16
+ PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
17
+
18
+ Your task is to write a prompt following the instruction of the user. Respond with the prompt and nothing else.
19
+
20
+ The prompt you write should follow the same style and structure as the following example prompts, clearly specifying the possible classification labels.
21
+
22
+ If a label is composed of multiple words, use a hyphen to separate them. For example, 'smartphone-review', 'customer-service', 'product-quality'.:
23
+
24
+ Classify the following customer review of a cinema as either 'positive' or 'negative'.
25
+
26
+ Classify the following news article into one or more of the following categories: 'politics', 'sports', 'technology', 'entertainment', 'health', 'business', 'environment', 'education', 'science', 'international'.
27
+
28
+ Determine the sentiment of the following social media post: 'ambiguous', 'sarcastic', 'informative', 'emotional'.
29
+
30
+ Identify the issue category for the following technical support ticket: 'billing', 'technical', 'account', 'shipping', 'returns', 'installation', 'subscription'.
31
+
32
+ Classify the following movie review into one of the following categories: 'critical', 'praise', 'disappointed', 'enthusiastic'.
33
+
34
+ Determine the level of customer satisfaction from the following customer service transcript: 'satisfied', 'dissatisfied', 'highly-satisfied', 'somewhat-dissatisfied', 'indifferent'.
35
+
36
+ Categorize the following product description into one of the following product types: 'smartphone', 'laptop', 'tablet', 'smartwatch', 'e-reader', 'headphones'.
37
+
38
+ Classify the following tweet as expressing either 'support' or 'opposition' to the political event discussed.
39
+
40
+ Classify the following restaurant review into one of the following categories: 'food-quality', 'service', 'ambiance', or 'price'.
41
+
42
+ Classify the following blog post based on its primary fashion trend or style: 'casual', 'formal', 'streetwear', 'vintage' or 'sustainable-fashion'.
43
+
44
+ User dataset description:
45
+ """
46
+
47
+ DEFAULT_DATASET_DESCRIPTIONS = [
48
+ "A dataset covering customer reviews for an e-commerce website.",
49
+ "A dataset covering news articles about various topics.",
50
+ ]
51
+
52
+ DEFAULT_DATASETS = [
53
+ pd.DataFrame.from_dict(
54
+ {
55
+ "text": [
56
+ "I love the product! It's amazing and I'll buy it again.",
57
+ "The product was okay, but I wouldn't buy it again.",
58
+ ],
59
+ "label": ["positive", "negative"],
60
+ }
61
+ ),
62
+ pd.DataFrame.from_dict(
63
+ {
64
+ "text": [
65
+ "Yesterday, the US stock market had a significant increase.",
66
+ "New research suggests that the Earth is not a perfect sphere.",
67
+ ],
68
+ "labels": [["economy", "politics"], ["science", "environment"]],
69
+ }
70
+ ),
71
+ ]
72
+
73
+ DEFAULT_SYSTEM_PROMPTS = [
74
+ "Classify the following customer review as either 'positive' or 'negative'.",
75
+ "Classify the following news article into one of the following categories: 'politics', 'economy', 'environment', 'science', 'health'.",
76
+ ]
77
+
78
+
79
+ def generate_pipeline_code(
80
+ system_prompt: str,
81
+ difficulty: str = None,
82
+ clarity: str = None,
83
+ labels: List[str] = None,
84
+ num_labels: int = 1,
85
+ num_rows: int = 10,
86
+ ) -> str:
87
+ labels = get_preprocess_labels(labels)
88
+ base_code = f"""
89
+ # Requirements: `pip install distilabel[hf-inference-endpoints]`
90
+ import os
91
+ from distilabel.llms import InferenceEndpointsLLM
92
+ from distilabel.pipeline import Pipeline
93
+ from distilabel.steps import LoadDataFromDicts, KeepColumns
94
+ from distilabel.steps.tasks import {"GenerateTextClassificationData" if num_labels == 1 else "GenerateTextClassificationData, TextClassification"}
95
+
96
+ MODEL = "{MODEL}"
97
+ TEXT_CLASSIFICATION_TASK = "{system_prompt}"
98
+ os.environ["HF_TOKEN"] = (
99
+ "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
100
+ )
101
+
102
+ with Pipeline(name="textcat") as pipeline:
103
+
104
+ task_generator = LoadDataFromDicts(data=[{{"task": TEXT_CLASSIFICATION_TASK}}])
105
+
106
+ textcat_generation = GenerateTextClassificationData(
107
+ llm=InferenceEndpointsLLM(
108
+ model_id=MODEL,
109
+ tokenizer_id=MODEL,
110
+ api_key=os.environ["HF_TOKEN"],
111
+ generation_kwargs={{
112
+ "temperature": 0.8,
113
+ "max_new_tokens": 2048,
114
+ }},
115
+ ),
116
+ difficulty={None if difficulty == "mixed" else repr(difficulty)},
117
+ clarity={None if clarity == "mixed" else repr(clarity)},
118
+ num_generations={num_rows},
119
+ output_mappings={{"input_text": "text"}},
120
+ )
121
+ """
122
+
123
+ if num_labels == 1:
124
+ return (
125
+ base_code
126
+ + """
127
+ keep_columns = KeepColumns(
128
+ columns=["text", "label"],
129
+ )
130
+
131
+ # Connect steps in the pipeline
132
+ task_generator >> textcat_generation >> keep_columns
133
+
134
+ if __name__ == "__main__":
135
+ distiset = pipeline.run()
136
+ """
137
+ )
138
+
139
+ return (
140
+ base_code
141
+ + f"""
142
+ keep_columns = KeepColumns(
143
+ columns=["text"],
144
+ )
145
+
146
+ textcat_labeller = TextClassification(
147
+ llm=InferenceEndpointsLLM(
148
+ model_id=MODEL,
149
+ tokenizer_id=MODEL,
150
+ api_key=os.environ["HF_TOKEN"],
151
+ generation_kwargs={{
152
+ "temperature": 0.8,
153
+ "max_new_tokens": 2048,
154
+ }},
155
+ ),
156
+ n={num_labels},
157
+ available_labels={labels},
158
+ context=TEXT_CLASSIFICATION_TASK,
159
+ default_label="unknown"
160
+ )
161
+
162
+ # Connect steps in the pipeline
163
+ task_generator >> textcat_generation >> keep_columns >> textcat_labeller
164
+
165
+ if __name__ == "__main__":
166
+ distiset = pipeline.run()
167
+ """
168
+ )
169
+
170
+
171
+ def get_textcat_generator(difficulty, clarity, is_sample):
172
+ textcat_generator = GenerateTextClassificationData(
173
+ llm=InferenceEndpointsLLM(
174
+ model_id=MODEL,
175
+ tokenizer_id=MODEL,
176
+ api_key=_get_next_api_key(),
177
+ generation_kwargs={
178
+ "temperature": 0.8,
179
+ "max_new_tokens": 256 if is_sample else 1024,
180
+ },
181
+ ),
182
+ difficulty=None if difficulty == "mixed" else difficulty,
183
+ clarity=None if clarity == "mixed" else clarity,
184
+ )
185
+ textcat_generator.load()
186
+ return textcat_generator
187
+
188
+
189
+ def get_labeller_generator(system_prompt, labels, num_labels, is_sample):
190
+ labeller_generator = TextClassification(
191
+ llm=InferenceEndpointsLLM(
192
+ model_id=MODEL,
193
+ tokenizer_id=MODEL,
194
+ api_key=_get_next_api_key(),
195
+ generation_kwargs={
196
+ "temperature": 0.8,
197
+ "max_new_tokens": 256 if is_sample else 1024,
198
+ },
199
+ ),
200
+ context=system_prompt,
201
+ available_labels=labels,
202
+ n=num_labels,
203
+ default_label="unknown",
204
+ )
205
+ labeller_generator.load()
206
+ return labeller_generator
207
+
208
+
209
+ def get_prompt_generator():
210
+ prompt_generator = TextGeneration(
211
+ llm=InferenceEndpointsLLM(
212
+ api_key=_get_next_api_key(),
213
+ model_id=MODEL,
214
+ tokenizer_id=MODEL,
215
+ generation_kwargs={
216
+ "temperature": 0.8,
217
+ "max_new_tokens": 2048,
218
+ "do_sample": True,
219
+ },
220
+ ),
221
+ use_system_prompt=True,
222
+ )
223
+ prompt_generator.load()
224
+ return prompt_generator
src/distilabel_dataset_generator/utils.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- from typing import Union
3
 
4
  import argilla as rg
5
  import gradio as gr
@@ -12,6 +12,8 @@ from gradio.oauth import (
12
  )
13
  from huggingface_hub import whoami
14
 
 
 
15
  HF_TOKENS = [os.getenv("HF_TOKEN")] + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)]
16
  HF_TOKENS = [token for token in HF_TOKENS if token]
17
 
@@ -78,13 +80,35 @@ def get_token(oauth_token: OAuthToken = None):
78
  return ""
79
 
80
 
81
- def swap_visibilty(oauth_token: OAuthToken = None):
82
  if oauth_token:
83
  return gr.update(elem_classes=["main_ui_logged_in"])
84
  else:
85
  return gr.update(elem_classes=["main_ui_logged_out"])
86
 
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def get_argilla_client() -> Union[rg.Argilla, None]:
89
  try:
90
  api_url = os.getenv("ARGILLA_API_URL_SDG_REVIEWER")
@@ -98,3 +122,6 @@ def get_argilla_client() -> Union[rg.Argilla, None]:
98
  )
99
  except Exception:
100
  return None
 
 
 
 
1
  import os
2
+ from typing import Union, List, Optional
3
 
4
  import argilla as rg
5
  import gradio as gr
 
12
  )
13
  from huggingface_hub import whoami
14
 
15
+ _LOGGED_OUT_CSS = ".main_ui_logged_out{opacity: 0.3; pointer-events: none}"
16
+
17
  HF_TOKENS = [os.getenv("HF_TOKEN")] + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)]
18
  HF_TOKENS = [token for token in HF_TOKENS if token]
19
 
 
80
  return ""
81
 
82
 
83
+ def swap_visibilty(oauth_token: Optional[OAuthToken] = None):
84
  if oauth_token:
85
  return gr.update(elem_classes=["main_ui_logged_in"])
86
  else:
87
  return gr.update(elem_classes=["main_ui_logged_out"])
88
 
89
 
90
+ def get_base_app():
91
+ with gr.Blocks(
92
+ title="🧬 Synthetic Data Generator",
93
+ head="🧬 Synthetic Data Generator",
94
+ css=_LOGGED_OUT_CSS,
95
+ ) as app:
96
+ with gr.Row():
97
+ gr.Markdown(
98
+ "Want to run this locally or with other LLMs? Take a look at the FAQ tab. distilabel Synthetic Data Generator is free, we use the authentication token to push the dataset to the Hugging Face Hub and not for data generation."
99
+ )
100
+ with gr.Row():
101
+ gr.Column()
102
+ get_login_button()
103
+ gr.Column()
104
+
105
+ gr.Markdown("## Iterate on a sample dataset")
106
+ with gr.Column() as main_ui:
107
+ pass
108
+
109
+ return app
110
+
111
+
112
  def get_argilla_client() -> Union[rg.Argilla, None]:
113
  try:
114
  api_url = os.getenv("ARGILLA_API_URL_SDG_REVIEWER")
 
122
  )
123
  except Exception:
124
  return None
125
+
126
+ def get_preprocess_labels(labels: Optional[List[str]]) -> List[str]:
127
+ return [label.lower().strip() for label in labels] if labels else []