davidberenstein1957 HF staff commited on
Commit
2e2beb7
2 Parent(s): f6a1e43 abbcdf5

Merge branch 'main' into pr/11

Browse files
app.py CHANGED
@@ -55,6 +55,17 @@ demo = gr.TabbedInterface(
55
  margin-bottom: 20px;
56
  }
57
  }
 
 
 
 
 
 
 
 
 
 
 
58
  </style>
59
  <div class="header-container">
60
  <div class="logo-container">
@@ -63,7 +74,7 @@ demo = gr.TabbedInterface(
63
  </a>
64
  </div>
65
  <div class="title-container">
66
- <h1 style="margin: 0; font-size: 2em;">🧬 Synthetic Data Generator</h1>
67
  <p style="margin: 10px 0 0 0; color: #666; font-size: 1.1em;">Build datasets using natural language</p>
68
  </div>
69
  </div>
 
55
  margin-bottom: 20px;
56
  }
57
  }
58
+ button[role="tab"].selected,
59
+ button[role="tab"][aria-selected="true"],
60
+ button[role="tab"][data-tab-id][aria-selected="true"] {
61
+ background-color: #000000;
62
+ color: white;
63
+ border: none;
64
+ font-size: 16px;
65
+ font-weight: bold;
66
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
67
+ transition: background-color 0.3s ease, color 0.3s ease;
68
+ }
69
  </style>
70
  <div class="header-container">
71
  <div class="logo-container">
 
74
  </a>
75
  </div>
76
  <div class="title-container">
77
+ <h1 style="margin: 0; font-size: 2em;">🧬 Synthetic Data Generator</h1>
78
  <p style="margin: 10px 0 0 0; color: #666; font-size: 1.1em;">Build datasets using natural language</p>
79
  </div>
80
  </div>
pdm.lock CHANGED
The diff for this file is too large to render. See raw diff
 
pyproject.toml CHANGED
@@ -6,11 +6,13 @@ authors = [
6
  {name = "davidberenstein1957", email = "[email protected]"},
7
  ]
8
  dependencies = [
9
- "distilabel[hf-inference-endpoints] @ git+https://github.com/argilla-io/distilabel.git@develop",
10
  "gradio[oauth]<5,>=4.38",
11
  "transformers>=4.44.2",
 
 
12
  ]
13
- requires-python = ">=3.10"
14
  readme = "README.md"
15
  license = {text = "apache 2"}
16
 
 
6
  {name = "davidberenstein1957", email = "[email protected]"},
7
  ]
8
  dependencies = [
9
+ "distilabel[hf-inference-endpoints,argilla]==1.4.0",
10
  "gradio[oauth]<5,>=4.38",
11
  "transformers>=4.44.2",
12
+ "sentence-transformers>=3.2.0",
13
+ "model2vec>=0.2.4",
14
  ]
15
+ requires-python = "<3.13,>=3.10"
16
  readme = "README.md"
17
  license = {text = "apache 2"}
18
 
requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
  transformers
2
  gradio[oauth]
3
- distilabel[hf-inference-endpoints] @ git+https://github.com/argilla-io/distilabel.git@develop
4
- beautifulsoup4
 
 
 
1
  transformers
2
  gradio[oauth]
3
+ distilabel[hf-inference-endpoints,argilla]
4
+ beautifulsoup4
5
+ sentence-transformers
6
+ model2vec
src/distilabel_dataset_generator/apps/sft.py CHANGED
@@ -1,6 +1,9 @@
 
1
  import io
2
- from typing import Union
 
3
 
 
4
  import gradio as gr
5
  import pandas as pd
6
  from datasets import Dataset
@@ -8,7 +11,12 @@ from distilabel.distiset import Distiset
8
  from distilabel.steps.tasks.text_generation import TextGeneration
9
  from gradio.oauth import OAuthToken
10
  from huggingface_hub import upload_file
 
11
 
 
 
 
 
12
  from src.distilabel_dataset_generator.pipelines.sft import (
13
  DEFAULT_BATCH_SIZE,
14
  DEFAULT_DATASET_DESCRIPTIONS,
@@ -21,12 +29,21 @@ from src.distilabel_dataset_generator.pipelines.sft import (
21
  get_response_generator,
22
  )
23
  from src.distilabel_dataset_generator.utils import (
 
24
  get_base_app,
25
  get_org_dropdown,
26
  swap_visibilty,
27
  )
28
 
29
 
 
 
 
 
 
 
 
 
30
  def generate_system_prompt(dataset_description, progress=gr.Progress()):
31
  progress(0.0, desc="Generating system prompt")
32
  if dataset_description in DEFAULT_DATASET_DESCRIPTIONS:
@@ -82,7 +99,7 @@ def generate_dataset(
82
  num_rows: int = 5,
83
  is_sample: bool = False,
84
  progress=gr.Progress(),
85
- ):
86
  progress(0.0, desc="(1/2) Generating instructions")
87
  magpie_generator = get_magpie_generator(
88
  num_turns, num_rows, system_prompt, is_sample
@@ -191,7 +208,12 @@ def push_to_hub(
191
  repo_name: str = None,
192
  oauth_token: Union[OAuthToken, None] = None,
193
  progress=gr.Progress(),
194
- ):
 
 
 
 
 
195
  progress(0.1, desc="Setting up dataset")
196
  repo_id = _check_push_to_hub(org_name, repo_name)
197
  distiset = Distiset(
@@ -208,7 +230,167 @@ def push_to_hub(
208
  create_pr=False,
209
  )
210
  progress(1.0, desc="Dataset pushed to hub")
211
- return dataframe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
 
214
  def upload_pipeline_code(
@@ -296,7 +478,7 @@ with get_base_app() as app:
296
  # Add a header for the full dataset generation section
297
  gr.Markdown("## Generate full dataset")
298
  gr.Markdown(
299
- "Once you're satisfied with the sample, generate a larger dataset and push it to the Hub."
300
  )
301
 
302
  with gr.Column() as push_to_hub_ui:
@@ -316,27 +498,64 @@ with get_base_app() as app:
316
  maximum=500,
317
  info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
318
  )
319
- with gr.Row(variant="panel"):
320
- org_name = get_org_dropdown()
321
- repo_name = gr.Textbox(
322
- label="Repo name", placeholder="dataset_name", value="my-distiset"
323
- )
324
- private = gr.Checkbox(
325
- label="Private dataset",
326
- value=True,
327
- interactive=True,
328
- scale=0.5,
329
- )
330
- with gr.Row() as regenerate_row:
331
- btn_generate_full_dataset = gr.Button(
332
- value="Generate", variant="primary", scale=2
333
- )
334
- btn_generate_and_push_to_hub = gr.Button(
335
- value="Generate and Push to Hub", variant="primary", scale=2
336
- )
337
- btn_push_to_hub = gr.Button(
338
- value="Push to Hub", variant="primary", scale=2
339
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  with gr.Row():
341
  final_dataset = gr.Dataframe(
342
  value=DEFAULT_DATASETS[0],
@@ -348,7 +567,28 @@ with get_base_app() as app:
348
  with gr.Row():
349
  success_message = gr.Markdown(visible=False)
350
 
351
- def show_success_message(org_name, repo_name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  return gr.Markdown(
353
  value=f"""
354
  <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
@@ -361,7 +601,7 @@ with get_base_app() as app:
361
  </a>
362
  </p>
363
  </div>
364
- """,
365
  visible=True,
366
  )
367
 
@@ -390,8 +630,11 @@ with get_base_app() as app:
390
  inputs=[sample_dataset],
391
  outputs=[final_dataset],
392
  )
393
-
394
- btn_generate_full_dataset.click(
 
 
 
395
  fn=hide_success_message,
396
  outputs=[success_message],
397
  ).then(
@@ -401,6 +644,30 @@ with get_base_app() as app:
401
  show_progress=True,
402
  )
403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  btn_generate_and_push_to_hub.click(
405
  fn=hide_success_message,
406
  outputs=[success_message],
@@ -420,7 +687,7 @@ with get_base_app() as app:
420
  outputs=[],
421
  show_progress=True,
422
  ).success(
423
- fn=show_success_message,
424
  inputs=[org_name, repo_name],
425
  outputs=[success_message],
426
  )
@@ -439,11 +706,30 @@ with get_base_app() as app:
439
  outputs=[],
440
  show_progress=True,
441
  ).success(
442
- fn=show_success_message,
443
  inputs=[org_name, repo_name],
444
  outputs=[success_message],
445
  )
446
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
  system_prompt.change(
448
  fn=generate_pipeline_code,
449
  inputs=[system_prompt, num_turns, num_rows],
 
1
+ import ast
2
  import io
3
+ import uuid
4
+ from typing import Dict, List, Union
5
 
6
+ import argilla as rg
7
  import gradio as gr
8
  import pandas as pd
9
  from datasets import Dataset
 
11
  from distilabel.steps.tasks.text_generation import TextGeneration
12
  from gradio.oauth import OAuthToken
13
  from huggingface_hub import upload_file
14
+ from huggingface_hub.hf_api import HfApi
15
 
16
+ from src.distilabel_dataset_generator.pipelines.embeddings import (
17
+ get_embeddings,
18
+ get_sentence_embedding_dimensions,
19
+ )
20
  from src.distilabel_dataset_generator.pipelines.sft import (
21
  DEFAULT_BATCH_SIZE,
22
  DEFAULT_DATASET_DESCRIPTIONS,
 
29
  get_response_generator,
30
  )
31
  from src.distilabel_dataset_generator.utils import (
32
+ get_argilla_client,
33
  get_base_app,
34
  get_org_dropdown,
35
  swap_visibilty,
36
  )
37
 
38
 
39
+ def convert_to_list_of_dicts(messages: str) -> List[Dict[str, str]]:
40
+ return ast.literal_eval(
41
+ messages.replace("'user'}", "'user'},")
42
+ .replace("'system'}", "'system'},")
43
+ .replace("'assistant'}", "'assistant'},")
44
+ )
45
+
46
+
47
  def generate_system_prompt(dataset_description, progress=gr.Progress()):
48
  progress(0.0, desc="Generating system prompt")
49
  if dataset_description in DEFAULT_DATASET_DESCRIPTIONS:
 
99
  num_rows: int = 5,
100
  is_sample: bool = False,
101
  progress=gr.Progress(),
102
+ ) -> pd.DataFrame:
103
  progress(0.0, desc="(1/2) Generating instructions")
104
  magpie_generator = get_magpie_generator(
105
  num_turns, num_rows, system_prompt, is_sample
 
208
  repo_name: str = None,
209
  oauth_token: Union[OAuthToken, None] = None,
210
  progress=gr.Progress(),
211
+ ) -> pd.DataFrame:
212
+ original_dataframe = dataframe.copy(deep=True)
213
+ if "messages" in dataframe.columns:
214
+ dataframe["messages"] = dataframe["messages"].apply(
215
+ lambda x: convert_to_list_of_dicts(x) if isinstance(x, str) else x
216
+ )
217
  progress(0.1, desc="Setting up dataset")
218
  repo_id = _check_push_to_hub(org_name, repo_name)
219
  distiset = Distiset(
 
230
  create_pr=False,
231
  )
232
  progress(1.0, desc="Dataset pushed to hub")
233
+ return original_dataframe
234
+
235
+
236
+ def push_to_argilla(
237
+ dataframe: pd.DataFrame,
238
+ dataset_name: str,
239
+ oauth_token: Union[OAuthToken, None] = None,
240
+ progress=gr.Progress(),
241
+ ) -> pd.DataFrame:
242
+ original_dataframe = dataframe.copy(deep=True)
243
+ if "messages" in dataframe.columns:
244
+ dataframe["messages"] = dataframe["messages"].apply(
245
+ lambda x: convert_to_list_of_dicts(x) if isinstance(x, str) else x
246
+ )
247
+ try:
248
+ progress(0.1, desc="Setting up user and workspace")
249
+ client = get_argilla_client()
250
+ hf_user = HfApi().whoami(token=oauth_token.token)["name"]
251
+ if "messages" in dataframe.columns:
252
+ settings = rg.Settings(
253
+ fields=[
254
+ rg.ChatField(
255
+ name="messages",
256
+ description="The messages in the conversation",
257
+ title="Messages",
258
+ ),
259
+ ],
260
+ questions=[
261
+ rg.RatingQuestion(
262
+ name="rating",
263
+ title="Rating",
264
+ description="The rating of the conversation",
265
+ values=list(range(1, 6)),
266
+ ),
267
+ ],
268
+ metadata=[
269
+ rg.IntegerMetadataProperty(
270
+ name="user_message_length", title="User Message Length"
271
+ ),
272
+ rg.IntegerMetadataProperty(
273
+ name="assistant_message_length",
274
+ title="Assistant Message Length",
275
+ ),
276
+ ],
277
+ vectors=[
278
+ rg.VectorField(
279
+ name="messages_embeddings",
280
+ dimensions=get_sentence_embedding_dimensions(),
281
+ )
282
+ ],
283
+ guidelines="Please review the conversation and provide a score for the assistant's response.",
284
+ )
285
+
286
+ dataframe["user_message_length"] = dataframe["messages"].apply(
287
+ lambda x: sum([len(y["content"]) for y in x if y["role"] == "user"])
288
+ )
289
+ dataframe["assistant_message_length"] = dataframe["messages"].apply(
290
+ lambda x: sum(
291
+ [len(y["content"]) for y in x if y["role"] == "assistant"]
292
+ )
293
+ )
294
+ dataframe["messages_embeddings"] = get_embeddings(
295
+ dataframe["messages"].apply(
296
+ lambda x: " ".join([y["content"] for y in x])
297
+ )
298
+ )
299
+ else:
300
+ settings = rg.Settings(
301
+ fields=[
302
+ rg.TextField(
303
+ name="system_prompt",
304
+ title="System Prompt",
305
+ description="The system prompt used for the conversation",
306
+ required=False,
307
+ ),
308
+ rg.TextField(
309
+ name="prompt",
310
+ title="Prompt",
311
+ description="The prompt used for the conversation",
312
+ ),
313
+ rg.TextField(
314
+ name="completion",
315
+ title="Completion",
316
+ description="The completion from the assistant",
317
+ ),
318
+ ],
319
+ questions=[
320
+ rg.RatingQuestion(
321
+ name="rating",
322
+ title="Rating",
323
+ description="The rating of the conversation",
324
+ values=list(range(1, 6)),
325
+ ),
326
+ ],
327
+ metadata=[
328
+ rg.IntegerMetadataProperty(
329
+ name="prompt_length", title="Prompt Length"
330
+ ),
331
+ rg.IntegerMetadataProperty(
332
+ name="completion_length", title="Completion Length"
333
+ ),
334
+ ],
335
+ vectors=[
336
+ rg.VectorField(
337
+ name="prompt_embeddings",
338
+ dimensions=get_sentence_embedding_dimensions(),
339
+ )
340
+ ],
341
+ guidelines="Please review the conversation and correct the prompt and completion where needed.",
342
+ )
343
+ dataframe["prompt_length"] = dataframe["prompt"].apply(len)
344
+ dataframe["completion_length"] = dataframe["completion"].apply(len)
345
+ dataframe["prompt_embeddings"] = get_embeddings(dataframe["prompt"])
346
+
347
+ progress(0.5, desc="Creating dataset")
348
+ rg_dataset = client.datasets(name=dataset_name, workspace=hf_user)
349
+ if rg_dataset is None:
350
+ rg_dataset = rg.Dataset(
351
+ name=dataset_name,
352
+ workspace=hf_user,
353
+ settings=settings,
354
+ client=client,
355
+ )
356
+ rg_dataset = rg_dataset.create()
357
+ progress(0.7, desc="Pushing dataset to Argilla")
358
+ hf_dataset = Dataset.from_pandas(dataframe)
359
+ rg_dataset.records.log(records=hf_dataset)
360
+ progress(1.0, desc="Dataset pushed to Argilla")
361
+ except Exception as e:
362
+ raise gr.Error(f"Error pushing dataset to Argilla: {e}")
363
+ return original_dataframe
364
+
365
+
366
+ def validate_argilla_dataset_name(
367
+ dataset_name: str,
368
+ final_dataset: pd.DataFrame,
369
+ add_to_existing_dataset: bool,
370
+ oauth_token: Union[OAuthToken, None] = None,
371
+ progress=gr.Progress(),
372
+ ) -> str:
373
+ progress(0, desc="Validating dataset configuration")
374
+ hf_user = HfApi().whoami(token=oauth_token.token)["name"]
375
+ client = get_argilla_client()
376
+ if dataset_name is None or dataset_name == "":
377
+ raise gr.Error("Dataset name is required")
378
+ # Create user if it doesn't exist
379
+ rg_user = client.users(username=hf_user)
380
+ if rg_user is None:
381
+ rg_user = client.users.add(
382
+ rg.User(username=hf_user, role="admin", password=str(uuid.uuid4()))
383
+ )
384
+ # Create workspace if it doesn't exist
385
+ workspace = client.workspaces(name=hf_user)
386
+ if workspace is None:
387
+ workspace = client.workspaces.add(rg.Workspace(name=hf_user))
388
+ workspace.add_user(rg_user)
389
+ # Check if dataset exists
390
+ dataset = client.datasets(name=dataset_name, workspace=hf_user)
391
+ if dataset and not add_to_existing_dataset:
392
+ raise gr.Error(f"Dataset {dataset_name} already exists")
393
+ return final_dataset
394
 
395
 
396
  def upload_pipeline_code(
 
478
  # Add a header for the full dataset generation section
479
  gr.Markdown("## Generate full dataset")
480
  gr.Markdown(
481
+ "Once you're satisfied with the sample, generate a larger dataset and push it to Argilla or the Hugging Face Hub."
482
  )
483
 
484
  with gr.Column() as push_to_hub_ui:
 
498
  maximum=500,
499
  info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
500
  )
501
+
502
+ with gr.Tab(label="Argilla"):
503
+ if get_argilla_client() is not None:
504
+ with gr.Row(variant="panel"):
505
+ dataset_name = gr.Textbox(
506
+ label="Dataset name",
507
+ placeholder="dataset_name",
508
+ value="my-distiset",
509
+ )
510
+ add_to_existing_dataset = gr.Checkbox(
511
+ label="Allow adding records to existing dataset",
512
+ info="When selected, you do need to ensure the number of turns in the conversation is the same as the number of turns in the existing dataset.",
513
+ value=False,
514
+ interactive=True,
515
+ scale=0.5,
516
+ )
517
+
518
+ with gr.Row(variant="panel"):
519
+ btn_generate_full_dataset_copy = gr.Button(
520
+ value="Generate", variant="primary", scale=2
521
+ )
522
+ btn_generate_and_push_to_argilla = gr.Button(
523
+ value="Generate and Push to Argilla",
524
+ variant="primary",
525
+ scale=2,
526
+ )
527
+ btn_push_to_argilla = gr.Button(
528
+ value="Push to Argilla", variant="primary", scale=2
529
+ )
530
+ else:
531
+ gr.Markdown(
532
+ "Please add `ARGILLA_API_URL` and `ARGILLA_API_KEY` to use Argilla or export the dataset to the Hugging Face Hub."
533
+ )
534
+ with gr.Tab("Hugging Face Hub"):
535
+ with gr.Row(variant="panel"):
536
+ org_name = get_org_dropdown()
537
+ repo_name = gr.Textbox(
538
+ label="Repo name",
539
+ placeholder="dataset_name",
540
+ value="my-distiset",
541
+ )
542
+ private = gr.Checkbox(
543
+ label="Private dataset",
544
+ value=True,
545
+ interactive=True,
546
+ scale=0.5,
547
+ )
548
+ with gr.Row(variant="panel"):
549
+ btn_generate_full_dataset = gr.Button(
550
+ value="Generate", variant="primary", scale=2
551
+ )
552
+ btn_generate_and_push_to_hub = gr.Button(
553
+ value="Generate and Push to Hub", variant="primary", scale=2
554
+ )
555
+ btn_push_to_hub = gr.Button(
556
+ value="Push to Hub", variant="primary", scale=2
557
+ )
558
+
559
  with gr.Row():
560
  final_dataset = gr.Dataframe(
561
  value=DEFAULT_DATASETS[0],
 
567
  with gr.Row():
568
  success_message = gr.Markdown(visible=False)
569
 
570
+ def show_success_message_argilla():
571
+ client = get_argilla_client()
572
+ argilla_api_url = client.api_url
573
+ return gr.Markdown(
574
+ value=f"""
575
+ <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
576
+ <h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
577
+ <p style="margin-top: 0.5em;">
578
+ Your dataset is now available at:
579
+ <a href="{argilla_api_url}" target="_blank" style="color: #1565c0; text-decoration: none;">
580
+ {argilla_api_url}
581
+ </a>
582
+ <br>Unfamiliar with Argilla? Here are some docs to help you get started:
583
+ <br>• <a href="https://docs.argilla.io/latest/how_to_guides/annotate/" target="_blank">How to curate data in Argilla</a>
584
+ <br>• <a href="https://docs.argilla.io/latest/how_to_guides/import_export/" target="_blank">How to export data once you have reviewed the dataset</a>
585
+ </p>
586
+ </div>
587
+ """,
588
+ visible=True,
589
+ )
590
+
591
+ def show_success_message_hub(org_name, repo_name):
592
  return gr.Markdown(
593
  value=f"""
594
  <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
 
601
  </a>
602
  </p>
603
  </div>
604
+ """,
605
  visible=True,
606
  )
607
 
 
630
  inputs=[sample_dataset],
631
  outputs=[final_dataset],
632
  )
633
+ gr.on(
634
+ triggers=[
635
+ btn_generate_full_dataset.click,
636
+ btn_generate_full_dataset_copy.click,
637
+ ],
638
  fn=hide_success_message,
639
  outputs=[success_message],
640
  ).then(
 
644
  show_progress=True,
645
  )
646
 
647
+ btn_generate_and_push_to_argilla.click(
648
+ fn=validate_argilla_dataset_name,
649
+ inputs=[dataset_name, final_dataset, add_to_existing_dataset],
650
+ outputs=[final_dataset],
651
+ show_progress=True,
652
+ ).success(
653
+ fn=hide_success_message,
654
+ outputs=[success_message],
655
+ ).success(
656
+ fn=generate_dataset,
657
+ inputs=[system_prompt, num_turns, num_rows],
658
+ outputs=[final_dataset],
659
+ show_progress=True,
660
+ ).success(
661
+ fn=push_to_argilla,
662
+ inputs=[final_dataset, dataset_name],
663
+ outputs=[final_dataset],
664
+ show_progress=True,
665
+ ).success(
666
+ fn=show_success_message_argilla,
667
+ inputs=[],
668
+ outputs=[success_message],
669
+ )
670
+
671
  btn_generate_and_push_to_hub.click(
672
  fn=hide_success_message,
673
  outputs=[success_message],
 
687
  outputs=[],
688
  show_progress=True,
689
  ).success(
690
+ fn=show_success_message_hub,
691
  inputs=[org_name, repo_name],
692
  outputs=[success_message],
693
  )
 
706
  outputs=[],
707
  show_progress=True,
708
  ).success(
709
+ fn=show_success_message_hub,
710
  inputs=[org_name, repo_name],
711
  outputs=[success_message],
712
  )
713
 
714
+ btn_push_to_argilla.click(
715
+ fn=hide_success_message,
716
+ outputs=[success_message],
717
+ ).success(
718
+ fn=validate_argilla_dataset_name,
719
+ inputs=[dataset_name, final_dataset, add_to_existing_dataset],
720
+ outputs=[final_dataset],
721
+ show_progress=True,
722
+ ).success(
723
+ fn=push_to_argilla,
724
+ inputs=[final_dataset, dataset_name],
725
+ outputs=[final_dataset],
726
+ show_progress=True,
727
+ ).success(
728
+ fn=show_success_message_argilla,
729
+ inputs=[],
730
+ outputs=[success_message],
731
+ )
732
+
733
  system_prompt.change(
734
  fn=generate_pipeline_code,
735
  inputs=[system_prompt, num_turns, num_rows],
src/distilabel_dataset_generator/pipelines/embeddings.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from sentence_transformers import SentenceTransformer
4
+ from sentence_transformers.models import StaticEmbedding
5
+
6
+ # Initialize a StaticEmbedding module
7
+ static_embedding = StaticEmbedding.from_model2vec("minishlab/M2V_base_output")
8
+ model = SentenceTransformer(modules=[static_embedding])
9
+
10
+
11
+ def get_embeddings(texts: List[str]) -> List[List[float]]:
12
+ return [embedding.tolist() for embedding in model.encode(texts)]
13
+
14
+
15
+ def get_sentence_embedding_dimensions() -> int:
16
+ return model.get_sentence_embedding_dimension()
src/distilabel_dataset_generator/pipelines/sft.py CHANGED
@@ -189,7 +189,7 @@ with Pipeline(name="sft") as pipeline:
189
  tokenizer_id=MODEL,
190
  magpie_pre_query_template="llama3",
191
  generation_kwargs={{
192
- "temperature": 0.8,
193
  "do_sample": True,
194
  "max_new_tokens": 2048,
195
  "stop_sequences": {_STOP_SEQUENCES}
@@ -231,7 +231,7 @@ def get_magpie_generator(num_turns, num_rows, system_prompt, is_sample):
231
  api_key=_get_next_api_key(),
232
  magpie_pre_query_template="llama3",
233
  generation_kwargs={
234
- "temperature": 0.8,
235
  "do_sample": True,
236
  "max_new_tokens": 256 if is_sample else 512,
237
  "stop_sequences": _STOP_SEQUENCES,
@@ -250,7 +250,7 @@ def get_magpie_generator(num_turns, num_rows, system_prompt, is_sample):
250
  api_key=_get_next_api_key(),
251
  magpie_pre_query_template="llama3",
252
  generation_kwargs={
253
- "temperature": 0.8,
254
  "do_sample": True,
255
  "max_new_tokens": 256 if is_sample else 1024,
256
  "stop_sequences": _STOP_SEQUENCES,
 
189
  tokenizer_id=MODEL,
190
  magpie_pre_query_template="llama3",
191
  generation_kwargs={{
192
+ "temperature": 0.9,
193
  "do_sample": True,
194
  "max_new_tokens": 2048,
195
  "stop_sequences": {_STOP_SEQUENCES}
 
231
  api_key=_get_next_api_key(),
232
  magpie_pre_query_template="llama3",
233
  generation_kwargs={
234
+ "temperature": 0.9,
235
  "do_sample": True,
236
  "max_new_tokens": 256 if is_sample else 512,
237
  "stop_sequences": _STOP_SEQUENCES,
 
250
  api_key=_get_next_api_key(),
251
  magpie_pre_query_template="llama3",
252
  generation_kwargs={
253
+ "temperature": 0.9,
254
  "do_sample": True,
255
  "max_new_tokens": 256 if is_sample else 1024,
256
  "stop_sequences": _STOP_SEQUENCES,
src/distilabel_dataset_generator/utils.py CHANGED
@@ -1,5 +1,7 @@
1
  import os
 
2
 
 
3
  import gradio as gr
4
  from gradio.oauth import (
5
  OAUTH_CLIENT_ID,
@@ -10,6 +12,8 @@ from gradio.oauth import (
10
  )
11
  from huggingface_hub import whoami
12
 
 
 
13
  HF_TOKENS = [os.getenv("HF_TOKEN")] + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)]
14
  HF_TOKENS = [token for token in HF_TOKENS if token]
15
 
@@ -105,4 +109,16 @@ def get_base_app():
105
  return app
106
 
107
 
108
- _LOGGED_OUT_CSS = ".main_ui_logged_out{opacity: 0.3; pointer-events: none}"
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from typing import Union
3
 
4
+ import argilla as rg
5
  import gradio as gr
6
  from gradio.oauth import (
7
  OAUTH_CLIENT_ID,
 
12
  )
13
  from huggingface_hub import whoami
14
 
15
+ _LOGGED_OUT_CSS = ".main_ui_logged_out{opacity: 0.3; pointer-events: none}"
16
+
17
  HF_TOKENS = [os.getenv("HF_TOKEN")] + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)]
18
  HF_TOKENS = [token for token in HF_TOKENS if token]
19
 
 
109
  return app
110
 
111
 
112
+ def get_argilla_client() -> Union[rg.Argilla, None]:
113
+ try:
114
+ api_url = os.getenv("ARGILLA_API_URL_SDG_REVIEWER")
115
+ api_key = os.getenv("ARGILLA_API_KEY_SDG_REVIEWER")
116
+ if api_url is None or api_key is None:
117
+ api_url = os.getenv("ARGILLA_API_URL")
118
+ api_key = os.getenv("ARGILLA_API_KEY")
119
+ return rg.Argilla(
120
+ api_url=api_url,
121
+ api_key=api_key,
122
+ )
123
+ except Exception:
124
+ return None