Spaces:
wuhp
/
Running on Zero

wuhp commited on
Commit
4df7266
·
verified ·
1 Parent(s): c8cf005

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -13
app.py CHANGED
@@ -23,12 +23,12 @@ from sentence_transformers import SentenceTransformer
23
  # Global variables for pipelines and settings.
24
  TEXT_PIPELINE = None
25
  COMPARISON_PIPELINE = None
26
- NUM_EXAMPLES = 50
27
 
28
  @spaces.GPU(duration=300)
29
  def finetune_small_subset():
30
  """
31
- Fine-tunes the custom DeepSeekV3 model on a small subset of the ServiceNow-AI/R1-Distill-SFT dataset.
32
  Steps:
33
  1) Loads the model from "wuhp/myr1" (using files from the "myr1" subfolder via trust_remote_code).
34
  2) Applies 4-bit quantization and prepares for QLoRA training.
@@ -163,7 +163,7 @@ def ensure_pipeline():
163
 
164
  def ensure_comparison_pipeline():
165
  """
166
- Loads a reference DeepSeek model pipeline if not already loaded.
167
  """
168
  global COMPARISON_PIPELINE
169
  if COMPARISON_PIPELINE is None:
@@ -180,7 +180,7 @@ def ensure_comparison_pipeline():
180
  @spaces.GPU(duration=120)
181
  def predict(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
182
  """
183
- Direct generation without retrieval.
184
  """
185
  pipe = ensure_pipeline()
186
  out = pipe(
@@ -196,7 +196,7 @@ def predict(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
196
  @spaces.GPU(duration=120)
197
  def compare_models(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
198
  """
199
- Compare outputs between your custom model and a reference DeepSeek model.
200
  """
201
  local_pipe = ensure_pipeline()
202
  comp_pipe = ensure_comparison_pipeline()
@@ -299,34 +299,34 @@ def chat_rag(user_input, history, temperature, top_p, min_new_tokens, max_new_to
299
 
300
  # Build the Gradio interface.
301
  with gr.Blocks() as demo:
302
- gr.Markdown("# QLoRA Fine-tuning & RAG-based Chat Demo using Custom DeepSeekV3 Model")
303
 
304
  finetune_btn = gr.Button("Finetune 4-bit (QLoRA) on ServiceNow-AI/R1-Distill-SFT subset (up to 5 min)")
305
  status_box = gr.Textbox(label="Finetune Status")
306
  finetune_btn.click(fn=finetune_small_subset, outputs=status_box)
307
 
308
- gr.Markdown("## Direct Generation (No Retrieval)")
309
  prompt_in = gr.Textbox(lines=3, label="Prompt")
310
  temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature")
311
  top_p = gr.Slider(0.0, 1.0, step=0.05, value=0.9, label="Top-p")
312
  min_tokens = gr.Slider(1, 2500, value=50, step=10, label="Min New Tokens")
313
  max_tokens = gr.Slider(1, 2500, value=200, step=50, label="Max New Tokens")
314
- output_box = gr.Textbox(label="DeepSeekV3 Output", lines=8)
315
- gen_btn = gr.Button("Generate with DeepSeekV3")
316
  gen_btn.click(
317
  fn=predict,
318
  inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
319
  outputs=output_box
320
  )
321
 
322
- gr.Markdown("## Compare DeepSeekV3 vs Reference DeepSeek")
323
  compare_btn = gr.Button("Compare")
324
- out_local = gr.Textbox(label="DeepSeekV3 Output", lines=6)
325
- out_deepseek = gr.Textbox(label="Reference DeepSeek Output", lines=6)
326
  compare_btn.click(
327
  fn=compare_models,
328
  inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
329
- outputs=[out_local, out_deepseek]
330
  )
331
 
332
  gr.Markdown("## Chat with Retrieval-Augmented Memory")
 
23
  # Global variables for pipelines and settings.
24
  TEXT_PIPELINE = None
25
  COMPARISON_PIPELINE = None
26
+ NUM_EXAMPLES = 100
27
 
28
  @spaces.GPU(duration=300)
29
  def finetune_small_subset():
30
  """
31
+ Fine-tunes the custom R1 model on a small subset of the ServiceNow-AI/R1-Distill-SFT dataset.
32
  Steps:
33
  1) Loads the model from "wuhp/myr1" (using files from the "myr1" subfolder via trust_remote_code).
34
  2) Applies 4-bit quantization and prepares for QLoRA training.
 
163
 
164
  def ensure_comparison_pipeline():
165
  """
166
+ Loads the official R1 model pipeline if not already loaded.
167
  """
168
  global COMPARISON_PIPELINE
169
  if COMPARISON_PIPELINE is None:
 
180
  @spaces.GPU(duration=120)
181
  def predict(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
182
  """
183
+ Direct generation without retrieval using the custom R1 model.
184
  """
185
  pipe = ensure_pipeline()
186
  out = pipe(
 
196
  @spaces.GPU(duration=120)
197
  def compare_models(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
198
  """
199
+ Compare outputs between your custom R1 model and the official R1 model.
200
  """
201
  local_pipe = ensure_pipeline()
202
  comp_pipe = ensure_comparison_pipeline()
 
299
 
300
  # Build the Gradio interface.
301
  with gr.Blocks() as demo:
302
+ gr.Markdown("# QLoRA Fine-tuning & RAG-based Chat Demo using Custom R1 Model")
303
 
304
  finetune_btn = gr.Button("Finetune 4-bit (QLoRA) on ServiceNow-AI/R1-Distill-SFT subset (up to 5 min)")
305
  status_box = gr.Textbox(label="Finetune Status")
306
  finetune_btn.click(fn=finetune_small_subset, outputs=status_box)
307
 
308
+ gr.Markdown("## Direct Generation (No Retrieval) using Custom R1")
309
  prompt_in = gr.Textbox(lines=3, label="Prompt")
310
  temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature")
311
  top_p = gr.Slider(0.0, 1.0, step=0.05, value=0.9, label="Top-p")
312
  min_tokens = gr.Slider(1, 2500, value=50, step=10, label="Min New Tokens")
313
  max_tokens = gr.Slider(1, 2500, value=200, step=50, label="Max New Tokens")
314
+ output_box = gr.Textbox(label="Custom R1 Output", lines=8)
315
+ gen_btn = gr.Button("Generate with Custom R1")
316
  gen_btn.click(
317
  fn=predict,
318
  inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
319
  outputs=output_box
320
  )
321
 
322
+ gr.Markdown("## Compare Custom R1 vs Official R1")
323
  compare_btn = gr.Button("Compare")
324
+ out_custom = gr.Textbox(label="Custom R1 Output", lines=6)
325
+ out_official = gr.Textbox(label="Official R1 Output", lines=6)
326
  compare_btn.click(
327
  fn=compare_models,
328
  inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
329
+ outputs=[out_custom, out_official]
330
  )
331
 
332
  gr.Markdown("## Chat with Retrieval-Augmented Memory")