Basic implementation
Browse files- app.py +79 -4
- get_max_min_values.py +21 -0
- recognizers/__init__.py +2 -0
- recognizers/base.py +36 -0
- recognizers/diff_align.py +48 -0
- recognizers/diff_del.py +217 -0
- recognizers/feature_based.py +136 -0
- recognizers/utils.py +129 -0
- result_template.html +47 -0
- tests.py +26 -0
app.py
CHANGED
@@ -1,9 +1,84 @@
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
|
|
|
|
3 |
|
4 |
-
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
-
|
9 |
-
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
import gradio as gr
|
4 |
+
from jinja2 import Environment
|
5 |
+
from tokenizers.pre_tokenizers import Whitespace
|
6 |
+
from transformers import pipeline
|
7 |
+
|
8 |
+
from recognizers import DiffAlign, DiffDel
|
9 |
+
|
10 |
+
|
11 |
+
def load_pipeline(model_name_or_path: str = "ZurichNLP/unsup-simcse-xlm-roberta-base"):
|
12 |
+
return pipeline("feature-extraction", model=model_name_or_path)
|
13 |
+
|
14 |
+
|
15 |
+
def generate_diff(text_a: str, text_b: str, method: str):
|
16 |
+
global my_pipeline
|
17 |
+
if my_pipeline is None:
|
18 |
+
my_pipeline = load_pipeline()
|
19 |
+
|
20 |
+
if method == "DiffAlign":
|
21 |
+
diff = DiffAlign(pipeline=my_pipeline)
|
22 |
+
min_value = 0.3758048415184021 - 0.37
|
23 |
+
max_value = 1.045647144317627 - 0.1
|
24 |
+
elif method == "DiffDel":
|
25 |
+
diff = DiffDel(pipeline=my_pipeline)
|
26 |
+
min_value = 0.4864141941070556
|
27 |
+
max_value = 0.5012983083724976 + 0.025
|
28 |
+
else:
|
29 |
+
raise ValueError(f"Unknown method: {method}")
|
30 |
+
|
31 |
+
encoding_a = tokenizer.pre_tokenize_str(text_a)
|
32 |
+
encoding_b = tokenizer.pre_tokenize_str(text_b)
|
33 |
+
|
34 |
+
result = diff.predict(
|
35 |
+
a=" ".join([token[0] for token in encoding_a]),
|
36 |
+
b=" ".join([token[0] for token in encoding_b]),
|
37 |
+
)
|
38 |
+
|
39 |
+
result.add_whitespace(encoding_a, encoding_b)
|
40 |
+
|
41 |
+
# Normalize labels based on empirical min/max values
|
42 |
+
result.labels_a = tuple([(label - min_value) / (max_value - min_value) for label in result.labels_a])
|
43 |
+
result.labels_b = tuple([(label - min_value) / (max_value - min_value) for label in result.labels_b])
|
44 |
+
|
45 |
+
# Round labels to range 0, 2, ... 10
|
46 |
+
result.labels_a = tuple([round(min(10, label * 10)) for label in result.labels_a])
|
47 |
+
result.labels_b = tuple([round(min(10, label * 10)) for label in result.labels_b])
|
48 |
+
|
49 |
+
template_path = Path(__file__).parent / "result_template.html"
|
50 |
+
template = Environment().from_string(template_path.read_text())
|
51 |
+
html_dir = Path(__file__).parent / "html_out"
|
52 |
+
html_dir.mkdir(exist_ok=True)
|
53 |
+
|
54 |
+
html_a = template.render(token_labels=result.token_labels_a)
|
55 |
+
html_b = template.render(token_labels=result.token_labels_b)
|
56 |
+
return str(html_a), str(html_b)
|
57 |
+
|
58 |
|
59 |
+
my_pipeline = None
|
60 |
+
tokenizer = Whitespace()
|
61 |
|
62 |
+
with gr.Blocks() as demo:
|
63 |
+
with gr.Row():
|
64 |
+
text_a = gr.Textbox(label="Text A", value="Chinese shares close higher Friday.", lines=2)
|
65 |
+
text_b = gr.Textbox(label="Text B", value="Les actions chinoises clôturent en baisse mercredi.", lines=2)
|
66 |
+
with gr.Row():
|
67 |
+
method = gr.Dropdown(choices=["DiffAlign", "DiffDel"], label="Comparison Method", value="DiffAlign")
|
68 |
+
with gr.Row():
|
69 |
+
with gr.Column(variant="panel"):
|
70 |
+
output_a = gr.HTML(label="Result for text A", show_label=True)
|
71 |
+
with gr.Column(variant="panel"):
|
72 |
+
output_b = gr.HTML(label="Result for text B", show_label=True)
|
73 |
+
with gr.Row():
|
74 |
+
submit_btn = gr.Button(label="Generate Diff")
|
75 |
+
submit_btn.click(
|
76 |
+
fn=generate_diff,
|
77 |
+
inputs=[text_a, text_b, method],
|
78 |
+
outputs=[output_a, output_b],
|
79 |
+
)
|
80 |
|
81 |
|
82 |
+
if my_pipeline is None:
|
83 |
+
my_pipeline = load_pipeline()
|
84 |
+
demo.launch()
|
get_max_min_values.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Get similarities of similar and dissimilar pairs.
|
3 |
+
The values are used for normalizing the colors in the visualization.
|
4 |
+
"""
|
5 |
+
from app import load_pipeline
|
6 |
+
from recognizers import DiffAlign, DiffDel
|
7 |
+
|
8 |
+
similar_pair = ("Hello!", "Hi!")
|
9 |
+
dissimilar_pair = ("Hello!", "asdf")
|
10 |
+
|
11 |
+
pipeline = load_pipeline()
|
12 |
+
diff_align = DiffAlign(pipeline=pipeline)
|
13 |
+
diff_del = DiffDel(pipeline=pipeline)
|
14 |
+
|
15 |
+
print("Similar pair:")
|
16 |
+
print(diff_align.predict(*similar_pair).min)
|
17 |
+
print(diff_del.predict(*similar_pair).min)
|
18 |
+
|
19 |
+
print("Dissimilar pair:")
|
20 |
+
print(diff_align.predict(*dissimilar_pair).max)
|
21 |
+
print(diff_del.predict(*dissimilar_pair).max)
|
recognizers/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from recognizers.diff_align import DiffAlign
|
2 |
+
from recognizers.diff_del import DiffDel
|
recognizers/base.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source: https://github.com/ZurichNLP/recognizing-semantic-differences
|
3 |
+
MIT License
|
4 |
+
Copyright (c) 2023 University of Zurich
|
5 |
+
"""
|
6 |
+
|
7 |
+
from typing import List
|
8 |
+
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from recognizers.utils import DifferenceSample
|
12 |
+
|
13 |
+
|
14 |
+
class DifferenceRecognizer:
|
15 |
+
|
16 |
+
def __str__(self):
|
17 |
+
raise NotImplemented
|
18 |
+
|
19 |
+
def predict(self,
|
20 |
+
a: str,
|
21 |
+
b: str,
|
22 |
+
**kwargs,
|
23 |
+
) -> DifferenceSample:
|
24 |
+
raise NotImplemented
|
25 |
+
|
26 |
+
def predict_all(self,
|
27 |
+
a: List[str],
|
28 |
+
b: List[str],
|
29 |
+
**kwargs,
|
30 |
+
) -> List[DifferenceSample]:
|
31 |
+
assert len(a) == len(b)
|
32 |
+
predictions = []
|
33 |
+
for i in tqdm(list(range(len(a)))):
|
34 |
+
prediction = self.predict(a[i], b[i], **kwargs)
|
35 |
+
predictions.append(prediction)
|
36 |
+
return predictions
|
recognizers/diff_align.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source: https://github.com/ZurichNLP/recognizing-semantic-differences
|
3 |
+
MIT License
|
4 |
+
Copyright (c) 2023 University of Zurich
|
5 |
+
"""
|
6 |
+
|
7 |
+
from typing import List
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from recognizers.feature_based import FeatureExtractionRecognizer
|
12 |
+
from recognizers.utils import DifferenceSample, cos_sim
|
13 |
+
|
14 |
+
|
15 |
+
class DiffAlign(FeatureExtractionRecognizer):
|
16 |
+
|
17 |
+
def __str__(self):
|
18 |
+
return f"DiffAlign(model={self.pipeline.model.name_or_path}, layer={self.layer}"
|
19 |
+
|
20 |
+
@torch.no_grad()
|
21 |
+
def _predict_all(self,
|
22 |
+
a: List[str],
|
23 |
+
b: List[str],
|
24 |
+
**kwargs,
|
25 |
+
) -> List[DifferenceSample]:
|
26 |
+
outputs_a = self.encode_batch(a, **kwargs)
|
27 |
+
outputs_b = self.encode_batch(b, **kwargs)
|
28 |
+
subwords_by_words_a = [self._get_subwords_by_word(sentence) for sentence in a]
|
29 |
+
subwords_by_words_b = [self._get_subwords_by_word(sentence) for sentence in b]
|
30 |
+
subword_labels_a = []
|
31 |
+
subword_labels_b = []
|
32 |
+
for i in range(len(a)):
|
33 |
+
cosine_similarities = cos_sim(outputs_a[i], outputs_b[i])
|
34 |
+
max_similarities_a = torch.max(cosine_similarities, dim=1).values
|
35 |
+
max_similarities_b = torch.max(cosine_similarities, dim=0).values
|
36 |
+
subword_labels_a.append((1 - max_similarities_a))
|
37 |
+
subword_labels_b.append((1 - max_similarities_b))
|
38 |
+
samples = []
|
39 |
+
for i in range(len(a)):
|
40 |
+
labels_a = self._subword_labels_to_word_labels(subword_labels_a[i], subwords_by_words_a[i])
|
41 |
+
labels_b = self._subword_labels_to_word_labels(subword_labels_b[i], subwords_by_words_b[i])
|
42 |
+
samples.append(DifferenceSample(
|
43 |
+
tokens_a=tuple(a[i].split()),
|
44 |
+
tokens_b=tuple(b[i].split()),
|
45 |
+
labels_a=tuple(labels_a),
|
46 |
+
labels_b=tuple(labels_b),
|
47 |
+
))
|
48 |
+
return samples
|
recognizers/diff_del.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source: https://github.com/ZurichNLP/recognizing-semantic-differences
|
3 |
+
MIT License
|
4 |
+
Copyright (c) 2023 University of Zurich
|
5 |
+
"""
|
6 |
+
|
7 |
+
import itertools
|
8 |
+
from copy import deepcopy
|
9 |
+
from typing import Union, List
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from transformers import Pipeline, FeatureExtractionPipeline
|
13 |
+
|
14 |
+
from recognizers.feature_based import FeatureExtractionRecognizer, Ngram
|
15 |
+
from recognizers.utils import DifferenceSample, pairwise_cos_sim, cos_sim
|
16 |
+
|
17 |
+
|
18 |
+
class DiffDel(FeatureExtractionRecognizer):
|
19 |
+
|
20 |
+
def __init__(self,
|
21 |
+
model_name_or_path: str = None,
|
22 |
+
pipeline: Union[FeatureExtractionPipeline, Pipeline] = None,
|
23 |
+
layer: int = -1,
|
24 |
+
batch_size: int = 16,
|
25 |
+
min_n: int = 1,
|
26 |
+
max_n: int = 1, # Inclusive
|
27 |
+
):
|
28 |
+
super().__init__(model_name_or_path, pipeline, layer, batch_size)
|
29 |
+
assert min_n <= max_n
|
30 |
+
self.min_n = min_n
|
31 |
+
self.max_n = max_n
|
32 |
+
|
33 |
+
def __str__(self):
|
34 |
+
return f"DiffDel(model={self.pipeline.model.name_or_path}, layer={self.layer}, " \
|
35 |
+
f"min_n={self.min_n}, max_n={self.max_n})"
|
36 |
+
|
37 |
+
@torch.no_grad()
|
38 |
+
def _predict_all(self,
|
39 |
+
a: List[str],
|
40 |
+
b: List[str],
|
41 |
+
**kwargs,
|
42 |
+
) -> List[DifferenceSample]:
|
43 |
+
outputs_a = self.encode_batch(a, **kwargs)
|
44 |
+
outputs_b = self.encode_batch(b, **kwargs)
|
45 |
+
subwords_by_words_a = [self._get_subwords_by_word(sentence) for sentence in a]
|
46 |
+
subwords_by_words_b = [self._get_subwords_by_word(sentence) for sentence in b]
|
47 |
+
ngrams_a = [self._get_ngrams(subwords_by_word) for subwords_by_word in subwords_by_words_a]
|
48 |
+
ngrams_b = [self._get_ngrams(subwords_by_word) for subwords_by_word in subwords_by_words_b]
|
49 |
+
sentence_embeddings_a = self._get_full_sentence_embeddings(outputs_a, [list(itertools.chain.from_iterable(subwords)) for subwords in subwords_by_words_a])
|
50 |
+
sentence_embeddings_b = self._get_full_sentence_embeddings(outputs_b, [list(itertools.chain.from_iterable(subwords)) for subwords in subwords_by_words_b])
|
51 |
+
full_similarities = pairwise_cos_sim(sentence_embeddings_a, sentence_embeddings_b)
|
52 |
+
|
53 |
+
all_labels_a = []
|
54 |
+
all_labels_b = []
|
55 |
+
for i in range(len(a)):
|
56 |
+
partial_embeddings_a = self._get_partial_sentence_embeddings_for_sample(outputs_a[i], ngrams_a[i])
|
57 |
+
partial_embeddings_b = self._get_partial_sentence_embeddings_for_sample(outputs_b[i], ngrams_b[i])
|
58 |
+
partial_similarities_a = cos_sim(partial_embeddings_a, sentence_embeddings_b[i].unsqueeze(0)).squeeze(1)
|
59 |
+
partial_similarities_b = cos_sim(partial_embeddings_b, sentence_embeddings_a[i].unsqueeze(0)).squeeze(1)
|
60 |
+
ngram_labels_a = (partial_similarities_a - full_similarities[i] + 1) / 2
|
61 |
+
ngram_labels_b = (partial_similarities_b - full_similarities[i] + 1) / 2
|
62 |
+
subword_labels_a = self._distribute_ngram_labels_to_subwords(ngram_labels_a, ngrams_a[i])
|
63 |
+
subword_labels_b = self._distribute_ngram_labels_to_subwords(ngram_labels_b, ngrams_b[i])
|
64 |
+
labels_a = self._subword_labels_to_word_labels(subword_labels_a, subwords_by_words_a[i])
|
65 |
+
labels_b = self._subword_labels_to_word_labels(subword_labels_b, subwords_by_words_b[i])
|
66 |
+
all_labels_a.append(labels_a)
|
67 |
+
all_labels_b.append(labels_b)
|
68 |
+
|
69 |
+
samples = []
|
70 |
+
for i in range(len(a)):
|
71 |
+
samples.append(DifferenceSample(
|
72 |
+
tokens_a=tuple(a[i].split()),
|
73 |
+
tokens_b=tuple(b[i].split()),
|
74 |
+
labels_a=tuple(all_labels_a[i]),
|
75 |
+
labels_b=tuple(all_labels_b[i]),
|
76 |
+
))
|
77 |
+
return samples
|
78 |
+
|
79 |
+
def _get_full_sentence_embeddings(self, token_embeddings: torch.Tensor, include_subwords: List[List[int]]) -> torch.Tensor:
|
80 |
+
"""
|
81 |
+
:param token_embeddings: batch x seq_len x dim
|
82 |
+
:param include_subwords: batch x num_subwords
|
83 |
+
:return: A tensor of shape batch x dim
|
84 |
+
"""
|
85 |
+
pool_mask = torch.zeros(token_embeddings.shape[0], token_embeddings.shape[1], device=token_embeddings.device)
|
86 |
+
for i, subword_indices in enumerate(include_subwords):
|
87 |
+
pool_mask[i, subword_indices] = 1
|
88 |
+
sentence_embeddings = self._pool(token_embeddings, pool_mask)
|
89 |
+
return sentence_embeddings
|
90 |
+
|
91 |
+
def _get_partial_sentence_embeddings_for_sample(self, token_embeddings: torch.Tensor, ngrams: List[Ngram]) -> torch.Tensor:
|
92 |
+
"""
|
93 |
+
:param token_embeddings: seq_len x dim
|
94 |
+
:param ngrams: num_ngrams x n
|
95 |
+
:return: A tensor of shape num_ngrams x dim
|
96 |
+
"""
|
97 |
+
pool_mask = torch.zeros(len(ngrams), token_embeddings.shape[0], device=token_embeddings.device)
|
98 |
+
pool_mask[:, list(itertools.chain.from_iterable(ngrams))] = 1
|
99 |
+
for i, subword_indices in enumerate(ngrams):
|
100 |
+
pool_mask[i, subword_indices] = 0
|
101 |
+
partial_embeddings = self._pool(token_embeddings.unsqueeze(0).repeat(len(ngrams), 1, 1), pool_mask)
|
102 |
+
return partial_embeddings
|
103 |
+
|
104 |
+
def _distribute_ngram_labels_to_subwords(self, ngram_labels: torch.Tensor, ngrams: List[Ngram]) -> torch.Tensor:
|
105 |
+
"""
|
106 |
+
:param ngram_labels: num_ngrams
|
107 |
+
:param ngrams: num_ngrams x n
|
108 |
+
:return: num_subwords
|
109 |
+
"""
|
110 |
+
max_subword_idx = max(itertools.chain.from_iterable(ngrams))
|
111 |
+
subword_contributions = torch.zeros(max_subword_idx + 1, device=ngram_labels.device)
|
112 |
+
contribution_count = torch.zeros(max_subword_idx + 1, device=ngram_labels.device)
|
113 |
+
for i, ngram in enumerate(ngrams):
|
114 |
+
subword_contributions[ngram] += ngram_labels[i] / len(ngram)
|
115 |
+
contribution_count[ngram] += 1 / len(ngram)
|
116 |
+
subword_contributions /= contribution_count
|
117 |
+
return subword_contributions
|
118 |
+
|
119 |
+
|
120 |
+
class DiffDelWithReencode(FeatureExtractionRecognizer):
|
121 |
+
"""
|
122 |
+
Version of DiffDel that encodes the partial sentences from scratch (instead of encoding the full sentence once and
|
123 |
+
then excluding hidden states from the mean)
|
124 |
+
"""
|
125 |
+
|
126 |
+
def __init__(self,
|
127 |
+
model_name_or_path: str = None,
|
128 |
+
pipeline: Union[FeatureExtractionPipeline, Pipeline] = None,
|
129 |
+
layer: int = -1,
|
130 |
+
batch_size: int = 16,
|
131 |
+
):
|
132 |
+
super().__init__(model_name_or_path, pipeline, layer, batch_size)
|
133 |
+
|
134 |
+
def __str__(self):
|
135 |
+
return f"DiffDelWithReencode(model={self.pipeline.model.name_or_path}, layer={self.layer})"
|
136 |
+
|
137 |
+
@torch.no_grad()
|
138 |
+
def _predict_all(self,
|
139 |
+
a: List[str],
|
140 |
+
b: List[str],
|
141 |
+
**kwargs,
|
142 |
+
) -> List[DifferenceSample]:
|
143 |
+
a_words = [sentence.split() for sentence in a]
|
144 |
+
b_words = [sentence.split() for sentence in b]
|
145 |
+
a_words_partial = []
|
146 |
+
b_words_partial = []
|
147 |
+
for words in a_words:
|
148 |
+
for i, word in enumerate(words):
|
149 |
+
partial = deepcopy(words)
|
150 |
+
del partial[i]
|
151 |
+
a_words_partial.append(partial)
|
152 |
+
for words in b_words:
|
153 |
+
for i, word in enumerate(words):
|
154 |
+
partial = deepcopy(words)
|
155 |
+
del partial[i]
|
156 |
+
b_words_partial.append(partial)
|
157 |
+
a_partial = [" ".join([word for word in words if word]) for words in a_words_partial]
|
158 |
+
b_partial = [" ".join([word for word in words if word]) for words in b_words_partial]
|
159 |
+
a_num_partial = [len(words) for words in a_words]
|
160 |
+
b_num_partial = [len(words) for words in b_words]
|
161 |
+
a_embedding_full = self._encode_and_pool(a, **kwargs)
|
162 |
+
b_embedding_full = self._encode_and_pool(b, **kwargs)
|
163 |
+
a_embeddings_partial = []
|
164 |
+
b_embeddings_partial = []
|
165 |
+
for i in range(0, len(a_partial), self.batch_size):
|
166 |
+
a_embeddings_partial_batch = self._encode_and_pool(a_partial[i:i + self.batch_size], **kwargs)
|
167 |
+
a_embeddings_partial.append(a_embeddings_partial_batch)
|
168 |
+
for i in range(0, len(b_partial), self.batch_size):
|
169 |
+
b_embeddings_partial_batch = self._encode_and_pool(b_partial[i:i + self.batch_size], **kwargs)
|
170 |
+
b_embeddings_partial.append(b_embeddings_partial_batch)
|
171 |
+
a_embeddings_partial = torch.cat(a_embeddings_partial, dim=0)
|
172 |
+
b_embeddings_partial = torch.cat(b_embeddings_partial, dim=0)
|
173 |
+
|
174 |
+
labels_a = []
|
175 |
+
labels_b = []
|
176 |
+
similarity_full = pairwise_cos_sim(a_embedding_full, b_embedding_full)
|
177 |
+
for i in range(len(a)):
|
178 |
+
a_embeddings_partial_i = a_embeddings_partial[sum(a_num_partial[:i]):sum(a_num_partial[:i + 1])]
|
179 |
+
similarities_partial = pairwise_cos_sim(a_embeddings_partial_i, b_embedding_full[i].unsqueeze(0)).squeeze(0)
|
180 |
+
labels = (similarities_partial - similarity_full[i] + 1) / 2
|
181 |
+
labels = labels.detach().cpu().tolist()
|
182 |
+
if isinstance(labels, float):
|
183 |
+
labels = [labels]
|
184 |
+
assert len(labels) == len(a_words[i])
|
185 |
+
labels_a.append(labels)
|
186 |
+
for i in range(len(b)):
|
187 |
+
b_embeddings_partial_i = b_embeddings_partial[sum(b_num_partial[:i]):sum(b_num_partial[:i + 1])]
|
188 |
+
similarities_partial = pairwise_cos_sim(b_embeddings_partial_i, a_embedding_full[i].unsqueeze(0)).squeeze(0)
|
189 |
+
labels = (similarities_partial - similarity_full[i] + 1) / 2
|
190 |
+
labels = labels.detach().cpu().tolist()
|
191 |
+
if isinstance(labels, float):
|
192 |
+
labels = [labels]
|
193 |
+
assert len(labels) == len(b_words[i])
|
194 |
+
labels_b.append(labels)
|
195 |
+
|
196 |
+
samples = []
|
197 |
+
for i in range(len(a)):
|
198 |
+
samples.append(DifferenceSample(
|
199 |
+
tokens_a=tuple(a_words[i]),
|
200 |
+
tokens_b=tuple(b_words[i]),
|
201 |
+
labels_a=tuple(labels_a[i]),
|
202 |
+
labels_b=tuple(labels_b[i]),
|
203 |
+
))
|
204 |
+
return samples
|
205 |
+
|
206 |
+
def _encode_and_pool(self, sentences: List[str], **kwargs) -> torch.Tensor:
|
207 |
+
model_inputs = self.pipeline.tokenizer(sentences, return_tensors="pt", padding=True, truncation=True)
|
208 |
+
model_inputs = model_inputs.to(self.pipeline.device)
|
209 |
+
outputs = self.pipeline.model(**model_inputs, output_hidden_states=True, **kwargs)
|
210 |
+
if self.layer == "mean":
|
211 |
+
token_embeddings = torch.stack(outputs.hidden_states, dim=0).mean(dim=0)
|
212 |
+
else:
|
213 |
+
assert isinstance(self.layer, int)
|
214 |
+
token_embeddings = outputs.hidden_states[self.layer]
|
215 |
+
mask = model_inputs["attention_mask"]
|
216 |
+
sentence_embeddings = torch.sum(token_embeddings * mask.unsqueeze(-1), dim=1)
|
217 |
+
return sentence_embeddings
|
recognizers/feature_based.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source: https://github.com/ZurichNLP/recognizing-semantic-differences
|
3 |
+
MIT License
|
4 |
+
Copyright (c) 2023 University of Zurich
|
5 |
+
"""
|
6 |
+
|
7 |
+
import itertools
|
8 |
+
from typing import List, Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import transformers
|
12 |
+
from transformers import FeatureExtractionPipeline, Pipeline
|
13 |
+
|
14 |
+
from recognizers.base import DifferenceRecognizer
|
15 |
+
from recognizers.utils import DifferenceSample
|
16 |
+
|
17 |
+
Ngram = List[int] # A span of subword indices
|
18 |
+
|
19 |
+
|
20 |
+
class FeatureExtractionRecognizer(DifferenceRecognizer):
|
21 |
+
|
22 |
+
def __init__(self,
|
23 |
+
model_name_or_path: str = None,
|
24 |
+
pipeline: Union[FeatureExtractionPipeline, Pipeline] = None,
|
25 |
+
layer: int = -1,
|
26 |
+
batch_size: int = 16,
|
27 |
+
):
|
28 |
+
assert model_name_or_path is not None or pipeline is not None
|
29 |
+
if pipeline is None:
|
30 |
+
pipeline = transformers.pipeline(
|
31 |
+
model=model_name_or_path,
|
32 |
+
task="feature-extraction",
|
33 |
+
)
|
34 |
+
self.pipeline = pipeline
|
35 |
+
self.layer = layer
|
36 |
+
self.batch_size = batch_size
|
37 |
+
|
38 |
+
def encode_batch(self, sentences: List[str], **kwargs) -> torch.Tensor:
|
39 |
+
model_inputs = self.pipeline.tokenizer(sentences, return_tensors="pt", padding=True, truncation=True)
|
40 |
+
model_inputs = model_inputs.to(self.pipeline.device)
|
41 |
+
outputs = self.pipeline.model(**model_inputs, output_hidden_states=True, **kwargs)
|
42 |
+
return outputs.hidden_states[self.layer]
|
43 |
+
|
44 |
+
def predict(self,
|
45 |
+
a: str,
|
46 |
+
b: str,
|
47 |
+
**kwargs,
|
48 |
+
) -> DifferenceSample:
|
49 |
+
return self.predict_all([a], [b], **kwargs)[0]
|
50 |
+
|
51 |
+
def predict_all(self,
|
52 |
+
a: List[str],
|
53 |
+
b: List[str],
|
54 |
+
**kwargs,
|
55 |
+
) -> List[DifferenceSample]:
|
56 |
+
samples = []
|
57 |
+
for i in range(0, len(a), self.batch_size):
|
58 |
+
samples.extend(self._predict_all(
|
59 |
+
a[i:i + self.batch_size],
|
60 |
+
b[i:i + self.batch_size],
|
61 |
+
**kwargs,
|
62 |
+
))
|
63 |
+
return samples
|
64 |
+
|
65 |
+
@torch.no_grad()
|
66 |
+
def _predict_all(self,
|
67 |
+
a: List[str],
|
68 |
+
b: List[str],
|
69 |
+
**kwargs,
|
70 |
+
) -> List[DifferenceSample]:
|
71 |
+
raise NotImplementedError
|
72 |
+
|
73 |
+
def _pool(self, token_embeddings: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
74 |
+
"""
|
75 |
+
:param token_embeddings: batch x seq_len x dim
|
76 |
+
:param mask: batch x seq_len; 1 if token should be included in the pooling
|
77 |
+
:return: batch x dim
|
78 |
+
Do only sum and do not divide by the number of tokens because cosine similarity is length-invariant.
|
79 |
+
"""
|
80 |
+
return torch.sum(token_embeddings * mask.unsqueeze(-1), dim=1)
|
81 |
+
|
82 |
+
def _get_subwords_by_word(self, sentence: str) -> List[Ngram]:
|
83 |
+
"""
|
84 |
+
:return: For each word in the sentence, the positions of the subwords that make up the word.
|
85 |
+
"""
|
86 |
+
batch_encoding = self.pipeline.tokenizer(
|
87 |
+
sentence,
|
88 |
+
padding=True,
|
89 |
+
truncation=True,
|
90 |
+
)
|
91 |
+
subword_ids: List[List[int]] = []
|
92 |
+
|
93 |
+
for subword_idx in range(len(batch_encoding.encodings[0].word_ids)):
|
94 |
+
if batch_encoding.encodings[0].word_ids[subword_idx] is None: # Special token
|
95 |
+
continue
|
96 |
+
char_idx = batch_encoding.encodings[0].offsets[subword_idx][0]
|
97 |
+
if isinstance(self.pipeline.tokenizer, transformers.XLMRobertaTokenizerFast) or \
|
98 |
+
isinstance(self.pipeline.tokenizer, transformers.XLMRobertaTokenizer):
|
99 |
+
token = batch_encoding.encodings[0].tokens[subword_idx]
|
100 |
+
is_tail = not token.startswith("▁") and token not in self.pipeline.tokenizer.all_special_tokens
|
101 |
+
elif isinstance(self.pipeline.tokenizer, transformers.RobertaTokenizerFast) or \
|
102 |
+
isinstance(self.pipeline.tokenizer, transformers.RobertaTokenizer):
|
103 |
+
token = batch_encoding.encodings[0].tokens[subword_idx]
|
104 |
+
is_tail = not token.startswith("Ġ") and token not in self.pipeline.tokenizer.all_special_tokens
|
105 |
+
else:
|
106 |
+
is_tail = char_idx > 0 and char_idx == batch_encoding.encodings[0].offsets[subword_idx - 1][1]
|
107 |
+
if is_tail and len(subword_ids) > 0:
|
108 |
+
subword_ids[-1].append(subword_idx)
|
109 |
+
else:
|
110 |
+
subword_ids.append([subword_idx])
|
111 |
+
return subword_ids
|
112 |
+
|
113 |
+
def _get_ngrams(self, subwords_by_word: List[Ngram]) -> List[Ngram]:
|
114 |
+
"""
|
115 |
+
:return: For each subword ngram in the sentence, the positions of the subwords that make up the ngram.
|
116 |
+
"""
|
117 |
+
subwords = list(itertools.chain.from_iterable(subwords_by_word))
|
118 |
+
# Always return at least one ngram (reduce n if necessary)
|
119 |
+
min_n = min(self.min_n, len(subwords))
|
120 |
+
ngrams = []
|
121 |
+
for n in range(min_n, self.max_n + 1):
|
122 |
+
for i in range(len(subwords) - n + 1):
|
123 |
+
ngrams.append(subwords[i:i + n])
|
124 |
+
return ngrams
|
125 |
+
|
126 |
+
def _subword_labels_to_word_labels(self, subword_labels: torch.Tensor, subwords_by_words: List[Ngram]) -> List[float]:
|
127 |
+
"""
|
128 |
+
:param subword_labels: num_subwords
|
129 |
+
:param subwords_by_words: num_words x num_subwords
|
130 |
+
:return: num_words
|
131 |
+
"""
|
132 |
+
labels = []
|
133 |
+
for subword_indices in subwords_by_words:
|
134 |
+
label = subword_labels[subword_indices].mean().item()
|
135 |
+
labels.append(label)
|
136 |
+
return labels
|
recognizers/utils.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source: https://github.com/ZurichNLP/recognizing-semantic-differences
|
3 |
+
MIT License
|
4 |
+
Copyright (c) 2023 University of Zurich
|
5 |
+
"""
|
6 |
+
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from typing import Tuple, Optional
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from tokenizers.pre_tokenizers import Whitespace
|
12 |
+
from torch import Tensor
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class DifferenceSample:
|
17 |
+
tokens_a: Tuple[str, ...]
|
18 |
+
tokens_b: Tuple[str, ...]
|
19 |
+
labels_a: Tuple[float, ...]
|
20 |
+
labels_b: Optional[Tuple[float, ...]]
|
21 |
+
|
22 |
+
def add_whitespace(self, encoding_a, encoding_b):
|
23 |
+
self.tokens_a = self._add_whitespace(self.tokens_a, encoding_a)
|
24 |
+
self.tokens_b = self._add_whitespace(self.tokens_b, encoding_b)
|
25 |
+
|
26 |
+
def _add_whitespace(self, tokens, encoding) -> Tuple[str, ...]:
|
27 |
+
assert len(tokens) == len(encoding)
|
28 |
+
new_tokens = []
|
29 |
+
for i in range(len(encoding)):
|
30 |
+
token = tokens[i]
|
31 |
+
if i < len(encoding) - 1:
|
32 |
+
cur_end = encoding[i][1][1]
|
33 |
+
next_start = encoding[i + 1][1][0]
|
34 |
+
token += " " * (next_start - cur_end)
|
35 |
+
new_tokens.append(token)
|
36 |
+
return tuple(new_tokens)
|
37 |
+
|
38 |
+
# For rendering with Jinja2
|
39 |
+
@property
|
40 |
+
def token_labels_a(self) -> Tuple[Tuple[str, float], ...]:
|
41 |
+
return tuple(zip(self.tokens_a, self.labels_a))
|
42 |
+
|
43 |
+
@property
|
44 |
+
def token_labels_b(self) -> Tuple[Tuple[str, float], ...]:
|
45 |
+
return tuple(zip(self.tokens_b, self.labels_b))
|
46 |
+
|
47 |
+
@property
|
48 |
+
def min(self) -> float:
|
49 |
+
return min(self.labels_a + self.labels_b)
|
50 |
+
|
51 |
+
@property
|
52 |
+
def max(self) -> float:
|
53 |
+
return max(self.labels_a + self.labels_b)
|
54 |
+
|
55 |
+
|
56 |
+
def tokenize(text: str) -> Tuple[str]:
|
57 |
+
"""
|
58 |
+
Apply Moses-like tokenization to a string.
|
59 |
+
"""
|
60 |
+
whitespace_tokenizer = Whitespace()
|
61 |
+
output = whitespace_tokenizer.pre_tokenize_str(text)
|
62 |
+
# [('This', (0, 4)), ('is', (5, 7)), ('a', (8, 9)), ('test', (10, 14)), ('.', (14, 15))]
|
63 |
+
tokens = [str(token[0]) for token in output]
|
64 |
+
return tuple(tokens)
|
65 |
+
|
66 |
+
|
67 |
+
def cos_sim(a: Tensor, b: Tensor):
|
68 |
+
"""
|
69 |
+
Copied from https://github.com/UKPLab/sentence-transformers/blob/d928410803bb90f555926d145ee7ad3bd1373a83/sentence_transformers/util.py#L31
|
70 |
+
|
71 |
+
Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.
|
72 |
+
:return: Matrix with res[i][j] = cos_sim(a[i], b[j])
|
73 |
+
"""
|
74 |
+
if not isinstance(a, torch.Tensor):
|
75 |
+
a = torch.tensor(a)
|
76 |
+
|
77 |
+
if not isinstance(b, torch.Tensor):
|
78 |
+
b = torch.tensor(b)
|
79 |
+
|
80 |
+
if len(a.shape) == 1:
|
81 |
+
a = a.unsqueeze(0)
|
82 |
+
|
83 |
+
if len(b.shape) == 1:
|
84 |
+
b = b.unsqueeze(0)
|
85 |
+
|
86 |
+
a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
|
87 |
+
b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
|
88 |
+
return torch.mm(a_norm, b_norm.transpose(0, 1))
|
89 |
+
|
90 |
+
|
91 |
+
def pairwise_dot_score(a: Tensor, b: Tensor):
|
92 |
+
"""
|
93 |
+
Copied from https://github.com/UKPLab/sentence-transformers/blob/d928410803bb90f555926d145ee7ad3bd1373a83/sentence_transformers/util.py#L73
|
94 |
+
|
95 |
+
Computes the pairwise dot-product dot_prod(a[i], b[i])
|
96 |
+
:return: Vector with res[i] = dot_prod(a[i], b[i])
|
97 |
+
"""
|
98 |
+
if not isinstance(a, torch.Tensor):
|
99 |
+
a = torch.tensor(a)
|
100 |
+
|
101 |
+
if not isinstance(b, torch.Tensor):
|
102 |
+
b = torch.tensor(b)
|
103 |
+
|
104 |
+
return (a * b).sum(dim=-1)
|
105 |
+
|
106 |
+
|
107 |
+
def normalize_embeddings(embeddings: Tensor):
|
108 |
+
"""
|
109 |
+
Copied from https://github.com/UKPLab/sentence-transformers/blob/d928410803bb90f555926d145ee7ad3bd1373a83/sentence_transformers/util.py#L101
|
110 |
+
|
111 |
+
Normalizes the embeddings matrix, so that each sentence embedding has unit length
|
112 |
+
"""
|
113 |
+
return torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
114 |
+
|
115 |
+
|
116 |
+
def pairwise_cos_sim(a: Tensor, b: Tensor):
|
117 |
+
"""
|
118 |
+
Copied from https://github.com/UKPLab/sentence-transformers/blob/d928410803bb90f555926d145ee7ad3bd1373a83/sentence_transformers/util.py#L87
|
119 |
+
|
120 |
+
Computes the pairwise cossim cos_sim(a[i], b[i])
|
121 |
+
:return: Vector with res[i] = cos_sim(a[i], b[i])
|
122 |
+
"""
|
123 |
+
if not isinstance(a, torch.Tensor):
|
124 |
+
a = torch.tensor(a)
|
125 |
+
|
126 |
+
if not isinstance(b, torch.Tensor):
|
127 |
+
b = torch.tensor(b)
|
128 |
+
|
129 |
+
return pairwise_dot_score(normalize_embeddings(a), normalize_embeddings(b))
|
result_template.html
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<p>
|
2 |
+
{% for token, label in token_labels %}<span class="highlight-{{ label }}">{{ token }}</span>{% endfor %}
|
3 |
+
</p>
|
4 |
+
|
5 |
+
|
6 |
+
<style>
|
7 |
+
.highlight-1 {
|
8 |
+
background: linear-gradient(90deg, transparent, rgba(255, 245, 235, 0.05), transparent);
|
9 |
+
}
|
10 |
+
|
11 |
+
.highlight-2 {
|
12 |
+
background: linear-gradient(90deg, transparent, rgba(254, 230, 206, 0.10), transparent);
|
13 |
+
}
|
14 |
+
|
15 |
+
.highlight-3 {
|
16 |
+
background: linear-gradient(90deg, transparent, rgba(253, 208, 162, 0.15), transparent);
|
17 |
+
}
|
18 |
+
|
19 |
+
.highlight-4 {
|
20 |
+
background: linear-gradient(90deg, transparent, rgba(253, 141, 60, 0.20), transparent);
|
21 |
+
}
|
22 |
+
|
23 |
+
.highlight-5 {
|
24 |
+
background: linear-gradient(90deg, transparent, rgba(241, 105, 19, 0.25), transparent);
|
25 |
+
}
|
26 |
+
|
27 |
+
.highlight-6 {
|
28 |
+
background: linear-gradient(90deg, transparent, rgba(217, 72, 1, 0.30), transparent);
|
29 |
+
}
|
30 |
+
|
31 |
+
.highlight-7 {
|
32 |
+
background: linear-gradient(90deg, transparent, rgba(127, 39, 4, 0.35), transparent);
|
33 |
+
}
|
34 |
+
|
35 |
+
.highlight-8 {
|
36 |
+
background: linear-gradient(90deg, transparent, rgba(127, 39, 4, 0.40), transparent);
|
37 |
+
}
|
38 |
+
|
39 |
+
.highlight-9 {
|
40 |
+
background: linear-gradient(90deg, transparent, rgba(127, 39, 4, 0.45), transparent);
|
41 |
+
}
|
42 |
+
|
43 |
+
.highlight-10 {
|
44 |
+
background: linear-gradient(90deg, transparent, rgba(127, 39, 4, 0.50), transparent);
|
45 |
+
}
|
46 |
+
|
47 |
+
</style>
|
tests.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from unittest import TestCase
|
2 |
+
|
3 |
+
from tokenizers.pre_tokenizers import Whitespace
|
4 |
+
|
5 |
+
from recognizers.utils import DifferenceSample
|
6 |
+
|
7 |
+
|
8 |
+
class DifferenceSampleTestCase(TestCase):
|
9 |
+
|
10 |
+
def setUp(self):
|
11 |
+
self.text_a = "Chinese shares close higher Friday."
|
12 |
+
self.text_b = "Les actions chinoises clôturent en baisse mercredi."
|
13 |
+
self.tokenizer = Whitespace()
|
14 |
+
self.encoding_a = self.tokenizer.pre_tokenize_str(self.text_a)
|
15 |
+
self.encoding_b = self.tokenizer.pre_tokenize_str(self.text_b)
|
16 |
+
self.result = DifferenceSample(
|
17 |
+
tokens_a=tuple([token[0] for token in self.encoding_a]),
|
18 |
+
tokens_b=tuple([token[0] for token in self.encoding_b]),
|
19 |
+
labels_a=tuple([0.1 for _ in range(len(self.encoding_a))]),
|
20 |
+
labels_b=tuple([0.1 for _ in range(len(self.encoding_b))]),
|
21 |
+
)
|
22 |
+
|
23 |
+
def test_add_whitespace(self):
|
24 |
+
self.result.add_whitespace(self.encoding_a, self.encoding_b)
|
25 |
+
self.assertEqual("".join(self.result.tokens_a), self.text_a)
|
26 |
+
self.assertEqual("".join(self.result.tokens_b), self.text_b)
|