{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Fine-tune ModernBERT with Synthetic Data for RAG\n", "\n", "This notebook demonstrates the fine-tuning process of `modernbert-embed-base` using synthetic data tailored for the Retrieval-Augmented Generation (RAG) model.\n", "\n", "It provides a complete walkthrough of the fine-tuning process after generating synthetic data using the Synthetic Data Generator. For a comprehensive explanation of the methodology and additional details, refer to the blog post: [Fine-tune ModernBERT for RAG with Synthetic Data](https://huggingface.co/blog/fine-tune-modernbert-for-rag-with-synthetic-data)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Getting Started" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Install the Dependencies" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install torch\n", "!pip install datasets\n", "!pip install sentence-transformers\n", "!pip install haystack-ai\n", "!pip install git+https://github.com/huggingface/transformers.git # for the latest version of transformers" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Import the Required Libraries" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch.utils.data import DataLoader\n", "\n", "from datasets import load_dataset, concatenate_datasets, Dataset, DatasetDict\n", "\n", "\n", "from sentence_transformers import (\n", " SentenceTransformer,\n", " SentenceTransformerModelCardData,\n", " CrossEncoder,\n", " InputExample,\n", " SentenceTransformerTrainer,\n", ")\n", "from sentence_transformers.losses import TripletLoss\n", "from sentence_transformers.training_args import (\n", " SentenceTransformerTrainingArguments,\n", " BatchSamplers,\n", ")\n", "from sentence_transformers.evaluation import TripletEvaluator\n", "from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator\n", "\n", "\n", "from haystack import Document, Pipeline\n", "from haystack.document_stores.in_memory import InMemoryDocumentStore\n", "from haystack.components.embedders import (\n", " SentenceTransformersDocumentEmbedder,\n", " SentenceTransformersTextEmbedder,\n", ")\n", "from haystack.components.rankers import SentenceTransformersDiversityRanker\n", "from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever\n", "from haystack.components.builders import ChatPromptBuilder\n", "from haystack.components.generators.chat import HuggingFaceAPIChatGenerator\n", "from haystack.dataclasses import ChatMessage\n", "from haystack.utils import Secret\n", "from haystack.utils.hf import HFGenerationAPIType" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Configure the Environment" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "MODEL = \"nomic-ai/modernbert-embed-base\"\n", "REPO_NAME = \"sdiazlor\" # your HF username here\n", "MODEL_NAME_BIENCODER = \"modernbert-embed-base-biencoder-human-rights\"\n", "MODEL_NAME_CROSSENCODER = \"modernbert-embed-base-crossencoder-human-rights\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using device: mps\n" ] } ], "source": [ "if torch.backends.mps.is_available():\n", " device = \"mps\"\n", "elif torch.cuda.is_available():\n", " device = \"cuda\"\n", "else:\n", " device = \"cpu\"\n", "\n", "print(f\"Using device: {device}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Pre-process the Synthetic Data" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dataset({\n", " features: ['context', 'question', 'response', 'positive_retrieval', 'negative_retrieval', 'positive_reranking', 'negative_reranking'],\n", " num_rows: 1000\n", "})" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Combine the generated datasets from files and prompts\n", "\n", "dataset_rag_from_file = load_dataset(f\"{REPO_NAME}/rag-human-rights-from-files\", split=\"train\")\n", "dataset_rag_from_prompt = load_dataset(f\"{REPO_NAME}/rag-human-rights-from-prompt\", split=\"train\")\n", "\n", "combined_rag_dataset = concatenate_datasets(\n", " [dataset_rag_from_file, dataset_rag_from_prompt]\n", ")\n", "\n", "combined_rag_dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dataset({\n", " features: ['context', 'question', 'response', 'positive_retrieval', 'negative_retrieval', 'positive_reranking', 'negative_reranking'],\n", " num_rows: 828\n", "})" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Filter out examples with empty or NaN values\n", "\n", "def filter_empty_or_nan(example):\n", " return all(\n", " value is not None and str(value).strip() != \"\" for value in example.values()\n", " )\n", "\n", "filtered_rag_dataset = combined_rag_dataset.filter(filter_empty_or_nan).shuffle(seed=42)\n", "filtered_rag_dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataset({\n", " features: ['anchor', 'positive', 'negative'],\n", " num_rows: 828\n", "})\n", "Dataset({\n", " features: ['anchor', 'positive'],\n", " num_rows: 828\n", "})\n" ] } ], "source": [ "# Rename, select and reorder columns according to the expected format for the SentenceTransformer and CrossEncoder models\n", "\n", "def rename_and_reorder_columns(dataset, rename_map, selected_columns):\n", " for old_name, new_name in rename_map.items():\n", " if old_name in dataset.column_names:\n", " dataset = dataset.rename_column(old_name, new_name)\n", " dataset = dataset.select_columns(selected_columns)\n", " return dataset\n", "\n", "clean_rag_dataset_biencoder = rename_and_reorder_columns(\n", " filtered_rag_dataset,\n", " rename_map={\"context\": \"anchor\", \"positive_retrieval\": \"positive\", \"negative_retrieval\": \"negative\"},\n", " selected_columns=[\"anchor\", \"positive\", \"negative\"],\n", ")\n", "\n", "clean_rag_dataset_crossencoder = rename_and_reorder_columns(\n", " filtered_rag_dataset,\n", " rename_map={\"context\": \"anchor\", \"positive_retrieval\": \"positive\"}, #TODO\n", " selected_columns=[\"anchor\", \"positive\"],\n", ")\n", "\n", "print(clean_rag_dataset_biencoder)\n", "print(clean_rag_dataset_crossencoder)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at Snowflake/snowflake-arctic-embed-m-v1.5 and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "406c4d22f43f41d592d3b94da2955444", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/828 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "Dataset({\n", " features: ['anchor', 'positive', 'score'],\n", " num_rows: 828\n", "})" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Add scores to train the CrossEncoder model, which requires sentence pairs with a score indicating how related they are.\n", "# Check the available models: https://huggingface.co/spaces/mteb/leaderboard\n", "\n", "model_reranking = CrossEncoder(\n", " model_name=\"Snowflake/snowflake-arctic-embed-m-v1.5\", device=device\n", ")\n", "\n", "def add_reranking_scores(batch):\n", " pairs = list(zip(batch[\"anchor\"], batch[\"positive\"]))\n", " batch[\"score\"] = model_reranking.predict(pairs)\n", " return batch\n", "\n", "clean_rag_dataset_crossencoder = clean_rag_dataset_crossencoder.map(\n", " add_reranking_scores, batched=True, batch_size=250\n", ")\n", "clean_rag_dataset_crossencoder" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['anchor', 'positive', 'negative'],\n", " num_rows: 662\n", " })\n", " eval: Dataset({\n", " features: ['anchor', 'positive', 'negative'],\n", " num_rows: 166\n", " })\n", "})\n", "DatasetDict({\n", " train: Dataset({\n", " features: ['anchor', 'positive', 'score'],\n", " num_rows: 662\n", " })\n", " eval: Dataset({\n", " features: ['anchor', 'positive', 'score'],\n", " num_rows: 166\n", " })\n", "})\n" ] } ], "source": [ "# Split the datasets into training and evaluation sets\n", "def split_dataset(dataset, train_size=0.8, seed=42):\n", " train_eval_split = dataset.train_test_split(test_size=1 - train_size, seed=seed)\n", "\n", " dataset_dict = DatasetDict(\n", " {\"train\": train_eval_split[\"train\"], \"eval\": train_eval_split[\"test\"]}\n", " )\n", "\n", " return dataset_dict\n", "\n", "dataset_rag_biencoder = split_dataset(clean_rag_dataset_biencoder)\n", "dataset_rag_crossencoder = split_dataset(clean_rag_dataset_crossencoder)\n", "\n", "print(dataset_rag_biencoder)\n", "print(dataset_rag_crossencoder)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train the Bi-Encoder model for Retrieval" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Load the base model and create the SentenceTransformer model\n", "model_biencoder = SentenceTransformer(\n", " MODEL,\n", " model_card_data=SentenceTransformerModelCardData(\n", " language=\"en\",\n", " license=\"apache-2.0\",\n", " model_name=MODEL_NAME_BIENCODER,\n", " ),\n", ")\n", "model_biencoder.gradient_checkpointing_enable() # Enable gradient checkpointing to save memory" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Select the TripleLoss loss function which requires sentence triplets (anchor, positive, negative)\n", "# Check the available losses: https://sbert.net/docs/sentence_transformer/loss_overview.html\n", "\n", "loss_biencoder = TripletLoss" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/sdiazlor/.pyenv/versions/3.11.4/envs/distilabel-tutorials/lib/python3.11/site-packages/transformers/training_args.py:2243: UserWarning: `use_mps_device` is deprecated and will be removed in version 5.0 of 🤗 Transformers. `mps` device will be used by default if available similar to the way `cuda` device is used.Therefore, no action from user is required. \n", " warnings.warn(\n" ] } ], "source": [ "# Define the training arguments for the SentenceTransformer model\n", "# Customize them as needed for your requirements\n", "\n", "training_args = SentenceTransformerTrainingArguments(\n", " output_dir=f\"models/{MODEL_NAME_BIENCODER}\",\n", " num_train_epochs=3,\n", " per_device_train_batch_size=4,\n", " gradient_accumulation_steps=4,\n", " per_device_eval_batch_size=4,\n", " warmup_ratio=0.1,\n", " learning_rate=2e-5,\n", " lr_scheduler_type=\"cosine\",\n", " fp16=False, # or True if stable on your MPS device\n", " bf16=False,\n", " batch_sampler=BatchSamplers.NO_DUPLICATES,\n", " eval_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " save_total_limit=2,\n", " logging_steps=100,\n", " load_best_model_at_end=True,\n", " use_mps_device=(device == \"mps\"),\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Define the evaluator to assess the performance of the model\n", "triplet_evaluator = TripletEvaluator(\n", " anchors=dataset_rag_biencoder[\"eval\"][\"anchor\"],\n", " positives=dataset_rag_biencoder[\"eval\"][\"positive\"],\n", " negatives=dataset_rag_biencoder[\"eval\"][\"negative\"],\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/sdiazlor/.pyenv/versions/3.11.4/envs/distilabel-tutorials/lib/python3.11/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.\n", " with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]\n" ] }, { "data": { "text/html": [ "\n", "
Epoch | \n", "Training Loss | \n", "Validation Loss | \n", "Cosine Accuracy | \n", "
---|---|---|---|
1 | \n", "No log | \n", "3.655929 | \n", "0.969880 | \n", "
2 | \n", "14.374000 | \n", "3.498395 | \n", "0.981928 | \n", "
"
],
"text/plain": [
"