wuhp commited on
Commit
13b1681
·
verified ·
1 Parent(s): b7d0639

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -55
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import gradio as gr
2
  import spaces
3
- from datasets import load_dataset
4
  import torch
 
 
5
  from transformers import (
6
  AutoConfig,
7
  AutoTokenizer,
@@ -9,131 +10,183 @@ from transformers import (
9
  DataCollatorForLanguageModeling,
10
  Trainer,
11
  TrainingArguments,
12
- pipeline
 
13
  )
14
 
 
 
 
 
15
  ##############################################################################
16
- # GLOBALS / ZERO-GPU APPROACH
17
  ##############################################################################
18
- # We store a global pipeline after finetuning (if any).
19
  TEXT_PIPELINE = None
 
20
 
21
- # We'll train on only 50 examples from WikiText-2 to keep it short.
22
- NUM_EXAMPLES = 50
23
-
24
-
25
- @spaces.GPU(duration=600) # up to 600 seconds (10 minutes) for mini-finetraining
26
  def finetune_small_subset():
27
  """
28
- 1) Loads 'wuhp/myr1' in 8-bit,
29
- 2) Takes 50 examples from WikiText-2,
30
- 3) Finetunes for 1 epoch,
31
- 4) Saves to 'finetuned_myr1/',
32
- 5) Reloads the new model into a pipeline for inference.
33
  """
34
 
35
- # 1) Load dataset
36
  ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
37
- # Keep only 50 to fit ephemeral GPU time
38
  ds = ds.select(range(min(NUM_EXAMPLES, len(ds))))
39
 
40
- # 2) Load config, tokenizer, model
 
 
 
 
 
 
 
 
 
 
 
41
  config = AutoConfig.from_pretrained(
42
- "wuhp/myr1",
43
  subfolder="myr1",
44
  trust_remote_code=True
45
  )
46
  tokenizer = AutoTokenizer.from_pretrained(
47
- "wuhp/myr1",
48
  subfolder="myr1",
49
  trust_remote_code=True
50
  )
51
- # 8-bit loading via bitsandbytes
52
- model = AutoModelForCausalLM.from_pretrained(
 
53
  "wuhp/myr1",
54
  subfolder="myr1",
55
  config=config,
56
- load_in_8bit=True, # <--- 8-bit
57
- device_map="auto", # let HF manage device placement
58
  trust_remote_code=True
59
  )
60
 
61
- # 3) Tokenize
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def tokenize_fn(ex):
63
  return tokenizer(ex["text"], truncation=True, max_length=512)
64
 
65
  ds = ds.map(tokenize_fn, batched=True, remove_columns=["text"])
66
  ds.set_format("torch")
67
 
 
68
  collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
69
 
70
- # 4) TrainingArguments: no fp16 to avoid half-precision gradient issues
71
  training_args = TrainingArguments(
72
  output_dir="finetuned_myr1",
73
  num_train_epochs=1,
74
  per_device_train_batch_size=1,
75
  gradient_accumulation_steps=2,
76
- logging_steps=10,
77
- save_steps=999999, # skip mid-training saves
78
  save_total_limit=1,
79
- fp16=False, # <--- disable FP16
80
  )
81
 
82
- # 5) Trainer
83
  trainer = Trainer(
84
- model=model,
85
  args=training_args,
86
  train_dataset=ds,
87
  data_collator=collator,
88
  )
89
 
90
- # 6) Train
91
  trainer.train()
92
 
93
- # 7) Save final model
94
- trainer.save_model("finetuned_myr1")
 
95
  tokenizer.save_pretrained("finetuned_myr1")
96
 
97
- # 8) Reload the newly finetuned model as a pipeline (for inference)
98
- finetuned_model = AutoModelForCausalLM.from_pretrained(
99
- "finetuned_myr1",
 
 
 
 
100
  device_map="auto",
101
  trust_remote_code=True
102
  )
 
 
 
 
 
 
 
103
 
 
 
 
 
 
104
  global TEXT_PIPELINE
105
- TEXT_PIPELINE = pipeline("text-generation", model=finetuned_model, tokenizer=tokenizer)
106
- return "Finetuning complete! Model reloaded for inference."
 
107
 
108
 
109
  def ensure_pipeline():
110
  """
111
- If no pipeline yet, load the original model from wuhp/myr1 for inference.
112
- (In 8-bit or normal float? We can do normal float here for a simpler approach.)
113
  """
114
  global TEXT_PIPELINE
115
  if TEXT_PIPELINE is None:
116
- tokenizer = AutoTokenizer.from_pretrained(
117
- "wuhp/myr1",
118
- subfolder="myr1",
119
- trust_remote_code=True
 
 
120
  )
121
- model = AutoModelForCausalLM.from_pretrained(
 
 
122
  "wuhp/myr1",
123
  subfolder="myr1",
124
- trust_remote_code=True,
125
- load_in_8bit=True, # load in 8-bit also for inference
126
- device_map="auto"
 
127
  )
128
- TEXT_PIPELINE = pipeline("text-generation", model=model, tokenizer=tokenizer)
129
  return TEXT_PIPELINE
130
 
131
 
132
- @spaces.GPU(duration=120) # up to 120s for text generation
133
  def predict(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
134
  """
135
- Generates text from either the finetuned pipeline (if it exists) or the base model.
136
- Allows user to adjust temperature, top_p, min/max tokens.
137
  """
138
  pipe = ensure_pipeline()
139
  out = pipe(
@@ -149,13 +202,13 @@ def predict(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
149
 
150
  # Build Gradio UI
151
  with gr.Blocks() as demo:
152
- gr.Markdown("## ZeroGPU: Mini-Finetune with 8-bit + Extended Generation")
153
 
154
- finetune_btn = gr.Button("Finetune on 50 lines of WikiText-2 (up to 10 min)")
155
  status_box = gr.Textbox(label="Finetune Status")
156
  finetune_btn.click(fn=finetune_small_subset, outputs=status_box)
157
 
158
- gr.Markdown("After finetuning, or even without it, generate text below:")
159
 
160
  prompt_in = gr.Textbox(lines=3, label="Prompt")
161
  temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature")
 
1
  import gradio as gr
2
  import spaces
 
3
  import torch
4
+
5
+ from datasets import load_dataset
6
  from transformers import (
7
  AutoConfig,
8
  AutoTokenizer,
 
10
  DataCollatorForLanguageModeling,
11
  Trainer,
12
  TrainingArguments,
13
+ pipeline,
14
+ BitsAndBytesConfig, # for 4-bit config
15
  )
16
 
17
+ # PEFT (LoRA / QLoRA)
18
+ from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
19
+
20
+
21
  ##############################################################################
22
+ # ZeroGPU + QLoRA Example
23
  ##############################################################################
 
24
  TEXT_PIPELINE = None
25
+ NUM_EXAMPLES = 50 # We'll train on 50 lines of WikiText-2 for demonstration
26
 
27
+ @spaces.GPU(duration=600) # up to 10 min
 
 
 
 
28
  def finetune_small_subset():
29
  """
30
+ 1) Loads 'wuhp/myr1' in 4-bit quantization (QLoRA style),
31
+ 2) Adds LoRA adapters (trainable),
32
+ 3) Trains on 50 lines of WikiText-2,
33
+ 4) Saves LoRA adapter to 'finetuned_myr1',
34
+ 5) Reloads LoRA adapters for inference in a pipeline.
35
  """
36
 
37
+ # --- 1) Load WikiText-2 subset ---
38
  ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
 
39
  ds = ds.select(range(min(NUM_EXAMPLES, len(ds))))
40
 
41
+ # We'll define tokenize_fn after we have the tokenizer
42
+
43
+ # --- 2) Setup 4-bit quantization with BitsAndBytes ---
44
+ # This is QLoRA approach: we load the base model in 4-bit
45
+ # and attach LoRA adapters for training.
46
+ bnb_config = BitsAndBytesConfig(
47
+ load_in_4bit=True,
48
+ bnb_4bit_compute_dtype=torch.bfloat16, # or torch.float16 if preferred
49
+ bnb_4bit_use_double_quant=True,
50
+ bnb_4bit_quant_type="nf4", # "nf4" is standard for QLoRA
51
+ )
52
+
53
  config = AutoConfig.from_pretrained(
54
+ "wuhp/myr1",
55
  subfolder="myr1",
56
  trust_remote_code=True
57
  )
58
  tokenizer = AutoTokenizer.from_pretrained(
59
+ "wuhp/myr1",
60
  subfolder="myr1",
61
  trust_remote_code=True
62
  )
63
+
64
+ # Load model in 4-bit
65
+ base_model = AutoModelForCausalLM.from_pretrained(
66
  "wuhp/myr1",
67
  subfolder="myr1",
68
  config=config,
69
+ quantization_config=bnb_config, # <--- QLoRA 4-bit
70
+ device_map="auto",
71
  trust_remote_code=True
72
  )
73
 
74
+ # Prepare the model for k-bit training (QLoRA)
75
+ # This step disables dropout on some layers, sets up gradients for LN, etc.
76
+ base_model = prepare_model_for_kbit_training(base_model)
77
+
78
+ # --- 3) Create LoRA config & wrap the base model in LoRA adapter ---
79
+ # For LLaMA-like models, "q_proj" and "v_proj" are typical. If your model is different,
80
+ # adjust target_modules accordingly (maybe "c_attn", "W_pack", "query_key_value", etc.)
81
+ lora_config = LoraConfig(
82
+ r=16,
83
+ lora_alpha=32,
84
+ lora_dropout=0.05,
85
+ bias="none",
86
+ target_modules=["q_proj", "v_proj"], # Adjust if your model uses different layer names
87
+ task_type=TaskType.CAUSAL_LM,
88
+ )
89
+ lora_model = get_peft_model(base_model, lora_config)
90
+
91
+ # --- 4) Tokenize dataset ---
92
  def tokenize_fn(ex):
93
  return tokenizer(ex["text"], truncation=True, max_length=512)
94
 
95
  ds = ds.map(tokenize_fn, batched=True, remove_columns=["text"])
96
  ds.set_format("torch")
97
 
98
+ # Data collator
99
  collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
100
 
101
+ # Training args
102
  training_args = TrainingArguments(
103
  output_dir="finetuned_myr1",
104
  num_train_epochs=1,
105
  per_device_train_batch_size=1,
106
  gradient_accumulation_steps=2,
107
+ logging_steps=5,
108
+ save_steps=999999,
109
  save_total_limit=1,
110
+ fp16=False, # We'll rely on bnb_4bit/bfloat16 for the base model
111
  )
112
 
113
+ # Trainer
114
  trainer = Trainer(
115
+ model=lora_model,
116
  args=training_args,
117
  train_dataset=ds,
118
  data_collator=collator,
119
  )
120
 
121
+ # --- 5) Train ---
122
  trainer.train()
123
 
124
+ # Save LoRA adapter + tokenizer
125
+ # The 'save_model' would save only the LoRA adapter if using PEFT
126
+ trainer.model.save_pretrained("finetuned_myr1")
127
  tokenizer.save_pretrained("finetuned_myr1")
128
 
129
+ # --- 6) Reload the base model in 4-bit, then merge or apply the LoRA adapter for inference
130
+ # We'll do the same approach, then load adapter from 'finetuned_myr1'
131
+ base_model_2 = AutoModelForCausalLM.from_pretrained(
132
+ "wuhp/myr1",
133
+ subfolder="myr1",
134
+ config=config,
135
+ quantization_config=bnb_config,
136
  device_map="auto",
137
  trust_remote_code=True
138
  )
139
+ base_model_2 = prepare_model_for_kbit_training(base_model_2)
140
+
141
+ # Re-inject LoRA
142
+ # If your LoRA was saved in the same folder, you can do:
143
+ # from peft import PeftModel
144
+ # lora_model_2 = PeftModel.from_pretrained(base_model_2, "finetuned_myr1")
145
+ # or you can do get_peft_model and pass the weights, etc.
146
 
147
+ # But we can reuse 'get_peft_model' + load the LoRA weights
148
+ lora_model_2 = get_peft_model(base_model_2, lora_config)
149
+ lora_model_2.load_adapter("finetuned_myr1")
150
+
151
+ # Create pipeline
152
  global TEXT_PIPELINE
153
+ TEXT_PIPELINE = pipeline("text-generation", model=lora_model_2, tokenizer=tokenizer)
154
+
155
+ return "Finetuning complete (QLoRA + LoRA). Model loaded for inference."
156
 
157
 
158
  def ensure_pipeline():
159
  """
160
+ If we haven't finetuned yet (TEXT_PIPELINE is None),
161
+ load the base model in 4-bit with NO LoRA.
162
  """
163
  global TEXT_PIPELINE
164
  if TEXT_PIPELINE is None:
165
+ # Just load base model in 4-bit
166
+ bnb_config = BitsAndBytesConfig(
167
+ load_in_4bit=True,
168
+ bnb_4bit_compute_dtype=torch.bfloat16,
169
+ bnb_4bit_use_double_quant=True,
170
+ bnb_4bit_quant_type="nf4",
171
  )
172
+ config = AutoConfig.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True)
173
+ tokenizer = AutoTokenizer.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True)
174
+ base_model = AutoModelForCausalLM.from_pretrained(
175
  "wuhp/myr1",
176
  subfolder="myr1",
177
+ config=config,
178
+ quantization_config=bnb_config,
179
+ device_map="auto",
180
+ trust_remote_code=True
181
  )
182
+ TEXT_PIPELINE = pipeline("text-generation", model=base_model, tokenizer=tokenizer)
183
  return TEXT_PIPELINE
184
 
185
 
186
+ @spaces.GPU(duration=120) # up to 2 min for text generation
187
  def predict(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
188
  """
189
+ Generates text from the finetuned (LoRA) model if present, else the base model.
 
190
  """
191
  pipe = ensure_pipeline()
192
  out = pipe(
 
202
 
203
  # Build Gradio UI
204
  with gr.Blocks() as demo:
205
+ gr.Markdown("## ZeroGPU QLoRA Example for wuhp/myr1")
206
 
207
+ finetune_btn = gr.Button("Finetune 4-bit (QLoRA) on 50 lines of WikiText-2 (up to 10 min)")
208
  status_box = gr.Textbox(label="Finetune Status")
209
  finetune_btn.click(fn=finetune_small_subset, outputs=status_box)
210
 
211
+ gr.Markdown("Then generate text below (or skip finetuning to see base model).")
212
 
213
  prompt_in = gr.Textbox(lines=3, label="Prompt")
214
  temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature")