dvilasuero HF staff commited on
Commit
9d1a2d6
1 Parent(s): 8571d5a

Update src/distilabel_dataset_generator/sft.py

Browse files
src/distilabel_dataset_generator/sft.py CHANGED
@@ -300,6 +300,48 @@ def generate_dataset(
300
  return pd.DataFrame(outputs)
301
 
302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  with gr.Blocks(
304
  title="⚗️ Distilabel Dataset Generator",
305
  head="⚗️ Distilabel Dataset Generator",
@@ -357,6 +399,8 @@ with gr.Blocks(
357
  gr.Markdown(
358
  "Once you're satisfied with the sample, generate a larger dataset and push it to the hub. Get <a href='https://huggingface.co/settings/tokens' target='_blank'>a Hugging Face token</a> with write access to the organization you want to push the dataset to."
359
  )
 
 
360
  with gr.Column() as push_to_hub_ui:
361
  with gr.Row(variant="panel"):
362
  num_turns = gr.Number(
@@ -384,9 +428,40 @@ with gr.Blocks(
384
  value="⚗️ Generate Full Dataset", variant="primary"
385
  )
386
 
387
- btn_generate_full_dataset.click(
388
- fn=generate_dataset,
389
- inputs=[system_prompt, num_turns, num_rows, private, repo_id, hf_token],
390
- outputs=[table],
391
- show_progress=True,
392
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  return pd.DataFrame(outputs)
301
 
302
 
303
+ def generate_pipeline_code(system_prompt):
304
+ code = f"""
305
+ from distilabel.pipeline import Pipeline
306
+ from distilabel.steps import KeepColumns
307
+ from distilabel.steps.tasks import MagpieGenerator
308
+ from distilabel.llms import InferenceEndpointsLLM
309
+
310
+ MODEL = "{MODEL}"
311
+ SYSTEM_PROMPT = "{system_prompt}"
312
+
313
+ with Pipeline(name="sft") as pipeline:
314
+ magpie = MagpieGenerator(
315
+ llm=InferenceEndpointsLLM(
316
+ model_id=MODEL,
317
+ tokenizer_id=MODEL,
318
+ magpie_pre_query_template="llama3",
319
+ generation_kwargs={{
320
+ "temperature": 0.8,
321
+ "do_sample": True,
322
+ "max_new_tokens": 2048,
323
+ "stop_sequences": [
324
+ "<|eot_id|>",
325
+ "<|end_of_text|>",
326
+ "<|start_header_id|>",
327
+ "<|end_header_id|>",
328
+ "assistant",
329
+ ],
330
+ }}
331
+ ),
332
+ n_turns=1,
333
+ num_rows=100,
334
+ system_prompt=SYSTEM_PROMPT,
335
+ )
336
+
337
+ if __name__ == "__main__":
338
+ distiset = pipeline.run()
339
+ """
340
+ return code
341
+
342
+ def update_pipeline_code(system_prompt):
343
+ return generate_pipeline_code(system_prompt)
344
+
345
  with gr.Blocks(
346
  title="⚗️ Distilabel Dataset Generator",
347
  head="⚗️ Distilabel Dataset Generator",
 
399
  gr.Markdown(
400
  "Once you're satisfied with the sample, generate a larger dataset and push it to the hub. Get <a href='https://huggingface.co/settings/tokens' target='_blank'>a Hugging Face token</a> with write access to the organization you want to push the dataset to."
401
  )
402
+
403
+
404
  with gr.Column() as push_to_hub_ui:
405
  with gr.Row(variant="panel"):
406
  num_turns = gr.Number(
 
428
  value="⚗️ Generate Full Dataset", variant="primary"
429
  )
430
 
431
+ # Add this line here, before the button click event
432
+ success_message = gr.Markdown(visible=False)
433
+
434
+ def show_success_message(repo_id_value):
435
+ return gr.update(value=f"""
436
+ <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
437
+ <h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
438
+ <p style="margin-top: 0.5em;">
439
+ Your dataset is now available at:
440
+ <a href="https://huggingface.co/datasets/{repo_id_value}" target="_blank" style="color: #1565c0; text-decoration: none;">
441
+ https://huggingface.co/datasets/{repo_id_value}
442
+ </a>
443
+ </p>
444
+ </div>
445
+ """, visible=True)
446
+
447
+ btn_generate_full_dataset.click(
448
+ fn=generate_dataset,
449
+ inputs=[system_prompt, num_turns, num_rows, private, repo_id, hf_token],
450
+ outputs=[table],
451
+ show_progress=True,
452
+ ).then(
453
+ fn=show_success_message,
454
+ inputs=[repo_id],
455
+ outputs=[success_message]
456
+ )
457
+
458
+ gr.Markdown("## Or run this pipeline locally with distilabel")
459
+
460
+ with gr.Accordion("Run this pipeline on Distilabel", open=False):
461
+ pipeline_code = gr.Code(language="python", label="Distilabel Pipeline Code")
462
+
463
+ system_prompt.change(
464
+ fn=update_pipeline_code,
465
+ inputs=[system_prompt],
466
+ outputs=[pipeline_code],
467
+ )