File size: 4,101 Bytes
fda57dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
Source: https://github.com/ZurichNLP/recognizing-semantic-differences
MIT License
Copyright (c) 2023 University of Zurich
"""

from dataclasses import dataclass
from typing import Tuple, Optional

import torch
from tokenizers.pre_tokenizers import Whitespace
from torch import Tensor


@dataclass
class DifferenceSample:
    tokens_a: Tuple[str, ...]
    tokens_b: Tuple[str, ...]
    labels_a: Tuple[float, ...]
    labels_b: Optional[Tuple[float, ...]]

    def add_whitespace(self, encoding_a, encoding_b):
        self.tokens_a = self._add_whitespace(self.tokens_a, encoding_a)
        self.tokens_b = self._add_whitespace(self.tokens_b, encoding_b)

    def _add_whitespace(self, tokens, encoding) -> Tuple[str, ...]:
        assert len(tokens) == len(encoding)
        new_tokens = []
        for i in range(len(encoding)):
            token = tokens[i]
            if i < len(encoding) - 1:
                cur_end = encoding[i][1][1]
                next_start = encoding[i + 1][1][0]
                token += " " * (next_start - cur_end)
            new_tokens.append(token)
        return tuple(new_tokens)

    # For rendering with Jinja2
    @property
    def token_labels_a(self) -> Tuple[Tuple[str, float], ...]:
        return tuple(zip(self.tokens_a, self.labels_a))

    @property
    def token_labels_b(self) -> Tuple[Tuple[str, float], ...]:
        return tuple(zip(self.tokens_b, self.labels_b))

    @property
    def min(self) -> float:
        return min(self.labels_a + self.labels_b)

    @property
    def max(self) -> float:
        return max(self.labels_a + self.labels_b)


def tokenize(text: str) -> Tuple[str]:
    """
    Apply Moses-like tokenization to a string.
    """
    whitespace_tokenizer = Whitespace()
    output = whitespace_tokenizer.pre_tokenize_str(text)
    # [('This', (0, 4)), ('is', (5, 7)), ('a', (8, 9)), ('test', (10, 14)), ('.', (14, 15))]
    tokens = [str(token[0]) for token in output]
    return tuple(tokens)


def cos_sim(a: Tensor, b: Tensor):
    """
    Copied from https://github.com/UKPLab/sentence-transformers/blob/d928410803bb90f555926d145ee7ad3bd1373a83/sentence_transformers/util.py#L31

    Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.
    :return: Matrix with res[i][j]  = cos_sim(a[i], b[j])
    """
    if not isinstance(a, torch.Tensor):
        a = torch.tensor(a)

    if not isinstance(b, torch.Tensor):
        b = torch.tensor(b)

    if len(a.shape) == 1:
        a = a.unsqueeze(0)

    if len(b.shape) == 1:
        b = b.unsqueeze(0)

    a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
    b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
    return torch.mm(a_norm, b_norm.transpose(0, 1))


def pairwise_dot_score(a: Tensor, b: Tensor):
    """
    Copied from https://github.com/UKPLab/sentence-transformers/blob/d928410803bb90f555926d145ee7ad3bd1373a83/sentence_transformers/util.py#L73

    Computes the pairwise dot-product dot_prod(a[i], b[i])
    :return: Vector with res[i] = dot_prod(a[i], b[i])
    """
    if not isinstance(a, torch.Tensor):
        a = torch.tensor(a)

    if not isinstance(b, torch.Tensor):
        b = torch.tensor(b)

    return (a * b).sum(dim=-1)


def normalize_embeddings(embeddings: Tensor):
    """
    Copied from https://github.com/UKPLab/sentence-transformers/blob/d928410803bb90f555926d145ee7ad3bd1373a83/sentence_transformers/util.py#L101

    Normalizes the embeddings matrix, so that each sentence embedding has unit length
    """
    return torch.nn.functional.normalize(embeddings, p=2, dim=1)


def pairwise_cos_sim(a: Tensor, b: Tensor):
    """
    Copied from https://github.com/UKPLab/sentence-transformers/blob/d928410803bb90f555926d145ee7ad3bd1373a83/sentence_transformers/util.py#L87

    Computes the pairwise cossim cos_sim(a[i], b[i])
    :return: Vector with res[i] = cos_sim(a[i], b[i])
   """
    if not isinstance(a, torch.Tensor):
        a = torch.tensor(a)

    if not isinstance(b, torch.Tensor):
        b = torch.tensor(b)

    return pairwise_dot_score(normalize_embeddings(a), normalize_embeddings(b))