Tuchuanhuhuhu commited on
Commit
93defe7
·
1 Parent(s): 11750f0

chore: 提取 models.py 中的类

Browse files
modules/models/ChatGLM.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ import platform
6
+
7
+ import colorama
8
+
9
+ from ..index_func import *
10
+ from ..presets import *
11
+ from ..utils import *
12
+ from .base_model import BaseLLMModel
13
+
14
+
15
+ class ChatGLM_Client(BaseLLMModel):
16
+ def __init__(self, model_name, user_name="") -> None:
17
+ super().__init__(model_name=model_name, user=user_name)
18
+ import torch
19
+ from transformers import AutoModel, AutoTokenizer
20
+ global CHATGLM_TOKENIZER, CHATGLM_MODEL
21
+ if CHATGLM_TOKENIZER is None or CHATGLM_MODEL is None:
22
+ system_name = platform.system()
23
+ model_path = None
24
+ if os.path.exists("models"):
25
+ model_dirs = os.listdir("models")
26
+ if model_name in model_dirs:
27
+ model_path = f"models/{model_name}"
28
+ if model_path is not None:
29
+ model_source = model_path
30
+ else:
31
+ model_source = f"THUDM/{model_name}"
32
+ CHATGLM_TOKENIZER = AutoTokenizer.from_pretrained(
33
+ model_source, trust_remote_code=True
34
+ )
35
+ quantified = False
36
+ if "int4" in model_name:
37
+ quantified = True
38
+ model = AutoModel.from_pretrained(
39
+ model_source, trust_remote_code=True
40
+ )
41
+ if torch.cuda.is_available():
42
+ # run on CUDA
43
+ logging.info("CUDA is available, using CUDA")
44
+ model = model.half().cuda()
45
+ # mps加速还存在一些问题,暂时不使用
46
+ elif system_name == "Darwin" and model_path is not None and not quantified:
47
+ logging.info("Running on macOS, using MPS")
48
+ # running on macOS and model already downloaded
49
+ model = model.half().to("mps")
50
+ else:
51
+ logging.info("GPU is not available, using CPU")
52
+ model = model.float()
53
+ model = model.eval()
54
+ CHATGLM_MODEL = model
55
+
56
+ def _get_glm_style_input(self):
57
+ history = [x["content"] for x in self.history]
58
+ query = history.pop()
59
+ logging.debug(colorama.Fore.YELLOW +
60
+ f"{history}" + colorama.Fore.RESET)
61
+ assert (
62
+ len(history) % 2 == 0
63
+ ), f"History should be even length. current history is: {history}"
64
+ history = [[history[i], history[i + 1]]
65
+ for i in range(0, len(history), 2)]
66
+ return history, query
67
+
68
+ def get_answer_at_once(self):
69
+ history, query = self._get_glm_style_input()
70
+ response, _ = CHATGLM_MODEL.chat(
71
+ CHATGLM_TOKENIZER, query, history=history)
72
+ return response, len(response)
73
+
74
+ def get_answer_stream_iter(self):
75
+ history, query = self._get_glm_style_input()
76
+ for response, history in CHATGLM_MODEL.stream_chat(
77
+ CHATGLM_TOKENIZER,
78
+ query,
79
+ history,
80
+ max_length=self.token_upper_limit,
81
+ top_p=self.top_p,
82
+ temperature=self.temperature,
83
+ ):
84
+ yield response
modules/models/{Google_PaLM.py → GooglePaLM.py} RENAMED
@@ -1,6 +1,7 @@
1
  from .base_model import BaseLLMModel
2
  import google.generativeai as palm
3
 
 
4
  class Google_PaLM_Client(BaseLLMModel):
5
  def __init__(self, model_name, api_key, user_name="") -> None:
6
  super().__init__(model_name=model_name, user=user_name)
@@ -18,9 +19,11 @@ class Google_PaLM_Client(BaseLLMModel):
18
  def get_answer_at_once(self):
19
  palm.configure(api_key=self.api_key)
20
  messages = self._get_palm_style_input()
21
- response = palm.chat(context=self.system_prompt, messages=messages, temperature=self.temperature, top_p=self.top_p)
 
22
  if response.last is not None:
23
  return response.last, len(response.last)
24
  else:
25
- reasons = '\n\n'.join(reason['reason'].name for reason in response.filters)
26
- return "由于下面的原因,Google 拒绝返回 PaLM 的回答:\n\n" + reasons, 0
 
 
1
  from .base_model import BaseLLMModel
2
  import google.generativeai as palm
3
 
4
+
5
  class Google_PaLM_Client(BaseLLMModel):
6
  def __init__(self, model_name, api_key, user_name="") -> None:
7
  super().__init__(model_name=model_name, user=user_name)
 
19
  def get_answer_at_once(self):
20
  palm.configure(api_key=self.api_key)
21
  messages = self._get_palm_style_input()
22
+ response = palm.chat(context=self.system_prompt, messages=messages,
23
+ temperature=self.temperature, top_p=self.top_p)
24
  if response.last is not None:
25
  return response.last, len(response.last)
26
  else:
27
+ reasons = '\n\n'.join(
28
+ reason['reason'].name for reason in response.filters)
29
+ return "由于下面的原因,Google 拒绝返回 PaLM 的回答:\n\n" + reasons, 0
modules/models/LLaMA.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+
6
+ from ..index_func import *
7
+ from ..presets import *
8
+ from ..utils import *
9
+ from .base_model import BaseLLMModel
10
+
11
+
12
+ class LLaMA_Client(BaseLLMModel):
13
+ def __init__(
14
+ self,
15
+ model_name,
16
+ lora_path=None,
17
+ user_name=""
18
+ ) -> None:
19
+ super().__init__(model_name=model_name, user=user_name)
20
+ from lmflow.args import (DatasetArguments, InferencerArguments,
21
+ ModelArguments)
22
+ from lmflow.datasets.dataset import Dataset
23
+ from lmflow.models.auto_model import AutoModel
24
+ from lmflow.pipeline.auto_pipeline import AutoPipeline
25
+
26
+ self.max_generation_token = 1000
27
+ self.end_string = "\n\n"
28
+ # We don't need input data
29
+ data_args = DatasetArguments(dataset_path=None)
30
+ self.dataset = Dataset(data_args)
31
+ self.system_prompt = ""
32
+
33
+ global LLAMA_MODEL, LLAMA_INFERENCER
34
+ if LLAMA_MODEL is None or LLAMA_INFERENCER is None:
35
+ model_path = None
36
+ if os.path.exists("models"):
37
+ model_dirs = os.listdir("models")
38
+ if model_name in model_dirs:
39
+ model_path = f"models/{model_name}"
40
+ if model_path is not None:
41
+ model_source = model_path
42
+ else:
43
+ model_source = f"decapoda-research/{model_name}"
44
+ # raise Exception(f"models目录下没有这个模型: {model_name}")
45
+ if lora_path is not None:
46
+ lora_path = f"lora/{lora_path}"
47
+ model_args = ModelArguments(model_name_or_path=model_source, lora_model_path=lora_path, model_type=None, config_overrides=None, config_name=None, tokenizer_name=None, cache_dir=None,
48
+ use_fast_tokenizer=True, model_revision='main', use_auth_token=False, torch_dtype=None, use_lora=False, lora_r=8, lora_alpha=32, lora_dropout=0.1, use_ram_optimized_load=True)
49
+ pipeline_args = InferencerArguments(
50
+ local_rank=0, random_seed=1, deepspeed='configs/ds_config_chatbot.json', mixed_precision='bf16')
51
+
52
+ with open(pipeline_args.deepspeed, "r", encoding="utf-8") as f:
53
+ ds_config = json.load(f)
54
+ LLAMA_MODEL = AutoModel.get_model(
55
+ model_args,
56
+ tune_strategy="none",
57
+ ds_config=ds_config,
58
+ )
59
+ LLAMA_INFERENCER = AutoPipeline.get_pipeline(
60
+ pipeline_name="inferencer",
61
+ model_args=model_args,
62
+ data_args=data_args,
63
+ pipeline_args=pipeline_args,
64
+ )
65
+
66
+ def _get_llama_style_input(self):
67
+ history = []
68
+ instruction = ""
69
+ if self.system_prompt:
70
+ instruction = (f"Instruction: {self.system_prompt}\n")
71
+ for x in self.history:
72
+ if x["role"] == "user":
73
+ history.append(f"{instruction}Input: {x['content']}")
74
+ else:
75
+ history.append(f"Output: {x['content']}")
76
+ context = "\n\n".join(history)
77
+ context += "\n\nOutput: "
78
+ return context
79
+
80
+ def get_answer_at_once(self):
81
+ context = self._get_llama_style_input()
82
+
83
+ input_dataset = self.dataset.from_dict(
84
+ {"type": "text_only", "instances": [{"text": context}]}
85
+ )
86
+
87
+ output_dataset = LLAMA_INFERENCER.inference(
88
+ model=LLAMA_MODEL,
89
+ dataset=input_dataset,
90
+ max_new_tokens=self.max_generation_token,
91
+ temperature=self.temperature,
92
+ )
93
+
94
+ response = output_dataset.to_dict()["instances"][0]["text"]
95
+ return response, len(response)
96
+
97
+ def get_answer_stream_iter(self):
98
+ context = self._get_llama_style_input()
99
+ partial_text = ""
100
+ step = 1
101
+ for _ in range(0, self.max_generation_token, step):
102
+ input_dataset = self.dataset.from_dict(
103
+ {"type": "text_only", "instances": [
104
+ {"text": context + partial_text}]}
105
+ )
106
+ output_dataset = LLAMA_INFERENCER.inference(
107
+ model=LLAMA_MODEL,
108
+ dataset=input_dataset,
109
+ max_new_tokens=step,
110
+ temperature=self.temperature,
111
+ )
112
+ response = output_dataset.to_dict()["instances"][0]["text"]
113
+ if response == "" or response == self.end_string:
114
+ break
115
+ partial_text += response
116
+ yield partial_text
modules/models/OpenAI.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging
5
+
6
+ import colorama
7
+ import requests
8
+
9
+ from .. import shared
10
+ from ..config import retrieve_proxy, sensitive_id, usage_limit
11
+ from ..index_func import *
12
+ from ..presets import *
13
+ from ..utils import *
14
+ from .base_model import BaseLLMModel
15
+
16
+
17
+ class OpenAIClient(BaseLLMModel):
18
+ def __init__(
19
+ self,
20
+ model_name,
21
+ api_key,
22
+ system_prompt=INITIAL_SYSTEM_PROMPT,
23
+ temperature=1.0,
24
+ top_p=1.0,
25
+ user_name=""
26
+ ) -> None:
27
+ super().__init__(
28
+ model_name=model_name,
29
+ temperature=temperature,
30
+ top_p=top_p,
31
+ system_prompt=system_prompt,
32
+ user=user_name
33
+ )
34
+ self.api_key = api_key
35
+ self.need_api_key = True
36
+ self._refresh_header()
37
+
38
+ def get_answer_stream_iter(self):
39
+ response = self._get_response(stream=True)
40
+ if response is not None:
41
+ iter = self._decode_chat_response(response)
42
+ partial_text = ""
43
+ for i in iter:
44
+ partial_text += i
45
+ yield partial_text
46
+ else:
47
+ yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG
48
+
49
+ def get_answer_at_once(self):
50
+ response = self._get_response()
51
+ response = json.loads(response.text)
52
+ content = response["choices"][0]["message"]["content"]
53
+ total_token_count = response["usage"]["total_tokens"]
54
+ return content, total_token_count
55
+
56
+ def count_token(self, user_input):
57
+ input_token_count = count_token(construct_user(user_input))
58
+ if self.system_prompt is not None and len(self.all_token_counts) == 0:
59
+ system_prompt_token_count = count_token(
60
+ construct_system(self.system_prompt)
61
+ )
62
+ return input_token_count + system_prompt_token_count
63
+ return input_token_count
64
+
65
+ def billing_info(self):
66
+ try:
67
+ curr_time = datetime.datetime.now()
68
+ last_day_of_month = get_last_day_of_month(
69
+ curr_time).strftime("%Y-%m-%d")
70
+ first_day_of_month = curr_time.replace(day=1).strftime("%Y-%m-%d")
71
+ usage_url = f"{shared.state.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}"
72
+ try:
73
+ usage_data = self._get_billing_data(usage_url)
74
+ except Exception as e:
75
+ # logging.error(f"获取API使用情况失败: " + str(e))
76
+ if "Invalid authorization header" in str(e):
77
+ return i18n("**获取API使用情况失败**,需在填写`config.json`中正确填写sensitive_id")
78
+ elif "Incorrect API key provided: sess" in str(e):
79
+ return i18n("**获取API使用情况失败**,sensitive_id错误或已过期")
80
+ return i18n("**获取API使用情况失败**")
81
+ # rounded_usage = "{:.5f}".format(usage_data["total_usage"] / 100)
82
+ rounded_usage = round(usage_data["total_usage"] / 100, 5)
83
+ usage_percent = round(usage_data["total_usage"] / usage_limit, 2)
84
+ from ..webui import get_html
85
+
86
+ # return i18n("**本月使用金额** ") + f"\u3000 ${rounded_usage}"
87
+ return get_html("billing_info.html").format(
88
+ label=i18n("本月使用金额"),
89
+ usage_percent=usage_percent,
90
+ rounded_usage=rounded_usage,
91
+ usage_limit=usage_limit
92
+ )
93
+ except requests.exceptions.ConnectTimeout:
94
+ status_text = (
95
+ STANDARD_ERROR_MSG + CONNECTION_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
96
+ )
97
+ return status_text
98
+ except requests.exceptions.ReadTimeout:
99
+ status_text = STANDARD_ERROR_MSG + READ_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
100
+ return status_text
101
+ except Exception as e:
102
+ import traceback
103
+ traceback.print_exc()
104
+ logging.error(i18n("获取API使用情况失败:") + str(e))
105
+ return STANDARD_ERROR_MSG + ERROR_RETRIEVE_MSG
106
+
107
+ def set_token_upper_limit(self, new_upper_limit):
108
+ pass
109
+
110
+ @shared.state.switching_api_key # 在不开启多账号模式的时候,这个装饰器不会起作用
111
+ def _get_response(self, stream=False):
112
+ openai_api_key = self.api_key
113
+ system_prompt = self.system_prompt
114
+ history = self.history
115
+ logging.debug(colorama.Fore.YELLOW +
116
+ f"{history}" + colorama.Fore.RESET)
117
+ headers = {
118
+ "Content-Type": "application/json",
119
+ "Authorization": f"Bearer {openai_api_key}",
120
+ }
121
+
122
+ if system_prompt is not None:
123
+ history = [construct_system(system_prompt), *history]
124
+
125
+ payload = {
126
+ "model": self.model_name,
127
+ "messages": history,
128
+ "temperature": self.temperature,
129
+ "top_p": self.top_p,
130
+ "n": self.n_choices,
131
+ "stream": stream,
132
+ "presence_penalty": self.presence_penalty,
133
+ "frequency_penalty": self.frequency_penalty,
134
+ }
135
+
136
+ if self.max_generation_token is not None:
137
+ payload["max_tokens"] = self.max_generation_token
138
+ if self.stop_sequence is not None:
139
+ payload["stop"] = self.stop_sequence
140
+ if self.logit_bias is not None:
141
+ payload["logit_bias"] = self.logit_bias
142
+ if self.user_identifier:
143
+ payload["user"] = self.user_identifier
144
+
145
+ if stream:
146
+ timeout = TIMEOUT_STREAMING
147
+ else:
148
+ timeout = TIMEOUT_ALL
149
+
150
+ # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
151
+ if shared.state.completion_url != COMPLETION_URL:
152
+ logging.debug(f"使用自定义API URL: {shared.state.completion_url}")
153
+
154
+ with retrieve_proxy():
155
+ try:
156
+ response = requests.post(
157
+ shared.state.completion_url,
158
+ headers=headers,
159
+ json=payload,
160
+ stream=stream,
161
+ timeout=timeout,
162
+ )
163
+ except:
164
+ return None
165
+ return response
166
+
167
+ def _refresh_header(self):
168
+ self.headers = {
169
+ "Content-Type": "application/json",
170
+ "Authorization": f"Bearer {sensitive_id}",
171
+ }
172
+
173
+ def _get_billing_data(self, billing_url):
174
+ with retrieve_proxy():
175
+ response = requests.get(
176
+ billing_url,
177
+ headers=self.headers,
178
+ timeout=TIMEOUT_ALL,
179
+ )
180
+
181
+ if response.status_code == 200:
182
+ data = response.json()
183
+ return data
184
+ else:
185
+ raise Exception(
186
+ f"API request failed with status code {response.status_code}: {response.text}"
187
+ )
188
+
189
+ def _decode_chat_response(self, response):
190
+ error_msg = ""
191
+ for chunk in response.iter_lines():
192
+ if chunk:
193
+ chunk = chunk.decode()
194
+ chunk_length = len(chunk)
195
+ try:
196
+ chunk = json.loads(chunk[6:])
197
+ except:
198
+ print(i18n("JSON解析错误,收到的内容: ") + f"{chunk}")
199
+ error_msg += chunk
200
+ continue
201
+ if chunk_length > 6 and "delta" in chunk["choices"][0]:
202
+ if chunk["choices"][0]["finish_reason"] == "stop":
203
+ break
204
+ try:
205
+ yield chunk["choices"][0]["delta"]["content"]
206
+ except Exception as e:
207
+ # logging.error(f"Error: {e}")
208
+ continue
209
+ if error_msg:
210
+ raise Exception(error_msg)
211
+
212
+ def set_key(self, new_access_key):
213
+ ret = super().set_key(new_access_key)
214
+ self._refresh_header()
215
+ return ret
216
+
217
+ def _single_query_at_once(self, history, temperature=1.0):
218
+ timeout = TIMEOUT_ALL
219
+ headers = {
220
+ "Content-Type": "application/json",
221
+ "Authorization": f"Bearer {self.api_key}",
222
+ "temperature": f"{temperature}",
223
+ }
224
+ payload = {
225
+ "model": self.model_name,
226
+ "messages": history,
227
+ }
228
+ # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
229
+ if shared.state.completion_url != COMPLETION_URL:
230
+ logging.debug(f"使用自定义API URL: {shared.state.completion_url}")
231
+
232
+ with retrieve_proxy():
233
+ response = requests.post(
234
+ shared.state.completion_url,
235
+ headers=headers,
236
+ json=payload,
237
+ stream=False,
238
+ timeout=timeout,
239
+ )
240
+
241
+ return response
242
+
243
+ def auto_name_chat_history(self, name_chat_method, user_question, chatbot, user_name, single_turn_checkbox):
244
+ if len(self.history) == 2 and not single_turn_checkbox:
245
+ user_question = self.history[0]["content"]
246
+ if name_chat_method == i18n("模型自动总结(消耗tokens)"):
247
+ ai_answer = self.history[1]["content"]
248
+ try:
249
+ history = [
250
+ {"role": "system", "content": SUMMARY_CHAT_SYSTEM_PROMPT},
251
+ {"role": "user", "content": f"Please write a title based on the following conversation:\n---\nUser: {user_question}\nAssistant: {ai_answer}"}
252
+ ]
253
+ response = self._single_query_at_once(
254
+ history, temperature=0.0)
255
+ response = json.loads(response.text)
256
+ content = response["choices"][0]["message"]["content"]
257
+ filename = replace_special_symbols(content) + ".json"
258
+ except Exception as e:
259
+ logging.info(f"自动命名失败。{e}")
260
+ filename = replace_special_symbols(user_question)[
261
+ :16] + ".json"
262
+ return self.rename_chat_history(filename, chatbot, user_name)
263
+ elif name_chat_method == i18n("第一条提问"):
264
+ filename = replace_special_symbols(user_question)[
265
+ :16] + ".json"
266
+ return self.rename_chat_history(filename, chatbot, user_name)
267
+ else:
268
+ return gr.update()
269
+ else:
270
+ return gr.update()
modules/models/XMChat.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import json
5
+ import logging
6
+ import os
7
+ import uuid
8
+ from io import BytesIO
9
+
10
+ import requests
11
+ from PIL import Image
12
+
13
+ from ..index_func import *
14
+ from ..presets import *
15
+ from ..utils import *
16
+ from .base_model import BaseLLMModel
17
+
18
+
19
+ class XMChatClient(BaseLLMModel):
20
+ def __init__(self, api_key, user_name=""):
21
+ super().__init__(model_name="xmchat", user=user_name)
22
+ self.api_key = api_key
23
+ self.session_id = None
24
+ self.reset()
25
+ self.image_bytes = None
26
+ self.image_path = None
27
+ self.xm_history = []
28
+ self.url = "https://xmbot.net/web"
29
+ self.last_conv_id = None
30
+
31
+ def reset(self):
32
+ self.session_id = str(uuid.uuid4())
33
+ self.last_conv_id = None
34
+ return [], "已重置"
35
+
36
+ def image_to_base64(self, image_path):
37
+ # 打开并加载图片
38
+ img = Image.open(image_path)
39
+
40
+ # 获取图片的宽度和高度
41
+ width, height = img.size
42
+
43
+ # 计算压缩比例,以确保最长边小于4096像素
44
+ max_dimension = 2048
45
+ scale_ratio = min(max_dimension / width, max_dimension / height)
46
+
47
+ if scale_ratio < 1:
48
+ # 按压缩比例调整图片大小
49
+ new_width = int(width * scale_ratio)
50
+ new_height = int(height * scale_ratio)
51
+ img = img.resize((new_width, new_height), Image.ANTIALIAS)
52
+
53
+ # 将图片转换为jpg格式的二进制数据
54
+ buffer = BytesIO()
55
+ if img.mode == "RGBA":
56
+ img = img.convert("RGB")
57
+ img.save(buffer, format='JPEG')
58
+ binary_image = buffer.getvalue()
59
+
60
+ # 对二进制数据进行Base64编码
61
+ base64_image = base64.b64encode(binary_image).decode('utf-8')
62
+
63
+ return base64_image
64
+
65
+ def try_read_image(self, filepath):
66
+ def is_image_file(filepath):
67
+ # 判断文件是否为图片
68
+ valid_image_extensions = [
69
+ ".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"]
70
+ file_extension = os.path.splitext(filepath)[1].lower()
71
+ return file_extension in valid_image_extensions
72
+
73
+ if is_image_file(filepath):
74
+ logging.info(f"读取图片文件: {filepath}")
75
+ self.image_bytes = self.image_to_base64(filepath)
76
+ self.image_path = filepath
77
+ else:
78
+ self.image_bytes = None
79
+ self.image_path = None
80
+
81
+ def like(self):
82
+ if self.last_conv_id is None:
83
+ return "点赞失败,你还没发送过消息"
84
+ data = {
85
+ "uuid": self.last_conv_id,
86
+ "appraise": "good"
87
+ }
88
+ requests.post(self.url, json=data)
89
+ return "👍点赞成功,感谢反馈~"
90
+
91
+ def dislike(self):
92
+ if self.last_conv_id is None:
93
+ return "点踩失败,你还没发送过消息"
94
+ data = {
95
+ "uuid": self.last_conv_id,
96
+ "appraise": "bad"
97
+ }
98
+ requests.post(self.url, json=data)
99
+ return "👎点踩成功,感谢反馈~"
100
+
101
+ def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot):
102
+ fake_inputs = real_inputs
103
+ display_append = ""
104
+ limited_context = False
105
+ return limited_context, fake_inputs, display_append, real_inputs, chatbot
106
+
107
+ def handle_file_upload(self, files, chatbot, language):
108
+ """if the model accepts multi modal input, implement this function"""
109
+ if files:
110
+ for file in files:
111
+ if file.name:
112
+ logging.info(f"尝试读取图像: {file.name}")
113
+ self.try_read_image(file.name)
114
+ if self.image_path is not None:
115
+ chatbot = chatbot + [((self.image_path,), None)]
116
+ if self.image_bytes is not None:
117
+ logging.info("使用图片作为输入")
118
+ # XMChat的一轮对话中实际上只能处理一张图片
119
+ self.reset()
120
+ conv_id = str(uuid.uuid4())
121
+ data = {
122
+ "user_id": self.api_key,
123
+ "session_id": self.session_id,
124
+ "uuid": conv_id,
125
+ "data_type": "imgbase64",
126
+ "data": self.image_bytes
127
+ }
128
+ response = requests.post(self.url, json=data)
129
+ response = json.loads(response.text)
130
+ logging.info(f"图片回复: {response['data']}")
131
+ return None, chatbot, None
132
+
133
+ def get_answer_at_once(self):
134
+ question = self.history[-1]["content"]
135
+ conv_id = str(uuid.uuid4())
136
+ self.last_conv_id = conv_id
137
+ data = {
138
+ "user_id": self.api_key,
139
+ "session_id": self.session_id,
140
+ "uuid": conv_id,
141
+ "data_type": "text",
142
+ "data": question
143
+ }
144
+ response = requests.post(self.url, json=data)
145
+ try:
146
+ response = json.loads(response.text)
147
+ return response["data"], len(response["data"])
148
+ except Exception as e:
149
+ return response.text, len(response.text)
modules/models/models.py CHANGED
@@ -1,599 +1,19 @@
1
  from __future__ import annotations
2
- from typing import TYPE_CHECKING, List
3
 
4
  import logging
5
- import json
6
- import commentjson as cjson
7
  import os
8
- import sys
9
- import requests
10
- import urllib3
11
- import platform
12
- import base64
13
- from io import BytesIO
14
- from PIL import Image
15
 
16
- from tqdm import tqdm
17
  import colorama
18
- import asyncio
19
- import aiohttp
20
- from enum import Enum
21
- import uuid
22
 
23
- from ..presets import *
24
  from ..index_func import *
 
25
  from ..utils import *
26
- from .. import shared
27
- from ..config import retrieve_proxy, usage_limit, sensitive_id
28
- from modules import config
29
  from .base_model import BaseLLMModel, ModelType
30
 
31
 
32
- class OpenAIClient(BaseLLMModel):
33
- def __init__(
34
- self,
35
- model_name,
36
- api_key,
37
- system_prompt=INITIAL_SYSTEM_PROMPT,
38
- temperature=1.0,
39
- top_p=1.0,
40
- user_name=""
41
- ) -> None:
42
- super().__init__(
43
- model_name=model_name,
44
- temperature=temperature,
45
- top_p=top_p,
46
- system_prompt=system_prompt,
47
- user=user_name
48
- )
49
- self.api_key = api_key
50
- self.need_api_key = True
51
- self._refresh_header()
52
-
53
- def get_answer_stream_iter(self):
54
- response = self._get_response(stream=True)
55
- if response is not None:
56
- iter = self._decode_chat_response(response)
57
- partial_text = ""
58
- for i in iter:
59
- partial_text += i
60
- yield partial_text
61
- else:
62
- yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG
63
-
64
- def get_answer_at_once(self):
65
- response = self._get_response()
66
- response = json.loads(response.text)
67
- content = response["choices"][0]["message"]["content"]
68
- total_token_count = response["usage"]["total_tokens"]
69
- return content, total_token_count
70
-
71
- def count_token(self, user_input):
72
- input_token_count = count_token(construct_user(user_input))
73
- if self.system_prompt is not None and len(self.all_token_counts) == 0:
74
- system_prompt_token_count = count_token(
75
- construct_system(self.system_prompt)
76
- )
77
- return input_token_count + system_prompt_token_count
78
- return input_token_count
79
-
80
- def billing_info(self):
81
- try:
82
- curr_time = datetime.datetime.now()
83
- last_day_of_month = get_last_day_of_month(
84
- curr_time).strftime("%Y-%m-%d")
85
- first_day_of_month = curr_time.replace(day=1).strftime("%Y-%m-%d")
86
- usage_url = f"{shared.state.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}"
87
- try:
88
- usage_data = self._get_billing_data(usage_url)
89
- except Exception as e:
90
- # logging.error(f"获取API使用情况失败: " + str(e))
91
- if "Invalid authorization header" in str(e):
92
- return i18n("**获取API使用情况失败**,需在填写`config.json`中正确填写sensitive_id")
93
- elif "Incorrect API key provided: sess" in str(e):
94
- return i18n("**获取API使用情况失败**,sensitive_id错误或已过期")
95
- return i18n("**获取API使用情况失败**")
96
- # rounded_usage = "{:.5f}".format(usage_data["total_usage"] / 100)
97
- rounded_usage = round(usage_data["total_usage"] / 100, 5)
98
- usage_percent = round(usage_data["total_usage"] / usage_limit, 2)
99
- from ..webui import get_html
100
- # return i18n("**本月使用金额** ") + f"\u3000 ${rounded_usage}"
101
- return get_html("billing_info.html").format(
102
- label = i18n("本月使用金额"),
103
- usage_percent = usage_percent,
104
- rounded_usage = rounded_usage,
105
- usage_limit = usage_limit
106
- )
107
- except requests.exceptions.ConnectTimeout:
108
- status_text = (
109
- STANDARD_ERROR_MSG + CONNECTION_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
110
- )
111
- return status_text
112
- except requests.exceptions.ReadTimeout:
113
- status_text = STANDARD_ERROR_MSG + READ_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
114
- return status_text
115
- except Exception as e:
116
- import traceback
117
- traceback.print_exc()
118
- logging.error(i18n("获取API使用情况失败:") + str(e))
119
- return STANDARD_ERROR_MSG + ERROR_RETRIEVE_MSG
120
-
121
- def set_token_upper_limit(self, new_upper_limit):
122
- pass
123
-
124
- @shared.state.switching_api_key # 在不开启多账号模式的时候,这个装饰器不会起作用
125
- def _get_response(self, stream=False):
126
- openai_api_key = self.api_key
127
- system_prompt = self.system_prompt
128
- history = self.history
129
- logging.debug(colorama.Fore.YELLOW +
130
- f"{history}" + colorama.Fore.RESET)
131
- headers = {
132
- "Content-Type": "application/json",
133
- "Authorization": f"Bearer {openai_api_key}",
134
- }
135
-
136
- if system_prompt is not None:
137
- history = [construct_system(system_prompt), *history]
138
-
139
- payload = {
140
- "model": self.model_name,
141
- "messages": history,
142
- "temperature": self.temperature,
143
- "top_p": self.top_p,
144
- "n": self.n_choices,
145
- "stream": stream,
146
- "presence_penalty": self.presence_penalty,
147
- "frequency_penalty": self.frequency_penalty,
148
- }
149
-
150
- if self.max_generation_token is not None:
151
- payload["max_tokens"] = self.max_generation_token
152
- if self.stop_sequence is not None:
153
- payload["stop"] = self.stop_sequence
154
- if self.logit_bias is not None:
155
- payload["logit_bias"] = self.logit_bias
156
- if self.user_identifier:
157
- payload["user"] = self.user_identifier
158
-
159
- if stream:
160
- timeout = TIMEOUT_STREAMING
161
- else:
162
- timeout = TIMEOUT_ALL
163
-
164
- # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
165
- if shared.state.completion_url != COMPLETION_URL:
166
- logging.debug(f"使用自定义API URL: {shared.state.completion_url}")
167
-
168
- with retrieve_proxy():
169
- try:
170
- response = requests.post(
171
- shared.state.completion_url,
172
- headers=headers,
173
- json=payload,
174
- stream=stream,
175
- timeout=timeout,
176
- )
177
- except:
178
- return None
179
- return response
180
-
181
- def _refresh_header(self):
182
- self.headers = {
183
- "Content-Type": "application/json",
184
- "Authorization": f"Bearer {sensitive_id}",
185
- }
186
-
187
-
188
- def _get_billing_data(self, billing_url):
189
- with retrieve_proxy():
190
- response = requests.get(
191
- billing_url,
192
- headers=self.headers,
193
- timeout=TIMEOUT_ALL,
194
- )
195
-
196
- if response.status_code == 200:
197
- data = response.json()
198
- return data
199
- else:
200
- raise Exception(
201
- f"API request failed with status code {response.status_code}: {response.text}"
202
- )
203
-
204
- def _decode_chat_response(self, response):
205
- error_msg = ""
206
- for chunk in response.iter_lines():
207
- if chunk:
208
- chunk = chunk.decode()
209
- chunk_length = len(chunk)
210
- try:
211
- chunk = json.loads(chunk[6:])
212
- except:
213
- print(i18n("JSON解析错误,收到的内容: ") + f"{chunk}")
214
- error_msg += chunk
215
- continue
216
- if chunk_length > 6 and "delta" in chunk["choices"][0]:
217
- if chunk["choices"][0]["finish_reason"] == "stop":
218
- break
219
- try:
220
- yield chunk["choices"][0]["delta"]["content"]
221
- except Exception as e:
222
- # logging.error(f"Error: {e}")
223
- continue
224
- if error_msg:
225
- raise Exception(error_msg)
226
-
227
- def set_key(self, new_access_key):
228
- ret = super().set_key(new_access_key)
229
- self._refresh_header()
230
- return ret
231
-
232
- def _single_query_at_once(self, history, temperature=1.0):
233
- timeout = TIMEOUT_ALL
234
- headers = {
235
- "Content-Type": "application/json",
236
- "Authorization": f"Bearer {self.api_key}",
237
- "temperature": f"{temperature}",
238
- }
239
- payload = {
240
- "model": self.model_name,
241
- "messages": history,
242
- }
243
- # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
244
- if shared.state.completion_url != COMPLETION_URL:
245
- logging.debug(f"使用自定义API URL: {shared.state.completion_url}")
246
-
247
- with retrieve_proxy():
248
- response = requests.post(
249
- shared.state.completion_url,
250
- headers=headers,
251
- json=payload,
252
- stream=False,
253
- timeout=timeout,
254
- )
255
-
256
- return response
257
-
258
-
259
- def auto_name_chat_history(self, name_chat_method, user_question, chatbot, user_name, single_turn_checkbox):
260
- if len(self.history) == 2 and not single_turn_checkbox:
261
- user_question = self.history[0]["content"]
262
- if name_chat_method == i18n("模型自动总结(消耗tokens)"):
263
- ai_answer = self.history[1]["content"]
264
- try:
265
- history = [
266
- { "role": "system", "content": SUMMARY_CHAT_SYSTEM_PROMPT},
267
- { "role": "user", "content": f"Please write a title based on the following conversation:\n---\nUser: {user_question}\nAssistant: {ai_answer}"}
268
- ]
269
- response = self._single_query_at_once(history, temperature=0.0)
270
- response = json.loads(response.text)
271
- content = response["choices"][0]["message"]["content"]
272
- filename = replace_special_symbols(content) + ".json"
273
- except Exception as e:
274
- logging.info(f"自动命名失败。{e}")
275
- filename = replace_special_symbols(user_question)[:16] + ".json"
276
- return self.rename_chat_history(filename, chatbot, user_name)
277
- elif name_chat_method == i18n("第一条提问"):
278
- filename = replace_special_symbols(user_question)[:16] + ".json"
279
- return self.rename_chat_history(filename, chatbot, user_name)
280
- else:
281
- return gr.update()
282
- else:
283
- return gr.update()
284
-
285
-
286
- class ChatGLM_Client(BaseLLMModel):
287
- def __init__(self, model_name, user_name="") -> None:
288
- super().__init__(model_name=model_name, user=user_name)
289
- from transformers import AutoTokenizer, AutoModel
290
- import torch
291
- global CHATGLM_TOKENIZER, CHATGLM_MODEL
292
- if CHATGLM_TOKENIZER is None or CHATGLM_MODEL is None:
293
- system_name = platform.system()
294
- model_path = None
295
- if os.path.exists("models"):
296
- model_dirs = os.listdir("models")
297
- if model_name in model_dirs:
298
- model_path = f"models/{model_name}"
299
- if model_path is not None:
300
- model_source = model_path
301
- else:
302
- model_source = f"THUDM/{model_name}"
303
- CHATGLM_TOKENIZER = AutoTokenizer.from_pretrained(
304
- model_source, trust_remote_code=True
305
- )
306
- quantified = False
307
- if "int4" in model_name:
308
- quantified = True
309
- model = AutoModel.from_pretrained(
310
- model_source, trust_remote_code=True
311
- )
312
- if torch.cuda.is_available():
313
- # run on CUDA
314
- logging.info("CUDA is available, using CUDA")
315
- model = model.half().cuda()
316
- # mps加速还存在一些问题,暂时不使用
317
- elif system_name == "Darwin" and model_path is not None and not quantified:
318
- logging.info("Running on macOS, using MPS")
319
- # running on macOS and model already downloaded
320
- model = model.half().to("mps")
321
- else:
322
- logging.info("GPU is not available, using CPU")
323
- model = model.float()
324
- model = model.eval()
325
- CHATGLM_MODEL = model
326
-
327
- def _get_glm_style_input(self):
328
- history = [x["content"] for x in self.history]
329
- query = history.pop()
330
- logging.debug(colorama.Fore.YELLOW +
331
- f"{history}" + colorama.Fore.RESET)
332
- assert (
333
- len(history) % 2 == 0
334
- ), f"History should be even length. current history is: {history}"
335
- history = [[history[i], history[i + 1]]
336
- for i in range(0, len(history), 2)]
337
- return history, query
338
-
339
- def get_answer_at_once(self):
340
- history, query = self._get_glm_style_input()
341
- response, _ = CHATGLM_MODEL.chat(
342
- CHATGLM_TOKENIZER, query, history=history)
343
- return response, len(response)
344
-
345
- def get_answer_stream_iter(self):
346
- history, query = self._get_glm_style_input()
347
- for response, history in CHATGLM_MODEL.stream_chat(
348
- CHATGLM_TOKENIZER,
349
- query,
350
- history,
351
- max_length=self.token_upper_limit,
352
- top_p=self.top_p,
353
- temperature=self.temperature,
354
- ):
355
- yield response
356
-
357
-
358
- class LLaMA_Client(BaseLLMModel):
359
- def __init__(
360
- self,
361
- model_name,
362
- lora_path=None,
363
- user_name=""
364
- ) -> None:
365
- super().__init__(model_name=model_name, user=user_name)
366
- from lmflow.datasets.dataset import Dataset
367
- from lmflow.pipeline.auto_pipeline import AutoPipeline
368
- from lmflow.models.auto_model import AutoModel
369
- from lmflow.args import ModelArguments, DatasetArguments, InferencerArguments
370
-
371
- self.max_generation_token = 1000
372
- self.end_string = "\n\n"
373
- # We don't need input data
374
- data_args = DatasetArguments(dataset_path=None)
375
- self.dataset = Dataset(data_args)
376
- self.system_prompt = ""
377
-
378
- global LLAMA_MODEL, LLAMA_INFERENCER
379
- if LLAMA_MODEL is None or LLAMA_INFERENCER is None:
380
- model_path = None
381
- if os.path.exists("models"):
382
- model_dirs = os.listdir("models")
383
- if model_name in model_dirs:
384
- model_path = f"models/{model_name}"
385
- if model_path is not None:
386
- model_source = model_path
387
- else:
388
- model_source = f"decapoda-research/{model_name}"
389
- # raise Exception(f"models目录下没有这个模型: {model_name}")
390
- if lora_path is not None:
391
- lora_path = f"lora/{lora_path}"
392
- model_args = ModelArguments(model_name_or_path=model_source, lora_model_path=lora_path, model_type=None, config_overrides=None, config_name=None, tokenizer_name=None, cache_dir=None,
393
- use_fast_tokenizer=True, model_revision='main', use_auth_token=False, torch_dtype=None, use_lora=False, lora_r=8, lora_alpha=32, lora_dropout=0.1, use_ram_optimized_load=True)
394
- pipeline_args = InferencerArguments(
395
- local_rank=0, random_seed=1, deepspeed='configs/ds_config_chatbot.json', mixed_precision='bf16')
396
-
397
- with open(pipeline_args.deepspeed, "r", encoding="utf-8") as f:
398
- ds_config = json.load(f)
399
- LLAMA_MODEL = AutoModel.get_model(
400
- model_args,
401
- tune_strategy="none",
402
- ds_config=ds_config,
403
- )
404
- LLAMA_INFERENCER = AutoPipeline.get_pipeline(
405
- pipeline_name="inferencer",
406
- model_args=model_args,
407
- data_args=data_args,
408
- pipeline_args=pipeline_args,
409
- )
410
-
411
- def _get_llama_style_input(self):
412
- history = []
413
- instruction = ""
414
- if self.system_prompt:
415
- instruction = (f"Instruction: {self.system_prompt}\n")
416
- for x in self.history:
417
- if x["role"] == "user":
418
- history.append(f"{instruction}Input: {x['content']}")
419
- else:
420
- history.append(f"Output: {x['content']}")
421
- context = "\n\n".join(history)
422
- context += "\n\nOutput: "
423
- return context
424
-
425
- def get_answer_at_once(self):
426
- context = self._get_llama_style_input()
427
-
428
- input_dataset = self.dataset.from_dict(
429
- {"type": "text_only", "instances": [{"text": context}]}
430
- )
431
-
432
- output_dataset = LLAMA_INFERENCER.inference(
433
- model=LLAMA_MODEL,
434
- dataset=input_dataset,
435
- max_new_tokens=self.max_generation_token,
436
- temperature=self.temperature,
437
- )
438
-
439
- response = output_dataset.to_dict()["instances"][0]["text"]
440
- return response, len(response)
441
-
442
- def get_answer_stream_iter(self):
443
- context = self._get_llama_style_input()
444
- partial_text = ""
445
- step = 1
446
- for _ in range(0, self.max_generation_token, step):
447
- input_dataset = self.dataset.from_dict(
448
- {"type": "text_only", "instances": [
449
- {"text": context + partial_text}]}
450
- )
451
- output_dataset = LLAMA_INFERENCER.inference(
452
- model=LLAMA_MODEL,
453
- dataset=input_dataset,
454
- max_new_tokens=step,
455
- temperature=self.temperature,
456
- )
457
- response = output_dataset.to_dict()["instances"][0]["text"]
458
- if response == "" or response == self.end_string:
459
- break
460
- partial_text += response
461
- yield partial_text
462
-
463
-
464
- class XMChat(BaseLLMModel):
465
- def __init__(self, api_key, user_name=""):
466
- super().__init__(model_name="xmchat", user=user_name)
467
- self.api_key = api_key
468
- self.session_id = None
469
- self.reset()
470
- self.image_bytes = None
471
- self.image_path = None
472
- self.xm_history = []
473
- self.url = "https://xmbot.net/web"
474
- self.last_conv_id = None
475
-
476
- def reset(self):
477
- self.session_id = str(uuid.uuid4())
478
- self.last_conv_id = None
479
- return [], "已重置"
480
-
481
- def image_to_base64(self, image_path):
482
- # 打开并加载图片
483
- img = Image.open(image_path)
484
-
485
- # 获取图片的宽度和高度
486
- width, height = img.size
487
-
488
- # 计算压缩比例,以确保最长边小于4096像素
489
- max_dimension = 2048
490
- scale_ratio = min(max_dimension / width, max_dimension / height)
491
-
492
- if scale_ratio < 1:
493
- # 按压缩比例调整图片大小
494
- new_width = int(width * scale_ratio)
495
- new_height = int(height * scale_ratio)
496
- img = img.resize((new_width, new_height), Image.ANTIALIAS)
497
-
498
- # 将图片转换为jpg格式的二进制数据
499
- buffer = BytesIO()
500
- if img.mode == "RGBA":
501
- img = img.convert("RGB")
502
- img.save(buffer, format='JPEG')
503
- binary_image = buffer.getvalue()
504
-
505
- # 对二进制数据进行Base64编码
506
- base64_image = base64.b64encode(binary_image).decode('utf-8')
507
-
508
- return base64_image
509
-
510
- def try_read_image(self, filepath):
511
- def is_image_file(filepath):
512
- # 判断文件是否为图片
513
- valid_image_extensions = [
514
- ".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"]
515
- file_extension = os.path.splitext(filepath)[1].lower()
516
- return file_extension in valid_image_extensions
517
-
518
- if is_image_file(filepath):
519
- logging.info(f"读取图片文件: {filepath}")
520
- self.image_bytes = self.image_to_base64(filepath)
521
- self.image_path = filepath
522
- else:
523
- self.image_bytes = None
524
- self.image_path = None
525
-
526
- def like(self):
527
- if self.last_conv_id is None:
528
- return "点赞失败,你还没发送过消息"
529
- data = {
530
- "uuid": self.last_conv_id,
531
- "appraise": "good"
532
- }
533
- requests.post(self.url, json=data)
534
- return "👍点赞成功,感谢反馈~"
535
-
536
- def dislike(self):
537
- if self.last_conv_id is None:
538
- return "点踩失败,你还没发送过消息"
539
- data = {
540
- "uuid": self.last_conv_id,
541
- "appraise": "bad"
542
- }
543
- requests.post(self.url, json=data)
544
- return "👎点踩成功,感谢反馈~"
545
-
546
- def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot):
547
- fake_inputs = real_inputs
548
- display_append = ""
549
- limited_context = False
550
- return limited_context, fake_inputs, display_append, real_inputs, chatbot
551
-
552
- def handle_file_upload(self, files, chatbot, language):
553
- """if the model accepts multi modal input, implement this function"""
554
- if files:
555
- for file in files:
556
- if file.name:
557
- logging.info(f"尝试读取图像: {file.name}")
558
- self.try_read_image(file.name)
559
- if self.image_path is not None:
560
- chatbot = chatbot + [((self.image_path,), None)]
561
- if self.image_bytes is not None:
562
- logging.info("使用图片作为输入")
563
- # XMChat的一轮对话中实际上只能处理一张图片
564
- self.reset()
565
- conv_id = str(uuid.uuid4())
566
- data = {
567
- "user_id": self.api_key,
568
- "session_id": self.session_id,
569
- "uuid": conv_id,
570
- "data_type": "imgbase64",
571
- "data": self.image_bytes
572
- }
573
- response = requests.post(self.url, json=data)
574
- response = json.loads(response.text)
575
- logging.info(f"图片回复: {response['data']}")
576
- return None, chatbot, None
577
-
578
- def get_answer_at_once(self):
579
- question = self.history[-1]["content"]
580
- conv_id = str(uuid.uuid4())
581
- self.last_conv_id = conv_id
582
- data = {
583
- "user_id": self.api_key,
584
- "session_id": self.session_id,
585
- "uuid": conv_id,
586
- "data_type": "text",
587
- "data": question
588
- }
589
- response = requests.post(self.url, json=data)
590
- try:
591
- response = json.loads(response.text)
592
- return response["data"], len(response["data"])
593
- except Exception as e:
594
- return response.text, len(response.text)
595
-
596
-
597
  def get_model(
598
  model_name,
599
  lora_model_path=None,
@@ -616,6 +36,7 @@ def get_model(
616
  try:
617
  if model_type == ModelType.OpenAI:
618
  logging.info(f"正在加载OpenAI模型: {model_name}")
 
619
  access_key = os.environ.get("OPENAI_API_KEY", access_key)
620
  model = OpenAIClient(
621
  model_name=model_name,
@@ -627,6 +48,7 @@ def get_model(
627
  )
628
  elif model_type == ModelType.ChatGLM:
629
  logging.info(f"正在加载ChatGLM模型: {model_name}")
 
630
  model = ChatGLM_Client(model_name, user_name=user_name)
631
  elif model_type == ModelType.LLaMA and lora_model_path == "":
632
  msg = f"现在请为 {model_name} 选择LoRA模型"
@@ -637,6 +59,7 @@ def get_model(
637
  lora_choices = ["No LoRA"] + lora_choices
638
  elif model_type == ModelType.LLaMA and lora_model_path != "":
639
  logging.info(f"正在加载LLaMA模型: {model_name} + {lora_model_path}")
 
640
  dont_change_lora_selector = True
641
  if lora_model_path == "No LoRA":
642
  lora_model_path = None
@@ -646,9 +69,10 @@ def get_model(
646
  model = LLaMA_Client(
647
  model_name, lora_model_path, user_name=user_name)
648
  elif model_type == ModelType.XMChat:
 
649
  if os.environ.get("XMCHAT_API_KEY") != "":
650
  access_key = os.environ.get("XMCHAT_API_KEY")
651
- model = XMChat(api_key=access_key, user_name=user_name)
652
  elif model_type == ModelType.StableLM:
653
  from .StableLM import StableLM_Client
654
  model = StableLM_Client(model_name, user_name=user_name)
@@ -657,29 +81,35 @@ def get_model(
657
  model = MOSS_Client(model_name, user_name=user_name)
658
  elif model_type == ModelType.YuanAI:
659
  from .inspurai import Yuan_Client
660
- model = Yuan_Client(model_name, api_key=access_key, user_name=user_name, system_prompt=system_prompt)
 
661
  elif model_type == ModelType.Minimax:
662
  from .minimax import MiniMax_Client
663
  if os.environ.get("MINIMAX_API_KEY") != "":
664
  access_key = os.environ.get("MINIMAX_API_KEY")
665
- model = MiniMax_Client(model_name, api_key=access_key, user_name=user_name, system_prompt=system_prompt)
 
666
  elif model_type == ModelType.ChuanhuAgent:
667
  from .ChuanhuAgent import ChuanhuAgent_Client
668
- model = ChuanhuAgent_Client(model_name, access_key, user_name=user_name)
 
669
  elif model_type == ModelType.GooglePaLM:
670
- from .Google_PaLM import Google_PaLM_Client
671
  access_key = os.environ.get("GOOGLE_PALM_API_KEY", access_key)
672
- model = Google_PaLM_Client(model_name, access_key, user_name=user_name)
 
673
  elif model_type == ModelType.LangchainChat:
674
- from .azure import Azure_OpenAI_Client
675
  model = Azure_OpenAI_Client(model_name, user_name=user_name)
676
  elif model_type == ModelType.Midjourney:
677
  from .midjourney import Midjourney_Client
678
  mj_proxy_api_secret = os.getenv("MIDJOURNEY_PROXY_API_SECRET")
679
- model = Midjourney_Client(model_name, mj_proxy_api_secret, user_name=user_name)
 
680
  elif model_type == ModelType.Spark:
681
  from .spark import Spark_Client
682
- model = Spark_Client(model_name, os.getenv("SPARK_APPID"), os.getenv("SPARK_API_KEY"), os.getenv("SPARK_API_SECRET"), user_name=user_name)
 
683
  elif model_type == ModelType.Unknown:
684
  raise ValueError(f"未知模型: {model_name}")
685
  logging.info(msg)
 
1
  from __future__ import annotations
 
2
 
3
  import logging
 
 
4
  import os
 
 
 
 
 
 
 
5
 
 
6
  import colorama
7
+ import commentjson as cjson
8
+
9
+ from modules import config
 
10
 
 
11
  from ..index_func import *
12
+ from ..presets import *
13
  from ..utils import *
 
 
 
14
  from .base_model import BaseLLMModel, ModelType
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def get_model(
18
  model_name,
19
  lora_model_path=None,
 
36
  try:
37
  if model_type == ModelType.OpenAI:
38
  logging.info(f"正在加载OpenAI模型: {model_name}")
39
+ from .OpenAI import OpenAIClient
40
  access_key = os.environ.get("OPENAI_API_KEY", access_key)
41
  model = OpenAIClient(
42
  model_name=model_name,
 
48
  )
49
  elif model_type == ModelType.ChatGLM:
50
  logging.info(f"正在加载ChatGLM模型: {model_name}")
51
+ from .ChatGLM import ChatGLM_Client
52
  model = ChatGLM_Client(model_name, user_name=user_name)
53
  elif model_type == ModelType.LLaMA and lora_model_path == "":
54
  msg = f"现在请为 {model_name} 选择LoRA模型"
 
59
  lora_choices = ["No LoRA"] + lora_choices
60
  elif model_type == ModelType.LLaMA and lora_model_path != "":
61
  logging.info(f"正在加载LLaMA模型: {model_name} + {lora_model_path}")
62
+ from .LLaMA import LLaMA_Client
63
  dont_change_lora_selector = True
64
  if lora_model_path == "No LoRA":
65
  lora_model_path = None
 
69
  model = LLaMA_Client(
70
  model_name, lora_model_path, user_name=user_name)
71
  elif model_type == ModelType.XMChat:
72
+ from .XMChat import XMChatClient
73
  if os.environ.get("XMCHAT_API_KEY") != "":
74
  access_key = os.environ.get("XMCHAT_API_KEY")
75
+ model = XMChatClient(api_key=access_key, user_name=user_name)
76
  elif model_type == ModelType.StableLM:
77
  from .StableLM import StableLM_Client
78
  model = StableLM_Client(model_name, user_name=user_name)
 
81
  model = MOSS_Client(model_name, user_name=user_name)
82
  elif model_type == ModelType.YuanAI:
83
  from .inspurai import Yuan_Client
84
+ model = Yuan_Client(model_name, api_key=access_key,
85
+ user_name=user_name, system_prompt=system_prompt)
86
  elif model_type == ModelType.Minimax:
87
  from .minimax import MiniMax_Client
88
  if os.environ.get("MINIMAX_API_KEY") != "":
89
  access_key = os.environ.get("MINIMAX_API_KEY")
90
+ model = MiniMax_Client(
91
+ model_name, api_key=access_key, user_name=user_name, system_prompt=system_prompt)
92
  elif model_type == ModelType.ChuanhuAgent:
93
  from .ChuanhuAgent import ChuanhuAgent_Client
94
+ model = ChuanhuAgent_Client(
95
+ model_name, access_key, user_name=user_name)
96
  elif model_type == ModelType.GooglePaLM:
97
+ from .GooglePaLM import Google_PaLM_Client
98
  access_key = os.environ.get("GOOGLE_PALM_API_KEY", access_key)
99
+ model = Google_PaLM_Client(
100
+ model_name, access_key, user_name=user_name)
101
  elif model_type == ModelType.LangchainChat:
102
+ from .Azure import Azure_OpenAI_Client
103
  model = Azure_OpenAI_Client(model_name, user_name=user_name)
104
  elif model_type == ModelType.Midjourney:
105
  from .midjourney import Midjourney_Client
106
  mj_proxy_api_secret = os.getenv("MIDJOURNEY_PROXY_API_SECRET")
107
+ model = Midjourney_Client(
108
+ model_name, mj_proxy_api_secret, user_name=user_name)
109
  elif model_type == ModelType.Spark:
110
  from .spark import Spark_Client
111
+ model = Spark_Client(model_name, os.getenv("SPARK_APPID"), os.getenv(
112
+ "SPARK_API_KEY"), os.getenv("SPARK_API_SECRET"), user_name=user_name)
113
  elif model_type == ModelType.Unknown:
114
  raise ValueError(f"未知模型: {model_name}")
115
  logging.info(msg)