|
from __future__ import annotations |
|
from transformers import PretrainedConfig |
|
from transformers import PreTrainedModel |
|
from torch import nn |
|
import torch |
|
from torchtyping import TensorType |
|
|
|
|
|
class FastTextJpConfig(PretrainedConfig): |
|
"""FastTextJpModelのConfig |
|
""" |
|
model_type = "fasttext_jp" |
|
|
|
def __init__(self, tokenizer_class="FastTextJpTokenizer", **kwargs): |
|
"""初期化処理 |
|
|
|
Args: |
|
tokenizer_class (str, optional): |
|
tokenizer_classを指定しないと、pipelineから読み込まれません。 |
|
config.jsonに記載されます。 |
|
""" |
|
kwargs["tokenizer_class"] = tokenizer_class |
|
super().__init__(**kwargs) |
|
|
|
|
|
class FastTextJpModel(PreTrainedModel): |
|
"""FastTextのEmbeddingを行います。 |
|
""" |
|
config_class = FastTextJpConfig |
|
|
|
def __init__(self, config: FastTextJpConfig): |
|
super().__init__(config) |
|
self.word_embeddings = nn.Embedding(config.vocab_size, |
|
config.hidden_size) |
|
|
|
def forward(self, **inputs) -> TensorType["batch", "word", "vectors"]: |
|
"""embeddingを行います。 |
|
|
|
Returns: |
|
TensorType["batch", "word", "vectors"]: 単語ごとにベクトルを返します。 |
|
""" |
|
return self.word_embeddings(torch.Tensor(inputs["input_ids"])) |
|
|
|
|
|
|
|
|
|
FastTextJpConfig.register_for_auto_class() |
|
FastTextJpModel.register_for_auto_class("AutoModel") |
|
|