Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +105 -96
modeling_gemmoe.py
CHANGED
@@ -18,12 +18,13 @@
|
|
18 |
import math
|
19 |
import warnings
|
20 |
from typing import List, Optional, Tuple, Union
|
21 |
-
|
22 |
import torch
|
23 |
import torch.nn.functional as F
|
24 |
import torch.utils.checkpoint
|
25 |
from torch import nn
|
26 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
27 |
|
28 |
from transformers.activations import ACT2FN
|
29 |
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
@@ -306,6 +307,7 @@ class GemmoeAttention(nn.Module):
|
|
306 |
- The attention weights (if `output_attentions=True`).
|
307 |
- The past key-value cache (if `use_cache=True`).
|
308 |
"""
|
|
|
309 |
bsz, q_len, _ = hidden_states.size()
|
310 |
|
311 |
query_states = self.q_proj(hidden_states)
|
@@ -331,12 +333,14 @@ class GemmoeAttention(nn.Module):
|
|
331 |
|
332 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
333 |
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
|
|
|
|
340 |
|
341 |
# upcast attention to fp32
|
342 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
@@ -371,84 +375,72 @@ class GemmoeFlashAttention2(GemmoeAttention):
|
|
371 |
# TODO: Remove this attribute once Flash Attention for RoCm is bumped to 2.1.
|
372 |
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
373 |
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
# cast them back in the correct dtype just to be sure everything works as expected.
|
420 |
-
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
421 |
-
# in fp32. (GemmoeRMSNorm handles it correctly)
|
422 |
-
input_dtype = query_states.dtype
|
423 |
-
if input_dtype == torch.float32:
|
424 |
-
if torch.is_autocast_enabled():
|
425 |
-
target_dtype = torch.get_autocast_gpu_dtype()
|
426 |
-
# Handle the case where the model is quantized
|
427 |
-
elif hasattr(self.config, "_pre_quantization_dtype"):
|
428 |
-
target_dtype = self.config._pre_quantization_dtype
|
429 |
-
else:
|
430 |
-
target_dtype = self.q_proj.weight.dtype
|
431 |
-
|
432 |
-
logger.warning_once(
|
433 |
-
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
434 |
-
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
435 |
-
f" {target_dtype}."
|
436 |
-
)
|
437 |
-
query_states = query_states.to(target_dtype)
|
438 |
-
key_states = key_states.to(target_dtype)
|
439 |
-
value_states = value_states.to(target_dtype)
|
440 |
|
441 |
-
|
442 |
-
|
|
|
|
|
443 |
)
|
|
|
|
|
|
|
444 |
|
445 |
-
|
446 |
-
|
|
|
447 |
|
448 |
-
|
449 |
-
|
450 |
|
451 |
-
|
|
|
|
|
|
|
452 |
|
453 |
def _flash_attention_forward(
|
454 |
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
@@ -684,6 +676,7 @@ class GemmoeSparseMoeBlock(nn.Module):
|
|
684 |
|
685 |
self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
686 |
|
|
|
687 |
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
688 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
689 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
@@ -724,6 +717,7 @@ class GemmoeDecoderLayer(nn.Module):
|
|
724 |
self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
725 |
self.post_attention_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
726 |
|
|
|
727 |
def forward(
|
728 |
self,
|
729 |
hidden_states: torch.Tensor,
|
@@ -973,7 +967,7 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
973 |
hidden_states = inputs_embeds
|
974 |
|
975 |
# Normalize
|
976 |
-
scale_factor = torch.tensor(
|
977 |
hidden_states = hidden_states * scale_factor
|
978 |
# Decoder layers
|
979 |
all_hidden_states = () if output_hidden_states else None
|
@@ -986,8 +980,8 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
986 |
all_hidden_states += (hidden_states,)
|
987 |
|
988 |
if self.gradient_checkpointing and self.training:
|
989 |
-
layer_outputs =
|
990 |
-
decoder_layer
|
991 |
hidden_states,
|
992 |
causal_mask,
|
993 |
position_ids,
|
@@ -1200,19 +1194,34 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1200 |
)
|
1201 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1202 |
|
1203 |
-
|
1204 |
-
|
1205 |
-
|
1206 |
-
|
1207 |
-
|
1208 |
-
|
1209 |
-
|
1210 |
-
|
1211 |
-
|
1212 |
-
|
1213 |
-
|
1214 |
-
|
1215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1216 |
|
1217 |
hidden_states = outputs[0]
|
1218 |
|
|
|
18 |
import math
|
19 |
import warnings
|
20 |
from typing import List, Optional, Tuple, Union
|
21 |
+
import contextlib
|
22 |
import torch
|
23 |
import torch.nn.functional as F
|
24 |
import torch.utils.checkpoint
|
25 |
from torch import nn
|
26 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
27 |
+
import flash_attn_cuda_utils
|
28 |
|
29 |
from transformers.activations import ACT2FN
|
30 |
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
|
|
307 |
- The attention weights (if `output_attentions=True`).
|
308 |
- The past key-value cache (if `use_cache=True`).
|
309 |
"""
|
310 |
+
|
311 |
bsz, q_len, _ = hidden_states.size()
|
312 |
|
313 |
query_states = self.q_proj(hidden_states)
|
|
|
333 |
|
334 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
335 |
|
336 |
+
with torch.no_grad() if not self.training else contextlib.nullcontext():
|
337 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
338 |
+
if attention_mask is not None:
|
339 |
+
if cache_position is not None:
|
340 |
+
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
|
341 |
+
else:
|
342 |
+
causal_mask = attention_mask
|
343 |
+
attn_weights = attn_weights + causal_mask
|
344 |
|
345 |
# upcast attention to fp32
|
346 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
|
|
375 |
# TODO: Remove this attribute once Flash Attention for RoCm is bumped to 2.1.
|
376 |
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
377 |
|
378 |
+
def forward(
|
379 |
+
self,
|
380 |
+
hidden_states: torch.Tensor,
|
381 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
382 |
+
position_ids: Optional[torch.LongTensor] = None,
|
383 |
+
past_key_value: Optional[Cache] = None,
|
384 |
+
output_attentions: bool = False,
|
385 |
+
use_cache: bool = False,
|
386 |
+
cache_position: Optional[torch.LongTensor] = None,
|
387 |
+
**kwargs,
|
388 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
389 |
+
output_attentions = False
|
390 |
+
|
391 |
+
bsz, q_len, _ = hidden_states.size()
|
392 |
+
|
393 |
+
query_states = self.q_proj(hidden_states)
|
394 |
+
key_states = self.k_proj(hidden_states)
|
395 |
+
value_states = self.v_proj(hidden_states)
|
396 |
+
|
397 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
398 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
399 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
400 |
+
|
401 |
+
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
|
402 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
|
403 |
+
|
404 |
+
past_key_value = getattr(self, "past_key_value", past_key_value)
|
405 |
+
if past_key_value is not None:
|
406 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
407 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
408 |
+
|
409 |
+
query_states = query_states.transpose(1, 2)
|
410 |
+
key_states = key_states.transpose(1, 2)
|
411 |
+
value_states = value_states.transpose(1, 2)
|
412 |
+
|
413 |
+
dropout_rate = self.attention_dropout if self.training else 0.0
|
414 |
+
|
415 |
+
input_dtype = query_states.dtype
|
416 |
+
if input_dtype == torch.float32:
|
417 |
+
if torch.is_autocast_enabled():
|
418 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
419 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
420 |
+
target_dtype = self.config._pre_quantization_dtype
|
421 |
+
else:
|
422 |
+
target_dtype = self.q_proj.weight.dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
423 |
|
424 |
+
logger.warning_once(
|
425 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
426 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
427 |
+
f" {target_dtype}."
|
428 |
)
|
429 |
+
query_states = query_states.to(target_dtype)
|
430 |
+
key_states = key_states.to(target_dtype)
|
431 |
+
value_states = value_states.to(target_dtype)
|
432 |
|
433 |
+
attn_output = flash_attn_cuda_utils.pyt_flash_scaled_dot_attention(
|
434 |
+
query_states, key_states, value_states, attn_mask=attention_mask, dropout_prob=dropout_rate
|
435 |
+
)
|
436 |
|
437 |
+
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
438 |
+
attn_output = self.o_proj(attn_output)
|
439 |
|
440 |
+
if not output_attentions:
|
441 |
+
attn_weights = None
|
442 |
+
|
443 |
+
return attn_output, attn_weights, past_key_value
|
444 |
|
445 |
def _flash_attention_forward(
|
446 |
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
|
|
676 |
|
677 |
self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
678 |
|
679 |
+
@torch.jit.script
|
680 |
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
681 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
682 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
|
717 |
self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
718 |
self.post_attention_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
719 |
|
720 |
+
@torch.jit.script
|
721 |
def forward(
|
722 |
self,
|
723 |
hidden_states: torch.Tensor,
|
|
|
967 |
hidden_states = inputs_embeds
|
968 |
|
969 |
# Normalize
|
970 |
+
scale_factor = torch.tensor(math.sqrt(self.config.hidden_size), dtype=hidden_states.dtype)
|
971 |
hidden_states = hidden_states * scale_factor
|
972 |
# Decoder layers
|
973 |
all_hidden_states = () if output_hidden_states else None
|
|
|
980 |
all_hidden_states += (hidden_states,)
|
981 |
|
982 |
if self.gradient_checkpointing and self.training:
|
983 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
984 |
+
decoder_layer,
|
985 |
hidden_states,
|
986 |
causal_mask,
|
987 |
position_ids,
|
|
|
1194 |
)
|
1195 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1196 |
|
1197 |
+
if self.training:
|
1198 |
+
outputs = torch.utils.checkpoint.checkpoint(
|
1199 |
+
self.model,
|
1200 |
+
input_ids,
|
1201 |
+
attention_mask,
|
1202 |
+
position_ids,
|
1203 |
+
past_key_values,
|
1204 |
+
inputs_embeds,
|
1205 |
+
use_cache,
|
1206 |
+
output_attentions,
|
1207 |
+
output_hidden_states,
|
1208 |
+
return_dict,
|
1209 |
+
cache_position,
|
1210 |
+
)
|
1211 |
+
else:
|
1212 |
+
outputs = self.model(
|
1213 |
+
input_ids=input_ids,
|
1214 |
+
attention_mask=attention_mask,
|
1215 |
+
position_ids=position_ids,
|
1216 |
+
past_key_values=past_key_values,
|
1217 |
+
inputs_embeds=inputs_embeds,
|
1218 |
+
use_cache=use_cache,
|
1219 |
+
output_attentions=output_attentions,
|
1220 |
+
output_hidden_states=output_hidden_states,
|
1221 |
+
output_router_logits=output_router_logits,
|
1222 |
+
return_dict=return_dict,
|
1223 |
+
cache_position=cache_position,
|
1224 |
+
)
|
1225 |
|
1226 |
hidden_states = outputs[0]
|
1227 |
|