Spaces:
wuhp
/
Running on Zero

wuhp commited on
Commit
2f957f0
·
verified ·
1 Parent(s): b5aeb95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -49
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import gradio as gr
2
  import spaces
3
  import torch
 
 
4
 
5
  from datasets import load_dataset
6
  from transformers import (
@@ -17,16 +19,18 @@ from transformers import (
17
  # PEFT (LoRA / QLoRA)
18
  from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training, PeftModel
19
 
 
 
 
20
  ##############################################################################
21
- # ZeroGPU + QLoRA Example
22
  ##############################################################################
23
 
24
- TEXT_PIPELINE = None # Pipeline for wuhp/myr1 (fine-tuned or base)
25
- COMPARISON_PIPELINE = None # Pipeline for the DeepSeek model
26
-
27
  NUM_EXAMPLES = 50 # We'll train on 50 rows for demonstration
28
 
29
- @spaces.GPU(duration=300) # up to 5 min
30
  def finetune_small_subset():
31
  """
32
  1) Loads 'wuhp/myr1' in 4-bit quantization (QLoRA style),
@@ -42,15 +46,13 @@ def finetune_small_subset():
42
  split="train"
43
  )
44
 
45
- # For demonstration, pick a single conversation_id
46
  unique_ids = list(set(ds["conversation_id"]))
47
  single_id = unique_ids[0]
48
  ds = ds.filter(lambda x: x["conversation_id"] == single_id)
49
 
50
- # Then select only NUM_EXAMPLES from that subset
51
  ds = ds.select(range(min(NUM_EXAMPLES, len(ds))))
52
 
53
- # --- 2) Setup 4-bit quantization with BitsAndBytes ---
54
  bnb_config = BitsAndBytesConfig(
55
  load_in_4bit=True,
56
  bnb_4bit_compute_dtype=torch.bfloat16, # or torch.float16
@@ -78,7 +80,6 @@ def finetune_small_subset():
78
  trust_remote_code=True
79
  )
80
 
81
- # Prepare the model for k-bit training
82
  base_model = prepare_model_for_kbit_training(base_model)
83
 
84
  # --- 3) Create LoRA config & wrap the base model in LoRA ---
@@ -94,10 +95,6 @@ def finetune_small_subset():
94
 
95
  # --- 4) Tokenize dataset ---
96
  def tokenize_fn(ex):
97
- """
98
- Combine instruction + response into a single text.
99
- You can adjust this to include more fields or different formatting.
100
- """
101
  text = (
102
  f"Instruction: {ex['instruction']}\n\n"
103
  f"Response: {ex['response']}"
@@ -116,27 +113,24 @@ def finetune_small_subset():
116
  per_device_train_batch_size=1,
117
  gradient_accumulation_steps=2,
118
  logging_steps=5,
119
- save_steps=999999, # effectively don't save mid-epoch
120
  save_total_limit=1,
121
- fp16=False, # rely on bfloat16 from quantization
122
  )
123
 
124
- # Trainer
125
  trainer = Trainer(
126
  model=lora_model,
127
  args=training_args,
128
  train_dataset=ds,
129
  data_collator=collator,
130
  )
131
-
132
- # --- 5) Train ---
133
  trainer.train()
134
 
135
- # --- 6) Save LoRA adapter + tokenizer ---
136
  trainer.model.save_pretrained("finetuned_myr1")
137
  tokenizer.save_pretrained("finetuned_myr1")
138
 
139
- # --- 7) Reload the base model + LoRA adapter for inference
140
  base_model_2 = AutoModelForCausalLM.from_pretrained(
141
  "wuhp/myr1",
142
  subfolder="myr1",
@@ -191,8 +185,6 @@ def ensure_comparison_pipeline():
191
  """
192
  global COMPARISON_PIPELINE
193
  if COMPARISON_PIPELINE is None:
194
- # If you prefer 4-bit, you can define BitsAndBytesConfig here,
195
- # but let's keep it simpler for demonstration (fp16 or bf16).
196
  config = AutoConfig.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
197
  tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
198
  model = AutoModelForCausalLM.from_pretrained(
@@ -200,18 +192,14 @@ def ensure_comparison_pipeline():
200
  config=config,
201
  device_map="auto"
202
  )
203
- COMPARISON_PIPELINE = pipeline(
204
- "text-generation",
205
- model=model,
206
- tokenizer=tokenizer
207
- )
208
  return COMPARISON_PIPELINE
209
 
210
 
211
- @spaces.GPU(duration=120) # up to 2 min for text generation
212
  def predict(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
213
  """
214
- Generates text from the fine-tuned (LoRA) model if present, else the base model.
215
  """
216
  pipe = ensure_pipeline()
217
  out = pipe(
@@ -225,11 +213,10 @@ def predict(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
225
  return out[0]["generated_text"]
226
 
227
 
228
- @spaces.GPU(duration=120) # up to 2 min for text generation
229
  def compare_models(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
230
  """
231
- Generates text side-by-side from the local myr1 pipeline (fine-tuned or base)
232
- AND from the DeepSeek model. Returns two strings.
233
  """
234
  local_pipe = ensure_pipeline()
235
  comp_pipe = ensure_comparison_pipeline()
@@ -242,8 +229,6 @@ def compare_models(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
242
  max_new_tokens=int(max_new_tokens),
243
  do_sample=True
244
  )
245
- local_text = local_out[0]["generated_text"]
246
-
247
  comp_out = comp_pipe(
248
  prompt,
249
  temperature=float(temperature),
@@ -252,47 +237,205 @@ def compare_models(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
252
  max_new_tokens=int(max_new_tokens),
253
  do_sample=True
254
  )
255
- comp_text = comp_out[0]["generated_text"]
256
 
257
- return local_text, comp_text
258
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
- # Build Gradio UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  with gr.Blocks() as demo:
262
- gr.Markdown("# QLoRA Fine-tuning & Comparison Demo")
263
- gr.Markdown("**Fine-tune wuhp/myr1** on a small subset of the Magpie dataset, then generate or compare output with the DeepSeek model.")
264
 
265
  finetune_btn = gr.Button("Finetune 4-bit (QLoRA) on Magpie subset (up to 5 min)")
266
  status_box = gr.Textbox(label="Finetune Status")
267
- finetune_btn.click(fn=finetune_small_subset, outputs=status_box)
268
 
269
- gr.Markdown("### Generate with myr1 (fine-tuned if done, else base)")
270
 
 
 
271
  prompt_in = gr.Textbox(lines=3, label="Prompt")
272
  temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature")
273
  top_p = gr.Slider(0.0, 1.0, step=0.05, value=0.9, label="Top-p")
274
- min_tokens = gr.Slider(50, 1024, value=50, step=10, label="Min New Tokens")
275
- max_tokens = gr.Slider(50, 1024, value=200, step=50, label="Max New Tokens")
276
 
277
  output_box = gr.Textbox(label="myr1 Output", lines=8)
278
  gen_btn = gr.Button("Generate with myr1")
279
-
280
  gen_btn.click(
281
  fn=predict,
282
  inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
283
  outputs=output_box
284
  )
285
 
286
- gr.Markdown("### Compare myr1 vs DeepSeek side-by-side")
287
-
288
  compare_btn = gr.Button("Compare")
289
- out_local = gr.Textbox(label="myr1 Output", lines=8)
290
- out_deepseek = gr.Textbox(label="DeepSeek Output", lines=8)
291
-
292
  compare_btn.click(
293
  fn=compare_models,
294
  inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
295
  outputs=[out_local, out_deepseek]
296
  )
297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  demo.launch()
 
1
  import gradio as gr
2
  import spaces
3
  import torch
4
+ import faiss
5
+ import numpy as np
6
 
7
  from datasets import load_dataset
8
  from transformers import (
 
19
  # PEFT (LoRA / QLoRA)
20
  from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training, PeftModel
21
 
22
+ # For embeddings
23
+ from sentence_transformers import SentenceTransformer
24
+
25
  ##############################################################################
26
+ # QLoRA Demo Setup
27
  ##############################################################################
28
 
29
+ TEXT_PIPELINE = None
30
+ COMPARISON_PIPELINE = None
 
31
  NUM_EXAMPLES = 50 # We'll train on 50 rows for demonstration
32
 
33
+ @spaces.GPU(duration=300)
34
  def finetune_small_subset():
35
  """
36
  1) Loads 'wuhp/myr1' in 4-bit quantization (QLoRA style),
 
46
  split="train"
47
  )
48
 
 
49
  unique_ids = list(set(ds["conversation_id"]))
50
  single_id = unique_ids[0]
51
  ds = ds.filter(lambda x: x["conversation_id"] == single_id)
52
 
 
53
  ds = ds.select(range(min(NUM_EXAMPLES, len(ds))))
54
 
55
+ # --- 2) Setup 4-bit quantization ---
56
  bnb_config = BitsAndBytesConfig(
57
  load_in_4bit=True,
58
  bnb_4bit_compute_dtype=torch.bfloat16, # or torch.float16
 
80
  trust_remote_code=True
81
  )
82
 
 
83
  base_model = prepare_model_for_kbit_training(base_model)
84
 
85
  # --- 3) Create LoRA config & wrap the base model in LoRA ---
 
95
 
96
  # --- 4) Tokenize dataset ---
97
  def tokenize_fn(ex):
 
 
 
 
98
  text = (
99
  f"Instruction: {ex['instruction']}\n\n"
100
  f"Response: {ex['response']}"
 
113
  per_device_train_batch_size=1,
114
  gradient_accumulation_steps=2,
115
  logging_steps=5,
116
+ save_steps=999999,
117
  save_total_limit=1,
118
+ fp16=False,
119
  )
120
 
 
121
  trainer = Trainer(
122
  model=lora_model,
123
  args=training_args,
124
  train_dataset=ds,
125
  data_collator=collator,
126
  )
 
 
127
  trainer.train()
128
 
129
+ # --- 5) Save LoRA adapter + tokenizer ---
130
  trainer.model.save_pretrained("finetuned_myr1")
131
  tokenizer.save_pretrained("finetuned_myr1")
132
 
133
+ # --- 6) Reload for inference
134
  base_model_2 = AutoModelForCausalLM.from_pretrained(
135
  "wuhp/myr1",
136
  subfolder="myr1",
 
185
  """
186
  global COMPARISON_PIPELINE
187
  if COMPARISON_PIPELINE is None:
 
 
188
  config = AutoConfig.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
189
  tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
190
  model = AutoModelForCausalLM.from_pretrained(
 
192
  config=config,
193
  device_map="auto"
194
  )
195
+ COMPARISON_PIPELINE = pipeline("text-generation", model=model, tokenizer=tokenizer)
 
 
 
 
196
  return COMPARISON_PIPELINE
197
 
198
 
199
+ @spaces.GPU(duration=120)
200
  def predict(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
201
  """
202
+ Simple single-prompt generation (no retrieval).
203
  """
204
  pipe = ensure_pipeline()
205
  out = pipe(
 
213
  return out[0]["generated_text"]
214
 
215
 
216
+ @spaces.GPU(duration=120)
217
  def compare_models(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
218
  """
219
+ Compare local pipeline vs. DeepSeek side-by-side.
 
220
  """
221
  local_pipe = ensure_pipeline()
222
  comp_pipe = ensure_comparison_pipeline()
 
229
  max_new_tokens=int(max_new_tokens),
230
  do_sample=True
231
  )
 
 
232
  comp_out = comp_pipe(
233
  prompt,
234
  temperature=float(temperature),
 
237
  max_new_tokens=int(max_new_tokens),
238
  do_sample=True
239
  )
240
+ return local_out[0]["generated_text"], comp_out[0]["generated_text"]
241
 
 
242
 
243
+ ###############################################################################
244
+ # Retrieval-Augmented Memory with FAISS
245
+ ###############################################################################
246
+ class ConversationRetriever:
247
+ """
248
+ A simple in-memory store + FAISS for retrieval of conversation chunks.
249
+ Each chunk is embedded via SentenceTransformer. On a new user query,
250
+ we embed the query, do similarity search, and retrieve top-k relevant chunks.
251
+ """
252
+
253
+ def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2", embed_dim=384):
254
+ """
255
+ model_name: embedding model for messages
256
+ embed_dim: dimension of the embeddings from that model
257
+ """
258
+ self.embed_model = SentenceTransformer(model_name)
259
+ self.embed_dim = embed_dim
260
 
261
+ # We'll store (text, vector) in FAISS. For metadata, store in python list/dict.
262
+ # For a real app, you'd probably want a more robust store.
263
+ self.index = faiss.IndexFlatL2(embed_dim)
264
+ self.texts = [] # store the raw text chunks
265
+ self.vectors = [] # store vectors (redundant but simpler to show)
266
+ self.ids = [] # store an integer ID or similar
267
+
268
+ self.id_counter = 0
269
+
270
+ def add_text(self, text):
271
+ """
272
+ Add a new text chunk to the vector store.
273
+ Could chunk it up if desired, but here we treat the entire text as one chunk.
274
+ """
275
+ if not text.strip():
276
+ return
277
+
278
+ emb = self.embed_model.encode([text], convert_to_numpy=True)
279
+ vec = emb[0].astype(np.float32) # shape [embed_dim]
280
+ self.index.add(vec.reshape(1, -1))
281
+
282
+ self.texts.append(text)
283
+ self.vectors.append(vec)
284
+ self.ids.append(self.id_counter)
285
+
286
+ self.id_counter += 1
287
+
288
+ def search(self, query, top_k=3):
289
+ """
290
+ Given a query, embed it, do similarity search in FAISS, return top-k texts.
291
+ """
292
+ q_emb = self.embed_model.encode([query], convert_to_numpy=True).astype(np.float32)
293
+ q_vec = q_emb[0].reshape(1, -1)
294
+ distances, indices = self.index.search(q_vec, top_k)
295
+
296
+ # indices is shape [1, top_k], distances is shape [1, top_k]
297
+ results = []
298
+ for dist, idx in zip(distances[0], indices[0]):
299
+ if idx < len(self.texts): # safety check
300
+ results.append((self.texts[idx], dist))
301
+ return results
302
+
303
+
304
+ ###############################################################################
305
+ # Build a Chat that uses RAG
306
+ ###############################################################################
307
+ retriever = ConversationRetriever() # global retriever instance
308
+
309
+ def build_rag_prompt(user_query, retrieved_chunks):
310
+ """
311
+ Construct a prompt that includes:
312
+ - The user's new query
313
+ - A "Relevant Context" section from retrieved chunks
314
+ - "Assistant:" to let the model continue
315
+ Feel free to customize the formatting as you like.
316
+ """
317
+ context_str = ""
318
+ for i, (chunk, dist) in enumerate(retrieved_chunks):
319
+ context_str += f"Chunk #{i+1} (similarity score ~ {dist:.2f}):\n{chunk}\n\n"
320
+
321
+ prompt = (
322
+ f"User's Query:\n{user_query}\n\n"
323
+ f"Relevant Context from Conversation:\n{context_str}"
324
+ "Assistant:"
325
+ )
326
+ return prompt
327
+
328
+
329
+ @spaces.GPU(duration=120)
330
+ def chat_rag(user_input, history, temperature, top_p, min_new_tokens, max_new_tokens):
331
+ """
332
+ Our RAG-based chat function. We'll:
333
+ 1) Add user input to FAISS
334
+ 2) Retrieve top-k relevant older messages from FAISS
335
+ 3) Build a prompt that includes the relevant chunks + user query
336
+ 4) Generate a response from the pipeline
337
+ 5) Add the assistant's response to FAISS as well
338
+ """
339
+ pipe = ensure_pipeline()
340
+
341
+ # 1) Add the user input as a chunk to the retriever DB.
342
+ retriever.add_text(f"User: {user_input}")
343
+
344
+ # 2) Retrieve top-3 older chunks. We can skip the chunk we just added if we want to
345
+ # (since it's the same text), but for simplicity let's just do a search for user_input.
346
+ top_k = 3
347
+ results = retriever.search(user_input, top_k=top_k)
348
+
349
+ # 3) Build final prompt
350
+ prompt = build_rag_prompt(user_input, results)
351
+
352
+ # 4) Generate
353
+ output = pipe(
354
+ prompt,
355
+ temperature=float(temperature),
356
+ top_p=float(top_p),
357
+ min_new_tokens=int(min_new_tokens),
358
+ max_new_tokens=int(max_new_tokens),
359
+ do_sample=True
360
+ )[0]["generated_text"]
361
+
362
+ # We only want the new part after "Assistant:"
363
+ # Because the pipeline output includes the entire prompt + new text.
364
+ if output.startswith(prompt):
365
+ assistant_reply = output[len(prompt):].strip()
366
+ else:
367
+ assistant_reply = output.strip()
368
+
369
+ # 5) Add the assistant's response to the DB as well
370
+ retriever.add_text(f"Assistant: {assistant_reply}")
371
+
372
+ # 6) Update the chat history for display in the Gradio Chatbot
373
+ history.append([user_input, assistant_reply])
374
+ return history, history
375
+
376
+
377
+ ###############################################################################
378
+ # Gradio UI
379
+ ###############################################################################
380
  with gr.Blocks() as demo:
381
+ gr.Markdown("# QLoRA Fine-tuning & RAG-based Chat Demo")
 
382
 
383
  finetune_btn = gr.Button("Finetune 4-bit (QLoRA) on Magpie subset (up to 5 min)")
384
  status_box = gr.Textbox(label="Finetune Status")
 
385
 
386
+ finetune_btn.click(fn=finetune_small_subset, outputs=status_box)
387
 
388
+ # Simple generation UI (no retrieval):
389
+ gr.Markdown("## Direct Generation (No Retrieval)")
390
  prompt_in = gr.Textbox(lines=3, label="Prompt")
391
  temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature")
392
  top_p = gr.Slider(0.0, 1.0, step=0.05, value=0.9, label="Top-p")
393
+ min_tokens = gr.Slider(1, 2500, value=50, step=10, label="Min New Tokens")
394
+ max_tokens = gr.Slider(1, 2500, value=200, step=50, label="Max New Tokens")
395
 
396
  output_box = gr.Textbox(label="myr1 Output", lines=8)
397
  gen_btn = gr.Button("Generate with myr1")
 
398
  gen_btn.click(
399
  fn=predict,
400
  inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
401
  outputs=output_box
402
  )
403
 
404
+ # Comparison UI:
405
+ gr.Markdown("## Compare myr1 vs DeepSeek")
406
  compare_btn = gr.Button("Compare")
407
+ out_local = gr.Textbox(label="myr1 Output", lines=6)
408
+ out_deepseek = gr.Textbox(label="DeepSeek Output", lines=6)
 
409
  compare_btn.click(
410
  fn=compare_models,
411
  inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
412
  outputs=[out_local, out_deepseek]
413
  )
414
 
415
+ # RAG-based Chat
416
+ gr.Markdown("## Chat with Retrieval-Augmented Memory")
417
+ with gr.Row():
418
+ with gr.Column():
419
+ chatbot = gr.Chatbot(label="RAG Chat")
420
+ chat_state = gr.State([]) # just for display
421
+
422
+ user_input = gr.Textbox(
423
+ show_label=False,
424
+ placeholder="Ask a question...",
425
+ lines=2
426
+ )
427
+ send_btn = gr.Button("Send")
428
+
429
+ # On user submit, call chat_rag
430
+ user_input.submit(
431
+ fn=chat_rag,
432
+ inputs=[user_input, chat_state, temperature, top_p, min_tokens, max_tokens],
433
+ outputs=[chat_state, chatbot]
434
+ )
435
+ send_btn.click(
436
+ fn=chat_rag,
437
+ inputs=[user_input, chat_state, temperature, top_p, min_tokens, max_tokens],
438
+ outputs=[chat_state, chatbot]
439
+ )
440
+
441
  demo.launch()