zxdu20 commited on
Commit
f82b180
1 Parent(s): fb23542

Fix position ids expand

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +2 -2
modeling_chatglm.py CHANGED
@@ -680,7 +680,7 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
680
  batch_size, seq_length = input_ids.shape
681
  context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
682
  if self.position_encoding_2d:
683
- position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
684
  for i, context_length in enumerate(context_lengths):
685
  position_ids[i, context_length:] = mask_positions[i]
686
  block_position_ids = [torch.cat((
@@ -690,7 +690,7 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
690
  block_position_ids = torch.stack(block_position_ids, dim=0)
691
  position_ids = torch.stack((position_ids, block_position_ids), dim=1)
692
  else:
693
- position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
694
  if not gmask:
695
  for i, context_length in enumerate(context_lengths):
696
  position_ids[context_length:] = mask_positions[i]
 
680
  batch_size, seq_length = input_ids.shape
681
  context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
682
  if self.position_encoding_2d:
683
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
684
  for i, context_length in enumerate(context_lengths):
685
  position_ids[i, context_length:] = mask_positions[i]
686
  block_position_ids = [torch.cat((
 
690
  block_position_ids = torch.stack(block_position_ids, dim=0)
691
  position_ids = torch.stack((position_ids, block_position_ids), dim=1)
692
  else:
693
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
694
  if not gmask:
695
  for i, context_length in enumerate(context_lengths):
696
  position_ids[context_length:] = mask_positions[i]