|
import torch |
|
from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig |
|
from transformers.modeling_outputs import ModelOutput |
|
|
|
class HelloWorldConfig(PretrainedConfig): |
|
model_type = "hello-world" |
|
|
|
class HelloWorldModel(PreTrainedModel): |
|
config_class = HelloWorldConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
def forward(self, *args, **kwargs): |
|
return ModelOutput(logits=torch.tensor([[0]]), decoder_hidden_states=["Hello, world!"]) |
|
|
|
tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer.json") |
|
|
|
|
|
tokenizer_config = { |
|
"do_lower_case": False, |
|
"model_max_length": 512, |
|
"padding_side": "right", |
|
"special_tokens_map_file": None, |
|
"tokenizer_file": "tokenizer.json", |
|
"unk_token": "<unk>", |
|
"bos_token": "<s>", |
|
"eos_token": "</s>" |
|
} |
|
|
|
with open("tokenizer.json", "w") as f: |
|
import json |
|
json.dump(tokenizer_config, f) |
|
|