|
import torch.nn.functional as F |
|
|
|
hpdict={'num_layers': 5, |
|
'd_model': 896, |
|
'num_heads': 14, |
|
'dff': 2389, |
|
'Gcachelst': './predefined_G_LM_cache_list_IDENTITY_5layer_14head_64x64_paper.pkl', |
|
'input_vocab_size': 32000, |
|
'max_seq_len': 1024, |
|
'epochs': 1, |
|
'save_model_path': './PLDRv51G-106M-2-checkpoint', |
|
'warmup_steps': 2000, |
|
'lr_total_steps': 250000, |
|
'learning_rate': 0.0003, |
|
'lr_alpha': 0.1, |
|
'adamw_decay': 0.1, |
|
'activation': F.silu, |
|
'disable_amp': False, |
|
'auto_size_minimum': None, |
|
'disable_fsdp_mixed_precision': False, |
|
'fsdp_cpu_offload': False, |
|
'fsdp_sharding_strategy': 'HYBRID_SHARD', |
|
'backward_prefetch': 'PRE', |
|
'save_type': 'torch'} |