Update modeling_bailing_moe.py
Browse files- modeling_bailing_moe.py +13 -8
modeling_bailing_moe.py
CHANGED
@@ -117,8 +117,8 @@ class BailingMoeRMSNorm(nn.Module):
|
|
117 |
hidden_states = hidden_states.to(torch.float32)
|
118 |
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
119 |
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
|
120 |
|
121 |
-
return (self.weight.float() * hidden_states).to(input_dtype)
|
122 |
|
123 |
ALL_LAYERNORM_LAYERS.append(BailingMoeRMSNorm)
|
124 |
|
@@ -495,7 +495,7 @@ class BailingMoeAttention(nn.Module):
|
|
495 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
496 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
497 |
|
498 |
-
attn_weights = torch.matmul(query_states / math.sqrt(self.head_dim), key_states.transpose(2, 3))
|
499 |
|
500 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
501 |
raise ValueError(
|
@@ -825,7 +825,6 @@ class BailingMoeSdpaAttention(BailingMoeAttention):
|
|
825 |
dropout_p=self.attention_dropout if self.training else 0.0,
|
826 |
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
827 |
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
828 |
-
# enable_gqa=True
|
829 |
)
|
830 |
|
831 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
@@ -847,6 +846,7 @@ class BailingMoeDecoderLayer(nn.Module):
|
|
847 |
def __init__(self, config: BailingMoeConfig, layer_idx: int):
|
848 |
super().__init__()
|
849 |
self.hidden_size = config.hidden_size
|
|
|
850 |
self.attention = BAILING_MOE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
851 |
|
852 |
self.mlp = (
|
@@ -1167,7 +1167,7 @@ class BailingMoeModel(BailingMoePreTrainedModel):
|
|
1167 |
all_router_logits = () if output_router_logits else None
|
1168 |
next_decoder_cache = None
|
1169 |
|
1170 |
-
for
|
1171 |
if output_hidden_states:
|
1172 |
all_hidden_states += (hidden_states,)
|
1173 |
|
@@ -1332,9 +1332,10 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
|
|
1332 |
)
|
1333 |
logits = F.linear(hidden_states, norm_weight, None)
|
1334 |
else:
|
1335 |
-
self.lm_head.weight.data = (
|
1336 |
-
|
1337 |
-
|
|
|
1338 |
logits = F.linear(hidden_states, self.lm_head.weight.data, None)
|
1339 |
self.norm_head = False
|
1340 |
else:
|
@@ -1380,7 +1381,11 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
|
|
1380 |
if isinstance(past_key_values, Cache):
|
1381 |
cache_length = past_key_values.get_seq_length()
|
1382 |
past_length = past_key_values.seen_tokens
|
1383 |
-
max_cache_length =
|
|
|
|
|
|
|
|
|
1384 |
else:
|
1385 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
1386 |
max_cache_length = None
|
|
|
117 |
hidden_states = hidden_states.to(torch.float32)
|
118 |
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
119 |
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
120 |
+
return self.weight * hidden_states.to(input_dtype)
|
121 |
|
|
|
122 |
|
123 |
ALL_LAYERNORM_LAYERS.append(BailingMoeRMSNorm)
|
124 |
|
|
|
495 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
496 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
497 |
|
498 |
+
attn_weights = torch.matmul(query_states / math.sqrt(self.head_dim), key_states.transpose(2, 3))
|
499 |
|
500 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
501 |
raise ValueError(
|
|
|
825 |
dropout_p=self.attention_dropout if self.training else 0.0,
|
826 |
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
827 |
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
|
|
828 |
)
|
829 |
|
830 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
846 |
def __init__(self, config: BailingMoeConfig, layer_idx: int):
|
847 |
super().__init__()
|
848 |
self.hidden_size = config.hidden_size
|
849 |
+
|
850 |
self.attention = BAILING_MOE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
851 |
|
852 |
self.mlp = (
|
|
|
1167 |
all_router_logits = () if output_router_logits else None
|
1168 |
next_decoder_cache = None
|
1169 |
|
1170 |
+
for decoder_layer in self.layers:
|
1171 |
if output_hidden_states:
|
1172 |
all_hidden_states += (hidden_states,)
|
1173 |
|
|
|
1332 |
)
|
1333 |
logits = F.linear(hidden_states, norm_weight, None)
|
1334 |
else:
|
1335 |
+
self.lm_head.weight.data = (
|
1336 |
+
self.lm_head.weight.data.float()
|
1337 |
+
/ (torch.norm(self.lm_head.weight.data.float(), p=2, dim=0, keepdim=True) + 1e-7)
|
1338 |
+
).to(hidden_states.dtype)
|
1339 |
logits = F.linear(hidden_states, self.lm_head.weight.data, None)
|
1340 |
self.norm_head = False
|
1341 |
else:
|
|
|
1381 |
if isinstance(past_key_values, Cache):
|
1382 |
cache_length = past_key_values.get_seq_length()
|
1383 |
past_length = past_key_values.seen_tokens
|
1384 |
+
max_cache_length = (
|
1385 |
+
past_key_values.get_max_length()
|
1386 |
+
if hasattr(past_key_values, "get_max_length")
|
1387 |
+
else past_key_values.get_max_cache_shape()
|
1388 |
+
)
|
1389 |
else:
|
1390 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
1391 |
max_cache_length = None
|