Agata Dobrzyniewicz commited on
Commit
3c98ba6
1 Parent(s): dcea567

model added

Browse files
model.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from sentence_transformers.losses import CosineSimilarityLoss
3
+
4
+ from setfit import SetFitModel, SetFitTrainer
5
+
6
+ dataset = load_dataset("ayakiri/wolo-app-categories-to-description")
7
+
8
+ train_ds = dataset["train"].shuffle(seed=42).select(range(8 * 2))
9
+ test_ds = dataset["test"]
10
+
11
+ model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
12
+
13
+ trainer = SetFitTrainer(
14
+ model=model,
15
+ train_dataset=train_ds,
16
+ eval_dataset=test_ds,
17
+ loss_class=CosineSimilarityLoss,
18
+ batch_size=16,
19
+ num_iterations=20,
20
+ num_epochs=1
21
+ )
22
+
23
+ trainer.train()
24
+ metrics = trainer.evaluate()
25
+
26
+ trainer.push_to_hub("ayakiri/wolo-app-categories-setfit-model")
runs/Feb01_12-42-17_DESKTOP-S8RJVAJ/events.out.tfevents.1706787738.DESKTOP-S8RJVAJ ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb9aab020a6e8863b120be040b3823a6494c90cf0dcb3e5985e6b8ec32c94b30
3
+ size 2752
runs/Feb01_12-47-52_DESKTOP-S8RJVAJ/events.out.tfevents.1706788072.DESKTOP-S8RJVAJ ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5afdd071d0a65cedd62e21d513025a3f7a0f1525f388da6afb9161d99a6f5070
3
+ size 2752