File size: 982 Bytes
1bae57f dc78d43 1bae57f dc78d43 1bae57f dc78d43 1bae57f dc78d43 1bae57f dc78d43 1bae57f 80b293f 1bae57f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import torch
from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig
from transformers.modeling_outputs import ModelOutput
class HelloWorldConfig(PretrainedConfig):
model_type = "hello-world"
class HelloWorldModel(PreTrainedModel):
config_class = HelloWorldConfig
def __init__(self, config):
super().__init__(config)
def forward(self, *args, **kwargs):
return ModelOutput(logits=torch.tensor([[0]]), decoder_hidden_states=["Hello, world!"])
tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer.json")
# Dummy tokenizer configuration to work with the model
tokenizer_config = {
"do_lower_case": False,
"model_max_length": 512,
"padding_side": "right",
"special_tokens_map_file": None,
"tokenizer_file": "tokenizer.json",
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>"
}
with open("tokenizer.json", "w") as f:
import json
json.dump(tokenizer_config, f)
|