Edit model card

Summary

Distilled with Distily library using teacher model gpt2 on dataset wikimedia/wikipedia.

Model Architecture:

  • Architecture: GPT2LMHeadModel
  • Total Parameters: 124,439,808
  • Data Type (dtype): torch.bfloat16
  • Model Size: 0.24 GB

Evaluation Metrics Comparison

step epoch enwikippl frwikippl loss runtime samples_per_second steps_per_second tinystoriesppl zhwikippl
teacher eval 43.25 61.25 11.6875 19.125
0 0 2473901162496.0 170424302305280.0 22.7948 25.4866 98.091 12.281 4060086272.0 71468255805440.0
2500 0.0404 800.0 6240.0 2.9661 25.4278 98.318 12.309 470.0 5024.0
5000 0.0808 326.0 1480.0 2.1697 25.4996 98.041 12.275 247.0 278.0
7500 0.1212 224.0 804.0 1.8396 25.5448 97.867 12.253 185.0 190.0
10000 0.1616 171.0 608.0 1.6412 25.4672 98.165 12.29 145.0 166.0
12500 0.2020 127.0 482.0 1.3752 25.4897 98.079 12.279 111.0 141.0
15000 0.2424 104.5 436.0 1.2398 25.4711 98.15 12.288 93.5 99.5
17500 0.2828 90.5 346.0 1.1286 25.4723 98.146 12.288 74.0 147.0
20000 0.3232 81.5 312.0 1.0325 25.4627 98.183 12.293 69.5 111.0
22500 0.3636 73.0 236.0 0.9000 25.4791 98.12 12.285 59.75 100.0
25000 0.4040 67.0 209.0 0.8527 25.4728 98.144 12.288 53.0 183.0
27500 0.4444 64.0 228.0 0.8201 25.4859 98.094 12.281 48.0 105.5
30000 0.4848 64.5 225.0 0.8103 25.489 98.082 12.28 51.75 77.5
32500 0.5253 64.0 194.0 0.8016 25.4563 98.208 12.296 46.5 117.5
35000 0.5657 63.5 188.0 0.7395 25.4507 98.229 12.298 44.0 73.0
37500 0.6061 60.25 172.0 0.7164 25.411 98.382 12.317 45.5 68.5
40000 0.6465 59.5 180.0 0.7014 25.4454 98.25 12.301 41.25 94.5
42500 0.6869 58.25 168.0 0.6708 25.4719 98.147 12.288 42.0 65.5
45000 0.7273 53.75 158.0 0.5781 25.3987 98.43 12.323 35.25 67.5
47500 0.7677 54.0 136.0 0.5538 25.4465 98.245 12.3 34.0 41.75
50000 0.8081 52.25 136.0 0.5368 25.4472 98.243 12.3 33.0 41.0
52500 0.8485 50.75 131.0 0.5244 25.4589 98.198 12.294 33.25 38.25
55000 0.8889 50.0 128.0 0.5073 25.4565 98.207 12.295 32.0 35.5
57500 0.9293 49.75 127.0 0.5019 25.4729 98.143 12.288 31.75 33.5
60000 0.9697 49.75 126.5 0.4983 25.4379 98.279 12.304 31.5 33.75
61875 1.0 49.75 126.5 0.4979 25.4846 98.098 12.282 31.5 33.75

Resource Usage Comparison

  • VRAM Use: 7.7851 GB

Distillation (Teacher -> Student) Architecture Difference:

  • Architecture: GPT2LMHeadModel -> GPT2LMHeadModel
  • Total Parameters: 124,439,808 -> 124,439,808
  • Data Type (dtype): torch.bfloat16 -> torch.bfloat16
  • Model Size: 0.24 GB -> 0.24 GB
Module Diff Details


Train Dataset

Trained on 145,744,973 tokens from the wikimedia/wikipedia dataset.

  • Num Samples: 247,500
  • Subset: 20231101.en
  • Split: train

Training Objective

DistillationObjective(logits_loss_component=LossComponent(label=logits, weight=1, loss_fn=kl), attn_loss_component=LossComponent(label=attn, weight=25.0, loss_fn=kl, layer_mapper=layer-2))

Hyperparameters

The following hyperparameters were used during training:

Expand
  • learning_rate: 0.0001
  • train_batch_size: 4
  • eval_batch_size: 8
  • seed: 42
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: linear
  • lr_scheduler_warmup_ratio: 0.5
  • num_epochs: 1.0
  • distillation_objective: DistillationObjective(logits_loss_component=LossComponent(label=logits, weight=1, loss_fn=kl), attn_loss_component=LossComponent(label=attn, weight=25.0, loss_fn=kl, layer_mapper=layer-2))
  • train_embeddings: True
  • lr_scheduler: <torch.optim.lr_scheduler.LambdaLR object at 0x7f0428288790>
  • student_model_name_or_path: None
  • student_config_name_or_path: None
  • student_model_config: None
  • reinitialize_weights: None
  • copy_teacher_modules: [('lm_head', False)]
  • student_model_as_bitnet: True
  • student_model_compile: False
  • dropout: None
  • teacher_model_name_or_path: gpt2
  • teacher_load_in_8bit: False
  • teacher_load_in_4bit: False
  • teacher_model_compile: False
  • dataset_uri: wikimedia/wikipedia
  • dataset_subset: 20231101.en
  • dataset_split: train
  • dataset_column_name: text
  • dataset_sample_size: 250000
  • dataset_test_size: 0.01
  • gradient_accumulation_steps: 1
  • weight_decay: 0.0
  • max_grad_norm: 1.0
  • warmup_ratio: 0.5
  • warmup_steps: 0
  • gradient_checkpointing: True

Framework Versions

  • Distily 0.2.0
  • Transformers 4.44.1
  • Pytorch 2.5.0.dev20240821+cu121
  • Datasets 2.21.0
Downloads last month
60
Safetensors
Model size
124M params
Tensor type
BF16
·
Inference API
Unable to determine this model’s pipeline type. Check the docs .

Model tree for distily/distily_multi_experiment

Finetuned
(1129)
this model

Dataset used to train distily/distily_multi_experiment