Remove hardcode bos_token_id
Browse files- modeling_chatglm.py +11 -13
modeling_chatglm.py
CHANGED
@@ -753,9 +753,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
753 |
def set_input_embeddings(self, new_embeddings: torch.Tensor):
|
754 |
self.word_embeddings = new_embeddings
|
755 |
|
756 |
-
|
757 |
-
|
758 |
-
context_length = seq.index(150004) + 1
|
759 |
|
760 |
attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
|
761 |
attention_mask.tril_()
|
@@ -766,9 +765,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
766 |
return attention_mask
|
767 |
|
768 |
def get_position_ids(self, seq, mask_position, device, gmask=False):
|
769 |
-
context_length = seq.index(
|
770 |
if self.position_encoding_2d:
|
771 |
-
seq_length = seq.index(
|
772 |
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
773 |
if not gmask:
|
774 |
position_ids[seq_length:] = mask_position
|
@@ -823,14 +822,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
823 |
|
824 |
if past_key_values is None:
|
825 |
past_key_values = tuple([None] * len(self.layers))
|
826 |
-
|
827 |
-
MASK, gMASK = 150000, 150001
|
828 |
-
mask_token = MASK if MASK in input_ids else gMASK
|
829 |
-
use_gmask = False if MASK in input_ids else gMASK
|
830 |
seq = input_ids[0].tolist()
|
831 |
|
832 |
-
mask_position = seq.index(mask_token)
|
833 |
-
|
834 |
if attention_mask is None:
|
835 |
attention_mask = self.get_masks(
|
836 |
seq=seq,
|
@@ -838,6 +831,11 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
838 |
)
|
839 |
|
840 |
if position_ids is None:
|
|
|
|
|
|
|
|
|
|
|
841 |
position_ids = self.get_position_ids(
|
842 |
seq=seq,
|
843 |
mask_position=mask_position,
|
@@ -941,7 +939,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
941 |
attention_mask = (attention_mask < 0.5).bool()
|
942 |
|
943 |
if self.position_encoding_2d:
|
944 |
-
seq_length = seq.index(
|
945 |
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
946 |
if not gmask:
|
947 |
position_ids[seq_length:] = mask_position
|
@@ -979,7 +977,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
979 |
|
980 |
# only last token for input_ids if past is not None
|
981 |
if past is not None or past_key_values is not None:
|
982 |
-
context_length = seq.index(
|
983 |
last_token = input_ids[:, -1].unsqueeze(-1)
|
984 |
if self.position_encoding_2d:
|
985 |
position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
|
|
|
753 |
def set_input_embeddings(self, new_embeddings: torch.Tensor):
|
754 |
self.word_embeddings = new_embeddings
|
755 |
|
756 |
+
def get_masks(self, seq, device):
|
757 |
+
context_length = seq.index(self.config.bos_token_id) + 1
|
|
|
758 |
|
759 |
attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
|
760 |
attention_mask.tril_()
|
|
|
765 |
return attention_mask
|
766 |
|
767 |
def get_position_ids(self, seq, mask_position, device, gmask=False):
|
768 |
+
context_length = seq.index(self.config.bos_token_id) + 1
|
769 |
if self.position_encoding_2d:
|
770 |
+
seq_length = seq.index(self.config.bos_token_id)
|
771 |
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
772 |
if not gmask:
|
773 |
position_ids[seq_length:] = mask_position
|
|
|
822 |
|
823 |
if past_key_values is None:
|
824 |
past_key_values = tuple([None] * len(self.layers))
|
|
|
|
|
|
|
|
|
825 |
seq = input_ids[0].tolist()
|
826 |
|
|
|
|
|
827 |
if attention_mask is None:
|
828 |
attention_mask = self.get_masks(
|
829 |
seq=seq,
|
|
|
831 |
)
|
832 |
|
833 |
if position_ids is None:
|
834 |
+
MASK, gMASK = 150000, 150001
|
835 |
+
mask_token = MASK if MASK in input_ids else gMASK
|
836 |
+
use_gmask = False if MASK in input_ids else gMASK
|
837 |
+
|
838 |
+
mask_position = seq.index(mask_token)
|
839 |
position_ids = self.get_position_ids(
|
840 |
seq=seq,
|
841 |
mask_position=mask_position,
|
|
|
939 |
attention_mask = (attention_mask < 0.5).bool()
|
940 |
|
941 |
if self.position_encoding_2d:
|
942 |
+
seq_length = seq.index(self.config.bos_token_id)
|
943 |
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
944 |
if not gmask:
|
945 |
position_ids[seq_length:] = mask_position
|
|
|
977 |
|
978 |
# only last token for input_ids if past is not None
|
979 |
if past is not None or past_key_values is not None:
|
980 |
+
context_length = seq.index(self.config.bos_token_id)
|
981 |
last_token = input_ids[:, -1].unsqueeze(-1)
|
982 |
if self.position_encoding_2d:
|
983 |
position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
|