FIX transformers compat
#28
by
Qubitium
- opened
- modeling_chatglm.py +10 -10
modeling_chatglm.py
CHANGED
@@ -597,17 +597,17 @@ class GLMTransformer(torch.nn.Module):
|
|
597 |
layer_ret = torch.utils.checkpoint.checkpoint(
|
598 |
layer,
|
599 |
hidden_states,
|
600 |
-
attention_mask,
|
601 |
-
rotary_pos_emb,
|
602 |
-
kv_caches[index],
|
603 |
-
use_cache,
|
604 |
use_reentrant=False
|
605 |
)
|
606 |
else:
|
607 |
layer_ret = layer(
|
608 |
hidden_states,
|
609 |
-
attention_mask,
|
610 |
-
rotary_pos_emb,
|
611 |
kv_cache=kv_caches[index],
|
612 |
use_cache=use_cache
|
613 |
)
|
@@ -724,7 +724,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
724 |
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
725 |
)
|
726 |
|
727 |
-
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio, original_impl=config.original_rope,
|
728 |
device=device, dtype=config.torch_dtype)
|
729 |
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
|
730 |
self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
|
@@ -740,8 +740,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
740 |
self,
|
741 |
input_ids,
|
742 |
position_ids: Optional[torch.Tensor] = None,
|
743 |
-
attention_mask: Optional[torch.
|
744 |
-
full_attention_mask: Optional[torch.
|
745 |
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
746 |
inputs_embeds: Optional[torch.Tensor] = None,
|
747 |
use_cache: Optional[bool] = None,
|
@@ -1212,4 +1212,4 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
|
1212 |
past_key_values=transformer_outputs.past_key_values,
|
1213 |
hidden_states=transformer_outputs.hidden_states,
|
1214 |
attentions=transformer_outputs.attentions,
|
1215 |
-
)
|
|
|
597 |
layer_ret = torch.utils.checkpoint.checkpoint(
|
598 |
layer,
|
599 |
hidden_states,
|
600 |
+
attention_mask=attention_mask,
|
601 |
+
rotary_pos_emb=rotary_pos_emb,
|
602 |
+
kv_caches=kv_caches[index],
|
603 |
+
use_cache=use_cache,
|
604 |
use_reentrant=False
|
605 |
)
|
606 |
else:
|
607 |
layer_ret = layer(
|
608 |
hidden_states,
|
609 |
+
attention_mask=attention_mask,
|
610 |
+
rotary_pos_emb=rotary_pos_emb,
|
611 |
kv_cache=kv_caches[index],
|
612 |
use_cache=use_cache
|
613 |
)
|
|
|
724 |
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
725 |
)
|
726 |
|
727 |
+
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio, original_impl=config.original_rope,
|
728 |
device=device, dtype=config.torch_dtype)
|
729 |
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
|
730 |
self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
|
|
|
740 |
self,
|
741 |
input_ids,
|
742 |
position_ids: Optional[torch.Tensor] = None,
|
743 |
+
attention_mask: Optional[torch.Tensor] = None,
|
744 |
+
full_attention_mask: Optional[torch.Tensor] = None,
|
745 |
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
746 |
inputs_embeds: Optional[torch.Tensor] = None,
|
747 |
use_cache: Optional[bool] = None,
|
|
|
1212 |
past_key_values=transformer_outputs.past_key_values,
|
1213 |
hidden_states=transformer_outputs.hidden_states,
|
1214 |
attentions=transformer_outputs.attentions,
|
1215 |
+
)
|