Spaces:
wuhp
/
Running on Zero

wuhp commited on
Commit
b446d41
·
verified ·
1 Parent(s): 5755412

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -27
app.py CHANGED
@@ -1,52 +1,73 @@
1
  import gradio as gr
2
  import spaces
3
  import torch
 
 
4
  from transformers import (
5
  AutoConfig,
6
  AutoTokenizer,
7
  AutoModelForCausalLM,
8
- pipeline
9
  )
10
 
11
- # 1) Decorate your GPU-dependent function(s)
12
- @spaces.GPU(duration=60) # default is 60s, can increase if needed
13
- def load_pipeline():
14
- # -- load config & model from wuhp/myr1 --
 
 
 
 
 
 
 
 
 
15
  config = AutoConfig.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True)
16
  tokenizer = AutoTokenizer.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True)
17
  model = AutoModelForCausalLM.from_pretrained(
18
  "wuhp/myr1",
19
  subfolder="myr1",
20
  config=config,
21
- torch_dtype=torch.float16, # half precision
22
  device_map="auto",
23
  trust_remote_code=True
24
  )
25
- # optional: load generation config if you have generation_config.json
26
- text_pipeline = pipeline(
27
- "text-generation",
28
- model=model,
29
- tokenizer=tokenizer
30
- )
31
- return text_pipeline
32
 
33
- # We'll load it once and store globally
34
- text_pipeline = load_pipeline()
 
 
 
 
 
35
 
36
- def predict(prompt, max_new_tokens=64):
37
- outputs = text_pipeline(
38
- prompt, max_new_tokens=int(max_new_tokens), do_sample=True, temperature=0.7
 
 
 
 
 
39
  )
40
- return outputs[0]["generated_text"]
 
 
 
 
 
 
41
 
42
- # 2) Build your Gradio app
43
- with gr.Blocks() as demo:
44
- gr.Markdown("## My LLM Inference (ZeroGPU)")
45
- prompt = gr.Textbox(label="Prompt")
46
- max_nt = gr.Slider(1, 200, value=64, step=1, label="Max New Tokens")
47
- output = gr.Textbox(label="Generated Text")
48
 
49
- btn = gr.Button("Generate")
50
- btn.click(fn=predict, inputs=[prompt, max_nt], outputs=output)
 
 
 
51
 
52
  demo.launch()
 
1
  import gradio as gr
2
  import spaces
3
  import torch
4
+ from transformers import Trainer, TrainingArguments
5
+ from datasets import load_dataset
6
  from transformers import (
7
  AutoConfig,
8
  AutoTokenizer,
9
  AutoModelForCausalLM,
10
+ DataCollatorForLanguageModeling,
11
  )
12
 
13
+ @spaces.GPU(duration=600) # 10 minutes
14
+ def run_finetuning():
15
+ # Load dataset
16
+ ds = load_dataset("Magpie-Align/Magpie-Reasoning-V2-250K-CoT-Deepseek-R1-Llama-70B")
17
+ # maybe select a small subset (like 1000 rows) or you'll likely time out
18
+ ds_small = ds["train"].select(range(1000))
19
+
20
+ # Format example:
21
+ def format_row(ex):
22
+ return {"text": f"User: {ex['instruction']}\nAssistant: {ex['response']}"}
23
+ ds_small = ds_small.map(format_row)
24
+
25
+ # Load config/tokenizer/model with trust_remote_code
26
  config = AutoConfig.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True)
27
  tokenizer = AutoTokenizer.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True)
28
  model = AutoModelForCausalLM.from_pretrained(
29
  "wuhp/myr1",
30
  subfolder="myr1",
31
  config=config,
32
+ torch_dtype=torch.float16,
33
  device_map="auto",
34
  trust_remote_code=True
35
  )
 
 
 
 
 
 
 
36
 
37
+ # Tokenize
38
+ def tokenize(ex):
39
+ return tokenizer(ex["text"], truncation=True, max_length=512)
40
+ ds_small = ds_small.map(tokenize, batched=True)
41
+
42
+ ds_small.set_format("torch")
43
+ collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
44
 
45
+ # Trainer
46
+ args = TrainingArguments(
47
+ output_dir="finetuned_model",
48
+ num_train_epochs=1,
49
+ per_device_train_batch_size=1,
50
+ logging_steps=5,
51
+ fp16=True,
52
+ save_strategy="no",
53
  )
54
+ trainer = Trainer(
55
+ model=model,
56
+ args=args,
57
+ train_dataset=ds_small,
58
+ data_collator=collator,
59
+ )
60
+ trainer.train()
61
 
62
+ # Save
63
+ trainer.save_model("finetuned_model")
64
+ tokenizer.save_pretrained("finetuned_model")
65
+ return "Finetuning done!"
 
 
66
 
67
+ # Then define a Gradio UI that calls run_finetuning
68
+ with gr.Blocks() as demo:
69
+ btn = gr.Button("Run Finetuning (10 min max!)")
70
+ status = gr.Textbox(label="Status")
71
+ btn.click(fn=run_finetuning, inputs=None, outputs=status)
72
 
73
  demo.launch()