davidberenstein1957 HF staff commited on
Commit
8571d5a
1 Parent(s): 2eb6d1a

feat: small quality updates

Browse files
src/distilabel_dataset_generator/sft.py CHANGED
@@ -8,14 +8,7 @@ from distilabel.llms import InferenceEndpointsLLM
8
  from distilabel.pipeline import Pipeline
9
  from distilabel.steps import KeepColumns
10
  from distilabel.steps.tasks import MagpieGenerator, TextGeneration
11
-
12
- from src.distilabel_dataset_generator.utils import (
13
- OAuthToken,
14
- get_duplicate_button,
15
- get_login_button,
16
- get_org_dropdown,
17
- swap_visibilty,
18
- )
19
 
20
  INFORMATION_SEEKING_PROMPT = (
21
  "You are an AI assistant designed to provide accurate and concise information on a wide"
@@ -180,7 +173,7 @@ def _run_pipeline(result_queue, num_turns, num_rows, system_prompt, token: str =
180
  result_queue.put(distiset)
181
 
182
 
183
- def generate_system_prompt(dataset_description, token: OAuthToken = None, progress=gr.Progress()):
184
  progress(0.1, desc="Initializing text generation")
185
  generate_description = TextGeneration(
186
  llm=InferenceEndpointsLLM(
@@ -230,10 +223,13 @@ def generate_dataset(
230
  if repo_id is not None:
231
  if not repo_id:
232
  raise gr.Error("Please provide a dataset name to push the dataset to.")
233
- if token is None:
 
 
234
  raise gr.Error(
235
- "Please sign in with Hugging Face to be able to push the dataset to the Hub."
236
  )
 
237
  if num_turns > 4:
238
  raise gr.Info(
239
  "You can only generate a dataset with 4 or fewer turns. Setting to 4."
@@ -263,20 +259,22 @@ def generate_dataset(
263
  target=_run_pipeline,
264
  args=(result_queue, num_turns, num_rows, system_prompt),
265
  )
266
-
267
  try:
268
  p.start()
269
  total_steps = 100
270
  for step in range(total_steps):
271
  if not p.is_alive():
272
  break
273
- progress((step + 1) / total_steps, desc=f"Generating dataset with {num_rows} rows")
 
 
 
274
  time.sleep(0.5) # Adjust this value based on your needs
275
  p.join()
276
  except Exception as e:
277
  raise gr.Error(f"An error occurred during dataset generation: {str(e)}")
278
-
279
-
280
  distiset = result_queue.get()
281
 
282
  if repo_id is not None:
@@ -290,20 +288,13 @@ def generate_dataset(
290
  gr.Info(
291
  f'Dataset pushed to Hugging Face Hub: <a href="https://huggingface.co/datasets/{repo_id}">https://huggingface.co/datasets/{repo_id}</a>'
292
  )
293
-
294
  # If not pushing to hub generate the dataset directly
295
  distiset = distiset["default"]["train"]
296
  if num_turns == 1:
297
  outputs = distiset.to_pandas()[["prompt", "completion"]]
298
  else:
299
  outputs = distiset.to_pandas()[["messages"]]
300
- # outputs = {"conversation_id": [], "role": [], "content": []}
301
- # conversations = distiset["messages"]
302
- # for idx, entry in enumerate(conversations):
303
- # for message in entry["messages"]:
304
- # outputs["conversation_id"].append(idx + 1)
305
- # outputs["role"].append(message["role"])
306
- # outputs["content"].append(message["content"])
307
 
308
  progress(1.0, desc="Dataset generation completed")
309
  return pd.DataFrame(outputs)
@@ -320,9 +311,7 @@ with gr.Blocks(
320
  )
321
  with gr.Row():
322
  gr.Column(scale=1)
323
- btn_generate_system_prompt = gr.Button(
324
- value="Generate sample dataset"
325
- )
326
  gr.Column(scale=1)
327
 
328
  system_prompt = gr.TextArea(
@@ -337,12 +326,12 @@ with gr.Blocks(
337
  )
338
  gr.Column(scale=1)
339
 
340
- table = gr.DataFrame(
341
- value=DEFAULT_DATASET,
342
- interactive=False,
343
- wrap=True,
344
-
345
- )
346
 
347
  result = btn_generate_system_prompt.click(
348
  fn=generate_system_prompt,
@@ -362,17 +351,21 @@ with gr.Blocks(
362
  outputs=[table],
363
  show_progress=True,
364
  )
365
-
366
- # Add a header for the full dataset generation section
367
  gr.Markdown("## Generate full dataset")
368
- gr.Markdown("Once you're satisfied with the sample, generate a larger dataset and push it to the hub.")
 
 
369
  with gr.Column() as push_to_hub_ui:
370
  with gr.Row(variant="panel"):
371
  num_turns = gr.Number(
372
  value=1,
373
  label="Number of turns in the conversation",
 
374
  maximum=4,
375
- info="Whether the dataset is for a single turn with 'instruction-response' columns or a multi-turn conversation with a 'conversation' column.",
 
376
  )
377
  num_rows = gr.Number(
378
  value=100,
@@ -381,10 +374,9 @@ with gr.Blocks(
381
  maximum=5000,
382
  info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
383
  )
384
-
385
 
386
  with gr.Row(variant="panel"):
387
- hf_token = gr.Textbox(label="HF token")
388
  repo_id = gr.Textbox(label="HF repo ID", placeholder="owner/dataset_name")
389
  private = gr.Checkbox(label="Private dataset", value=True, interactive=True)
390
 
@@ -394,13 +386,7 @@ with gr.Blocks(
394
 
395
  btn_generate_full_dataset.click(
396
  fn=generate_dataset,
397
- inputs=[
398
- system_prompt,
399
- num_turns,
400
- num_rows,
401
- private,
402
- repo_id,
403
- ],
404
  outputs=[table],
405
  show_progress=True,
406
  )
 
8
  from distilabel.pipeline import Pipeline
9
  from distilabel.steps import KeepColumns
10
  from distilabel.steps.tasks import MagpieGenerator, TextGeneration
11
+ from huggingface_hub import whoami
 
 
 
 
 
 
 
12
 
13
  INFORMATION_SEEKING_PROMPT = (
14
  "You are an AI assistant designed to provide accurate and concise information on a wide"
 
173
  result_queue.put(distiset)
174
 
175
 
176
+ def generate_system_prompt(dataset_description, progress=gr.Progress()):
177
  progress(0.1, desc="Initializing text generation")
178
  generate_description = TextGeneration(
179
  llm=InferenceEndpointsLLM(
 
223
  if repo_id is not None:
224
  if not repo_id:
225
  raise gr.Error("Please provide a dataset name to push the dataset to.")
226
+ try:
227
+ whoami(token=token)
228
+ except Exception:
229
  raise gr.Error(
230
+ "Provide a Hugging Face to be able to push the dataset to the Hub."
231
  )
232
+
233
  if num_turns > 4:
234
  raise gr.Info(
235
  "You can only generate a dataset with 4 or fewer turns. Setting to 4."
 
259
  target=_run_pipeline,
260
  args=(result_queue, num_turns, num_rows, system_prompt),
261
  )
262
+
263
  try:
264
  p.start()
265
  total_steps = 100
266
  for step in range(total_steps):
267
  if not p.is_alive():
268
  break
269
+ progress(
270
+ (step + 1) / total_steps,
271
+ desc=f"Generating dataset with {num_rows} rows",
272
+ )
273
  time.sleep(0.5) # Adjust this value based on your needs
274
  p.join()
275
  except Exception as e:
276
  raise gr.Error(f"An error occurred during dataset generation: {str(e)}")
277
+
 
278
  distiset = result_queue.get()
279
 
280
  if repo_id is not None:
 
288
  gr.Info(
289
  f'Dataset pushed to Hugging Face Hub: <a href="https://huggingface.co/datasets/{repo_id}">https://huggingface.co/datasets/{repo_id}</a>'
290
  )
291
+
292
  # If not pushing to hub generate the dataset directly
293
  distiset = distiset["default"]["train"]
294
  if num_turns == 1:
295
  outputs = distiset.to_pandas()[["prompt", "completion"]]
296
  else:
297
  outputs = distiset.to_pandas()[["messages"]]
 
 
 
 
 
 
 
298
 
299
  progress(1.0, desc="Dataset generation completed")
300
  return pd.DataFrame(outputs)
 
311
  )
312
  with gr.Row():
313
  gr.Column(scale=1)
314
+ btn_generate_system_prompt = gr.Button(value="Generate sample dataset")
 
 
315
  gr.Column(scale=1)
316
 
317
  system_prompt = gr.TextArea(
 
326
  )
327
  gr.Column(scale=1)
328
 
329
+ with gr.Row():
330
+ table = gr.DataFrame(
331
+ value=DEFAULT_DATASET,
332
+ interactive=False,
333
+ wrap=True,
334
+ )
335
 
336
  result = btn_generate_system_prompt.click(
337
  fn=generate_system_prompt,
 
351
  outputs=[table],
352
  show_progress=True,
353
  )
354
+
355
+ # Add a header for the full dataset generation section
356
  gr.Markdown("## Generate full dataset")
357
+ gr.Markdown(
358
+ "Once you're satisfied with the sample, generate a larger dataset and push it to the hub. Get <a href='https://huggingface.co/settings/tokens' target='_blank'>a Hugging Face token</a> with write access to the organization you want to push the dataset to."
359
+ )
360
  with gr.Column() as push_to_hub_ui:
361
  with gr.Row(variant="panel"):
362
  num_turns = gr.Number(
363
  value=1,
364
  label="Number of turns in the conversation",
365
+ minimum=1,
366
  maximum=4,
367
+ step=1,
368
+ info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'conversation' column).",
369
  )
370
  num_rows = gr.Number(
371
  value=100,
 
374
  maximum=5000,
375
  info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
376
  )
 
377
 
378
  with gr.Row(variant="panel"):
379
+ hf_token = gr.Textbox(label="HF token", type="password")
380
  repo_id = gr.Textbox(label="HF repo ID", placeholder="owner/dataset_name")
381
  private = gr.Checkbox(label="Private dataset", value=True, interactive=True)
382
 
 
386
 
387
  btn_generate_full_dataset.click(
388
  fn=generate_dataset,
389
+ inputs=[system_prompt, num_turns, num_rows, private, repo_id, hf_token],
 
 
 
 
 
 
390
  outputs=[table],
391
  show_progress=True,
392
  )