davidberenstein1957 HF staff commited on
Commit
b445efe
1 Parent(s): 35eb40e

feat: Add validation to check if data exists before pushing/generating

Browse files
src/distilabel_dataset_generator/apps/sft.py CHANGED
@@ -267,17 +267,19 @@ def push_to_argilla(
267
  ),
268
  ],
269
  questions=[
270
- rg.TextQuestion(
271
- name="correct_response",
272
- description="The corrected response from the assistant",
 
273
  ),
274
  ],
275
  metadata=[
276
  rg.IntegerMetadataProperty(
277
- name="messages_length", title="Messages Length"
278
  ),
279
  rg.IntegerMetadataProperty(
280
- name="response_length", title="Response Length"
 
281
  ),
282
  ],
283
  vectors=[
@@ -288,25 +290,28 @@ def push_to_argilla(
288
  ],
289
  guidelines="Please review the conversation and provide a score for the assistant's response.",
290
  )
291
- import pdb
292
 
293
- pdb.set_trace()
294
- dataframe["messages_length"] = dataframe["messages"].apply(
295
- lambda x: sum([len(y["content"]) for y in x])
 
 
 
 
296
  )
297
  dataframe["messages_embeddings"] = get_embeddings(
298
  dataframe["messages"].apply(
299
  lambda x: " ".join([y["content"] for y in x])
300
  )
301
  )
302
- dataframe["correct_response"] = dataframe["messages"].apply(
303
- lambda x: x[-1]["content"]
304
- )
305
- dataframe["response_length"] = dataframe["correct_response"].apply(len)
306
- dataframe["messages"] = dataframe["messages"].apply(lambda x: x[:-1])
307
  else:
308
  settings = rg.Settings(
309
  fields=[
 
 
 
 
 
310
  rg.TextField(
311
  name="prompt",
312
  description="The prompt used for the conversation",
@@ -317,13 +322,10 @@ def push_to_argilla(
317
  ),
318
  ],
319
  questions=[
320
- rg.TextQuestion(
321
- name="correct_prompt",
322
- description="The corrected prompt from the assistant",
323
- ),
324
- rg.TextQuestion(
325
- name="correct_completion",
326
- description="The corrected completion from the assistant",
327
  ),
328
  ],
329
  metadata=[
@@ -342,22 +344,20 @@ def push_to_argilla(
342
  ],
343
  guidelines="Please review the conversation and correct the prompt and completion where needed.",
344
  )
345
- dataframe["correct_prompt"] = dataframe["prompt"]
346
- dataframe["correct_completion"] = dataframe["completion"]
347
  dataframe["prompt_length"] = dataframe["prompt"].apply(len)
348
  dataframe["completion_length"] = dataframe["completion"].apply(len)
349
  dataframe["prompt_embeddings"] = get_embeddings(dataframe["prompt"])
350
 
351
  progress(0.5, desc="Creating dataset")
352
- if client.datasets(name=dataset_name, workspace=rg_user.username) is not None:
353
- raise gr.Error(f"Dataset {dataset_name} already exists")
354
- rg_dataset = rg.Dataset(
355
- name=dataset_name,
356
- workspace=rg_user.username,
357
- settings=settings,
358
- client=client,
359
- )
360
- rg_dataset = rg_dataset.create()
361
  progress(0.7, desc="Pushing dataset to Argilla")
362
  hf_dataset = Dataset.from_pandas(dataframe)
363
  rg_dataset.records.log(records=hf_dataset)
@@ -367,6 +367,23 @@ def push_to_argilla(
367
  return original_dataframe
368
 
369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  def upload_pipeline_code(
371
  pipeline_code,
372
  org_name,
@@ -469,7 +486,7 @@ with gr.Blocks(
469
  # Add a header for the full dataset generation section
470
  gr.Markdown("## Generate full dataset")
471
  gr.Markdown(
472
- "Once you're satisfied with the sample, generate a larger dataset and push it to the Hub."
473
  )
474
 
475
  with gr.Column() as push_to_hub_ui:
@@ -489,22 +506,31 @@ with gr.Blocks(
489
  maximum=500,
490
  info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
491
  )
 
492
  with gr.Tab(label="Argilla"):
493
- with gr.Row(variant="panel"):
494
- dataset_name = gr.Textbox(
495
- label="Dataset name",
496
- placeholder="dataset_name",
497
- value="my-distiset",
498
- )
499
- with gr.Row(variant="panel"):
500
- btn_generate_full_dataset_copy = gr.Button(
501
- value="Generate", variant="primary", scale=2
502
- )
503
- btn_generate_and_push_to_argilla = gr.Button(
504
- value="Generate and Push to Argilla", variant="primary", scale=2
505
- )
506
- btn_push_to_argilla = gr.Button(
507
- value="Push to Argilla", variant="primary", scale=2
 
 
 
 
 
 
 
 
508
  )
509
  with gr.Tab("Hugging Face Hub"):
510
  with gr.Row(variant="panel"):
@@ -554,10 +580,10 @@ with gr.Blocks(
554
  <a href="{argilla_api_url}" target="_blank" style="color: #1565c0; text-decoration: none;">
555
  {argilla_api_url}
556
  </a>
557
- Here are some docs to help you:
558
- <a href="https://docs.argilla.io/latest/getting_started/quickstart/#sign-in-into-the-argilla-ui" target="_blank">Login with OAuth</a>
559
- <a href="https://docs.argilla.io/latest/how_to_guides/annotate/" target="_blank">Curate your data</a>
560
- <a href="https://docs.argilla.io/latest/how_to_guides/import_export/" target="_blank">Export your data</a>
561
  </p>
562
  </div>
563
  """,
@@ -621,14 +647,19 @@ with gr.Blocks(
621
  )
622
 
623
  btn_generate_and_push_to_argilla.click(
 
 
 
 
 
624
  fn=hide_success_message,
625
  outputs=[success_message],
626
- ).then(
627
  fn=generate_dataset,
628
  inputs=[system_prompt, num_turns, num_rows],
629
  outputs=[final_dataset],
630
  show_progress=True,
631
- ).then(
632
  fn=push_to_argilla,
633
  inputs=[final_dataset, dataset_name],
634
  outputs=[final_dataset],
@@ -685,7 +716,12 @@ with gr.Blocks(
685
  btn_push_to_argilla.click(
686
  fn=hide_success_message,
687
  outputs=[success_message],
688
- ).then(
 
 
 
 
 
689
  fn=push_to_argilla,
690
  inputs=[final_dataset, dataset_name],
691
  outputs=[final_dataset],
 
267
  ),
268
  ],
269
  questions=[
270
+ rg.RatingQuestion(
271
+ name="rating",
272
+ description="The rating of the conversation",
273
+ values=list(range(1, 6)),
274
  ),
275
  ],
276
  metadata=[
277
  rg.IntegerMetadataProperty(
278
+ name="user_message_length", title="User Message Length"
279
  ),
280
  rg.IntegerMetadataProperty(
281
+ name="assistant_message_length",
282
+ title="Assistant Message Length",
283
  ),
284
  ],
285
  vectors=[
 
290
  ],
291
  guidelines="Please review the conversation and provide a score for the assistant's response.",
292
  )
 
293
 
294
+ dataframe["user_message_length"] = dataframe["messages"].apply(
295
+ lambda x: sum([len(y["content"]) for y in x if y["role"] == "user"])
296
+ )
297
+ dataframe["assistant_message_length"] = dataframe["messages"].apply(
298
+ lambda x: sum(
299
+ [len(y["content"]) for y in x if y["role"] == "assistant"]
300
+ )
301
  )
302
  dataframe["messages_embeddings"] = get_embeddings(
303
  dataframe["messages"].apply(
304
  lambda x: " ".join([y["content"] for y in x])
305
  )
306
  )
 
 
 
 
 
307
  else:
308
  settings = rg.Settings(
309
  fields=[
310
+ rg.TextField(
311
+ name="system_prompt",
312
+ description="The system prompt used for the conversation",
313
+ required=False,
314
+ ),
315
  rg.TextField(
316
  name="prompt",
317
  description="The prompt used for the conversation",
 
322
  ),
323
  ],
324
  questions=[
325
+ rg.RatingQuestion(
326
+ name="rating",
327
+ description="The rating of the conversation",
328
+ values=list(range(1, 6)),
 
 
 
329
  ),
330
  ],
331
  metadata=[
 
344
  ],
345
  guidelines="Please review the conversation and correct the prompt and completion where needed.",
346
  )
 
 
347
  dataframe["prompt_length"] = dataframe["prompt"].apply(len)
348
  dataframe["completion_length"] = dataframe["completion"].apply(len)
349
  dataframe["prompt_embeddings"] = get_embeddings(dataframe["prompt"])
350
 
351
  progress(0.5, desc="Creating dataset")
352
+ rg_dataset = client.datasets(name=dataset_name, workspace=rg_user.username)
353
+ if rg_dataset is None:
354
+ rg_dataset = rg.Dataset(
355
+ name=dataset_name,
356
+ workspace=rg_user.username,
357
+ settings=settings,
358
+ client=client,
359
+ )
360
+ rg_dataset = rg_dataset.create()
361
  progress(0.7, desc="Pushing dataset to Argilla")
362
  hf_dataset = Dataset.from_pandas(dataframe)
363
  rg_dataset.records.log(records=hf_dataset)
 
367
  return original_dataframe
368
 
369
 
370
+ def validate_argilla_dataset_name(
371
+ dataset_name: str,
372
+ final_dataset: pd.DataFrame,
373
+ oauth_token: Union[OAuthToken, None] = None,
374
+ progress=gr.Progress(),
375
+ ) -> str:
376
+ progress(0, desc="Validating dataset configuration")
377
+ hf_user = HfApi().whoami(token=oauth_token.token)["name"]
378
+ client = get_argilla_client()
379
+ if dataset_name is None or dataset_name == "":
380
+ raise gr.Error("Dataset name is required")
381
+ dataset = client.datasets(name=dataset_name, workspace=hf_user)
382
+ if dataset:
383
+ raise gr.Error(f"Dataset {dataset_name} already exists")
384
+ return final_dataset
385
+
386
+
387
  def upload_pipeline_code(
388
  pipeline_code,
389
  org_name,
 
486
  # Add a header for the full dataset generation section
487
  gr.Markdown("## Generate full dataset")
488
  gr.Markdown(
489
+ "Once you're satisfied with the sample, generate a larger dataset and push it to Argilla or the Hugging Face Hub."
490
  )
491
 
492
  with gr.Column() as push_to_hub_ui:
 
506
  maximum=500,
507
  info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
508
  )
509
+
510
  with gr.Tab(label="Argilla"):
511
+ if get_argilla_client():
512
+ with gr.Row(variant="panel"):
513
+ dataset_name = gr.Textbox(
514
+ label="Dataset name",
515
+ placeholder="dataset_name",
516
+ value="my-distiset",
517
+ )
518
+
519
+ with gr.Row(variant="panel"):
520
+ btn_generate_full_dataset_copy = gr.Button(
521
+ value="Generate", variant="primary", scale=2
522
+ )
523
+ btn_generate_and_push_to_argilla = gr.Button(
524
+ value="Generate and Push to Argilla",
525
+ variant="primary",
526
+ scale=2,
527
+ )
528
+ btn_push_to_argilla = gr.Button(
529
+ value="Push to Argilla", variant="primary", scale=2
530
+ )
531
+ else:
532
+ gr.Markdown(
533
+ "Please add `ARGILLA_API_URL` and `ARGILLA_API_KEY` to use Argilla."
534
  )
535
  with gr.Tab("Hugging Face Hub"):
536
  with gr.Row(variant="panel"):
 
580
  <a href="{argilla_api_url}" target="_blank" style="color: #1565c0; text-decoration: none;">
581
  {argilla_api_url}
582
  </a>
583
+ <br>Unfamiliar with Argilla? Here are some docs to help you get started:
584
+ <br>• <a href="https://docs.argilla.io/latest/getting_started/quickstart/#sign-in-into-the-argilla-ui" target="_blank">Login with OAuth</a>
585
+ <br>• <a href="https://docs.argilla.io/latest/how_to_guides/annotate/" target="_blank">Curate your data</a>
586
+ <br>• <a href="https://docs.argilla.io/latest/how_to_guides/import_export/" target="_blank">Export your data</a>
587
  </p>
588
  </div>
589
  """,
 
647
  )
648
 
649
  btn_generate_and_push_to_argilla.click(
650
+ fn=validate_argilla_dataset_name,
651
+ inputs=[dataset_name, final_dataset],
652
+ outputs=[final_dataset],
653
+ show_progress=True,
654
+ ).success(
655
  fn=hide_success_message,
656
  outputs=[success_message],
657
+ ).success(
658
  fn=generate_dataset,
659
  inputs=[system_prompt, num_turns, num_rows],
660
  outputs=[final_dataset],
661
  show_progress=True,
662
+ ).success(
663
  fn=push_to_argilla,
664
  inputs=[final_dataset, dataset_name],
665
  outputs=[final_dataset],
 
716
  btn_push_to_argilla.click(
717
  fn=hide_success_message,
718
  outputs=[success_message],
719
+ ).success(
720
+ fn=validate_argilla_dataset_name,
721
+ inputs=[dataset_name, final_dataset],
722
+ outputs=[final_dataset],
723
+ show_progress=True,
724
+ ).success(
725
  fn=push_to_argilla,
726
  inputs=[final_dataset, dataset_name],
727
  outputs=[final_dataset],
src/distilabel_dataset_generator/pipelines/sft.py CHANGED
@@ -189,7 +189,7 @@ with Pipeline(name="sft") as pipeline:
189
  tokenizer_id=MODEL,
190
  magpie_pre_query_template="llama3",
191
  generation_kwargs={{
192
- "temperature": 0.8,
193
  "do_sample": True,
194
  "max_new_tokens": 2048,
195
  "stop_sequences": {_STOP_SEQUENCES}
@@ -231,7 +231,7 @@ def get_magpie_generator(num_turns, num_rows, system_prompt, is_sample):
231
  api_key=_get_next_api_key(),
232
  magpie_pre_query_template="llama3",
233
  generation_kwargs={
234
- "temperature": 0.8,
235
  "do_sample": True,
236
  "max_new_tokens": 256 if is_sample else 512,
237
  "stop_sequences": _STOP_SEQUENCES,
@@ -250,7 +250,7 @@ def get_magpie_generator(num_turns, num_rows, system_prompt, is_sample):
250
  api_key=_get_next_api_key(),
251
  magpie_pre_query_template="llama3",
252
  generation_kwargs={
253
- "temperature": 0.8,
254
  "do_sample": True,
255
  "max_new_tokens": 256 if is_sample else 1024,
256
  "stop_sequences": _STOP_SEQUENCES,
 
189
  tokenizer_id=MODEL,
190
  magpie_pre_query_template="llama3",
191
  generation_kwargs={{
192
+ "temperature": 1,
193
  "do_sample": True,
194
  "max_new_tokens": 2048,
195
  "stop_sequences": {_STOP_SEQUENCES}
 
231
  api_key=_get_next_api_key(),
232
  magpie_pre_query_template="llama3",
233
  generation_kwargs={
234
+ "temperature": 1,
235
  "do_sample": True,
236
  "max_new_tokens": 256 if is_sample else 512,
237
  "stop_sequences": _STOP_SEQUENCES,
 
250
  api_key=_get_next_api_key(),
251
  magpie_pre_query_template="llama3",
252
  generation_kwargs={
253
+ "temperature": 1,
254
  "do_sample": True,
255
  "max_new_tokens": 256 if is_sample else 1024,
256
  "stop_sequences": _STOP_SEQUENCES,
src/distilabel_dataset_generator/utils.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
 
3
  import argilla as rg
4
  import gradio as gr
@@ -84,10 +85,13 @@ def swap_visibilty(oauth_token: OAuthToken = None):
84
  return gr.update(elem_classes=["main_ui_logged_out"])
85
 
86
 
87
- def get_argilla_client():
88
- return rg.Argilla(
89
- api_url=os.getenv("ARGILLA_API_URL_SDG_REVIEWER")
90
- or os.getenv("ARGILLA_API_URL"),
91
- api_key=os.getenv("ARGILLA_API_KEY_SDG_REVIEWER")
92
- or os.getenv("ARGILLA_API_KEY"),
93
- )
 
 
 
 
1
  import os
2
+ from typing import Union
3
 
4
  import argilla as rg
5
  import gradio as gr
 
85
  return gr.update(elem_classes=["main_ui_logged_out"])
86
 
87
 
88
+ def get_argilla_client() -> Union[rg.Argilla, None]:
89
+ try:
90
+ return rg.Argilla(
91
+ api_url=os.getenv("ARGILLA_API_URL_SDG_REVIEWER")
92
+ or os.getenv("ARGILLA_API_URL"),
93
+ api_key=os.getenv("ARGILLA_API_KEY_SDG_REVIEWER")
94
+ or os.getenv("ARGILLA_API_KEY"),
95
+ )
96
+ except Exception:
97
+ return None