import gradio as gr import spaces import torch from datasets import load_dataset from transformers import ( AutoConfig, AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments, pipeline, BitsAndBytesConfig, ) # PEFT (LoRA / QLoRA) from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training, PeftModel ############################################################################## # ZeroGPU + QLoRA Example ############################################################################## TEXT_PIPELINE = None # Pipeline for wuhp/myr1 (fine-tuned or base) COMPARISON_PIPELINE = None # Pipeline for the DeepSeek model NUM_EXAMPLES = 50 # We'll train on 50 rows for demonstration @spaces.GPU(duration=300) # up to 5 min def finetune_small_subset(): """ 1) Loads 'wuhp/myr1' in 4-bit quantization (QLoRA style), 2) Adds LoRA adapters (trainable), 3) Trains on a small subset of the Magpie dataset, 4) Saves LoRA adapter to 'finetuned_myr1', 5) Reloads LoRA adapters for inference in a pipeline. """ # --- 1) Load a small subset of the Magpie dataset --- ds = load_dataset( "Magpie-Align/Magpie-Reasoning-V2-250K-CoT-Deepseek-R1-Llama-70B", split="train" ) # For demonstration, pick a single conversation_id unique_ids = list(set(ds["conversation_id"])) single_id = unique_ids[0] ds = ds.filter(lambda x: x["conversation_id"] == single_id) # Then select only NUM_EXAMPLES from that subset ds = ds.select(range(min(NUM_EXAMPLES, len(ds)))) # --- 2) Setup 4-bit quantization with BitsAndBytes --- bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, # or torch.float16 bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) config = AutoConfig.from_pretrained( "wuhp/myr1", subfolder="myr1", trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained( "wuhp/myr1", subfolder="myr1", trust_remote_code=True ) base_model = AutoModelForCausalLM.from_pretrained( "wuhp/myr1", subfolder="myr1", config=config, quantization_config=bnb_config, # <--- QLoRA 4-bit device_map="auto", trust_remote_code=True ) # Prepare the model for k-bit training base_model = prepare_model_for_kbit_training(base_model) # --- 3) Create LoRA config & wrap the base model in LoRA --- lora_config = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.05, bias="none", target_modules=["q_proj", "v_proj"], task_type=TaskType.CAUSAL_LM, ) lora_model = get_peft_model(base_model, lora_config) # --- 4) Tokenize dataset --- def tokenize_fn(ex): """ Combine instruction + response into a single text. You can adjust this to include more fields or different formatting. """ text = ( f"Instruction: {ex['instruction']}\n\n" f"Response: {ex['response']}" ) return tokenizer(text, truncation=True, max_length=512) ds = ds.map(tokenize_fn, batched=False, remove_columns=ds.column_names) ds.set_format("torch") collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) # Training args training_args = TrainingArguments( output_dir="finetuned_myr1", num_train_epochs=1, per_device_train_batch_size=1, gradient_accumulation_steps=2, logging_steps=5, save_steps=999999, # effectively don't save mid-epoch save_total_limit=1, fp16=False, # rely on bfloat16 from quantization ) # Trainer trainer = Trainer( model=lora_model, args=training_args, train_dataset=ds, data_collator=collator, ) # --- 5) Train --- trainer.train() # --- 6) Save LoRA adapter + tokenizer --- trainer.model.save_pretrained("finetuned_myr1") tokenizer.save_pretrained("finetuned_myr1") # --- 7) Reload the base model + LoRA adapter for inference base_model_2 = AutoModelForCausalLM.from_pretrained( "wuhp/myr1", subfolder="myr1", config=config, quantization_config=bnb_config, device_map="auto", trust_remote_code=True ) base_model_2 = prepare_model_for_kbit_training(base_model_2) lora_model_2 = PeftModel.from_pretrained( base_model_2, "finetuned_myr1", ) global TEXT_PIPELINE TEXT_PIPELINE = pipeline("text-generation", model=lora_model_2, tokenizer=tokenizer) return "Finetuning complete. Model loaded for inference." def ensure_pipeline(): """ If we haven't finetuned yet (TEXT_PIPELINE is None), load the base model in 4-bit with NO LoRA. """ global TEXT_PIPELINE if TEXT_PIPELINE is None: bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) config = AutoConfig.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True) base_model = AutoModelForCausalLM.from_pretrained( "wuhp/myr1", subfolder="myr1", config=config, quantization_config=bnb_config, device_map="auto", trust_remote_code=True ) TEXT_PIPELINE = pipeline("text-generation", model=base_model, tokenizer=tokenizer) return TEXT_PIPELINE def ensure_comparison_pipeline(): """ Load the DeepSeek model pipeline if not already loaded. """ global COMPARISON_PIPELINE if COMPARISON_PIPELINE is None: # If you prefer 4-bit, you can define BitsAndBytesConfig here, # but let's keep it simpler for demonstration (fp16 or bf16). config = AutoConfig.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Llama-8B") tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Llama-8B") model = AutoModelForCausalLM.from_pretrained( "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", config=config, device_map="auto" ) COMPARISON_PIPELINE = pipeline( "text-generation", model=model, tokenizer=tokenizer ) return COMPARISON_PIPELINE @spaces.GPU(duration=120) # up to 2 min for text generation def predict(prompt, temperature, top_p, min_new_tokens, max_new_tokens): """ Generates text from the fine-tuned (LoRA) model if present, else the base model. """ pipe = ensure_pipeline() out = pipe( prompt, temperature=float(temperature), top_p=float(top_p), min_new_tokens=int(min_new_tokens), max_new_tokens=int(max_new_tokens), do_sample=True ) return out[0]["generated_text"] @spaces.GPU(duration=120) # up to 2 min for text generation def compare_models(prompt, temperature, top_p, min_new_tokens, max_new_tokens): """ Generates text side-by-side from the local myr1 pipeline (fine-tuned or base) AND from the DeepSeek model. Returns two strings. """ local_pipe = ensure_pipeline() comp_pipe = ensure_comparison_pipeline() local_out = local_pipe( prompt, temperature=float(temperature), top_p=float(top_p), min_new_tokens=int(min_new_tokens), max_new_tokens=int(max_new_tokens), do_sample=True ) local_text = local_out[0]["generated_text"] comp_out = comp_pipe( prompt, temperature=float(temperature), top_p=float(top_p), min_new_tokens=int(min_new_tokens), max_new_tokens=int(max_new_tokens), do_sample=True ) comp_text = comp_out[0]["generated_text"] return local_text, comp_text # Build Gradio UI with gr.Blocks() as demo: gr.Markdown("# QLoRA Fine-tuning & Comparison Demo") gr.Markdown("**Fine-tune wuhp/myr1** on a small subset of the Magpie dataset, then generate or compare output with the DeepSeek model.") finetune_btn = gr.Button("Finetune 4-bit (QLoRA) on Magpie subset (up to 5 min)") status_box = gr.Textbox(label="Finetune Status") finetune_btn.click(fn=finetune_small_subset, outputs=status_box) gr.Markdown("### Generate with myr1 (fine-tuned if done, else base)") prompt_in = gr.Textbox(lines=3, label="Prompt") temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature") top_p = gr.Slider(0.0, 1.0, step=0.05, value=0.9, label="Top-p") min_tokens = gr.Slider(50, 1024, value=50, step=10, label="Min New Tokens") max_tokens = gr.Slider(50, 1024, value=200, step=50, label="Max New Tokens") output_box = gr.Textbox(label="myr1 Output", lines=8) gen_btn = gr.Button("Generate with myr1") gen_btn.click( fn=predict, inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens], outputs=output_box ) gr.Markdown("### Compare myr1 vs DeepSeek side-by-side") compare_btn = gr.Button("Compare") out_local = gr.Textbox(label="myr1 Output", lines=8) out_deepseek = gr.Textbox(label="DeepSeek Output", lines=8) compare_btn.click( fn=compare_models, inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens], outputs=[out_local, out_deepseek] ) demo.launch()