julien-c HF staff commited on
Commit
2a130d4
·
1 Parent(s): cd4e66b

Migrate model card from transformers-repo

Browse files

Read announcement at https://discuss.huggingface.co/t/announcement-all-model-cards-will-be-migrated-to-hf-co-model-repos/2755
Original file history: https://github.com/huggingface/transformers/commits/master/model_cards/ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli/README.md

Files changed (1) hide show
  1. README.md +82 -0
README.md ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ datasets:
3
+ - snli
4
+ - anli
5
+ - multi_nli
6
+ - multi_nli_mismatch
7
+ - fever
8
+ license: mit
9
+ ---
10
+ This is a strong pre-trained RoBERTa-Large NLI model.
11
+
12
+ The training data is a combination of well-known NLI datasets: [`SNLI`](https://nlp.stanford.edu/projects/snli/), [`MNLI`](https://cims.nyu.edu/~sbowman/multinli/), [`FEVER-NLI`](https://github.com/easonnie/combine-FEVER-NSMN/blob/master/other_resources/nli_fever.md), [`ANLI (R1, R2, R3)`](https://github.com/facebookresearch/anli).
13
+ Other pre-trained NLI models including `RoBERTa`, `ALBert`, `BART`, `ELECTRA`, `XLNet` are also available.
14
+
15
+ Trained by [Yixin Nie](https://easonnie.github.io), [original source](https://github.com/facebookresearch/anli).
16
+
17
+ Try the code snippet below.
18
+ ```
19
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
20
+ import torch
21
+
22
+ if __name__ == '__main__':
23
+ max_length = 256
24
+
25
+ premise = "Two women are embracing while holding to go packages."
26
+ hypothesis = "The men are fighting outside a deli."
27
+
28
+ hg_model_hub_name = "ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli"
29
+ # hg_model_hub_name = "ynie/albert-xxlarge-v2-snli_mnli_fever_anli_R1_R2_R3-nli"
30
+ # hg_model_hub_name = "ynie/bart-large-snli_mnli_fever_anli_R1_R2_R3-nli"
31
+ # hg_model_hub_name = "ynie/electra-large-discriminator-snli_mnli_fever_anli_R1_R2_R3-nli"
32
+ # hg_model_hub_name = "ynie/xlnet-large-cased-snli_mnli_fever_anli_R1_R2_R3-nli"
33
+
34
+ tokenizer = AutoTokenizer.from_pretrained(hg_model_hub_name)
35
+ model = AutoModelForSequenceClassification.from_pretrained(hg_model_hub_name)
36
+
37
+ tokenized_input_seq_pair = tokenizer.encode_plus(premise, hypothesis,
38
+ max_length=max_length,
39
+ return_token_type_ids=True, truncation=True)
40
+
41
+ input_ids = torch.Tensor(tokenized_input_seq_pair['input_ids']).long().unsqueeze(0)
42
+ # remember bart doesn't have 'token_type_ids', remove the line below if you are using bart.
43
+ token_type_ids = torch.Tensor(tokenized_input_seq_pair['token_type_ids']).long().unsqueeze(0)
44
+ attention_mask = torch.Tensor(tokenized_input_seq_pair['attention_mask']).long().unsqueeze(0)
45
+
46
+ outputs = model(input_ids,
47
+ attention_mask=attention_mask,
48
+ token_type_ids=token_type_ids,
49
+ labels=None)
50
+ # Note:
51
+ # "id2label": {
52
+ # "0": "entailment",
53
+ # "1": "neutral",
54
+ # "2": "contradiction"
55
+ # },
56
+
57
+ predicted_probability = torch.softmax(outputs[0], dim=1)[0].tolist() # batch_size only one
58
+
59
+ print("Premise:", premise)
60
+ print("Hypothesis:", hypothesis)
61
+ print("Entailment:", predicted_probability[0])
62
+ print("Neutral:", predicted_probability[1])
63
+ print("Contradiction:", predicted_probability[2])
64
+ ```
65
+
66
+ More in [here](https://github.com/facebookresearch/anli/blob/master/src/hg_api/interactive_eval.py).
67
+
68
+ Citation:
69
+ ```
70
+ @inproceedings{nie-etal-2020-adversarial,
71
+ title = "Adversarial {NLI}: A New Benchmark for Natural Language Understanding",
72
+ author = "Nie, Yixin and
73
+ Williams, Adina and
74
+ Dinan, Emily and
75
+ Bansal, Mohit and
76
+ Weston, Jason and
77
+ Kiela, Douwe",
78
+ booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics",
79
+ year = "2020",
80
+ publisher = "Association for Computational Linguistics",
81
+ }
82
+ ```