Merge branch 'main' into pr/45
Browse files- modeling_chatglm.py +2 -0
modeling_chatglm.py
CHANGED
@@ -970,6 +970,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
970 |
|
971 |
if attention_mask is None:
|
972 |
attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
|
|
|
|
|
973 |
|
974 |
for i, layer in enumerate(self.layers):
|
975 |
|
|
|
970 |
|
971 |
if attention_mask is None:
|
972 |
attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
|
973 |
+
else:
|
974 |
+
attention_mask = attention_mask.to(hidden_states.device)
|
975 |
|
976 |
for i, layer in enumerate(self.layers):
|
977 |
|