Commit
·
189fb58
1
Parent(s):
589b81c
Upload model
Browse files- config.json +1 -1
- model.safetensors +3 -0
- modeling_backpack_gpt2_nli.py +9 -1
config.json
CHANGED
@@ -42,7 +42,7 @@
|
|
42 |
"summary_type": "cls_index",
|
43 |
"summary_use_proj": true,
|
44 |
"torch_dtype": "float32",
|
45 |
-
"transformers_version": "4.
|
46 |
"use_cache": true,
|
47 |
"vocab_size": 50264
|
48 |
}
|
|
|
42 |
"summary_type": "cls_index",
|
43 |
"summary_use_proj": true,
|
44 |
"torch_dtype": "float32",
|
45 |
+
"transformers_version": "4.35.2",
|
46 |
"use_cache": true,
|
47 |
"vocab_size": 50264
|
48 |
}
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a540f442266937e77756ce8730b0f0492dc80e51c21522a7e86984e3c6d21bfd
|
3 |
+
size 682724236
|
modeling_backpack_gpt2_nli.py
CHANGED
@@ -52,4 +52,12 @@ class BackpackGPT2NLIModel(GPT2PreTrainedModel):
|
|
52 |
loss = self.loss_func(flat_logits, flat_labels)
|
53 |
return {'logits': logits, 'loss': loss}
|
54 |
else:
|
55 |
-
return {'logits': logits}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
loss = self.loss_func(flat_logits, flat_labels)
|
53 |
return {'logits': logits, 'loss': loss}
|
54 |
else:
|
55 |
+
return {'logits': logits}
|
56 |
+
|
57 |
+
|
58 |
+
def predict(self, input_ids=None, attention_mask=None):
|
59 |
+
logits = self.forward(input_ids, attention_mask, labels=None)
|
60 |
+
p = torch.argmax(p, axis=1)
|
61 |
+
labels = [self.config.id2label[index] for index in p]
|
62 |
+
return labels
|
63 |
+
|