wuhp commited on
Commit
09f030f
·
verified ·
1 Parent(s): adb5084

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -62
app.py CHANGED
@@ -1,44 +1,74 @@
1
  import gradio as gr
2
- import spaces
3
  import torch
 
4
  import faiss
5
- import numpy as np
6
 
7
  from datasets import load_dataset
 
 
8
  from transformers import (
9
  AutoConfig,
10
- AutoTokenizer,
11
  AutoModelForCausalLM,
 
 
12
  DataCollatorForLanguageModeling,
13
  Trainer,
14
  TrainingArguments,
15
  pipeline,
16
- BitsAndBytesConfig,
17
  )
18
 
19
- from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training, PeftModel
 
 
20
 
21
- from sentence_transformers import SentenceTransformer
22
 
23
- # Global variables for pipelines and settings.
24
- TEXT_PIPELINE = None
25
- COMPARISON_PIPELINE = None
26
- NUM_EXAMPLES = 50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  @spaces.GPU(duration=300)
29
- def finetune_small_subset():
30
  """
31
  Fine-tunes the custom R1 model on a small subset of the ServiceNow-AI/R1-Distill-SFT dataset.
 
32
  Steps:
33
  1) Loads the model from "wuhp/myr1" (using files from the "myr1" subfolder via trust_remote_code).
34
  2) Applies 4-bit quantization and prepares for QLoRA training.
35
  3) Fine-tunes on the dataset (mapping "problem" to prompt and "solution" to target).
36
  4) Saves the LoRA adapter to "finetuned_myr1".
37
  5) Reloads the adapter for inference.
 
 
 
38
  """
39
  # Specify the configuration ("v0" or "v1") explicitly.
40
  ds = load_dataset("ServiceNow-AI/R1-Distill-SFT", "v0", split="train")
41
- ds = ds.select(range(min(NUM_EXAMPLES, len(ds))))
42
 
43
  bnb_config = BitsAndBytesConfig(
44
  load_in_4bit=True,
@@ -48,26 +78,8 @@ def finetune_small_subset():
48
  )
49
 
50
  # Load the custom model configuration from the repository.
51
- base_config = AutoConfig.from_pretrained(
52
- "wuhp/myr1",
53
- subfolder="myr1",
54
- trust_remote_code=True,
55
- )
56
- # (Optionally apply local overrides here if needed.)
57
-
58
- tokenizer = AutoTokenizer.from_pretrained(
59
- "wuhp/myr1",
60
- subfolder="myr1",
61
- trust_remote_code=True
62
- )
63
-
64
- base_model = AutoModelForCausalLM.from_pretrained(
65
- "wuhp/myr1",
66
- subfolder="myr1",
67
- config=base_config,
68
- quantization_config=bnb_config,
69
- device_map="auto",
70
- trust_remote_code=True
71
  )
72
 
73
  base_model = prepare_model_for_kbit_training(base_model)
@@ -100,8 +112,8 @@ def finetune_small_subset():
100
  per_device_train_batch_size=1,
101
  gradient_accumulation_steps=2,
102
  logging_steps=5,
103
- save_steps=999999,
104
- save_total_limit=1,
105
  fp16=False,
106
  )
107
 
@@ -116,13 +128,8 @@ def finetune_small_subset():
116
  trainer.model.save_pretrained("finetuned_myr1")
117
  tokenizer.save_pretrained("finetuned_myr1")
118
 
119
- base_model_2 = AutoModelForCausalLM.from_pretrained(
120
- "wuhp/myr1",
121
- subfolder="myr1",
122
- config=base_config,
123
- quantization_config=bnb_config,
124
- device_map="auto",
125
- trust_remote_code=True
126
  )
127
  base_model_2 = prepare_model_for_kbit_training(base_model_2)
128
 
@@ -132,13 +139,17 @@ def finetune_small_subset():
132
  )
133
 
134
  global TEXT_PIPELINE
135
- TEXT_PIPELINE = pipeline("text-generation", model=lora_model_2, tokenizer=tokenizer)
136
 
137
  return "Finetuning complete. Model loaded for inference."
138
 
139
- def ensure_pipeline():
 
140
  """
141
  Loads the base model (without LoRA) if no fine-tuned model is available.
 
 
 
142
  """
143
  global TEXT_PIPELINE
144
  if TEXT_PIPELINE is None:
@@ -148,22 +159,19 @@ def ensure_pipeline():
148
  bnb_4bit_use_double_quant=True,
149
  bnb_4bit_quant_type="nf4",
150
  )
151
- base_config = AutoConfig.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True)
152
- tokenizer = AutoTokenizer.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True)
153
- base_model = AutoModelForCausalLM.from_pretrained(
154
- "wuhp/myr1",
155
- subfolder="myr1",
156
- config=base_config,
157
- quantization_config=bnb_config,
158
- device_map="auto",
159
- trust_remote_code=True
160
  )
161
  TEXT_PIPELINE = pipeline("text-generation", model=base_model, tokenizer=tokenizer)
162
  return TEXT_PIPELINE
163
 
164
- def ensure_comparison_pipeline():
 
165
  """
166
  Loads the official R1 model pipeline if not already loaded.
 
 
 
167
  """
168
  global COMPARISON_PIPELINE
169
  if COMPARISON_PIPELINE is None:
@@ -177,10 +185,27 @@ def ensure_comparison_pipeline():
177
  COMPARISON_PIPELINE = pipeline("text-generation", model=model, tokenizer=tokenizer)
178
  return COMPARISON_PIPELINE
179
 
 
180
  @spaces.GPU(duration=120)
181
- def predict(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
 
 
 
 
 
 
182
  """
183
  Direct generation without retrieval using the custom R1 model.
 
 
 
 
 
 
 
 
 
 
184
  """
185
  pipe = ensure_pipeline()
186
  out = pipe(
@@ -193,10 +218,27 @@ def predict(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
193
  )
194
  return out[0]["generated_text"]
195
 
 
196
  @spaces.GPU(duration=120)
197
- def compare_models(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
 
 
 
 
 
 
198
  """
199
  Compare outputs between your custom R1 model and the official R1 model.
 
 
 
 
 
 
 
 
 
 
200
  """
201
  local_pipe = ensure_pipeline()
202
  comp_pipe = ensure_comparison_pipeline()
@@ -219,11 +261,22 @@ def compare_models(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
219
  )
220
  return local_out[0]["generated_text"], comp_out[0]["generated_text"]
221
 
 
222
  class ConversationRetriever:
223
  """
224
  A FAISS-based retriever using SentenceTransformer for embedding.
 
 
 
225
  """
226
- def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2", embed_dim=384):
 
 
 
 
 
 
 
227
  self.embed_model = SentenceTransformer(model_name)
228
  self.embed_dim = embed_dim
229
  self.index = faiss.IndexFlatL2(embed_dim)
@@ -232,7 +285,13 @@ class ConversationRetriever:
232
  self.ids = []
233
  self.id_counter = 0
234
 
235
- def add_text(self, text):
 
 
 
 
 
 
236
  if not text.strip():
237
  return
238
  emb = self.embed_model.encode([text], convert_to_numpy=True)
@@ -243,7 +302,17 @@ class ConversationRetriever:
243
  self.ids.append(self.id_counter)
244
  self.id_counter += 1
245
 
246
- def search(self, query, top_k=3):
 
 
 
 
 
 
 
 
 
 
247
  q_emb = self.embed_model.encode([query], convert_to_numpy=True).astype(np.float32)
248
  q_vec = q_emb[0].reshape(1, -1)
249
  distances, indices = self.index.search(q_vec, top_k)
@@ -253,11 +322,20 @@ class ConversationRetriever:
253
  results.append((self.texts[idx], dist))
254
  return results
255
 
 
256
  retriever = ConversationRetriever()
257
 
258
- def build_rag_prompt(user_query, retrieved_chunks):
 
259
  """
260
  Builds a prompt for retrieval-augmented generation.
 
 
 
 
 
 
 
261
  """
262
  context_str = ""
263
  for i, (chunk, dist) in enumerate(retrieved_chunks):
@@ -269,10 +347,29 @@ def build_rag_prompt(user_query, retrieved_chunks):
269
  )
270
  return prompt
271
 
 
272
  @spaces.GPU(duration=120)
273
- def chat_rag(user_input, history, temperature, top_p, min_new_tokens, max_new_tokens):
 
 
 
 
 
 
 
274
  """
275
- Chat with retrieval augmentation.
 
 
 
 
 
 
 
 
 
 
 
276
  """
277
  pipe = ensure_pipeline()
278
  retriever.add_text(f"User: {user_input}")
@@ -297,6 +394,7 @@ def chat_rag(user_input, history, temperature, top_p, min_new_tokens, max_new_to
297
  history.append([user_input, assistant_reply])
298
  return history, history
299
 
 
300
  # Build the Gradio interface.
301
  with gr.Blocks() as demo:
302
  gr.Markdown("# QLoRA Fine-tuning & RAG-based Chat Demo using Custom R1 Model")
@@ -351,4 +449,4 @@ with gr.Blocks() as demo:
351
  outputs=[chat_state, chatbot]
352
  )
353
 
354
- demo.launch()
 
1
  import gradio as gr
2
+ import numpy as np
3
  import torch
4
+
5
  import faiss
6
+ import spaces
7
 
8
  from datasets import load_dataset
9
+ from peft import LoraConfig, PeftModel, TaskType, get_peft_model, prepare_model_for_kbit_training
10
+ from sentence_transformers import SentenceTransformer
11
  from transformers import (
12
  AutoConfig,
 
13
  AutoModelForCausalLM,
14
+ AutoTokenizer,
15
+ BitsAndBytesConfig,
16
  DataCollatorForLanguageModeling,
17
  Trainer,
18
  TrainingArguments,
19
  pipeline,
 
20
  )
21
 
22
+ NUM_EXAMPLES_FOR_FINETUNING = 50 # Constant for the number of examples to use for finetuning
23
+ TEXT_PIPELINE = None # Global to store the custom R1 text generation pipeline
24
+ COMPARISON_PIPELINE = None # Global to store the official R1 text generation pipeline
25
 
 
26
 
27
+ def _load_model_and_tokenizer(model_name: str, subfolder: str = None, quantization_config: BitsAndBytesConfig = None, device_map: str = "auto", trust_remote_code: bool = True) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
28
+ """
29
+ Helper function to load a causal language model and its tokenizer.
30
+
31
+ Args:
32
+ model_name (str): The name or path of the pretrained model.
33
+ subfolder (str, optional): Subfolder within the model repository. Defaults to None.
34
+ quantization_config (BitsAndBytesConfig, optional): Configuration for quantization. Defaults to None.
35
+ device_map (str, optional): Device mapping for model loading. Defaults to "auto".
36
+ trust_remote_code (bool, optional): Trust remote code for custom models. Defaults to True.
37
+
38
+ Returns:
39
+ tuple[AutoModelForCausalLM, AutoTokenizer]: The loaded model and tokenizer.
40
+ """
41
+ config = AutoConfig.from_pretrained(model_name, subfolder=subfolder, trust_remote_code=trust_remote_code)
42
+ tokenizer = AutoTokenizer.from_pretrained(model_name, subfolder=subfolder, trust_remote_code=trust_remote_code)
43
+ model = AutoModelForCausalLM.from_pretrained(
44
+ model_name,
45
+ subfolder=subfolder,
46
+ config=config,
47
+ quantization_config=quantization_config,
48
+ device_map=device_map,
49
+ trust_remote_code=trust_remote_code
50
+ )
51
+ return model, tokenizer
52
+
53
 
54
  @spaces.GPU(duration=300)
55
+ def finetune_small_subset() -> str:
56
  """
57
  Fine-tunes the custom R1 model on a small subset of the ServiceNow-AI/R1-Distill-SFT dataset.
58
+
59
  Steps:
60
  1) Loads the model from "wuhp/myr1" (using files from the "myr1" subfolder via trust_remote_code).
61
  2) Applies 4-bit quantization and prepares for QLoRA training.
62
  3) Fine-tunes on the dataset (mapping "problem" to prompt and "solution" to target).
63
  4) Saves the LoRA adapter to "finetuned_myr1".
64
  5) Reloads the adapter for inference.
65
+
66
+ Returns:
67
+ str: A message indicating finetuning completion.
68
  """
69
  # Specify the configuration ("v0" or "v1") explicitly.
70
  ds = load_dataset("ServiceNow-AI/R1-Distill-SFT", "v0", split="train")
71
+ ds = ds.select(range(min(NUM_EXAMPLES_FOR_FINETUNING, len(ds))))
72
 
73
  bnb_config = BitsAndBytesConfig(
74
  load_in_4bit=True,
 
78
  )
79
 
80
  # Load the custom model configuration from the repository.
81
+ base_model, tokenizer = _load_model_and_tokenizer(
82
+ "wuhp/myr1", subfolder="myr1", quantization_config=bnb_config, device_map="auto"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  )
84
 
85
  base_model = prepare_model_for_kbit_training(base_model)
 
112
  per_device_train_batch_size=1,
113
  gradient_accumulation_steps=2,
114
  logging_steps=5,
115
+ save_steps=999999, # Save infrequently to avoid filling up disk during demo
116
+ save_total_limit=1, # Keep only the last saved checkpoint
117
  fp16=False,
118
  )
119
 
 
128
  trainer.model.save_pretrained("finetuned_myr1")
129
  tokenizer.save_pretrained("finetuned_myr1")
130
 
131
+ base_model_2, tokenizer_2 = _load_model_and_tokenizer( # Re-load base model for inference adapter application
132
+ "wuhp/myr1", subfolder="myr1", quantization_config=bnb_config, device_map="auto"
 
 
 
 
 
133
  )
134
  base_model_2 = prepare_model_for_kbit_training(base_model_2)
135
 
 
139
  )
140
 
141
  global TEXT_PIPELINE
142
+ TEXT_PIPELINE = pipeline("text-generation", model=lora_model_2, tokenizer=tokenizer_2) # Use tokenizer_2 here to be consistent
143
 
144
  return "Finetuning complete. Model loaded for inference."
145
 
146
+
147
+ def ensure_pipeline() -> pipeline:
148
  """
149
  Loads the base model (without LoRA) if no fine-tuned model is available.
150
+
151
+ Returns:
152
+ pipeline: The text generation pipeline using the custom R1 model.
153
  """
154
  global TEXT_PIPELINE
155
  if TEXT_PIPELINE is None:
 
159
  bnb_4bit_use_double_quant=True,
160
  bnb_4bit_quant_type="nf4",
161
  )
162
+ base_model, tokenizer = _load_model_and_tokenizer(
163
+ "wuhp/myr1", subfolder="myr1", quantization_config=bnb_config, device_map="auto"
 
 
 
 
 
 
 
164
  )
165
  TEXT_PIPELINE = pipeline("text-generation", model=base_model, tokenizer=tokenizer)
166
  return TEXT_PIPELINE
167
 
168
+
169
+ def ensure_comparison_pipeline() -> pipeline:
170
  """
171
  Loads the official R1 model pipeline if not already loaded.
172
+
173
+ Returns:
174
+ pipeline: The text generation pipeline using the official R1 model.
175
  """
176
  global COMPARISON_PIPELINE
177
  if COMPARISON_PIPELINE is None:
 
185
  COMPARISON_PIPELINE = pipeline("text-generation", model=model, tokenizer=tokenizer)
186
  return COMPARISON_PIPELINE
187
 
188
+
189
  @spaces.GPU(duration=120)
190
+ def predict(
191
+ prompt: str,
192
+ temperature: float,
193
+ top_p: float,
194
+ min_new_tokens: int,
195
+ max_new_tokens: int
196
+ ) -> str:
197
  """
198
  Direct generation without retrieval using the custom R1 model.
199
+
200
+ Args:
201
+ prompt (str): The input prompt for text generation.
202
+ temperature (float): Sampling temperature.
203
+ top_p (float): Top-p sampling probability.
204
+ min_new_tokens (int): Minimum number of new tokens to generate.
205
+ max_new_tokens (int): Maximum number of new tokens to generate.
206
+
207
+ Returns:
208
+ str: The generated text output.
209
  """
210
  pipe = ensure_pipeline()
211
  out = pipe(
 
218
  )
219
  return out[0]["generated_text"]
220
 
221
+
222
  @spaces.GPU(duration=120)
223
+ def compare_models(
224
+ prompt: str,
225
+ temperature: float,
226
+ top_p: float,
227
+ min_new_tokens: int,
228
+ max_new_tokens: int
229
+ ) -> tuple[str, str]:
230
  """
231
  Compare outputs between your custom R1 model and the official R1 model.
232
+
233
+ Args:
234
+ prompt (str): The input prompt for text generation.
235
+ temperature (float): Sampling temperature.
236
+ top_p (float): Top-p sampling probability.
237
+ min_new_tokens (int): Minimum number of new tokens to generate.
238
+ max_new_tokens (int): Maximum number of new tokens to generate.
239
+
240
+ Returns:
241
+ tuple[str, str]: A tuple containing the generated text from the custom R1 and official R1 models.
242
  """
243
  local_pipe = ensure_pipeline()
244
  comp_pipe = ensure_comparison_pipeline()
 
261
  )
262
  return local_out[0]["generated_text"], comp_out[0]["generated_text"]
263
 
264
+
265
  class ConversationRetriever:
266
  """
267
  A FAISS-based retriever using SentenceTransformer for embedding.
268
+
269
+ This class indexes text chunks using FAISS and SentenceTransformer embeddings
270
+ to enable efficient similarity search for retrieval-augmented generation.
271
  """
272
+ def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", embed_dim: int = 384):
273
+ """
274
+ Initializes the ConversationRetriever.
275
+
276
+ Args:
277
+ model_name (str, optional): Name of the SentenceTransformer model. Defaults to "sentence-transformers/all-MiniLM-L6-v2".
278
+ embed_dim (int, optional): Dimensionality of the embeddings. Defaults to 384.
279
+ """
280
  self.embed_model = SentenceTransformer(model_name)
281
  self.embed_dim = embed_dim
282
  self.index = faiss.IndexFlatL2(embed_dim)
 
285
  self.ids = []
286
  self.id_counter = 0
287
 
288
+ def add_text(self, text: str):
289
+ """
290
+ Adds text to the retriever's index.
291
+
292
+ Args:
293
+ text (str): The text to add.
294
+ """
295
  if not text.strip():
296
  return
297
  emb = self.embed_model.encode([text], convert_to_numpy=True)
 
302
  self.ids.append(self.id_counter)
303
  self.id_counter += 1
304
 
305
+ def search(self, query: str, top_k: int = 3) -> list[tuple[str, float]]:
306
+ """
307
+ Searches the retriever index for texts similar to the query.
308
+
309
+ Args:
310
+ query (str): The query text.
311
+ top_k (int, optional): Number of top results to retrieve. Defaults to 3.
312
+
313
+ Returns:
314
+ list[tuple[str, float]]: A list of tuples, where each tuple contains (text, distance).
315
+ """
316
  q_emb = self.embed_model.encode([query], convert_to_numpy=True).astype(np.float32)
317
  q_vec = q_emb[0].reshape(1, -1)
318
  distances, indices = self.index.search(q_vec, top_k)
 
322
  results.append((self.texts[idx], dist))
323
  return results
324
 
325
+
326
  retriever = ConversationRetriever()
327
 
328
+
329
+ def build_rag_prompt(user_query: str, retrieved_chunks: list[tuple[str, float]]) -> str:
330
  """
331
  Builds a prompt for retrieval-augmented generation.
332
+
333
+ Args:
334
+ user_query (str): The user's input query.
335
+ retrieved_chunks (list[tuple[str, float]]): List of retrieved text chunks and their distances.
336
+
337
+ Returns:
338
+ str: The formatted prompt string.
339
  """
340
  context_str = ""
341
  for i, (chunk, dist) in enumerate(retrieved_chunks):
 
347
  )
348
  return prompt
349
 
350
+
351
  @spaces.GPU(duration=120)
352
+ def chat_rag(
353
+ user_input: str,
354
+ history: list[list[str]],
355
+ temperature: float,
356
+ top_p: float,
357
+ min_new_tokens: int,
358
+ max_new_tokens: int
359
+ ) -> tuple[list[list[str]], list[list[str]]]:
360
  """
361
+ Chat with retrieval augmentation using the custom R1 model.
362
+
363
+ Args:
364
+ user_input (str): The user's chat input.
365
+ history (list[list[str]]): The chat history.
366
+ temperature (float): Sampling temperature.
367
+ top_p (float): Top-p sampling probability.
368
+ min_new_tokens (int): Minimum number of new tokens to generate.
369
+ max_new_tokens (int): Maximum number of new tokens to generate.
370
+
371
+ Returns:
372
+ tuple[list[list[str]], list[list[str]]]: Updated chat history and chatbot display history.
373
  """
374
  pipe = ensure_pipeline()
375
  retriever.add_text(f"User: {user_input}")
 
394
  history.append([user_input, assistant_reply])
395
  return history, history
396
 
397
+
398
  # Build the Gradio interface.
399
  with gr.Blocks() as demo:
400
  gr.Markdown("# QLoRA Fine-tuning & RAG-based Chat Demo using Custom R1 Model")
 
449
  outputs=[chat_state, chatbot]
450
  )
451
 
452
+ demo.launch()