import gradio as gr import numpy as np import torch import faiss import spaces from datasets import load_dataset from peft import LoraConfig, PeftModel, TaskType, get_peft_model, prepare_model_for_kbit_training from sentence_transformers import SentenceTransformer from transformers import ( AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, DataCollatorForLanguageModeling, Trainer, TrainingArguments, pipeline, ) NUM_EXAMPLES_FOR_FINETUNING = 50 # Constant for the number of examples to use for finetuning TEXT_PIPELINE = None # Global to store the custom R1 text generation pipeline COMPARISON_PIPELINE = None # Global to store the official R1 text generation pipeline def _load_model_and_tokenizer(model_name: str, subfolder: str = None, quantization_config: BitsAndBytesConfig = None, device_map: str = "auto", trust_remote_code: bool = True) -> tuple[AutoModelForCausalLM, AutoTokenizer]: """ Helper function to load a causal language model and its tokenizer. Args: model_name (str): The name or path of the pretrained model. subfolder (str, optional): Subfolder within the model repository. Defaults to None. quantization_config (BitsAndBytesConfig, optional): Configuration for quantization. Defaults to None. device_map (str, optional): Device mapping for model loading. Defaults to "auto". trust_remote_code (bool, optional): Trust remote code for custom models. Defaults to True. Returns: tuple[AutoModelForCausalLM, AutoTokenizer]: The loaded model and tokenizer. """ config = AutoConfig.from_pretrained(model_name, subfolder=subfolder, trust_remote_code=trust_remote_code) tokenizer = AutoTokenizer.from_pretrained(model_name, subfolder=subfolder, trust_remote_code=trust_remote_code) model = AutoModelForCausalLM.from_pretrained( model_name, subfolder=subfolder, config=config, quantization_config=quantization_config, device_map=device_map, trust_remote_code=trust_remote_code ) return model, tokenizer @spaces.GPU(duration=300) def finetune_small_subset() -> str: """ Fine-tunes the custom R1 model on a small subset of the ServiceNow-AI/R1-Distill-SFT dataset. Steps: 1) Loads the model from "wuhp/myr1" (using files from the "myr1" subfolder via trust_remote_code). 2) Applies 4-bit quantization and prepares for QLoRA training. 3) Fine-tunes on the dataset (mapping "problem" to prompt and "solution" to target). 4) Saves the LoRA adapter to "finetuned_myr1". 5) Reloads the adapter for inference. Returns: str: A message indicating finetuning completion. """ ds = load_dataset("ServiceNow-AI/R1-Distill-SFT", "v0", split="train") ds = ds.select(range(min(NUM_EXAMPLES_FOR_FINETUNING, len(ds)))) bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) base_model, tokenizer = _load_model_and_tokenizer( "wuhp/myr1", subfolder="myr1", quantization_config=bnb_config, device_map="auto" ) base_model = prepare_model_for_kbit_training(base_model) lora_config = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.05, bias="none", target_modules=["q_proj", "v_proj"], task_type=TaskType.CAUSAL_LM, ) lora_model = get_peft_model(base_model, lora_config) def tokenize_fn(ex): text = ( f"Problem: {ex['problem']}\n\n" f"Solution: {ex['solution']}" ) return tokenizer(text, truncation=True, max_length=512) ds = ds.map(tokenize_fn, batched=False, remove_columns=ds.column_names) ds.set_format("torch") collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) training_args = TrainingArguments( output_dir="finetuned_myr1", num_train_epochs=1, per_device_train_batch_size=1, gradient_accumulation_steps=2, logging_steps=5, save_steps=999999, save_total_limit=1, fp16=False, ) trainer = Trainer( model=lora_model, args=training_args, train_dataset=ds, data_collator=collator, ) trainer.train() trainer.model.save_pretrained("finetuned_myr1") tokenizer.save_pretrained("finetuned_myr1") base_model_2, tokenizer_2 = _load_model_and_tokenizer( "wuhp/myr1", subfolder="myr1", quantization_config=bnb_config, device_map="auto" ) base_model_2 = prepare_model_for_kbit_training(base_model_2) lora_model_2 = PeftModel.from_pretrained( base_model_2, "finetuned_myr1", ) global TEXT_PIPELINE TEXT_PIPELINE = pipeline("text-generation", model=lora_model_2, tokenizer=tokenizer_2) return "Finetuning complete. Model loaded for inference." def ensure_pipeline() -> pipeline: """ Loads the base model (without LoRA) if no fine-tuned model is available. Returns: pipeline: The text generation pipeline using the custom R1 model. """ global TEXT_PIPELINE if TEXT_PIPELINE is None: bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) base_model, tokenizer = _load_model_and_tokenizer( "wuhp/myr1", subfolder="myr1", quantization_config=bnb_config, device_map="auto" ) TEXT_PIPELINE = pipeline("text-generation", model=base_model, tokenizer=tokenizer) return TEXT_PIPELINE def ensure_comparison_pipeline() -> pipeline: """ Loads the official R1 model pipeline if not already loaded. Returns: pipeline: The text generation pipeline using the official R1 model. """ global COMPARISON_PIPELINE if COMPARISON_PIPELINE is None: config = AutoConfig.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Llama-8B") tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Llama-8B") model = AutoModelForCausalLM.from_pretrained( "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", config=config, device_map="auto" ) COMPARISON_PIPELINE = pipeline("text-generation", model=model, tokenizer=tokenizer) return COMPARISON_PIPELINE @spaces.GPU(duration=120) def predict( prompt: str, temperature: float, top_p: float, min_new_tokens: int, max_new_tokens: int ) -> str: """ Direct generation without retrieval using the custom R1 model. Args: prompt (str): The input prompt for text generation. temperature (float): Sampling temperature. top_p (float): Top-p sampling probability. min_new_tokens (int): Minimum number of new tokens to generate. max_new_tokens (int): Maximum number of new tokens to generate. Returns: str: The generated text output with "Thinking Process" and "Solution" sections. """ pipe = ensure_pipeline() thinking_prefix = "**Thinking Process:**\n" solution_prefix = "\n**Solution:**\n" formatted_output = thinking_prefix output = pipe( prompt, temperature=float(temperature), top_p=float(top_p), min_new_tokens=int(min_new_tokens), max_new_tokens=int(max_new_tokens), do_sample=True )[0]["generated_text"] formatted_output += output.strip() return formatted_output @spaces.GPU(duration=120) def compare_models( prompt: str, temperature: float, top_p: float, min_new_tokens: int, max_new_tokens: int ) -> tuple[str, str]: """ Compare outputs between your custom R1 model and the official R1 model. Args: prompt (str): The input prompt for text generation. temperature (float): Sampling temperature. top_p (float): Top-p sampling probability. min_new_tokens (int): Minimum number of new tokens to generate. max_new_tokens (int): Maximum number of new tokens to generate. Returns: tuple[str, str]: A tuple containing the formatted generated text from the custom R1 and official R1 models, each with "Thinking Process" and "Solution" sections. """ local_pipe = ensure_pipeline() comp_pipe = ensure_comparison_pipeline() def format_comparison_output(model_name, raw_output): thinking_prefix = f"**{model_name} - Thinking Process:**\n" solution_prefix = f"\n**{model_name} - Solution:**\n" formatted_output = thinking_prefix formatted_output += raw_output.strip() return formatted_output local_out_raw = local_pipe( prompt, temperature=float(temperature), top_p=float(top_p), min_new_tokens=int(min_new_tokens), max_new_tokens=int(max_new_tokens), do_sample=True )[0]["generated_text"] comp_out_raw = comp_pipe( prompt, temperature=float(temperature), top_p=float(top_p), min_new_tokens=int(min_new_tokens), max_new_tokens=int(max_new_tokens), do_sample=True )[0]["generated_text"] local_out_formatted = format_comparison_output("Custom R1", local_out_raw) comp_out_formatted = format_comparison_output("Official R1", comp_out_raw) return local_out_formatted, comp_out_formatted class ConversationRetriever: """ A FAISS-based retriever using SentenceTransformer for embedding. This class indexes text chunks using FAISS and SentenceTransformer embeddings to enable efficient similarity search for retrieval-augmented generation. """ def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", embed_dim: int = 384): """ Initializes the ConversationRetriever. Args: model_name (str, optional): Name of the SentenceTransformer model. Defaults to "sentence-transformers/all-MiniLM-L6-v2". embed_dim (int, optional): Dimensionality of the embeddings. Defaults to 384. """ self.embed_model = SentenceTransformer(model_name) self.embed_dim = embed_dim self.index = faiss.IndexFlatL2(embed_dim) self.texts = [] self.vectors = [] self.ids = [] self.id_counter = 0 def add_text(self, text: str): """ Adds text to the retriever's index. Args: text (str): The text to add. """ if not text.strip(): return emb = self.embed_model.encode([text], convert_to_numpy=True) vec = emb[0].astype(np.float32) self.index.add(vec.reshape(1, -1)) self.texts.append(text) self.vectors.append(vec) self.ids.append(self.id_counter) self.id_counter += 1 def search(self, query: str, top_k: int = 3) -> list[tuple[str, float]]: """ Searches the retriever index for texts similar to the query. Args: query (str): The query text. top_k (int, optional): Number of top results to retrieve. Defaults to 3. Returns: list[tuple[str, float]]: A list of tuples, where each tuple contains (text, distance). """ q_emb = self.embed_model.encode([query], convert_to_numpy=True).astype(np.float32) q_vec = q_emb[0].reshape(1, -1) distances, indices = self.index.search(q_vec, top_k) results = [] for dist, idx in zip(distances[0], indices[0]): if idx < len(self.texts): results.append((self.texts[idx], dist)) return results retriever = ConversationRetriever() def build_rag_prompt(user_query: str, retrieved_chunks: list[tuple[str, float]]) -> str: """ Builds a prompt for retrieval-augmented generation. Args: user_query (str): The user's input query. retrieved_chunks (list[tuple[str, float]]): List of retrieved text chunks and their distances. Returns: str: The formatted prompt string including instructions for step-by-step thinking and using context. """ context_str = "" if retrieved_chunks: context_str += "**Relevant Context:**\n" for i, (chunk, dist) in enumerate(retrieved_chunks): context_str += f"Chunk #{i+1} (similarity ~ {dist:.2f}):\n> {chunk}\n\n" prompt_instruction = "Please provide a detailed answer, showing your thinking process step-by-step before stating the final answer. Use the provided context if relevant." prompt = ( f"**User Query:**\n{user_query}\n\n" f"{context_str}\n" f"{prompt_instruction}\n\n" "**Answer:**\n" ) return prompt @spaces.GPU(duration=120) def chat_rag( user_input: str, history: list[list[str]], temperature: float, top_p: float, min_new_tokens: int, max_new_tokens: int ) -> tuple[list[list[str]], list[list[str]]]: """ Chat with retrieval augmentation using the custom R1 model. Args: user_input (str): The user's chat input. history (list[list[str]]): The chat history. temperature (float): Sampling temperature. top_p (float): Top-p sampling probability. min_new_tokens (int): Minimum number of new tokens to generate. max_new_tokens (int): Maximum number of new tokens to generate. Returns: tuple[list[list[str]], list[list[str]]]: Updated chat history and chatbot display history, with formatted assistant replies. """ pipe = ensure_pipeline() retriever.add_text(f"User: {user_input}") top_k = 3 results = retriever.search(user_input, top_k=top_k) prompt = build_rag_prompt(user_input, results) thinking_prefix = "**Thinking Process:**\n" solution_prefix = "\n**Solution:**\n" formatted_output = thinking_prefix output = pipe( prompt, temperature=float(temperature), top_p=float(top_p), min_new_tokens=int(min_new_tokens), max_new_tokens=int(max_new_tokens), do_sample=True )[0]["generated_text"] formatted_output += output.strip() assistant_reply = formatted_output if assistant_reply.startswith(prompt): assistant_reply = assistant_reply[len(prompt):].strip() else: assistant_reply = assistant_reply.strip() retriever.add_text(f"Assistant: {assistant_reply}") history.append([user_input, assistant_reply]) return history, history # Build the Gradio interface. with gr.Blocks() as demo: gr.Markdown("# QLoRA Fine-tuning & RAG-based Chat Demo using Custom R1 Model") gr.Markdown("---") gr.Markdown("## ⚙️ Fine-tuning (Optional)") gr.Markdown("This section allows you to fine-tune the custom R1 model on a small subset of the ServiceNow dataset. This step is optional but can potentially improve the model's performance on ServiceNow-related tasks. **Note:** This process may take up to 5 minutes.") finetune_btn = gr.Button("🚀 Start Fine-tuning (QLoRA)") status_box = gr.Textbox(label="Fine-tuning Status", interactive=False) finetune_btn.click(fn=finetune_small_subset, outputs=status_box) gr.Markdown("---") gr.Markdown("## ✍️ Direct Generation (No Retrieval)") gr.Markdown("Enter a prompt below to generate text directly using the custom R1 model. This is standard text generation without retrieval augmentation.") prompt_in = gr.Textbox(lines=3, label="Input Prompt", placeholder="Enter your prompt here...") temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature (Creativity)") top_p = gr.Slider(0.0, 1.0, step=0.05, value=0.9, label="Top-p (Sampling Nucleus)") min_tokens = gr.Slider(1, 2500, value=50, step=10, label="Min New Tokens") max_tokens = gr.Slider(1, 2500, value=200, step=50, label="Max New Tokens") output_box = gr.Textbox(label="Custom R1 Output", lines=8, interactive=False) gen_btn = gr.Button("✨ Generate Text") gen_btn.click( fn=predict, inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens], outputs=output_box ) gr.Markdown("---") gr.Markdown("## 🆚 Compare Custom R1 vs Official R1") gr.Markdown("Enter a prompt to compare the text generation of your fine-tuned custom R1 model with the official DeepSeek-R1-Distill-Llama-8B model.") compare_prompt_in = gr.Textbox(lines=3, label="Comparison Prompt", placeholder="Enter prompt for comparison...") compare_btn = gr.Button("⚖️ Compare Models") out_custom = gr.Textbox(label="Custom R1 Output", lines=6, interactive=False) out_official = gr.Textbox(label="Official R1 Output", lines=6, interactive=False) compare_btn.click( fn=compare_models, inputs=[compare_prompt_in, temperature, top_p, min_tokens, max_tokens], outputs=[out_custom, out_official] ) gr.Markdown("---") gr.Markdown("## 💬 Chat with Retrieval-Augmented Memory (RAG)") gr.Markdown("Chat with the custom R1 model, enhanced with a retrieval-augmented memory. The model will retrieve relevant information based on your queries to provide more informed responses.") with gr.Row(): with gr.Column(): chatbot = gr.Chatbot(label="RAG Chatbot") chat_state = gr.State([]) user_input = gr.Textbox( show_label=False, placeholder="Ask a question to the RAG Chatbot...", lines=2 ) send_btn = gr.Button("➡️ Send") user_input.submit( fn=chat_rag, inputs=[user_input, chat_state, temperature, top_p, min_tokens, max_tokens], outputs=[chat_state, chatbot] ) send_btn.click( fn=chat_rag, inputs=[user_input, chat_state, temperature, top_p, min_tokens, max_tokens], outputs=[chat_state, chatbot] ) gr.Markdown("---") demo.launch()