wuhp commited on
Commit
3986b4b
·
verified ·
1 Parent(s): 09f030f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -46
app.py CHANGED
@@ -21,7 +21,7 @@ from transformers import (
21
 
22
  NUM_EXAMPLES_FOR_FINETUNING = 50 # Constant for the number of examples to use for finetuning
23
  TEXT_PIPELINE = None # Global to store the custom R1 text generation pipeline
24
- COMPARISON_PIPELINE = None # Global to store the official R1 text generation pipeline
25
 
26
 
27
  def _load_model_and_tokenizer(model_name: str, subfolder: str = None, quantization_config: BitsAndBytesConfig = None, device_map: str = "auto", trust_remote_code: bool = True) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
@@ -66,7 +66,6 @@ def finetune_small_subset() -> str:
66
  Returns:
67
  str: A message indicating finetuning completion.
68
  """
69
- # Specify the configuration ("v0" or "v1") explicitly.
70
  ds = load_dataset("ServiceNow-AI/R1-Distill-SFT", "v0", split="train")
71
  ds = ds.select(range(min(NUM_EXAMPLES_FOR_FINETUNING, len(ds))))
72
 
@@ -76,8 +75,6 @@ def finetune_small_subset() -> str:
76
  bnb_4bit_use_double_quant=True,
77
  bnb_4bit_quant_type="nf4",
78
  )
79
-
80
- # Load the custom model configuration from the repository.
81
  base_model, tokenizer = _load_model_and_tokenizer(
82
  "wuhp/myr1", subfolder="myr1", quantization_config=bnb_config, device_map="auto"
83
  )
@@ -112,8 +109,8 @@ def finetune_small_subset() -> str:
112
  per_device_train_batch_size=1,
113
  gradient_accumulation_steps=2,
114
  logging_steps=5,
115
- save_steps=999999, # Save infrequently to avoid filling up disk during demo
116
- save_total_limit=1, # Keep only the last saved checkpoint
117
  fp16=False,
118
  )
119
 
@@ -128,7 +125,7 @@ def finetune_small_subset() -> str:
128
  trainer.model.save_pretrained("finetuned_myr1")
129
  tokenizer.save_pretrained("finetuned_myr1")
130
 
131
- base_model_2, tokenizer_2 = _load_model_and_tokenizer( # Re-load base model for inference adapter application
132
  "wuhp/myr1", subfolder="myr1", quantization_config=bnb_config, device_map="auto"
133
  )
134
  base_model_2 = prepare_model_for_kbit_training(base_model_2)
@@ -139,7 +136,7 @@ def finetune_small_subset() -> str:
139
  )
140
 
141
  global TEXT_PIPELINE
142
- TEXT_PIPELINE = pipeline("text-generation", model=lora_model_2, tokenizer=tokenizer_2) # Use tokenizer_2 here to be consistent
143
 
144
  return "Finetuning complete. Model loaded for inference."
145
 
@@ -205,18 +202,26 @@ def predict(
205
  max_new_tokens (int): Maximum number of new tokens to generate.
206
 
207
  Returns:
208
- str: The generated text output.
209
  """
210
  pipe = ensure_pipeline()
211
- out = pipe(
 
 
 
 
212
  prompt,
213
  temperature=float(temperature),
214
  top_p=float(top_p),
215
  min_new_tokens=int(min_new_tokens),
216
  max_new_tokens=int(max_new_tokens),
217
  do_sample=True
218
- )
219
- return out[0]["generated_text"]
 
 
 
 
220
 
221
 
222
  @spaces.GPU(duration=120)
@@ -238,28 +243,41 @@ def compare_models(
238
  max_new_tokens (int): Maximum number of new tokens to generate.
239
 
240
  Returns:
241
- tuple[str, str]: A tuple containing the generated text from the custom R1 and official R1 models.
242
  """
243
  local_pipe = ensure_pipeline()
244
  comp_pipe = ensure_comparison_pipeline()
245
 
246
- local_out = local_pipe(
 
 
 
 
 
 
 
 
247
  prompt,
248
  temperature=float(temperature),
249
  top_p=float(top_p),
250
  min_new_tokens=int(min_new_tokens),
251
  max_new_tokens=int(max_new_tokens),
252
  do_sample=True
253
- )
254
- comp_out = comp_pipe(
 
255
  prompt,
256
  temperature=float(temperature),
257
  top_p=float(top_p),
258
  min_new_tokens=int(min_new_tokens),
259
  max_new_tokens=int(max_new_tokens),
260
  do_sample=True
261
- )
262
- return local_out[0]["generated_text"], comp_out[0]["generated_text"]
 
 
 
 
263
 
264
 
265
  class ConversationRetriever:
@@ -335,15 +353,20 @@ def build_rag_prompt(user_query: str, retrieved_chunks: list[tuple[str, float]])
335
  retrieved_chunks (list[tuple[str, float]]): List of retrieved text chunks and their distances.
336
 
337
  Returns:
338
- str: The formatted prompt string.
339
  """
340
  context_str = ""
341
- for i, (chunk, dist) in enumerate(retrieved_chunks):
342
- context_str += f"Chunk #{i+1} (similarity ~ {dist:.2f}):\n{chunk}\n\n"
 
 
 
 
343
  prompt = (
344
- f"User's Query:\n{user_query}\n\n"
345
- f"Relevant Context:\n{context_str}"
346
- "Assistant:"
 
347
  )
348
  return prompt
349
 
@@ -369,13 +392,18 @@ def chat_rag(
369
  max_new_tokens (int): Maximum number of new tokens to generate.
370
 
371
  Returns:
372
- tuple[list[list[str]], list[list[str]]]: Updated chat history and chatbot display history.
373
  """
374
  pipe = ensure_pipeline()
375
  retriever.add_text(f"User: {user_input}")
376
  top_k = 3
377
  results = retriever.search(user_input, top_k=top_k)
378
  prompt = build_rag_prompt(user_input, results)
 
 
 
 
 
379
  output = pipe(
380
  prompt,
381
  temperature=float(temperature),
@@ -385,10 +413,14 @@ def chat_rag(
385
  do_sample=True
386
  )[0]["generated_text"]
387
 
388
- if output.startswith(prompt):
389
- assistant_reply = output[len(prompt):].strip()
 
 
 
 
390
  else:
391
- assistant_reply = output.strip()
392
 
393
  retriever.add_text(f"Assistant: {assistant_reply}")
394
  history.append([user_input, assistant_reply])
@@ -398,46 +430,56 @@ def chat_rag(
398
  # Build the Gradio interface.
399
  with gr.Blocks() as demo:
400
  gr.Markdown("# QLoRA Fine-tuning & RAG-based Chat Demo using Custom R1 Model")
 
401
 
402
- finetune_btn = gr.Button("Finetune 4-bit (QLoRA) on ServiceNow-AI/R1-Distill-SFT subset (up to 5 min)")
403
- status_box = gr.Textbox(label="Finetune Status")
 
 
404
  finetune_btn.click(fn=finetune_small_subset, outputs=status_box)
 
405
 
406
- gr.Markdown("## Direct Generation (No Retrieval) using Custom R1")
407
- prompt_in = gr.Textbox(lines=3, label="Prompt")
408
- temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature")
409
- top_p = gr.Slider(0.0, 1.0, step=0.05, value=0.9, label="Top-p")
 
410
  min_tokens = gr.Slider(1, 2500, value=50, step=10, label="Min New Tokens")
411
  max_tokens = gr.Slider(1, 2500, value=200, step=50, label="Max New Tokens")
412
- output_box = gr.Textbox(label="Custom R1 Output", lines=8)
413
- gen_btn = gr.Button("Generate with Custom R1")
414
  gen_btn.click(
415
  fn=predict,
416
  inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
417
  outputs=output_box
418
  )
419
-
420
- gr.Markdown("## Compare Custom R1 vs Official R1")
421
- compare_btn = gr.Button("Compare")
422
- out_custom = gr.Textbox(label="Custom R1 Output", lines=6)
423
- out_official = gr.Textbox(label="Official R1 Output", lines=6)
 
 
 
424
  compare_btn.click(
425
  fn=compare_models,
426
- inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
427
  outputs=[out_custom, out_official]
428
  )
 
429
 
430
- gr.Markdown("## Chat with Retrieval-Augmented Memory")
 
431
  with gr.Row():
432
  with gr.Column():
433
- chatbot = gr.Chatbot(label="RAG Chat")
434
  chat_state = gr.State([])
435
  user_input = gr.Textbox(
436
  show_label=False,
437
- placeholder="Ask a question...",
438
  lines=2
439
  )
440
- send_btn = gr.Button("Send")
441
  user_input.submit(
442
  fn=chat_rag,
443
  inputs=[user_input, chat_state, temperature, top_p, min_tokens, max_tokens],
@@ -448,5 +490,7 @@ with gr.Blocks() as demo:
448
  inputs=[user_input, chat_state, temperature, top_p, min_tokens, max_tokens],
449
  outputs=[chat_state, chatbot]
450
  )
 
 
451
 
452
  demo.launch()
 
21
 
22
  NUM_EXAMPLES_FOR_FINETUNING = 50 # Constant for the number of examples to use for finetuning
23
  TEXT_PIPELINE = None # Global to store the custom R1 text generation pipeline
24
+ COMPARISON_PIPELINE = None # Global to store the official R1 text generation pipeline
25
 
26
 
27
  def _load_model_and_tokenizer(model_name: str, subfolder: str = None, quantization_config: BitsAndBytesConfig = None, device_map: str = "auto", trust_remote_code: bool = True) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
 
66
  Returns:
67
  str: A message indicating finetuning completion.
68
  """
 
69
  ds = load_dataset("ServiceNow-AI/R1-Distill-SFT", "v0", split="train")
70
  ds = ds.select(range(min(NUM_EXAMPLES_FOR_FINETUNING, len(ds))))
71
 
 
75
  bnb_4bit_use_double_quant=True,
76
  bnb_4bit_quant_type="nf4",
77
  )
 
 
78
  base_model, tokenizer = _load_model_and_tokenizer(
79
  "wuhp/myr1", subfolder="myr1", quantization_config=bnb_config, device_map="auto"
80
  )
 
109
  per_device_train_batch_size=1,
110
  gradient_accumulation_steps=2,
111
  logging_steps=5,
112
+ save_steps=999999,
113
+ save_total_limit=1,
114
  fp16=False,
115
  )
116
 
 
125
  trainer.model.save_pretrained("finetuned_myr1")
126
  tokenizer.save_pretrained("finetuned_myr1")
127
 
128
+ base_model_2, tokenizer_2 = _load_model_and_tokenizer(
129
  "wuhp/myr1", subfolder="myr1", quantization_config=bnb_config, device_map="auto"
130
  )
131
  base_model_2 = prepare_model_for_kbit_training(base_model_2)
 
136
  )
137
 
138
  global TEXT_PIPELINE
139
+ TEXT_PIPELINE = pipeline("text-generation", model=lora_model_2, tokenizer=tokenizer_2)
140
 
141
  return "Finetuning complete. Model loaded for inference."
142
 
 
202
  max_new_tokens (int): Maximum number of new tokens to generate.
203
 
204
  Returns:
205
+ str: The generated text output with "Thinking Process" and "Solution" sections.
206
  """
207
  pipe = ensure_pipeline()
208
+ thinking_prefix = "**Thinking Process:**\n"
209
+ solution_prefix = "\n**Solution:**\n"
210
+ formatted_output = thinking_prefix
211
+
212
+ output = pipe(
213
  prompt,
214
  temperature=float(temperature),
215
  top_p=float(top_p),
216
  min_new_tokens=int(min_new_tokens),
217
  max_new_tokens=int(max_new_tokens),
218
  do_sample=True
219
+ )[0]["generated_text"]
220
+
221
+ formatted_output += output.strip() + solution_prefix
222
+ formatted_output += "Final Answer (This part is a placeholder and needs better extraction): ... "
223
+
224
+ return formatted_output
225
 
226
 
227
  @spaces.GPU(duration=120)
 
243
  max_new_tokens (int): Maximum number of new tokens to generate.
244
 
245
  Returns:
246
+ tuple[str, str]: A tuple containing the formatted generated text from the custom R1 and official R1 models, each with "Thinking Process" and "Solution" sections.
247
  """
248
  local_pipe = ensure_pipeline()
249
  comp_pipe = ensure_comparison_pipeline()
250
 
251
+ def format_comparison_output(model_name, raw_output):
252
+ thinking_prefix = f"**{model_name} - Thinking Process:**\n"
253
+ solution_prefix = f"\n**{model_name} - Solution:**\n"
254
+ formatted_output = thinking_prefix
255
+ formatted_output += raw_output.strip() + solution_prefix
256
+ formatted_output += f"{model_name} Final Answer: ... "
257
+ return formatted_output
258
+
259
+ local_out_raw = local_pipe(
260
  prompt,
261
  temperature=float(temperature),
262
  top_p=float(top_p),
263
  min_new_tokens=int(min_new_tokens),
264
  max_new_tokens=int(max_new_tokens),
265
  do_sample=True
266
+ )[0]["generated_text"]
267
+
268
+ comp_out_raw = comp_pipe(
269
  prompt,
270
  temperature=float(temperature),
271
  top_p=float(top_p),
272
  min_new_tokens=int(min_new_tokens),
273
  max_new_tokens=int(max_new_tokens),
274
  do_sample=True
275
+ )[0]["generated_text"]
276
+
277
+ local_out_formatted = format_comparison_output("Custom R1", local_out_raw)
278
+ comp_out_formatted = format_comparison_output("Official R1", comp_out_raw)
279
+
280
+ return local_out_formatted, comp_out_formatted
281
 
282
 
283
  class ConversationRetriever:
 
353
  retrieved_chunks (list[tuple[str, float]]): List of retrieved text chunks and their distances.
354
 
355
  Returns:
356
+ str: The formatted prompt string including instructions for step-by-step thinking and using context.
357
  """
358
  context_str = ""
359
+ if retrieved_chunks:
360
+ context_str += "**Relevant Context:**\n"
361
+ for i, (chunk, dist) in enumerate(retrieved_chunks):
362
+ context_str += f"Chunk #{i+1} (similarity ~ {dist:.2f}):\n> {chunk}\n\n"
363
+
364
+ prompt_instruction = "Please provide a detailed answer, showing your thinking process step-by-step before stating the final answer. Use the provided context if relevant."
365
  prompt = (
366
+ f"**User Query:**\n{user_query}\n\n"
367
+ f"{context_str}\n"
368
+ f"{prompt_instruction}\n\n"
369
+ "**Answer:**\n"
370
  )
371
  return prompt
372
 
 
392
  max_new_tokens (int): Maximum number of new tokens to generate.
393
 
394
  Returns:
395
+ tuple[list[list[str]], list[list[str]]]: Updated chat history and chatbot display history, with formatted assistant replies.
396
  """
397
  pipe = ensure_pipeline()
398
  retriever.add_text(f"User: {user_input}")
399
  top_k = 3
400
  results = retriever.search(user_input, top_k=top_k)
401
  prompt = build_rag_prompt(user_input, results)
402
+
403
+ thinking_prefix = "**Thinking Process:**\n"
404
+ solution_prefix = "\n**Solution:**\n"
405
+ formatted_output = thinking_prefix
406
+
407
  output = pipe(
408
  prompt,
409
  temperature=float(temperature),
 
413
  do_sample=True
414
  )[0]["generated_text"]
415
 
416
+ formatted_output += output.strip() + solution_prefix
417
+ formatted_output += "Final Answer (This part is a placeholder and needs better extraction): ... "
418
+ assistant_reply = formatted_output
419
+
420
+ if assistant_reply.startswith(prompt):
421
+ assistant_reply = assistant_reply[len(prompt):].strip()
422
  else:
423
+ assistant_reply = assistant_reply.strip()
424
 
425
  retriever.add_text(f"Assistant: {assistant_reply}")
426
  history.append([user_input, assistant_reply])
 
430
  # Build the Gradio interface.
431
  with gr.Blocks() as demo:
432
  gr.Markdown("# QLoRA Fine-tuning & RAG-based Chat Demo using Custom R1 Model")
433
+ gr.Markdown("---")
434
 
435
+ gr.Markdown("## ⚙️ Fine-tuning (Optional)")
436
+ gr.Markdown("This section allows you to fine-tune the custom R1 model on a small subset of the ServiceNow dataset. This step is optional but can potentially improve the model's performance on ServiceNow-related tasks. **Note:** This process may take up to 5 minutes.")
437
+ finetune_btn = gr.Button("🚀 Start Fine-tuning (QLoRA)")
438
+ status_box = gr.Textbox(label="Fine-tuning Status", interactive=False)
439
  finetune_btn.click(fn=finetune_small_subset, outputs=status_box)
440
+ gr.Markdown("---")
441
 
442
+ gr.Markdown("## ✍️ Direct Generation (No Retrieval)")
443
+ gr.Markdown("Enter a prompt below to generate text directly using the custom R1 model. This is standard text generation without retrieval augmentation.")
444
+ prompt_in = gr.Textbox(lines=3, label="Input Prompt", placeholder="Enter your prompt here...")
445
+ temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature (Creativity)")
446
+ top_p = gr.Slider(0.0, 1.0, step=0.05, value=0.9, label="Top-p (Sampling Nucleus)")
447
  min_tokens = gr.Slider(1, 2500, value=50, step=10, label="Min New Tokens")
448
  max_tokens = gr.Slider(1, 2500, value=200, step=50, label="Max New Tokens")
449
+ output_box = gr.Textbox(label="Custom R1 Output", lines=8, interactive=False)
450
+ gen_btn = gr.Button("Generate Text")
451
  gen_btn.click(
452
  fn=predict,
453
  inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
454
  outputs=output_box
455
  )
456
+ gr.Markdown("---")
457
+
458
+ gr.Markdown("## 🆚 Compare Custom R1 vs Official R1")
459
+ gr.Markdown("Enter a prompt to compare the text generation of your fine-tuned custom R1 model with the official DeepSeek-R1-Distill-Llama-8B model.")
460
+ compare_prompt_in = gr.Textbox(lines=3, label="Comparison Prompt", placeholder="Enter prompt for comparison...")
461
+ compare_btn = gr.Button("⚖️ Compare Models")
462
+ out_custom = gr.Textbox(label="Custom R1 Output", lines=6, interactive=False)
463
+ out_official = gr.Textbox(label="Official R1 Output", lines=6, interactive=False)
464
  compare_btn.click(
465
  fn=compare_models,
466
+ inputs=[compare_prompt_in, temperature, top_p, min_tokens, max_tokens],
467
  outputs=[out_custom, out_official]
468
  )
469
+ gr.Markdown("---")
470
 
471
+ gr.Markdown("## 💬 Chat with Retrieval-Augmented Memory (RAG)")
472
+ gr.Markdown("Chat with the custom R1 model, enhanced with a retrieval-augmented memory. The model will retrieve relevant information based on your queries to provide more informed responses.")
473
  with gr.Row():
474
  with gr.Column():
475
+ chatbot = gr.Chatbot(label="RAG Chatbot")
476
  chat_state = gr.State([])
477
  user_input = gr.Textbox(
478
  show_label=False,
479
+ placeholder="Ask a question to the RAG Chatbot...",
480
  lines=2
481
  )
482
+ send_btn = gr.Button("➡️ Send")
483
  user_input.submit(
484
  fn=chat_rag,
485
  inputs=[user_input, chat_state, temperature, top_p, min_tokens, max_tokens],
 
490
  inputs=[user_input, chat_state, temperature, top_p, min_tokens, max_tokens],
491
  outputs=[chat_state, chatbot]
492
  )
493
+ gr.Markdown("---")
494
+
495
 
496
  demo.launch()