Update app.py
Browse files
app.py
CHANGED
@@ -23,12 +23,12 @@ from sentence_transformers import SentenceTransformer
|
|
23 |
# Global variables for pipelines and settings.
|
24 |
TEXT_PIPELINE = None
|
25 |
COMPARISON_PIPELINE = None
|
26 |
-
NUM_EXAMPLES =
|
27 |
|
28 |
@spaces.GPU(duration=300)
|
29 |
def finetune_small_subset():
|
30 |
"""
|
31 |
-
Fine-tunes the custom
|
32 |
Steps:
|
33 |
1) Loads the model from "wuhp/myr1" (using files from the "myr1" subfolder via trust_remote_code).
|
34 |
2) Applies 4-bit quantization and prepares for QLoRA training.
|
@@ -163,7 +163,7 @@ def ensure_pipeline():
|
|
163 |
|
164 |
def ensure_comparison_pipeline():
|
165 |
"""
|
166 |
-
Loads
|
167 |
"""
|
168 |
global COMPARISON_PIPELINE
|
169 |
if COMPARISON_PIPELINE is None:
|
@@ -180,7 +180,7 @@ def ensure_comparison_pipeline():
|
|
180 |
@spaces.GPU(duration=120)
|
181 |
def predict(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
|
182 |
"""
|
183 |
-
Direct generation without retrieval.
|
184 |
"""
|
185 |
pipe = ensure_pipeline()
|
186 |
out = pipe(
|
@@ -196,7 +196,7 @@ def predict(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
|
|
196 |
@spaces.GPU(duration=120)
|
197 |
def compare_models(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
|
198 |
"""
|
199 |
-
Compare outputs between your custom model and
|
200 |
"""
|
201 |
local_pipe = ensure_pipeline()
|
202 |
comp_pipe = ensure_comparison_pipeline()
|
@@ -299,34 +299,34 @@ def chat_rag(user_input, history, temperature, top_p, min_new_tokens, max_new_to
|
|
299 |
|
300 |
# Build the Gradio interface.
|
301 |
with gr.Blocks() as demo:
|
302 |
-
gr.Markdown("# QLoRA Fine-tuning & RAG-based Chat Demo using Custom
|
303 |
|
304 |
finetune_btn = gr.Button("Finetune 4-bit (QLoRA) on ServiceNow-AI/R1-Distill-SFT subset (up to 5 min)")
|
305 |
status_box = gr.Textbox(label="Finetune Status")
|
306 |
finetune_btn.click(fn=finetune_small_subset, outputs=status_box)
|
307 |
|
308 |
-
gr.Markdown("## Direct Generation (No Retrieval)")
|
309 |
prompt_in = gr.Textbox(lines=3, label="Prompt")
|
310 |
temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature")
|
311 |
top_p = gr.Slider(0.0, 1.0, step=0.05, value=0.9, label="Top-p")
|
312 |
min_tokens = gr.Slider(1, 2500, value=50, step=10, label="Min New Tokens")
|
313 |
max_tokens = gr.Slider(1, 2500, value=200, step=50, label="Max New Tokens")
|
314 |
-
output_box = gr.Textbox(label="
|
315 |
-
gen_btn = gr.Button("Generate with
|
316 |
gen_btn.click(
|
317 |
fn=predict,
|
318 |
inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
|
319 |
outputs=output_box
|
320 |
)
|
321 |
|
322 |
-
gr.Markdown("## Compare
|
323 |
compare_btn = gr.Button("Compare")
|
324 |
-
|
325 |
-
|
326 |
compare_btn.click(
|
327 |
fn=compare_models,
|
328 |
inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
|
329 |
-
outputs=[
|
330 |
)
|
331 |
|
332 |
gr.Markdown("## Chat with Retrieval-Augmented Memory")
|
|
|
23 |
# Global variables for pipelines and settings.
|
24 |
TEXT_PIPELINE = None
|
25 |
COMPARISON_PIPELINE = None
|
26 |
+
NUM_EXAMPLES = 100
|
27 |
|
28 |
@spaces.GPU(duration=300)
|
29 |
def finetune_small_subset():
|
30 |
"""
|
31 |
+
Fine-tunes the custom R1 model on a small subset of the ServiceNow-AI/R1-Distill-SFT dataset.
|
32 |
Steps:
|
33 |
1) Loads the model from "wuhp/myr1" (using files from the "myr1" subfolder via trust_remote_code).
|
34 |
2) Applies 4-bit quantization and prepares for QLoRA training.
|
|
|
163 |
|
164 |
def ensure_comparison_pipeline():
|
165 |
"""
|
166 |
+
Loads the official R1 model pipeline if not already loaded.
|
167 |
"""
|
168 |
global COMPARISON_PIPELINE
|
169 |
if COMPARISON_PIPELINE is None:
|
|
|
180 |
@spaces.GPU(duration=120)
|
181 |
def predict(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
|
182 |
"""
|
183 |
+
Direct generation without retrieval using the custom R1 model.
|
184 |
"""
|
185 |
pipe = ensure_pipeline()
|
186 |
out = pipe(
|
|
|
196 |
@spaces.GPU(duration=120)
|
197 |
def compare_models(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
|
198 |
"""
|
199 |
+
Compare outputs between your custom R1 model and the official R1 model.
|
200 |
"""
|
201 |
local_pipe = ensure_pipeline()
|
202 |
comp_pipe = ensure_comparison_pipeline()
|
|
|
299 |
|
300 |
# Build the Gradio interface.
|
301 |
with gr.Blocks() as demo:
|
302 |
+
gr.Markdown("# QLoRA Fine-tuning & RAG-based Chat Demo using Custom R1 Model")
|
303 |
|
304 |
finetune_btn = gr.Button("Finetune 4-bit (QLoRA) on ServiceNow-AI/R1-Distill-SFT subset (up to 5 min)")
|
305 |
status_box = gr.Textbox(label="Finetune Status")
|
306 |
finetune_btn.click(fn=finetune_small_subset, outputs=status_box)
|
307 |
|
308 |
+
gr.Markdown("## Direct Generation (No Retrieval) using Custom R1")
|
309 |
prompt_in = gr.Textbox(lines=3, label="Prompt")
|
310 |
temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature")
|
311 |
top_p = gr.Slider(0.0, 1.0, step=0.05, value=0.9, label="Top-p")
|
312 |
min_tokens = gr.Slider(1, 2500, value=50, step=10, label="Min New Tokens")
|
313 |
max_tokens = gr.Slider(1, 2500, value=200, step=50, label="Max New Tokens")
|
314 |
+
output_box = gr.Textbox(label="Custom R1 Output", lines=8)
|
315 |
+
gen_btn = gr.Button("Generate with Custom R1")
|
316 |
gen_btn.click(
|
317 |
fn=predict,
|
318 |
inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
|
319 |
outputs=output_box
|
320 |
)
|
321 |
|
322 |
+
gr.Markdown("## Compare Custom R1 vs Official R1")
|
323 |
compare_btn = gr.Button("Compare")
|
324 |
+
out_custom = gr.Textbox(label="Custom R1 Output", lines=6)
|
325 |
+
out_official = gr.Textbox(label="Official R1 Output", lines=6)
|
326 |
compare_btn.click(
|
327 |
fn=compare_models,
|
328 |
inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
|
329 |
+
outputs=[out_custom, out_official]
|
330 |
)
|
331 |
|
332 |
gr.Markdown("## Chat with Retrieval-Augmented Memory")
|