|
from __future__ import annotations |
|
from transformers import PretrainedConfig |
|
from torch import nn |
|
import torch |
|
from torchtyping import TensorType |
|
from .fasttext_jp_embedding import FastTextJpModel, FastTextJpConfig |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
|
|
|
class FastTextForSeuqenceClassificationConfig(FastTextJpConfig): |
|
"""FastTextJpModelのConfig |
|
""" |
|
model_type = "fasttext_jp" |
|
|
|
def __init__(self, |
|
ngram: int = 2, |
|
tokenizer_class="FastTextJpTokenizer", |
|
**kwargs): |
|
"""初期化処理 |
|
|
|
Args: |
|
ngram (int, optional): |
|
文章を分割する際のNgram |
|
tokenizer_class (str, optional): |
|
tokenizer_classを指定しないと、pipelineから読み込まれません。 |
|
config.jsonに記載されます。 |
|
""" |
|
self.ngram = ngram |
|
kwargs["tokenizer_class"] = tokenizer_class |
|
super().__init__(**kwargs) |
|
|
|
|
|
class FastTextForSeuqenceClassification(FastTextJpModel): |
|
"""FastTextのベクトルをベースとした分類を行います。 |
|
""" |
|
|
|
def __init__(self, config: FastTextForSeuqenceClassificationConfig): |
|
|
|
self.ngram = config.ngram |
|
super().__init__(config) |
|
|
|
def forward(self, **inputs) -> SequenceClassifierOutput: |
|
"""候補となるラベルから分類を行います。 |
|
|
|
Returns: |
|
SequenceClassifierOutput: 候補が正解している確率 |
|
""" |
|
input_ids = inputs["input_ids"] |
|
outputs = self.word_embeddings(input_ids) |
|
|
|
logits = [] |
|
for idx in range(len(outputs)): |
|
output = outputs[idx] |
|
|
|
token_type_ids = inputs["token_type_ids"][idx] |
|
|
|
attention_mask = inputs["attention_mask"][idx] |
|
|
|
sentence = output[torch.logical_and(token_type_ids == 0, |
|
attention_mask == 1)] |
|
candidate_label = output[torch.logical_and(token_type_ids == 1, |
|
attention_mask == 1)] |
|
sentence_words = self.split_ngram(sentence, self.ngram) |
|
candidate_label_mean = torch.mean(candidate_label, |
|
dim=-2, |
|
keepdim=True) |
|
p = self.cosine_similarity(sentence_words, candidate_label_mean) |
|
logits.append([torch.log(p), -torch.inf, torch.log(1 - p)]) |
|
logits = torch.FloatTensor(logits) |
|
return SequenceClassifierOutput( |
|
loss=None, |
|
logits=logits, |
|
hidden_states=None, |
|
attentions=None, |
|
) |
|
|
|
def cosine_similarity( |
|
self, sentence_words: TensorType["words", "vectors"], |
|
candidate_label_means: TensorType[1, "vectors"]) -> TensorType[1]: |
|
res = torch.tensor(0.) |
|
for sw in sentence_words: |
|
p = torch.nn.functional.cosine_similarity(sw, |
|
candidate_label_means[0], |
|
dim=0) |
|
if p > res: |
|
res = p |
|
return res |
|
|
|
def split_ngram(self, sentences: TensorType["word", "vectors"], |
|
n: int) -> TensorType["word", "vectors"]: |
|
res = [] |
|
for i in range(len(sentences) - n + 1): |
|
ngram = sentences[i:i + n] |
|
res.append(torch.mean(ngram, dim=0, keepdim=False)) |
|
return torch.stack(res) |
|
|
|
|
|
|
|
|
|
FastTextForSeuqenceClassificationConfig.register_for_auto_class() |
|
FastTextForSeuqenceClassification.register_for_auto_class( |
|
"AutoModelForSequenceClassification") |
|
|