johnsmith253325 commited on
Commit
4cae7cc
·
1 Parent(s): cebe276

feat: 支持ChatGLM3 resolve #941

Browse files
modules/models/ChatGLM.py CHANGED
@@ -4,6 +4,8 @@ import logging
4
  import os
5
  import platform
6
 
 
 
7
  import colorama
8
 
9
  from ..index_func import *
@@ -18,6 +20,7 @@ class ChatGLM_Client(BaseLLMModel):
18
  import torch
19
  from transformers import AutoModel, AutoTokenizer
20
  global CHATGLM_TOKENIZER, CHATGLM_MODEL
 
21
  if CHATGLM_TOKENIZER is None or CHATGLM_MODEL is None:
22
  system_name = platform.system()
23
  model_path = None
@@ -53,7 +56,12 @@ class ChatGLM_Client(BaseLLMModel):
53
  model = model.eval()
54
  CHATGLM_MODEL = model
55
 
56
- def _get_glm_style_input(self):
 
 
 
 
 
57
  history = [x["content"] for x in self.history]
58
  query = history.pop()
59
  logging.debug(colorama.Fore.YELLOW +
@@ -65,6 +73,12 @@ class ChatGLM_Client(BaseLLMModel):
65
  for i in range(0, len(history), 2)]
66
  return history, query
67
 
 
 
 
 
 
 
68
  def get_answer_at_once(self):
69
  history, query = self._get_glm_style_input()
70
  response, _ = CHATGLM_MODEL.chat(
@@ -82,3 +96,12 @@ class ChatGLM_Client(BaseLLMModel):
82
  temperature=self.temperature,
83
  ):
84
  yield response
 
 
 
 
 
 
 
 
 
 
4
  import os
5
  import platform
6
 
7
+ import gc
8
+ import torch
9
  import colorama
10
 
11
  from ..index_func import *
 
20
  import torch
21
  from transformers import AutoModel, AutoTokenizer
22
  global CHATGLM_TOKENIZER, CHATGLM_MODEL
23
+ self.deinitialize()
24
  if CHATGLM_TOKENIZER is None or CHATGLM_MODEL is None:
25
  system_name = platform.system()
26
  model_path = None
 
56
  model = model.eval()
57
  CHATGLM_MODEL = model
58
 
59
+ def _get_glm3_style_input(self):
60
+ history = self.history
61
+ query = history.pop()["content"]
62
+ return history, query
63
+
64
+ def _get_glm2_style_input(self):
65
  history = [x["content"] for x in self.history]
66
  query = history.pop()
67
  logging.debug(colorama.Fore.YELLOW +
 
73
  for i in range(0, len(history), 2)]
74
  return history, query
75
 
76
+ def _get_glm_style_input(self):
77
+ if "glm2" in self.model_name:
78
+ return self._get_glm2_style_input()
79
+ else:
80
+ return self._get_glm3_style_input()
81
+
82
  def get_answer_at_once(self):
83
  history, query = self._get_glm_style_input()
84
  response, _ = CHATGLM_MODEL.chat(
 
96
  temperature=self.temperature,
97
  ):
98
  yield response
99
+
100
+ def deinitialize(self):
101
+ # 释放显存
102
+ global CHATGLM_MODEL, CHATGLM_TOKENIZER
103
+ CHATGLM_MODEL = None
104
+ CHATGLM_TOKENIZER = None
105
+ gc.collect()
106
+ torch.cuda.empty_cache()
107
+ logging.info("ChatGLM model deinitialized")
modules/models/base_model.py CHANGED
@@ -847,6 +847,11 @@ class BaseLLMModel:
847
  """
848
  return gr.update()
849
 
 
 
 
 
 
850
 
851
  class Base_Chat_Langchain_Client(BaseLLMModel):
852
  def __init__(self, model_name, user_name=""):
 
847
  """
848
  return gr.update()
849
 
850
+ def deinitialize(self):
851
+ """deinitialize the model, implement if needed
852
+ """
853
+ pass
854
+
855
 
856
  class Base_Chat_Langchain_Client(BaseLLMModel):
857
  def __init__(self, model_name, user_name=""):
modules/presets.py CHANGED
@@ -87,6 +87,8 @@ LOCAL_MODELS = [
87
  "chatglm-6b-int4-ge",
88
  "chatglm2-6b",
89
  "chatglm2-6b-int4",
 
 
90
  "StableLM",
91
  "MOSS",
92
  "Llama-2-7B-Chat",
 
87
  "chatglm-6b-int4-ge",
88
  "chatglm2-6b",
89
  "chatglm2-6b-int4",
90
+ "chatglm3-6b",
91
+ "chatglm3-6b-32k",
92
  "StableLM",
93
  "MOSS",
94
  "Llama-2-7B-Chat",