File size: 1,606 Bytes
25845f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import torch
from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig, CausalLMOutput
# Define the model configuration
class HelloWorldConfig(PretrainedConfig):
model_type = "hello-world"
vocab_size = 2
bos_token_id = 0
eos_token_id = 1
# Define the model
class HelloWorldModel(PreTrainedModel):
config_class = HelloWorldConfig
def __init__(self, config):
super().__init__(config)
def forward(self, input_ids=None, **kwargs):
batch_size = input_ids.shape[0]
sequence_length = input_ids.shape[1]
# Generate logits for the "Hello, world!" token
hello_world_token_id = self.config.vocab_size - 1
logits = torch.full((batch_size, sequence_length, self.config.vocab_size), float('-inf'))
logits[:, :, hello_world_token_id] = 0
return CausalLMOutput(logits=logits)
# Define and save the tokenizer
tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer.json")
tokenizer.add_tokens(["Hello, world!"])
tokenizer_config = {
"do_lower_case": False,
"model_max_length": 512,
"padding_side": "right",
"special_tokens_map_file": None,
"tokenizer_file": "tokenizer.json",
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"vocab_size": 2,
}
with open("tokenizer.json", "w") as f:
json.dump(tokenizer_config, f)
# Initialize model
config = HelloWorldConfig()
model = HelloWorldModel(config)
# Save model using safetensors format
from safetensors.torch import save_file
save_file(model.state_dict(), "hello_world_model.safetensors")
|