Qingsong Lv
commited on
Commit
·
e06f497
1
Parent(s):
cde457b
fix mask and position bug for batch generation
Browse files- modeling_chatglm.py +27 -6
modeling_chatglm.py
CHANGED
@@ -662,6 +662,12 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
662 |
"""Initialize the weights."""
|
663 |
return
|
664 |
|
|
|
|
|
|
|
|
|
|
|
|
|
665 |
def get_masks(self, input_ids, device):
|
666 |
batch_size, seq_length = input_ids.shape
|
667 |
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
@@ -669,6 +675,10 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
669 |
attention_mask.tril_()
|
670 |
for i, context_length in enumerate(context_lengths):
|
671 |
attention_mask[i, :, :context_length] = 1
|
|
|
|
|
|
|
|
|
672 |
attention_mask.unsqueeze_(1)
|
673 |
attention_mask = (attention_mask < 0.5).bool()
|
674 |
|
@@ -676,16 +686,22 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
676 |
|
677 |
def get_position_ids(self, input_ids, mask_positions, device, gmask=False):
|
678 |
batch_size, seq_length = input_ids.shape
|
|
|
679 |
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
680 |
if self.position_encoding_2d:
|
681 |
-
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
|
682 |
-
for i, context_length in enumerate(context_lengths):
|
683 |
-
position_ids[i
|
684 |
block_position_ids = [torch.cat((
|
685 |
torch.zeros(context_length, dtype=torch.long, device=device),
|
686 |
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
|
687 |
)) for context_length in context_lengths]
|
688 |
block_position_ids = torch.stack(block_position_ids, dim=0)
|
|
|
|
|
|
|
|
|
|
|
689 |
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
690 |
else:
|
691 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
@@ -1094,15 +1110,20 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1094 |
if attention_mask is not None and attention_mask.dtype == torch.bool:
|
1095 |
attention_mask = attention_mask[:, :, -1:]
|
1096 |
else:
|
1097 |
-
attention_mask =
|
|
|
|
|
|
|
|
|
1098 |
if position_ids is not None:
|
1099 |
position_ids = position_ids[..., -1:]
|
1100 |
else:
|
|
|
1101 |
context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
|
1102 |
if self.position_encoding_2d:
|
1103 |
position_ids = torch.tensor(
|
1104 |
-
[[mask_position, seq_length - context_length] for mask_position, context_length in
|
1105 |
-
zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1)
|
1106 |
else:
|
1107 |
position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long,
|
1108 |
device=input_ids.device).unsqueeze(-1)
|
|
|
662 |
"""Initialize the weights."""
|
663 |
return
|
664 |
|
665 |
+
def get_pad_length(self, seq):
|
666 |
+
l = 0
|
667 |
+
while l < len(seq) and seq[l] == self.config.pad_token_id:
|
668 |
+
l += 1
|
669 |
+
return l
|
670 |
+
|
671 |
def get_masks(self, input_ids, device):
|
672 |
batch_size, seq_length = input_ids.shape
|
673 |
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
|
|
675 |
attention_mask.tril_()
|
676 |
for i, context_length in enumerate(context_lengths):
|
677 |
attention_mask[i, :, :context_length] = 1
|
678 |
+
pad_lengths = [self.get_pad_length(seq.tolist()) for seq in input_ids]
|
679 |
+
for i, pad_length in enumerate(pad_lengths):
|
680 |
+
attention_mask[i, :, :pad_length] = 0
|
681 |
+
attention_mask[i, :pad_length, :] = 0
|
682 |
attention_mask.unsqueeze_(1)
|
683 |
attention_mask = (attention_mask < 0.5).bool()
|
684 |
|
|
|
686 |
|
687 |
def get_position_ids(self, input_ids, mask_positions, device, gmask=False):
|
688 |
batch_size, seq_length = input_ids.shape
|
689 |
+
pad_lengths = [self.get_pad_length(seq.tolist()) for seq in input_ids]
|
690 |
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
691 |
if self.position_encoding_2d:
|
692 |
+
position_ids = [torch.arange(seq_length-pad_length, dtype=torch.long, device=device) for pad_length in pad_lengths]
|
693 |
+
for i, (context_length, pad_length) in enumerate(zip(context_lengths, pad_lengths)):
|
694 |
+
position_ids[i][context_length-pad_length:] = mask_positions[i] - pad_length
|
695 |
block_position_ids = [torch.cat((
|
696 |
torch.zeros(context_length, dtype=torch.long, device=device),
|
697 |
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
|
698 |
)) for context_length in context_lengths]
|
699 |
block_position_ids = torch.stack(block_position_ids, dim=0)
|
700 |
+
position_ids = [torch.cat((
|
701 |
+
torch.zeros(pad_length, dtype=torch.long, device=device),
|
702 |
+
range_pos
|
703 |
+
)) for pad_length, range_pos in zip(pad_lengths, position_ids)]
|
704 |
+
position_ids = torch.stack(position_ids, dim=0)
|
705 |
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
706 |
else:
|
707 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
|
|
1110 |
if attention_mask is not None and attention_mask.dtype == torch.bool:
|
1111 |
attention_mask = attention_mask[:, :, -1:]
|
1112 |
else:
|
1113 |
+
attention_mask = self.get_masks(
|
1114 |
+
input_ids,
|
1115 |
+
device=input_ids.device
|
1116 |
+
)
|
1117 |
+
attention_mask[:, :, -1:]
|
1118 |
if position_ids is not None:
|
1119 |
position_ids = position_ids[..., -1:]
|
1120 |
else:
|
1121 |
+
pad_lengths = [self.get_pad_length(seq.tolist()) for seq in input_ids]
|
1122 |
context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
|
1123 |
if self.position_encoding_2d:
|
1124 |
position_ids = torch.tensor(
|
1125 |
+
[[mask_position - pad_length, seq_length - context_length] for pad_length, mask_position, context_length in
|
1126 |
+
zip(pad_lengths, mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1)
|
1127 |
else:
|
1128 |
position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long,
|
1129 |
device=input_ids.device).unsqueeze(-1)
|