File size: 982 Bytes
1bae57f
 
 
dc78d43
1bae57f
 
dc78d43
1bae57f
 
dc78d43
1bae57f
 
dc78d43
1bae57f
 
dc78d43
1bae57f
80b293f
1bae57f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig
from transformers.modeling_outputs import ModelOutput

class HelloWorldConfig(PretrainedConfig):
    model_type = "hello-world"

class HelloWorldModel(PreTrainedModel):
    config_class = HelloWorldConfig

    def __init__(self, config):
        super().__init__(config)

    def forward(self, *args, **kwargs):
        return ModelOutput(logits=torch.tensor([[0]]), decoder_hidden_states=["Hello, world!"])

tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer.json")

# Dummy tokenizer configuration to work with the model
tokenizer_config = {
    "do_lower_case": False,
    "model_max_length": 512,
    "padding_side": "right",
    "special_tokens_map_file": None,
    "tokenizer_file": "tokenizer.json",
    "unk_token": "<unk>",
    "bos_token": "<s>",
    "eos_token": "</s>"
}

with open("tokenizer.json", "w") as f:
    import json
    json.dump(tokenizer_config, f)