Qingsong Lv commited on
Commit
e06f497
·
1 Parent(s): cde457b

fix mask and position bug for batch generation

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +27 -6
modeling_chatglm.py CHANGED
@@ -662,6 +662,12 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
662
  """Initialize the weights."""
663
  return
664
 
 
 
 
 
 
 
665
  def get_masks(self, input_ids, device):
666
  batch_size, seq_length = input_ids.shape
667
  context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
@@ -669,6 +675,10 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
669
  attention_mask.tril_()
670
  for i, context_length in enumerate(context_lengths):
671
  attention_mask[i, :, :context_length] = 1
 
 
 
 
672
  attention_mask.unsqueeze_(1)
673
  attention_mask = (attention_mask < 0.5).bool()
674
 
@@ -676,16 +686,22 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
676
 
677
  def get_position_ids(self, input_ids, mask_positions, device, gmask=False):
678
  batch_size, seq_length = input_ids.shape
 
679
  context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
680
  if self.position_encoding_2d:
681
- position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
682
- for i, context_length in enumerate(context_lengths):
683
- position_ids[i, context_length:] = mask_positions[i]
684
  block_position_ids = [torch.cat((
685
  torch.zeros(context_length, dtype=torch.long, device=device),
686
  torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
687
  )) for context_length in context_lengths]
688
  block_position_ids = torch.stack(block_position_ids, dim=0)
 
 
 
 
 
689
  position_ids = torch.stack((position_ids, block_position_ids), dim=1)
690
  else:
691
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
@@ -1094,15 +1110,20 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1094
  if attention_mask is not None and attention_mask.dtype == torch.bool:
1095
  attention_mask = attention_mask[:, :, -1:]
1096
  else:
1097
- attention_mask = None
 
 
 
 
1098
  if position_ids is not None:
1099
  position_ids = position_ids[..., -1:]
1100
  else:
 
1101
  context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
1102
  if self.position_encoding_2d:
1103
  position_ids = torch.tensor(
1104
- [[mask_position, seq_length - context_length] for mask_position, context_length in
1105
- zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1)
1106
  else:
1107
  position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long,
1108
  device=input_ids.device).unsqueeze(-1)
 
662
  """Initialize the weights."""
663
  return
664
 
665
+ def get_pad_length(self, seq):
666
+ l = 0
667
+ while l < len(seq) and seq[l] == self.config.pad_token_id:
668
+ l += 1
669
+ return l
670
+
671
  def get_masks(self, input_ids, device):
672
  batch_size, seq_length = input_ids.shape
673
  context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
 
675
  attention_mask.tril_()
676
  for i, context_length in enumerate(context_lengths):
677
  attention_mask[i, :, :context_length] = 1
678
+ pad_lengths = [self.get_pad_length(seq.tolist()) for seq in input_ids]
679
+ for i, pad_length in enumerate(pad_lengths):
680
+ attention_mask[i, :, :pad_length] = 0
681
+ attention_mask[i, :pad_length, :] = 0
682
  attention_mask.unsqueeze_(1)
683
  attention_mask = (attention_mask < 0.5).bool()
684
 
 
686
 
687
  def get_position_ids(self, input_ids, mask_positions, device, gmask=False):
688
  batch_size, seq_length = input_ids.shape
689
+ pad_lengths = [self.get_pad_length(seq.tolist()) for seq in input_ids]
690
  context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
691
  if self.position_encoding_2d:
692
+ position_ids = [torch.arange(seq_length-pad_length, dtype=torch.long, device=device) for pad_length in pad_lengths]
693
+ for i, (context_length, pad_length) in enumerate(zip(context_lengths, pad_lengths)):
694
+ position_ids[i][context_length-pad_length:] = mask_positions[i] - pad_length
695
  block_position_ids = [torch.cat((
696
  torch.zeros(context_length, dtype=torch.long, device=device),
697
  torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
698
  )) for context_length in context_lengths]
699
  block_position_ids = torch.stack(block_position_ids, dim=0)
700
+ position_ids = [torch.cat((
701
+ torch.zeros(pad_length, dtype=torch.long, device=device),
702
+ range_pos
703
+ )) for pad_length, range_pos in zip(pad_lengths, position_ids)]
704
+ position_ids = torch.stack(position_ids, dim=0)
705
  position_ids = torch.stack((position_ids, block_position_ids), dim=1)
706
  else:
707
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
 
1110
  if attention_mask is not None and attention_mask.dtype == torch.bool:
1111
  attention_mask = attention_mask[:, :, -1:]
1112
  else:
1113
+ attention_mask = self.get_masks(
1114
+ input_ids,
1115
+ device=input_ids.device
1116
+ )
1117
+ attention_mask[:, :, -1:]
1118
  if position_ids is not None:
1119
  position_ids = position_ids[..., -1:]
1120
  else:
1121
+ pad_lengths = [self.get_pad_length(seq.tolist()) for seq in input_ids]
1122
  context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
1123
  if self.position_encoding_2d:
1124
  position_ids = torch.tensor(
1125
+ [[mask_position - pad_length, seq_length - context_length] for pad_length, mask_position, context_length in
1126
+ zip(pad_lengths, mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1)
1127
  else:
1128
  position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long,
1129
  device=input_ids.device).unsqueeze(-1)