Crystalcareai commited on
Commit
5450314
·
verified ·
1 Parent(s): f84e893

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +105 -96
modeling_gemmoe.py CHANGED
@@ -18,12 +18,13 @@
18
  import math
19
  import warnings
20
  from typing import List, Optional, Tuple, Union
21
-
22
  import torch
23
  import torch.nn.functional as F
24
  import torch.utils.checkpoint
25
  from torch import nn
26
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
27
 
28
  from transformers.activations import ACT2FN
29
  from transformers.cache_utils import Cache, DynamicCache, StaticCache
@@ -306,6 +307,7 @@ class GemmoeAttention(nn.Module):
306
  - The attention weights (if `output_attentions=True`).
307
  - The past key-value cache (if `use_cache=True`).
308
  """
 
309
  bsz, q_len, _ = hidden_states.size()
310
 
311
  query_states = self.q_proj(hidden_states)
@@ -331,12 +333,14 @@ class GemmoeAttention(nn.Module):
331
 
332
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
333
 
334
- if attention_mask is not None: # no matter the length, we just slice it
335
- if cache_position is not None:
336
- causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
337
- else:
338
- causal_mask = attention_mask
339
- attn_weights = attn_weights + causal_mask
 
 
340
 
341
  # upcast attention to fp32
342
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
@@ -371,84 +375,72 @@ class GemmoeFlashAttention2(GemmoeAttention):
371
  # TODO: Remove this attribute once Flash Attention for RoCm is bumped to 2.1.
372
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
373
 
374
- def forward(
375
- self,
376
- hidden_states: torch.Tensor,
377
- attention_mask: Optional[torch.LongTensor] = None,
378
- position_ids: Optional[torch.LongTensor] = None,
379
- past_key_value: Optional[Cache] = None,
380
- output_attentions: bool = False,
381
- use_cache: bool = False,
382
- cache_position: Optional[torch.LongTensor] = None,
383
- **kwargs,
384
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
385
- output_attentions = False
386
-
387
- bsz, q_len, _ = hidden_states.size()
388
-
389
- query_states = self.q_proj(hidden_states)
390
- key_states = self.k_proj(hidden_states)
391
- value_states = self.v_proj(hidden_states)
392
-
393
- # Flash attention requires the input to have the shape
394
- # batch_size x seq_length x head_dim x hidden_dim
395
- # therefore we just need to keep the original shape
396
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
397
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
398
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
399
-
400
- cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
401
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
402
-
403
- past_key_value = getattr(self, "past_key_value", past_key_value)
404
- if past_key_value is not None:
405
- # sin and cos are specific to RoPE models; position_ids needed for the static cache
406
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
407
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
408
-
409
- # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
410
- # to be able to avoid many of these transpose/reshape/view.
411
- query_states = query_states.transpose(1, 2)
412
- key_states = key_states.transpose(1, 2)
413
- value_states = value_states.transpose(1, 2)
414
-
415
- dropout_rate = self.attention_dropout if self.training else 0.0
416
-
417
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
418
- # therefore the input hidden states gets silently casted in float32. Hence, we need
419
- # cast them back in the correct dtype just to be sure everything works as expected.
420
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
421
- # in fp32. (GemmoeRMSNorm handles it correctly)
422
- input_dtype = query_states.dtype
423
- if input_dtype == torch.float32:
424
- if torch.is_autocast_enabled():
425
- target_dtype = torch.get_autocast_gpu_dtype()
426
- # Handle the case where the model is quantized
427
- elif hasattr(self.config, "_pre_quantization_dtype"):
428
- target_dtype = self.config._pre_quantization_dtype
429
- else:
430
- target_dtype = self.q_proj.weight.dtype
431
-
432
- logger.warning_once(
433
- f"The input hidden states seems to be silently casted in float32, this might be related to"
434
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
435
- f" {target_dtype}."
436
- )
437
- query_states = query_states.to(target_dtype)
438
- key_states = key_states.to(target_dtype)
439
- value_states = value_states.to(target_dtype)
440
 
441
- attn_output = self._flash_attention_forward(
442
- query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
 
 
443
  )
 
 
 
444
 
445
- attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
446
- attn_output = self.o_proj(attn_output)
 
447
 
448
- if not output_attentions:
449
- attn_weights = None
450
 
451
- return attn_output, attn_weights, past_key_value
 
 
 
452
 
453
  def _flash_attention_forward(
454
  self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
@@ -684,6 +676,7 @@ class GemmoeSparseMoeBlock(nn.Module):
684
 
685
  self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
686
 
 
687
  def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
688
  batch_size, sequence_length, hidden_dim = hidden_states.shape
689
  hidden_states = hidden_states.view(-1, hidden_dim)
@@ -724,6 +717,7 @@ class GemmoeDecoderLayer(nn.Module):
724
  self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
725
  self.post_attention_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
726
 
 
727
  def forward(
728
  self,
729
  hidden_states: torch.Tensor,
@@ -973,7 +967,7 @@ class GemmoeModel(GemmoePreTrainedModel):
973
  hidden_states = inputs_embeds
974
 
975
  # Normalize
976
- scale_factor = torch.tensor(math_sqrt(self.config.hidden_size), dtype=hidden_states.dtype)
977
  hidden_states = hidden_states * scale_factor
978
  # Decoder layers
979
  all_hidden_states = () if output_hidden_states else None
@@ -986,8 +980,8 @@ class GemmoeModel(GemmoePreTrainedModel):
986
  all_hidden_states += (hidden_states,)
987
 
988
  if self.gradient_checkpointing and self.training:
989
- layer_outputs = self._gradient_checkpointing_func(
990
- decoder_layer.__call__,
991
  hidden_states,
992
  causal_mask,
993
  position_ids,
@@ -1200,19 +1194,34 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1200
  )
1201
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1202
 
1203
- outputs = self.model(
1204
- input_ids=input_ids,
1205
- attention_mask=attention_mask,
1206
- position_ids=position_ids,
1207
- past_key_values=past_key_values,
1208
- inputs_embeds=inputs_embeds,
1209
- use_cache=use_cache,
1210
- output_attentions=output_attentions,
1211
- output_hidden_states=output_hidden_states,
1212
- output_router_logits=output_router_logits,
1213
- return_dict=return_dict,
1214
- cache_position=cache_position,
1215
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1216
 
1217
  hidden_states = outputs[0]
1218
 
 
18
  import math
19
  import warnings
20
  from typing import List, Optional, Tuple, Union
21
+ import contextlib
22
  import torch
23
  import torch.nn.functional as F
24
  import torch.utils.checkpoint
25
  from torch import nn
26
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+ import flash_attn_cuda_utils
28
 
29
  from transformers.activations import ACT2FN
30
  from transformers.cache_utils import Cache, DynamicCache, StaticCache
 
307
  - The attention weights (if `output_attentions=True`).
308
  - The past key-value cache (if `use_cache=True`).
309
  """
310
+
311
  bsz, q_len, _ = hidden_states.size()
312
 
313
  query_states = self.q_proj(hidden_states)
 
333
 
334
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
335
 
336
+ with torch.no_grad() if not self.training else contextlib.nullcontext():
337
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
338
+ if attention_mask is not None:
339
+ if cache_position is not None:
340
+ causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
341
+ else:
342
+ causal_mask = attention_mask
343
+ attn_weights = attn_weights + causal_mask
344
 
345
  # upcast attention to fp32
346
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
 
375
  # TODO: Remove this attribute once Flash Attention for RoCm is bumped to 2.1.
376
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
377
 
378
+ def forward(
379
+ self,
380
+ hidden_states: torch.Tensor,
381
+ attention_mask: Optional[torch.LongTensor] = None,
382
+ position_ids: Optional[torch.LongTensor] = None,
383
+ past_key_value: Optional[Cache] = None,
384
+ output_attentions: bool = False,
385
+ use_cache: bool = False,
386
+ cache_position: Optional[torch.LongTensor] = None,
387
+ **kwargs,
388
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
389
+ output_attentions = False
390
+
391
+ bsz, q_len, _ = hidden_states.size()
392
+
393
+ query_states = self.q_proj(hidden_states)
394
+ key_states = self.k_proj(hidden_states)
395
+ value_states = self.v_proj(hidden_states)
396
+
397
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
398
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
399
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
400
+
401
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
402
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
403
+
404
+ past_key_value = getattr(self, "past_key_value", past_key_value)
405
+ if past_key_value is not None:
406
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
407
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
408
+
409
+ query_states = query_states.transpose(1, 2)
410
+ key_states = key_states.transpose(1, 2)
411
+ value_states = value_states.transpose(1, 2)
412
+
413
+ dropout_rate = self.attention_dropout if self.training else 0.0
414
+
415
+ input_dtype = query_states.dtype
416
+ if input_dtype == torch.float32:
417
+ if torch.is_autocast_enabled():
418
+ target_dtype = torch.get_autocast_gpu_dtype()
419
+ elif hasattr(self.config, "_pre_quantization_dtype"):
420
+ target_dtype = self.config._pre_quantization_dtype
421
+ else:
422
+ target_dtype = self.q_proj.weight.dtype
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
 
424
+ logger.warning_once(
425
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
426
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
427
+ f" {target_dtype}."
428
  )
429
+ query_states = query_states.to(target_dtype)
430
+ key_states = key_states.to(target_dtype)
431
+ value_states = value_states.to(target_dtype)
432
 
433
+ attn_output = flash_attn_cuda_utils.pyt_flash_scaled_dot_attention(
434
+ query_states, key_states, value_states, attn_mask=attention_mask, dropout_prob=dropout_rate
435
+ )
436
 
437
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
438
+ attn_output = self.o_proj(attn_output)
439
 
440
+ if not output_attentions:
441
+ attn_weights = None
442
+
443
+ return attn_output, attn_weights, past_key_value
444
 
445
  def _flash_attention_forward(
446
  self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
 
676
 
677
  self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
678
 
679
+ @torch.jit.script
680
  def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
681
  batch_size, sequence_length, hidden_dim = hidden_states.shape
682
  hidden_states = hidden_states.view(-1, hidden_dim)
 
717
  self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
718
  self.post_attention_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
719
 
720
+ @torch.jit.script
721
  def forward(
722
  self,
723
  hidden_states: torch.Tensor,
 
967
  hidden_states = inputs_embeds
968
 
969
  # Normalize
970
+ scale_factor = torch.tensor(math.sqrt(self.config.hidden_size), dtype=hidden_states.dtype)
971
  hidden_states = hidden_states * scale_factor
972
  # Decoder layers
973
  all_hidden_states = () if output_hidden_states else None
 
980
  all_hidden_states += (hidden_states,)
981
 
982
  if self.gradient_checkpointing and self.training:
983
+ layer_outputs = torch.utils.checkpoint.checkpoint(
984
+ decoder_layer,
985
  hidden_states,
986
  causal_mask,
987
  position_ids,
 
1194
  )
1195
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1196
 
1197
+ if self.training:
1198
+ outputs = torch.utils.checkpoint.checkpoint(
1199
+ self.model,
1200
+ input_ids,
1201
+ attention_mask,
1202
+ position_ids,
1203
+ past_key_values,
1204
+ inputs_embeds,
1205
+ use_cache,
1206
+ output_attentions,
1207
+ output_hidden_states,
1208
+ return_dict,
1209
+ cache_position,
1210
+ )
1211
+ else:
1212
+ outputs = self.model(
1213
+ input_ids=input_ids,
1214
+ attention_mask=attention_mask,
1215
+ position_ids=position_ids,
1216
+ past_key_values=past_key_values,
1217
+ inputs_embeds=inputs_embeds,
1218
+ use_cache=use_cache,
1219
+ output_attentions=output_attentions,
1220
+ output_hidden_states=output_hidden_states,
1221
+ output_router_logits=output_router_logits,
1222
+ return_dict=return_dict,
1223
+ cache_position=cache_position,
1224
+ )
1225
 
1226
  hidden_states = outputs[0]
1227