Pringled commited on
Commit
e315f16
·
verified ·
1 Parent(s): d9007cf

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +101 -68
README.md CHANGED
@@ -87,6 +87,12 @@ Please cite the [Model2Vec repository](https://github.com/MinishLab/model2vec) i
87
 
88
  ## Reproducibility
89
 
 
 
 
 
 
 
90
  ```python
91
  import random
92
  import logging
@@ -102,29 +108,29 @@ from sentence_transformers.training_args import BatchSamplers, MultiDatasetBatch
102
  from sentence_transformers.evaluation import NanoBEIREvaluator
103
  from sentence_transformers.models.StaticEmbedding import StaticEmbedding
104
  import wandb
105
- from transformers import AutoTokenizer
106
- from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
107
- from torch.optim import AdamW
108
- from transformers import get_linear_schedule_with_warmup
109
- from torch.optim.lr_scheduler import CosineAnnealingLR
110
 
111
  logging.basicConfig(
112
  format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
113
  )
114
  random.seed(12)
115
 
116
-
117
- def load_train_eval_datasets():
118
  """
119
- Either load the train and eval datasets from disk or load them from the datasets library & save them to disk.
 
 
120
 
121
- Upon saving to disk, we quit() to ensure that the datasets are not loaded into memory before training.
 
122
  """
123
  try:
 
124
  train_dataset = DatasetDict.load_from_disk("datasets/train_dataset")
125
  eval_dataset = DatasetDict.load_from_disk("datasets/eval_dataset")
126
- return train_dataset, eval_dataset
127
  except FileNotFoundError:
 
 
128
  print("Loading gooaq dataset...")
129
  gooaq_dataset = load_dataset("sentence-transformers/gooaq", split="train")
130
  gooaq_dataset_dict = gooaq_dataset.train_test_split(test_size=10_000, seed=12)
@@ -133,7 +139,11 @@ def load_train_eval_datasets():
133
  print("Loaded gooaq dataset.")
134
 
135
  print("Loading msmarco dataset...")
136
- msmarco_dataset = load_dataset("sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1", "triplet", split="train")
 
 
 
 
137
  msmarco_dataset_dict = msmarco_dataset.train_test_split(test_size=10_000, seed=12)
138
  msmarco_train_dataset: Dataset = msmarco_dataset_dict["train"]
139
  msmarco_eval_dataset: Dataset = msmarco_dataset_dict["test"]
@@ -147,15 +157,27 @@ def load_train_eval_datasets():
147
  print("Loaded squad dataset.")
148
 
149
  print("Loading s2orc dataset...")
150
- s2orc_dataset = load_dataset("sentence-transformers/s2orc", "title-abstract-pair", split="train[:100000]")
 
 
 
 
151
  s2orc_dataset_dict = s2orc_dataset.train_test_split(test_size=10_000, seed=12)
152
  s2orc_train_dataset: Dataset = s2orc_dataset_dict["train"]
153
  s2orc_eval_dataset: Dataset = s2orc_dataset_dict["test"]
154
  print("Loaded s2orc dataset.")
155
 
156
  print("Loading allnli dataset...")
157
- allnli_train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train")
158
- allnli_eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")
 
 
 
 
 
 
 
 
159
  print("Loaded allnli dataset.")
160
 
161
  print("Loading paq dataset...")
@@ -174,21 +196,33 @@ def load_train_eval_datasets():
174
 
175
  print("Loading msmarco_10m dataset...")
176
  msmarco_10m_dataset = load_dataset("bclavie/msmarco-10m-triplets", split="train")
177
- msmarco_10m_dataset_dict = msmarco_10m_dataset.train_test_split(test_size=10_000, seed=12)
 
 
178
  msmarco_10m_train_dataset: Dataset = msmarco_10m_dataset_dict["train"]
179
  msmarco_10m_eval_dataset: Dataset = msmarco_10m_dataset_dict["test"]
180
  print("Loaded msmarco_10m dataset.")
181
 
182
  print("Loading swim_ir dataset...")
183
- swim_ir_dataset = load_dataset("nthakur/swim-ir-monolingual", "en", split="train").select_columns(["query", "text"])
184
- swim_ir_dataset_dict = swim_ir_dataset.train_test_split(test_size=10_000, seed=12)
 
 
 
 
 
 
185
  swim_ir_train_dataset: Dataset = swim_ir_dataset_dict["train"]
186
  swim_ir_eval_dataset: Dataset = swim_ir_dataset_dict["test"]
187
  print("Loaded swim_ir dataset.")
188
 
189
  # NOTE: 20 negatives
190
  print("Loading pubmedqa dataset...")
191
- pubmedqa_dataset = load_dataset("sentence-transformers/pubmedqa", "triplet-20", split="train")
 
 
 
 
192
  pubmedqa_dataset_dict = pubmedqa_dataset.train_test_split(test_size=100, seed=12)
193
  pubmedqa_train_dataset: Dataset = pubmedqa_dataset_dict["train"]
194
  pubmedqa_eval_dataset: Dataset = pubmedqa_dataset_dict["test"]
@@ -196,7 +230,11 @@ def load_train_eval_datasets():
196
 
197
  # NOTE: A lot of overlap with anchor/positives
198
  print("Loading miracl dataset...")
199
- miracl_dataset = load_dataset("sentence-transformers/miracl", "en-triplet-all", split="train")
 
 
 
 
200
  miracl_dataset_dict = miracl_dataset.train_test_split(test_size=10_000, seed=12)
201
  miracl_train_dataset: Dataset = miracl_dataset_dict["train"]
202
  miracl_eval_dataset: Dataset = miracl_dataset_dict["test"]
@@ -204,7 +242,11 @@ def load_train_eval_datasets():
204
 
205
  # NOTE: A lot of overlap with anchor/positives
206
  print("Loading mldr dataset...")
207
- mldr_dataset = load_dataset("sentence-transformers/mldr", "en-triplet-all", split="train")
 
 
 
 
208
  mldr_dataset_dict = mldr_dataset.train_test_split(test_size=10_000, seed=12)
209
  mldr_train_dataset: Dataset = mldr_dataset_dict["train"]
210
  mldr_eval_dataset: Dataset = mldr_dataset_dict["test"]
@@ -212,7 +254,11 @@ def load_train_eval_datasets():
212
 
213
  # NOTE: A lot of overlap with anchor/positives
214
  print("Loading mr_tydi dataset...")
215
- mr_tydi_dataset = load_dataset("sentence-transformers/mr-tydi", "en-triplet-all", split="train")
 
 
 
 
216
  mr_tydi_dataset_dict = mr_tydi_dataset.train_test_split(test_size=10_000, seed=12)
217
  mr_tydi_train_dataset: Dataset = mr_tydi_dataset_dict["train"]
218
  mr_tydi_eval_dataset: Dataset = mr_tydi_dataset_dict["test"]
@@ -249,43 +295,35 @@ def load_train_eval_datasets():
249
  "mr_tydi": mr_tydi_eval_dataset,
250
  })
251
 
 
252
  train_dataset.save_to_disk("datasets/train_dataset")
253
  eval_dataset.save_to_disk("datasets/eval_dataset")
254
-
255
- # The `train_test_split` calls have put a lot of the datasets in memory, while we want it to just be on disk
256
  quit()
257
-
258
-
259
-
260
- def load_train_eval_datasets_reduced():
261
- # 1. Load the full datasets from disk
262
- train_dataset = DatasetDict.load_from_disk("datasets/train_dataset")
263
- eval_dataset = DatasetDict.load_from_disk("datasets/eval_dataset")
264
-
265
- factor = 10
266
- for subset_name in train_dataset:
267
- ds = train_dataset[subset_name]
268
- ds = ds.shuffle(seed=42) # Shuffle for a random subset
269
- new_len = len(ds) // factor # Keep 1/10th
270
- ds = ds.select(range(new_len))
271
- train_dataset[subset_name] = ds
272
-
273
- for subset_name in eval_dataset:
274
- ds = eval_dataset[subset_name]
275
- ds = ds.shuffle(seed=42)
276
- new_len = len(ds) // factor
277
- ds = ds.select(range(new_len))
278
- eval_dataset[subset_name] = ds
279
-
280
  return train_dataset, eval_dataset
281
 
 
282
  def main():
283
  wandb.init(entity="minishlab", project="minishlab")
284
- # 1. Load a model to finetune with 2. (Optional) model card data
285
- static_embedding = StaticEmbedding.from_model2vec("minishlab/potion-base-512dim-60kvocab")
286
- # 2. Initialize the SentenceTransformer model as usual
287
 
288
- model_name = "potion-retrieval-512dim-60kvocab-v1"
 
289
  model = SentenceTransformer(
290
  modules=[static_embedding],
291
  model_card_data=SentenceTransformerModelCardData(
@@ -294,33 +332,31 @@ def main():
294
  model_name=model_name,
295
  ),
296
  )
297
- # 3. Set up training & evaluation datasets - each dataset is trained with MNRL (with MRL)
298
- train_dataset, eval_dataset = load_train_eval_datasets_reduced()
 
 
299
  print(train_dataset)
300
 
301
  # 4. Define a loss function
302
  loss = MultipleNegativesRankingLoss(model)
303
  loss = MatryoshkaLoss(model, loss, matryoshka_dims=[32, 64, 128, 256, 512])
304
 
305
-
306
- # 5. (Optional) Specify training arguments
307
  run_name = model_name
308
  epochs = 3
309
- lr = 0.05
310
  args = SentenceTransformerTrainingArguments(
311
- # Required parameter:
312
  output_dir=f"models/{run_name}",
313
- # Optional training parameters:
314
  num_train_epochs=epochs,
315
  per_device_train_batch_size=2048,
316
  per_device_eval_batch_size=2048,
317
  learning_rate=lr,
318
  warmup_ratio=0.1,
319
- fp16=False, # Set to False if you get an error that your GPU can't run on FP16
320
- bf16=True, # Set to True if you have a GPU that supports BF16
321
- batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
322
  multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
323
- # Optional tracking/debugging parameters:
324
  eval_strategy="steps",
325
  eval_steps=250,
326
  save_strategy="steps",
@@ -328,17 +364,15 @@ def main():
328
  save_total_limit=2,
329
  logging_steps=250,
330
  logging_first_step=True,
331
- run_name=run_name, # Will be used in W&B if `wandb` is installed
332
  report_to=["wandb"],
333
  load_best_model_at_end=True,
334
  metric_for_best_model="eval_NanoBEIR_mean_cosine_ndcg@10",
335
  greater_is_better=True,
336
- #
337
  )
338
 
339
- # 6. (Optional) Create an evaluator & evaluate the base model
340
  evaluator = NanoBEIREvaluator()
341
-
342
  evaluator(model)
343
 
344
  # 7. Create a trainer & train
@@ -352,12 +386,11 @@ def main():
352
  )
353
  trainer.train()
354
 
355
- # (Optional) Evaluate the trained model on the evaluator after training
356
  evaluator(model)
357
-
358
- # 8. Save the trained model
359
  model.save_pretrained(f"models/{run_name}/final")
360
 
 
361
  if __name__ == "__main__":
362
  main()
363
  ```
 
87
 
88
  ## Reproducibility
89
 
90
+ The following script can be used to reproduce this model. All credits go to [Tom Aarsen](https://huggingface.co/tomaarsen) for this fine-tuning approach and code. We make a few modifcations to the original code, namely:
91
+
92
+ - We start with a pre-trained Model2Vec model ([potion-base-32M](https://huggingface.co/minishlab/potion-base-32M)).
93
+ - We reduce the dataset size by a factor of 10. During experiments we saw that we didn't need the full dataset for the model to converge.
94
+ - We decease the learning rate and train for 3 epochs instead of 1. Using a high learning rate wipes the effects of using a pre-trained model.
95
+
96
  ```python
97
  import random
98
  import logging
 
108
  from sentence_transformers.evaluation import NanoBEIREvaluator
109
  from sentence_transformers.models.StaticEmbedding import StaticEmbedding
110
  import wandb
 
 
 
 
 
111
 
112
  logging.basicConfig(
113
  format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
114
  )
115
  random.seed(12)
116
 
117
+
118
+ def load_train_eval_datasets(factor: int = 1):
119
  """
120
+ Loads train and eval datasets from disk if available. Otherwise, downloads
121
+ them from Hugging Face, preprocesses, and saves them to disk. If `factor` is
122
+ greater than 1, returns a fraction (1/factor) of each dataset subset.
123
 
124
+ :param factor: The factor by which the data is reduced. If factor=1, no reduction is performed.
125
+ :return: (train_dataset: DatasetDict, eval_dataset: DatasetDict)
126
  """
127
  try:
128
+ # Try loading from disk
129
  train_dataset = DatasetDict.load_from_disk("datasets/train_dataset")
130
  eval_dataset = DatasetDict.load_from_disk("datasets/eval_dataset")
 
131
  except FileNotFoundError:
132
+ print("Prebuilt datasets not found on disk. Building from scratch...")
133
+
134
  print("Loading gooaq dataset...")
135
  gooaq_dataset = load_dataset("sentence-transformers/gooaq", split="train")
136
  gooaq_dataset_dict = gooaq_dataset.train_test_split(test_size=10_000, seed=12)
 
139
  print("Loaded gooaq dataset.")
140
 
141
  print("Loading msmarco dataset...")
142
+ msmarco_dataset = load_dataset(
143
+ "sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1",
144
+ "triplet",
145
+ split="train"
146
+ )
147
  msmarco_dataset_dict = msmarco_dataset.train_test_split(test_size=10_000, seed=12)
148
  msmarco_train_dataset: Dataset = msmarco_dataset_dict["train"]
149
  msmarco_eval_dataset: Dataset = msmarco_dataset_dict["test"]
 
157
  print("Loaded squad dataset.")
158
 
159
  print("Loading s2orc dataset...")
160
+ s2orc_dataset = load_dataset(
161
+ "sentence-transformers/s2orc",
162
+ "title-abstract-pair",
163
+ split="train[:100000]" # limit to 100k
164
+ )
165
  s2orc_dataset_dict = s2orc_dataset.train_test_split(test_size=10_000, seed=12)
166
  s2orc_train_dataset: Dataset = s2orc_dataset_dict["train"]
167
  s2orc_eval_dataset: Dataset = s2orc_dataset_dict["test"]
168
  print("Loaded s2orc dataset.")
169
 
170
  print("Loading allnli dataset...")
171
+ allnli_train_dataset = load_dataset(
172
+ "sentence-transformers/all-nli",
173
+ "triplet",
174
+ split="train"
175
+ )
176
+ allnli_eval_dataset = load_dataset(
177
+ "sentence-transformers/all-nli",
178
+ "triplet",
179
+ split="dev"
180
+ )
181
  print("Loaded allnli dataset.")
182
 
183
  print("Loading paq dataset...")
 
196
 
197
  print("Loading msmarco_10m dataset...")
198
  msmarco_10m_dataset = load_dataset("bclavie/msmarco-10m-triplets", split="train")
199
+ msmarco_10m_dataset_dict = msmarco_10m_dataset.train_test_split(
200
+ test_size=10_000, seed=12
201
+ )
202
  msmarco_10m_train_dataset: Dataset = msmarco_10m_dataset_dict["train"]
203
  msmarco_10m_eval_dataset: Dataset = msmarco_10m_dataset_dict["test"]
204
  print("Loaded msmarco_10m dataset.")
205
 
206
  print("Loading swim_ir dataset...")
207
+ swim_ir_dataset = load_dataset(
208
+ "nthakur/swim-ir-monolingual",
209
+ "en",
210
+ split="train"
211
+ ).select_columns(["query", "text"])
212
+ swim_ir_dataset_dict = swim_ir_dataset.train_test_split(
213
+ test_size=10_000, seed=12
214
+ )
215
  swim_ir_train_dataset: Dataset = swim_ir_dataset_dict["train"]
216
  swim_ir_eval_dataset: Dataset = swim_ir_dataset_dict["test"]
217
  print("Loaded swim_ir dataset.")
218
 
219
  # NOTE: 20 negatives
220
  print("Loading pubmedqa dataset...")
221
+ pubmedqa_dataset = load_dataset(
222
+ "sentence-transformers/pubmedqa",
223
+ "triplet-20",
224
+ split="train"
225
+ )
226
  pubmedqa_dataset_dict = pubmedqa_dataset.train_test_split(test_size=100, seed=12)
227
  pubmedqa_train_dataset: Dataset = pubmedqa_dataset_dict["train"]
228
  pubmedqa_eval_dataset: Dataset = pubmedqa_dataset_dict["test"]
 
230
 
231
  # NOTE: A lot of overlap with anchor/positives
232
  print("Loading miracl dataset...")
233
+ miracl_dataset = load_dataset(
234
+ "sentence-transformers/miracl",
235
+ "en-triplet-all",
236
+ split="train"
237
+ )
238
  miracl_dataset_dict = miracl_dataset.train_test_split(test_size=10_000, seed=12)
239
  miracl_train_dataset: Dataset = miracl_dataset_dict["train"]
240
  miracl_eval_dataset: Dataset = miracl_dataset_dict["test"]
 
242
 
243
  # NOTE: A lot of overlap with anchor/positives
244
  print("Loading mldr dataset...")
245
+ mldr_dataset = load_dataset(
246
+ "sentence-transformers/mldr",
247
+ "en-triplet-all",
248
+ split="train"
249
+ )
250
  mldr_dataset_dict = mldr_dataset.train_test_split(test_size=10_000, seed=12)
251
  mldr_train_dataset: Dataset = mldr_dataset_dict["train"]
252
  mldr_eval_dataset: Dataset = mldr_dataset_dict["test"]
 
254
 
255
  # NOTE: A lot of overlap with anchor/positives
256
  print("Loading mr_tydi dataset...")
257
+ mr_tydi_dataset = load_dataset(
258
+ "sentence-transformers/mr-tydi",
259
+ "en-triplet-all",
260
+ split="train"
261
+ )
262
  mr_tydi_dataset_dict = mr_tydi_dataset.train_test_split(test_size=10_000, seed=12)
263
  mr_tydi_train_dataset: Dataset = mr_tydi_dataset_dict["train"]
264
  mr_tydi_eval_dataset: Dataset = mr_tydi_dataset_dict["test"]
 
295
  "mr_tydi": mr_tydi_eval_dataset,
296
  })
297
 
298
+ # Save to disk for next time
299
  train_dataset.save_to_disk("datasets/train_dataset")
300
  eval_dataset.save_to_disk("datasets/eval_dataset")
301
+
302
+ # Quit to avoid memory overhead on large datasets
303
  quit()
304
+
305
+ # Reduce the dataset if factor > 1
306
+ if factor > 1:
307
+ for subset_name in train_dataset:
308
+ ds = train_dataset[subset_name].shuffle(seed=42)
309
+ new_len = len(ds) // factor
310
+ train_dataset[subset_name] = ds.select(range(new_len))
311
+
312
+ for subset_name in eval_dataset:
313
+ ds = eval_dataset[subset_name].shuffle(seed=42)
314
+ new_len = len(ds) // factor
315
+ eval_dataset[subset_name] = ds.select(range(new_len))
316
+
 
 
 
 
 
 
 
 
 
 
317
  return train_dataset, eval_dataset
318
 
319
+
320
  def main():
321
  wandb.init(entity="minishlab", project="minishlab")
322
+ # 1. Load a model to finetune
323
+ static_embedding = StaticEmbedding.from_model2vec("minishlab/potion-base-32M")
 
324
 
325
+ # 2. Initialize the SentenceTransformer model
326
+ model_name = "potion-retrieval-32M"
327
  model = SentenceTransformer(
328
  modules=[static_embedding],
329
  model_card_data=SentenceTransformerModelCardData(
 
332
  model_name=model_name,
333
  ),
334
  )
335
+
336
+ # 3. Load training & evaluation datasets
337
+ # NOTE: we reduce the total dataset size by a factor of 10
338
+ train_dataset, eval_dataset = load_train_eval_datasets(factor=10)
339
  print(train_dataset)
340
 
341
  # 4. Define a loss function
342
  loss = MultipleNegativesRankingLoss(model)
343
  loss = MatryoshkaLoss(model, loss, matryoshka_dims=[32, 64, 128, 256, 512])
344
 
345
+ # 5. Specify training arguments
 
346
  run_name = model_name
347
  epochs = 3
348
+ lr = 0.05
349
  args = SentenceTransformerTrainingArguments(
 
350
  output_dir=f"models/{run_name}",
 
351
  num_train_epochs=epochs,
352
  per_device_train_batch_size=2048,
353
  per_device_eval_batch_size=2048,
354
  learning_rate=lr,
355
  warmup_ratio=0.1,
356
+ fp16=False,
357
+ bf16=True,
358
+ batch_sampler=BatchSamplers.NO_DUPLICATES,
359
  multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
 
360
  eval_strategy="steps",
361
  eval_steps=250,
362
  save_strategy="steps",
 
364
  save_total_limit=2,
365
  logging_steps=250,
366
  logging_first_step=True,
367
+ run_name=run_name,
368
  report_to=["wandb"],
369
  load_best_model_at_end=True,
370
  metric_for_best_model="eval_NanoBEIR_mean_cosine_ndcg@10",
371
  greater_is_better=True,
 
372
  )
373
 
374
+ # 6. Create an evaluator & evaluate the base model
375
  evaluator = NanoBEIREvaluator()
 
376
  evaluator(model)
377
 
378
  # 7. Create a trainer & train
 
386
  )
387
  trainer.train()
388
 
389
+ # 8. Evaluate the trained model and save
390
  evaluator(model)
 
 
391
  model.save_pretrained(f"models/{run_name}/final")
392
 
393
+
394
  if __name__ == "__main__":
395
  main()
396
  ```