Set ignore_index for CrossEntropyLoss
Browse files- modeling_chatglm.py +1 -1
modeling_chatglm.py
CHANGED
@@ -1124,7 +1124,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1124 |
shift_logits = lm_logits[..., :-1, :].contiguous()
|
1125 |
shift_labels = labels[..., 1:].contiguous()
|
1126 |
# Flatten the tokens
|
1127 |
-
loss_fct = CrossEntropyLoss()
|
1128 |
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
1129 |
|
1130 |
lm_logits = lm_logits.to(hidden_states.dtype)
|
|
|
1124 |
shift_logits = lm_logits[..., :-1, :].contiguous()
|
1125 |
shift_labels = labels[..., 1:].contiguous()
|
1126 |
# Flatten the tokens
|
1127 |
+
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
1128 |
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
1129 |
|
1130 |
lm_logits = lm_logits.to(hidden_states.dtype)
|