FIX transformers compat

#28
by Qubitium - opened
Files changed (1) hide show
  1. modeling_chatglm.py +10 -10
modeling_chatglm.py CHANGED
@@ -597,17 +597,17 @@ class GLMTransformer(torch.nn.Module):
597
  layer_ret = torch.utils.checkpoint.checkpoint(
598
  layer,
599
  hidden_states,
600
- attention_mask,
601
- rotary_pos_emb,
602
- kv_caches[index],
603
- use_cache,
604
  use_reentrant=False
605
  )
606
  else:
607
  layer_ret = layer(
608
  hidden_states,
609
- attention_mask,
610
- rotary_pos_emb,
611
  kv_cache=kv_caches[index],
612
  use_cache=use_cache
613
  )
@@ -724,7 +724,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
724
  config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
725
  )
726
 
727
- self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio, original_impl=config.original_rope,
728
  device=device, dtype=config.torch_dtype)
729
  self.encoder = init_method(GLMTransformer, config, **init_kwargs)
730
  self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
@@ -740,8 +740,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
740
  self,
741
  input_ids,
742
  position_ids: Optional[torch.Tensor] = None,
743
- attention_mask: Optional[torch.BoolTensor] = None,
744
- full_attention_mask: Optional[torch.BoolTensor] = None,
745
  past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
746
  inputs_embeds: Optional[torch.Tensor] = None,
747
  use_cache: Optional[bool] = None,
@@ -1212,4 +1212,4 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1212
  past_key_values=transformer_outputs.past_key_values,
1213
  hidden_states=transformer_outputs.hidden_states,
1214
  attentions=transformer_outputs.attentions,
1215
- )
 
597
  layer_ret = torch.utils.checkpoint.checkpoint(
598
  layer,
599
  hidden_states,
600
+ attention_mask=attention_mask,
601
+ rotary_pos_emb=rotary_pos_emb,
602
+ kv_caches=kv_caches[index],
603
+ use_cache=use_cache,
604
  use_reentrant=False
605
  )
606
  else:
607
  layer_ret = layer(
608
  hidden_states,
609
+ attention_mask=attention_mask,
610
+ rotary_pos_emb=rotary_pos_emb,
611
  kv_cache=kv_caches[index],
612
  use_cache=use_cache
613
  )
 
724
  config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
725
  )
726
 
727
+ self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio, original_impl=config.original_rope,
728
  device=device, dtype=config.torch_dtype)
729
  self.encoder = init_method(GLMTransformer, config, **init_kwargs)
730
  self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
 
740
  self,
741
  input_ids,
742
  position_ids: Optional[torch.Tensor] = None,
743
+ attention_mask: Optional[torch.Tensor] = None,
744
+ full_attention_mask: Optional[torch.Tensor] = None,
745
  past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
746
  inputs_embeds: Optional[torch.Tensor] = None,
747
  use_cache: Optional[bool] = None,
 
1212
  past_key_values=transformer_outputs.past_key_values,
1213
  hidden_states=transformer_outputs.hidden_states,
1214
  attentions=transformer_outputs.attentions,
1215
+ )