Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu
commited on
Commit
·
d708c00
1
Parent(s):
86018c8
支持ChatGLM
Browse files- .gitignore +3 -0
- modules/models.py +53 -0
- modules/presets.py +3 -0
- requirements.txt +2 -0
.gitignore
CHANGED
@@ -133,7 +133,10 @@ dmypy.json
|
|
133 |
# Mac system file
|
134 |
**/.DS_Store
|
135 |
|
|
|
136 |
api_key.txt
|
137 |
config.json
|
138 |
auth.json
|
|
|
|
|
139 |
.idea
|
|
|
133 |
# Mac system file
|
134 |
**/.DS_Store
|
135 |
|
136 |
+
# 配置文件/模型文件
|
137 |
api_key.txt
|
138 |
config.json
|
139 |
auth.json
|
140 |
+
models/
|
141 |
+
lora/
|
142 |
.idea
|
modules/models.py
CHANGED
@@ -8,6 +8,7 @@ import os
|
|
8 |
import sys
|
9 |
import requests
|
10 |
import urllib3
|
|
|
11 |
|
12 |
from tqdm import tqdm
|
13 |
import colorama
|
@@ -191,6 +192,55 @@ class OpenAIClient(BaseLLMModel):
|
|
191 |
# logging.error(f"Error: {e}")
|
192 |
continue
|
193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
|
195 |
def get_model(
|
196 |
model_name, access_key=None, temperature=None, top_p=None, system_prompt=None
|
@@ -198,6 +248,7 @@ def get_model(
|
|
198 |
msg = f"模型设置为了: {model_name}"
|
199 |
logging.info(msg)
|
200 |
model_type = ModelType.get_type(model_name)
|
|
|
201 |
if model_type == ModelType.OpenAI:
|
202 |
model = OpenAIClient(
|
203 |
model_name=model_name,
|
@@ -206,6 +257,8 @@ def get_model(
|
|
206 |
temperature=temperature,
|
207 |
top_p=top_p,
|
208 |
)
|
|
|
|
|
209 |
return model, msg
|
210 |
|
211 |
|
|
|
8 |
import sys
|
9 |
import requests
|
10 |
import urllib3
|
11 |
+
import platform
|
12 |
|
13 |
from tqdm import tqdm
|
14 |
import colorama
|
|
|
192 |
# logging.error(f"Error: {e}")
|
193 |
continue
|
194 |
|
195 |
+
class ChatGLM_Client(BaseLLMModel):
|
196 |
+
def __init__(
|
197 |
+
self,
|
198 |
+
model_name,
|
199 |
+
model_path = None
|
200 |
+
) -> None:
|
201 |
+
super().__init__(
|
202 |
+
model_name=model_name
|
203 |
+
)
|
204 |
+
from transformers import AutoTokenizer, AutoModel
|
205 |
+
import torch
|
206 |
+
system_name = platform.system()
|
207 |
+
if os.path.exists("models"):
|
208 |
+
model_dirs = os.listdir("models")
|
209 |
+
if model_name in model_dirs:
|
210 |
+
model_path = f"models/{model_name}"
|
211 |
+
if model_path is not None:
|
212 |
+
model_source = model_path
|
213 |
+
else:
|
214 |
+
model_source = f"THUDM/{model_name}"
|
215 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_source, trust_remote_code=True)
|
216 |
+
if torch.cuda.is_available():
|
217 |
+
# run on CUDA
|
218 |
+
model = AutoModel.from_pretrained(model_source, trust_remote_code=True).half().cuda()
|
219 |
+
elif system_name == "Darwin" and model_path is not None:
|
220 |
+
# running on macOS and model already downloaded
|
221 |
+
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().to('mps')
|
222 |
+
else:
|
223 |
+
# run on CPU
|
224 |
+
model = AutoModel.from_pretrained(model_source, trust_remote_code=True).float()
|
225 |
+
model = model.eval()
|
226 |
+
self.model = model
|
227 |
+
|
228 |
+
def _get_glm_style_input(self):
|
229 |
+
history = [x["content"] for x in self.history]
|
230 |
+
query = history.pop()
|
231 |
+
return history, query
|
232 |
+
|
233 |
+
def get_answer_at_once(self):
|
234 |
+
history, query = self._get_glm_style_input()
|
235 |
+
response, _ = self.model.chat(self.tokenizer, query, history=history)
|
236 |
+
return response
|
237 |
+
|
238 |
+
def get_answer_stream_iter(self):
|
239 |
+
history, query = self._get_glm_style_input()
|
240 |
+
for response, history in self.model.stream_chat(self.tokenizer, query, history, max_length=self.token_upper_limit, top_p=self.top_p,
|
241 |
+
temperature=self.temperature):
|
242 |
+
yield response
|
243 |
+
|
244 |
|
245 |
def get_model(
|
246 |
model_name, access_key=None, temperature=None, top_p=None, system_prompt=None
|
|
|
248 |
msg = f"模型设置为了: {model_name}"
|
249 |
logging.info(msg)
|
250 |
model_type = ModelType.get_type(model_name)
|
251 |
+
del model
|
252 |
if model_type == ModelType.OpenAI:
|
253 |
model = OpenAIClient(
|
254 |
model_name=model_name,
|
|
|
257 |
temperature=temperature,
|
258 |
top_p=top_p,
|
259 |
)
|
260 |
+
elif model_type == ModelType.ChatGLM:
|
261 |
+
model = ChatGLM_Client(model_name)
|
262 |
return model, msg
|
263 |
|
264 |
|
modules/presets.py
CHANGED
@@ -57,6 +57,9 @@ MODELS = [
|
|
57 |
"gpt-4-0314",
|
58 |
"gpt-4-32k",
|
59 |
"gpt-4-32k-0314",
|
|
|
|
|
|
|
60 |
] # 可选的模型
|
61 |
|
62 |
MODEL_TOKEN_LIMIT = {
|
|
|
57 |
"gpt-4-0314",
|
58 |
"gpt-4-32k",
|
59 |
"gpt-4-32k-0314",
|
60 |
+
"chatglm-6b",
|
61 |
+
"chatglm-6b-int4",
|
62 |
+
"chatglm-6b-int4-qe"
|
63 |
] # 可选的模型
|
64 |
|
65 |
MODEL_TOKEN_LIMIT = {
|
requirements.txt
CHANGED
@@ -13,3 +13,5 @@ markdown
|
|
13 |
PyPDF2
|
14 |
pdfplumber
|
15 |
pandas
|
|
|
|
|
|
13 |
PyPDF2
|
14 |
pdfplumber
|
15 |
pandas
|
16 |
+
transformers
|
17 |
+
torch
|