import gradio as gr import spaces from datasets import load_dataset import torch from transformers import ( AutoConfig, AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments, pipeline ) ############################################################################## # GLOBALS / ZERO-GPU APPROACH ############################################################################## # We store a global pipeline after finetuning (if any). TEXT_PIPELINE = None # We'll train on only 50 examples from WikiText-2 to keep it short. NUM_EXAMPLES = 50 @spaces.GPU(duration=600) # up to 600 seconds (10 minutes) for mini-finetraining def finetune_small_subset(): """ 1) Loads 'wuhp/myr1' in 8-bit, 2) Takes 50 examples from WikiText-2, 3) Finetunes for 1 epoch, 4) Saves to 'finetuned_myr1/', 5) Reloads the new model into a pipeline for inference. """ # 1) Load dataset ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") # Keep only 50 to fit ephemeral GPU time ds = ds.select(range(min(NUM_EXAMPLES, len(ds)))) # 2) Load config, tokenizer, model config = AutoConfig.from_pretrained( "wuhp/myr1", subfolder="myr1", trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained( "wuhp/myr1", subfolder="myr1", trust_remote_code=True ) # 8-bit loading via bitsandbytes model = AutoModelForCausalLM.from_pretrained( "wuhp/myr1", subfolder="myr1", config=config, load_in_8bit=True, # <--- 8-bit device_map="auto", # let HF manage device placement trust_remote_code=True ) # 3) Tokenize def tokenize_fn(ex): return tokenizer(ex["text"], truncation=True, max_length=512) ds = ds.map(tokenize_fn, batched=True, remove_columns=["text"]) ds.set_format("torch") collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) # 4) TrainingArguments: no fp16 to avoid half-precision gradient issues training_args = TrainingArguments( output_dir="finetuned_myr1", num_train_epochs=1, per_device_train_batch_size=1, gradient_accumulation_steps=2, logging_steps=10, save_steps=999999, # skip mid-training saves save_total_limit=1, fp16=False, # <--- disable FP16 ) # 5) Trainer trainer = Trainer( model=model, args=training_args, train_dataset=ds, data_collator=collator, ) # 6) Train trainer.train() # 7) Save final model trainer.save_model("finetuned_myr1") tokenizer.save_pretrained("finetuned_myr1") # 8) Reload the newly finetuned model as a pipeline (for inference) finetuned_model = AutoModelForCausalLM.from_pretrained( "finetuned_myr1", device_map="auto", trust_remote_code=True ) global TEXT_PIPELINE TEXT_PIPELINE = pipeline("text-generation", model=finetuned_model, tokenizer=tokenizer) return "Finetuning complete! Model reloaded for inference." def ensure_pipeline(): """ If no pipeline yet, load the original model from wuhp/myr1 for inference. (In 8-bit or normal float? We can do normal float here for a simpler approach.) """ global TEXT_PIPELINE if TEXT_PIPELINE is None: tokenizer = AutoTokenizer.from_pretrained( "wuhp/myr1", subfolder="myr1", trust_remote_code=True ) model = AutoModelForCausalLM.from_pretrained( "wuhp/myr1", subfolder="myr1", trust_remote_code=True, load_in_8bit=True, # load in 8-bit also for inference device_map="auto" ) TEXT_PIPELINE = pipeline("text-generation", model=model, tokenizer=tokenizer) return TEXT_PIPELINE @spaces.GPU(duration=120) # up to 120s for text generation def predict(prompt, temperature, top_p, min_new_tokens, max_new_tokens): """ Generates text from either the finetuned pipeline (if it exists) or the base model. Allows user to adjust temperature, top_p, min/max tokens. """ pipe = ensure_pipeline() out = pipe( prompt, temperature=float(temperature), top_p=float(top_p), min_new_tokens=int(min_new_tokens), max_new_tokens=int(max_new_tokens), do_sample=True ) return out[0]["generated_text"] # Build Gradio UI with gr.Blocks() as demo: gr.Markdown("## ZeroGPU: Mini-Finetune with 8-bit + Extended Generation") finetune_btn = gr.Button("Finetune on 50 lines of WikiText-2 (up to 10 min)") status_box = gr.Textbox(label="Finetune Status") finetune_btn.click(fn=finetune_small_subset, outputs=status_box) gr.Markdown("After finetuning, or even without it, generate text below:") prompt_in = gr.Textbox(lines=3, label="Prompt") temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature") top_p = gr.Slider(0.0, 1.0, step=0.05, value=0.9, label="Top-p") min_tokens = gr.Slider(260, 5000, value=260, step=10, label="Min New Tokens") max_tokens = gr.Slider(260, 5000, value=500, step=50, label="Max New Tokens") output_box = gr.Textbox(label="Generated Text", lines=12) gen_btn = gr.Button("Generate") gen_btn.click( fn=predict, inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens], outputs=output_box ) demo.launch()