# Matryoshka test from collections import defaultdict from typing import Dict import datasets from datasets import Dataset from sentence_transformers import ( SentenceTransformer, SentenceTransformerTrainer, losses, evaluation, TrainingArguments ) from sentence_transformers.models import Transformer, Pooling def to_triplets(dataset): premises = defaultdict(dict) for sample in dataset: premises[sample["premise"]][sample["label"]] = sample["hypothesis"] queries = [] positives = [] negatives = [] for premise, sentences in premises.items(): if 0 in sentences and 2 in sentences: queries.append(premise) positives.append(sentences[0]) # <- entailment negatives.append(sentences[2]) # <- contradiction return Dataset.from_dict({ "anchor": queries, "positive": positives, "negative": negatives, }) snli_ds = datasets.load_dataset("snli") snli_ds = datasets.DatasetDict({ "train": to_triplets(snli_ds["train"]), "validation": to_triplets(snli_ds["validation"]), "test": to_triplets(snli_ds["test"]), }) multi_nli_ds = datasets.load_dataset("multi_nli") multi_nli_ds = datasets.DatasetDict({ "train": to_triplets(multi_nli_ds["train"]), "validation_matched": to_triplets(multi_nli_ds["validation_matched"]), }) all_nli_ds = datasets.DatasetDict({ "train": datasets.concatenate_datasets([snli_ds["train"], multi_nli_ds["train"]]), "validation": datasets.concatenate_datasets([snli_ds["validation"], multi_nli_ds["validation_matched"]]), "test": snli_ds["test"] }) stsb_dev = datasets.load_dataset("mteb/stsbenchmark-sts", split="validation") stsb_test = datasets.load_dataset("mteb/stsbenchmark-sts", split="test") training_args = TrainingArguments( output_dir="checkpoints", num_train_epochs=1, seed=42, per_device_train_batch_size=64, per_device_eval_batch_size=64, learning_rate=2e-5, warmup_ratio=0.1, bf16=True, logging_steps=10, evaluation_strategy="steps", eval_steps=300, save_steps=1000, save_total_limit=2, metric_for_best_model="spearman_cosine", greater_is_better=True, ) transformer = Transformer("microsoft/mpnet-base", max_seq_length=384) pooling = Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean") model = SentenceTransformer(modules=[transformer, pooling]) tokenizer = model.tokenizer loss = losses.MultipleNegativesRankingLoss(model) loss = losses.MatryoshkaLoss(model, loss, [768, 512, 256, 128, 64]) dev_evaluator = evaluation.EmbeddingSimilarityEvaluator( stsb_dev["sentence1"], stsb_dev["sentence2"], [score / 5 for score in stsb_dev["score"]], main_similarity=evaluation.SimilarityFunction.COSINE, name="sts-dev", ) trainer = SentenceTransformerTrainer( model=model, evaluator=dev_evaluator, args=training_args, train_dataset=all_nli_ds["train"], # eval_dataset=all_nli_ds["validation"], loss=loss, ) trainer.train() test_evaluator = evaluation.EmbeddingSimilarityEvaluator( stsb_test["sentence1"], stsb_test["sentence2"], [score / 5 for score in stsb_test["score"]], main_similarity=evaluation.SimilarityFunction.COSINE, name="sts-test", ) results = test_evaluator(model) print(results)