File size: 699 Bytes
3c98ba6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss

from setfit import SetFitModel, SetFitTrainer

dataset = load_dataset("ayakiri/wolo-app-categories-to-description")

train_ds = dataset["train"].shuffle(seed=42).select(range(8 * 2))
test_ds = dataset["test"]

model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")

trainer = SetFitTrainer(
    model=model,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    loss_class=CosineSimilarityLoss,
    batch_size=16,
    num_iterations=20,
    num_epochs=1
)

trainer.train()
metrics = trainer.evaluate()

trainer.push_to_hub("ayakiri/wolo-app-categories-setfit-model")