|
""" |
|
Source: https://github.com/ZurichNLP/recognizing-semantic-differences |
|
MIT License |
|
Copyright (c) 2023 University of Zurich |
|
""" |
|
|
|
from typing import List |
|
|
|
import torch |
|
|
|
from recognizers.feature_based import FeatureExtractionRecognizer |
|
from recognizers.utils import DifferenceSample, cos_sim |
|
|
|
|
|
class DiffAlign(FeatureExtractionRecognizer): |
|
|
|
def __str__(self): |
|
return f"DiffAlign(model={self.pipeline.model.name_or_path}, layer={self.layer}" |
|
|
|
@torch.no_grad() |
|
def _predict_all(self, |
|
a: List[str], |
|
b: List[str], |
|
**kwargs, |
|
) -> List[DifferenceSample]: |
|
outputs_a = self.encode_batch(a, **kwargs) |
|
outputs_b = self.encode_batch(b, **kwargs) |
|
subwords_by_words_a = [self._get_subwords_by_word(sentence) for sentence in a] |
|
subwords_by_words_b = [self._get_subwords_by_word(sentence) for sentence in b] |
|
subword_labels_a = [] |
|
subword_labels_b = [] |
|
for i in range(len(a)): |
|
cosine_similarities = cos_sim(outputs_a[i], outputs_b[i]) |
|
max_similarities_a = torch.max(cosine_similarities, dim=1).values |
|
max_similarities_b = torch.max(cosine_similarities, dim=0).values |
|
subword_labels_a.append((1 - max_similarities_a)) |
|
subword_labels_b.append((1 - max_similarities_b)) |
|
samples = [] |
|
for i in range(len(a)): |
|
labels_a = self._subword_labels_to_word_labels(subword_labels_a[i], subwords_by_words_a[i]) |
|
labels_b = self._subword_labels_to_word_labels(subword_labels_b[i], subwords_by_words_b[i]) |
|
samples.append(DifferenceSample( |
|
tokens_a=tuple(a[i].split()), |
|
tokens_b=tuple(b[i].split()), |
|
labels_a=tuple(labels_a), |
|
labels_b=tuple(labels_b), |
|
)) |
|
return samples |
|
|