mpnet-base-nli-matryoshka / training_nli_matryoshka.py
tomaarsen's picture
tomaarsen HF staff
Create training_nli_matryoshka.py
82dda1b verified
raw
history blame
No virus
3.32 kB
# Matryoshka test
from collections import defaultdict
from typing import Dict
import datasets
from datasets import Dataset
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
losses,
evaluation,
TrainingArguments
)
from sentence_transformers.models import Transformer, Pooling
def to_triplets(dataset):
premises = defaultdict(dict)
for sample in dataset:
premises[sample["premise"]][sample["label"]] = sample["hypothesis"]
queries = []
positives = []
negatives = []
for premise, sentences in premises.items():
if 0 in sentences and 2 in sentences:
queries.append(premise)
positives.append(sentences[0]) # <- entailment
negatives.append(sentences[2]) # <- contradiction
return Dataset.from_dict({
"anchor": queries,
"positive": positives,
"negative": negatives,
})
snli_ds = datasets.load_dataset("snli")
snli_ds = datasets.DatasetDict({
"train": to_triplets(snli_ds["train"]),
"validation": to_triplets(snli_ds["validation"]),
"test": to_triplets(snli_ds["test"]),
})
multi_nli_ds = datasets.load_dataset("multi_nli")
multi_nli_ds = datasets.DatasetDict({
"train": to_triplets(multi_nli_ds["train"]),
"validation_matched": to_triplets(multi_nli_ds["validation_matched"]),
})
all_nli_ds = datasets.DatasetDict({
"train": datasets.concatenate_datasets([snli_ds["train"], multi_nli_ds["train"]]),
"validation": datasets.concatenate_datasets([snli_ds["validation"], multi_nli_ds["validation_matched"]]),
"test": snli_ds["test"]
})
stsb_dev = datasets.load_dataset("mteb/stsbenchmark-sts", split="validation")
stsb_test = datasets.load_dataset("mteb/stsbenchmark-sts", split="test")
training_args = TrainingArguments(
output_dir="checkpoints",
num_train_epochs=1,
seed=42,
per_device_train_batch_size=64,
per_device_eval_batch_size=64,
learning_rate=2e-5,
warmup_ratio=0.1,
bf16=True,
logging_steps=10,
evaluation_strategy="steps",
eval_steps=300,
save_steps=1000,
save_total_limit=2,
metric_for_best_model="spearman_cosine",
greater_is_better=True,
)
transformer = Transformer("microsoft/mpnet-base", max_seq_length=384)
pooling = Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean")
model = SentenceTransformer(modules=[transformer, pooling])
tokenizer = model.tokenizer
loss = losses.MultipleNegativesRankingLoss(model)
loss = losses.MatryoshkaLoss(model, loss, [768, 512, 256, 128, 64])
dev_evaluator = evaluation.EmbeddingSimilarityEvaluator(
stsb_dev["sentence1"],
stsb_dev["sentence2"],
[score / 5 for score in stsb_dev["score"]],
main_similarity=evaluation.SimilarityFunction.COSINE,
name="sts-dev",
)
trainer = SentenceTransformerTrainer(
model=model,
evaluator=dev_evaluator,
args=training_args,
train_dataset=all_nli_ds["train"],
# eval_dataset=all_nli_ds["validation"],
loss=loss,
)
trainer.train()
test_evaluator = evaluation.EmbeddingSimilarityEvaluator(
stsb_test["sentence1"],
stsb_test["sentence2"],
[score / 5 for score in stsb_test["score"]],
main_similarity=evaluation.SimilarityFunction.COSINE,
name="sts-test",
)
results = test_evaluator(model)
print(results)