Spaces:
wuhp
/
Running on Zero

myr1-2 / app.py
wuhp's picture
Update app.py
b446d41 verified
raw
history blame
2.22 kB
import gradio as gr
import spaces
import torch
from transformers import Trainer, TrainingArguments
from datasets import load_dataset
from transformers import (
AutoConfig,
AutoTokenizer,
AutoModelForCausalLM,
DataCollatorForLanguageModeling,
)
@spaces.GPU(duration=600) # 10 minutes
def run_finetuning():
# Load dataset
ds = load_dataset("Magpie-Align/Magpie-Reasoning-V2-250K-CoT-Deepseek-R1-Llama-70B")
# maybe select a small subset (like 1000 rows) or you'll likely time out
ds_small = ds["train"].select(range(1000))
# Format example:
def format_row(ex):
return {"text": f"User: {ex['instruction']}\nAssistant: {ex['response']}"}
ds_small = ds_small.map(format_row)
# Load config/tokenizer/model with trust_remote_code
config = AutoConfig.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
"wuhp/myr1",
subfolder="myr1",
config=config,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
# Tokenize
def tokenize(ex):
return tokenizer(ex["text"], truncation=True, max_length=512)
ds_small = ds_small.map(tokenize, batched=True)
ds_small.set_format("torch")
collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
# Trainer
args = TrainingArguments(
output_dir="finetuned_model",
num_train_epochs=1,
per_device_train_batch_size=1,
logging_steps=5,
fp16=True,
save_strategy="no",
)
trainer = Trainer(
model=model,
args=args,
train_dataset=ds_small,
data_collator=collator,
)
trainer.train()
# Save
trainer.save_model("finetuned_model")
tokenizer.save_pretrained("finetuned_model")
return "Finetuning done!"
# Then define a Gradio UI that calls run_finetuning
with gr.Blocks() as demo:
btn = gr.Button("Run Finetuning (10 min max!)")
status = gr.Textbox(label="Status")
btn.click(fn=run_finetuning, inputs=None, outputs=status)
demo.launch()