add example for RAG
Browse files- .gitignore +4 -1
- examples/fine-tune-modernbert-rag.ipynb +980 -0
.gitignore
CHANGED
@@ -161,4 +161,7 @@ cython_debug/
|
|
161 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
162 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
163 |
#.idea/
|
164 |
-
.DS_Store
|
|
|
|
|
|
|
|
161 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
162 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
163 |
#.idea/
|
164 |
+
.DS_Store
|
165 |
+
|
166 |
+
# examples
|
167 |
+
models/
|
examples/fine-tune-modernbert-rag.ipynb
ADDED
@@ -0,0 +1,980 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# Fine-tune ModernBERT with Synthetic Data for RAG\n",
|
8 |
+
"\n",
|
9 |
+
"This notebook demonstrates the fine-tuning process of `modernbert-embed-base` using synthetic data tailored for the Retrieval-Augmented Generation (RAG) model.\n",
|
10 |
+
"\n",
|
11 |
+
"It provides a complete walkthrough of the fine-tuning process after generating synthetic data using the Synthetic Data Generator. For a comprehensive explanation of the methodology and additional details, refer to the blog post: [Fine-tune ModernBERT with Synthetic Data for RAG](https://huggingface.co/blog/fine-tune-modernbert-with-synthetic-data-for-rag)."
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "markdown",
|
16 |
+
"metadata": {},
|
17 |
+
"source": [
|
18 |
+
"## Getting Started"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "markdown",
|
23 |
+
"metadata": {},
|
24 |
+
"source": [
|
25 |
+
"### Install the Dependencies"
|
26 |
+
]
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"cell_type": "code",
|
30 |
+
"execution_count": null,
|
31 |
+
"metadata": {},
|
32 |
+
"outputs": [],
|
33 |
+
"source": [
|
34 |
+
"!pip install torch\n",
|
35 |
+
"!pip install datasets\n",
|
36 |
+
"!pip install sentence-transformers\n",
|
37 |
+
"!pip install haystack-ai\n",
|
38 |
+
"!pip install git+https://github.com/huggingface/transformers.git # for the latest version of transformers"
|
39 |
+
]
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"cell_type": "markdown",
|
43 |
+
"metadata": {},
|
44 |
+
"source": [
|
45 |
+
"### Import the Required Libraries"
|
46 |
+
]
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"cell_type": "code",
|
50 |
+
"execution_count": null,
|
51 |
+
"metadata": {},
|
52 |
+
"outputs": [],
|
53 |
+
"source": [
|
54 |
+
"import torch\n",
|
55 |
+
"from torch.utils.data import DataLoader\n",
|
56 |
+
"\n",
|
57 |
+
"from datasets import load_dataset, concatenate_datasets, Dataset, DatasetDict\n",
|
58 |
+
"\n",
|
59 |
+
"\n",
|
60 |
+
"from sentence_transformers import (\n",
|
61 |
+
" SentenceTransformer,\n",
|
62 |
+
" SentenceTransformerModelCardData,\n",
|
63 |
+
" CrossEncoder,\n",
|
64 |
+
" InputExample,\n",
|
65 |
+
" SentenceTransformerTrainer,\n",
|
66 |
+
")\n",
|
67 |
+
"from sentence_transformers.losses import TripletLoss\n",
|
68 |
+
"from sentence_transformers.training_args import (\n",
|
69 |
+
" SentenceTransformerTrainingArguments,\n",
|
70 |
+
" BatchSamplers,\n",
|
71 |
+
")\n",
|
72 |
+
"from sentence_transformers.evaluation import TripletEvaluator\n",
|
73 |
+
"from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator\n",
|
74 |
+
"\n",
|
75 |
+
"\n",
|
76 |
+
"from haystack import Document, Pipeline\n",
|
77 |
+
"from haystack.document_stores.in_memory import InMemoryDocumentStore\n",
|
78 |
+
"from haystack.components.embedders import (\n",
|
79 |
+
" SentenceTransformersDocumentEmbedder,\n",
|
80 |
+
" SentenceTransformersTextEmbedder,\n",
|
81 |
+
")\n",
|
82 |
+
"from haystack.components.rankers import SentenceTransformersDiversityRanker\n",
|
83 |
+
"from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever\n",
|
84 |
+
"from haystack.components.builders import ChatPromptBuilder\n",
|
85 |
+
"from haystack.components.generators.chat import HuggingFaceAPIChatGenerator\n",
|
86 |
+
"from haystack.dataclasses import ChatMessage\n",
|
87 |
+
"from haystack.utils import Secret\n",
|
88 |
+
"from haystack.utils.hf import HFGenerationAPIType"
|
89 |
+
]
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"cell_type": "markdown",
|
93 |
+
"metadata": {},
|
94 |
+
"source": [
|
95 |
+
"### Configure the Environment"
|
96 |
+
]
|
97 |
+
},
|
98 |
+
{
|
99 |
+
"cell_type": "code",
|
100 |
+
"execution_count": null,
|
101 |
+
"metadata": {},
|
102 |
+
"outputs": [],
|
103 |
+
"source": [
|
104 |
+
"MODEL = \"nomic-ai/modernbert-embed-base\"\n",
|
105 |
+
"REPO_NAME = \"sdiazlor\" # your HF username here\n",
|
106 |
+
"MODEL_NAME_BIENCODER = \"modernbert-embed-base-biencoder-human-rights\"\n",
|
107 |
+
"MODEL_NAME_CROSSENCODER = \"modernbert-embed-base-crossencoder-human-rights\""
|
108 |
+
]
|
109 |
+
},
|
110 |
+
{
|
111 |
+
"cell_type": "code",
|
112 |
+
"execution_count": null,
|
113 |
+
"metadata": {},
|
114 |
+
"outputs": [
|
115 |
+
{
|
116 |
+
"name": "stdout",
|
117 |
+
"output_type": "stream",
|
118 |
+
"text": [
|
119 |
+
"Using device: mps\n"
|
120 |
+
]
|
121 |
+
}
|
122 |
+
],
|
123 |
+
"source": [
|
124 |
+
"if torch.backends.mps.is_available():\n",
|
125 |
+
" device = \"mps\"\n",
|
126 |
+
"elif torch.cuda.is_available():\n",
|
127 |
+
" device = \"cuda\"\n",
|
128 |
+
"else:\n",
|
129 |
+
" device = \"cpu\"\n",
|
130 |
+
"\n",
|
131 |
+
"print(f\"Using device: {device}\")"
|
132 |
+
]
|
133 |
+
},
|
134 |
+
{
|
135 |
+
"cell_type": "markdown",
|
136 |
+
"metadata": {},
|
137 |
+
"source": [
|
138 |
+
"## Pre-process the Synthetic Data"
|
139 |
+
]
|
140 |
+
},
|
141 |
+
{
|
142 |
+
"cell_type": "code",
|
143 |
+
"execution_count": null,
|
144 |
+
"metadata": {},
|
145 |
+
"outputs": [
|
146 |
+
{
|
147 |
+
"data": {
|
148 |
+
"text/plain": [
|
149 |
+
"Dataset({\n",
|
150 |
+
" features: ['context', 'question', 'response', 'positive_retrieval', 'negative_retrieval', 'positive_reranking', 'negative_reranking'],\n",
|
151 |
+
" num_rows: 1000\n",
|
152 |
+
"})"
|
153 |
+
]
|
154 |
+
},
|
155 |
+
"execution_count": 2,
|
156 |
+
"metadata": {},
|
157 |
+
"output_type": "execute_result"
|
158 |
+
}
|
159 |
+
],
|
160 |
+
"source": [
|
161 |
+
"# Combine the generated datasets from files and prompts\n",
|
162 |
+
"\n",
|
163 |
+
"dataset_rag_from_file = load_dataset(f\"{REPO_NAME}/rag-human-rights-from-files\", split=\"train\")\n",
|
164 |
+
"dataset_rag_from_prompt = load_dataset(f\"{REPO_NAME}/rag-human-rights-from-prompt\", split=\"train\")\n",
|
165 |
+
"\n",
|
166 |
+
"combined_rag_dataset = concatenate_datasets(\n",
|
167 |
+
" [dataset_rag_from_file, dataset_rag_from_prompt]\n",
|
168 |
+
")\n",
|
169 |
+
"\n",
|
170 |
+
"combined_rag_dataset"
|
171 |
+
]
|
172 |
+
},
|
173 |
+
{
|
174 |
+
"cell_type": "code",
|
175 |
+
"execution_count": null,
|
176 |
+
"metadata": {},
|
177 |
+
"outputs": [
|
178 |
+
{
|
179 |
+
"data": {
|
180 |
+
"text/plain": [
|
181 |
+
"Dataset({\n",
|
182 |
+
" features: ['context', 'question', 'response', 'positive_retrieval', 'negative_retrieval', 'positive_reranking', 'negative_reranking'],\n",
|
183 |
+
" num_rows: 828\n",
|
184 |
+
"})"
|
185 |
+
]
|
186 |
+
},
|
187 |
+
"execution_count": 6,
|
188 |
+
"metadata": {},
|
189 |
+
"output_type": "execute_result"
|
190 |
+
}
|
191 |
+
],
|
192 |
+
"source": [
|
193 |
+
"# Filter out examples with empty or NaN values\n",
|
194 |
+
"\n",
|
195 |
+
"def filter_empty_or_nan(example):\n",
|
196 |
+
" return all(\n",
|
197 |
+
" value is not None and str(value).strip() != \"\" for value in example.values()\n",
|
198 |
+
" )\n",
|
199 |
+
"\n",
|
200 |
+
"filtered_rag_dataset = combined_rag_dataset.filter(filter_empty_or_nan).shuffle(seed=42)\n",
|
201 |
+
"filtered_rag_dataset"
|
202 |
+
]
|
203 |
+
},
|
204 |
+
{
|
205 |
+
"cell_type": "code",
|
206 |
+
"execution_count": null,
|
207 |
+
"metadata": {},
|
208 |
+
"outputs": [
|
209 |
+
{
|
210 |
+
"name": "stdout",
|
211 |
+
"output_type": "stream",
|
212 |
+
"text": [
|
213 |
+
"Dataset({\n",
|
214 |
+
" features: ['anchor', 'positive', 'negative'],\n",
|
215 |
+
" num_rows: 828\n",
|
216 |
+
"})\n",
|
217 |
+
"Dataset({\n",
|
218 |
+
" features: ['anchor', 'positive'],\n",
|
219 |
+
" num_rows: 828\n",
|
220 |
+
"})\n"
|
221 |
+
]
|
222 |
+
}
|
223 |
+
],
|
224 |
+
"source": [
|
225 |
+
"# Rename, select and reorder columns according to the expected format for the SentenceTransformer and CrossEncoder models\n",
|
226 |
+
"\n",
|
227 |
+
"def rename_and_reorder_columns(dataset, rename_map, selected_columns):\n",
|
228 |
+
" for old_name, new_name in rename_map.items():\n",
|
229 |
+
" if old_name in dataset.column_names:\n",
|
230 |
+
" dataset = dataset.rename_column(old_name, new_name)\n",
|
231 |
+
" dataset = dataset.select_columns(selected_columns)\n",
|
232 |
+
" return dataset\n",
|
233 |
+
"\n",
|
234 |
+
"clean_rag_dataset_biencoder = rename_and_reorder_columns(\n",
|
235 |
+
" filtered_rag_dataset,\n",
|
236 |
+
" rename_map={\"context\": \"anchor\", \"positive_retrieval\": \"positive\", \"negative_retrieval\": \"negative\"},\n",
|
237 |
+
" selected_columns=[\"anchor\", \"positive\", \"negative\"],\n",
|
238 |
+
")\n",
|
239 |
+
"\n",
|
240 |
+
"clean_rag_dataset_crossencoder = rename_and_reorder_columns(\n",
|
241 |
+
" filtered_rag_dataset,\n",
|
242 |
+
" rename_map={\"context\": \"anchor\", \"positive_retrieval\": \"positive\"}, #TODO\n",
|
243 |
+
" selected_columns=[\"anchor\", \"positive\"],\n",
|
244 |
+
")\n",
|
245 |
+
"\n",
|
246 |
+
"print(clean_rag_dataset_biencoder)\n",
|
247 |
+
"print(clean_rag_dataset_crossencoder)"
|
248 |
+
]
|
249 |
+
},
|
250 |
+
{
|
251 |
+
"cell_type": "code",
|
252 |
+
"execution_count": null,
|
253 |
+
"metadata": {},
|
254 |
+
"outputs": [
|
255 |
+
{
|
256 |
+
"name": "stderr",
|
257 |
+
"output_type": "stream",
|
258 |
+
"text": [
|
259 |
+
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at Snowflake/snowflake-arctic-embed-m-v1.5 and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']\n",
|
260 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
261 |
+
]
|
262 |
+
},
|
263 |
+
{
|
264 |
+
"data": {
|
265 |
+
"application/vnd.jupyter.widget-view+json": {
|
266 |
+
"model_id": "406c4d22f43f41d592d3b94da2955444",
|
267 |
+
"version_major": 2,
|
268 |
+
"version_minor": 0
|
269 |
+
},
|
270 |
+
"text/plain": [
|
271 |
+
"Map: 0%| | 0/828 [00:00<?, ? examples/s]"
|
272 |
+
]
|
273 |
+
},
|
274 |
+
"metadata": {},
|
275 |
+
"output_type": "display_data"
|
276 |
+
},
|
277 |
+
{
|
278 |
+
"data": {
|
279 |
+
"text/plain": [
|
280 |
+
"Dataset({\n",
|
281 |
+
" features: ['anchor', 'positive', 'score'],\n",
|
282 |
+
" num_rows: 828\n",
|
283 |
+
"})"
|
284 |
+
]
|
285 |
+
},
|
286 |
+
"execution_count": 8,
|
287 |
+
"metadata": {},
|
288 |
+
"output_type": "execute_result"
|
289 |
+
}
|
290 |
+
],
|
291 |
+
"source": [
|
292 |
+
"# Add scores to train the CrossEncoder model, which requires sentence pairs with a score indicating how related they are.\n",
|
293 |
+
"# Check the available models: https://huggingface.co/spaces/mteb/leaderboard\n",
|
294 |
+
"\n",
|
295 |
+
"model_reranking = CrossEncoder(\n",
|
296 |
+
" model_name=\"Snowflake/snowflake-arctic-embed-m-v1.5\", device=device\n",
|
297 |
+
")\n",
|
298 |
+
"\n",
|
299 |
+
"def add_reranking_scores(batch):\n",
|
300 |
+
" pairs = list(zip(batch[\"anchor\"], batch[\"positive\"]))\n",
|
301 |
+
" batch[\"score\"] = model_reranking.predict(pairs)\n",
|
302 |
+
" return batch\n",
|
303 |
+
"\n",
|
304 |
+
"clean_rag_dataset_crossencoder = clean_rag_dataset_crossencoder.map(\n",
|
305 |
+
" add_reranking_scores, batched=True, batch_size=250\n",
|
306 |
+
")\n",
|
307 |
+
"clean_rag_dataset_crossencoder"
|
308 |
+
]
|
309 |
+
},
|
310 |
+
{
|
311 |
+
"cell_type": "code",
|
312 |
+
"execution_count": null,
|
313 |
+
"metadata": {},
|
314 |
+
"outputs": [
|
315 |
+
{
|
316 |
+
"name": "stdout",
|
317 |
+
"output_type": "stream",
|
318 |
+
"text": [
|
319 |
+
"DatasetDict({\n",
|
320 |
+
" train: Dataset({\n",
|
321 |
+
" features: ['anchor', 'positive', 'negative'],\n",
|
322 |
+
" num_rows: 662\n",
|
323 |
+
" })\n",
|
324 |
+
" eval: Dataset({\n",
|
325 |
+
" features: ['anchor', 'positive', 'negative'],\n",
|
326 |
+
" num_rows: 166\n",
|
327 |
+
" })\n",
|
328 |
+
"})\n",
|
329 |
+
"DatasetDict({\n",
|
330 |
+
" train: Dataset({\n",
|
331 |
+
" features: ['anchor', 'positive', 'score'],\n",
|
332 |
+
" num_rows: 662\n",
|
333 |
+
" })\n",
|
334 |
+
" eval: Dataset({\n",
|
335 |
+
" features: ['anchor', 'positive', 'score'],\n",
|
336 |
+
" num_rows: 166\n",
|
337 |
+
" })\n",
|
338 |
+
"})\n"
|
339 |
+
]
|
340 |
+
}
|
341 |
+
],
|
342 |
+
"source": [
|
343 |
+
"# Split the datasets into training and evaluation sets\n",
|
344 |
+
"def split_dataset(dataset, train_size=0.8, seed=42):\n",
|
345 |
+
" train_eval_split = dataset.train_test_split(test_size=1 - train_size, seed=seed)\n",
|
346 |
+
"\n",
|
347 |
+
" dataset_dict = DatasetDict(\n",
|
348 |
+
" {\"train\": train_eval_split[\"train\"], \"eval\": train_eval_split[\"test\"]}\n",
|
349 |
+
" )\n",
|
350 |
+
"\n",
|
351 |
+
" return dataset_dict\n",
|
352 |
+
"\n",
|
353 |
+
"dataset_rag_biencoder = split_dataset(clean_rag_dataset_biencoder)\n",
|
354 |
+
"dataset_rag_crossencoder = split_dataset(clean_rag_dataset_crossencoder)\n",
|
355 |
+
"\n",
|
356 |
+
"print(dataset_rag_biencoder)\n",
|
357 |
+
"print(dataset_rag_crossencoder)"
|
358 |
+
]
|
359 |
+
},
|
360 |
+
{
|
361 |
+
"cell_type": "markdown",
|
362 |
+
"metadata": {},
|
363 |
+
"source": [
|
364 |
+
"## Train the Bi-Encoder model for Retrieval"
|
365 |
+
]
|
366 |
+
},
|
367 |
+
{
|
368 |
+
"cell_type": "code",
|
369 |
+
"execution_count": null,
|
370 |
+
"metadata": {},
|
371 |
+
"outputs": [],
|
372 |
+
"source": [
|
373 |
+
"# Load the base model and create the SentenceTransformer model\n",
|
374 |
+
"model_biencoder = SentenceTransformer(\n",
|
375 |
+
" MODEL,\n",
|
376 |
+
" model_card_data=SentenceTransformerModelCardData(\n",
|
377 |
+
" language=\"en\",\n",
|
378 |
+
" license=\"apache-2.0\",\n",
|
379 |
+
" model_name=MODEL_NAME_BIENCODER,\n",
|
380 |
+
" ),\n",
|
381 |
+
")\n",
|
382 |
+
"model_biencoder.gradient_checkpointing_enable() # Enable gradient checkpointing to save memory"
|
383 |
+
]
|
384 |
+
},
|
385 |
+
{
|
386 |
+
"cell_type": "code",
|
387 |
+
"execution_count": null,
|
388 |
+
"metadata": {},
|
389 |
+
"outputs": [],
|
390 |
+
"source": [
|
391 |
+
"# Select the TripleLoss loss function which requires sentence triplets (anchor, positive, negative)\n",
|
392 |
+
"# Check the available losses: https://sbert.net/docs/sentence_transformer/loss_overview.html\n",
|
393 |
+
"\n",
|
394 |
+
"loss_biencoder = TripletLoss"
|
395 |
+
]
|
396 |
+
},
|
397 |
+
{
|
398 |
+
"cell_type": "code",
|
399 |
+
"execution_count": null,
|
400 |
+
"metadata": {},
|
401 |
+
"outputs": [
|
402 |
+
{
|
403 |
+
"name": "stderr",
|
404 |
+
"output_type": "stream",
|
405 |
+
"text": [
|
406 |
+
"/Users/sdiazlor/.pyenv/versions/3.11.4/envs/distilabel-tutorials/lib/python3.11/site-packages/transformers/training_args.py:2243: UserWarning: `use_mps_device` is deprecated and will be removed in version 5.0 of 🤗 Transformers. `mps` device will be used by default if available similar to the way `cuda` device is used.Therefore, no action from user is required. \n",
|
407 |
+
" warnings.warn(\n"
|
408 |
+
]
|
409 |
+
}
|
410 |
+
],
|
411 |
+
"source": [
|
412 |
+
"# Define the training arguments for the SentenceTransformer model\n",
|
413 |
+
"# Customize them as needed for your requirements\n",
|
414 |
+
"\n",
|
415 |
+
"training_args = SentenceTransformerTrainingArguments(\n",
|
416 |
+
" output_dir=f\"models/{MODEL_NAME_BIENCODER}\",\n",
|
417 |
+
" num_train_epochs=3,\n",
|
418 |
+
" per_device_train_batch_size=4,\n",
|
419 |
+
" gradient_accumulation_steps=4,\n",
|
420 |
+
" per_device_eval_batch_size=4,\n",
|
421 |
+
" warmup_ratio=0.1,\n",
|
422 |
+
" learning_rate=2e-5,\n",
|
423 |
+
" lr_scheduler_type=\"cosine\",\n",
|
424 |
+
" fp16=False, # or True if stable on your MPS device\n",
|
425 |
+
" bf16=False,\n",
|
426 |
+
" batch_sampler=BatchSamplers.NO_DUPLICATES,\n",
|
427 |
+
" eval_strategy=\"epoch\",\n",
|
428 |
+
" save_strategy=\"epoch\",\n",
|
429 |
+
" save_total_limit=2,\n",
|
430 |
+
" logging_steps=100,\n",
|
431 |
+
" load_best_model_at_end=True,\n",
|
432 |
+
" use_mps_device=(device == \"mps\"),\n",
|
433 |
+
")"
|
434 |
+
]
|
435 |
+
},
|
436 |
+
{
|
437 |
+
"cell_type": "code",
|
438 |
+
"execution_count": null,
|
439 |
+
"metadata": {},
|
440 |
+
"outputs": [],
|
441 |
+
"source": [
|
442 |
+
"# Define the evaluator to assess the performance of the model\n",
|
443 |
+
"triplet_evaluator = TripletEvaluator(\n",
|
444 |
+
" anchors=dataset_rag_biencoder[\"eval\"][\"anchor\"],\n",
|
445 |
+
" positives=dataset_rag_biencoder[\"eval\"][\"positive\"],\n",
|
446 |
+
" negatives=dataset_rag_biencoder[\"eval\"][\"negative\"],\n",
|
447 |
+
")"
|
448 |
+
]
|
449 |
+
},
|
450 |
+
{
|
451 |
+
"cell_type": "code",
|
452 |
+
"execution_count": null,
|
453 |
+
"metadata": {},
|
454 |
+
"outputs": [
|
455 |
+
{
|
456 |
+
"name": "stderr",
|
457 |
+
"output_type": "stream",
|
458 |
+
"text": [
|
459 |
+
"/Users/sdiazlor/.pyenv/versions/3.11.4/envs/distilabel-tutorials/lib/python3.11/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.\n",
|
460 |
+
" with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]\n"
|
461 |
+
]
|
462 |
+
},
|
463 |
+
{
|
464 |
+
"data": {
|
465 |
+
"text/html": [
|
466 |
+
"\n",
|
467 |
+
" <div>\n",
|
468 |
+
" \n",
|
469 |
+
" <progress value='123' max='123' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
470 |
+
" [123/123 25:34, Epoch 2/3]\n",
|
471 |
+
" </div>\n",
|
472 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
473 |
+
" <thead>\n",
|
474 |
+
" <tr style=\"text-align: left;\">\n",
|
475 |
+
" <th>Epoch</th>\n",
|
476 |
+
" <th>Training Loss</th>\n",
|
477 |
+
" <th>Validation Loss</th>\n",
|
478 |
+
" <th>Cosine Accuracy</th>\n",
|
479 |
+
" </tr>\n",
|
480 |
+
" </thead>\n",
|
481 |
+
" <tbody>\n",
|
482 |
+
" <tr>\n",
|
483 |
+
" <td>1</td>\n",
|
484 |
+
" <td>No log</td>\n",
|
485 |
+
" <td>3.655929</td>\n",
|
486 |
+
" <td>0.969880</td>\n",
|
487 |
+
" </tr>\n",
|
488 |
+
" <tr>\n",
|
489 |
+
" <td>2</td>\n",
|
490 |
+
" <td>14.374000</td>\n",
|
491 |
+
" <td>3.498395</td>\n",
|
492 |
+
" <td>0.981928</td>\n",
|
493 |
+
" </tr>\n",
|
494 |
+
" </tbody>\n",
|
495 |
+
"</table><p>"
|
496 |
+
],
|
497 |
+
"text/plain": [
|
498 |
+
"<IPython.core.display.HTML object>"
|
499 |
+
]
|
500 |
+
},
|
501 |
+
"metadata": {},
|
502 |
+
"output_type": "display_data"
|
503 |
+
},
|
504 |
+
{
|
505 |
+
"data": {
|
506 |
+
"application/vnd.jupyter.widget-view+json": {
|
507 |
+
"model_id": "faad6e9752f34babadff7a966ae55d87",
|
508 |
+
"version_major": 2,
|
509 |
+
"version_minor": 0
|
510 |
+
},
|
511 |
+
"text/plain": [
|
512 |
+
"Computing widget examples: 0%| | 0/1 [00:00<?, ?example/s]"
|
513 |
+
]
|
514 |
+
},
|
515 |
+
"metadata": {},
|
516 |
+
"output_type": "display_data"
|
517 |
+
},
|
518 |
+
{
|
519 |
+
"name": "stderr",
|
520 |
+
"output_type": "stream",
|
521 |
+
"text": [
|
522 |
+
"/Users/sdiazlor/.pyenv/versions/3.11.4/envs/distilabel-tutorials/lib/python3.11/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.\n",
|
523 |
+
" with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]\n",
|
524 |
+
"/Users/sdiazlor/.pyenv/versions/3.11.4/envs/distilabel-tutorials/lib/python3.11/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.\n",
|
525 |
+
" with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]\n"
|
526 |
+
]
|
527 |
+
}
|
528 |
+
],
|
529 |
+
"source": [
|
530 |
+
"# Train the model. This will take some time depending on the size of the dataset and the model\n",
|
531 |
+
"# Remember to adjust the training arguments according to your requirements\n",
|
532 |
+
"\n",
|
533 |
+
"trainer = SentenceTransformerTrainer(\n",
|
534 |
+
" model=model,\n",
|
535 |
+
" args=training_args,\n",
|
536 |
+
" train_dataset=dataset_rag_biencoder[\"train\"],\n",
|
537 |
+
" eval_dataset=dataset_rag_biencoder[\"eval\"],\n",
|
538 |
+
" loss=loss_biencoder,\n",
|
539 |
+
" evaluator=triplet_evaluator,\n",
|
540 |
+
")\n",
|
541 |
+
"trainer.train()"
|
542 |
+
]
|
543 |
+
},
|
544 |
+
{
|
545 |
+
"cell_type": "code",
|
546 |
+
"execution_count": null,
|
547 |
+
"metadata": {},
|
548 |
+
"outputs": [],
|
549 |
+
"source": [
|
550 |
+
"# Save the model to the local directory and push it to the Hub\n",
|
551 |
+
"model_biencoder.save_pretrained(f\"models/{MODEL_NAME_BIENCODER}\")\n",
|
552 |
+
"model_biencoder.push_to_hub(f\"{REPO_NAME}/{MODEL_NAME_BIENCODER}\")"
|
553 |
+
]
|
554 |
+
},
|
555 |
+
{
|
556 |
+
"cell_type": "markdown",
|
557 |
+
"metadata": {},
|
558 |
+
"source": [
|
559 |
+
"## Train the Cross-Encoder model for Ranking"
|
560 |
+
]
|
561 |
+
},
|
562 |
+
{
|
563 |
+
"cell_type": "code",
|
564 |
+
"execution_count": null,
|
565 |
+
"metadata": {},
|
566 |
+
"outputs": [],
|
567 |
+
"source": [
|
568 |
+
"# Prepare the training and evaluation samples for the CrossEncoder model\n",
|
569 |
+
"\n",
|
570 |
+
"train_samples = []\n",
|
571 |
+
"for row in dataset_rag_crossencoder[\"train\"]:\n",
|
572 |
+
" # Suppose 'score' is a float or an integer that you want to predict\n",
|
573 |
+
" train_samples.append(\n",
|
574 |
+
" InputExample(texts=[row[\"anchor\"], row[\"positive\"]], label=float(row[\"score\"]))\n",
|
575 |
+
" )\n",
|
576 |
+
"\n",
|
577 |
+
"eval_samples = []\n",
|
578 |
+
"for row in dataset_rag_crossencoder[\"eval\"]:\n",
|
579 |
+
" eval_samples.append(\n",
|
580 |
+
" InputExample(texts=[row[\"anchor\"], row[\"positive\"]], label=float(row[\"score\"]))\n",
|
581 |
+
" )\n",
|
582 |
+
"\n",
|
583 |
+
"# Initialize the DataLoader for the training samples\n",
|
584 |
+
"train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=4)"
|
585 |
+
]
|
586 |
+
},
|
587 |
+
{
|
588 |
+
"cell_type": "code",
|
589 |
+
"execution_count": null,
|
590 |
+
"metadata": {},
|
591 |
+
"outputs": [],
|
592 |
+
"source": [
|
593 |
+
"# Initialize the CrossEncoder model. Set the number of labels to 1 for regression tasks\n",
|
594 |
+
"model_crossencoder = CrossEncoder(model_name=MODEL, num_labels=1)"
|
595 |
+
]
|
596 |
+
},
|
597 |
+
{
|
598 |
+
"cell_type": "code",
|
599 |
+
"execution_count": null,
|
600 |
+
"metadata": {},
|
601 |
+
"outputs": [],
|
602 |
+
"source": [
|
603 |
+
"# Define the evaluator\n",
|
604 |
+
"evaluator = CECorrelationEvaluator.from_input_examples(eval_samples)"
|
605 |
+
]
|
606 |
+
},
|
607 |
+
{
|
608 |
+
"cell_type": "code",
|
609 |
+
"execution_count": null,
|
610 |
+
"metadata": {},
|
611 |
+
"outputs": [
|
612 |
+
{
|
613 |
+
"data": {
|
614 |
+
"application/vnd.jupyter.widget-view+json": {
|
615 |
+
"model_id": "9517a852f3d34cff86808c4b10cf8973",
|
616 |
+
"version_major": 2,
|
617 |
+
"version_minor": 0
|
618 |
+
},
|
619 |
+
"text/plain": [
|
620 |
+
"Epoch: 0%| | 0/3 [00:00<?, ?it/s]"
|
621 |
+
]
|
622 |
+
},
|
623 |
+
"metadata": {},
|
624 |
+
"output_type": "display_data"
|
625 |
+
},
|
626 |
+
{
|
627 |
+
"data": {
|
628 |
+
"application/vnd.jupyter.widget-view+json": {
|
629 |
+
"model_id": "6e942043c5a24e77bd6172cb5492d2a7",
|
630 |
+
"version_major": 2,
|
631 |
+
"version_minor": 0
|
632 |
+
},
|
633 |
+
"text/plain": [
|
634 |
+
"Iteration: 0%| | 0/166 [00:00<?, ?it/s]"
|
635 |
+
]
|
636 |
+
},
|
637 |
+
"metadata": {},
|
638 |
+
"output_type": "display_data"
|
639 |
+
},
|
640 |
+
{
|
641 |
+
"data": {
|
642 |
+
"application/vnd.jupyter.widget-view+json": {
|
643 |
+
"model_id": "d039d5acf3ed424e9ff6d0b30b51aceb",
|
644 |
+
"version_major": 2,
|
645 |
+
"version_minor": 0
|
646 |
+
},
|
647 |
+
"text/plain": [
|
648 |
+
"Iteration: 0%| | 0/166 [00:00<?, ?it/s]"
|
649 |
+
]
|
650 |
+
},
|
651 |
+
"metadata": {},
|
652 |
+
"output_type": "display_data"
|
653 |
+
},
|
654 |
+
{
|
655 |
+
"data": {
|
656 |
+
"application/vnd.jupyter.widget-view+json": {
|
657 |
+
"model_id": "5fd5d0442b76448e8cab18b652e29ad8",
|
658 |
+
"version_major": 2,
|
659 |
+
"version_minor": 0
|
660 |
+
},
|
661 |
+
"text/plain": [
|
662 |
+
"Iteration: 0%| | 0/166 [00:00<?, ?it/s]"
|
663 |
+
]
|
664 |
+
},
|
665 |
+
"metadata": {},
|
666 |
+
"output_type": "display_data"
|
667 |
+
}
|
668 |
+
],
|
669 |
+
"source": [
|
670 |
+
"# Train the CrossEncoder model\n",
|
671 |
+
"\n",
|
672 |
+
"model_crossencoder.fit(\n",
|
673 |
+
" train_dataloader=train_dataloader,\n",
|
674 |
+
" evaluator=evaluator,\n",
|
675 |
+
" epochs=3,\n",
|
676 |
+
" warmup_steps=500,\n",
|
677 |
+
" output_path=f\"models/{MODEL_NAME_CROSSENCODER}\",\n",
|
678 |
+
" save_best_model=True,\n",
|
679 |
+
")"
|
680 |
+
]
|
681 |
+
},
|
682 |
+
{
|
683 |
+
"cell_type": "code",
|
684 |
+
"execution_count": null,
|
685 |
+
"metadata": {},
|
686 |
+
"outputs": [],
|
687 |
+
"source": [
|
688 |
+
"# Save the model to the local directory and push it to the Hub\n",
|
689 |
+
"model_crossencoder.save_pretrained(f\"models/{MODEL_NAME_CROSSENCODER}\")\n",
|
690 |
+
"model_crossencoder.push_to_hub(f\"{REPO_NAME}/{MODEL_NAME_CROSSENCODER}\")"
|
691 |
+
]
|
692 |
+
},
|
693 |
+
{
|
694 |
+
"cell_type": "markdown",
|
695 |
+
"metadata": {},
|
696 |
+
"source": [
|
697 |
+
"## Build the RAG Pipeline\n",
|
698 |
+
"\n",
|
699 |
+
"The following section is inspired by the Haystack tutorial, check it for further details: [Creating Your First QA Pipeline with Retrieval-Augmentation](https://haystack.deepset.ai/tutorials/27_first_rag_pipeline)"
|
700 |
+
]
|
701 |
+
},
|
702 |
+
{
|
703 |
+
"cell_type": "code",
|
704 |
+
"execution_count": null,
|
705 |
+
"metadata": {},
|
706 |
+
"outputs": [],
|
707 |
+
"source": [
|
708 |
+
"# Add the documents to the DocumentStore\n",
|
709 |
+
"# Use the already chunked documents from original datasets\n",
|
710 |
+
"\n",
|
711 |
+
"df = combined_rag_dataset.to_pandas()\n",
|
712 |
+
"df = df.drop_duplicates(subset=[\"context\"]) # drop duplicates based on \"context\" column\n",
|
713 |
+
"# df = df.sample(n=100, random_state=42) # optional: sample a subset of the dataset\n",
|
714 |
+
"dataset = Dataset.from_pandas(df)\n",
|
715 |
+
"\n",
|
716 |
+
"docs = [Document(content=doc[\"context\"]) for doc in dataset]"
|
717 |
+
]
|
718 |
+
},
|
719 |
+
{
|
720 |
+
"cell_type": "code",
|
721 |
+
"execution_count": null,
|
722 |
+
"metadata": {},
|
723 |
+
"outputs": [],
|
724 |
+
"source": [
|
725 |
+
"# Initialize the document store and store the documents with the embeddings using our bi-encoder model\n",
|
726 |
+
"\n",
|
727 |
+
"document_store = InMemoryDocumentStore()\n",
|
728 |
+
"doc_embedder = SentenceTransformersDocumentEmbedder(\n",
|
729 |
+
" model=f\"{REPO_NAME}/{MODEL_NAME_BIENCODER}\",\n",
|
730 |
+
")\n",
|
731 |
+
"doc_embedder.warm_up()\n",
|
732 |
+
"\n",
|
733 |
+
"docs_with_embeddings = doc_embedder.run(docs)\n",
|
734 |
+
"document_store.write_documents(docs_with_embeddings[\"documents\"])\n",
|
735 |
+
"\n",
|
736 |
+
"text_embedder = SentenceTransformersTextEmbedder(\n",
|
737 |
+
" model=f\"{REPO_NAME}/{MODEL_NAME_BIENCODER}\",\n",
|
738 |
+
")"
|
739 |
+
]
|
740 |
+
},
|
741 |
+
{
|
742 |
+
"cell_type": "code",
|
743 |
+
"execution_count": null,
|
744 |
+
"metadata": {},
|
745 |
+
"outputs": [],
|
746 |
+
"source": [
|
747 |
+
"# Initialize the retriever (our bi-encoder model) and the ranker (our cross-encoder model)\n",
|
748 |
+
"\n",
|
749 |
+
"retriever = InMemoryEmbeddingRetriever(document_store)\n",
|
750 |
+
"ranker = SentenceTransformersDiversityRanker(\n",
|
751 |
+
" model=f\"{REPO_NAME}/{MODEL_NAME_CROSSENCODER}\"\n",
|
752 |
+
")"
|
753 |
+
]
|
754 |
+
},
|
755 |
+
{
|
756 |
+
"cell_type": "code",
|
757 |
+
"execution_count": null,
|
758 |
+
"metadata": {},
|
759 |
+
"outputs": [],
|
760 |
+
"source": [
|
761 |
+
"# Define the prompt builder and the chat generator to interact with the models using the HF Serverless Inference API\n",
|
762 |
+
"\n",
|
763 |
+
"template = [\n",
|
764 |
+
" ChatMessage.from_user(\n",
|
765 |
+
" \"\"\"\n",
|
766 |
+
"Given the following information, answer the question.\n",
|
767 |
+
"\n",
|
768 |
+
"Context:\n",
|
769 |
+
"{% for document in documents %}\n",
|
770 |
+
" {{ document.content }}\n",
|
771 |
+
"{% endfor %}\n",
|
772 |
+
"\n",
|
773 |
+
"Question: {{question}}\n",
|
774 |
+
"Answer:\n",
|
775 |
+
"\"\"\"\n",
|
776 |
+
" )\n",
|
777 |
+
"]\n",
|
778 |
+
"\n",
|
779 |
+
"prompt_builder = ChatPromptBuilder(template=template)\n",
|
780 |
+
"\n",
|
781 |
+
"chat_generator = HuggingFaceAPIChatGenerator(\n",
|
782 |
+
" api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,\n",
|
783 |
+
" api_params={\"model\": \"meta-llama/Llama-3.1-8B-Instruct\"},\n",
|
784 |
+
" token=Secret.from_env_var(\"HF_TOKEN\"),\n",
|
785 |
+
")"
|
786 |
+
]
|
787 |
+
},
|
788 |
+
{
|
789 |
+
"cell_type": "code",
|
790 |
+
"execution_count": null,
|
791 |
+
"metadata": {},
|
792 |
+
"outputs": [],
|
793 |
+
"source": [
|
794 |
+
"# Initialize the pipeline with the components\n",
|
795 |
+
"\n",
|
796 |
+
"rag_pipeline = Pipeline()\n",
|
797 |
+
"rag_pipeline.add_component(\"text_embedder\", text_embedder)\n",
|
798 |
+
"rag_pipeline.add_component(\"retriever\", retriever)\n",
|
799 |
+
"rag_pipeline.add_component(\"ranker\", ranker)\n",
|
800 |
+
"rag_pipeline.add_component(\"prompt_builder\", prompt_builder)\n",
|
801 |
+
"rag_pipeline.add_component(\"llm\", chat_generator)"
|
802 |
+
]
|
803 |
+
},
|
804 |
+
{
|
805 |
+
"cell_type": "code",
|
806 |
+
"execution_count": null,
|
807 |
+
"metadata": {},
|
808 |
+
"outputs": [
|
809 |
+
{
|
810 |
+
"data": {
|
811 |
+
"text/plain": [
|
812 |
+
"<haystack.core.pipeline.pipeline.Pipeline object at 0x32e75b4d0>\n",
|
813 |
+
"🚅 Components\n",
|
814 |
+
" - text_embedder: SentenceTransformersTextEmbedder\n",
|
815 |
+
" - retriever: InMemoryEmbeddingRetriever\n",
|
816 |
+
" - ranker: SentenceTransformersDiversityRanker\n",
|
817 |
+
" - prompt_builder: ChatPromptBuilder\n",
|
818 |
+
" - llm: HuggingFaceAPIChatGenerator\n",
|
819 |
+
"🛤️ Connections\n",
|
820 |
+
" - text_embedder.embedding -> retriever.query_embedding (List[float])\n",
|
821 |
+
" - retriever.documents -> ranker.documents (List[Document])\n",
|
822 |
+
" - ranker.documents -> prompt_builder.documents (List[Document])\n",
|
823 |
+
" - prompt_builder.prompt -> llm.messages (List[ChatMessage])"
|
824 |
+
]
|
825 |
+
},
|
826 |
+
"execution_count": 12,
|
827 |
+
"metadata": {},
|
828 |
+
"output_type": "execute_result"
|
829 |
+
}
|
830 |
+
],
|
831 |
+
"source": [
|
832 |
+
"# Connect the components to each other\n",
|
833 |
+
"\n",
|
834 |
+
"rag_pipeline.connect(\"text_embedder.embedding\", \"retriever.query_embedding\")\n",
|
835 |
+
"rag_pipeline.connect(\"retriever.documents\", \"ranker.documents\")\n",
|
836 |
+
"rag_pipeline.connect(\"ranker\", \"prompt_builder\")\n",
|
837 |
+
"rag_pipeline.connect(\"prompt_builder.prompt\", \"llm.messages\")"
|
838 |
+
]
|
839 |
+
},
|
840 |
+
{
|
841 |
+
"cell_type": "code",
|
842 |
+
"execution_count": null,
|
843 |
+
"metadata": {},
|
844 |
+
"outputs": [
|
845 |
+
{
|
846 |
+
"data": {
|
847 |
+
"application/vnd.jupyter.widget-view+json": {
|
848 |
+
"model_id": "80c813c847524f1493067f6dbe65c725",
|
849 |
+
"version_major": 2,
|
850 |
+
"version_minor": 0
|
851 |
+
},
|
852 |
+
"text/plain": [
|
853 |
+
"Batches: 0%| | 0/1 [00:00<?, ?it/s]"
|
854 |
+
]
|
855 |
+
},
|
856 |
+
"metadata": {},
|
857 |
+
"output_type": "display_data"
|
858 |
+
},
|
859 |
+
{
|
860 |
+
"name": "stdout",
|
861 |
+
"output_type": "stream",
|
862 |
+
"text": [
|
863 |
+
"It seems that there is not enough information given in the human rights protocols provided to accurately answer the question. However, we can inform you that there are several types of human rights documents that this could be referring too. Event the most widely respected declared world document on human rights for Example - Exernal and some Individual (Part 1 Art.) and some other attempted Separation apart include: The convention lists several key rights such as \n",
|
864 |
+
"\n",
|
865 |
+
"1. Right to Life \n",
|
866 |
+
"2. Right to Liberty and Security \n",
|
867 |
+
"3. Freedom from Torture \n",
|
868 |
+
"4. Freedom from Slavery \n",
|
869 |
+
"5. Right to a Fair Trial \n",
|
870 |
+
"6. No Punishment without Law \n",
|
871 |
+
"7. Respect for Family Life \n",
|
872 |
+
"... (and throughout given information 44 protocals - are actually chapter and not... How is the answer \n",
|
873 |
+
" \n",
|
874 |
+
"\n",
|
875 |
+
"Not possible to answer your question due to lack of information, however we can tell you Event the most widely respected declared world document on human rights.\n"
|
876 |
+
]
|
877 |
+
}
|
878 |
+
],
|
879 |
+
"source": [
|
880 |
+
"# Make a query to the pipeline without references included in your documentation\n",
|
881 |
+
"question = \"How many human rights there are?\"\n",
|
882 |
+
"\n",
|
883 |
+
"response = rag_pipeline.run(\n",
|
884 |
+
" {\n",
|
885 |
+
" \"text_embedder\": {\"text\": question},\n",
|
886 |
+
" \"prompt_builder\": {\"question\": question},\n",
|
887 |
+
" \"ranker\": {\"query\": question},\n",
|
888 |
+
" }\n",
|
889 |
+
")\n",
|
890 |
+
"\n",
|
891 |
+
"print(response[\"llm\"][\"replies\"][0].text)"
|
892 |
+
]
|
893 |
+
},
|
894 |
+
{
|
895 |
+
"cell_type": "code",
|
896 |
+
"execution_count": null,
|
897 |
+
"metadata": {},
|
898 |
+
"outputs": [
|
899 |
+
{
|
900 |
+
"data": {
|
901 |
+
"application/vnd.jupyter.widget-view+json": {
|
902 |
+
"model_id": "2995f14154d148589129a3f449adc5d5",
|
903 |
+
"version_major": 2,
|
904 |
+
"version_minor": 0
|
905 |
+
},
|
906 |
+
"text/plain": [
|
907 |
+
"Batches: 0%| | 0/1 [00:00<?, ?it/s]"
|
908 |
+
]
|
909 |
+
},
|
910 |
+
"metadata": {},
|
911 |
+
"output_type": "display_data"
|
912 |
+
},
|
913 |
+
{
|
914 |
+
"name": "stdout",
|
915 |
+
"output_type": "stream",
|
916 |
+
"text": [
|
917 |
+
"The information you provided does not directly list the \"Right of Fair Trial\" but looking under articles of the Convention for the Protection of Human Rights and Fundamental Freedoms, Article 6, also known as the Right to a Fair Trial, gives a clear idea.\n",
|
918 |
+
"\n",
|
919 |
+
" Article 6. Right to a fair Trial\n",
|
920 |
+
" \n",
|
921 |
+
"\n",
|
922 |
+
"1. Everyone is entitled to a fair and public hearing within a reasonable time by an independent and impartial tribunal established by law.\n",
|
923 |
+
" \n",
|
924 |
+
"2, everybody shall be presumed innocent until proven guilty by a final decision of a competent court.\n",
|
925 |
+
" \n",
|
926 |
+
"3. Everyone charged with a criminal offence has the following minimum rights:\n",
|
927 |
+
"\n",
|
928 |
+
" a to be informed promptly, in a language which he understands and in detail, of the charges, if any, against him.\n",
|
929 |
+
" b to have adequate time and facilities for the preparation of his defence.\n",
|
930 |
+
" c to defend himself in person or through legal assistance of his own choosing or, if he has not sufficient means to pay for legal assistance, to be given it free when the interests of justice so require.\n",
|
931 |
+
" d to be tried in his presence, and to defend himself in person or through legal assistance of his own choosing; to be informed, if he does not have legal assistance chosen or appointed under Article 5 Part 3 of this Convention, to communicate with the defence he has chosen\n",
|
932 |
+
" e to have the free assistance of an interpreter if he cannot understand or speak the language used in court.\n",
|
933 |
+
" \n",
|
934 |
+
" \n",
|
935 |
+
"4. Everyone sentenced has the right to, review by a higher tribunal according to law\n",
|
936 |
+
"\n",
|
937 |
+
"5. Everyone sentenced has the right to, take up or pursue his occupation.\n",
|
938 |
+
"\n",
|
939 |
+
"6. Sentences may, also include restoration of rights or removal of disabilities\n"
|
940 |
+
]
|
941 |
+
}
|
942 |
+
],
|
943 |
+
"source": [
|
944 |
+
"# Make a query to the pipeline with references included in your documentation\n",
|
945 |
+
"question = \"What's the Right of Fair Trial?\"\n",
|
946 |
+
"\n",
|
947 |
+
"response = rag_pipeline.run(\n",
|
948 |
+
" {\n",
|
949 |
+
" \"text_embedder\": {\"text\": question},\n",
|
950 |
+
" \"prompt_builder\": {\"question\": question},\n",
|
951 |
+
" \"ranker\": {\"query\": question},\n",
|
952 |
+
" }\n",
|
953 |
+
")\n",
|
954 |
+
"\n",
|
955 |
+
"print(response[\"llm\"][\"replies\"][0].text)"
|
956 |
+
]
|
957 |
+
}
|
958 |
+
],
|
959 |
+
"metadata": {
|
960 |
+
"kernelspec": {
|
961 |
+
"display_name": "distilabel-tutorials",
|
962 |
+
"language": "python",
|
963 |
+
"name": "python3"
|
964 |
+
},
|
965 |
+
"language_info": {
|
966 |
+
"codemirror_mode": {
|
967 |
+
"name": "ipython",
|
968 |
+
"version": 3
|
969 |
+
},
|
970 |
+
"file_extension": ".py",
|
971 |
+
"mimetype": "text/x-python",
|
972 |
+
"name": "python",
|
973 |
+
"nbconvert_exporter": "python",
|
974 |
+
"pygments_lexer": "ipython3",
|
975 |
+
"version": "3.11.4"
|
976 |
+
}
|
977 |
+
},
|
978 |
+
"nbformat": 4,
|
979 |
+
"nbformat_minor": 2
|
980 |
+
}
|