jvamvas's picture
Basic implementation
fda57dd
raw
history blame
10.7 kB
"""
Source: https://github.com/ZurichNLP/recognizing-semantic-differences
MIT License
Copyright (c) 2023 University of Zurich
"""
import itertools
from copy import deepcopy
from typing import Union, List
import torch
from transformers import Pipeline, FeatureExtractionPipeline
from recognizers.feature_based import FeatureExtractionRecognizer, Ngram
from recognizers.utils import DifferenceSample, pairwise_cos_sim, cos_sim
class DiffDel(FeatureExtractionRecognizer):
def __init__(self,
model_name_or_path: str = None,
pipeline: Union[FeatureExtractionPipeline, Pipeline] = None,
layer: int = -1,
batch_size: int = 16,
min_n: int = 1,
max_n: int = 1, # Inclusive
):
super().__init__(model_name_or_path, pipeline, layer, batch_size)
assert min_n <= max_n
self.min_n = min_n
self.max_n = max_n
def __str__(self):
return f"DiffDel(model={self.pipeline.model.name_or_path}, layer={self.layer}, " \
f"min_n={self.min_n}, max_n={self.max_n})"
@torch.no_grad()
def _predict_all(self,
a: List[str],
b: List[str],
**kwargs,
) -> List[DifferenceSample]:
outputs_a = self.encode_batch(a, **kwargs)
outputs_b = self.encode_batch(b, **kwargs)
subwords_by_words_a = [self._get_subwords_by_word(sentence) for sentence in a]
subwords_by_words_b = [self._get_subwords_by_word(sentence) for sentence in b]
ngrams_a = [self._get_ngrams(subwords_by_word) for subwords_by_word in subwords_by_words_a]
ngrams_b = [self._get_ngrams(subwords_by_word) for subwords_by_word in subwords_by_words_b]
sentence_embeddings_a = self._get_full_sentence_embeddings(outputs_a, [list(itertools.chain.from_iterable(subwords)) for subwords in subwords_by_words_a])
sentence_embeddings_b = self._get_full_sentence_embeddings(outputs_b, [list(itertools.chain.from_iterable(subwords)) for subwords in subwords_by_words_b])
full_similarities = pairwise_cos_sim(sentence_embeddings_a, sentence_embeddings_b)
all_labels_a = []
all_labels_b = []
for i in range(len(a)):
partial_embeddings_a = self._get_partial_sentence_embeddings_for_sample(outputs_a[i], ngrams_a[i])
partial_embeddings_b = self._get_partial_sentence_embeddings_for_sample(outputs_b[i], ngrams_b[i])
partial_similarities_a = cos_sim(partial_embeddings_a, sentence_embeddings_b[i].unsqueeze(0)).squeeze(1)
partial_similarities_b = cos_sim(partial_embeddings_b, sentence_embeddings_a[i].unsqueeze(0)).squeeze(1)
ngram_labels_a = (partial_similarities_a - full_similarities[i] + 1) / 2
ngram_labels_b = (partial_similarities_b - full_similarities[i] + 1) / 2
subword_labels_a = self._distribute_ngram_labels_to_subwords(ngram_labels_a, ngrams_a[i])
subword_labels_b = self._distribute_ngram_labels_to_subwords(ngram_labels_b, ngrams_b[i])
labels_a = self._subword_labels_to_word_labels(subword_labels_a, subwords_by_words_a[i])
labels_b = self._subword_labels_to_word_labels(subword_labels_b, subwords_by_words_b[i])
all_labels_a.append(labels_a)
all_labels_b.append(labels_b)
samples = []
for i in range(len(a)):
samples.append(DifferenceSample(
tokens_a=tuple(a[i].split()),
tokens_b=tuple(b[i].split()),
labels_a=tuple(all_labels_a[i]),
labels_b=tuple(all_labels_b[i]),
))
return samples
def _get_full_sentence_embeddings(self, token_embeddings: torch.Tensor, include_subwords: List[List[int]]) -> torch.Tensor:
"""
:param token_embeddings: batch x seq_len x dim
:param include_subwords: batch x num_subwords
:return: A tensor of shape batch x dim
"""
pool_mask = torch.zeros(token_embeddings.shape[0], token_embeddings.shape[1], device=token_embeddings.device)
for i, subword_indices in enumerate(include_subwords):
pool_mask[i, subword_indices] = 1
sentence_embeddings = self._pool(token_embeddings, pool_mask)
return sentence_embeddings
def _get_partial_sentence_embeddings_for_sample(self, token_embeddings: torch.Tensor, ngrams: List[Ngram]) -> torch.Tensor:
"""
:param token_embeddings: seq_len x dim
:param ngrams: num_ngrams x n
:return: A tensor of shape num_ngrams x dim
"""
pool_mask = torch.zeros(len(ngrams), token_embeddings.shape[0], device=token_embeddings.device)
pool_mask[:, list(itertools.chain.from_iterable(ngrams))] = 1
for i, subword_indices in enumerate(ngrams):
pool_mask[i, subword_indices] = 0
partial_embeddings = self._pool(token_embeddings.unsqueeze(0).repeat(len(ngrams), 1, 1), pool_mask)
return partial_embeddings
def _distribute_ngram_labels_to_subwords(self, ngram_labels: torch.Tensor, ngrams: List[Ngram]) -> torch.Tensor:
"""
:param ngram_labels: num_ngrams
:param ngrams: num_ngrams x n
:return: num_subwords
"""
max_subword_idx = max(itertools.chain.from_iterable(ngrams))
subword_contributions = torch.zeros(max_subword_idx + 1, device=ngram_labels.device)
contribution_count = torch.zeros(max_subword_idx + 1, device=ngram_labels.device)
for i, ngram in enumerate(ngrams):
subword_contributions[ngram] += ngram_labels[i] / len(ngram)
contribution_count[ngram] += 1 / len(ngram)
subword_contributions /= contribution_count
return subword_contributions
class DiffDelWithReencode(FeatureExtractionRecognizer):
"""
Version of DiffDel that encodes the partial sentences from scratch (instead of encoding the full sentence once and
then excluding hidden states from the mean)
"""
def __init__(self,
model_name_or_path: str = None,
pipeline: Union[FeatureExtractionPipeline, Pipeline] = None,
layer: int = -1,
batch_size: int = 16,
):
super().__init__(model_name_or_path, pipeline, layer, batch_size)
def __str__(self):
return f"DiffDelWithReencode(model={self.pipeline.model.name_or_path}, layer={self.layer})"
@torch.no_grad()
def _predict_all(self,
a: List[str],
b: List[str],
**kwargs,
) -> List[DifferenceSample]:
a_words = [sentence.split() for sentence in a]
b_words = [sentence.split() for sentence in b]
a_words_partial = []
b_words_partial = []
for words in a_words:
for i, word in enumerate(words):
partial = deepcopy(words)
del partial[i]
a_words_partial.append(partial)
for words in b_words:
for i, word in enumerate(words):
partial = deepcopy(words)
del partial[i]
b_words_partial.append(partial)
a_partial = [" ".join([word for word in words if word]) for words in a_words_partial]
b_partial = [" ".join([word for word in words if word]) for words in b_words_partial]
a_num_partial = [len(words) for words in a_words]
b_num_partial = [len(words) for words in b_words]
a_embedding_full = self._encode_and_pool(a, **kwargs)
b_embedding_full = self._encode_and_pool(b, **kwargs)
a_embeddings_partial = []
b_embeddings_partial = []
for i in range(0, len(a_partial), self.batch_size):
a_embeddings_partial_batch = self._encode_and_pool(a_partial[i:i + self.batch_size], **kwargs)
a_embeddings_partial.append(a_embeddings_partial_batch)
for i in range(0, len(b_partial), self.batch_size):
b_embeddings_partial_batch = self._encode_and_pool(b_partial[i:i + self.batch_size], **kwargs)
b_embeddings_partial.append(b_embeddings_partial_batch)
a_embeddings_partial = torch.cat(a_embeddings_partial, dim=0)
b_embeddings_partial = torch.cat(b_embeddings_partial, dim=0)
labels_a = []
labels_b = []
similarity_full = pairwise_cos_sim(a_embedding_full, b_embedding_full)
for i in range(len(a)):
a_embeddings_partial_i = a_embeddings_partial[sum(a_num_partial[:i]):sum(a_num_partial[:i + 1])]
similarities_partial = pairwise_cos_sim(a_embeddings_partial_i, b_embedding_full[i].unsqueeze(0)).squeeze(0)
labels = (similarities_partial - similarity_full[i] + 1) / 2
labels = labels.detach().cpu().tolist()
if isinstance(labels, float):
labels = [labels]
assert len(labels) == len(a_words[i])
labels_a.append(labels)
for i in range(len(b)):
b_embeddings_partial_i = b_embeddings_partial[sum(b_num_partial[:i]):sum(b_num_partial[:i + 1])]
similarities_partial = pairwise_cos_sim(b_embeddings_partial_i, a_embedding_full[i].unsqueeze(0)).squeeze(0)
labels = (similarities_partial - similarity_full[i] + 1) / 2
labels = labels.detach().cpu().tolist()
if isinstance(labels, float):
labels = [labels]
assert len(labels) == len(b_words[i])
labels_b.append(labels)
samples = []
for i in range(len(a)):
samples.append(DifferenceSample(
tokens_a=tuple(a_words[i]),
tokens_b=tuple(b_words[i]),
labels_a=tuple(labels_a[i]),
labels_b=tuple(labels_b[i]),
))
return samples
def _encode_and_pool(self, sentences: List[str], **kwargs) -> torch.Tensor:
model_inputs = self.pipeline.tokenizer(sentences, return_tensors="pt", padding=True, truncation=True)
model_inputs = model_inputs.to(self.pipeline.device)
outputs = self.pipeline.model(**model_inputs, output_hidden_states=True, **kwargs)
if self.layer == "mean":
token_embeddings = torch.stack(outputs.hidden_states, dim=0).mean(dim=0)
else:
assert isinstance(self.layer, int)
token_embeddings = outputs.hidden_states[self.layer]
mask = model_inputs["attention_mask"]
sentence_embeddings = torch.sum(token_embeddings * mask.unsqueeze(-1), dim=1)
return sentence_embeddings