fasttext-classification / fasttext_fsc.py
paulhindemith's picture
commit files to HF hub
46f1d8c
from __future__ import annotations
import torch
from torchtyping import TensorType
from .fasttext_jp_embedding import FastTextJpModel, FastTextJpConfig
from transformers.modeling_outputs import SequenceClassifierOutput
class FastTextForSeuqenceClassificationConfig(FastTextJpConfig):
"""FastTextJpModelのConfig
"""
model_type = "fasttext_classification"
def __init__(self,
ngram: int | list[int] = 2,
tokenizer_class="FastTextJpTokenizer",
**kwargs):
"""初期化処理
Args:
ngram (int | list[int], optional):
文章を分割する際のNgram。
tokenizer_class (str, optional):
tokenizer_classを指定しないと、pipelineから読み込まれません。
config.jsonに記載されます。
"""
if isinstance(ngram, int):
self.ngrams = [ngram]
elif isinstance(ngram, list):
self.ngrams = ngram
else:
raise TypeError(f"got unknown type {type(ngram)}")
kwargs["tokenizer_class"] = tokenizer_class
super().__init__(**kwargs)
class NgramForSeuqenceClassification():
def __init__(self):
...
def __call__(self, sentence: TensorType["A", "vectors"],
candidate_label: TensorType["B", "vectors"],
ngram: int) -> TensorType[3]:
"""Ngramで文章を分けてコサイン類似度を算出する。
Args:
sentence (TensorType["A", "vectors"]): 文章ベクトル
candidate_label (TensorType["B", "vectors"]): ラベルベクトル
ngram (int): Ngram
Returns:
TensorType[3]:
文章の類似度。[Entailment, Neutral, Contradiction]
"""
sentence_ngrams = self.split_ngram(sentence, ngram)
candidate_label_mean = torch.mean(candidate_label, dim=0, keepdim=True)
p = self.cosine_similarity(sentence_ngrams, candidate_label_mean)
return torch.tensor([torch.log(p), -torch.inf, torch.log(1 - p)])
def cosine_similarity(
self, sentence_ngrams: TensorType["ngrams", "vectors"],
candidate_label_mean: TensorType[1, "vectors"]) -> TensorType[1]:
"""コサイン類似度を計算する。
Args:
sentence_ngrams (TensorType["ngrams", "vectors"]):
Ngram化された文章ベクトル
candidate_label_mean (TensorType[1, "vectors"]):
ラベルベクトル
Returns:
TensorType[1]: _description_
"""
res = torch.tensor(0.)
for i in range(len(sentence_ngrams)):
sw = sentence_ngrams[i]
p = torch.nn.functional.cosine_similarity(sw,
candidate_label_mean[0],
dim=0)
if p > res:
res = p
return res
def split_ngram(self, sentences: TensorType["A", "vectors"],
n: int) -> TensorType["ngrams", "vectors"]:
"""AとBの関連度を計算します。
Args:
sentences(TensorType["A", "vectors"]):
対象の文章
n(int):
ngram
Returns:
TensorType["ngrams", "vectors"]:
Ngram化された文章
"""
res = []
if len(sentences) <= n:
return torch.stack([torch.mean(sentences, dim=0, keepdim=False)])
for i in range(len(sentences) - n + 1):
ngram = sentences[i:i + n]
res.append(torch.mean(ngram, dim=0, keepdim=False))
return torch.stack(res)
class NgramsForSeuqenceClassification():
def __init__(self, config: FastTextForSeuqenceClassificationConfig):
self.max_ngrams = config.ngrams
self.ngram_layer = NgramForSeuqenceClassification()
def __call__(self, sentence: TensorType["A", "vectors"],
candidate_label: TensorType["B", "vectors"]) -> TensorType[3]:
"""AとBの関連度を計算します。
Args:
sentence(TensorType["A", "vectors"]):
対象の文章
candidate_label(TensorType["B", "vectors"]):
ラベルの文章
Returns:
TensorType[3]:
文章の類似度。[Entailment, Neutral, Contradiction]
"""
res = [-torch.inf, -torch.inf, -torch.inf]
for ngram in self.max_ngrams:
logit = self.ngram_layer(sentence, candidate_label, ngram)
if logit[0] > res[0]:
res = logit
return torch.tensor(res)
class BatchedNgramsForSeuqenceClassification():
def __init__(self, config: FastTextForSeuqenceClassificationConfig):
self.ngrams_layer = NgramsForSeuqenceClassification(config)
def __call__(
self,
last_hidden_state: TensorType["batch", "A+B", "vectors"],
token_type_ids: TensorType["batch", "A+B"],
attention_mask: TensorType["batch", "A+B"],
) -> TensorType["batch", 3]:
"""AとBの関連度を計算します。
Args:
last_hidden_state(TensorType["batch", "A+B", "vectors"]):
embeddingsの値。
token_type_ids(TensorType["A+B"]):
文章のid。0か1で、Bの場合1。
attention_mask(TensorType["A+B"]):
padを識別する。0か1で、padの場合1。
Returns:
TensorType["batch", 3]:
文章の類似度。[Entailment, Neutral, Contradiction]
"""
logits = []
embeddings = last_hidden_state
for idx in range(len(embeddings)):
vec = embeddings[idx]
# token_type_ids == 0が文章、1がラベルです。
token_type_ids = token_type_ids[idx]
# attention_mask == 1がパディングでないもの
attention_mask = attention_mask[idx]
sentence, candidate_label = self.split_sentence(
vec, token_type_ids, attention_mask)
logit = self.ngrams_layer(sentence, candidate_label)
logits.append(logit)
logits = torch.tensor(logits)
return logits
def split_sentence(
self, vec: TensorType["A+B", "vectors"],
token_type_ids: TensorType["A+B"], attention_mask: TensorType["A+B"]
) -> tuple[TensorType["A", "vectors"], TensorType["B", "vectors"]]:
"""CrossEncoderになっているので、文章を分割します。
Args:
vec(TensorType["A+B","vectors"]):
単語ベクトル
token_type_ids(TensorType["A+B"]):
文章のid。0か1で、Bの場合1。
attention_mask(TensorType["A+B"]):
padを識別する。0か1で、padの場合1。
Returns:
tuple[TensorType["A", "vectors"], TensorType["B", "vectors"]]:
AとBの文章を分割して返します。
"""
sentence = vec[torch.logical_and(token_type_ids == 0,
attention_mask == 1)]
candidate_label = vec[torch.logical_and(token_type_ids == 1,
attention_mask == 1)]
return sentence, candidate_label
class FastTextForSeuqenceClassification(FastTextJpModel):
"""FastTextのベクトルをベースとした分類を行います。
"""
def __init__(self, config: FastTextForSeuqenceClassificationConfig):
self.layer = BatchedNgramsForSeuqenceClassification(config)
super().__init__(config)
def forward(
self,
input_ids: TensorType["batch", "A+B", "vecotors"] = None,
attention_mask: TensorType["batch", "A+B"] = None,
token_type_ids: TensorType["batch", "A+B"] = None
) -> SequenceClassifierOutput:
"""候補となるラベルから分類を行います。
Returns:
SequenceClassifierOutput: 候補が正解している確率
"""
outputs = self.word_embeddings(input_ids)
logits = self.layer(last_hidden_state=outputs,
attention_mask=attention_mask,
token_type_ids=token_type_ids)
return SequenceClassifierOutput(
loss=None,
logits=logits,
hidden_states=None,
attentions=None,
)
# AutoModelに登録が必要だが、いろいろやり方が変わっているようで定まっていない。(2022/11/6)
# https://huggingface.co/docs/transformers/custom_models#sending-the-code-to-the-hub
FastTextForSeuqenceClassificationConfig.register_for_auto_class()
FastTextForSeuqenceClassification.register_for_auto_class(
"AutoModelForSequenceClassification")