wuhp commited on
Commit
c8df7a5
·
verified ·
1 Parent(s): 4df6952

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -30
app.py CHANGED
@@ -20,10 +20,6 @@ from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_tr
20
 
21
  from sentence_transformers import SentenceTransformer
22
 
23
- # Import your custom configuration overrides.
24
- # For example, your configuration_deepseek.py might export a dictionary called CONFIG_OVERRIDES.
25
- import configuration_deepseek
26
-
27
  # Global variables for pipelines and settings.
28
  TEXT_PIPELINE = None
29
  COMPARISON_PIPELINE = None
@@ -32,13 +28,14 @@ NUM_EXAMPLES = 1000
32
  @spaces.GPU(duration=300)
33
  def finetune_small_subset():
34
  """
35
- 1) Loads your custom model ("wuhp/myr1") in 4-bit quantization (QLoRA style),
36
- 2) Adds LoRA adapters (trainable),
37
- 3) Fine-tunes on a small subset of the ServiceNow-AI/R1-Distill-SFT dataset,
38
- 4) Saves the LoRA adapter to "finetuned_myr1",
39
- 5) Reloads the LoRA adapter for inference.
 
 
40
  """
41
- # Load the new dataset.
42
  ds = load_dataset("ServiceNow-AI/R1-Distill-SFT", split="train")
43
  ds = ds.select(range(min(NUM_EXAMPLES, len(ds))))
44
 
@@ -49,15 +46,13 @@ def finetune_small_subset():
49
  bnb_4bit_quant_type="nf4",
50
  )
51
 
52
- # Load the base configuration from your model repository.
53
  base_config = AutoConfig.from_pretrained(
54
  "wuhp/myr1",
55
  subfolder="myr1",
56
  trust_remote_code=True,
57
  )
58
- # Apply your custom overrides (from configuration_deepseek.py).
59
- for key, value in configuration_deepseek.CONFIG_OVERRIDES.items():
60
- setattr(base_config, key, value)
61
 
62
  tokenizer = AutoTokenizer.from_pretrained(
63
  "wuhp/myr1",
@@ -65,8 +60,6 @@ def finetune_small_subset():
65
  trust_remote_code=True
66
  )
67
 
68
- # Load the model. With trust_remote_code=True, your custom model class (e.g. DeepseekV3ForCausalLM)
69
- # will be loaded from the repository.
70
  base_model = AutoModelForCausalLM.from_pretrained(
71
  "wuhp/myr1",
72
  subfolder="myr1",
@@ -88,7 +81,6 @@ def finetune_small_subset():
88
  )
89
  lora_model = get_peft_model(base_model, lora_config)
90
 
91
- # For this dataset, assume "problem" is the prompt and "solution" is the target.
92
  def tokenize_fn(ex):
93
  text = (
94
  f"Problem: {ex['problem']}\n\n"
@@ -107,9 +99,9 @@ def finetune_small_subset():
107
  per_device_train_batch_size=1,
108
  gradient_accumulation_steps=2,
109
  logging_steps=5,
110
- save_steps=999999, # High save interval
111
  save_total_limit=1,
112
- fp16=False, # Set to True if supported by your hardware
113
  )
114
 
115
  trainer = Trainer(
@@ -120,11 +112,9 @@ def finetune_small_subset():
120
  )
121
  trainer.train()
122
 
123
- # Save the LoRA adapter and tokenizer.
124
  trainer.model.save_pretrained("finetuned_myr1")
125
  tokenizer.save_pretrained("finetuned_myr1")
126
 
127
- # Reload the base model and attach the LoRA adapter for inference.
128
  base_model_2 = AutoModelForCausalLM.from_pretrained(
129
  "wuhp/myr1",
130
  subfolder="myr1",
@@ -147,8 +137,7 @@ def finetune_small_subset():
147
 
148
  def ensure_pipeline():
149
  """
150
- If we haven't fine-tuned yet (i.e. TEXT_PIPELINE is None),
151
- load the base model (without LoRA) in 4-bit mode.
152
  """
153
  global TEXT_PIPELINE
154
  if TEXT_PIPELINE is None:
@@ -159,8 +148,6 @@ def ensure_pipeline():
159
  bnb_4bit_quant_type="nf4",
160
  )
161
  base_config = AutoConfig.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True)
162
- for key, value in configuration_deepseek.CONFIG_OVERRIDES.items():
163
- setattr(base_config, key, value)
164
  tokenizer = AutoTokenizer.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True)
165
  base_model = AutoModelForCausalLM.from_pretrained(
166
  "wuhp/myr1",
@@ -175,7 +162,7 @@ def ensure_pipeline():
175
 
176
  def ensure_comparison_pipeline():
177
  """
178
- Load a reference DeepSeek model pipeline if not already loaded.
179
  """
180
  global COMPARISON_PIPELINE
181
  if COMPARISON_PIPELINE is None:
@@ -233,8 +220,7 @@ def compare_models(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
233
 
234
  class ConversationRetriever:
235
  """
236
- A simple in-memory FAISS-based retriever.
237
- Each text chunk is embedded using SentenceTransformer.
238
  """
239
  def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2", embed_dim=384):
240
  self.embed_model = SentenceTransformer(model_name)
@@ -270,7 +256,7 @@ retriever = ConversationRetriever()
270
 
271
  def build_rag_prompt(user_query, retrieved_chunks):
272
  """
273
- Build a prompt for retrieval-augmented generation.
274
  """
275
  context_str = ""
276
  for i, (chunk, dist) in enumerate(retrieved_chunks):
@@ -285,7 +271,7 @@ def build_rag_prompt(user_query, retrieved_chunks):
285
  @spaces.GPU(duration=120)
286
  def chat_rag(user_input, history, temperature, top_p, min_new_tokens, max_new_tokens):
287
  """
288
- Chat function with retrieval augmentation.
289
  """
290
  pipe = ensure_pipeline()
291
  retriever.add_text(f"User: {user_input}")
 
20
 
21
  from sentence_transformers import SentenceTransformer
22
 
 
 
 
 
23
  # Global variables for pipelines and settings.
24
  TEXT_PIPELINE = None
25
  COMPARISON_PIPELINE = None
 
28
  @spaces.GPU(duration=300)
29
  def finetune_small_subset():
30
  """
31
+ Fine-tunes the custom DeepSeekV3 model on a small subset of the ServiceNow-AI/R1-Distill-SFT dataset.
32
+ Steps:
33
+ 1) Loads the model from "wuhp/myr1" (using files from the "myr1" subfolder via trust_remote_code).
34
+ 2) Applies 4-bit quantization and prepares for QLoRA training.
35
+ 3) Fine-tunes on the dataset (mapping "problem" to prompt and "solution" to target).
36
+ 4) Saves the LoRA adapter to "finetuned_myr1".
37
+ 5) Reloads the adapter for inference.
38
  """
 
39
  ds = load_dataset("ServiceNow-AI/R1-Distill-SFT", split="train")
40
  ds = ds.select(range(min(NUM_EXAMPLES, len(ds))))
41
 
 
46
  bnb_4bit_quant_type="nf4",
47
  )
48
 
49
+ # Load the custom model configuration from the repository.
50
  base_config = AutoConfig.from_pretrained(
51
  "wuhp/myr1",
52
  subfolder="myr1",
53
  trust_remote_code=True,
54
  )
55
+ # (Optionally apply local overrides here if needed.)
 
 
56
 
57
  tokenizer = AutoTokenizer.from_pretrained(
58
  "wuhp/myr1",
 
60
  trust_remote_code=True
61
  )
62
 
 
 
63
  base_model = AutoModelForCausalLM.from_pretrained(
64
  "wuhp/myr1",
65
  subfolder="myr1",
 
81
  )
82
  lora_model = get_peft_model(base_model, lora_config)
83
 
 
84
  def tokenize_fn(ex):
85
  text = (
86
  f"Problem: {ex['problem']}\n\n"
 
99
  per_device_train_batch_size=1,
100
  gradient_accumulation_steps=2,
101
  logging_steps=5,
102
+ save_steps=999999,
103
  save_total_limit=1,
104
+ fp16=False,
105
  )
106
 
107
  trainer = Trainer(
 
112
  )
113
  trainer.train()
114
 
 
115
  trainer.model.save_pretrained("finetuned_myr1")
116
  tokenizer.save_pretrained("finetuned_myr1")
117
 
 
118
  base_model_2 = AutoModelForCausalLM.from_pretrained(
119
  "wuhp/myr1",
120
  subfolder="myr1",
 
137
 
138
  def ensure_pipeline():
139
  """
140
+ Loads the base model (without LoRA) if no fine-tuned model is available.
 
141
  """
142
  global TEXT_PIPELINE
143
  if TEXT_PIPELINE is None:
 
148
  bnb_4bit_quant_type="nf4",
149
  )
150
  base_config = AutoConfig.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True)
 
 
151
  tokenizer = AutoTokenizer.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True)
152
  base_model = AutoModelForCausalLM.from_pretrained(
153
  "wuhp/myr1",
 
162
 
163
  def ensure_comparison_pipeline():
164
  """
165
+ Loads a reference DeepSeek model pipeline if not already loaded.
166
  """
167
  global COMPARISON_PIPELINE
168
  if COMPARISON_PIPELINE is None:
 
220
 
221
  class ConversationRetriever:
222
  """
223
+ A FAISS-based retriever using SentenceTransformer for embedding.
 
224
  """
225
  def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2", embed_dim=384):
226
  self.embed_model = SentenceTransformer(model_name)
 
256
 
257
  def build_rag_prompt(user_query, retrieved_chunks):
258
  """
259
+ Builds a prompt for retrieval-augmented generation.
260
  """
261
  context_str = ""
262
  for i, (chunk, dist) in enumerate(retrieved_chunks):
 
271
  @spaces.GPU(duration=120)
272
  def chat_rag(user_input, history, temperature, top_p, min_new_tokens, max_new_tokens):
273
  """
274
+ Chat with retrieval augmentation.
275
  """
276
  pipe = ensure_pipeline()
277
  retriever.add_text(f"User: {user_input}")