File size: 2,342 Bytes
ac396ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
from transformers import PretrainedConfig
from typing import Optional
import math
class PhiConfig(PretrainedConfig):
model_type = "phi-msft"
def __init__(
self,
vocab_size: int = 51200,
n_positions: int = 2048,
n_embd: int = 2048,
n_layer: int = 24,
n_inner: Optional[int] = None,
n_head: int = 32,
n_head_kv: Optional[int] = None,
rotary_dim: Optional[int] = 32,
activation_function: Optional[str] = "gelu_new",
flash_attn: bool = False,
flash_rotary: bool = False,
fused_dense: bool = False,
attn_pdrop: float = 0.0,
embd_pdrop: float = 0.0,
resid_pdrop: float = 0.0,
layer_norm_epsilon: float = 1e-5,
initializer_range: float = 0.02,
tie_word_embeddings: bool = False,
pad_vocab_size_multiple: int = 64,
gradient_checkpointing: bool = False,
**kwargs
):
pad_vocab_size = (
math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
)
super().__init__(
vocab_size=pad_vocab_size,
n_positions=n_positions,
n_embd=n_embd,
n_layer=n_layer,
n_inner=n_inner,
n_head=n_head,
n_head_kv=n_head_kv,
activation_function=activation_function,
attn_pdrop=attn_pdrop,
embd_pdrop=embd_pdrop,
resid_pdrop=resid_pdrop,
layer_norm_epsilon=layer_norm_epsilon,
initializer_range=initializer_range,
pad_vocab_size_multiple=pad_vocab_size_multiple,
tie_word_embeddings=tie_word_embeddings,
gradient_checkpointing=gradient_checkpointing,
**kwargs
)
self.rotary_dim = min(rotary_dim, n_embd // n_head)
self.flash_attn = flash_attn
self.flash_rotary = flash_rotary
self.fused_dense = fused_dense
attribute_map = {
"max_position_embeddings": "n_positions",
"hidden_size": "n_embd",
"num_attention_heads": "n_head",
"num_hidden_layers": "n_layer",
}
class MoondreamConfig(PretrainedConfig):
model_type = "moondream1"
def __init__(self, **kwargs):
self.phi_config = PhiConfig(**kwargs)
super().__init__(**kwargs)
|