Pringled commited on
Commit
0972166
·
verified ·
1 Parent(s): 4df227c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +291 -0
README.md CHANGED
@@ -7,6 +7,20 @@ tags:
7
  - static-embeddings
8
  ---
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  # potion-retrieval-512dim-60kvocab-v1replica-v1 Model Card
11
 
12
  This [Model2Vec](https://github.com/MinishLab/model2vec) model is a distilled version of a Sentence Transformer. It uses static embeddings, allowing text embeddings to be computed orders of magnitude faster on both GPU and CPU. It is designed for applications where computational resources are limited or where real-time performance is critical.
@@ -72,4 +86,281 @@ Please cite the [Model2Vec repository](https://github.com/MinishLab/model2vec) i
72
  year = {2024},
73
  url = {https://github.com/MinishLab/model2vec},
74
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  ```
 
7
  - static-embeddings
8
  ---
9
 
10
+ ```
11
+ Average (All) 49.73
12
+ Average (MTEB) 49.76
13
+ Classification 59.56
14
+ Clustering 30.55
15
+ PairClassification 76.38
16
+ Reranking 50.05
17
+ Retrieval 36.35
18
+ STS 73.22
19
+ Summarization 28.85
20
+ PEARL 49.31
21
+ WordSim 50.02
22
+ ```
23
+
24
  # potion-retrieval-512dim-60kvocab-v1replica-v1 Model Card
25
 
26
  This [Model2Vec](https://github.com/MinishLab/model2vec) model is a distilled version of a Sentence Transformer. It uses static embeddings, allowing text embeddings to be computed orders of magnitude faster on both GPU and CPU. It is designed for applications where computational resources are limited or where real-time performance is critical.
 
86
  year = {2024},
87
  url = {https://github.com/MinishLab/model2vec},
88
  }
89
+ ```
90
+
91
+ ## Reproducibility
92
+
93
+ ```python
94
+ import random
95
+ import logging
96
+ from datasets import load_dataset, Dataset, DatasetDict
97
+ from sentence_transformers import (
98
+ SentenceTransformer,
99
+ SentenceTransformerTrainer,
100
+ SentenceTransformerTrainingArguments,
101
+ SentenceTransformerModelCardData,
102
+ )
103
+ from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
104
+ from sentence_transformers.training_args import BatchSamplers, MultiDatasetBatchSamplers
105
+ from sentence_transformers.evaluation import NanoBEIREvaluator
106
+ from sentence_transformers.models.StaticEmbedding import StaticEmbedding
107
+ import wandb
108
+ from transformers import AutoTokenizer
109
+ from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
110
+ from torch.optim import AdamW
111
+ from transformers import get_linear_schedule_with_warmup
112
+ from torch.optim.lr_scheduler import CosineAnnealingLR
113
+
114
+ logging.basicConfig(
115
+ format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
116
+ )
117
+ random.seed(12)
118
+
119
+
120
+ def load_train_eval_datasets():
121
+ """
122
+ Either load the train and eval datasets from disk or load them from the datasets library & save them to disk.
123
+
124
+ Upon saving to disk, we quit() to ensure that the datasets are not loaded into memory before training.
125
+ """
126
+ try:
127
+ train_dataset = DatasetDict.load_from_disk("datasets/train_dataset")
128
+ eval_dataset = DatasetDict.load_from_disk("datasets/eval_dataset")
129
+ return train_dataset, eval_dataset
130
+ except FileNotFoundError:
131
+ print("Loading gooaq dataset...")
132
+ gooaq_dataset = load_dataset("sentence-transformers/gooaq", split="train")
133
+ gooaq_dataset_dict = gooaq_dataset.train_test_split(test_size=10_000, seed=12)
134
+ gooaq_train_dataset: Dataset = gooaq_dataset_dict["train"]
135
+ gooaq_eval_dataset: Dataset = gooaq_dataset_dict["test"]
136
+ print("Loaded gooaq dataset.")
137
+
138
+ print("Loading msmarco dataset...")
139
+ msmarco_dataset = load_dataset("sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1", "triplet", split="train")
140
+ msmarco_dataset_dict = msmarco_dataset.train_test_split(test_size=10_000, seed=12)
141
+ msmarco_train_dataset: Dataset = msmarco_dataset_dict["train"]
142
+ msmarco_eval_dataset: Dataset = msmarco_dataset_dict["test"]
143
+ print("Loaded msmarco dataset.")
144
+
145
+ print("Loading squad dataset...")
146
+ squad_dataset = load_dataset("sentence-transformers/squad", split="train")
147
+ squad_dataset_dict = squad_dataset.train_test_split(test_size=10_000, seed=12)
148
+ squad_train_dataset: Dataset = squad_dataset_dict["train"]
149
+ squad_eval_dataset: Dataset = squad_dataset_dict["test"]
150
+ print("Loaded squad dataset.")
151
+
152
+ print("Loading s2orc dataset...")
153
+ s2orc_dataset = load_dataset("sentence-transformers/s2orc", "title-abstract-pair", split="train[:100000]")
154
+ s2orc_dataset_dict = s2orc_dataset.train_test_split(test_size=10_000, seed=12)
155
+ s2orc_train_dataset: Dataset = s2orc_dataset_dict["train"]
156
+ s2orc_eval_dataset: Dataset = s2orc_dataset_dict["test"]
157
+ print("Loaded s2orc dataset.")
158
+
159
+ print("Loading allnli dataset...")
160
+ allnli_train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train")
161
+ allnli_eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")
162
+ print("Loaded allnli dataset.")
163
+
164
+ print("Loading paq dataset...")
165
+ paq_dataset = load_dataset("sentence-transformers/paq", split="train")
166
+ paq_dataset_dict = paq_dataset.train_test_split(test_size=10_000, seed=12)
167
+ paq_train_dataset: Dataset = paq_dataset_dict["train"]
168
+ paq_eval_dataset: Dataset = paq_dataset_dict["test"]
169
+ print("Loaded paq dataset.")
170
+
171
+ print("Loading trivia_qa dataset...")
172
+ trivia_qa = load_dataset("sentence-transformers/trivia-qa", split="train")
173
+ trivia_qa_dataset_dict = trivia_qa.train_test_split(test_size=5_000, seed=12)
174
+ trivia_qa_train_dataset: Dataset = trivia_qa_dataset_dict["train"]
175
+ trivia_qa_eval_dataset: Dataset = trivia_qa_dataset_dict["test"]
176
+ print("Loaded trivia_qa dataset.")
177
+
178
+ print("Loading msmarco_10m dataset...")
179
+ msmarco_10m_dataset = load_dataset("bclavie/msmarco-10m-triplets", split="train")
180
+ msmarco_10m_dataset_dict = msmarco_10m_dataset.train_test_split(test_size=10_000, seed=12)
181
+ msmarco_10m_train_dataset: Dataset = msmarco_10m_dataset_dict["train"]
182
+ msmarco_10m_eval_dataset: Dataset = msmarco_10m_dataset_dict["test"]
183
+ print("Loaded msmarco_10m dataset.")
184
+
185
+ print("Loading swim_ir dataset...")
186
+ swim_ir_dataset = load_dataset("nthakur/swim-ir-monolingual", "en", split="train").select_columns(["query", "text"])
187
+ swim_ir_dataset_dict = swim_ir_dataset.train_test_split(test_size=10_000, seed=12)
188
+ swim_ir_train_dataset: Dataset = swim_ir_dataset_dict["train"]
189
+ swim_ir_eval_dataset: Dataset = swim_ir_dataset_dict["test"]
190
+ print("Loaded swim_ir dataset.")
191
+
192
+ # NOTE: 20 negatives
193
+ print("Loading pubmedqa dataset...")
194
+ pubmedqa_dataset = load_dataset("sentence-transformers/pubmedqa", "triplet-20", split="train")
195
+ pubmedqa_dataset_dict = pubmedqa_dataset.train_test_split(test_size=100, seed=12)
196
+ pubmedqa_train_dataset: Dataset = pubmedqa_dataset_dict["train"]
197
+ pubmedqa_eval_dataset: Dataset = pubmedqa_dataset_dict["test"]
198
+ print("Loaded pubmedqa dataset.")
199
+
200
+ # NOTE: A lot of overlap with anchor/positives
201
+ print("Loading miracl dataset...")
202
+ miracl_dataset = load_dataset("sentence-transformers/miracl", "en-triplet-all", split="train")
203
+ miracl_dataset_dict = miracl_dataset.train_test_split(test_size=10_000, seed=12)
204
+ miracl_train_dataset: Dataset = miracl_dataset_dict["train"]
205
+ miracl_eval_dataset: Dataset = miracl_dataset_dict["test"]
206
+ print("Loaded miracl dataset.")
207
+
208
+ # NOTE: A lot of overlap with anchor/positives
209
+ print("Loading mldr dataset...")
210
+ mldr_dataset = load_dataset("sentence-transformers/mldr", "en-triplet-all", split="train")
211
+ mldr_dataset_dict = mldr_dataset.train_test_split(test_size=10_000, seed=12)
212
+ mldr_train_dataset: Dataset = mldr_dataset_dict["train"]
213
+ mldr_eval_dataset: Dataset = mldr_dataset_dict["test"]
214
+ print("Loaded mldr dataset.")
215
+
216
+ # NOTE: A lot of overlap with anchor/positives
217
+ print("Loading mr_tydi dataset...")
218
+ mr_tydi_dataset = load_dataset("sentence-transformers/mr-tydi", "en-triplet-all", split="train")
219
+ mr_tydi_dataset_dict = mr_tydi_dataset.train_test_split(test_size=10_000, seed=12)
220
+ mr_tydi_train_dataset: Dataset = mr_tydi_dataset_dict["train"]
221
+ mr_tydi_eval_dataset: Dataset = mr_tydi_dataset_dict["test"]
222
+ print("Loaded mr_tydi dataset.")
223
+
224
+ train_dataset = DatasetDict({
225
+ "gooaq": gooaq_train_dataset,
226
+ "msmarco": msmarco_train_dataset,
227
+ "squad": squad_train_dataset,
228
+ "s2orc": s2orc_train_dataset,
229
+ "allnli": allnli_train_dataset,
230
+ "paq": paq_train_dataset,
231
+ "trivia_qa": trivia_qa_train_dataset,
232
+ "msmarco_10m": msmarco_10m_train_dataset,
233
+ "swim_ir": swim_ir_train_dataset,
234
+ "pubmedqa": pubmedqa_train_dataset,
235
+ "miracl": miracl_train_dataset,
236
+ "mldr": mldr_train_dataset,
237
+ "mr_tydi": mr_tydi_train_dataset,
238
+ })
239
+ eval_dataset = DatasetDict({
240
+ "gooaq": gooaq_eval_dataset,
241
+ "msmarco": msmarco_eval_dataset,
242
+ "squad": squad_eval_dataset,
243
+ "s2orc": s2orc_eval_dataset,
244
+ "allnli": allnli_eval_dataset,
245
+ "paq": paq_eval_dataset,
246
+ "trivia_qa": trivia_qa_eval_dataset,
247
+ "msmarco_10m": msmarco_10m_eval_dataset,
248
+ "swim_ir": swim_ir_eval_dataset,
249
+ "pubmedqa": pubmedqa_eval_dataset,
250
+ "miracl": miracl_eval_dataset,
251
+ "mldr": mldr_eval_dataset,
252
+ "mr_tydi": mr_tydi_eval_dataset,
253
+ })
254
+
255
+ train_dataset.save_to_disk("datasets/train_dataset")
256
+ eval_dataset.save_to_disk("datasets/eval_dataset")
257
+
258
+ # The `train_test_split` calls have put a lot of the datasets in memory, while we want it to just be on disk
259
+ quit()
260
+
261
+
262
+
263
+ def load_train_eval_datasets_reduced():
264
+ # 1. Load the full datasets from disk
265
+ train_dataset = DatasetDict.load_from_disk("datasets/train_dataset")
266
+ eval_dataset = DatasetDict.load_from_disk("datasets/eval_dataset")
267
+
268
+ factor = 10
269
+ for subset_name in train_dataset:
270
+ ds = train_dataset[subset_name]
271
+ ds = ds.shuffle(seed=42) # Shuffle for a random subset
272
+ new_len = len(ds) // factor # Keep 1/10th
273
+ ds = ds.select(range(new_len))
274
+ train_dataset[subset_name] = ds
275
+
276
+ for subset_name in eval_dataset:
277
+ ds = eval_dataset[subset_name]
278
+ ds = ds.shuffle(seed=42)
279
+ new_len = len(ds) // factor
280
+ ds = ds.select(range(new_len))
281
+ eval_dataset[subset_name] = ds
282
+
283
+ return train_dataset, eval_dataset
284
+
285
+ def main():
286
+ wandb.init(entity="minishlab", project="minishlab")
287
+ # 1. Load a model to finetune with 2. (Optional) model card data
288
+ static_embedding = StaticEmbedding.from_model2vec("minishlab/potion-base-512dim-60kvocab")
289
+ # 2. Initialize the SentenceTransformer model as usual
290
+
291
+ model_name = "potion-retrieval-512dim-60kvocab-v1"
292
+ model = SentenceTransformer(
293
+ modules=[static_embedding],
294
+ model_card_data=SentenceTransformerModelCardData(
295
+ language="en",
296
+ license="MIT",
297
+ model_name=model_name,
298
+ ),
299
+ )
300
+ # 3. Set up training & evaluation datasets - each dataset is trained with MNRL (with MRL)
301
+ train_dataset, eval_dataset = load_train_eval_datasets_reduced()
302
+ print(train_dataset)
303
+
304
+ # 4. Define a loss function
305
+ loss = MultipleNegativesRankingLoss(model)
306
+ loss = MatryoshkaLoss(model, loss, matryoshka_dims=[32, 64, 128, 256, 512])
307
+
308
+
309
+ # 5. (Optional) Specify training arguments
310
+ run_name = model_name
311
+ epochs = 3
312
+ lr = 0.05
313
+ args = SentenceTransformerTrainingArguments(
314
+ # Required parameter:
315
+ output_dir=f"models/{run_name}",
316
+ # Optional training parameters:
317
+ num_train_epochs=epochs,
318
+ per_device_train_batch_size=2048,
319
+ per_device_eval_batch_size=2048,
320
+ learning_rate=lr,
321
+ warmup_ratio=0.1,
322
+ fp16=False, # Set to False if you get an error that your GPU can't run on FP16
323
+ bf16=True, # Set to True if you have a GPU that supports BF16
324
+ batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
325
+ multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
326
+ # Optional tracking/debugging parameters:
327
+ eval_strategy="steps",
328
+ eval_steps=250,
329
+ save_strategy="steps",
330
+ save_steps=250,
331
+ save_total_limit=2,
332
+ logging_steps=250,
333
+ logging_first_step=True,
334
+ run_name=run_name, # Will be used in W&B if `wandb` is installed
335
+ report_to=["wandb"],
336
+ load_best_model_at_end=True,
337
+ metric_for_best_model="eval_NanoBEIR_mean_cosine_ndcg@10",
338
+ greater_is_better=True,
339
+ #
340
+ )
341
+
342
+ # 6. (Optional) Create an evaluator & evaluate the base model
343
+ evaluator = NanoBEIREvaluator()
344
+
345
+ evaluator(model)
346
+
347
+ # 7. Create a trainer & train
348
+ trainer = SentenceTransformerTrainer(
349
+ model=model,
350
+ args=args,
351
+ train_dataset=train_dataset,
352
+ eval_dataset=eval_dataset,
353
+ loss=loss,
354
+ evaluator=evaluator,
355
+ )
356
+ trainer.train()
357
+
358
+ # (Optional) Evaluate the trained model on the evaluator after training
359
+ evaluator(model)
360
+
361
+ # 8. Save the trained model
362
+ model.save_pretrained(f"models/{run_name}/final")
363
+
364
+ if __name__ == "__main__":
365
+ main()
366
  ```