{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Fine-tune ModernBERT with Synthetic Data for RAG\n", "\n", "This notebook demonstrates the fine-tuning process of `modernbert-embed-base` using synthetic data tailored for the Retrieval-Augmented Generation (RAG) model.\n", "\n", "It provides a complete walkthrough of the fine-tuning process after generating synthetic data using the Synthetic Data Generator. For a comprehensive explanation of the methodology and additional details, refer to the blog post: [Fine-tune ModernBERT for RAG with Synthetic Data](https://huggingface.co/blog/fine-tune-modernbert-for-rag-with-synthetic-data)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Getting Started" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Install the Dependencies" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install torch\n", "!pip install datasets\n", "!pip install sentence-transformers\n", "!pip install haystack-ai\n", "!pip install git+https://github.com/huggingface/transformers.git # for the latest version of transformers" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Import the Required Libraries" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch.utils.data import DataLoader\n", "\n", "from datasets import load_dataset, concatenate_datasets, Dataset, DatasetDict\n", "\n", "\n", "from sentence_transformers import (\n", " SentenceTransformer,\n", " SentenceTransformerModelCardData,\n", " CrossEncoder,\n", " InputExample,\n", " SentenceTransformerTrainer,\n", ")\n", "from sentence_transformers.losses import TripletLoss\n", "from sentence_transformers.training_args import (\n", " SentenceTransformerTrainingArguments,\n", " BatchSamplers,\n", ")\n", "from sentence_transformers.evaluation import TripletEvaluator\n", "from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator\n", "\n", "\n", "from haystack import Document, Pipeline\n", "from haystack.document_stores.in_memory import InMemoryDocumentStore\n", "from haystack.components.embedders import (\n", " SentenceTransformersDocumentEmbedder,\n", " SentenceTransformersTextEmbedder,\n", ")\n", "from haystack.components.rankers import SentenceTransformersDiversityRanker\n", "from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever\n", "from haystack.components.builders import ChatPromptBuilder\n", "from haystack.components.generators.chat import HuggingFaceAPIChatGenerator\n", "from haystack.dataclasses import ChatMessage\n", "from haystack.utils import Secret\n", "from haystack.utils.hf import HFGenerationAPIType" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Configure the Environment" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "MODEL = \"nomic-ai/modernbert-embed-base\"\n", "REPO_NAME = \"sdiazlor\" # your HF username here\n", "MODEL_NAME_BIENCODER = \"modernbert-embed-base-biencoder-human-rights\"\n", "MODEL_NAME_CROSSENCODER = \"modernbert-embed-base-crossencoder-human-rights\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using device: mps\n" ] } ], "source": [ "if torch.backends.mps.is_available():\n", " device = \"mps\"\n", "elif torch.cuda.is_available():\n", " device = \"cuda\"\n", "else:\n", " device = \"cpu\"\n", "\n", "print(f\"Using device: {device}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Pre-process the Synthetic Data" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dataset({\n", " features: ['context', 'question', 'response', 'positive_retrieval', 'negative_retrieval', 'positive_reranking', 'negative_reranking'],\n", " num_rows: 1000\n", "})" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Combine the generated datasets from files and prompts\n", "\n", "dataset_rag_from_file = load_dataset(f\"{REPO_NAME}/rag-human-rights-from-files\", split=\"train\")\n", "dataset_rag_from_prompt = load_dataset(f\"{REPO_NAME}/rag-human-rights-from-prompt\", split=\"train\")\n", "\n", "combined_rag_dataset = concatenate_datasets(\n", " [dataset_rag_from_file, dataset_rag_from_prompt]\n", ")\n", "\n", "combined_rag_dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dataset({\n", " features: ['context', 'question', 'response', 'positive_retrieval', 'negative_retrieval', 'positive_reranking', 'negative_reranking'],\n", " num_rows: 828\n", "})" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Filter out examples with empty or NaN values\n", "\n", "def filter_empty_or_nan(example):\n", " return all(\n", " value is not None and str(value).strip() != \"\" for value in example.values()\n", " )\n", "\n", "filtered_rag_dataset = combined_rag_dataset.filter(filter_empty_or_nan).shuffle(seed=42)\n", "filtered_rag_dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataset({\n", " features: ['anchor', 'positive', 'negative'],\n", " num_rows: 828\n", "})\n", "Dataset({\n", " features: ['anchor', 'positive'],\n", " num_rows: 828\n", "})\n" ] } ], "source": [ "# Rename, select and reorder columns according to the expected format for the SentenceTransformer and CrossEncoder models\n", "\n", "def rename_and_reorder_columns(dataset, rename_map, selected_columns):\n", " for old_name, new_name in rename_map.items():\n", " if old_name in dataset.column_names:\n", " dataset = dataset.rename_column(old_name, new_name)\n", " dataset = dataset.select_columns(selected_columns)\n", " return dataset\n", "\n", "clean_rag_dataset_biencoder = rename_and_reorder_columns(\n", " filtered_rag_dataset,\n", " rename_map={\"context\": \"anchor\", \"positive_retrieval\": \"positive\", \"negative_retrieval\": \"negative\"},\n", " selected_columns=[\"anchor\", \"positive\", \"negative\"],\n", ")\n", "\n", "clean_rag_dataset_crossencoder = rename_and_reorder_columns(\n", " filtered_rag_dataset,\n", " rename_map={\"context\": \"anchor\", \"positive_retrieval\": \"positive\"}, #TODO\n", " selected_columns=[\"anchor\", \"positive\"],\n", ")\n", "\n", "print(clean_rag_dataset_biencoder)\n", "print(clean_rag_dataset_crossencoder)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at Snowflake/snowflake-arctic-embed-m-v1.5 and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "406c4d22f43f41d592d3b94da2955444", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/828 [00:00\n", " \n", " \n", " [123/123 25:34, Epoch 2/3]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossCosine Accuracy
1No log3.6559290.969880
214.3740003.4983950.981928

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "faad6e9752f34babadff7a966ae55d87", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Computing widget examples: 0%| | 0/1 [00:00\n", "🚅 Components\n", " - text_embedder: SentenceTransformersTextEmbedder\n", " - retriever: InMemoryEmbeddingRetriever\n", " - ranker: SentenceTransformersDiversityRanker\n", " - prompt_builder: ChatPromptBuilder\n", " - llm: HuggingFaceAPIChatGenerator\n", "🛤️ Connections\n", " - text_embedder.embedding -> retriever.query_embedding (List[float])\n", " - retriever.documents -> ranker.documents (List[Document])\n", " - ranker.documents -> prompt_builder.documents (List[Document])\n", " - prompt_builder.prompt -> llm.messages (List[ChatMessage])" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Connect the components to each other\n", "\n", "rag_pipeline.connect(\"text_embedder.embedding\", \"retriever.query_embedding\")\n", "rag_pipeline.connect(\"retriever.documents\", \"ranker.documents\")\n", "rag_pipeline.connect(\"ranker\", \"prompt_builder\")\n", "rag_pipeline.connect(\"prompt_builder.prompt\", \"llm.messages\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "80c813c847524f1493067f6dbe65c725", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Batches: 0%| | 0/1 [00:00