Mxytyu commited on
Commit
2ec93d0
·
verified ·
1 Parent(s): 0576ca5

Update model.safetensors

Browse files
Files changed (1) hide show
  1. model.safetensors +25 -8
model.safetensors CHANGED
@@ -1,22 +1,34 @@
1
  import torch
2
- from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig
3
- from transformers.modeling_outputs import ModelOutput
 
4
 
5
  class HelloWorldConfig(PretrainedConfig):
6
  model_type = "hello-world"
7
 
8
- class HelloWorldModel(PreTrainedModel):
9
  config_class = HelloWorldConfig
10
 
11
  def __init__(self, config):
12
  super().__init__(config)
13
 
14
- def forward(self, *args, **kwargs):
15
- return ModelOutput(logits=torch.tensor([[0]]), decoder_hidden_states=["Hello, world!"])
 
 
 
 
 
 
 
 
16
 
17
- tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer.json")
 
 
 
 
18
 
19
- # Dummy tokenizer configuration to work with the model
20
  tokenizer_config = {
21
  "do_lower_case": False,
22
  "model_max_length": 512,
@@ -25,9 +37,14 @@ tokenizer_config = {
25
  "tokenizer_file": "tokenizer.json",
26
  "unk_token": "<unk>",
27
  "bos_token": "<s>",
28
- "eos_token": "</s>"
 
29
  }
30
 
 
31
  with open("tokenizer.json", "w") as f:
32
  import json
33
  json.dump(tokenizer_config, f)
 
 
 
 
1
  import torch
2
+ from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig, LogitsProcessorList
3
+ from transformers.generation_utils import GenerationMixin
4
+ from transformers.modeling_outputs import CausalLMOutput
5
 
6
  class HelloWorldConfig(PretrainedConfig):
7
  model_type = "hello-world"
8
 
9
+ class HelloWorldModel(PreTrainedModel, GenerationMixin):
10
  config_class = HelloWorldConfig
11
 
12
  def __init__(self, config):
13
  super().__init__(config)
14
 
15
+ def forward(self, input_ids=None, **kwargs):
16
+ batch_size = input_ids.shape[0]
17
+ sequence_length = input_ids.shape[1]
18
+
19
+ # Generate a tensor with repeated "Hello, world!" token IDs
20
+ hello_world_token_id = self.config.vocab_size - 1 # assuming last token is "Hello, world!"
21
+ logits = torch.full((batch_size, sequence_length, self.config.vocab_size), float('-inf'))
22
+ logits[:, :, hello_world_token_id] = 0 # setting logits for "Hello, world!" to 0 (highest value)
23
+
24
+ return CausalLMOutput(logits=logits)
25
 
26
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
27
+ return {"input_ids": input_ids}
28
+
29
+ def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder=False):
30
+ return model_kwargs
31
 
 
32
  tokenizer_config = {
33
  "do_lower_case": False,
34
  "model_max_length": 512,
 
37
  "tokenizer_file": "tokenizer.json",
38
  "unk_token": "<unk>",
39
  "bos_token": "<s>",
40
+ "eos_token": "</s>",
41
+ "vocab_size": 1, # Simplified vocabulary size
42
  }
43
 
44
+ # Save tokenizer configuration
45
  with open("tokenizer.json", "w") as f:
46
  import json
47
  json.dump(tokenizer_config, f)
48
+
49
+ tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer.json")
50
+ tokenizer.add_tokens(["Hello, world!"])