|
--- |
|
license: llama3.2 |
|
language: |
|
- en |
|
base_model: |
|
- meta-llama/Llama-3.2-1B |
|
pipeline_tag: text-generation |
|
library_name: transformers |
|
--- |
|
|
|
``` |
|
class LayerSkipSFTTrainer(SFTTrainer): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.early_exit_layer = 0 # initialize with 0 |
|
self.always_last_layer = True |
|
|
|
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): |
|
self.early_exit_layer = (self.early_exit_layer % (model.config.num_hidden_layers - 1)) + 1 # rotates between [1, num_hidden_layers-1] |
|
|
|
labels = inputs.pop("labels") |
|
outputs = model(**inputs, output_hidden_states=True) |
|
|
|
hidden_state = outputs["hidden_states"][self.early_exit_layer] |
|
logits = model.lm_head(hidden_state) |
|
loss = model.loss_function(logits=logits, labels=labels, vocab_size=model.vocab_size) |
|
|
|
if self.always_last_layer: |
|
loss = loss + model.loss_function(logits=outputs["logits"], labels=labels, vocab_size=model.vocab_size) |
|
|
|
return loss |
|
``` |