wuhp commited on
Commit
4df6952
·
verified ·
1 Parent(s): 4f357d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -92
app.py CHANGED
@@ -20,6 +20,11 @@ from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_tr
20
 
21
  from sentence_transformers import SentenceTransformer
22
 
 
 
 
 
 
23
  TEXT_PIPELINE = None
24
  COMPARISON_PIPELINE = None
25
  NUM_EXAMPLES = 1000
@@ -27,47 +32,46 @@ NUM_EXAMPLES = 1000
27
  @spaces.GPU(duration=300)
28
  def finetune_small_subset():
29
  """
30
- 1) Loads 'wuhp/myr1' in 4-bit quantization (QLoRA style),
31
  2) Adds LoRA adapters (trainable),
32
- 3) Trains on a small subset of the Magpie dataset,
33
- 4) Saves LoRA adapter to 'finetuned_myr1',
34
- 5) Reloads LoRA adapters for inference in a pipeline.
35
  """
36
-
37
- ds = load_dataset(
38
- "Magpie-Align/Magpie-Reasoning-V2-250K-CoT-Deepseek-R1-Llama-70B",
39
- split="train"
40
- )
41
-
42
- unique_ids = list(set(ds["conversation_id"]))
43
- single_id = unique_ids[0]
44
- ds = ds.filter(lambda x: x["conversation_id"] == single_id)
45
-
46
  ds = ds.select(range(min(NUM_EXAMPLES, len(ds))))
47
 
48
  bnb_config = BitsAndBytesConfig(
49
  load_in_4bit=True,
50
- bnb_4bit_compute_dtype=torch.bfloat16,
51
  bnb_4bit_use_double_quant=True,
52
  bnb_4bit_quant_type="nf4",
53
  )
54
 
55
- config = AutoConfig.from_pretrained(
56
- "wuhp/myr1",
 
57
  subfolder="myr1",
58
- trust_remote_code=True
59
  )
 
 
 
 
60
  tokenizer = AutoTokenizer.from_pretrained(
61
- "wuhp/myr1",
62
  subfolder="myr1",
63
  trust_remote_code=True
64
  )
65
 
 
 
66
  base_model = AutoModelForCausalLM.from_pretrained(
67
  "wuhp/myr1",
68
  subfolder="myr1",
69
- config=config,
70
- quantization_config=bnb_config,
71
  device_map="auto",
72
  trust_remote_code=True
73
  )
@@ -84,10 +88,11 @@ def finetune_small_subset():
84
  )
85
  lora_model = get_peft_model(base_model, lora_config)
86
 
 
87
  def tokenize_fn(ex):
88
  text = (
89
- f"Instruction: {ex['instruction']}\n\n"
90
- f"Response: {ex['response']}"
91
  )
92
  return tokenizer(text, truncation=True, max_length=512)
93
 
@@ -102,9 +107,9 @@ def finetune_small_subset():
102
  per_device_train_batch_size=1,
103
  gradient_accumulation_steps=2,
104
  logging_steps=5,
105
- save_steps=999999,
106
  save_total_limit=1,
107
- fp16=False,
108
  )
109
 
110
  trainer = Trainer(
@@ -115,13 +120,15 @@ def finetune_small_subset():
115
  )
116
  trainer.train()
117
 
 
118
  trainer.model.save_pretrained("finetuned_myr1")
119
  tokenizer.save_pretrained("finetuned_myr1")
120
 
 
121
  base_model_2 = AutoModelForCausalLM.from_pretrained(
122
  "wuhp/myr1",
123
  subfolder="myr1",
124
- config=config,
125
  quantization_config=bnb_config,
126
  device_map="auto",
127
  trust_remote_code=True
@@ -140,8 +147,8 @@ def finetune_small_subset():
140
 
141
  def ensure_pipeline():
142
  """
143
- If we haven't finetuned yet (TEXT_PIPELINE is None),
144
- load the base model in 4-bit with NO LoRA.
145
  """
146
  global TEXT_PIPELINE
147
  if TEXT_PIPELINE is None:
@@ -151,12 +158,14 @@ def ensure_pipeline():
151
  bnb_4bit_use_double_quant=True,
152
  bnb_4bit_quant_type="nf4",
153
  )
154
- config = AutoConfig.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True)
 
 
155
  tokenizer = AutoTokenizer.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True)
156
  base_model = AutoModelForCausalLM.from_pretrained(
157
  "wuhp/myr1",
158
  subfolder="myr1",
159
- config=config,
160
  quantization_config=bnb_config,
161
  device_map="auto",
162
  trust_remote_code=True
@@ -166,7 +175,7 @@ def ensure_pipeline():
166
 
167
  def ensure_comparison_pipeline():
168
  """
169
- Load the DeepSeek model pipeline if not already loaded.
170
  """
171
  global COMPARISON_PIPELINE
172
  if COMPARISON_PIPELINE is None:
@@ -183,7 +192,7 @@ def ensure_comparison_pipeline():
183
  @spaces.GPU(duration=120)
184
  def predict(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
185
  """
186
- Simple single-prompt generation (no retrieval).
187
  """
188
  pipe = ensure_pipeline()
189
  out = pipe(
@@ -199,7 +208,7 @@ def predict(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
199
  @spaces.GPU(duration=120)
200
  def compare_models(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
201
  """
202
- Compare local pipeline vs. DeepSeek side-by-side.
203
  """
204
  local_pipe = ensure_pipeline()
205
  comp_pipe = ensure_comparison_pipeline()
@@ -224,75 +233,51 @@ def compare_models(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
224
 
225
  class ConversationRetriever:
226
  """
227
- A simple in-memory store + FAISS for retrieval of conversation chunks.
228
- Each chunk is embedded via SentenceTransformer. On a new user query,
229
- we embed the query, do similarity search, and retrieve top-k relevant chunks.
230
  """
231
-
232
  def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2", embed_dim=384):
233
- """
234
- model_name: embedding model for messages
235
- embed_dim: dimension of the embeddings from that model
236
- """
237
  self.embed_model = SentenceTransformer(model_name)
238
  self.embed_dim = embed_dim
239
-
240
  self.index = faiss.IndexFlatL2(embed_dim)
241
- self.texts = []
242
- self.vectors = []
243
- self.ids = []
244
-
245
  self.id_counter = 0
246
 
247
  def add_text(self, text):
248
- """
249
- Add a new text chunk to the vector store.
250
- Could chunk it up if desired, but here we treat the entire text as one chunk.
251
- """
252
  if not text.strip():
253
  return
254
-
255
  emb = self.embed_model.encode([text], convert_to_numpy=True)
256
- vec = emb[0].astype(np.float32)
257
  self.index.add(vec.reshape(1, -1))
258
-
259
  self.texts.append(text)
260
  self.vectors.append(vec)
261
  self.ids.append(self.id_counter)
262
-
263
  self.id_counter += 1
264
 
265
  def search(self, query, top_k=3):
266
- """
267
- Given a query, embed it, do similarity search in FAISS, return top-k texts.
268
- """
269
  q_emb = self.embed_model.encode([query], convert_to_numpy=True).astype(np.float32)
270
  q_vec = q_emb[0].reshape(1, -1)
271
  distances, indices = self.index.search(q_vec, top_k)
272
-
273
  results = []
274
  for dist, idx in zip(distances[0], indices[0]):
275
- if idx < len(self.texts):
276
  results.append((self.texts[idx], dist))
277
  return results
278
 
279
- retriever = ConversationRetriever()
280
 
281
  def build_rag_prompt(user_query, retrieved_chunks):
282
  """
283
- Construct a prompt that includes:
284
- - The user's new query
285
- - A "Relevant Context" section from retrieved chunks
286
- - "Assistant:" to let the model continue
287
- Feel free to customize the formatting as you like.
288
  """
289
  context_str = ""
290
  for i, (chunk, dist) in enumerate(retrieved_chunks):
291
- context_str += f"Chunk #{i+1} (similarity score ~ {dist:.2f}):\n{chunk}\n\n"
292
-
293
  prompt = (
294
  f"User's Query:\n{user_query}\n\n"
295
- f"Relevant Context from Conversation:\n{context_str}"
296
  "Assistant:"
297
  )
298
  return prompt
@@ -300,22 +285,13 @@ def build_rag_prompt(user_query, retrieved_chunks):
300
  @spaces.GPU(duration=120)
301
  def chat_rag(user_input, history, temperature, top_p, min_new_tokens, max_new_tokens):
302
  """
303
- Our RAG-based chat function. We'll:
304
- 1) Add user input to FAISS
305
- 2) Retrieve top-k relevant older messages from FAISS
306
- 3) Build a prompt that includes the relevant chunks + user query
307
- 4) Generate a response from the pipeline
308
- 5) Add the assistant's response to FAISS as well
309
  """
310
  pipe = ensure_pipeline()
311
-
312
  retriever.add_text(f"User: {user_input}")
313
-
314
  top_k = 3
315
  results = retriever.search(user_input, top_k=top_k)
316
-
317
  prompt = build_rag_prompt(user_input, results)
318
-
319
  output = pipe(
320
  prompt,
321
  temperature=float(temperature),
@@ -331,16 +307,15 @@ def chat_rag(user_input, history, temperature, top_p, min_new_tokens, max_new_to
331
  assistant_reply = output.strip()
332
 
333
  retriever.add_text(f"Assistant: {assistant_reply}")
334
-
335
  history.append([user_input, assistant_reply])
336
  return history, history
337
 
 
338
  with gr.Blocks() as demo:
339
- gr.Markdown("# QLoRA Fine-tuning & RAG-based Chat Demo")
340
 
341
- finetune_btn = gr.Button("Finetune 4-bit (QLoRA) on Magpie subset (up to 5 min)")
342
  status_box = gr.Textbox(label="Finetune Status")
343
-
344
  finetune_btn.click(fn=finetune_small_subset, outputs=status_box)
345
 
346
  gr.Markdown("## Direct Generation (No Retrieval)")
@@ -349,19 +324,18 @@ with gr.Blocks() as demo:
349
  top_p = gr.Slider(0.0, 1.0, step=0.05, value=0.9, label="Top-p")
350
  min_tokens = gr.Slider(1, 2500, value=50, step=10, label="Min New Tokens")
351
  max_tokens = gr.Slider(1, 2500, value=200, step=50, label="Max New Tokens")
352
-
353
- output_box = gr.Textbox(label="myr1 Output", lines=8)
354
- gen_btn = gr.Button("Generate with myr1")
355
  gen_btn.click(
356
  fn=predict,
357
  inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
358
  outputs=output_box
359
  )
360
 
361
- gr.Markdown("## Compare myr1 vs DeepSeek")
362
  compare_btn = gr.Button("Compare")
363
- out_local = gr.Textbox(label="myr1 Output", lines=6)
364
- out_deepseek = gr.Textbox(label="DeepSeek Output", lines=6)
365
  compare_btn.click(
366
  fn=compare_models,
367
  inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
@@ -372,15 +346,13 @@ with gr.Blocks() as demo:
372
  with gr.Row():
373
  with gr.Column():
374
  chatbot = gr.Chatbot(label="RAG Chat")
375
- chat_state = gr.State([])
376
-
377
  user_input = gr.Textbox(
378
  show_label=False,
379
  placeholder="Ask a question...",
380
  lines=2
381
  )
382
  send_btn = gr.Button("Send")
383
-
384
  user_input.submit(
385
  fn=chat_rag,
386
  inputs=[user_input, chat_state, temperature, top_p, min_tokens, max_tokens],
@@ -392,4 +364,4 @@ with gr.Blocks() as demo:
392
  outputs=[chat_state, chatbot]
393
  )
394
 
395
- demo.launch()
 
20
 
21
  from sentence_transformers import SentenceTransformer
22
 
23
+ # Import your custom configuration overrides.
24
+ # For example, your configuration_deepseek.py might export a dictionary called CONFIG_OVERRIDES.
25
+ import configuration_deepseek
26
+
27
+ # Global variables for pipelines and settings.
28
  TEXT_PIPELINE = None
29
  COMPARISON_PIPELINE = None
30
  NUM_EXAMPLES = 1000
 
32
  @spaces.GPU(duration=300)
33
  def finetune_small_subset():
34
  """
35
+ 1) Loads your custom model ("wuhp/myr1") in 4-bit quantization (QLoRA style),
36
  2) Adds LoRA adapters (trainable),
37
+ 3) Fine-tunes on a small subset of the ServiceNow-AI/R1-Distill-SFT dataset,
38
+ 4) Saves the LoRA adapter to "finetuned_myr1",
39
+ 5) Reloads the LoRA adapter for inference.
40
  """
41
+ # Load the new dataset.
42
+ ds = load_dataset("ServiceNow-AI/R1-Distill-SFT", split="train")
 
 
 
 
 
 
 
 
43
  ds = ds.select(range(min(NUM_EXAMPLES, len(ds))))
44
 
45
  bnb_config = BitsAndBytesConfig(
46
  load_in_4bit=True,
47
+ bnb_4bit_compute_dtype=torch.bfloat16,
48
  bnb_4bit_use_double_quant=True,
49
  bnb_4bit_quant_type="nf4",
50
  )
51
 
52
+ # Load the base configuration from your model repository.
53
+ base_config = AutoConfig.from_pretrained(
54
+ "wuhp/myr1",
55
  subfolder="myr1",
56
+ trust_remote_code=True,
57
  )
58
+ # Apply your custom overrides (from configuration_deepseek.py).
59
+ for key, value in configuration_deepseek.CONFIG_OVERRIDES.items():
60
+ setattr(base_config, key, value)
61
+
62
  tokenizer = AutoTokenizer.from_pretrained(
63
+ "wuhp/myr1",
64
  subfolder="myr1",
65
  trust_remote_code=True
66
  )
67
 
68
+ # Load the model. With trust_remote_code=True, your custom model class (e.g. DeepseekV3ForCausalLM)
69
+ # will be loaded from the repository.
70
  base_model = AutoModelForCausalLM.from_pretrained(
71
  "wuhp/myr1",
72
  subfolder="myr1",
73
+ config=base_config,
74
+ quantization_config=bnb_config,
75
  device_map="auto",
76
  trust_remote_code=True
77
  )
 
88
  )
89
  lora_model = get_peft_model(base_model, lora_config)
90
 
91
+ # For this dataset, assume "problem" is the prompt and "solution" is the target.
92
  def tokenize_fn(ex):
93
  text = (
94
+ f"Problem: {ex['problem']}\n\n"
95
+ f"Solution: {ex['solution']}"
96
  )
97
  return tokenizer(text, truncation=True, max_length=512)
98
 
 
107
  per_device_train_batch_size=1,
108
  gradient_accumulation_steps=2,
109
  logging_steps=5,
110
+ save_steps=999999, # High save interval
111
  save_total_limit=1,
112
+ fp16=False, # Set to True if supported by your hardware
113
  )
114
 
115
  trainer = Trainer(
 
120
  )
121
  trainer.train()
122
 
123
+ # Save the LoRA adapter and tokenizer.
124
  trainer.model.save_pretrained("finetuned_myr1")
125
  tokenizer.save_pretrained("finetuned_myr1")
126
 
127
+ # Reload the base model and attach the LoRA adapter for inference.
128
  base_model_2 = AutoModelForCausalLM.from_pretrained(
129
  "wuhp/myr1",
130
  subfolder="myr1",
131
+ config=base_config,
132
  quantization_config=bnb_config,
133
  device_map="auto",
134
  trust_remote_code=True
 
147
 
148
  def ensure_pipeline():
149
  """
150
+ If we haven't fine-tuned yet (i.e. TEXT_PIPELINE is None),
151
+ load the base model (without LoRA) in 4-bit mode.
152
  """
153
  global TEXT_PIPELINE
154
  if TEXT_PIPELINE is None:
 
158
  bnb_4bit_use_double_quant=True,
159
  bnb_4bit_quant_type="nf4",
160
  )
161
+ base_config = AutoConfig.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True)
162
+ for key, value in configuration_deepseek.CONFIG_OVERRIDES.items():
163
+ setattr(base_config, key, value)
164
  tokenizer = AutoTokenizer.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True)
165
  base_model = AutoModelForCausalLM.from_pretrained(
166
  "wuhp/myr1",
167
  subfolder="myr1",
168
+ config=base_config,
169
  quantization_config=bnb_config,
170
  device_map="auto",
171
  trust_remote_code=True
 
175
 
176
  def ensure_comparison_pipeline():
177
  """
178
+ Load a reference DeepSeek model pipeline if not already loaded.
179
  """
180
  global COMPARISON_PIPELINE
181
  if COMPARISON_PIPELINE is None:
 
192
  @spaces.GPU(duration=120)
193
  def predict(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
194
  """
195
+ Direct generation without retrieval.
196
  """
197
  pipe = ensure_pipeline()
198
  out = pipe(
 
208
  @spaces.GPU(duration=120)
209
  def compare_models(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
210
  """
211
+ Compare outputs between your custom model and a reference DeepSeek model.
212
  """
213
  local_pipe = ensure_pipeline()
214
  comp_pipe = ensure_comparison_pipeline()
 
233
 
234
  class ConversationRetriever:
235
  """
236
+ A simple in-memory FAISS-based retriever.
237
+ Each text chunk is embedded using SentenceTransformer.
 
238
  """
 
239
  def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2", embed_dim=384):
 
 
 
 
240
  self.embed_model = SentenceTransformer(model_name)
241
  self.embed_dim = embed_dim
 
242
  self.index = faiss.IndexFlatL2(embed_dim)
243
+ self.texts = []
244
+ self.vectors = []
245
+ self.ids = []
 
246
  self.id_counter = 0
247
 
248
  def add_text(self, text):
 
 
 
 
249
  if not text.strip():
250
  return
 
251
  emb = self.embed_model.encode([text], convert_to_numpy=True)
252
+ vec = emb[0].astype(np.float32)
253
  self.index.add(vec.reshape(1, -1))
 
254
  self.texts.append(text)
255
  self.vectors.append(vec)
256
  self.ids.append(self.id_counter)
 
257
  self.id_counter += 1
258
 
259
  def search(self, query, top_k=3):
 
 
 
260
  q_emb = self.embed_model.encode([query], convert_to_numpy=True).astype(np.float32)
261
  q_vec = q_emb[0].reshape(1, -1)
262
  distances, indices = self.index.search(q_vec, top_k)
 
263
  results = []
264
  for dist, idx in zip(distances[0], indices[0]):
265
+ if idx < len(self.texts):
266
  results.append((self.texts[idx], dist))
267
  return results
268
 
269
+ retriever = ConversationRetriever()
270
 
271
  def build_rag_prompt(user_query, retrieved_chunks):
272
  """
273
+ Build a prompt for retrieval-augmented generation.
 
 
 
 
274
  """
275
  context_str = ""
276
  for i, (chunk, dist) in enumerate(retrieved_chunks):
277
+ context_str += f"Chunk #{i+1} (similarity ~ {dist:.2f}):\n{chunk}\n\n"
 
278
  prompt = (
279
  f"User's Query:\n{user_query}\n\n"
280
+ f"Relevant Context:\n{context_str}"
281
  "Assistant:"
282
  )
283
  return prompt
 
285
  @spaces.GPU(duration=120)
286
  def chat_rag(user_input, history, temperature, top_p, min_new_tokens, max_new_tokens):
287
  """
288
+ Chat function with retrieval augmentation.
 
 
 
 
 
289
  """
290
  pipe = ensure_pipeline()
 
291
  retriever.add_text(f"User: {user_input}")
 
292
  top_k = 3
293
  results = retriever.search(user_input, top_k=top_k)
 
294
  prompt = build_rag_prompt(user_input, results)
 
295
  output = pipe(
296
  prompt,
297
  temperature=float(temperature),
 
307
  assistant_reply = output.strip()
308
 
309
  retriever.add_text(f"Assistant: {assistant_reply}")
 
310
  history.append([user_input, assistant_reply])
311
  return history, history
312
 
313
+ # Build the Gradio interface.
314
  with gr.Blocks() as demo:
315
+ gr.Markdown("# QLoRA Fine-tuning & RAG-based Chat Demo using Custom DeepSeekV3 Model")
316
 
317
+ finetune_btn = gr.Button("Finetune 4-bit (QLoRA) on ServiceNow-AI/R1-Distill-SFT subset (up to 5 min)")
318
  status_box = gr.Textbox(label="Finetune Status")
 
319
  finetune_btn.click(fn=finetune_small_subset, outputs=status_box)
320
 
321
  gr.Markdown("## Direct Generation (No Retrieval)")
 
324
  top_p = gr.Slider(0.0, 1.0, step=0.05, value=0.9, label="Top-p")
325
  min_tokens = gr.Slider(1, 2500, value=50, step=10, label="Min New Tokens")
326
  max_tokens = gr.Slider(1, 2500, value=200, step=50, label="Max New Tokens")
327
+ output_box = gr.Textbox(label="DeepSeekV3 Output", lines=8)
328
+ gen_btn = gr.Button("Generate with DeepSeekV3")
 
329
  gen_btn.click(
330
  fn=predict,
331
  inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
332
  outputs=output_box
333
  )
334
 
335
+ gr.Markdown("## Compare DeepSeekV3 vs Reference DeepSeek")
336
  compare_btn = gr.Button("Compare")
337
+ out_local = gr.Textbox(label="DeepSeekV3 Output", lines=6)
338
+ out_deepseek = gr.Textbox(label="Reference DeepSeek Output", lines=6)
339
  compare_btn.click(
340
  fn=compare_models,
341
  inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
 
346
  with gr.Row():
347
  with gr.Column():
348
  chatbot = gr.Chatbot(label="RAG Chat")
349
+ chat_state = gr.State([])
 
350
  user_input = gr.Textbox(
351
  show_label=False,
352
  placeholder="Ask a question...",
353
  lines=2
354
  )
355
  send_btn = gr.Button("Send")
 
356
  user_input.submit(
357
  fn=chat_rag,
358
  inputs=[user_input, chat_state, temperature, top_p, min_tokens, max_tokens],
 
364
  outputs=[chat_state, chatbot]
365
  )
366
 
367
+ demo.launch()