import gradio as gr import numpy as np import torch import faiss import spaces from datasets import load_dataset from peft import LoraConfig, PeftModel, TaskType, get_peft_model, prepare_model_for_kbit_training from sentence_transformers import SentenceTransformer from transformers import ( AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, DataCollatorForLanguageModeling, Trainer, TrainingArguments, pipeline, ) NUM_EXAMPLES_FOR_FINETUNING = 50 # Constant for the number of examples to use for finetuning TEXT_PIPELINE = None # Global to store the custom R1 text generation pipeline COMPARISON_PIPELINE = None # Global to store the official R1 text generation pipeline def _load_model_and_tokenizer(model_name: str, subfolder: str = None, quantization_config: BitsAndBytesConfig = None, device_map: str = "auto", trust_remote_code: bool = True) -> tuple[AutoModelForCausalLM, AutoTokenizer]: """ Helper function to load a causal language model and its tokenizer. Args: model_name (str): The name or path of the pretrained model. subfolder (str, optional): Subfolder within the model repository. Defaults to None. quantization_config (BitsAndBytesConfig, optional): Configuration for quantization. Defaults to None. device_map (str, optional): Device mapping for model loading. Defaults to "auto". trust_remote_code (bool, optional): Trust remote code for custom models. Defaults to True. Returns: tuple[AutoModelForCausalLM, AutoTokenizer]: The loaded model and tokenizer. """ config = AutoConfig.from_pretrained(model_name, subfolder=subfolder, trust_remote_code=trust_remote_code) tokenizer = AutoTokenizer.from_pretrained(model_name, subfolder=subfolder, trust_remote_code=trust_remote_code) model = AutoModelForCausalLM.from_pretrained( model_name, subfolder=subfolder, config=config, quantization_config=quantization_config, device_map=device_map, trust_remote_code=trust_remote_code ) return model, tokenizer @spaces.GPU(duration=300) def finetune_small_subset() -> str: """ Fine-tunes the custom R1 model on a small subset of the ServiceNow-AI/R1-Distill-SFT dataset. Steps: 1) Loads the model from "wuhp/myr1" (using files from the "myr1" subfolder via trust_remote_code). 2) Applies 4-bit quantization and prepares for QLoRA training. 3) Fine-tunes on the dataset (mapping "problem" to prompt and "solution" to target). 4) Saves the LoRA adapter to "finetuned_myr1". 5) Reloads the adapter for inference. Returns: str: A message indicating finetuning completion. """ ds = load_dataset("ServiceNow-AI/R1-Distill-SFT", "v0", split="train") ds = ds.select(range(min(NUM_EXAMPLES_FOR_FINETUNING, len(ds)))) bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) base_model, tokenizer = _load_model_and_tokenizer( "wuhp/myr1", subfolder="myr1", quantization_config=bnb_config, device_map="auto" ) base_model = prepare_model_for_kbit_training(base_model) lora_config = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.05, bias="none", target_modules=["q_proj", "v_proj"], task_type=TaskType.CAUSAL_LM, ) lora_model = get_peft_model(base_model, lora_config) def tokenize_fn(ex): text = ( f"Problem: {ex['problem']}\n\n" f"Solution: {ex['solution']}" ) return tokenizer(text, truncation=True, max_length=512) ds = ds.map(tokenize_fn, batched=False, remove_columns=ds.column_names) ds.set_format("torch") collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) training_args = TrainingArguments( output_dir="finetuned_myr1", num_train_epochs=1, per_device_train_batch_size=1, gradient_accumulation_steps=2, logging_steps=5, save_steps=999999, save_total_limit=1, fp16=False, ) trainer = Trainer( model=lora_model, args=training_args, train_dataset=ds, data_collator=collator, ) trainer.train() trainer.model.save_pretrained("finetuned_myr1") tokenizer.save_pretrained("finetuned_myr1") base_model_2, tokenizer_2 = _load_model_and_tokenizer( "wuhp/myr1", subfolder="myr1", quantization_config=bnb_config, device_map="auto" ) base_model_2 = prepare_model_for_kbit_training(base_model_2) lora_model_2 = PeftModel.from_pretrained( base_model_2, "finetuned_myr1", ) global TEXT_PIPELINE TEXT_PIPELINE = pipeline("text-generation", model=lora_model_2, tokenizer=tokenizer_2) return "Finetuning complete. Model loaded for inference." def ensure_pipeline() -> pipeline: """ Loads the base model (without LoRA) if no fine-tuned model is available. Returns: pipeline: The text generation pipeline using the custom R1 model. """ global TEXT_PIPELINE if TEXT_PIPELINE is None: bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) base_model, tokenizer = _load_model_and_tokenizer( "wuhp/myr1", subfolder="myr1", quantization_config=bnb_config, device_map="auto" ) TEXT_PIPELINE = pipeline("text-generation", model=base_model, tokenizer=tokenizer) return TEXT_PIPELINE def ensure_comparison_pipeline() -> pipeline: """ Loads the official R1 model pipeline if not already loaded. Returns: pipeline: The text generation pipeline using the official R1 model. """ global COMPARISON_PIPELINE if COMPARISON_PIPELINE is None: config = AutoConfig.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Llama-8B") tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Llama-8B") model = AutoModelForCausalLM.from_pretrained( "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", config=config, device_map="auto" ) COMPARISON_PIPELINE = pipeline("text-generation", model=model, tokenizer=tokenizer) return COMPARISON_PIPELINE @spaces.GPU(duration=120) def predict( prompt: str, temperature: float, top_p: float, min_new_tokens: int, max_new_tokens: int ) -> str: """ Direct generation without retrieval using the custom R1 model. Args: prompt (str): The input prompt for text generation. temperature (float): Sampling temperature. top_p (float): Top-p sampling probability. min_new_tokens (int): Minimum number of new tokens to generate. max_new_tokens (int): Maximum number of new tokens to generate. Returns: str: The generated text output with "Thinking Process" and "Solution" sections. """ pipe = ensure_pipeline() thinking_prefix = "**Thinking Process:**\n" solution_prefix = "\n**Solution:**\n" formatted_output = thinking_prefix output = pipe( prompt, temperature=float(temperature), top_p=float(top_p), min_new_tokens=int(min_new_tokens), max_new_tokens=int(max_new_tokens), do_sample=True )[0]["generated_text"] formatted_output += output.strip() return formatted_output @spaces.GPU(duration=120) def compare_models( prompt: str, temperature: float, top_p: float, min_new_tokens: int, max_new_tokens: int ) -> tuple[str, str]: """ Compare outputs between your custom R1 model and the official R1 model. Args: prompt (str): The input prompt for text generation. temperature (float): Sampling temperature. top_p (float): Sampling top-p. min_new_tokens (int): Minimum number of new tokens to generate. max_new_tokens (int): Maximum number of new tokens to generate. Returns: tuple[str, str]: A tuple containing the formatted generated text from the custom R1 and official R1 models, each with "Thinking Process" and "Solution" sections. """ local_pipe = ensure_pipeline() comp_pipe = ensure_comparison_pipeline() def format_comparison_output(model_name, raw_output): thinking_prefix = f"**{model_name} - Thinking Process:**\n" solution_prefix = f"\n**{model_name} - Solution:**\n" formatted_output = thinking_prefix formatted_output += raw_output.strip() return formatted_output local_out_raw = local_pipe( prompt, temperature=float(temperature), top_p=float(top_p), min_new_tokens=int(min_new_tokens), max_new_tokens=int(max_new_tokens), do_sample=True )[0]["generated_text"] comp_out_raw = comp_pipe( prompt, temperature=float(temperature), top_p=float(top_p), min_new_tokens=int(min_new_tokens), max_new_tokens=int(max_new_tokens), do_sample=True )[0]["generated_text"] local_out_formatted = format_comparison_output("Custom R1", local_out_raw) comp_out_formatted = format_comparison_output("Official R1", comp_out_raw) return local_out_formatted, comp_out_formatted class ConversationRetriever: """ A FAISS-based retriever using SentenceTransformer for embedding. This class indexes text chunks using FAISS and SentenceTransformer embeddings to enable efficient similarity search for retrieval-augmented generation. """ def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", embed_dim: int = 384): """ Initializes the ConversationRetriever. Args: model_name (str, optional): Name of the SentenceTransformer model. Defaults to "sentence-transformers/all-MiniLM-L6-v2". embed_dim (int, optional): Dimensionality of the embeddings. Defaults to 384. """ self.embed_model = SentenceTransformer(model_name) self.embed_dim = embed_dim self.index = faiss.IndexFlatL2(embed_dim) self.texts = [] self.vectors = [] self.ids = [] self.id_counter = 0 def add_text(self, text: str): """ Adds text to the retriever's index. Args: text (str): The text to add. """ if not text.strip(): return emb = self.embed_model.encode([text], convert_to_numpy=True) vec = emb[0].astype(np.float32) self.index.add(vec.reshape(1, -1)) self.texts.append(text) self.vectors.append(vec) self.ids.append(self.id_counter) self.id_counter += 1 def search(self, query: str, top_k: int = 3) -> list[tuple[str, float]]: """ Searches the retriever index for texts similar to the query. Args: query (str): The query text. top_k (int, optional): Number of top results to retrieve. Defaults to 3. Returns: list[tuple[str, float]]: A list of tuples, where each tuple contains (text, distance). """ q_emb = self.embed_model.encode([query], convert_to_numpy=True).astype(np.float32) q_vec = q_emb[0].reshape(1, -1) distances, indices = self.index.search(q_vec, top_k) results = [] for dist, idx in zip(distances[0], indices[0]): if idx < len(self.texts): results.append((self.texts[idx], dist)) return results retriever = ConversationRetriever() def build_rag_prompt(user_query: str, retrieved_chunks: list[tuple[str, float]]) -> str: """ Builds a prompt for retrieval-augmented generation. Args: user_query (str): The user's input query. retrieved_chunks (list[tuple[str, float]]): List of retrieved text chunks and their distances. Returns: str: The formatted prompt string including instructions for step-by-step thinking and using context. """ context_str = "" if retrieved_chunks: context_str += "**Relevant Context:**\n" for i, (chunk, dist) in enumerate(retrieved_chunks): context_str += f"Chunk #{i+1} (similarity ~ {dist:.2f}):\n> {chunk}\n\n" prompt_instruction = "Please provide a detailed answer, showing your thinking process step-by-step before stating the final answer. Use the provided context if relevant." prompt = ( f"**User Query:**\n{user_query}\n\n" f"{context_str}\n" f"{prompt_instruction}\n\n" "**Answer:**\n" ) return prompt @spaces.GPU(duration=120) def chat_rag( user_input: str, history: list[list[str]], temperature: float, top_p: float, min_new_tokens: int, max_new_tokens: int ) -> tuple[list[list[str]], list[list[str]]]: """ Chat with retrieval augmentation using the custom R1 model. Args: user_input (str): The user's chat input. history (list[list[str]]): The chat history. temperature (float): Sampling temperature. top_p (float): Sampling top-p. min_new_tokens (int): Minimum number of new tokens to generate. max_new_tokens (int): Maximum number of new tokens to generate. Returns: tuple[list[list[str]], list[list[str]]]: Updated chat history and chatbot display history, with formatted assistant replies. """ pipe = ensure_pipeline() retriever.add_text(f"User: {user_input}") top_k = 3 results = retriever.search(user_input, top_k=top_k) prompt = build_rag_prompt(user_input, results) thinking_prefix = "**Thinking Process:**\n" solution_prefix = "\n**Solution:**\n" formatted_output = thinking_prefix output = pipe( prompt, temperature=float(temperature), top_p=float(top_p), min_new_tokens=int(min_new_tokens), max_new_tokens=int(max_new_tokens), do_sample=True )[0]["generated_text"] formatted_output += output.strip() assistant_reply = formatted_output if assistant_reply.startswith(prompt): assistant_reply = assistant_reply[len(prompt):].strip() else: assistant_reply = assistant_reply.strip() retriever.add_text(f"Assistant: {assistant_reply}") history.append([user_input, assistant_reply]) return history, history # Build the Gradio interface with tabs. with gr.Blocks(css=""" body {background-color: #f5f5f5; font-family: Arial, sans-serif;} .gradio-container {max-width: 1000px; margin: auto; background: white; padding: 20px; border-radius: 10px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);} h1 {color: #333; text-align: center; font-size: 2rem;} h2 {color: #444; margin-top: 10px; font-size: 1.5rem;} .gr-tab {padding: 10px;} """) as demo: gr.Markdown("# 🚀 QLoRA Fine-tuning & RAG Chat Demo") gr.Markdown("Welcome to the enhanced **QLoRA fine-tuning and RAG-based chatbot interface**. This tool lets you fine-tune an AI model, generate text, and interact with a chatbot using retrieval-augmented responses.") with gr.TabbedInterface(): # Fine-tuning tab with gr.Tab(label="⚙️ Fine-tune Model"): gr.Markdown("### Train your custom R1 model") gr.Markdown("Fine-tune the model using QLoRA. This is **optional**, but recommended for better performance.") finetune_btn = gr.Button("Start Fine-tuning") finetune_output = gr.Textbox(label="Status", interactive=False) finetune_btn.click(finetune_small_subset, inputs=None, outputs=finetune_output) # Text Generation tab with gr.Tab(label="✍️ Text Generation"): gr.Markdown("### Generate text using your fine-tuned model") input_prompt = gr.Textbox(label="Enter Prompt", placeholder="Type something here...", lines=3) temp_slider = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature") topp_slider = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") min_tokens = gr.Slider(1, 1000, value=50, step=10, label="Min New Tokens") max_tokens = gr.Slider(1, 1000, value=200, step=10, label="Max New Tokens") generate_btn = gr.Button("Generate Text") output_box = gr.Textbox(label="Generated Output", lines=8, interactive=False) generate_btn.click(predict, inputs=[input_prompt, temp_slider, topp_slider, min_tokens, max_tokens], outputs=output_box) # Model Comparison tab with gr.Tab(label="🆚 Compare Models"): gr.Markdown("### Compare text outputs from your fine-tuned model and the official model") compare_prompt = gr.Textbox(label="Enter Comparison Prompt", placeholder="Enter a prompt here...", lines=3) compare_temp = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature") compare_topp = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") compare_min_tokens = gr.Slider(1, 1000, value=50, step=10, label="Min New Tokens") compare_max_tokens = gr.Slider(1, 1000, value=200, step=10, label="Max New Tokens") compare_btn = gr.Button("Compare Models") compare_output1 = gr.Textbox(label="Custom Model Output", lines=6, interactive=False) compare_output2 = gr.Textbox(label="Official Model Output", lines=6, interactive=False) compare_btn.click(compare_models, inputs=[compare_prompt, compare_temp, compare_topp, compare_min_tokens, compare_max_tokens], outputs=[compare_output1, compare_output2]) # Chatbot tab with gr.Tab(label="💬 AI Chatbot"): gr.Markdown("### Chat with an AI assistant using retrieval-augmented generation (RAG)") chatbot = gr.Chatbot(label="AI Chatbot", height=400) chat_input = gr.Textbox(placeholder="Ask me anything...", lines=2) chat_btn = gr.Button("Send") chat_output = gr.Chatbot(label="Chat History") chat_btn.click(chat_rag, inputs=[chat_input, chatbot, temp_slider, topp_slider, min_tokens, max_tokens], outputs=[chat_output, chatbot]) demo.launch()