Daniele Licari commited on
Commit
e5abd54
·
1 Parent(s): 0039f4f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +74 -1
README.md CHANGED
@@ -41,4 +41,77 @@ fill_mask("Il [MASK] ha chiesto revocarsi l'obbligo di pagamento")
41
  # {'sequence': "Il resistente ha chiesto revocarsi l'obbligo di pagamento",'score': 0.039877112954854965},
42
  # {'sequence': "Il lavoratore ha chiesto revocarsi l'obbligo di pagamento",'score': 0.028993653133511543},
43
  # {'sequence': "Il Ministero ha chiesto revocarsi l'obbligo di pagamento", 'score': 0.025297977030277252}]
44
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # {'sequence': "Il resistente ha chiesto revocarsi l'obbligo di pagamento",'score': 0.039877112954854965},
42
  # {'sequence': "Il lavoratore ha chiesto revocarsi l'obbligo di pagamento",'score': 0.028993653133511543},
43
  # {'sequence': "Il Ministero ha chiesto revocarsi l'obbligo di pagamento", 'score': 0.025297977030277252}]
44
+ ```
45
+ here how to use it for sentence similarity
46
+ ```python
47
+ import seaborn as sns
48
+ import matplotlib.pyplot as pl
49
+ from textwrap import wrap
50
+
51
+ #Mean Pooling - Take attention mask into account for correct averaging
52
+ def mean_pooling(model_output, attention_mask):
53
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
54
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
55
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
56
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
57
+ return sum_embeddings / sum_mask
58
+
59
+
60
+ # gettting Sentence Embeddings
61
+ def sentence_embeddings(sentences, model_name, max_length=512):
62
+ # load models
63
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
64
+ model = AutoModel.from_pretrained(model_name)
65
+
66
+ #Tokenize sentences
67
+ encoded_input = tokenizer(sentences, padding=True, truncation=True, max_length=max_length, return_tensors='pt')
68
+
69
+ #Compute token embeddings
70
+ with torch.no_grad():
71
+ model_output = model(**encoded_input)
72
+
73
+ #Perform pooling. In this case, mean pooling
74
+ return mean_pooling(model_output, encoded_input['attention_mask']).detach().numpy()
75
+
76
+
77
+ def plot_similarity(sentences, model_name):
78
+ # get sentence embeddings produced by the model
79
+ embeddings = sentence_embeddings(sentences, model_name)
80
+ # perfome similarity score using cosine similarity
81
+ corr = cosine_similarity(embeddings, embeddings)
82
+
83
+ # plot heatmap similarity
84
+ sns.set(font_scale=1.2)
85
+ # for text axis labels wrapping
86
+ labels = [ '\n'.join(wrap(l, 40)) for l in sentences]
87
+ g = sns.heatmap(
88
+ corr,
89
+ xticklabels=labels,
90
+ yticklabels=labels,
91
+ vmax=1,
92
+ cmap="YlOrRd")
93
+ g.set_xticklabels(labels, rotation=90)
94
+ model_short_name = model_name.split('/')[-1]
95
+ g.set_title(f"Semantic Textual Similarity ({model_short_name})")
96
+ plt.show()
97
+
98
+
99
+ sent = [
100
+ # 1. "The court shall pronounce the judgment for the dissolution or termination of the civil effects of marriage."
101
+ "Il tribunale pronuncia la sentenza per lo scioglimento o la cessazione degli effetti civili del matrimonio",
102
+
103
+ # 2. "having regard to Articles 1, 2, 3 No. 2(b) and 4 Paragraph 13 of Law No. 898 of December 1, 1970, as later amended."
104
+ # NOTE: Law Dec. 1, 1970 No. 898 is on divorce
105
+ "visti gli articoli 1, 2, 3 n. 2 lett. b) e 4 comma 13 della legge 1 dicembre 1970 n. 898 e successive modifiche",
106
+
107
+ # 3. "The plaintiff has lost the case."
108
+ "Il ricorrente ha perso la causa"
109
+ ]
110
+
111
+
112
+ model_name = "dlicari/Italian-Legal-BERT"
113
+ plot_similarity(sent, model_name)
114
+ model_name = 'dbmdz/bert-base-italian-xxl-cased'
115
+ plot_similarity(sent, model_name)
116
+ ```
117
+ <img src="https://huggingface.co/dlicari/Italian-Legal-BERT/resolve/main/semantic_text_similarity.jpg" width="500"/>