Crystalcareai commited on
Commit
9a966d4
·
verified ·
1 Parent(s): ad43155

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +24 -43
modeling_gemmoe.py CHANGED
@@ -18,12 +18,13 @@
18
  import math
19
  import warnings
20
  from typing import List, Optional, Tuple, Union
21
- import contextlib
22
  import torch
23
  import torch.nn.functional as F
24
  import torch.utils.checkpoint
25
  from torch import nn
26
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
27
  from transformers.activations import ACT2FN
28
  from transformers.cache_utils import Cache, DynamicCache, StaticCache
29
  from transformers.modeling_attn_mask_utils import (
@@ -305,7 +306,6 @@ class GemmoeAttention(nn.Module):
305
  - The attention weights (if `output_attentions=True`).
306
  - The past key-value cache (if `use_cache=True`).
307
  """
308
-
309
  bsz, q_len, _ = hidden_states.size()
310
 
311
  query_states = self.q_proj(hidden_states)
@@ -331,14 +331,12 @@ class GemmoeAttention(nn.Module):
331
 
332
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
333
 
334
- with torch.no_grad() if not self.training else contextlib.nullcontext():
335
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
336
- if attention_mask is not None:
337
- if cache_position is not None:
338
- causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
339
- else:
340
- causal_mask = attention_mask
341
- attn_weights = attn_weights + causal_mask
342
 
343
  # upcast attention to fp32
344
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
@@ -686,7 +684,6 @@ class GemmoeSparseMoeBlock(nn.Module):
686
 
687
  self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
688
 
689
- @torch.jit.script
690
  def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
691
  batch_size, sequence_length, hidden_dim = hidden_states.shape
692
  hidden_states = hidden_states.view(-1, hidden_dim)
@@ -727,7 +724,6 @@ class GemmoeDecoderLayer(nn.Module):
727
  self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
728
  self.post_attention_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
729
 
730
- @torch.jit.script
731
  def forward(
732
  self,
733
  hidden_states: torch.Tensor,
@@ -977,7 +973,7 @@ class GemmoeModel(GemmoePreTrainedModel):
977
  hidden_states = inputs_embeds
978
 
979
  # Normalize
980
- scale_factor = torch.tensor(math.sqrt(self.config.hidden_size), dtype=hidden_states.dtype)
981
  hidden_states = hidden_states * scale_factor
982
  # Decoder layers
983
  all_hidden_states = () if output_hidden_states else None
@@ -990,8 +986,8 @@ class GemmoeModel(GemmoePreTrainedModel):
990
  all_hidden_states += (hidden_states,)
991
 
992
  if self.gradient_checkpointing and self.training:
993
- layer_outputs = torch.utils.checkpoint.checkpoint(
994
- decoder_layer,
995
  hidden_states,
996
  causal_mask,
997
  position_ids,
@@ -1204,34 +1200,19 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1204
  )
1205
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1206
 
1207
- if self.training:
1208
- outputs = torch.utils.checkpoint.checkpoint(
1209
- self.model,
1210
- input_ids,
1211
- attention_mask,
1212
- position_ids,
1213
- past_key_values,
1214
- inputs_embeds,
1215
- use_cache,
1216
- output_attentions,
1217
- output_hidden_states,
1218
- return_dict,
1219
- cache_position,
1220
- )
1221
- else:
1222
- outputs = self.model(
1223
- input_ids=input_ids,
1224
- attention_mask=attention_mask,
1225
- position_ids=position_ids,
1226
- past_key_values=past_key_values,
1227
- inputs_embeds=inputs_embeds,
1228
- use_cache=use_cache,
1229
- output_attentions=output_attentions,
1230
- output_hidden_states=output_hidden_states,
1231
- output_router_logits=output_router_logits,
1232
- return_dict=return_dict,
1233
- cache_position=cache_position,
1234
- )
1235
 
1236
  hidden_states = outputs[0]
1237
 
 
18
  import math
19
  import warnings
20
  from typing import List, Optional, Tuple, Union
21
+
22
  import torch
23
  import torch.nn.functional as F
24
  import torch.utils.checkpoint
25
  from torch import nn
26
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+
28
  from transformers.activations import ACT2FN
29
  from transformers.cache_utils import Cache, DynamicCache, StaticCache
30
  from transformers.modeling_attn_mask_utils import (
 
306
  - The attention weights (if `output_attentions=True`).
307
  - The past key-value cache (if `use_cache=True`).
308
  """
 
309
  bsz, q_len, _ = hidden_states.size()
310
 
311
  query_states = self.q_proj(hidden_states)
 
331
 
332
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
333
 
334
+ if attention_mask is not None: # no matter the length, we just slice it
335
+ if cache_position is not None:
336
+ causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
337
+ else:
338
+ causal_mask = attention_mask
339
+ attn_weights = attn_weights + causal_mask
 
 
340
 
341
  # upcast attention to fp32
342
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
 
684
 
685
  self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
686
 
 
687
  def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
688
  batch_size, sequence_length, hidden_dim = hidden_states.shape
689
  hidden_states = hidden_states.view(-1, hidden_dim)
 
724
  self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
725
  self.post_attention_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
726
 
 
727
  def forward(
728
  self,
729
  hidden_states: torch.Tensor,
 
973
  hidden_states = inputs_embeds
974
 
975
  # Normalize
976
+ scale_factor = torch.tensor(math_sqrt(self.config.hidden_size), dtype=hidden_states.dtype)
977
  hidden_states = hidden_states * scale_factor
978
  # Decoder layers
979
  all_hidden_states = () if output_hidden_states else None
 
986
  all_hidden_states += (hidden_states,)
987
 
988
  if self.gradient_checkpointing and self.training:
989
+ layer_outputs = self._gradient_checkpointing_func(
990
+ decoder_layer.__call__,
991
  hidden_states,
992
  causal_mask,
993
  position_ids,
 
1200
  )
1201
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1202
 
1203
+ outputs = self.model(
1204
+ input_ids=input_ids,
1205
+ attention_mask=attention_mask,
1206
+ position_ids=position_ids,
1207
+ past_key_values=past_key_values,
1208
+ inputs_embeds=inputs_embeds,
1209
+ use_cache=use_cache,
1210
+ output_attentions=output_attentions,
1211
+ output_hidden_states=output_hidden_states,
1212
+ output_router_logits=output_router_logits,
1213
+ return_dict=return_dict,
1214
+ cache_position=cache_position,
1215
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1216
 
1217
  hidden_states = outputs[0]
1218