Add logit processor for NaN or Inf scores
Browse files- modeling_chatglm.py +17 -3
modeling_chatglm.py
CHANGED
@@ -3,6 +3,7 @@
|
|
3 |
import math
|
4 |
import copy
|
5 |
import os
|
|
|
6 |
|
7 |
import torch
|
8 |
import torch.utils.checkpoint
|
@@ -23,8 +24,10 @@ from transformers.modeling_outputs import (
|
|
23 |
BaseModelOutputWithPastAndCrossAttentions,
|
24 |
)
|
25 |
from transformers.modeling_utils import PreTrainedModel
|
26 |
-
|
27 |
from transformers.utils import logging
|
|
|
|
|
|
|
28 |
from .configuration_chatglm import ChatGLMConfig
|
29 |
|
30 |
# flags required to enable jit fusion kernels
|
@@ -44,6 +47,14 @@ CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|
44 |
]
|
45 |
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
|
48 |
"""Load tf checkpoints in a pytorch model."""
|
49 |
try:
|
@@ -1078,11 +1089,14 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1078 |
|
1079 |
@torch.no_grad()
|
1080 |
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
1081 |
-
do_sample=True, top_p=0.7, temperature=0.95, **kwargs):
|
1082 |
if history is None:
|
1083 |
history = []
|
|
|
|
|
|
|
1084 |
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
1085 |
-
"temperature": temperature, **kwargs}
|
1086 |
if not history:
|
1087 |
prompt = query
|
1088 |
else:
|
|
|
3 |
import math
|
4 |
import copy
|
5 |
import os
|
6 |
+
import time
|
7 |
|
8 |
import torch
|
9 |
import torch.utils.checkpoint
|
|
|
24 |
BaseModelOutputWithPastAndCrossAttentions,
|
25 |
)
|
26 |
from transformers.modeling_utils import PreTrainedModel
|
|
|
27 |
from transformers.utils import logging
|
28 |
+
from transformers.generation.logits_process import LogitsProcessor
|
29 |
+
from transformers.generation.utils import LogitsProcessorList
|
30 |
+
|
31 |
from .configuration_chatglm import ChatGLMConfig
|
32 |
|
33 |
# flags required to enable jit fusion kernels
|
|
|
47 |
]
|
48 |
|
49 |
|
50 |
+
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
51 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
52 |
+
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
53 |
+
scores.zero_()
|
54 |
+
scores[..., 20005] = 1e5
|
55 |
+
return scores
|
56 |
+
|
57 |
+
|
58 |
def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
|
59 |
"""Load tf checkpoints in a pytorch model."""
|
60 |
try:
|
|
|
1089 |
|
1090 |
@torch.no_grad()
|
1091 |
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
1092 |
+
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
1093 |
if history is None:
|
1094 |
history = []
|
1095 |
+
if logits_processor is None:
|
1096 |
+
logits_processor = LogitsProcessorList()
|
1097 |
+
logits_processor.append(InvalidScoreLogitsProcessor())
|
1098 |
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
1099 |
+
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
1100 |
if not history:
|
1101 |
prompt = query
|
1102 |
else:
|