Update README.md
Browse files
README.md
CHANGED
@@ -7,6 +7,20 @@ tags:
|
|
7 |
- static-embeddings
|
8 |
---
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
# potion-retrieval-512dim-60kvocab-v1replica-v1 Model Card
|
11 |
|
12 |
This [Model2Vec](https://github.com/MinishLab/model2vec) model is a distilled version of a Sentence Transformer. It uses static embeddings, allowing text embeddings to be computed orders of magnitude faster on both GPU and CPU. It is designed for applications where computational resources are limited or where real-time performance is critical.
|
@@ -72,4 +86,281 @@ Please cite the [Model2Vec repository](https://github.com/MinishLab/model2vec) i
|
|
72 |
year = {2024},
|
73 |
url = {https://github.com/MinishLab/model2vec},
|
74 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
```
|
|
|
7 |
- static-embeddings
|
8 |
---
|
9 |
|
10 |
+
```
|
11 |
+
Average (All) 49.73
|
12 |
+
Average (MTEB) 49.76
|
13 |
+
Classification 59.56
|
14 |
+
Clustering 30.55
|
15 |
+
PairClassification 76.38
|
16 |
+
Reranking 50.05
|
17 |
+
Retrieval 36.35
|
18 |
+
STS 73.22
|
19 |
+
Summarization 28.85
|
20 |
+
PEARL 49.31
|
21 |
+
WordSim 50.02
|
22 |
+
```
|
23 |
+
|
24 |
# potion-retrieval-512dim-60kvocab-v1replica-v1 Model Card
|
25 |
|
26 |
This [Model2Vec](https://github.com/MinishLab/model2vec) model is a distilled version of a Sentence Transformer. It uses static embeddings, allowing text embeddings to be computed orders of magnitude faster on both GPU and CPU. It is designed for applications where computational resources are limited or where real-time performance is critical.
|
|
|
86 |
year = {2024},
|
87 |
url = {https://github.com/MinishLab/model2vec},
|
88 |
}
|
89 |
+
```
|
90 |
+
|
91 |
+
## Reproducibility
|
92 |
+
|
93 |
+
```python
|
94 |
+
import random
|
95 |
+
import logging
|
96 |
+
from datasets import load_dataset, Dataset, DatasetDict
|
97 |
+
from sentence_transformers import (
|
98 |
+
SentenceTransformer,
|
99 |
+
SentenceTransformerTrainer,
|
100 |
+
SentenceTransformerTrainingArguments,
|
101 |
+
SentenceTransformerModelCardData,
|
102 |
+
)
|
103 |
+
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
|
104 |
+
from sentence_transformers.training_args import BatchSamplers, MultiDatasetBatchSamplers
|
105 |
+
from sentence_transformers.evaluation import NanoBEIREvaluator
|
106 |
+
from sentence_transformers.models.StaticEmbedding import StaticEmbedding
|
107 |
+
import wandb
|
108 |
+
from transformers import AutoTokenizer
|
109 |
+
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
|
110 |
+
from torch.optim import AdamW
|
111 |
+
from transformers import get_linear_schedule_with_warmup
|
112 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
113 |
+
|
114 |
+
logging.basicConfig(
|
115 |
+
format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
|
116 |
+
)
|
117 |
+
random.seed(12)
|
118 |
+
|
119 |
+
|
120 |
+
def load_train_eval_datasets():
|
121 |
+
"""
|
122 |
+
Either load the train and eval datasets from disk or load them from the datasets library & save them to disk.
|
123 |
+
|
124 |
+
Upon saving to disk, we quit() to ensure that the datasets are not loaded into memory before training.
|
125 |
+
"""
|
126 |
+
try:
|
127 |
+
train_dataset = DatasetDict.load_from_disk("datasets/train_dataset")
|
128 |
+
eval_dataset = DatasetDict.load_from_disk("datasets/eval_dataset")
|
129 |
+
return train_dataset, eval_dataset
|
130 |
+
except FileNotFoundError:
|
131 |
+
print("Loading gooaq dataset...")
|
132 |
+
gooaq_dataset = load_dataset("sentence-transformers/gooaq", split="train")
|
133 |
+
gooaq_dataset_dict = gooaq_dataset.train_test_split(test_size=10_000, seed=12)
|
134 |
+
gooaq_train_dataset: Dataset = gooaq_dataset_dict["train"]
|
135 |
+
gooaq_eval_dataset: Dataset = gooaq_dataset_dict["test"]
|
136 |
+
print("Loaded gooaq dataset.")
|
137 |
+
|
138 |
+
print("Loading msmarco dataset...")
|
139 |
+
msmarco_dataset = load_dataset("sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1", "triplet", split="train")
|
140 |
+
msmarco_dataset_dict = msmarco_dataset.train_test_split(test_size=10_000, seed=12)
|
141 |
+
msmarco_train_dataset: Dataset = msmarco_dataset_dict["train"]
|
142 |
+
msmarco_eval_dataset: Dataset = msmarco_dataset_dict["test"]
|
143 |
+
print("Loaded msmarco dataset.")
|
144 |
+
|
145 |
+
print("Loading squad dataset...")
|
146 |
+
squad_dataset = load_dataset("sentence-transformers/squad", split="train")
|
147 |
+
squad_dataset_dict = squad_dataset.train_test_split(test_size=10_000, seed=12)
|
148 |
+
squad_train_dataset: Dataset = squad_dataset_dict["train"]
|
149 |
+
squad_eval_dataset: Dataset = squad_dataset_dict["test"]
|
150 |
+
print("Loaded squad dataset.")
|
151 |
+
|
152 |
+
print("Loading s2orc dataset...")
|
153 |
+
s2orc_dataset = load_dataset("sentence-transformers/s2orc", "title-abstract-pair", split="train[:100000]")
|
154 |
+
s2orc_dataset_dict = s2orc_dataset.train_test_split(test_size=10_000, seed=12)
|
155 |
+
s2orc_train_dataset: Dataset = s2orc_dataset_dict["train"]
|
156 |
+
s2orc_eval_dataset: Dataset = s2orc_dataset_dict["test"]
|
157 |
+
print("Loaded s2orc dataset.")
|
158 |
+
|
159 |
+
print("Loading allnli dataset...")
|
160 |
+
allnli_train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train")
|
161 |
+
allnli_eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")
|
162 |
+
print("Loaded allnli dataset.")
|
163 |
+
|
164 |
+
print("Loading paq dataset...")
|
165 |
+
paq_dataset = load_dataset("sentence-transformers/paq", split="train")
|
166 |
+
paq_dataset_dict = paq_dataset.train_test_split(test_size=10_000, seed=12)
|
167 |
+
paq_train_dataset: Dataset = paq_dataset_dict["train"]
|
168 |
+
paq_eval_dataset: Dataset = paq_dataset_dict["test"]
|
169 |
+
print("Loaded paq dataset.")
|
170 |
+
|
171 |
+
print("Loading trivia_qa dataset...")
|
172 |
+
trivia_qa = load_dataset("sentence-transformers/trivia-qa", split="train")
|
173 |
+
trivia_qa_dataset_dict = trivia_qa.train_test_split(test_size=5_000, seed=12)
|
174 |
+
trivia_qa_train_dataset: Dataset = trivia_qa_dataset_dict["train"]
|
175 |
+
trivia_qa_eval_dataset: Dataset = trivia_qa_dataset_dict["test"]
|
176 |
+
print("Loaded trivia_qa dataset.")
|
177 |
+
|
178 |
+
print("Loading msmarco_10m dataset...")
|
179 |
+
msmarco_10m_dataset = load_dataset("bclavie/msmarco-10m-triplets", split="train")
|
180 |
+
msmarco_10m_dataset_dict = msmarco_10m_dataset.train_test_split(test_size=10_000, seed=12)
|
181 |
+
msmarco_10m_train_dataset: Dataset = msmarco_10m_dataset_dict["train"]
|
182 |
+
msmarco_10m_eval_dataset: Dataset = msmarco_10m_dataset_dict["test"]
|
183 |
+
print("Loaded msmarco_10m dataset.")
|
184 |
+
|
185 |
+
print("Loading swim_ir dataset...")
|
186 |
+
swim_ir_dataset = load_dataset("nthakur/swim-ir-monolingual", "en", split="train").select_columns(["query", "text"])
|
187 |
+
swim_ir_dataset_dict = swim_ir_dataset.train_test_split(test_size=10_000, seed=12)
|
188 |
+
swim_ir_train_dataset: Dataset = swim_ir_dataset_dict["train"]
|
189 |
+
swim_ir_eval_dataset: Dataset = swim_ir_dataset_dict["test"]
|
190 |
+
print("Loaded swim_ir dataset.")
|
191 |
+
|
192 |
+
# NOTE: 20 negatives
|
193 |
+
print("Loading pubmedqa dataset...")
|
194 |
+
pubmedqa_dataset = load_dataset("sentence-transformers/pubmedqa", "triplet-20", split="train")
|
195 |
+
pubmedqa_dataset_dict = pubmedqa_dataset.train_test_split(test_size=100, seed=12)
|
196 |
+
pubmedqa_train_dataset: Dataset = pubmedqa_dataset_dict["train"]
|
197 |
+
pubmedqa_eval_dataset: Dataset = pubmedqa_dataset_dict["test"]
|
198 |
+
print("Loaded pubmedqa dataset.")
|
199 |
+
|
200 |
+
# NOTE: A lot of overlap with anchor/positives
|
201 |
+
print("Loading miracl dataset...")
|
202 |
+
miracl_dataset = load_dataset("sentence-transformers/miracl", "en-triplet-all", split="train")
|
203 |
+
miracl_dataset_dict = miracl_dataset.train_test_split(test_size=10_000, seed=12)
|
204 |
+
miracl_train_dataset: Dataset = miracl_dataset_dict["train"]
|
205 |
+
miracl_eval_dataset: Dataset = miracl_dataset_dict["test"]
|
206 |
+
print("Loaded miracl dataset.")
|
207 |
+
|
208 |
+
# NOTE: A lot of overlap with anchor/positives
|
209 |
+
print("Loading mldr dataset...")
|
210 |
+
mldr_dataset = load_dataset("sentence-transformers/mldr", "en-triplet-all", split="train")
|
211 |
+
mldr_dataset_dict = mldr_dataset.train_test_split(test_size=10_000, seed=12)
|
212 |
+
mldr_train_dataset: Dataset = mldr_dataset_dict["train"]
|
213 |
+
mldr_eval_dataset: Dataset = mldr_dataset_dict["test"]
|
214 |
+
print("Loaded mldr dataset.")
|
215 |
+
|
216 |
+
# NOTE: A lot of overlap with anchor/positives
|
217 |
+
print("Loading mr_tydi dataset...")
|
218 |
+
mr_tydi_dataset = load_dataset("sentence-transformers/mr-tydi", "en-triplet-all", split="train")
|
219 |
+
mr_tydi_dataset_dict = mr_tydi_dataset.train_test_split(test_size=10_000, seed=12)
|
220 |
+
mr_tydi_train_dataset: Dataset = mr_tydi_dataset_dict["train"]
|
221 |
+
mr_tydi_eval_dataset: Dataset = mr_tydi_dataset_dict["test"]
|
222 |
+
print("Loaded mr_tydi dataset.")
|
223 |
+
|
224 |
+
train_dataset = DatasetDict({
|
225 |
+
"gooaq": gooaq_train_dataset,
|
226 |
+
"msmarco": msmarco_train_dataset,
|
227 |
+
"squad": squad_train_dataset,
|
228 |
+
"s2orc": s2orc_train_dataset,
|
229 |
+
"allnli": allnli_train_dataset,
|
230 |
+
"paq": paq_train_dataset,
|
231 |
+
"trivia_qa": trivia_qa_train_dataset,
|
232 |
+
"msmarco_10m": msmarco_10m_train_dataset,
|
233 |
+
"swim_ir": swim_ir_train_dataset,
|
234 |
+
"pubmedqa": pubmedqa_train_dataset,
|
235 |
+
"miracl": miracl_train_dataset,
|
236 |
+
"mldr": mldr_train_dataset,
|
237 |
+
"mr_tydi": mr_tydi_train_dataset,
|
238 |
+
})
|
239 |
+
eval_dataset = DatasetDict({
|
240 |
+
"gooaq": gooaq_eval_dataset,
|
241 |
+
"msmarco": msmarco_eval_dataset,
|
242 |
+
"squad": squad_eval_dataset,
|
243 |
+
"s2orc": s2orc_eval_dataset,
|
244 |
+
"allnli": allnli_eval_dataset,
|
245 |
+
"paq": paq_eval_dataset,
|
246 |
+
"trivia_qa": trivia_qa_eval_dataset,
|
247 |
+
"msmarco_10m": msmarco_10m_eval_dataset,
|
248 |
+
"swim_ir": swim_ir_eval_dataset,
|
249 |
+
"pubmedqa": pubmedqa_eval_dataset,
|
250 |
+
"miracl": miracl_eval_dataset,
|
251 |
+
"mldr": mldr_eval_dataset,
|
252 |
+
"mr_tydi": mr_tydi_eval_dataset,
|
253 |
+
})
|
254 |
+
|
255 |
+
train_dataset.save_to_disk("datasets/train_dataset")
|
256 |
+
eval_dataset.save_to_disk("datasets/eval_dataset")
|
257 |
+
|
258 |
+
# The `train_test_split` calls have put a lot of the datasets in memory, while we want it to just be on disk
|
259 |
+
quit()
|
260 |
+
|
261 |
+
|
262 |
+
|
263 |
+
def load_train_eval_datasets_reduced():
|
264 |
+
# 1. Load the full datasets from disk
|
265 |
+
train_dataset = DatasetDict.load_from_disk("datasets/train_dataset")
|
266 |
+
eval_dataset = DatasetDict.load_from_disk("datasets/eval_dataset")
|
267 |
+
|
268 |
+
factor = 10
|
269 |
+
for subset_name in train_dataset:
|
270 |
+
ds = train_dataset[subset_name]
|
271 |
+
ds = ds.shuffle(seed=42) # Shuffle for a random subset
|
272 |
+
new_len = len(ds) // factor # Keep 1/10th
|
273 |
+
ds = ds.select(range(new_len))
|
274 |
+
train_dataset[subset_name] = ds
|
275 |
+
|
276 |
+
for subset_name in eval_dataset:
|
277 |
+
ds = eval_dataset[subset_name]
|
278 |
+
ds = ds.shuffle(seed=42)
|
279 |
+
new_len = len(ds) // factor
|
280 |
+
ds = ds.select(range(new_len))
|
281 |
+
eval_dataset[subset_name] = ds
|
282 |
+
|
283 |
+
return train_dataset, eval_dataset
|
284 |
+
|
285 |
+
def main():
|
286 |
+
wandb.init(entity="minishlab", project="minishlab")
|
287 |
+
# 1. Load a model to finetune with 2. (Optional) model card data
|
288 |
+
static_embedding = StaticEmbedding.from_model2vec("minishlab/potion-base-512dim-60kvocab")
|
289 |
+
# 2. Initialize the SentenceTransformer model as usual
|
290 |
+
|
291 |
+
model_name = "potion-retrieval-512dim-60kvocab-v1"
|
292 |
+
model = SentenceTransformer(
|
293 |
+
modules=[static_embedding],
|
294 |
+
model_card_data=SentenceTransformerModelCardData(
|
295 |
+
language="en",
|
296 |
+
license="MIT",
|
297 |
+
model_name=model_name,
|
298 |
+
),
|
299 |
+
)
|
300 |
+
# 3. Set up training & evaluation datasets - each dataset is trained with MNRL (with MRL)
|
301 |
+
train_dataset, eval_dataset = load_train_eval_datasets_reduced()
|
302 |
+
print(train_dataset)
|
303 |
+
|
304 |
+
# 4. Define a loss function
|
305 |
+
loss = MultipleNegativesRankingLoss(model)
|
306 |
+
loss = MatryoshkaLoss(model, loss, matryoshka_dims=[32, 64, 128, 256, 512])
|
307 |
+
|
308 |
+
|
309 |
+
# 5. (Optional) Specify training arguments
|
310 |
+
run_name = model_name
|
311 |
+
epochs = 3
|
312 |
+
lr = 0.05
|
313 |
+
args = SentenceTransformerTrainingArguments(
|
314 |
+
# Required parameter:
|
315 |
+
output_dir=f"models/{run_name}",
|
316 |
+
# Optional training parameters:
|
317 |
+
num_train_epochs=epochs,
|
318 |
+
per_device_train_batch_size=2048,
|
319 |
+
per_device_eval_batch_size=2048,
|
320 |
+
learning_rate=lr,
|
321 |
+
warmup_ratio=0.1,
|
322 |
+
fp16=False, # Set to False if you get an error that your GPU can't run on FP16
|
323 |
+
bf16=True, # Set to True if you have a GPU that supports BF16
|
324 |
+
batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
|
325 |
+
multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
|
326 |
+
# Optional tracking/debugging parameters:
|
327 |
+
eval_strategy="steps",
|
328 |
+
eval_steps=250,
|
329 |
+
save_strategy="steps",
|
330 |
+
save_steps=250,
|
331 |
+
save_total_limit=2,
|
332 |
+
logging_steps=250,
|
333 |
+
logging_first_step=True,
|
334 |
+
run_name=run_name, # Will be used in W&B if `wandb` is installed
|
335 |
+
report_to=["wandb"],
|
336 |
+
load_best_model_at_end=True,
|
337 |
+
metric_for_best_model="eval_NanoBEIR_mean_cosine_ndcg@10",
|
338 |
+
greater_is_better=True,
|
339 |
+
#
|
340 |
+
)
|
341 |
+
|
342 |
+
# 6. (Optional) Create an evaluator & evaluate the base model
|
343 |
+
evaluator = NanoBEIREvaluator()
|
344 |
+
|
345 |
+
evaluator(model)
|
346 |
+
|
347 |
+
# 7. Create a trainer & train
|
348 |
+
trainer = SentenceTransformerTrainer(
|
349 |
+
model=model,
|
350 |
+
args=args,
|
351 |
+
train_dataset=train_dataset,
|
352 |
+
eval_dataset=eval_dataset,
|
353 |
+
loss=loss,
|
354 |
+
evaluator=evaluator,
|
355 |
+
)
|
356 |
+
trainer.train()
|
357 |
+
|
358 |
+
# (Optional) Evaluate the trained model on the evaluator after training
|
359 |
+
evaluator(model)
|
360 |
+
|
361 |
+
# 8. Save the trained model
|
362 |
+
model.save_pretrained(f"models/{run_name}/final")
|
363 |
+
|
364 |
+
if __name__ == "__main__":
|
365 |
+
main()
|
366 |
```
|