Update app.py
Browse files
app.py
CHANGED
@@ -237,7 +237,7 @@ def compare_models(
|
|
237 |
Args:
|
238 |
prompt (str): The input prompt for text generation.
|
239 |
temperature (float): Sampling temperature.
|
240 |
-
top_p (float):
|
241 |
min_new_tokens (int): Minimum number of new tokens to generate.
|
242 |
max_new_tokens (int): Maximum number of new tokens to generate.
|
243 |
|
@@ -385,7 +385,7 @@ def chat_rag(
|
|
385 |
user_input (str): The user's chat input.
|
386 |
history (list[list[str]]): The chat history.
|
387 |
temperature (float): Sampling temperature.
|
388 |
-
top_p (float):
|
389 |
min_new_tokens (int): Minimum number of new tokens to generate.
|
390 |
max_new_tokens (int): Maximum number of new tokens to generate.
|
391 |
|
@@ -424,60 +424,70 @@ def chat_rag(
|
|
424 |
return history, history
|
425 |
|
426 |
|
427 |
-
# Build the Gradio interface
|
428 |
-
with gr.Blocks(
|
429 |
-
|
430 |
-
.
|
431 |
-
|
432 |
-
|
433 |
-
.
|
434 |
-
""
|
435 |
-
|
436 |
-
|
437 |
-
gr.Markdown("
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
482 |
|
483 |
demo.launch()
|
|
|
237 |
Args:
|
238 |
prompt (str): The input prompt for text generation.
|
239 |
temperature (float): Sampling temperature.
|
240 |
+
top_p (float): Top-p sampling probability.
|
241 |
min_new_tokens (int): Minimum number of new tokens to generate.
|
242 |
max_new_tokens (int): Maximum number of new tokens to generate.
|
243 |
|
|
|
385 |
user_input (str): The user's chat input.
|
386 |
history (list[list[str]]): The chat history.
|
387 |
temperature (float): Sampling temperature.
|
388 |
+
top_p (float): Top-p sampling probability.
|
389 |
min_new_tokens (int): Minimum number of new tokens to generate.
|
390 |
max_new_tokens (int): Maximum number of new tokens to generate.
|
391 |
|
|
|
424 |
return history, history
|
425 |
|
426 |
|
427 |
+
# Build the Gradio interface.
|
428 |
+
with gr.Blocks() as demo:
|
429 |
+
gr.Markdown("# QLoRA Fine-tuning & RAG-based Chat Demo using Custom R1 Model")
|
430 |
+
gr.Markdown("---")
|
431 |
+
|
432 |
+
gr.Markdown("## ⚙️ Fine-tuning (Optional)")
|
433 |
+
gr.Markdown("This section allows you to fine-tune the custom R1 model on a small subset of the ServiceNow dataset. This step is optional but can potentially improve the model's performance on ServiceNow-related tasks. **Note:** This process may take up to 5 minutes.")
|
434 |
+
finetune_btn = gr.Button("🚀 Start Fine-tuning (QLoRA)")
|
435 |
+
status_box = gr.Textbox(label="Fine-tuning Status", interactive=False)
|
436 |
+
finetune_btn.click(fn=finetune_small_subset, outputs=status_box)
|
437 |
+
gr.Markdown("---")
|
438 |
+
|
439 |
+
gr.Markdown("## ✍️ Direct Generation (No Retrieval)")
|
440 |
+
gr.Markdown("Enter a prompt below to generate text directly using the custom R1 model. This is standard text generation without retrieval augmentation.")
|
441 |
+
prompt_in = gr.Textbox(lines=3, label="Input Prompt", placeholder="Enter your prompt here...")
|
442 |
+
temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature (Creativity)")
|
443 |
+
top_p = gr.Slider(0.0, 1.0, step=0.05, value=0.9, label="Top-p (Sampling Nucleus)")
|
444 |
+
min_tokens = gr.Slider(1, 2500, value=50, step=10, label="Min New Tokens")
|
445 |
+
max_tokens = gr.Slider(1, 2500, value=200, step=50, label="Max New Tokens")
|
446 |
+
output_box = gr.Textbox(label="Custom R1 Output", lines=8, interactive=False)
|
447 |
+
gen_btn = gr.Button("✨ Generate Text")
|
448 |
+
gen_btn.click(
|
449 |
+
fn=predict,
|
450 |
+
inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
|
451 |
+
outputs=output_box
|
452 |
+
)
|
453 |
+
gr.Markdown("---")
|
454 |
+
|
455 |
+
gr.Markdown("## 🆚 Compare Custom R1 vs Official R1")
|
456 |
+
gr.Markdown("Enter a prompt to compare the text generation of your fine-tuned custom R1 model with the official DeepSeek-R1-Distill-Llama-8B model.")
|
457 |
+
compare_prompt_in = gr.Textbox(lines=3, label="Comparison Prompt", placeholder="Enter prompt for comparison...")
|
458 |
+
compare_btn = gr.Button("⚖️ Compare Models")
|
459 |
+
out_custom = gr.Textbox(label="Custom R1 Output", lines=6, interactive=False)
|
460 |
+
out_official = gr.Textbox(label="Official R1 Output", lines=6, interactive=False)
|
461 |
+
compare_btn.click(
|
462 |
+
fn=compare_models,
|
463 |
+
inputs=[compare_prompt_in, temperature, top_p, min_tokens, max_tokens],
|
464 |
+
outputs=[out_custom, out_official]
|
465 |
+
)
|
466 |
+
gr.Markdown("---")
|
467 |
+
|
468 |
+
gr.Markdown("## 💬 Chat with Retrieval-Augmented Memory (RAG)")
|
469 |
+
gr.Markdown("Chat with the custom R1 model, enhanced with a retrieval-augmented memory. The model will retrieve relevant information based on your queries to provide more informed responses.")
|
470 |
+
with gr.Row():
|
471 |
+
with gr.Column():
|
472 |
+
chatbot = gr.Chatbot(label="RAG Chatbot")
|
473 |
+
chat_state = gr.State([])
|
474 |
+
user_input = gr.Textbox(
|
475 |
+
show_label=False,
|
476 |
+
placeholder="Ask a question to the RAG Chatbot...",
|
477 |
+
lines=2
|
478 |
+
)
|
479 |
+
send_btn = gr.Button("➡️ Send")
|
480 |
+
user_input.submit(
|
481 |
+
fn=chat_rag,
|
482 |
+
inputs=[user_input, chat_state, temperature, top_p, min_tokens, max_tokens],
|
483 |
+
outputs=[chat_state, chatbot]
|
484 |
+
)
|
485 |
+
send_btn.click(
|
486 |
+
fn=chat_rag,
|
487 |
+
inputs=[user_input, chat_state, temperature, top_p, min_tokens, max_tokens],
|
488 |
+
outputs=[chat_state, chatbot]
|
489 |
+
)
|
490 |
+
gr.Markdown("---")
|
491 |
+
|
492 |
|
493 |
demo.launch()
|