Tuchuanhuhuhu commited on
Commit
d708c00
·
1 Parent(s): 86018c8

支持ChatGLM

Browse files
Files changed (4) hide show
  1. .gitignore +3 -0
  2. modules/models.py +53 -0
  3. modules/presets.py +3 -0
  4. requirements.txt +2 -0
.gitignore CHANGED
@@ -133,7 +133,10 @@ dmypy.json
133
  # Mac system file
134
  **/.DS_Store
135
 
 
136
  api_key.txt
137
  config.json
138
  auth.json
 
 
139
  .idea
 
133
  # Mac system file
134
  **/.DS_Store
135
 
136
+ # 配置文件/模型文件
137
  api_key.txt
138
  config.json
139
  auth.json
140
+ models/
141
+ lora/
142
  .idea
modules/models.py CHANGED
@@ -8,6 +8,7 @@ import os
8
  import sys
9
  import requests
10
  import urllib3
 
11
 
12
  from tqdm import tqdm
13
  import colorama
@@ -191,6 +192,55 @@ class OpenAIClient(BaseLLMModel):
191
  # logging.error(f"Error: {e}")
192
  continue
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  def get_model(
196
  model_name, access_key=None, temperature=None, top_p=None, system_prompt=None
@@ -198,6 +248,7 @@ def get_model(
198
  msg = f"模型设置为了: {model_name}"
199
  logging.info(msg)
200
  model_type = ModelType.get_type(model_name)
 
201
  if model_type == ModelType.OpenAI:
202
  model = OpenAIClient(
203
  model_name=model_name,
@@ -206,6 +257,8 @@ def get_model(
206
  temperature=temperature,
207
  top_p=top_p,
208
  )
 
 
209
  return model, msg
210
 
211
 
 
8
  import sys
9
  import requests
10
  import urllib3
11
+ import platform
12
 
13
  from tqdm import tqdm
14
  import colorama
 
192
  # logging.error(f"Error: {e}")
193
  continue
194
 
195
+ class ChatGLM_Client(BaseLLMModel):
196
+ def __init__(
197
+ self,
198
+ model_name,
199
+ model_path = None
200
+ ) -> None:
201
+ super().__init__(
202
+ model_name=model_name
203
+ )
204
+ from transformers import AutoTokenizer, AutoModel
205
+ import torch
206
+ system_name = platform.system()
207
+ if os.path.exists("models"):
208
+ model_dirs = os.listdir("models")
209
+ if model_name in model_dirs:
210
+ model_path = f"models/{model_name}"
211
+ if model_path is not None:
212
+ model_source = model_path
213
+ else:
214
+ model_source = f"THUDM/{model_name}"
215
+ self.tokenizer = AutoTokenizer.from_pretrained(model_source, trust_remote_code=True)
216
+ if torch.cuda.is_available():
217
+ # run on CUDA
218
+ model = AutoModel.from_pretrained(model_source, trust_remote_code=True).half().cuda()
219
+ elif system_name == "Darwin" and model_path is not None:
220
+ # running on macOS and model already downloaded
221
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().to('mps')
222
+ else:
223
+ # run on CPU
224
+ model = AutoModel.from_pretrained(model_source, trust_remote_code=True).float()
225
+ model = model.eval()
226
+ self.model = model
227
+
228
+ def _get_glm_style_input(self):
229
+ history = [x["content"] for x in self.history]
230
+ query = history.pop()
231
+ return history, query
232
+
233
+ def get_answer_at_once(self):
234
+ history, query = self._get_glm_style_input()
235
+ response, _ = self.model.chat(self.tokenizer, query, history=history)
236
+ return response
237
+
238
+ def get_answer_stream_iter(self):
239
+ history, query = self._get_glm_style_input()
240
+ for response, history in self.model.stream_chat(self.tokenizer, query, history, max_length=self.token_upper_limit, top_p=self.top_p,
241
+ temperature=self.temperature):
242
+ yield response
243
+
244
 
245
  def get_model(
246
  model_name, access_key=None, temperature=None, top_p=None, system_prompt=None
 
248
  msg = f"模型设置为了: {model_name}"
249
  logging.info(msg)
250
  model_type = ModelType.get_type(model_name)
251
+ del model
252
  if model_type == ModelType.OpenAI:
253
  model = OpenAIClient(
254
  model_name=model_name,
 
257
  temperature=temperature,
258
  top_p=top_p,
259
  )
260
+ elif model_type == ModelType.ChatGLM:
261
+ model = ChatGLM_Client(model_name)
262
  return model, msg
263
 
264
 
modules/presets.py CHANGED
@@ -57,6 +57,9 @@ MODELS = [
57
  "gpt-4-0314",
58
  "gpt-4-32k",
59
  "gpt-4-32k-0314",
 
 
 
60
  ] # 可选的模型
61
 
62
  MODEL_TOKEN_LIMIT = {
 
57
  "gpt-4-0314",
58
  "gpt-4-32k",
59
  "gpt-4-32k-0314",
60
+ "chatglm-6b",
61
+ "chatglm-6b-int4",
62
+ "chatglm-6b-int4-qe"
63
  ] # 可选的模型
64
 
65
  MODEL_TOKEN_LIMIT = {
requirements.txt CHANGED
@@ -13,3 +13,5 @@ markdown
13
  PyPDF2
14
  pdfplumber
15
  pandas
 
 
 
13
  PyPDF2
14
  pdfplumber
15
  pandas
16
+ transformers
17
+ torch