zhujiangang commited on
Commit
08a1cd8
·
verified ·
1 Parent(s): 4848526

Update modeling_bailing_moe.py

Browse files
Files changed (1) hide show
  1. modeling_bailing_moe.py +13 -8
modeling_bailing_moe.py CHANGED
@@ -117,8 +117,8 @@ class BailingMoeRMSNorm(nn.Module):
117
  hidden_states = hidden_states.to(torch.float32)
118
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
119
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
 
120
 
121
- return (self.weight.float() * hidden_states).to(input_dtype)
122
 
123
  ALL_LAYERNORM_LAYERS.append(BailingMoeRMSNorm)
124
 
@@ -495,7 +495,7 @@ class BailingMoeAttention(nn.Module):
495
  key_states = repeat_kv(key_states, self.num_key_value_groups)
496
  value_states = repeat_kv(value_states, self.num_key_value_groups)
497
 
498
- attn_weights = torch.matmul(query_states / math.sqrt(self.head_dim), key_states.transpose(2, 3))
499
 
500
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
501
  raise ValueError(
@@ -825,7 +825,6 @@ class BailingMoeSdpaAttention(BailingMoeAttention):
825
  dropout_p=self.attention_dropout if self.training else 0.0,
826
  # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
827
  is_causal=self.is_causal and attention_mask is None and q_len > 1,
828
- # enable_gqa=True
829
  )
830
 
831
  attn_output = attn_output.transpose(1, 2).contiguous()
@@ -847,6 +846,7 @@ class BailingMoeDecoderLayer(nn.Module):
847
  def __init__(self, config: BailingMoeConfig, layer_idx: int):
848
  super().__init__()
849
  self.hidden_size = config.hidden_size
 
850
  self.attention = BAILING_MOE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
851
 
852
  self.mlp = (
@@ -1167,7 +1167,7 @@ class BailingMoeModel(BailingMoePreTrainedModel):
1167
  all_router_logits = () if output_router_logits else None
1168
  next_decoder_cache = None
1169
 
1170
- for layer_idx, decoder_layer in enumerate(self.layers):
1171
  if output_hidden_states:
1172
  all_hidden_states += (hidden_states,)
1173
 
@@ -1332,9 +1332,10 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
1332
  )
1333
  logits = F.linear(hidden_states, norm_weight, None)
1334
  else:
1335
- self.lm_head.weight.data = (self.lm_head.weight.data.float() / (
1336
- torch.norm(self.lm_head.weight.data.float(), p=2, dim=0, keepdim=True) + 1e-7
1337
- )).to(hidden_states.dtype)
 
1338
  logits = F.linear(hidden_states, self.lm_head.weight.data, None)
1339
  self.norm_head = False
1340
  else:
@@ -1380,7 +1381,11 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
1380
  if isinstance(past_key_values, Cache):
1381
  cache_length = past_key_values.get_seq_length()
1382
  past_length = past_key_values.seen_tokens
1383
- max_cache_length = past_key_values.get_max_length()
 
 
 
 
1384
  else:
1385
  cache_length = past_length = past_key_values[0][0].shape[2]
1386
  max_cache_length = None
 
117
  hidden_states = hidden_states.to(torch.float32)
118
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
119
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
120
+ return self.weight * hidden_states.to(input_dtype)
121
 
 
122
 
123
  ALL_LAYERNORM_LAYERS.append(BailingMoeRMSNorm)
124
 
 
495
  key_states = repeat_kv(key_states, self.num_key_value_groups)
496
  value_states = repeat_kv(value_states, self.num_key_value_groups)
497
 
498
+ attn_weights = torch.matmul(query_states / math.sqrt(self.head_dim), key_states.transpose(2, 3))
499
 
500
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
501
  raise ValueError(
 
825
  dropout_p=self.attention_dropout if self.training else 0.0,
826
  # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
827
  is_causal=self.is_causal and attention_mask is None and q_len > 1,
 
828
  )
829
 
830
  attn_output = attn_output.transpose(1, 2).contiguous()
 
846
  def __init__(self, config: BailingMoeConfig, layer_idx: int):
847
  super().__init__()
848
  self.hidden_size = config.hidden_size
849
+
850
  self.attention = BAILING_MOE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
851
 
852
  self.mlp = (
 
1167
  all_router_logits = () if output_router_logits else None
1168
  next_decoder_cache = None
1169
 
1170
+ for decoder_layer in self.layers:
1171
  if output_hidden_states:
1172
  all_hidden_states += (hidden_states,)
1173
 
 
1332
  )
1333
  logits = F.linear(hidden_states, norm_weight, None)
1334
  else:
1335
+ self.lm_head.weight.data = (
1336
+ self.lm_head.weight.data.float()
1337
+ / (torch.norm(self.lm_head.weight.data.float(), p=2, dim=0, keepdim=True) + 1e-7)
1338
+ ).to(hidden_states.dtype)
1339
  logits = F.linear(hidden_states, self.lm_head.weight.data, None)
1340
  self.norm_head = False
1341
  else:
 
1381
  if isinstance(past_key_values, Cache):
1382
  cache_length = past_key_values.get_seq_length()
1383
  past_length = past_key_values.seen_tokens
1384
+ max_cache_length = (
1385
+ past_key_values.get_max_length()
1386
+ if hasattr(past_key_values, "get_max_length")
1387
+ else past_key_values.get_max_cache_shape()
1388
+ )
1389
  else:
1390
  cache_length = past_length = past_key_values[0][0].shape[2]
1391
  max_cache_length = None