ariG23498 HF staff commited on
Commit
2879591
·
verified ·
1 Parent(s): 6efde3a

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +32 -0
README.md ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: llama3.2
3
+ language:
4
+ - en
5
+ base_model:
6
+ - meta-llama/Llama-3.2-1B
7
+ pipeline_tag: text-generation
8
+ library_name: transformers
9
+ ---
10
+
11
+ ```
12
+ class LayerSkipSFTTrainer(SFTTrainer):
13
+ def __init__(self, *args, **kwargs):
14
+ super().__init__(*args, **kwargs)
15
+ self.early_exit_layer = 0 # initialize with 0
16
+ self.always_last_layer = True
17
+
18
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
19
+ self.early_exit_layer = (self.early_exit_layer % (model.config.num_hidden_layers - 1)) + 1 # rotates between [1, num_hidden_layers-1]
20
+
21
+ labels = inputs.pop("labels")
22
+ outputs = model(**inputs, output_hidden_states=True)
23
+
24
+ hidden_state = outputs["hidden_states"][self.early_exit_layer]
25
+ logits = model.lm_head(hidden_state)
26
+ loss = model.loss_function(logits=logits, labels=labels, vocab_size=model.vocab_size)
27
+
28
+ if self.always_last_layer:
29
+ loss = loss + model.loss_function(logits=outputs["logits"], labels=labels, vocab_size=model.vocab_size)
30
+
31
+ return loss
32
+ ```