Spaces:
wuhp
/
Running on Zero

wuhp commited on
Commit
7014802
·
verified ·
1 Parent(s): e2ec65a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -57
app.py CHANGED
@@ -237,7 +237,7 @@ def compare_models(
237
  Args:
238
  prompt (str): The input prompt for text generation.
239
  temperature (float): Sampling temperature.
240
- top_p (float): Sampling top-p.
241
  min_new_tokens (int): Minimum number of new tokens to generate.
242
  max_new_tokens (int): Maximum number of new tokens to generate.
243
 
@@ -385,7 +385,7 @@ def chat_rag(
385
  user_input (str): The user's chat input.
386
  history (list[list[str]]): The chat history.
387
  temperature (float): Sampling temperature.
388
- top_p (float): Sampling top-p.
389
  min_new_tokens (int): Minimum number of new tokens to generate.
390
  max_new_tokens (int): Maximum number of new tokens to generate.
391
 
@@ -424,60 +424,70 @@ def chat_rag(
424
  return history, history
425
 
426
 
427
- # Build the Gradio interface with tabs.
428
- with gr.Blocks(css="""
429
- body {background-color: #f5f5f5; font-family: Arial, sans-serif;}
430
- .gradio-container {max-width: 1000px; margin: auto; background: white; padding: 20px; border-radius: 10px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);}
431
- h1 {color: #333; text-align: center; font-size: 2rem;}
432
- h2 {color: #444; margin-top: 10px; font-size: 1.5rem;}
433
- .gr-tab {padding: 10px;}
434
- """) as demo:
435
-
436
- gr.Markdown("# 🚀 QLoRA Fine-tuning & RAG Chat Demo")
437
- gr.Markdown("Welcome to the enhanced **QLoRA fine-tuning and RAG-based chatbot interface**. This tool lets you fine-tune an AI model, generate text, and interact with a chatbot using retrieval-augmented responses.")
438
-
439
- with gr.TabbedInterface():
440
-
441
- # Fine-tuning tab
442
- with gr.Tab(label="⚙️ Fine-tune Model"):
443
- gr.Markdown("### Train your custom R1 model")
444
- gr.Markdown("Fine-tune the model using QLoRA. This is **optional**, but recommended for better performance.")
445
- finetune_btn = gr.Button("Start Fine-tuning")
446
- finetune_output = gr.Textbox(label="Status", interactive=False)
447
- finetune_btn.click(finetune_small_subset, inputs=None, outputs=finetune_output)
448
-
449
- # Text Generation tab
450
- with gr.Tab(label="✍️ Text Generation"):
451
- gr.Markdown("### Generate text using your fine-tuned model")
452
- input_prompt = gr.Textbox(label="Enter Prompt", placeholder="Type something here...", lines=3)
453
- temp_slider = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature")
454
- topp_slider = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
455
- min_tokens = gr.Slider(1, 1000, value=50, step=10, label="Min New Tokens")
456
- max_tokens = gr.Slider(1, 1000, value=200, step=10, label="Max New Tokens")
457
- generate_btn = gr.Button("Generate Text")
458
- output_box = gr.Textbox(label="Generated Output", lines=8, interactive=False)
459
- generate_btn.click(predict, inputs=[input_prompt, temp_slider, topp_slider, min_tokens, max_tokens], outputs=output_box)
460
-
461
- # Model Comparison tab
462
- with gr.Tab(label="🆚 Compare Models"):
463
- gr.Markdown("### Compare text outputs from your fine-tuned model and the official model")
464
- compare_prompt = gr.Textbox(label="Enter Comparison Prompt", placeholder="Enter a prompt here...", lines=3)
465
- compare_temp = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature")
466
- compare_topp = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
467
- compare_min_tokens = gr.Slider(1, 1000, value=50, step=10, label="Min New Tokens")
468
- compare_max_tokens = gr.Slider(1, 1000, value=200, step=10, label="Max New Tokens")
469
- compare_btn = gr.Button("Compare Models")
470
- compare_output1 = gr.Textbox(label="Custom Model Output", lines=6, interactive=False)
471
- compare_output2 = gr.Textbox(label="Official Model Output", lines=6, interactive=False)
472
- compare_btn.click(compare_models, inputs=[compare_prompt, compare_temp, compare_topp, compare_min_tokens, compare_max_tokens], outputs=[compare_output1, compare_output2])
473
-
474
- # Chatbot tab
475
- with gr.Tab(label="💬 AI Chatbot"):
476
- gr.Markdown("### Chat with an AI assistant using retrieval-augmented generation (RAG)")
477
- chatbot = gr.Chatbot(label="AI Chatbot", height=400)
478
- chat_input = gr.Textbox(placeholder="Ask me anything...", lines=2)
479
- chat_btn = gr.Button("Send")
480
- chat_output = gr.Chatbot(label="Chat History")
481
- chat_btn.click(chat_rag, inputs=[chat_input, chatbot, temp_slider, topp_slider, min_tokens, max_tokens], outputs=[chat_output, chatbot])
 
 
 
 
 
 
 
 
 
 
482
 
483
  demo.launch()
 
237
  Args:
238
  prompt (str): The input prompt for text generation.
239
  temperature (float): Sampling temperature.
240
+ top_p (float): Top-p sampling probability.
241
  min_new_tokens (int): Minimum number of new tokens to generate.
242
  max_new_tokens (int): Maximum number of new tokens to generate.
243
 
 
385
  user_input (str): The user's chat input.
386
  history (list[list[str]]): The chat history.
387
  temperature (float): Sampling temperature.
388
+ top_p (float): Top-p sampling probability.
389
  min_new_tokens (int): Minimum number of new tokens to generate.
390
  max_new_tokens (int): Maximum number of new tokens to generate.
391
 
 
424
  return history, history
425
 
426
 
427
+ # Build the Gradio interface.
428
+ with gr.Blocks() as demo:
429
+ gr.Markdown("# QLoRA Fine-tuning & RAG-based Chat Demo using Custom R1 Model")
430
+ gr.Markdown("---")
431
+
432
+ gr.Markdown("## ⚙️ Fine-tuning (Optional)")
433
+ gr.Markdown("This section allows you to fine-tune the custom R1 model on a small subset of the ServiceNow dataset. This step is optional but can potentially improve the model's performance on ServiceNow-related tasks. **Note:** This process may take up to 5 minutes.")
434
+ finetune_btn = gr.Button("🚀 Start Fine-tuning (QLoRA)")
435
+ status_box = gr.Textbox(label="Fine-tuning Status", interactive=False)
436
+ finetune_btn.click(fn=finetune_small_subset, outputs=status_box)
437
+ gr.Markdown("---")
438
+
439
+ gr.Markdown("## ✍️ Direct Generation (No Retrieval)")
440
+ gr.Markdown("Enter a prompt below to generate text directly using the custom R1 model. This is standard text generation without retrieval augmentation.")
441
+ prompt_in = gr.Textbox(lines=3, label="Input Prompt", placeholder="Enter your prompt here...")
442
+ temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature (Creativity)")
443
+ top_p = gr.Slider(0.0, 1.0, step=0.05, value=0.9, label="Top-p (Sampling Nucleus)")
444
+ min_tokens = gr.Slider(1, 2500, value=50, step=10, label="Min New Tokens")
445
+ max_tokens = gr.Slider(1, 2500, value=200, step=50, label="Max New Tokens")
446
+ output_box = gr.Textbox(label="Custom R1 Output", lines=8, interactive=False)
447
+ gen_btn = gr.Button("✨ Generate Text")
448
+ gen_btn.click(
449
+ fn=predict,
450
+ inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
451
+ outputs=output_box
452
+ )
453
+ gr.Markdown("---")
454
+
455
+ gr.Markdown("## 🆚 Compare Custom R1 vs Official R1")
456
+ gr.Markdown("Enter a prompt to compare the text generation of your fine-tuned custom R1 model with the official DeepSeek-R1-Distill-Llama-8B model.")
457
+ compare_prompt_in = gr.Textbox(lines=3, label="Comparison Prompt", placeholder="Enter prompt for comparison...")
458
+ compare_btn = gr.Button("⚖️ Compare Models")
459
+ out_custom = gr.Textbox(label="Custom R1 Output", lines=6, interactive=False)
460
+ out_official = gr.Textbox(label="Official R1 Output", lines=6, interactive=False)
461
+ compare_btn.click(
462
+ fn=compare_models,
463
+ inputs=[compare_prompt_in, temperature, top_p, min_tokens, max_tokens],
464
+ outputs=[out_custom, out_official]
465
+ )
466
+ gr.Markdown("---")
467
+
468
+ gr.Markdown("## 💬 Chat with Retrieval-Augmented Memory (RAG)")
469
+ gr.Markdown("Chat with the custom R1 model, enhanced with a retrieval-augmented memory. The model will retrieve relevant information based on your queries to provide more informed responses.")
470
+ with gr.Row():
471
+ with gr.Column():
472
+ chatbot = gr.Chatbot(label="RAG Chatbot")
473
+ chat_state = gr.State([])
474
+ user_input = gr.Textbox(
475
+ show_label=False,
476
+ placeholder="Ask a question to the RAG Chatbot...",
477
+ lines=2
478
+ )
479
+ send_btn = gr.Button("➡️ Send")
480
+ user_input.submit(
481
+ fn=chat_rag,
482
+ inputs=[user_input, chat_state, temperature, top_p, min_tokens, max_tokens],
483
+ outputs=[chat_state, chatbot]
484
+ )
485
+ send_btn.click(
486
+ fn=chat_rag,
487
+ inputs=[user_input, chat_state, temperature, top_p, min_tokens, max_tokens],
488
+ outputs=[chat_state, chatbot]
489
+ )
490
+ gr.Markdown("---")
491
+
492
 
493
  demo.launch()