wuhp commited on
Commit
54e45c0
·
verified ·
1 Parent(s): dca5fea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -60
app.py CHANGED
@@ -19,7 +19,7 @@ from transformers import (
19
  pipeline,
20
  )
21
 
22
- NUM_EXAMPLES_FOR_FINETUNING = 75 # Constant for the number of examples to use for finetuning
23
  TEXT_PIPELINE = None # Global to store the custom R1 text generation pipeline
24
  COMPARISON_PIPELINE = None # Global to store the official R1 text generation pipeline
25
 
@@ -424,70 +424,72 @@ def chat_rag(
424
  return history, history
425
 
426
 
427
- # Build the Gradio interface.
428
  with gr.Blocks() as demo:
429
  gr.Markdown("# QLoRA Fine-tuning & RAG-based Chat Demo using Custom R1 Model")
430
  gr.Markdown("---")
431
 
432
- gr.Markdown("## ⚙️ Fine-tuning (Optional)")
433
- gr.Markdown("This section allows you to fine-tune the custom R1 model on a small subset of the ServiceNow dataset. This step is optional but can potentially improve the model's performance on ServiceNow-related tasks. **Note:** This process may take up to 5 minutes.")
434
- finetune_btn = gr.Button("🚀 Start Fine-tuning (QLoRA)")
435
- status_box = gr.Textbox(label="Fine-tuning Status", interactive=False)
436
- finetune_btn.click(fn=finetune_small_subset, outputs=status_box)
437
- gr.Markdown("---")
438
-
439
- gr.Markdown("## ✍️ Direct Generation (No Retrieval)")
440
- gr.Markdown("Enter a prompt below to generate text directly using the custom R1 model. This is standard text generation without retrieval augmentation.")
441
- prompt_in = gr.Textbox(lines=3, label="Input Prompt", placeholder="Enter your prompt here...")
442
- temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature (Creativity)")
443
- top_p = gr.Slider(0.0, 1.0, step=0.05, value=0.9, label="Top-p (Sampling Nucleus)")
444
- min_tokens = gr.Slider(1, 2500, value=50, step=10, label="Min New Tokens")
445
- max_tokens = gr.Slider(1, 2500, value=200, step=50, label="Max New Tokens")
446
- output_box = gr.Textbox(label="Custom R1 Output", lines=8, interactive=False)
447
- gen_btn = gr.Button(" Generate Text")
448
- gen_btn.click(
449
- fn=predict,
450
- inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
451
- outputs=output_box
452
- )
453
- gr.Markdown("---")
454
-
455
- gr.Markdown("## 🆚 Compare Custom R1 vs Official R1")
456
- gr.Markdown("Enter a prompt to compare the text generation of your fine-tuned custom R1 model with the official DeepSeek-R1-Distill-Llama-8B model.")
457
- compare_prompt_in = gr.Textbox(lines=3, label="Comparison Prompt", placeholder="Enter prompt for comparison...")
458
- compare_btn = gr.Button("⚖️ Compare Models")
459
- out_custom = gr.Textbox(label="Custom R1 Output", lines=6, interactive=False)
460
- out_official = gr.Textbox(label="Official R1 Output", lines=6, interactive=False)
461
- compare_btn.click(
462
- fn=compare_models,
463
- inputs=[compare_prompt_in, temperature, top_p, min_tokens, max_tokens],
464
- outputs=[out_custom, out_official]
465
- )
466
- gr.Markdown("---")
467
-
468
- gr.Markdown("## 💬 Chat with Retrieval-Augmented Memory (RAG)")
469
- gr.Markdown("Chat with the custom R1 model, enhanced with a retrieval-augmented memory. The model will retrieve relevant information based on your queries to provide more informed responses.")
470
- with gr.Row():
471
- with gr.Column():
472
- chatbot = gr.Chatbot(label="RAG Chatbot")
473
- chat_state = gr.State([])
474
- user_input = gr.Textbox(
475
- show_label=False,
476
- placeholder="Ask a question to the RAG Chatbot...",
477
- lines=2
 
 
 
 
 
 
 
 
 
 
 
478
  )
479
- send_btn = gr.Button("➡️ Send")
480
- user_input.submit(
481
- fn=chat_rag,
482
- inputs=[user_input, chat_state, temperature, top_p, min_tokens, max_tokens],
483
- outputs=[chat_state, chatbot]
484
- )
485
- send_btn.click(
486
- fn=chat_rag,
487
- inputs=[user_input, chat_state, temperature, top_p, min_tokens, max_tokens],
488
- outputs=[chat_state, chatbot]
489
- )
490
- gr.Markdown("---")
491
 
492
 
493
  demo.launch()
 
19
  pipeline,
20
  )
21
 
22
+ NUM_EXAMPLES_FOR_FINETUNING = 50 # Constant for the number of examples to use for finetuning
23
  TEXT_PIPELINE = None # Global to store the custom R1 text generation pipeline
24
  COMPARISON_PIPELINE = None # Global to store the official R1 text generation pipeline
25
 
 
424
  return history, history
425
 
426
 
427
+ # Build the Gradio interface with tabs.
428
  with gr.Blocks() as demo:
429
  gr.Markdown("# QLoRA Fine-tuning & RAG-based Chat Demo using Custom R1 Model")
430
  gr.Markdown("---")
431
 
432
+ with gr.TabbedInterface(
433
+ [
434
+ gr.Interface(
435
+ fn=finetune_small_subset,
436
+ inputs=None,
437
+ outputs=gr.Textbox(label="Fine-tuning Status", interactive=False),
438
+ title="⚙️ Fine-tuning (Optional)",
439
+ description="This section allows you to fine-tune the custom R1 model on a small subset of the ServiceNow dataset. This step is optional but can potentially improve the model's performance on ServiceNow-related tasks. **Note:** This process may take up to 5 minutes.",
440
+ submit_button_text="🚀 Start Fine-tuning (QLoRA)"
441
+ ),
442
+ gr.Interface(
443
+ fn=predict,
444
+ inputs=[
445
+ gr.Textbox(lines=3, label="Input Prompt", placeholder="Enter your prompt here..."),
446
+ gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature (Creativity)"),
447
+ gr.Slider(0.0, 1.0, step=0.05, value=0.9, label="Top-p (Sampling Nucleus)"),
448
+ gr.Slider(1, 2500, value=50, step=10, label="Min New Tokens"),
449
+ gr.Slider(1, 2500, value=200, step=50, label="Max New Tokens")
450
+ ],
451
+ outputs=gr.Textbox(label="Custom R1 Output", lines=8, interactive=False),
452
+ title="✍️ Direct Generation",
453
+ description="Enter a prompt to generate text directly using the custom R1 model. This is standard text generation without retrieval augmentation.",
454
+ submit_button_text="✨ Generate Text"
455
+ ),
456
+ gr.Interface(
457
+ fn=compare_models,
458
+ inputs=[
459
+ gr.Textbox(lines=3, label="Comparison Prompt", placeholder="Enter prompt for comparison..."),
460
+ gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature"),
461
+ gr.Slider(0.0, 1.0, step=0.05, value=0.9, label="Top-p"),
462
+ gr.Slider(1, 2500, value=50, step=10, label="Min New Tokens"),
463
+ gr.Slider(1, 2500, value=200, step=50, label="Max New Tokens")
464
+ ],
465
+ outputs=[
466
+ gr.Textbox(label="Custom R1 Output", lines=6, interactive=False),
467
+ gr.Textbox(label="Official R1 Output", lines=6, interactive=False)
468
+ ],
469
+ title="🆚 Model Comparison",
470
+ description="Enter a prompt to compare the text generation of your fine-tuned custom R1 model with the official DeepSeek-R1-Distill-Llama-8B model.",
471
+ submit_button_text="⚖️ Compare Models"
472
+ ),
473
+ gr.ChatInterface(
474
+ fn=chat_rag,
475
+ chatbot=gr.Chatbot(label="RAG Chatbot"),
476
+ textbox=gr.Textbox(placeholder="Ask a question to the RAG Chatbot...", lines=2, show_label=False),
477
+ inputs=[
478
+ "textbox",
479
+ gr.State([]), # chat_state
480
+ gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature"),
481
+ gr.Slider(0.0, 1.0, step=0.05, value=0.9, label="Top-p"),
482
+ gr.Slider(1, 2500, value=50, step=10, label="Min New Tokens"),
483
+ gr.Slider(1, 2500, value=200, step=50, label="Max New Tokens")
484
+ ],
485
+ title="💬 RAG Chat",
486
+ description="Chat with the custom R1 model, enhanced with retrieval-augmented memory. The model retrieves relevant info for informed responses.",
487
+ submit_button="➡️ Send",
488
+ clear_btn=None # Optional: You can add a clear button if needed
489
  )
490
+ ]
491
+ ).render():
492
+ pass # No need for extra elements outside the tabs now
 
 
 
 
 
 
 
 
 
493
 
494
 
495
  demo.launch()