Spaces:
wuhp
/
Running on Zero

wuhp commited on
Commit
b26485f
·
verified ·
1 Parent(s): b446d41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -57
app.py CHANGED
@@ -1,73 +1,64 @@
1
  import gradio as gr
2
  import spaces
3
- import torch
4
- from transformers import Trainer, TrainingArguments
5
- from datasets import load_dataset
6
- from transformers import (
7
- AutoConfig,
8
- AutoTokenizer,
9
- AutoModelForCausalLM,
10
- DataCollatorForLanguageModeling,
11
- )
12
 
13
- @spaces.GPU(duration=600) # 10 minutes
14
- def run_finetuning():
15
- # Load dataset
16
- ds = load_dataset("Magpie-Align/Magpie-Reasoning-V2-250K-CoT-Deepseek-R1-Llama-70B")
17
- # maybe select a small subset (like 1000 rows) or you'll likely time out
18
- ds_small = ds["train"].select(range(1000))
19
 
20
- # Format example:
21
- def format_row(ex):
22
- return {"text": f"User: {ex['instruction']}\nAssistant: {ex['response']}"}
23
- ds_small = ds_small.map(format_row)
24
-
25
- # Load config/tokenizer/model with trust_remote_code
26
- config = AutoConfig.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True)
27
- tokenizer = AutoTokenizer.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True)
 
 
 
 
 
 
 
 
28
  model = AutoModelForCausalLM.from_pretrained(
29
  "wuhp/myr1",
30
  subfolder="myr1",
31
  config=config,
32
- torch_dtype=torch.float16,
33
- device_map="auto",
34
  trust_remote_code=True
35
  )
 
 
36
 
37
- # Tokenize
38
- def tokenize(ex):
39
- return tokenizer(ex["text"], truncation=True, max_length=512)
40
- ds_small = ds_small.map(tokenize, batched=True)
 
 
 
 
 
41
 
42
- ds_small.set_format("torch")
43
- collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
 
 
 
 
 
 
 
44
 
45
- # Trainer
46
- args = TrainingArguments(
47
- output_dir="finetuned_model",
48
- num_train_epochs=1,
49
- per_device_train_batch_size=1,
50
- logging_steps=5,
51
- fp16=True,
52
- save_strategy="no",
53
- )
54
- trainer = Trainer(
55
- model=model,
56
- args=args,
57
- train_dataset=ds_small,
58
- data_collator=collator,
59
- )
60
- trainer.train()
61
-
62
- # Save
63
- trainer.save_model("finetuned_model")
64
- tokenizer.save_pretrained("finetuned_model")
65
- return "Finetuning done!"
66
-
67
- # Then define a Gradio UI that calls run_finetuning
68
  with gr.Blocks() as demo:
69
- btn = gr.Button("Run Finetuning (10 min max!)")
70
- status = gr.Textbox(label="Status")
71
- btn.click(fn=run_finetuning, inputs=None, outputs=status)
 
 
 
 
72
 
73
  demo.launch()
 
1
  import gradio as gr
2
  import spaces
3
+ from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, pipeline
 
 
 
 
 
 
 
 
4
 
5
+ text_pipeline = None # global var to hold our pipeline once loaded
 
 
 
 
 
6
 
7
+ @spaces.GPU(duration=120) # request up to 120s GPU time to load the model
8
+ def load_model():
9
+ """
10
+ This function will run in a *child* process that has GPU allocated.
11
+ We can safely do device_map="auto" or .to("cuda") here.
12
+ """
13
+ config = AutoConfig.from_pretrained(
14
+ "wuhp/myr1",
15
+ subfolder="myr1",
16
+ trust_remote_code=True
17
+ )
18
+ tokenizer = AutoTokenizer.from_pretrained(
19
+ "wuhp/myr1",
20
+ subfolder="myr1",
21
+ trust_remote_code=True
22
+ )
23
  model = AutoModelForCausalLM.from_pretrained(
24
  "wuhp/myr1",
25
  subfolder="myr1",
26
  config=config,
27
+ torch_dtype="auto", # triggers GPU usage
28
+ device_map="auto", # triggers GPU usage
29
  trust_remote_code=True
30
  )
31
+ text_pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
32
+ return text_pipe
33
 
34
+ def ensure_pipeline():
35
+ """
36
+ If we've never loaded the pipeline, call load_model() now.
37
+ If ZeroGPU has deallocated it, we might need to reload again.
38
+ """
39
+ global text_pipeline
40
+ if text_pipeline is None:
41
+ text_pipeline = load_model() # <-- calls the GPU-wrapped function
42
+ return text_pipeline
43
 
44
+ @spaces.GPU(duration=60) # up to 60s for each generate call
45
+ def predict(prompt, max_new_tokens=64):
46
+ """
47
+ Called when the user clicks 'Generate'; ensures the model is loaded,
48
+ then runs inference on GPU.
49
+ """
50
+ pipe = ensure_pipeline()
51
+ outputs = pipe(prompt, max_new_tokens=int(max_new_tokens))
52
+ return outputs[0]["generated_text"]
53
 
54
+ # Build the Gradio UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  with gr.Blocks() as demo:
56
+ gr.Markdown("# ZeroGPU Inference Demo")
57
+ prompt = gr.Textbox(label="Prompt")
58
+ max_tok = gr.Slider(1, 256, value=64, step=1, label="Max New Tokens")
59
+ output = gr.Textbox(label="Generated Text")
60
+
61
+ generate_btn = gr.Button("Generate")
62
+ generate_btn.click(fn=predict, inputs=[prompt, max_tok], outputs=output)
63
 
64
  demo.launch()