Update README.md
Browse files
README.md
CHANGED
@@ -87,6 +87,12 @@ Please cite the [Model2Vec repository](https://github.com/MinishLab/model2vec) i
|
|
87 |
|
88 |
## Reproducibility
|
89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
```python
|
91 |
import random
|
92 |
import logging
|
@@ -102,29 +108,29 @@ from sentence_transformers.training_args import BatchSamplers, MultiDatasetBatch
|
|
102 |
from sentence_transformers.evaluation import NanoBEIREvaluator
|
103 |
from sentence_transformers.models.StaticEmbedding import StaticEmbedding
|
104 |
import wandb
|
105 |
-
from transformers import AutoTokenizer
|
106 |
-
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
|
107 |
-
from torch.optim import AdamW
|
108 |
-
from transformers import get_linear_schedule_with_warmup
|
109 |
-
from torch.optim.lr_scheduler import CosineAnnealingLR
|
110 |
|
111 |
logging.basicConfig(
|
112 |
format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
|
113 |
)
|
114 |
random.seed(12)
|
115 |
|
116 |
-
|
117 |
-
def load_train_eval_datasets():
|
118 |
"""
|
119 |
-
|
|
|
|
|
120 |
|
121 |
-
|
|
|
122 |
"""
|
123 |
try:
|
|
|
124 |
train_dataset = DatasetDict.load_from_disk("datasets/train_dataset")
|
125 |
eval_dataset = DatasetDict.load_from_disk("datasets/eval_dataset")
|
126 |
-
return train_dataset, eval_dataset
|
127 |
except FileNotFoundError:
|
|
|
|
|
128 |
print("Loading gooaq dataset...")
|
129 |
gooaq_dataset = load_dataset("sentence-transformers/gooaq", split="train")
|
130 |
gooaq_dataset_dict = gooaq_dataset.train_test_split(test_size=10_000, seed=12)
|
@@ -133,7 +139,11 @@ def load_train_eval_datasets():
|
|
133 |
print("Loaded gooaq dataset.")
|
134 |
|
135 |
print("Loading msmarco dataset...")
|
136 |
-
msmarco_dataset = load_dataset(
|
|
|
|
|
|
|
|
|
137 |
msmarco_dataset_dict = msmarco_dataset.train_test_split(test_size=10_000, seed=12)
|
138 |
msmarco_train_dataset: Dataset = msmarco_dataset_dict["train"]
|
139 |
msmarco_eval_dataset: Dataset = msmarco_dataset_dict["test"]
|
@@ -147,15 +157,27 @@ def load_train_eval_datasets():
|
|
147 |
print("Loaded squad dataset.")
|
148 |
|
149 |
print("Loading s2orc dataset...")
|
150 |
-
s2orc_dataset = load_dataset(
|
|
|
|
|
|
|
|
|
151 |
s2orc_dataset_dict = s2orc_dataset.train_test_split(test_size=10_000, seed=12)
|
152 |
s2orc_train_dataset: Dataset = s2orc_dataset_dict["train"]
|
153 |
s2orc_eval_dataset: Dataset = s2orc_dataset_dict["test"]
|
154 |
print("Loaded s2orc dataset.")
|
155 |
|
156 |
print("Loading allnli dataset...")
|
157 |
-
allnli_train_dataset = load_dataset(
|
158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
print("Loaded allnli dataset.")
|
160 |
|
161 |
print("Loading paq dataset...")
|
@@ -174,21 +196,33 @@ def load_train_eval_datasets():
|
|
174 |
|
175 |
print("Loading msmarco_10m dataset...")
|
176 |
msmarco_10m_dataset = load_dataset("bclavie/msmarco-10m-triplets", split="train")
|
177 |
-
msmarco_10m_dataset_dict = msmarco_10m_dataset.train_test_split(
|
|
|
|
|
178 |
msmarco_10m_train_dataset: Dataset = msmarco_10m_dataset_dict["train"]
|
179 |
msmarco_10m_eval_dataset: Dataset = msmarco_10m_dataset_dict["test"]
|
180 |
print("Loaded msmarco_10m dataset.")
|
181 |
|
182 |
print("Loading swim_ir dataset...")
|
183 |
-
swim_ir_dataset = load_dataset(
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
swim_ir_train_dataset: Dataset = swim_ir_dataset_dict["train"]
|
186 |
swim_ir_eval_dataset: Dataset = swim_ir_dataset_dict["test"]
|
187 |
print("Loaded swim_ir dataset.")
|
188 |
|
189 |
# NOTE: 20 negatives
|
190 |
print("Loading pubmedqa dataset...")
|
191 |
-
pubmedqa_dataset = load_dataset(
|
|
|
|
|
|
|
|
|
192 |
pubmedqa_dataset_dict = pubmedqa_dataset.train_test_split(test_size=100, seed=12)
|
193 |
pubmedqa_train_dataset: Dataset = pubmedqa_dataset_dict["train"]
|
194 |
pubmedqa_eval_dataset: Dataset = pubmedqa_dataset_dict["test"]
|
@@ -196,7 +230,11 @@ def load_train_eval_datasets():
|
|
196 |
|
197 |
# NOTE: A lot of overlap with anchor/positives
|
198 |
print("Loading miracl dataset...")
|
199 |
-
miracl_dataset = load_dataset(
|
|
|
|
|
|
|
|
|
200 |
miracl_dataset_dict = miracl_dataset.train_test_split(test_size=10_000, seed=12)
|
201 |
miracl_train_dataset: Dataset = miracl_dataset_dict["train"]
|
202 |
miracl_eval_dataset: Dataset = miracl_dataset_dict["test"]
|
@@ -204,7 +242,11 @@ def load_train_eval_datasets():
|
|
204 |
|
205 |
# NOTE: A lot of overlap with anchor/positives
|
206 |
print("Loading mldr dataset...")
|
207 |
-
mldr_dataset = load_dataset(
|
|
|
|
|
|
|
|
|
208 |
mldr_dataset_dict = mldr_dataset.train_test_split(test_size=10_000, seed=12)
|
209 |
mldr_train_dataset: Dataset = mldr_dataset_dict["train"]
|
210 |
mldr_eval_dataset: Dataset = mldr_dataset_dict["test"]
|
@@ -212,7 +254,11 @@ def load_train_eval_datasets():
|
|
212 |
|
213 |
# NOTE: A lot of overlap with anchor/positives
|
214 |
print("Loading mr_tydi dataset...")
|
215 |
-
mr_tydi_dataset = load_dataset(
|
|
|
|
|
|
|
|
|
216 |
mr_tydi_dataset_dict = mr_tydi_dataset.train_test_split(test_size=10_000, seed=12)
|
217 |
mr_tydi_train_dataset: Dataset = mr_tydi_dataset_dict["train"]
|
218 |
mr_tydi_eval_dataset: Dataset = mr_tydi_dataset_dict["test"]
|
@@ -249,43 +295,35 @@ def load_train_eval_datasets():
|
|
249 |
"mr_tydi": mr_tydi_eval_dataset,
|
250 |
})
|
251 |
|
|
|
252 |
train_dataset.save_to_disk("datasets/train_dataset")
|
253 |
eval_dataset.save_to_disk("datasets/eval_dataset")
|
254 |
-
|
255 |
-
#
|
256 |
quit()
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
ds = ds.select(range(new_len))
|
271 |
-
train_dataset[subset_name] = ds
|
272 |
-
|
273 |
-
for subset_name in eval_dataset:
|
274 |
-
ds = eval_dataset[subset_name]
|
275 |
-
ds = ds.shuffle(seed=42)
|
276 |
-
new_len = len(ds) // factor
|
277 |
-
ds = ds.select(range(new_len))
|
278 |
-
eval_dataset[subset_name] = ds
|
279 |
-
|
280 |
return train_dataset, eval_dataset
|
281 |
|
|
|
282 |
def main():
|
283 |
wandb.init(entity="minishlab", project="minishlab")
|
284 |
-
# 1. Load a model to finetune
|
285 |
-
static_embedding = StaticEmbedding.from_model2vec("minishlab/potion-base-
|
286 |
-
# 2. Initialize the SentenceTransformer model as usual
|
287 |
|
288 |
-
|
|
|
289 |
model = SentenceTransformer(
|
290 |
modules=[static_embedding],
|
291 |
model_card_data=SentenceTransformerModelCardData(
|
@@ -294,33 +332,31 @@ def main():
|
|
294 |
model_name=model_name,
|
295 |
),
|
296 |
)
|
297 |
-
|
298 |
-
|
|
|
|
|
299 |
print(train_dataset)
|
300 |
|
301 |
# 4. Define a loss function
|
302 |
loss = MultipleNegativesRankingLoss(model)
|
303 |
loss = MatryoshkaLoss(model, loss, matryoshka_dims=[32, 64, 128, 256, 512])
|
304 |
|
305 |
-
|
306 |
-
# 5. (Optional) Specify training arguments
|
307 |
run_name = model_name
|
308 |
epochs = 3
|
309 |
-
lr = 0.05
|
310 |
args = SentenceTransformerTrainingArguments(
|
311 |
-
# Required parameter:
|
312 |
output_dir=f"models/{run_name}",
|
313 |
-
# Optional training parameters:
|
314 |
num_train_epochs=epochs,
|
315 |
per_device_train_batch_size=2048,
|
316 |
per_device_eval_batch_size=2048,
|
317 |
learning_rate=lr,
|
318 |
warmup_ratio=0.1,
|
319 |
-
fp16=False,
|
320 |
-
bf16=True,
|
321 |
-
batch_sampler=BatchSamplers.NO_DUPLICATES,
|
322 |
multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
|
323 |
-
# Optional tracking/debugging parameters:
|
324 |
eval_strategy="steps",
|
325 |
eval_steps=250,
|
326 |
save_strategy="steps",
|
@@ -328,17 +364,15 @@ def main():
|
|
328 |
save_total_limit=2,
|
329 |
logging_steps=250,
|
330 |
logging_first_step=True,
|
331 |
-
run_name=run_name,
|
332 |
report_to=["wandb"],
|
333 |
load_best_model_at_end=True,
|
334 |
metric_for_best_model="eval_NanoBEIR_mean_cosine_ndcg@10",
|
335 |
greater_is_better=True,
|
336 |
-
#
|
337 |
)
|
338 |
|
339 |
-
# 6.
|
340 |
evaluator = NanoBEIREvaluator()
|
341 |
-
|
342 |
evaluator(model)
|
343 |
|
344 |
# 7. Create a trainer & train
|
@@ -352,12 +386,11 @@ def main():
|
|
352 |
)
|
353 |
trainer.train()
|
354 |
|
355 |
-
#
|
356 |
evaluator(model)
|
357 |
-
|
358 |
-
# 8. Save the trained model
|
359 |
model.save_pretrained(f"models/{run_name}/final")
|
360 |
|
|
|
361 |
if __name__ == "__main__":
|
362 |
main()
|
363 |
```
|
|
|
87 |
|
88 |
## Reproducibility
|
89 |
|
90 |
+
The following script can be used to reproduce this model. All credits go to [Tom Aarsen](https://huggingface.co/tomaarsen) for this fine-tuning approach and code. We make a few modifcations to the original code, namely:
|
91 |
+
|
92 |
+
- We start with a pre-trained Model2Vec model ([potion-base-32M](https://huggingface.co/minishlab/potion-base-32M)).
|
93 |
+
- We reduce the dataset size by a factor of 10. During experiments we saw that we didn't need the full dataset for the model to converge.
|
94 |
+
- We decease the learning rate and train for 3 epochs instead of 1. Using a high learning rate wipes the effects of using a pre-trained model.
|
95 |
+
|
96 |
```python
|
97 |
import random
|
98 |
import logging
|
|
|
108 |
from sentence_transformers.evaluation import NanoBEIREvaluator
|
109 |
from sentence_transformers.models.StaticEmbedding import StaticEmbedding
|
110 |
import wandb
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
logging.basicConfig(
|
113 |
format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
|
114 |
)
|
115 |
random.seed(12)
|
116 |
|
117 |
+
|
118 |
+
def load_train_eval_datasets(factor: int = 1):
|
119 |
"""
|
120 |
+
Loads train and eval datasets from disk if available. Otherwise, downloads
|
121 |
+
them from Hugging Face, preprocesses, and saves them to disk. If `factor` is
|
122 |
+
greater than 1, returns a fraction (1/factor) of each dataset subset.
|
123 |
|
124 |
+
:param factor: The factor by which the data is reduced. If factor=1, no reduction is performed.
|
125 |
+
:return: (train_dataset: DatasetDict, eval_dataset: DatasetDict)
|
126 |
"""
|
127 |
try:
|
128 |
+
# Try loading from disk
|
129 |
train_dataset = DatasetDict.load_from_disk("datasets/train_dataset")
|
130 |
eval_dataset = DatasetDict.load_from_disk("datasets/eval_dataset")
|
|
|
131 |
except FileNotFoundError:
|
132 |
+
print("Prebuilt datasets not found on disk. Building from scratch...")
|
133 |
+
|
134 |
print("Loading gooaq dataset...")
|
135 |
gooaq_dataset = load_dataset("sentence-transformers/gooaq", split="train")
|
136 |
gooaq_dataset_dict = gooaq_dataset.train_test_split(test_size=10_000, seed=12)
|
|
|
139 |
print("Loaded gooaq dataset.")
|
140 |
|
141 |
print("Loading msmarco dataset...")
|
142 |
+
msmarco_dataset = load_dataset(
|
143 |
+
"sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1",
|
144 |
+
"triplet",
|
145 |
+
split="train"
|
146 |
+
)
|
147 |
msmarco_dataset_dict = msmarco_dataset.train_test_split(test_size=10_000, seed=12)
|
148 |
msmarco_train_dataset: Dataset = msmarco_dataset_dict["train"]
|
149 |
msmarco_eval_dataset: Dataset = msmarco_dataset_dict["test"]
|
|
|
157 |
print("Loaded squad dataset.")
|
158 |
|
159 |
print("Loading s2orc dataset...")
|
160 |
+
s2orc_dataset = load_dataset(
|
161 |
+
"sentence-transformers/s2orc",
|
162 |
+
"title-abstract-pair",
|
163 |
+
split="train[:100000]" # limit to 100k
|
164 |
+
)
|
165 |
s2orc_dataset_dict = s2orc_dataset.train_test_split(test_size=10_000, seed=12)
|
166 |
s2orc_train_dataset: Dataset = s2orc_dataset_dict["train"]
|
167 |
s2orc_eval_dataset: Dataset = s2orc_dataset_dict["test"]
|
168 |
print("Loaded s2orc dataset.")
|
169 |
|
170 |
print("Loading allnli dataset...")
|
171 |
+
allnli_train_dataset = load_dataset(
|
172 |
+
"sentence-transformers/all-nli",
|
173 |
+
"triplet",
|
174 |
+
split="train"
|
175 |
+
)
|
176 |
+
allnli_eval_dataset = load_dataset(
|
177 |
+
"sentence-transformers/all-nli",
|
178 |
+
"triplet",
|
179 |
+
split="dev"
|
180 |
+
)
|
181 |
print("Loaded allnli dataset.")
|
182 |
|
183 |
print("Loading paq dataset...")
|
|
|
196 |
|
197 |
print("Loading msmarco_10m dataset...")
|
198 |
msmarco_10m_dataset = load_dataset("bclavie/msmarco-10m-triplets", split="train")
|
199 |
+
msmarco_10m_dataset_dict = msmarco_10m_dataset.train_test_split(
|
200 |
+
test_size=10_000, seed=12
|
201 |
+
)
|
202 |
msmarco_10m_train_dataset: Dataset = msmarco_10m_dataset_dict["train"]
|
203 |
msmarco_10m_eval_dataset: Dataset = msmarco_10m_dataset_dict["test"]
|
204 |
print("Loaded msmarco_10m dataset.")
|
205 |
|
206 |
print("Loading swim_ir dataset...")
|
207 |
+
swim_ir_dataset = load_dataset(
|
208 |
+
"nthakur/swim-ir-monolingual",
|
209 |
+
"en",
|
210 |
+
split="train"
|
211 |
+
).select_columns(["query", "text"])
|
212 |
+
swim_ir_dataset_dict = swim_ir_dataset.train_test_split(
|
213 |
+
test_size=10_000, seed=12
|
214 |
+
)
|
215 |
swim_ir_train_dataset: Dataset = swim_ir_dataset_dict["train"]
|
216 |
swim_ir_eval_dataset: Dataset = swim_ir_dataset_dict["test"]
|
217 |
print("Loaded swim_ir dataset.")
|
218 |
|
219 |
# NOTE: 20 negatives
|
220 |
print("Loading pubmedqa dataset...")
|
221 |
+
pubmedqa_dataset = load_dataset(
|
222 |
+
"sentence-transformers/pubmedqa",
|
223 |
+
"triplet-20",
|
224 |
+
split="train"
|
225 |
+
)
|
226 |
pubmedqa_dataset_dict = pubmedqa_dataset.train_test_split(test_size=100, seed=12)
|
227 |
pubmedqa_train_dataset: Dataset = pubmedqa_dataset_dict["train"]
|
228 |
pubmedqa_eval_dataset: Dataset = pubmedqa_dataset_dict["test"]
|
|
|
230 |
|
231 |
# NOTE: A lot of overlap with anchor/positives
|
232 |
print("Loading miracl dataset...")
|
233 |
+
miracl_dataset = load_dataset(
|
234 |
+
"sentence-transformers/miracl",
|
235 |
+
"en-triplet-all",
|
236 |
+
split="train"
|
237 |
+
)
|
238 |
miracl_dataset_dict = miracl_dataset.train_test_split(test_size=10_000, seed=12)
|
239 |
miracl_train_dataset: Dataset = miracl_dataset_dict["train"]
|
240 |
miracl_eval_dataset: Dataset = miracl_dataset_dict["test"]
|
|
|
242 |
|
243 |
# NOTE: A lot of overlap with anchor/positives
|
244 |
print("Loading mldr dataset...")
|
245 |
+
mldr_dataset = load_dataset(
|
246 |
+
"sentence-transformers/mldr",
|
247 |
+
"en-triplet-all",
|
248 |
+
split="train"
|
249 |
+
)
|
250 |
mldr_dataset_dict = mldr_dataset.train_test_split(test_size=10_000, seed=12)
|
251 |
mldr_train_dataset: Dataset = mldr_dataset_dict["train"]
|
252 |
mldr_eval_dataset: Dataset = mldr_dataset_dict["test"]
|
|
|
254 |
|
255 |
# NOTE: A lot of overlap with anchor/positives
|
256 |
print("Loading mr_tydi dataset...")
|
257 |
+
mr_tydi_dataset = load_dataset(
|
258 |
+
"sentence-transformers/mr-tydi",
|
259 |
+
"en-triplet-all",
|
260 |
+
split="train"
|
261 |
+
)
|
262 |
mr_tydi_dataset_dict = mr_tydi_dataset.train_test_split(test_size=10_000, seed=12)
|
263 |
mr_tydi_train_dataset: Dataset = mr_tydi_dataset_dict["train"]
|
264 |
mr_tydi_eval_dataset: Dataset = mr_tydi_dataset_dict["test"]
|
|
|
295 |
"mr_tydi": mr_tydi_eval_dataset,
|
296 |
})
|
297 |
|
298 |
+
# Save to disk for next time
|
299 |
train_dataset.save_to_disk("datasets/train_dataset")
|
300 |
eval_dataset.save_to_disk("datasets/eval_dataset")
|
301 |
+
|
302 |
+
# Quit to avoid memory overhead on large datasets
|
303 |
quit()
|
304 |
+
|
305 |
+
# Reduce the dataset if factor > 1
|
306 |
+
if factor > 1:
|
307 |
+
for subset_name in train_dataset:
|
308 |
+
ds = train_dataset[subset_name].shuffle(seed=42)
|
309 |
+
new_len = len(ds) // factor
|
310 |
+
train_dataset[subset_name] = ds.select(range(new_len))
|
311 |
+
|
312 |
+
for subset_name in eval_dataset:
|
313 |
+
ds = eval_dataset[subset_name].shuffle(seed=42)
|
314 |
+
new_len = len(ds) // factor
|
315 |
+
eval_dataset[subset_name] = ds.select(range(new_len))
|
316 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
return train_dataset, eval_dataset
|
318 |
|
319 |
+
|
320 |
def main():
|
321 |
wandb.init(entity="minishlab", project="minishlab")
|
322 |
+
# 1. Load a model to finetune
|
323 |
+
static_embedding = StaticEmbedding.from_model2vec("minishlab/potion-base-32M")
|
|
|
324 |
|
325 |
+
# 2. Initialize the SentenceTransformer model
|
326 |
+
model_name = "potion-retrieval-32M"
|
327 |
model = SentenceTransformer(
|
328 |
modules=[static_embedding],
|
329 |
model_card_data=SentenceTransformerModelCardData(
|
|
|
332 |
model_name=model_name,
|
333 |
),
|
334 |
)
|
335 |
+
|
336 |
+
# 3. Load training & evaluation datasets
|
337 |
+
# NOTE: we reduce the total dataset size by a factor of 10
|
338 |
+
train_dataset, eval_dataset = load_train_eval_datasets(factor=10)
|
339 |
print(train_dataset)
|
340 |
|
341 |
# 4. Define a loss function
|
342 |
loss = MultipleNegativesRankingLoss(model)
|
343 |
loss = MatryoshkaLoss(model, loss, matryoshka_dims=[32, 64, 128, 256, 512])
|
344 |
|
345 |
+
# 5. Specify training arguments
|
|
|
346 |
run_name = model_name
|
347 |
epochs = 3
|
348 |
+
lr = 0.05
|
349 |
args = SentenceTransformerTrainingArguments(
|
|
|
350 |
output_dir=f"models/{run_name}",
|
|
|
351 |
num_train_epochs=epochs,
|
352 |
per_device_train_batch_size=2048,
|
353 |
per_device_eval_batch_size=2048,
|
354 |
learning_rate=lr,
|
355 |
warmup_ratio=0.1,
|
356 |
+
fp16=False,
|
357 |
+
bf16=True,
|
358 |
+
batch_sampler=BatchSamplers.NO_DUPLICATES,
|
359 |
multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
|
|
|
360 |
eval_strategy="steps",
|
361 |
eval_steps=250,
|
362 |
save_strategy="steps",
|
|
|
364 |
save_total_limit=2,
|
365 |
logging_steps=250,
|
366 |
logging_first_step=True,
|
367 |
+
run_name=run_name,
|
368 |
report_to=["wandb"],
|
369 |
load_best_model_at_end=True,
|
370 |
metric_for_best_model="eval_NanoBEIR_mean_cosine_ndcg@10",
|
371 |
greater_is_better=True,
|
|
|
372 |
)
|
373 |
|
374 |
+
# 6. Create an evaluator & evaluate the base model
|
375 |
evaluator = NanoBEIREvaluator()
|
|
|
376 |
evaluator(model)
|
377 |
|
378 |
# 7. Create a trainer & train
|
|
|
386 |
)
|
387 |
trainer.train()
|
388 |
|
389 |
+
# 8. Evaluate the trained model and save
|
390 |
evaluator(model)
|
|
|
|
|
391 |
model.save_pretrained(f"models/{run_name}/final")
|
392 |
|
393 |
+
|
394 |
if __name__ == "__main__":
|
395 |
main()
|
396 |
```
|