davidberenstein1957 HF staff commited on
Commit
f1f92f7
1 Parent(s): 3445828

feat: Add Argilla review integration

Browse files
app.py CHANGED
@@ -54,6 +54,10 @@ demo = gr.TabbedInterface(
54
  margin-bottom: 20px;
55
  }
56
  }
 
 
 
 
57
  </style>
58
  <div class="header-container">
59
  <div class="logo-container">
@@ -62,7 +66,7 @@ demo = gr.TabbedInterface(
62
  </a>
63
  </div>
64
  <div class="title-container">
65
- <h1 style="margin: 0; font-size: 2em;">🧬 Synthetic Data Generator</h1>
66
  <p style="margin: 10px 0 0 0; color: #666; font-size: 1.1em;">Build datasets using natural language</p>
67
  </div>
68
  </div>
 
54
  margin-bottom: 20px;
55
  }
56
  }
57
+ button[role="tab"].selected {
58
+ color: black;
59
+ font-weight: bold;
60
+ }
61
  </style>
62
  <div class="header-container">
63
  <div class="logo-container">
 
66
  </a>
67
  </div>
68
  <div class="title-container">
69
+ <h1 style="margin: 0; font-size: 2em;">🧬 Synthetic Data Generator</h1>
70
  <p style="margin: 10px 0 0 0; color: #666; font-size: 1.1em;">Build datasets using natural language</p>
71
  </div>
72
  </div>
pdm.lock CHANGED
The diff for this file is too large to render. See raw diff
 
pyproject.toml CHANGED
@@ -6,11 +6,12 @@ authors = [
6
  {name = "davidberenstein1957", email = "[email protected]"},
7
  ]
8
  dependencies = [
9
- "distilabel[hf-inference-endpoints] @ git+https://github.com/argilla-io/distilabel.git@develop",
10
  "gradio[oauth]<5,>=4.38",
11
  "transformers>=4.44.2",
 
12
  ]
13
- requires-python = ">=3.10"
14
  readme = "README.md"
15
  license = {text = "apache 2"}
16
 
 
6
  {name = "davidberenstein1957", email = "[email protected]"},
7
  ]
8
  dependencies = [
9
+ "distilabel[hf-inference-endpoints,argilla]==1.4.0",
10
  "gradio[oauth]<5,>=4.38",
11
  "transformers>=4.44.2",
12
+ "sentence-transformers>=3.2.0",
13
  ]
14
+ requires-python = "<3.13,>=3.10"
15
  readme = "README.md"
16
  license = {text = "apache 2"}
17
 
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  transformers
2
  gradio[oauth]
3
- distilabel[hf-inference-endpoints] @ git+https://github.com/argilla-io/distilabel.git@develop
4
- beautifulsoup4
 
 
1
  transformers
2
  gradio[oauth]
3
+ distilabel[hf-inference-endpoints,argilla]
4
+ beautifulsoup4
5
+ sentence-transformers
src/distilabel_dataset_generator/apps/sft.py CHANGED
@@ -1,6 +1,8 @@
 
1
  import io
2
- from typing import Union
3
 
 
4
  import gradio as gr
5
  import pandas as pd
6
  from datasets import Dataset
@@ -8,7 +10,12 @@ from distilabel.distiset import Distiset
8
  from distilabel.steps.tasks.text_generation import TextGeneration
9
  from gradio.oauth import OAuthToken
10
  from huggingface_hub import upload_file
 
11
 
 
 
 
 
12
  from src.distilabel_dataset_generator.pipelines.sft import (
13
  DEFAULT_BATCH_SIZE,
14
  DEFAULT_DATASET_DESCRIPTIONS,
@@ -21,12 +28,21 @@ from src.distilabel_dataset_generator.pipelines.sft import (
21
  get_response_generator,
22
  )
23
  from src.distilabel_dataset_generator.utils import (
 
24
  get_login_button,
25
  get_org_dropdown,
26
  swap_visibilty,
27
  )
28
 
29
 
 
 
 
 
 
 
 
 
30
  def generate_system_prompt(dataset_description, progress=gr.Progress()):
31
  progress(0.0, desc="Generating system prompt")
32
  if dataset_description in DEFAULT_DATASET_DESCRIPTIONS:
@@ -82,7 +98,7 @@ def generate_dataset(
82
  num_rows: int = 5,
83
  is_sample: bool = False,
84
  progress=gr.Progress(),
85
- ):
86
  progress(0.0, desc="(1/2) Generating instructions")
87
  magpie_generator = get_magpie_generator(
88
  num_turns, num_rows, system_prompt, is_sample
@@ -191,7 +207,12 @@ def push_to_hub(
191
  repo_name: str = None,
192
  oauth_token: Union[OAuthToken, None] = None,
193
  progress=gr.Progress(),
194
- ):
 
 
 
 
 
195
  progress(0.1, desc="Setting up dataset")
196
  repo_id = _check_push_to_hub(org_name, repo_name)
197
  distiset = Distiset(
@@ -208,7 +229,142 @@ def push_to_hub(
208
  create_pr=False,
209
  )
210
  progress(1.0, desc="Dataset pushed to hub")
211
- return dataframe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
 
214
  def upload_pipeline_code(
@@ -333,27 +489,47 @@ with gr.Blocks(
333
  maximum=500,
334
  info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
335
  )
336
- with gr.Row(variant="panel"):
337
- org_name = get_org_dropdown()
338
- repo_name = gr.Textbox(
339
- label="Repo name", placeholder="dataset_name", value="my-distiset"
340
- )
341
- private = gr.Checkbox(
342
- label="Private dataset",
343
- value=True,
344
- interactive=True,
345
- scale=0.5,
346
- )
347
- with gr.Row() as regenerate_row:
348
- btn_generate_full_dataset = gr.Button(
349
- value="Generate", variant="primary", scale=2
350
- )
351
- btn_generate_and_push_to_hub = gr.Button(
352
- value="Generate and Push to Hub", variant="primary", scale=2
353
- )
354
- btn_push_to_hub = gr.Button(
355
- value="Push to Hub", variant="primary", scale=2
356
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  with gr.Row():
358
  final_dataset = gr.Dataframe(
359
  value=DEFAULT_DATASETS[0],
@@ -365,7 +541,29 @@ with gr.Blocks(
365
  with gr.Row():
366
  success_message = gr.Markdown(visible=False)
367
 
368
- def show_success_message(org_name, repo_name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
  return gr.Markdown(
370
  value=f"""
371
  <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
@@ -378,7 +576,7 @@ with gr.Blocks(
378
  </a>
379
  </p>
380
  </div>
381
- """,
382
  visible=True,
383
  )
384
 
@@ -407,8 +605,21 @@ with gr.Blocks(
407
  inputs=[sample_dataset],
408
  outputs=[final_dataset],
409
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
410
 
411
- btn_generate_full_dataset.click(
412
  fn=hide_success_message,
413
  outputs=[success_message],
414
  ).then(
@@ -416,6 +627,15 @@ with gr.Blocks(
416
  inputs=[system_prompt, num_turns, num_rows],
417
  outputs=[final_dataset],
418
  show_progress=True,
 
 
 
 
 
 
 
 
 
419
  )
420
 
421
  btn_generate_and_push_to_hub.click(
@@ -437,7 +657,7 @@ with gr.Blocks(
437
  outputs=[],
438
  show_progress=True,
439
  ).success(
440
- fn=show_success_message,
441
  inputs=[org_name, repo_name],
442
  outputs=[success_message],
443
  )
@@ -456,11 +676,25 @@ with gr.Blocks(
456
  outputs=[],
457
  show_progress=True,
458
  ).success(
459
- fn=show_success_message,
460
  inputs=[org_name, repo_name],
461
  outputs=[success_message],
462
  )
463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  system_prompt.change(
465
  fn=generate_pipeline_code,
466
  inputs=[system_prompt, num_turns, num_rows],
 
1
+ import ast
2
  import io
3
+ from typing import Dict, List, Union
4
 
5
+ import argilla as rg
6
  import gradio as gr
7
  import pandas as pd
8
  from datasets import Dataset
 
10
  from distilabel.steps.tasks.text_generation import TextGeneration
11
  from gradio.oauth import OAuthToken
12
  from huggingface_hub import upload_file
13
+ from huggingface_hub.hf_api import HfApi
14
 
15
+ from src.distilabel_dataset_generator.pipelines.embeddings import (
16
+ get_embeddings,
17
+ get_sentence_embedding_dimensions,
18
+ )
19
  from src.distilabel_dataset_generator.pipelines.sft import (
20
  DEFAULT_BATCH_SIZE,
21
  DEFAULT_DATASET_DESCRIPTIONS,
 
28
  get_response_generator,
29
  )
30
  from src.distilabel_dataset_generator.utils import (
31
+ get_argilla_client,
32
  get_login_button,
33
  get_org_dropdown,
34
  swap_visibilty,
35
  )
36
 
37
 
38
+ def convert_to_list_of_dicts(messages: str) -> List[Dict[str, str]]:
39
+ return ast.literal_eval(
40
+ messages.replace("'user'}", "'user'},")
41
+ .replace("'system'}", "'system'},")
42
+ .replace("'assistant'}", "'assistant'},")
43
+ )
44
+
45
+
46
  def generate_system_prompt(dataset_description, progress=gr.Progress()):
47
  progress(0.0, desc="Generating system prompt")
48
  if dataset_description in DEFAULT_DATASET_DESCRIPTIONS:
 
98
  num_rows: int = 5,
99
  is_sample: bool = False,
100
  progress=gr.Progress(),
101
+ ) -> pd.DataFrame:
102
  progress(0.0, desc="(1/2) Generating instructions")
103
  magpie_generator = get_magpie_generator(
104
  num_turns, num_rows, system_prompt, is_sample
 
207
  repo_name: str = None,
208
  oauth_token: Union[OAuthToken, None] = None,
209
  progress=gr.Progress(),
210
+ ) -> pd.DataFrame:
211
+ original_dataframe = dataframe.copy(deep=True)
212
+ if "messages" in dataframe.columns:
213
+ dataframe["messages"] = dataframe["messages"].apply(
214
+ lambda x: convert_to_list_of_dicts(x) if isinstance(x, str) else x
215
+ )
216
  progress(0.1, desc="Setting up dataset")
217
  repo_id = _check_push_to_hub(org_name, repo_name)
218
  distiset = Distiset(
 
229
  create_pr=False,
230
  )
231
  progress(1.0, desc="Dataset pushed to hub")
232
+ return original_dataframe
233
+
234
+
235
+ def push_to_argilla(
236
+ dataframe: pd.DataFrame,
237
+ dataset_name: str,
238
+ oauth_token: Union[OAuthToken, None] = None,
239
+ progress=gr.Progress(),
240
+ ) -> pd.DataFrame:
241
+ original_dataframe = dataframe.copy(deep=True)
242
+ if "messages" in dataframe.columns:
243
+ dataframe["messages"] = dataframe["messages"].apply(
244
+ lambda x: convert_to_list_of_dicts(x) if isinstance(x, str) else x
245
+ )
246
+ try:
247
+ progress(0.1, desc="Setting up user and workspace")
248
+ client = get_argilla_client()
249
+ hf_user = HfApi().whoami(token=oauth_token.token)["name"]
250
+
251
+ # Create user if it doesn't exist
252
+ rg_user = client.users(username=hf_user)
253
+ if rg_user is None:
254
+ rg_user = client.users.add(rg.User(username=hf_user, role="admin"))
255
+
256
+ # Create workspace if it doesn't exist
257
+ workspace = client.workspaces(name=rg_user.username)
258
+ if workspace is None:
259
+ workspace = client.workspaces.add(rg.Workspace(name=rg_user.username))
260
+ workspace.add_user(rg_user)
261
+
262
+ if "messages" in dataframe.columns:
263
+ settings = rg.Settings(
264
+ fields=[
265
+ rg.ChatField(
266
+ name="messages", description="The messages in the conversation"
267
+ ),
268
+ ],
269
+ questions=[
270
+ rg.TextQuestion(
271
+ name="correct_response",
272
+ description="The corrected response from the assistant",
273
+ ),
274
+ ],
275
+ metadata=[
276
+ rg.IntegerMetadataProperty(
277
+ name="messages_length", title="Messages Length"
278
+ ),
279
+ rg.IntegerMetadataProperty(
280
+ name="response_length", title="Response Length"
281
+ ),
282
+ ],
283
+ vectors=[
284
+ rg.VectorField(
285
+ name="messages_embeddings",
286
+ dimensions=get_sentence_embedding_dimensions(),
287
+ )
288
+ ],
289
+ guidelines="Please review the conversation and provide a score for the assistant's response.",
290
+ )
291
+ import pdb
292
+
293
+ pdb.set_trace()
294
+ dataframe["messages_length"] = dataframe["messages"].apply(
295
+ lambda x: sum([len(y["content"]) for y in x])
296
+ )
297
+ dataframe["messages_embeddings"] = get_embeddings(
298
+ dataframe["messages"].apply(
299
+ lambda x: " ".join([y["content"] for y in x])
300
+ )
301
+ )
302
+ dataframe["correct_response"] = dataframe["messages"].apply(
303
+ lambda x: x[-1]["content"]
304
+ )
305
+ dataframe["response_length"] = dataframe["correct_response"].apply(len)
306
+ dataframe["messages"] = dataframe["messages"].apply(lambda x: x[:-1])
307
+ else:
308
+ settings = rg.Settings(
309
+ fields=[
310
+ rg.TextField(
311
+ name="prompt",
312
+ description="The prompt used for the conversation",
313
+ ),
314
+ rg.TextField(
315
+ name="completion",
316
+ description="The completion from the assistant",
317
+ ),
318
+ ],
319
+ questions=[
320
+ rg.TextQuestion(
321
+ name="correct_prompt",
322
+ description="The corrected prompt from the assistant",
323
+ ),
324
+ rg.TextQuestion(
325
+ name="correct_completion",
326
+ description="The corrected completion from the assistant",
327
+ ),
328
+ ],
329
+ metadata=[
330
+ rg.IntegerMetadataProperty(
331
+ name="prompt_length", title="Prompt Length"
332
+ ),
333
+ rg.IntegerMetadataProperty(
334
+ name="completion_length", title="Completion Length"
335
+ ),
336
+ ],
337
+ vectors=[
338
+ rg.VectorField(
339
+ name="prompt_embeddings",
340
+ dimensions=get_sentence_embedding_dimensions(),
341
+ )
342
+ ],
343
+ guidelines="Please review the conversation and correct the prompt and completion where needed.",
344
+ )
345
+ dataframe["correct_prompt"] = dataframe["prompt"]
346
+ dataframe["correct_completion"] = dataframe["completion"]
347
+ dataframe["prompt_length"] = dataframe["prompt"].apply(len)
348
+ dataframe["completion_length"] = dataframe["completion"].apply(len)
349
+ dataframe["prompt_embeddings"] = get_embeddings(dataframe["prompt"])
350
+
351
+ progress(0.5, desc="Creating dataset")
352
+ if client.datasets(name=dataset_name, workspace=rg_user.username) is not None:
353
+ raise gr.Error(f"Dataset {dataset_name} already exists")
354
+ rg_dataset = rg.Dataset(
355
+ name=dataset_name,
356
+ workspace=rg_user.username,
357
+ settings=settings,
358
+ client=client,
359
+ )
360
+ rg_dataset = rg_dataset.create()
361
+ progress(0.7, desc="Pushing dataset to Argilla")
362
+ hf_dataset = Dataset.from_pandas(dataframe)
363
+ rg_dataset.records.log(records=hf_dataset)
364
+ progress(1.0, desc="Dataset pushed to Argilla")
365
+ except Exception as e:
366
+ raise gr.Error(f"Error pushing dataset to Argilla: {e}")
367
+ return original_dataframe
368
 
369
 
370
  def upload_pipeline_code(
 
489
  maximum=500,
490
  info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
491
  )
492
+ with gr.Tab("Hugging Face Hub"):
493
+ with gr.Row(variant="panel"):
494
+ org_name = get_org_dropdown()
495
+ repo_name = gr.Textbox(
496
+ label="Repo name",
497
+ placeholder="dataset_name",
498
+ value="my-distiset",
499
+ )
500
+ private = gr.Checkbox(
501
+ label="Private dataset",
502
+ value=True,
503
+ interactive=True,
504
+ scale=0.5,
505
+ )
506
+ with gr.Row(variant="panel"):
507
+ btn_generate_full_dataset = gr.Button(
508
+ value="Generate", variant="primary", scale=2
509
+ )
510
+ btn_generate_and_push_to_hub = gr.Button(
511
+ value="Generate and Push to Hub", variant="primary", scale=2
512
+ )
513
+ btn_push_to_hub = gr.Button(
514
+ value="Push to Hub", variant="primary", scale=2
515
+ )
516
+ with gr.Tab(label="Argilla"):
517
+ with gr.Row(variant="panel"):
518
+ dataset_name = gr.Textbox(
519
+ label="Dataset name",
520
+ placeholder="dataset_name",
521
+ value="my-distiset",
522
+ )
523
+ with gr.Row(variant="panel"):
524
+ btn_generate_full_dataset_copy = gr.Button(
525
+ value="Generate", variant="primary", scale=2
526
+ )
527
+ btn_generate_and_push_to_argilla = gr.Button(
528
+ value="Generate and Push to Argilla", variant="primary", scale=2
529
+ )
530
+ btn_push_to_argilla = gr.Button(
531
+ value="Push to Argilla", variant="primary", scale=2
532
+ )
533
  with gr.Row():
534
  final_dataset = gr.Dataframe(
535
  value=DEFAULT_DATASETS[0],
 
541
  with gr.Row():
542
  success_message = gr.Markdown(visible=False)
543
 
544
+ def show_success_message_argilla():
545
+ client = get_argilla_client()
546
+ argilla_api_url = client.api_url
547
+ return gr.Markdown(
548
+ value=f"""
549
+ <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
550
+ <h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
551
+ <p style="margin-top: 0.5em;">
552
+ Your dataset is now available at:
553
+ <a href="{argilla_api_url}" target="_blank" style="color: #1565c0; text-decoration: none;">
554
+ {argilla_api_url}
555
+ </a>
556
+ Here are some docs to help you:
557
+ • <a href="https://docs.argilla.io/latest/getting_started/quickstart/#sign-in-into-the-argilla-ui" target="_blank">Login with OAuth</a>
558
+ • <a href="https://docs.argilla.io/latest/how_to_guides/annotate/" target="_blank">Curate your data</a>
559
+ • <a href="https://docs.argilla.io/latest/how_to_guides/import_export/" target="_blank">Export your data</a>
560
+ </p>
561
+ </div>
562
+ """,
563
+ visible=True,
564
+ )
565
+
566
+ def show_success_message_hub(org_name, repo_name):
567
  return gr.Markdown(
568
  value=f"""
569
  <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
 
576
  </a>
577
  </p>
578
  </div>
579
+ """,
580
  visible=True,
581
  )
582
 
 
605
  inputs=[sample_dataset],
606
  outputs=[final_dataset],
607
  )
608
+ gr.on(
609
+ triggers=[
610
+ btn_generate_full_dataset.click,
611
+ btn_generate_full_dataset_copy.click,
612
+ ],
613
+ fn=hide_success_message,
614
+ outputs=[success_message],
615
+ ).then(
616
+ fn=generate_dataset,
617
+ inputs=[system_prompt, num_turns, num_rows],
618
+ outputs=[final_dataset],
619
+ show_progress=True,
620
+ )
621
 
622
+ btn_generate_and_push_to_argilla.click(
623
  fn=hide_success_message,
624
  outputs=[success_message],
625
  ).then(
 
627
  inputs=[system_prompt, num_turns, num_rows],
628
  outputs=[final_dataset],
629
  show_progress=True,
630
+ ).then(
631
+ fn=push_to_argilla,
632
+ inputs=[final_dataset, dataset_name],
633
+ outputs=[final_dataset],
634
+ show_progress=True,
635
+ ).success(
636
+ fn=show_success_message_argilla,
637
+ inputs=[],
638
+ outputs=[success_message],
639
  )
640
 
641
  btn_generate_and_push_to_hub.click(
 
657
  outputs=[],
658
  show_progress=True,
659
  ).success(
660
+ fn=show_success_message_hub,
661
  inputs=[org_name, repo_name],
662
  outputs=[success_message],
663
  )
 
676
  outputs=[],
677
  show_progress=True,
678
  ).success(
679
+ fn=show_success_message_hub,
680
  inputs=[org_name, repo_name],
681
  outputs=[success_message],
682
  )
683
 
684
+ btn_push_to_argilla.click(
685
+ fn=hide_success_message,
686
+ outputs=[success_message],
687
+ ).then(
688
+ fn=push_to_argilla,
689
+ inputs=[final_dataset, dataset_name],
690
+ outputs=[final_dataset],
691
+ show_progress=True,
692
+ ).success(
693
+ fn=show_success_message_argilla,
694
+ inputs=[],
695
+ outputs=[success_message],
696
+ )
697
+
698
  system_prompt.change(
699
  fn=generate_pipeline_code,
700
  inputs=[system_prompt, num_turns, num_rows],
src/distilabel_dataset_generator/pipelines/embeddings.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from sentence_transformers import SentenceTransformer
4
+ from sentence_transformers.models import StaticEmbedding
5
+
6
+ # Initialize a StaticEmbedding module
7
+ static_embedding = StaticEmbedding.from_model2vec("minishlab/M2V_base_output")
8
+ model = SentenceTransformer(modules=[static_embedding])
9
+
10
+
11
+ def get_embeddings(texts: List[str]) -> List[List[float]]:
12
+ return [embedding.tolist() for embedding in model.encode(texts)]
13
+
14
+
15
+ def get_sentence_embedding_dimensions() -> int:
16
+ return model.get_sentence_embedding_dimension()
src/distilabel_dataset_generator/utils.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
 
 
3
  import gradio as gr
4
  from gradio.oauth import (
5
  OAUTH_CLIENT_ID,
@@ -81,3 +82,12 @@ def swap_visibilty(oauth_token: OAuthToken = None):
81
  return gr.update(elem_classes=["main_ui_logged_in"])
82
  else:
83
  return gr.update(elem_classes=["main_ui_logged_out"])
 
 
 
 
 
 
 
 
 
 
1
  import os
2
 
3
+ import argilla as rg
4
  import gradio as gr
5
  from gradio.oauth import (
6
  OAUTH_CLIENT_ID,
 
82
  return gr.update(elem_classes=["main_ui_logged_in"])
83
  else:
84
  return gr.update(elem_classes=["main_ui_logged_out"])
85
+
86
+
87
+ def get_argilla_client():
88
+ return rg.Argilla(
89
+ api_url=os.getenv("ARGILLA_API_URL_SDG_REVIEWER")
90
+ or os.getenv("ARGILLA_API_URL"),
91
+ api_key=os.getenv("ARGILLA_API_KEY_SDG_REVIEWER")
92
+ or os.getenv("ARGILLA_API_KEY"),
93
+ )