File size: 2,360 Bytes
fd29fa7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import torch
import torch.nn as nn
from typing import Optional
from dataclasses import dataclass
from transformers import PreTrainedModel
from transformers.utils import ModelOutput
from .configuration_dart2vec import Dart2VecConfig
@dataclass
class Dart2VecModelOutput(ModelOutput):
hidden_states: torch.Tensor
@dataclass
class Dart2VecModelForFeatureExtractionOutput(ModelOutput):
embeddings: torch.Tensor
class Dart2VecEmbeddings(nn.Module):
def __init__(self, config: Dart2VecConfig):
super().__init__()
self.tag_embeddings = nn.Embedding(
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
):
if inputs_embeds is not None:
return inputs_embeds
embeddings = self.tag_embeddings(input_ids)
return embeddings
class Dart2VecPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = Dart2VecConfig
base_model_prefix = "dart2vec"
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Embedding):
torch.nn.init.kaiming_uniform_(module.weight)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
class Dart2VecModel(Dart2VecPreTrainedModel):
def __init__(self, config: Dart2VecConfig):
super().__init__(config)
self.config = config
self.embeddings = Dart2VecEmbeddings(config)
self.post_init()
def get_input_embeddings(self):
return self.embeddings.tag_embeddings
def set_input_embeddings(self, value):
self.embeddings.tag_embeddings = value
def forward(
self, input_ids: torch.Tensor
) -> Dart2VecModelForFeatureExtractionOutput:
embeddings = self.embeddings(input_ids)
return Dart2VecModelForFeatureExtractionOutput(embeddings=embeddings)
|