import gradio as gr import spaces import torch import faiss import numpy as np from datasets import load_dataset from transformers import ( AutoConfig, AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments, pipeline, BitsAndBytesConfig, ) # PEFT (LoRA / QLoRA) from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training, PeftModel # For embeddings from sentence_transformers import SentenceTransformer ############################################################################## # QLoRA Demo Setup ############################################################################## TEXT_PIPELINE = None COMPARISON_PIPELINE = None NUM_EXAMPLES = 50 # We'll train on 50 rows for demonstration @spaces.GPU(duration=300) def finetune_small_subset(): """ 1) Loads 'wuhp/myr1' in 4-bit quantization (QLoRA style), 2) Adds LoRA adapters (trainable), 3) Trains on a small subset of the Magpie dataset, 4) Saves LoRA adapter to 'finetuned_myr1', 5) Reloads LoRA adapters for inference in a pipeline. """ # --- 1) Load a small subset of the Magpie dataset --- ds = load_dataset( "Magpie-Align/Magpie-Reasoning-V2-250K-CoT-Deepseek-R1-Llama-70B", split="train" ) unique_ids = list(set(ds["conversation_id"])) single_id = unique_ids[0] ds = ds.filter(lambda x: x["conversation_id"] == single_id) ds = ds.select(range(min(NUM_EXAMPLES, len(ds)))) # --- 2) Setup 4-bit quantization --- bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, # or torch.float16 bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) config = AutoConfig.from_pretrained( "wuhp/myr1", subfolder="myr1", trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained( "wuhp/myr1", subfolder="myr1", trust_remote_code=True ) base_model = AutoModelForCausalLM.from_pretrained( "wuhp/myr1", subfolder="myr1", config=config, quantization_config=bnb_config, # <--- QLoRA 4-bit device_map="auto", trust_remote_code=True ) base_model = prepare_model_for_kbit_training(base_model) # --- 3) Create LoRA config & wrap the base model in LoRA --- lora_config = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.05, bias="none", target_modules=["q_proj", "v_proj"], task_type=TaskType.CAUSAL_LM, ) lora_model = get_peft_model(base_model, lora_config) # --- 4) Tokenize dataset --- def tokenize_fn(ex): text = ( f"Instruction: {ex['instruction']}\n\n" f"Response: {ex['response']}" ) return tokenizer(text, truncation=True, max_length=512) ds = ds.map(tokenize_fn, batched=False, remove_columns=ds.column_names) ds.set_format("torch") collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) # Training args training_args = TrainingArguments( output_dir="finetuned_myr1", num_train_epochs=1, per_device_train_batch_size=1, gradient_accumulation_steps=2, logging_steps=5, save_steps=999999, save_total_limit=1, fp16=False, ) trainer = Trainer( model=lora_model, args=training_args, train_dataset=ds, data_collator=collator, ) trainer.train() # --- 5) Save LoRA adapter + tokenizer --- trainer.model.save_pretrained("finetuned_myr1") tokenizer.save_pretrained("finetuned_myr1") # --- 6) Reload for inference base_model_2 = AutoModelForCausalLM.from_pretrained( "wuhp/myr1", subfolder="myr1", config=config, quantization_config=bnb_config, device_map="auto", trust_remote_code=True ) base_model_2 = prepare_model_for_kbit_training(base_model_2) lora_model_2 = PeftModel.from_pretrained( base_model_2, "finetuned_myr1", ) global TEXT_PIPELINE TEXT_PIPELINE = pipeline("text-generation", model=lora_model_2, tokenizer=tokenizer) return "Finetuning complete. Model loaded for inference." def ensure_pipeline(): """ If we haven't finetuned yet (TEXT_PIPELINE is None), load the base model in 4-bit with NO LoRA. """ global TEXT_PIPELINE if TEXT_PIPELINE is None: bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) config = AutoConfig.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True) base_model = AutoModelForCausalLM.from_pretrained( "wuhp/myr1", subfolder="myr1", config=config, quantization_config=bnb_config, device_map="auto", trust_remote_code=True ) TEXT_PIPELINE = pipeline("text-generation", model=base_model, tokenizer=tokenizer) return TEXT_PIPELINE def ensure_comparison_pipeline(): """ Load the DeepSeek model pipeline if not already loaded. """ global COMPARISON_PIPELINE if COMPARISON_PIPELINE is None: config = AutoConfig.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Llama-8B") tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Llama-8B") model = AutoModelForCausalLM.from_pretrained( "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", config=config, device_map="auto" ) COMPARISON_PIPELINE = pipeline("text-generation", model=model, tokenizer=tokenizer) return COMPARISON_PIPELINE @spaces.GPU(duration=120) def predict(prompt, temperature, top_p, min_new_tokens, max_new_tokens): """ Simple single-prompt generation (no retrieval). """ pipe = ensure_pipeline() out = pipe( prompt, temperature=float(temperature), top_p=float(top_p), min_new_tokens=int(min_new_tokens), max_new_tokens=int(max_new_tokens), do_sample=True ) return out[0]["generated_text"] @spaces.GPU(duration=120) def compare_models(prompt, temperature, top_p, min_new_tokens, max_new_tokens): """ Compare local pipeline vs. DeepSeek side-by-side. """ local_pipe = ensure_pipeline() comp_pipe = ensure_comparison_pipeline() local_out = local_pipe( prompt, temperature=float(temperature), top_p=float(top_p), min_new_tokens=int(min_new_tokens), max_new_tokens=int(max_new_tokens), do_sample=True ) comp_out = comp_pipe( prompt, temperature=float(temperature), top_p=float(top_p), min_new_tokens=int(min_new_tokens), max_new_tokens=int(max_new_tokens), do_sample=True ) return local_out[0]["generated_text"], comp_out[0]["generated_text"] ############################################################################### # Retrieval-Augmented Memory with FAISS ############################################################################### class ConversationRetriever: """ A simple in-memory store + FAISS for retrieval of conversation chunks. Each chunk is embedded via SentenceTransformer. On a new user query, we embed the query, do similarity search, and retrieve top-k relevant chunks. """ def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2", embed_dim=384): """ model_name: embedding model for messages embed_dim: dimension of the embeddings from that model """ self.embed_model = SentenceTransformer(model_name) self.embed_dim = embed_dim # We'll store (text, vector) in FAISS. For metadata, store in python list/dict. # For a real app, you'd probably want a more robust store. self.index = faiss.IndexFlatL2(embed_dim) self.texts = [] # store the raw text chunks self.vectors = [] # store vectors (redundant but simpler to show) self.ids = [] # store an integer ID or similar self.id_counter = 0 def add_text(self, text): """ Add a new text chunk to the vector store. Could chunk it up if desired, but here we treat the entire text as one chunk. """ if not text.strip(): return emb = self.embed_model.encode([text], convert_to_numpy=True) vec = emb[0].astype(np.float32) # shape [embed_dim] self.index.add(vec.reshape(1, -1)) self.texts.append(text) self.vectors.append(vec) self.ids.append(self.id_counter) self.id_counter += 1 def search(self, query, top_k=3): """ Given a query, embed it, do similarity search in FAISS, return top-k texts. """ q_emb = self.embed_model.encode([query], convert_to_numpy=True).astype(np.float32) q_vec = q_emb[0].reshape(1, -1) distances, indices = self.index.search(q_vec, top_k) # indices is shape [1, top_k], distances is shape [1, top_k] results = [] for dist, idx in zip(distances[0], indices[0]): if idx < len(self.texts): # safety check results.append((self.texts[idx], dist)) return results ############################################################################### # Build a Chat that uses RAG ############################################################################### retriever = ConversationRetriever() # global retriever instance def build_rag_prompt(user_query, retrieved_chunks): """ Construct a prompt that includes: - The user's new query - A "Relevant Context" section from retrieved chunks - "Assistant:" to let the model continue Feel free to customize the formatting as you like. """ context_str = "" for i, (chunk, dist) in enumerate(retrieved_chunks): context_str += f"Chunk #{i+1} (similarity score ~ {dist:.2f}):\n{chunk}\n\n" prompt = ( f"User's Query:\n{user_query}\n\n" f"Relevant Context from Conversation:\n{context_str}" "Assistant:" ) return prompt @spaces.GPU(duration=120) def chat_rag(user_input, history, temperature, top_p, min_new_tokens, max_new_tokens): """ Our RAG-based chat function. We'll: 1) Add user input to FAISS 2) Retrieve top-k relevant older messages from FAISS 3) Build a prompt that includes the relevant chunks + user query 4) Generate a response from the pipeline 5) Add the assistant's response to FAISS as well """ pipe = ensure_pipeline() # 1) Add the user input as a chunk to the retriever DB. retriever.add_text(f"User: {user_input}") # 2) Retrieve top-3 older chunks. We can skip the chunk we just added if we want to # (since it's the same text), but for simplicity let's just do a search for user_input. top_k = 3 results = retriever.search(user_input, top_k=top_k) # 3) Build final prompt prompt = build_rag_prompt(user_input, results) # 4) Generate output = pipe( prompt, temperature=float(temperature), top_p=float(top_p), min_new_tokens=int(min_new_tokens), max_new_tokens=int(max_new_tokens), do_sample=True )[0]["generated_text"] # We only want the new part after "Assistant:" # Because the pipeline output includes the entire prompt + new text. if output.startswith(prompt): assistant_reply = output[len(prompt):].strip() else: assistant_reply = output.strip() # 5) Add the assistant's response to the DB as well retriever.add_text(f"Assistant: {assistant_reply}") # 6) Update the chat history for display in the Gradio Chatbot history.append([user_input, assistant_reply]) return history, history ############################################################################### # Gradio UI ############################################################################### with gr.Blocks() as demo: gr.Markdown("# QLoRA Fine-tuning & RAG-based Chat Demo") finetune_btn = gr.Button("Finetune 4-bit (QLoRA) on Magpie subset (up to 5 min)") status_box = gr.Textbox(label="Finetune Status") finetune_btn.click(fn=finetune_small_subset, outputs=status_box) # Simple generation UI (no retrieval): gr.Markdown("## Direct Generation (No Retrieval)") prompt_in = gr.Textbox(lines=3, label="Prompt") temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature") top_p = gr.Slider(0.0, 1.0, step=0.05, value=0.9, label="Top-p") min_tokens = gr.Slider(1, 2500, value=50, step=10, label="Min New Tokens") max_tokens = gr.Slider(1, 2500, value=200, step=50, label="Max New Tokens") output_box = gr.Textbox(label="myr1 Output", lines=8) gen_btn = gr.Button("Generate with myr1") gen_btn.click( fn=predict, inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens], outputs=output_box ) # Comparison UI: gr.Markdown("## Compare myr1 vs DeepSeek") compare_btn = gr.Button("Compare") out_local = gr.Textbox(label="myr1 Output", lines=6) out_deepseek = gr.Textbox(label="DeepSeek Output", lines=6) compare_btn.click( fn=compare_models, inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens], outputs=[out_local, out_deepseek] ) # RAG-based Chat gr.Markdown("## Chat with Retrieval-Augmented Memory") with gr.Row(): with gr.Column(): chatbot = gr.Chatbot(label="RAG Chat") chat_state = gr.State([]) # just for display user_input = gr.Textbox( show_label=False, placeholder="Ask a question...", lines=2 ) send_btn = gr.Button("Send") # On user submit, call chat_rag user_input.submit( fn=chat_rag, inputs=[user_input, chat_state, temperature, top_p, min_tokens, max_tokens], outputs=[chat_state, chatbot] ) send_btn.click( fn=chat_rag, inputs=[user_input, chat_state, temperature, top_p, min_tokens, max_tokens], outputs=[chat_state, chatbot] ) demo.launch()