Agata Dobrzyniewicz
model added
3c98ba6
raw
history blame
699 Bytes
from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitModel, SetFitTrainer
dataset = load_dataset("ayakiri/wolo-app-categories-to-description")
train_ds = dataset["train"].shuffle(seed=42).select(range(8 * 2))
test_ds = dataset["test"]
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
trainer = SetFitTrainer(
model=model,
train_dataset=train_ds,
eval_dataset=test_ds,
loss_class=CosineSimilarityLoss,
batch_size=16,
num_iterations=20,
num_epochs=1
)
trainer.train()
metrics = trainer.evaluate()
trainer.push_to_hub("ayakiri/wolo-app-categories-setfit-model")