Edit model card

Summary

Distilled with Distily library using teacher model gpt2 on dataset wikimedia/wikipedia.

Model Architecture:

  • Architecture: GPT2LMHeadModel
  • Total Parameters: 81,912,576
  • Data Type (dtype): torch.bfloat16
  • Model Size: 0.16 GB

Evaluation Metrics Comparison

step epoch enwikippl frwikippl loss runtime samples_per_second steps_per_second tinystoriesppl zhwikippl
teacher eval 43.25 61.25 11.6875 19.125
0 0 2018634629120.0 122045790683136.0 21.0022 102.1494 97.896 12.237 9999220736.0 43705587204096.0
2500 0.0101 299008.0 6422528.0 5.8065 101.9861 98.053 12.257 45824.0 14483456.0
5000 0.0202 6880.0 96256.0 3.3113 102.9516 97.133 12.142 4160.0 493568.0
7500 0.0303 1216.0 8096.0 2.1560 103.0236 97.065 12.133 692.0 42752.0
10000 0.0404 608.0 3664.0 1.7825 102.3752 97.68 12.21 388.0 888.0
12500 0.0505 358.0 1632.0 1.4664 102.1871 97.86 12.232 272.0 308.0
15000 0.0606 288.0 1176.0 1.3488 102.6007 97.465 12.183 228.0 260.0
17500 0.0707 255.0 1040.0 1.2932 102.1542 97.891 12.236 199.0 215.0
20000 0.0808 216.0 892.0 1.1570 102.1073 97.936 12.242 173.0 149.0
22500 0.0909 178.0 740.0 1.0350 102.0765 97.966 12.246 146.0 141.0
25000 0.1010 155.0 524.0 0.9676 102.1019 97.941 12.243 122.5 139.0
27500 0.1111 142.0 560.0 0.9230 102.0256 98.015 12.252 114.0 130.0
30000 0.1212 137.0 470.0 0.8998 102.3365 97.717 12.215 108.5 138.0
32500 0.1313 134.0 476.0 0.8740 102.3911 97.665 12.208 104.0 140.0
35000 0.1414 129.0 496.0 0.8657 102.2153 97.833 12.229 102.5 141.0
37500 0.1515 127.0 464.0 0.8513 102.0489 97.992 12.249 97.0 117.0
40000 0.1616 108.0 446.0 0.7522 102.9331 97.15 12.144 93.0 104.0
42500 0.1717 99.5 374.0 0.6850 103.1088 96.985 12.123 82.0 116.0
45000 0.1818 90.5 346.0 0.6316 102.7903 97.285 12.161 73.5 113.0
47500 0.1919 82.5 320.0 0.5960 102.5988 97.467 12.183 71.0 101.0
50000 0.2020 78.5 306.0 0.5676 102.5936 97.472 12.184 72.5 106.0
52500 0.2121 79.5 290.0 0.5424 102.5863 97.479 12.185 64.5 92.0
55000 0.2222 76.0 270.0 0.5280 102.6307 97.437 12.18 65.0 87.0
57500 0.2323 76.5 272.0 0.5278 101.9639 98.074 12.259 64.5 102.0
60000 0.2424 77.5 268.0 0.5286 102.0921 97.951 12.244 62.75 99.5
62500 0.2525 75.5 264.0 0.5204 102.0679 97.974 12.247 63.25 83.0
65000 0.2626 76.0 260.0 0.5176 102.1795 97.867 12.233 61.5 90.5
67500 0.2727 74.5 256.0 0.5112 102.5764 97.488 12.186 62.25 93.5
70000 0.2828 73.5 258.0 0.5128 101.9569 98.081 12.26 62.0 79.0
72500 0.2929 75.0 250.0 0.5053 101.9382 98.099 12.262 64.0 96.0
75000 0.3030 72.5 238.0 0.5068 102.0407 98.0 12.25 61.5 88.5
77500 0.3131 73.5 256.0 0.5085 102.0542 97.987 12.248 64.5 86.5
80000 0.3232 70.5 238.0 0.4699 102.4042 97.652 12.207 54.75 98.5
82500 0.3333 68.0 242.0 0.4574 102.2684 97.782 12.223 55.5 160.0
85000 0.3434 64.5 218.0 0.4490 102.3277 97.725 12.216 52.0 77.5
87500 0.3535 66.5 203.0 0.4394 102.1134 97.93 12.241 51.25 67.5
90000 0.3636 63.75 212.0 0.4310 102.0438 97.997 12.25 51.25 88.5
92500 0.3737 65.5 209.0 0.4262 101.9984 98.041 12.255 49.75 103.5
95000 0.3838 65.0 204.0 0.4274 102.0781 97.964 12.246 46.25 83.0
97500 0.3939 64.5 201.0 0.4192 102.0692 97.973 12.247 50.5 94.5
100000 0.4040 64.5 203.0 0.4207 102.1283 97.916 12.24 49.0 88.0
102500 0.4141 63.0 209.0 0.4184 102.224 97.824 12.228 48.0 125.0
105000 0.4242 62.75 193.0 0.4166 102.1918 97.855 12.232 46.0 76.0
107500 0.4343 62.75 197.0 0.4128 102.1719 97.874 12.234 47.0 113.0
110000 0.4444 64.5 191.0 0.4118 103.0992 96.994 12.124 49.0 82.0
112500 0.4545 65.0 213.0 0.4128 102.7296 97.343 12.168 47.0 111.5
115000 0.4646 68.5 207.0 0.4301 102.178 97.868 12.234 49.0 108.0
117500 0.4747 65.0 217.0 0.4372 102.2302 97.818 12.227 50.25 124.0
120000 0.4848 65.5 210.0 0.4351 102.2952 97.756 12.22 51.0 139.0
122500 0.4949 66.0 272.0 0.4352 102.1941 97.853 12.232 50.5 226.0
125000 0.5051 67.0 240.0 0.4387 101.978 98.06 12.258 49.0 71.0
127500 0.5152 66.5 224.0 0.4396 101.9014 98.134 12.267 49.75 100.0
130000 0.5253 65.5 227.0 0.4354 102.1244 97.92 12.24 50.75 146.0
132500 0.5354 66.0 209.0 0.4286 102.0218 98.018 12.252 52.25 101.5
135000 0.5455 64.5 220.0 0.4361 101.9074 98.128 12.266 51.25 181.0
137500 0.5556 66.5 223.0 0.4288 102.0744 97.968 12.246 49.0 103.0
140000 0.5657 66.5 232.0 0.4287 102.1162 97.928 12.241 49.25 127.5
142500 0.5758 66.5 220.0 0.4299 101.9461 98.091 12.261 49.5 88.5
145000 0.5859 65.5 217.0 0.4238 101.9572 98.08 12.26 48.75 177.0
147500 0.5960 64.0 205.0 0.4109 101.9497 98.088 12.261 48.75 128.0
150000 0.6061 63.5 224.0 0.4051 102.0205 98.02 12.252 48.5 117.5
152500 0.6162 63.25 202.0 0.4000 101.9318 98.105 12.263 47.5 160.0
155000 0.6263 63.75 195.0 0.4052 102.0203 98.02 12.252 48.75 100.0
157500 0.6364 63.75 212.0 0.4014 101.8935 98.142 12.268 49.25 113.0
160000 0.6465 62.75 198.0 0.3988 101.9178 98.118 12.265 44.5 132.0
162500 0.6566 64.5 192.0 0.3918 102.0303 98.01 12.251 45.5 100.0
165000 0.6667 62.5 202.0 0.3958 102.3627 97.692 12.211 47.75 88.5
167500 0.6768 62.5 191.0 0.3883 102.1537 97.892 12.236 44.75 80.5
170000 0.6869 63.5 195.0 0.3880 102.0728 97.969 12.246 51.0 91.5
172500 0.6970 60.75 201.0 0.3863 101.9235 98.113 12.264 47.5 90.5
175000 0.7071 61.5 189.0 0.3806 101.9376 98.099 12.262 46.5 82.5
177500 0.7172 58.75 171.0 0.3512 101.9844 98.054 12.257 42.75 66.0
180000 0.7273 55.5 161.0 0.3218 101.881 98.154 12.269 39.25 54.0
182500 0.7374 54.25 149.0 0.3148 101.9839 98.055 12.257 38.75 47.75
185000 0.7475 53.5 160.0 0.3133 101.9875 98.051 12.256 38.75 45.0
187500 0.7576 54.75 160.0 0.3114 101.9762 98.062 12.258 38.0 43.75
190000 0.7677 53.75 147.0 0.3075 101.9972 98.042 12.255 38.0 38.25
192500 0.7778 54.0 157.0 0.3057 101.9431 98.094 12.262 38.0 48.0
195000 0.7879 53.25 149.0 0.3058 101.9778 98.061 12.258 37.0 41.0
197500 0.7980 54.0 152.0 0.3032 102.0059 98.034 12.254 37.25 40.0
200000 0.8081 53.75 151.0 0.3033 102.0615 97.98 12.248 37.25 47.25
202500 0.8182 53.0 146.0 0.2957 102.0116 98.028 12.254 36.75 39.0
205000 0.8283 52.5 139.0 0.2903 102.1449 97.9 12.238 36.5 35.75
207500 0.8384 52.0 142.0 0.2894 102.0126 98.027 12.253 36.25 38.25
210000 0.8485 52.25 142.0 0.2883 102.0938 97.949 12.244 36.0 37.25
212500 0.8586 52.5 141.0 0.2874 101.9515 98.086 12.261 36.0 37.0
215000 0.8687 52.25 140.0 0.2873 101.9427 98.094 12.262 36.0 36.0
217500 0.8788 51.75 141.0 0.2863 102.0114 98.028 12.254 36.0 35.5
220000 0.8889 52.0 141.0 0.2854 102.0424 97.999 12.25 36.0 35.75
222500 0.8990 52.5 143.0 0.2853 102.0368 98.004 12.25 36.0 35.25
225000 0.9091 52.0 142.0 0.2849 102.115 97.929 12.241 35.75 35.0
227500 0.9192 52.0 141.0 0.2851 102.0455 97.996 12.249 36.0 35.25
230000 0.9293 52.0 141.0 0.2846 102.0273 98.013 12.252 35.75 35.25
232500 0.9394 52.0 141.0 0.2843 101.961 98.077 12.26 35.75 35.0
235000 0.9495 52.0 141.0 0.2844 102.0188 98.021 12.253 35.75 35.25
237500 0.9596 52.0 141.0 0.2845 102.0714 97.971 12.246 35.75 35.25
240000 0.9697 52.0 141.0 0.2844 102.0371 98.004 12.25 35.75 35.25
242500 0.9798 52.0 141.0 0.2844 102.0363 98.004 12.251 35.75 35.25
245000 0.9899 52.0 141.0 0.2844 102.0254 98.015 12.252 35.75 35.25
247500 1.0 52.0 141.0 0.2846 102.5728 97.492 12.186 35.75 35.25

Resource Usage Comparison

  • VRAM Use: 7.2012 GB

`# Distillation (Teacher -> Student) Architecture Difference:

  • Architecture: GPT2LMHeadModel -> GPT2LMHeadModel
  • Total Parameters: 124,439,808 -> 81,912,576
  • Data Type (dtype): 124439808 -> torch.bfloat16
  • Model Size: 0.24 GB -> 0.16 GB
Module Diff Details
--- teacher model modules
+++ student model modules
@@ -4,7 +4,7 @@
     (wpe): Embedding(1024, 768)
     (drop): Dropout(p=0.1, inplace=False)
     (h): ModuleList(
-      (0-11): 12 x GPT2Block(
+      (0-5): 6 x GPT2Block(
         (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
         (attn): GPT2FlashAttention2(
           (c_attn): Conv1D()

Train Dataset

Trained on 521,350,663 tokens from the wikimedia/wikipedia dataset.

  • Num Samples: 990,000
  • Subset: 20231101.en
  • Split: train

Training Objective

DistillationObjective(logits_loss_component=LossComponent(label=logits, weight=1, loss_fn=kl))

Hyperparameters

The following hyperparameters were used during training:

Expand
  • learning_rate: 0.0001
  • train_batch_size: 4
  • eval_batch_size: 8
  • seed: 42
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: cosine
  • lr_scheduler_warmup_ratio: 0.5
  • num_epochs: 1.0
  • distillation_objective: DistillationObjective(logits_loss_component=LossComponent(label=logits, weight=1, loss_fn=kl))
  • train_embeddings: True
  • lr_scheduler: <torch.optim.lr_scheduler.LambdaLR object at 0x7fd9b01df220>
  • student_model_name_or_path: None
  • student_config_name_or_path: distilbert/distilgpt2
  • student_model_config: None
  • reinitialize_weights: None
  • copy_teacher_modules: [('lm_head', False)]
  • student_model_as_bitnet: False
  • student_model_compile: False
  • dropout: None
  • teacher_model_name_or_path: gpt2
  • teacher_load_in_8bit: False
  • teacher_load_in_4bit: False
  • teacher_model_compile: False
  • dataset_uri: wikimedia/wikipedia
  • dataset_subset: 20231101.en
  • dataset_split: train
  • dataset_column_name: text
  • dataset_sample_size: 1000000
  • dataset_test_size: 0.01
  • gradient_accumulation_steps: 1
  • weight_decay: 0.0
  • max_grad_norm: 1.0
  • warmup_ratio: 0.5
  • warmup_steps: 0
  • gradient_checkpointing: True

Framework Versions

  • Distily 0.2.0
  • Transformers 4.44.0
  • Pytorch 2.3.0
  • Datasets 2.21.0
Downloads last month
8
Safetensors
Model size
81.9M params
Tensor type
BF16
·
Inference API
Unable to determine this model’s pipeline type. Check the docs .

Model tree for distily/short_gpt2

Finetuned
this model

Dataset used to train distily/short_gpt2