|
""" |
|
Source: https://github.com/ZurichNLP/recognizing-semantic-differences |
|
MIT License |
|
Copyright (c) 2023 University of Zurich |
|
""" |
|
|
|
import itertools |
|
from typing import List, Union |
|
|
|
import torch |
|
import transformers |
|
from transformers import FeatureExtractionPipeline, Pipeline |
|
|
|
from recognizers.base import DifferenceRecognizer |
|
from recognizers.utils import DifferenceSample |
|
|
|
Ngram = List[int] |
|
|
|
|
|
class FeatureExtractionRecognizer(DifferenceRecognizer): |
|
|
|
def __init__(self, |
|
model_name_or_path: str = None, |
|
pipeline: Union[FeatureExtractionPipeline, Pipeline] = None, |
|
layer: int = -1, |
|
batch_size: int = 16, |
|
): |
|
assert model_name_or_path is not None or pipeline is not None |
|
if pipeline is None: |
|
pipeline = transformers.pipeline( |
|
model=model_name_or_path, |
|
task="feature-extraction", |
|
) |
|
self.pipeline = pipeline |
|
self.layer = layer |
|
self.batch_size = batch_size |
|
|
|
def encode_batch(self, sentences: List[str], **kwargs) -> torch.Tensor: |
|
model_inputs = self.pipeline.tokenizer(sentences, return_tensors="pt", padding=True, truncation=True) |
|
model_inputs = model_inputs.to(self.pipeline.device) |
|
outputs = self.pipeline.model(**model_inputs, output_hidden_states=True, **kwargs) |
|
return outputs.hidden_states[self.layer] |
|
|
|
def predict(self, |
|
a: str, |
|
b: str, |
|
**kwargs, |
|
) -> DifferenceSample: |
|
return self.predict_all([a], [b], **kwargs)[0] |
|
|
|
def predict_all(self, |
|
a: List[str], |
|
b: List[str], |
|
**kwargs, |
|
) -> List[DifferenceSample]: |
|
samples = [] |
|
for i in range(0, len(a), self.batch_size): |
|
samples.extend(self._predict_all( |
|
a[i:i + self.batch_size], |
|
b[i:i + self.batch_size], |
|
**kwargs, |
|
)) |
|
return samples |
|
|
|
@torch.no_grad() |
|
def _predict_all(self, |
|
a: List[str], |
|
b: List[str], |
|
**kwargs, |
|
) -> List[DifferenceSample]: |
|
raise NotImplementedError |
|
|
|
def _pool(self, token_embeddings: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
|
""" |
|
:param token_embeddings: batch x seq_len x dim |
|
:param mask: batch x seq_len; 1 if token should be included in the pooling |
|
:return: batch x dim |
|
Do only sum and do not divide by the number of tokens because cosine similarity is length-invariant. |
|
""" |
|
return torch.sum(token_embeddings * mask.unsqueeze(-1), dim=1) |
|
|
|
def _get_subwords_by_word(self, sentence: str) -> List[Ngram]: |
|
""" |
|
:return: For each word in the sentence, the positions of the subwords that make up the word. |
|
""" |
|
batch_encoding = self.pipeline.tokenizer( |
|
sentence, |
|
padding=True, |
|
truncation=True, |
|
) |
|
subword_ids: List[List[int]] = [] |
|
|
|
for subword_idx in range(len(batch_encoding.encodings[0].word_ids)): |
|
if batch_encoding.encodings[0].word_ids[subword_idx] is None: |
|
continue |
|
char_idx = batch_encoding.encodings[0].offsets[subword_idx][0] |
|
if isinstance(self.pipeline.tokenizer, transformers.XLMRobertaTokenizerFast) or \ |
|
isinstance(self.pipeline.tokenizer, transformers.XLMRobertaTokenizer): |
|
token = batch_encoding.encodings[0].tokens[subword_idx] |
|
is_tail = not token.startswith("▁") and token not in self.pipeline.tokenizer.all_special_tokens |
|
elif isinstance(self.pipeline.tokenizer, transformers.RobertaTokenizerFast) or \ |
|
isinstance(self.pipeline.tokenizer, transformers.RobertaTokenizer): |
|
token = batch_encoding.encodings[0].tokens[subword_idx] |
|
is_tail = not token.startswith("Ġ") and token not in self.pipeline.tokenizer.all_special_tokens |
|
else: |
|
is_tail = char_idx > 0 and char_idx == batch_encoding.encodings[0].offsets[subword_idx - 1][1] |
|
if is_tail and len(subword_ids) > 0: |
|
subword_ids[-1].append(subword_idx) |
|
else: |
|
subword_ids.append([subword_idx]) |
|
return subword_ids |
|
|
|
def _get_ngrams(self, subwords_by_word: List[Ngram]) -> List[Ngram]: |
|
""" |
|
:return: For each subword ngram in the sentence, the positions of the subwords that make up the ngram. |
|
""" |
|
subwords = list(itertools.chain.from_iterable(subwords_by_word)) |
|
|
|
min_n = min(self.min_n, len(subwords)) |
|
ngrams = [] |
|
for n in range(min_n, self.max_n + 1): |
|
for i in range(len(subwords) - n + 1): |
|
ngrams.append(subwords[i:i + n]) |
|
return ngrams |
|
|
|
def _subword_labels_to_word_labels(self, subword_labels: torch.Tensor, subwords_by_words: List[Ngram]) -> List[float]: |
|
""" |
|
:param subword_labels: num_subwords |
|
:param subwords_by_words: num_words x num_subwords |
|
:return: num_words |
|
""" |
|
labels = [] |
|
for subword_indices in subwords_by_words: |
|
label = subword_labels[subword_indices].mean().item() |
|
labels.append(label) |
|
return labels |
|
|