Fix Chinese punctuation
Browse files- modeling_chatglm.py +18 -4
modeling_chatglm.py
CHANGED
@@ -4,6 +4,7 @@ import math
|
|
4 |
import copy
|
5 |
import os
|
6 |
import warnings
|
|
|
7 |
|
8 |
import torch
|
9 |
import torch.utils.checkpoint
|
@@ -1085,6 +1086,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1085 |
for layer_past in past
|
1086 |
)
|
1087 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1088 |
@torch.no_grad()
|
1089 |
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
1090 |
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
@@ -1107,8 +1123,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1107 |
outputs = self.generate(**input_ids, **gen_kwargs)
|
1108 |
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
1109 |
response = tokenizer.decode(outputs)
|
1110 |
-
response =
|
1111 |
-
response = response.replace("[[训练时间]]", "2023年")
|
1112 |
history = history + [(query, response)]
|
1113 |
return response, history
|
1114 |
|
@@ -1134,8 +1149,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1134 |
for outputs in self.stream_generate(**input_ids, **gen_kwargs):
|
1135 |
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
1136 |
response = tokenizer.decode(outputs)
|
1137 |
-
response =
|
1138 |
-
response = response.replace("[[训练时间]]", "2023年")
|
1139 |
new_history = history + [(query, response)]
|
1140 |
yield response, new_history
|
1141 |
|
|
|
4 |
import copy
|
5 |
import os
|
6 |
import warnings
|
7 |
+
import re
|
8 |
|
9 |
import torch
|
10 |
import torch.utils.checkpoint
|
|
|
1086 |
for layer_past in past
|
1087 |
)
|
1088 |
|
1089 |
+
def process_response(self, response):
|
1090 |
+
response = response.strip()
|
1091 |
+
response = response.replace("[[训练时间]]", "2023年")
|
1092 |
+
punkts = [
|
1093 |
+
[",", ","],
|
1094 |
+
["!", "!"],
|
1095 |
+
[":", ":"],
|
1096 |
+
[";", ";"],
|
1097 |
+
["\?", "?"],
|
1098 |
+
]
|
1099 |
+
for item in punkts:
|
1100 |
+
response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
|
1101 |
+
response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
|
1102 |
+
return response
|
1103 |
+
|
1104 |
@torch.no_grad()
|
1105 |
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
1106 |
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
|
|
1123 |
outputs = self.generate(**input_ids, **gen_kwargs)
|
1124 |
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
1125 |
response = tokenizer.decode(outputs)
|
1126 |
+
response = self.process_response(response)
|
|
|
1127 |
history = history + [(query, response)]
|
1128 |
return response, history
|
1129 |
|
|
|
1149 |
for outputs in self.stream_generate(**input_ids, **gen_kwargs):
|
1150 |
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
1151 |
response = tokenizer.decode(outputs)
|
1152 |
+
response = self.process_response(response)
|
|
|
1153 |
new_history = history + [(query, response)]
|
1154 |
yield response, new_history
|
1155 |
|