|
from __future__ import annotations |
|
import torch |
|
from torchtyping import TensorType |
|
from .fasttext_jp_embedding import FastTextJpModel, FastTextJpConfig |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
|
|
|
class FastTextForSeuqenceClassificationConfig(FastTextJpConfig): |
|
"""FastTextJpModelのConfig |
|
""" |
|
model_type = "fasttext_classification" |
|
|
|
def __init__(self, |
|
ngram: int | list[int] = 2, |
|
tokenizer_class="FastTextJpTokenizer", |
|
**kwargs): |
|
"""初期化処理 |
|
|
|
Args: |
|
ngram (int | list[int], optional): |
|
文章を分割する際のNgram。 |
|
tokenizer_class (str, optional): |
|
tokenizer_classを指定しないと、pipelineから読み込まれません。 |
|
config.jsonに記載されます。 |
|
""" |
|
if isinstance(ngram, int): |
|
self.ngrams = [ngram] |
|
elif isinstance(ngram, list): |
|
self.ngrams = ngram |
|
else: |
|
raise TypeError(f"got unknown type {type(ngram)}") |
|
kwargs["tokenizer_class"] = tokenizer_class |
|
super().__init__(**kwargs) |
|
|
|
|
|
class NgramForSeuqenceClassification(): |
|
|
|
def __init__(self): |
|
... |
|
|
|
def __call__(self, sentence: TensorType["A", "vectors"], |
|
candidate_label: TensorType["B", "vectors"], |
|
ngram: int) -> TensorType[3]: |
|
"""Ngramで文章を分けてコサイン類似度を算出する。 |
|
|
|
Args: |
|
sentence (TensorType["A", "vectors"]): 文章ベクトル |
|
candidate_label (TensorType["B", "vectors"]): ラベルベクトル |
|
ngram (int): Ngram |
|
|
|
Returns: |
|
TensorType[3]: |
|
文章の類似度。[Entailment, Neutral, Contradiction] |
|
""" |
|
|
|
sentence_ngrams = self.split_ngram(sentence, ngram) |
|
|
|
candidate_label_mean = torch.mean(candidate_label, dim=0, keepdim=True) |
|
p = self.cosine_similarity(sentence_ngrams, candidate_label_mean) |
|
return torch.tensor([torch.log(p), -torch.inf, torch.log(1 - p)]) |
|
|
|
def cosine_similarity( |
|
self, sentence_ngrams: TensorType["ngrams", "vectors"], |
|
candidate_label_mean: TensorType[1, "vectors"]) -> TensorType[1]: |
|
"""コサイン類似度を計算する。 |
|
|
|
Args: |
|
sentence_ngrams (TensorType["ngrams", "vectors"]): |
|
Ngram化された文章ベクトル |
|
candidate_label_mean (TensorType[1, "vectors"]): |
|
ラベルベクトル |
|
|
|
Returns: |
|
TensorType[1]: _description_ |
|
""" |
|
|
|
res = torch.tensor(0.) |
|
for i in range(len(sentence_ngrams)): |
|
sw = sentence_ngrams[i] |
|
p = torch.nn.functional.cosine_similarity(sw, |
|
candidate_label_mean[0], |
|
dim=0) |
|
if p > res: |
|
res = p |
|
return res |
|
|
|
def split_ngram(self, sentences: TensorType["A", "vectors"], |
|
n: int) -> TensorType["ngrams", "vectors"]: |
|
"""AとBの関連度を計算します。 |
|
Args: |
|
sentences(TensorType["A", "vectors"]): |
|
対象の文章 |
|
n(int): |
|
ngram |
|
Returns: |
|
TensorType["ngrams", "vectors"]: |
|
Ngram化された文章 |
|
""" |
|
|
|
res = [] |
|
if len(sentences) <= n: |
|
return torch.stack([torch.mean(sentences, dim=0, keepdim=False)]) |
|
for i in range(len(sentences) - n + 1): |
|
ngram = sentences[i:i + n] |
|
res.append(torch.mean(ngram, dim=0, keepdim=False)) |
|
return torch.stack(res) |
|
|
|
|
|
class NgramsForSeuqenceClassification(): |
|
|
|
def __init__(self, config: FastTextForSeuqenceClassificationConfig): |
|
self.max_ngrams = config.ngrams |
|
self.ngram_layer = NgramForSeuqenceClassification() |
|
|
|
def __call__(self, sentence: TensorType["A", "vectors"], |
|
candidate_label: TensorType["B", "vectors"]) -> TensorType[3]: |
|
"""AとBの関連度を計算します。 |
|
Args: |
|
sentence(TensorType["A", "vectors"]): |
|
対象の文章 |
|
candidate_label(TensorType["B", "vectors"]): |
|
ラベルの文章 |
|
|
|
Returns: |
|
TensorType[3]: |
|
文章の類似度。[Entailment, Neutral, Contradiction] |
|
""" |
|
|
|
res = [-torch.inf, -torch.inf, -torch.inf] |
|
for ngram in self.max_ngrams: |
|
logit = self.ngram_layer(sentence, candidate_label, ngram) |
|
if logit[0] > res[0]: |
|
res = logit |
|
return torch.tensor(res) |
|
|
|
|
|
class BatchedNgramsForSeuqenceClassification(): |
|
|
|
def __init__(self, config: FastTextForSeuqenceClassificationConfig): |
|
self.ngrams_layer = NgramsForSeuqenceClassification(config) |
|
|
|
def __call__( |
|
self, |
|
last_hidden_state: TensorType["batch", "A+B", "vectors"], |
|
token_type_ids: TensorType["batch", "A+B"], |
|
attention_mask: TensorType["batch", "A+B"], |
|
) -> TensorType["batch", 3]: |
|
"""AとBの関連度を計算します。 |
|
Args: |
|
last_hidden_state(TensorType["batch", "A+B", "vectors"]): |
|
embeddingsの値。 |
|
token_type_ids(TensorType["A+B"]): |
|
文章のid。0か1で、Bの場合1。 |
|
attention_mask(TensorType["A+B"]): |
|
padを識別する。0か1で、padの場合1。 |
|
|
|
Returns: |
|
TensorType["batch", 3]: |
|
文章の類似度。[Entailment, Neutral, Contradiction] |
|
""" |
|
|
|
logits = [] |
|
embeddings = last_hidden_state |
|
for idx in range(len(embeddings)): |
|
vec = embeddings[idx] |
|
|
|
token_type_ids = token_type_ids[idx] |
|
|
|
attention_mask = attention_mask[idx] |
|
|
|
sentence, candidate_label = self.split_sentence( |
|
vec, token_type_ids, attention_mask) |
|
logit = self.ngrams_layer(sentence, candidate_label) |
|
logits.append(logit) |
|
logits = torch.tensor(logits) |
|
return logits |
|
|
|
def split_sentence( |
|
self, vec: TensorType["A+B", "vectors"], |
|
token_type_ids: TensorType["A+B"], attention_mask: TensorType["A+B"] |
|
) -> tuple[TensorType["A", "vectors"], TensorType["B", "vectors"]]: |
|
"""CrossEncoderになっているので、文章を分割します。 |
|
|
|
Args: |
|
vec(TensorType["A+B","vectors"]): |
|
単語ベクトル |
|
|
|
token_type_ids(TensorType["A+B"]): |
|
文章のid。0か1で、Bの場合1。 |
|
|
|
attention_mask(TensorType["A+B"]): |
|
padを識別する。0か1で、padの場合1。 |
|
|
|
Returns: |
|
tuple[TensorType["A", "vectors"], TensorType["B", "vectors"]]: |
|
AとBの文章を分割して返します。 |
|
""" |
|
|
|
sentence = vec[torch.logical_and(token_type_ids == 0, |
|
attention_mask == 1)] |
|
candidate_label = vec[torch.logical_and(token_type_ids == 1, |
|
attention_mask == 1)] |
|
return sentence, candidate_label |
|
|
|
|
|
class FastTextForSeuqenceClassification(FastTextJpModel): |
|
"""FastTextのベクトルをベースとした分類を行います。 |
|
""" |
|
|
|
def __init__(self, config: FastTextForSeuqenceClassificationConfig): |
|
|
|
self.layer = BatchedNgramsForSeuqenceClassification(config) |
|
super().__init__(config) |
|
|
|
def forward( |
|
self, |
|
input_ids: TensorType["batch", "A+B", "vecotors"] = None, |
|
attention_mask: TensorType["batch", "A+B"] = None, |
|
token_type_ids: TensorType["batch", "A+B"] = None |
|
) -> SequenceClassifierOutput: |
|
"""候補となるラベルから分類を行います。 |
|
|
|
Returns: |
|
SequenceClassifierOutput: 候補が正解している確率 |
|
""" |
|
outputs = self.word_embeddings(input_ids) |
|
logits = self.layer(last_hidden_state=outputs, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids) |
|
|
|
return SequenceClassifierOutput( |
|
loss=None, |
|
logits=logits, |
|
hidden_states=None, |
|
attentions=None, |
|
) |
|
|
|
|
|
|
|
|
|
FastTextForSeuqenceClassificationConfig.register_for_auto_class() |
|
FastTextForSeuqenceClassification.register_for_auto_class( |
|
"AutoModelForSequenceClassification") |
|
|