Tuchuanhuhuhu commited on
Commit
8fdf34e
·
1 Parent(s): f079043

加入GPT Index

Browse files
Files changed (7) hide show
  1. ChuanhuChatbot.py +7 -2
  2. chat_func.py +447 -0
  3. llama_func.py +201 -0
  4. overwrites.py +97 -0
  5. presets.py +47 -15
  6. requirements.txt +3 -1
  7. utils.py +39 -405
ChuanhuChatbot.py CHANGED
@@ -6,9 +6,11 @@ import sys
6
  import argparse
7
  from utils import *
8
  from presets import *
 
 
9
 
10
  logging.basicConfig(
11
- level=logging.INFO,
12
  format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
13
  )
14
 
@@ -49,6 +51,7 @@ else:
49
  authflag = True
50
 
51
  gr.Chatbot.postprocess = postprocess
 
52
 
53
  with open("custom.css", "r", encoding="utf-8") as f:
54
  customCSS = f.read()
@@ -165,7 +168,7 @@ with gr.Blocks(
165
  label="实时传输回答", value=True, visible=enable_streaming_option
166
  )
167
  use_websearch_checkbox = gr.Checkbox(label="使用在线搜索", value=False)
168
- index_files = gr.File(label="上传索引文件", type="file", multiple=True)
169
 
170
  with gr.Tab(label="Prompt"):
171
  systemPromptTxt = gr.Textbox(
@@ -286,6 +289,7 @@ with gr.Blocks(
286
  use_streaming_checkbox,
287
  model_select_dropdown,
288
  use_websearch_checkbox,
 
289
  ],
290
  [chatbot, history, status_display, token_count],
291
  show_progress=True,
@@ -306,6 +310,7 @@ with gr.Blocks(
306
  use_streaming_checkbox,
307
  model_select_dropdown,
308
  use_websearch_checkbox,
 
309
  ],
310
  [chatbot, history, status_display, token_count],
311
  show_progress=True,
 
6
  import argparse
7
  from utils import *
8
  from presets import *
9
+ from overwrites import *
10
+ from chat_func import *
11
 
12
  logging.basicConfig(
13
+ level=logging.DEBUG,
14
  format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
15
  )
16
 
 
51
  authflag = True
52
 
53
  gr.Chatbot.postprocess = postprocess
54
+ PromptHelper.compact_text_chunks = compact_text_chunks
55
 
56
  with open("custom.css", "r", encoding="utf-8") as f:
57
  customCSS = f.read()
 
168
  label="实时传输回答", value=True, visible=enable_streaming_option
169
  )
170
  use_websearch_checkbox = gr.Checkbox(label="使用在线搜索", value=False)
171
+ index_files = gr.Files(label="上传索引文件", type="file", multiple=True)
172
 
173
  with gr.Tab(label="Prompt"):
174
  systemPromptTxt = gr.Textbox(
 
289
  use_streaming_checkbox,
290
  model_select_dropdown,
291
  use_websearch_checkbox,
292
+ index_files
293
  ],
294
  [chatbot, history, status_display, token_count],
295
  show_progress=True,
 
310
  use_streaming_checkbox,
311
  model_select_dropdown,
312
  use_websearch_checkbox,
313
+ index_files
314
  ],
315
  [chatbot, history, status_display, token_count],
316
  show_progress=True,
chat_func.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ from __future__ import annotations
3
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Type
4
+ import logging
5
+ import json
6
+ import gradio as gr
7
+
8
+ # import openai
9
+ import os
10
+ import traceback
11
+ import requests
12
+
13
+ # import markdown
14
+ import csv
15
+ import mdtex2html
16
+ from pypinyin import lazy_pinyin
17
+ from presets import *
18
+ from llama_func import *
19
+ from utils import *
20
+ import tiktoken
21
+ from tqdm import tqdm
22
+ import colorama
23
+ import os
24
+ from llama_index import (
25
+ GPTSimpleVectorIndex,
26
+ GPTTreeIndex,
27
+ GPTKeywordTableIndex,
28
+ GPTListIndex,
29
+ )
30
+ from llama_index import SimpleDirectoryReader, download_loader
31
+ from llama_index import (
32
+ Document,
33
+ LLMPredictor,
34
+ PromptHelper,
35
+ QuestionAnswerPrompt,
36
+ RefinePrompt,
37
+ )
38
+ from langchain.llms import OpenAIChat, OpenAI
39
+ from duckduckgo_search import ddg
40
+ import datetime
41
+
42
+ # logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
43
+
44
+ if TYPE_CHECKING:
45
+ from typing import TypedDict
46
+
47
+ class DataframeData(TypedDict):
48
+ headers: List[str]
49
+ data: List[List[str | int | bool]]
50
+
51
+
52
+ initial_prompt = "You are a helpful assistant."
53
+ API_URL = "https://api.openai.com/v1/chat/completions"
54
+ HISTORY_DIR = "history"
55
+ TEMPLATES_DIR = "templates"
56
+
57
+ def get_response(
58
+ openai_api_key, system_prompt, history, temperature, top_p, stream, selected_model
59
+ ):
60
+ headers = {
61
+ "Content-Type": "application/json",
62
+ "Authorization": f"Bearer {openai_api_key}",
63
+ }
64
+
65
+ history = [construct_system(system_prompt), *history]
66
+
67
+ payload = {
68
+ "model": selected_model,
69
+ "messages": history, # [{"role": "user", "content": f"{inputs}"}],
70
+ "temperature": temperature, # 1.0,
71
+ "top_p": top_p, # 1.0,
72
+ "n": 1,
73
+ "stream": stream,
74
+ "presence_penalty": 0,
75
+ "frequency_penalty": 0,
76
+ }
77
+ if stream:
78
+ timeout = timeout_streaming
79
+ else:
80
+ timeout = timeout_all
81
+
82
+ # 获取环境变量中的代理设置
83
+ http_proxy = os.environ.get("HTTP_PROXY") or os.environ.get("http_proxy")
84
+ https_proxy = os.environ.get("HTTPS_PROXY") or os.environ.get("https_proxy")
85
+
86
+ # 如果存在代理设置,使用它们
87
+ proxies = {}
88
+ if http_proxy:
89
+ logging.info(f"Using HTTP proxy: {http_proxy}")
90
+ proxies["http"] = http_proxy
91
+ if https_proxy:
92
+ logging.info(f"Using HTTPS proxy: {https_proxy}")
93
+ proxies["https"] = https_proxy
94
+
95
+ # 如果有代理,使用代理发送请求,否则使用默认设置发送请求
96
+ if proxies:
97
+ response = requests.post(
98
+ API_URL,
99
+ headers=headers,
100
+ json=payload,
101
+ stream=True,
102
+ timeout=timeout,
103
+ proxies=proxies,
104
+ )
105
+ else:
106
+ response = requests.post(
107
+ API_URL,
108
+ headers=headers,
109
+ json=payload,
110
+ stream=True,
111
+ timeout=timeout,
112
+ )
113
+ return response
114
+
115
+
116
+ def stream_predict(
117
+ openai_api_key,
118
+ system_prompt,
119
+ history,
120
+ inputs,
121
+ chatbot,
122
+ all_token_counts,
123
+ top_p,
124
+ temperature,
125
+ selected_model,
126
+ ):
127
+ def get_return_value():
128
+ return chatbot, history, status_text, all_token_counts
129
+
130
+ logging.info("实时回答模式")
131
+ partial_words = ""
132
+ counter = 0
133
+ status_text = "开始实时传输回答……"
134
+ history.append(construct_user(inputs))
135
+ history.append(construct_assistant(""))
136
+ chatbot.append((parse_text(inputs), ""))
137
+ user_token_count = 0
138
+ if len(all_token_counts) == 0:
139
+ system_prompt_token_count = count_token(construct_system(system_prompt))
140
+ user_token_count = (
141
+ count_token(construct_user(inputs)) + system_prompt_token_count
142
+ )
143
+ else:
144
+ user_token_count = count_token(construct_user(inputs))
145
+ all_token_counts.append(user_token_count)
146
+ logging.info(f"输入token计数: {user_token_count}")
147
+ yield get_return_value()
148
+ try:
149
+ response = get_response(
150
+ openai_api_key,
151
+ system_prompt,
152
+ history,
153
+ temperature,
154
+ top_p,
155
+ True,
156
+ selected_model,
157
+ )
158
+ except requests.exceptions.ConnectTimeout:
159
+ status_text = (
160
+ standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
161
+ )
162
+ yield get_return_value()
163
+ return
164
+ except requests.exceptions.ReadTimeout:
165
+ status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
166
+ yield get_return_value()
167
+ return
168
+
169
+ yield get_return_value()
170
+ error_json_str = ""
171
+
172
+ for chunk in tqdm(response.iter_lines()):
173
+ if counter == 0:
174
+ counter += 1
175
+ continue
176
+ counter += 1
177
+ # check whether each line is non-empty
178
+ if chunk:
179
+ chunk = chunk.decode()
180
+ chunklength = len(chunk)
181
+ try:
182
+ chunk = json.loads(chunk[6:])
183
+ except json.JSONDecodeError:
184
+ logging.info(chunk)
185
+ error_json_str += chunk
186
+ status_text = f"JSON解析错误。请重置对话。收到的内容: {error_json_str}"
187
+ yield get_return_value()
188
+ continue
189
+ # decode each line as response data is in bytes
190
+ if chunklength > 6 and "delta" in chunk["choices"][0]:
191
+ finish_reason = chunk["choices"][0]["finish_reason"]
192
+ status_text = construct_token_message(
193
+ sum(all_token_counts), stream=True
194
+ )
195
+ if finish_reason == "stop":
196
+ yield get_return_value()
197
+ break
198
+ try:
199
+ partial_words = (
200
+ partial_words + chunk["choices"][0]["delta"]["content"]
201
+ )
202
+ except KeyError:
203
+ status_text = (
204
+ standard_error_msg
205
+ + "API回复中找不到内容。很可能是Token计数达到上限了。请重置对话。当前Token计数: "
206
+ + str(sum(all_token_counts))
207
+ )
208
+ yield get_return_value()
209
+ break
210
+ history[-1] = construct_assistant(partial_words)
211
+ chatbot[-1] = (parse_text(inputs), parse_text(partial_words))
212
+ all_token_counts[-1] += 1
213
+ yield get_return_value()
214
+
215
+
216
+ def predict_all(
217
+ openai_api_key,
218
+ system_prompt,
219
+ history,
220
+ inputs,
221
+ chatbot,
222
+ all_token_counts,
223
+ top_p,
224
+ temperature,
225
+ selected_model,
226
+ ):
227
+ logging.info("一次性回答模式")
228
+ history.append(construct_user(inputs))
229
+ history.append(construct_assistant(""))
230
+ chatbot.append((parse_text(inputs), ""))
231
+ all_token_counts.append(count_token(construct_user(inputs)))
232
+ try:
233
+ response = get_response(
234
+ openai_api_key,
235
+ system_prompt,
236
+ history,
237
+ temperature,
238
+ top_p,
239
+ False,
240
+ selected_model,
241
+ )
242
+ except requests.exceptions.ConnectTimeout:
243
+ status_text = (
244
+ standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
245
+ )
246
+ return chatbot, history, status_text, all_token_counts
247
+ except requests.exceptions.ProxyError:
248
+ status_text = standard_error_msg + proxy_error_prompt + error_retrieve_prompt
249
+ return chatbot, history, status_text, all_token_counts
250
+ except requests.exceptions.SSLError:
251
+ status_text = standard_error_msg + ssl_error_prompt + error_retrieve_prompt
252
+ return chatbot, history, status_text, all_token_counts
253
+ response = json.loads(response.text)
254
+ content = response["choices"][0]["message"]["content"]
255
+ history[-1] = construct_assistant(content)
256
+ chatbot[-1] = (parse_text(inputs), parse_text(content))
257
+ total_token_count = response["usage"]["total_tokens"]
258
+ all_token_counts[-1] = total_token_count - sum(all_token_counts)
259
+ status_text = construct_token_message(total_token_count)
260
+ return chatbot, history, status_text, all_token_counts
261
+
262
+
263
+ def predict(
264
+ openai_api_key,
265
+ system_prompt,
266
+ history,
267
+ inputs,
268
+ chatbot,
269
+ all_token_counts,
270
+ top_p,
271
+ temperature,
272
+ stream=False,
273
+ selected_model=MODELS[0],
274
+ use_websearch_checkbox=False,
275
+ files = None,
276
+ should_check_token_count=True,
277
+ ): # repetition_penalty, top_k
278
+ logging.info("输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
279
+ if files:
280
+ msg = "构建索引中……(这可能需要比较久的时间)"
281
+ logging.info(msg)
282
+ yield chatbot, history, msg, all_token_counts
283
+ index = construct_index(openai_api_key, file_src=files)
284
+ msg = "索引构建完成,获取回答中……"
285
+ yield chatbot, history, msg, all_token_counts
286
+ history, chatbot, status_text = chat_ai(openai_api_key, index, inputs, history, chatbot)
287
+ yield chatbot, history, status_text, all_token_counts
288
+ return
289
+ if use_websearch_checkbox:
290
+ results = ddg(inputs, max_results=3)
291
+ web_results = []
292
+ for idx, result in enumerate(results):
293
+ logging.info(f"搜索结果{idx + 1}:{result}")
294
+ web_results.append(f'[{idx+1}]"{result["body"]}"\nURL: {result["href"]}')
295
+ web_results = "\n\n".join(web_results)
296
+ inputs = (
297
+ replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
298
+ .replace("{query}", inputs)
299
+ .replace("{web_results}", web_results)
300
+ )
301
+ if len(openai_api_key) != 51:
302
+ status_text = standard_error_msg + no_apikey_msg
303
+ logging.info(status_text)
304
+ chatbot.append((parse_text(inputs), ""))
305
+ if len(history) == 0:
306
+ history.append(construct_user(inputs))
307
+ history.append("")
308
+ all_token_counts.append(0)
309
+ else:
310
+ history[-2] = construct_user(inputs)
311
+ yield chatbot, history, status_text, all_token_counts
312
+ return
313
+ if stream:
314
+ yield chatbot, history, "开始生成回答……", all_token_counts
315
+ if stream:
316
+ logging.info("使用流式传输")
317
+ iter = stream_predict(
318
+ openai_api_key,
319
+ system_prompt,
320
+ history,
321
+ inputs,
322
+ chatbot,
323
+ all_token_counts,
324
+ top_p,
325
+ temperature,
326
+ selected_model,
327
+ )
328
+ for chatbot, history, status_text, all_token_counts in iter:
329
+ yield chatbot, history, status_text, all_token_counts
330
+ else:
331
+ logging.info("不使用流式传输")
332
+ chatbot, history, status_text, all_token_counts = predict_all(
333
+ openai_api_key,
334
+ system_prompt,
335
+ history,
336
+ inputs,
337
+ chatbot,
338
+ all_token_counts,
339
+ top_p,
340
+ temperature,
341
+ selected_model,
342
+ )
343
+ yield chatbot, history, status_text, all_token_counts
344
+ logging.info(f"传输完毕。当前token计数为{all_token_counts}")
345
+ if len(history) > 1 and history[-1]["content"] != inputs:
346
+ logging.info(
347
+ "回答为:"
348
+ + colorama.Fore.BLUE
349
+ + f"{history[-1]['content']}"
350
+ + colorama.Style.RESET_ALL
351
+ )
352
+ if stream:
353
+ max_token = max_token_streaming
354
+ else:
355
+ max_token = max_token_all
356
+ if sum(all_token_counts) > max_token and should_check_token_count:
357
+ status_text = f"精简token中{all_token_counts}/{max_token}"
358
+ logging.info(status_text)
359
+ yield chatbot, history, status_text, all_token_counts
360
+ iter = reduce_token_size(
361
+ openai_api_key,
362
+ system_prompt,
363
+ history,
364
+ chatbot,
365
+ all_token_counts,
366
+ top_p,
367
+ temperature,
368
+ stream=False,
369
+ selected_model=selected_model,
370
+ hidden=True,
371
+ )
372
+ for chatbot, history, status_text, all_token_counts in iter:
373
+ status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}"
374
+ yield chatbot, history, status_text, all_token_counts
375
+
376
+
377
+ def retry(
378
+ openai_api_key,
379
+ system_prompt,
380
+ history,
381
+ chatbot,
382
+ token_count,
383
+ top_p,
384
+ temperature,
385
+ stream=False,
386
+ selected_model=MODELS[0],
387
+ ):
388
+ logging.info("重试中……")
389
+ if len(history) == 0:
390
+ yield chatbot, history, f"{standard_error_msg}上下文是空的", token_count
391
+ return
392
+ history.pop()
393
+ inputs = history.pop()["content"]
394
+ token_count.pop()
395
+ iter = predict(
396
+ openai_api_key,
397
+ system_prompt,
398
+ history,
399
+ inputs,
400
+ chatbot,
401
+ token_count,
402
+ top_p,
403
+ temperature,
404
+ stream=stream,
405
+ selected_model=selected_model,
406
+ )
407
+ logging.info("重试完毕")
408
+ for x in iter:
409
+ yield x
410
+
411
+
412
+ def reduce_token_size(
413
+ openai_api_key,
414
+ system_prompt,
415
+ history,
416
+ chatbot,
417
+ token_count,
418
+ top_p,
419
+ temperature,
420
+ stream=False,
421
+ selected_model=MODELS[0],
422
+ hidden=False,
423
+ ):
424
+ logging.info("开始减少token数量……")
425
+ iter = predict(
426
+ openai_api_key,
427
+ system_prompt,
428
+ history,
429
+ summarize_prompt,
430
+ chatbot,
431
+ token_count,
432
+ top_p,
433
+ temperature,
434
+ stream=stream,
435
+ selected_model=selected_model,
436
+ should_check_token_count=False,
437
+ )
438
+ logging.info(f"chatbot: {chatbot}")
439
+ for chatbot, history, status_text, previous_token_count in iter:
440
+ history = history[-2:]
441
+ token_count = previous_token_count[-1:]
442
+ if hidden:
443
+ chatbot.pop()
444
+ yield chatbot, history, construct_token_message(
445
+ sum(token_count), stream=stream
446
+ ), token_count
447
+ logging.info("减少token数量完毕")
llama_func.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from llama_index import (
3
+ GPTSimpleVectorIndex,
4
+ GPTTreeIndex,
5
+ GPTKeywordTableIndex,
6
+ GPTListIndex,
7
+ )
8
+ from llama_index import SimpleDirectoryReader, download_loader
9
+ from llama_index import (
10
+ Document,
11
+ LLMPredictor,
12
+ PromptHelper,
13
+ QuestionAnswerPrompt,
14
+ RefinePrompt,
15
+ )
16
+ from langchain.llms import OpenAIChat, OpenAI
17
+ from googlesearch import search as google_search
18
+ from baidusearch.baidusearch import search as baidu_search
19
+ from duckduckgo_search import ddg
20
+ import colorama
21
+
22
+ import logging
23
+ import sys
24
+
25
+ from presets import *
26
+ from utils import *
27
+
28
+
29
+ def get_documents(file_src):
30
+ documents = []
31
+ index_name = ""
32
+ logging.debug("Loading documents...")
33
+ logging.debug(f"file_src: {file_src}")
34
+ for file in file_src:
35
+ logging.debug(f"file: {file.name}")
36
+ index_name += file.name
37
+ if os.path.splitext(file.name)[1] == ".pdf":
38
+ logging.debug("Loading PDF...")
39
+ CJKPDFReader = download_loader("CJKPDFReader")
40
+ loader = CJKPDFReader()
41
+ documents += loader.load_data(file=file.name)
42
+ elif os.path.splitext(file.name)[1] == ".docx":
43
+ logging.debug("Loading DOCX...")
44
+ DocxReader = download_loader("DocxReader")
45
+ loader = DocxReader()
46
+ documents += loader.load_data(file=file.name)
47
+ elif os.path.splitext(file.name)[1] == ".epub":
48
+ logging.debug("Loading EPUB...")
49
+ EpubReader = download_loader("EpubReader")
50
+ loader = EpubReader()
51
+ documents += loader.load_data(file=file.name)
52
+ else:
53
+ logging.debug("Loading text file...")
54
+ with open(file.name, "r", encoding="utf-8") as f:
55
+ text = add_space(f.read())
56
+ documents += [Document(text)]
57
+ index_name = sha1sum(index_name)
58
+ return documents, index_name
59
+
60
+
61
+ def construct_index(
62
+ api_key,
63
+ file_src,
64
+ max_input_size=4096,
65
+ num_outputs=1,
66
+ max_chunk_overlap=20,
67
+ chunk_size_limit=600,
68
+ embedding_limit=None,
69
+ separator=" ",
70
+ num_children=10,
71
+ max_keywords_per_chunk=10,
72
+ ):
73
+ os.environ["OPENAI_API_KEY"] = api_key
74
+ chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
75
+ embedding_limit = None if embedding_limit == 0 else embedding_limit
76
+ separator = " " if separator == "" else separator
77
+
78
+ llm_predictor = LLMPredictor(
79
+ llm=OpenAI(model_name="gpt-3.5-turbo-0301", openai_api_key=api_key)
80
+ )
81
+ prompt_helper = PromptHelper(
82
+ max_input_size,
83
+ num_outputs,
84
+ max_chunk_overlap,
85
+ embedding_limit,
86
+ chunk_size_limit,
87
+ separator=separator,
88
+ )
89
+ documents, index_name = get_documents(file_src)
90
+ if os.path.exists(f"./index/{index_name}.json"):
91
+ logging.info("找到了缓存的索引文件,加载中……")
92
+ return GPTSimpleVectorIndex.load_from_disk(f"./index/{index_name}.json")
93
+ else:
94
+ try:
95
+ logging.debug("构建索引中……")
96
+ index = GPTSimpleVectorIndex(
97
+ documents, llm_predictor=llm_predictor, prompt_helper=prompt_helper
98
+ )
99
+ os.makedirs("./index", exist_ok=True)
100
+ index.save_to_disk(f"./index/{index_name}.json")
101
+ return index
102
+ except Exception as e:
103
+ print(e)
104
+ return None
105
+
106
+
107
+ def chat_ai(
108
+ api_key,
109
+ index,
110
+ question,
111
+ context,
112
+ chatbot,
113
+ ):
114
+ os.environ["OPENAI_API_KEY"] = api_key
115
+
116
+ logging.info(f"Question: {question}")
117
+
118
+ response, status_text = ask_ai(
119
+ api_key,
120
+ index,
121
+ question,
122
+ replace_today(PROMPT_TEMPLATE),
123
+ REFINE_TEMPLATE,
124
+ SIM_K,
125
+ INDEX_QUERY_TEMPRATURE,
126
+ context,
127
+ )
128
+ if response is None:
129
+ status_text = "查询失败,请换个问法试试"
130
+ return context, chatbot
131
+ response = response
132
+
133
+ context.append({"role": "user", "content": question})
134
+ context.append({"role": "assistant", "content": response})
135
+ chatbot.append((question, response))
136
+
137
+ os.environ["OPENAI_API_KEY"] = ""
138
+ return context, chatbot, status_text
139
+
140
+
141
+ def ask_ai(
142
+ api_key,
143
+ index,
144
+ question,
145
+ prompt_tmpl,
146
+ refine_tmpl,
147
+ sim_k=1,
148
+ temprature=0,
149
+ prefix_messages=[],
150
+ ):
151
+ os.environ["OPENAI_API_KEY"] = api_key
152
+
153
+ logging.debug("Index file found")
154
+ logging.debug("Querying index...")
155
+ llm_predictor = LLMPredictor(
156
+ llm=OpenAI(
157
+ temperature=temprature,
158
+ model_name="gpt-3.5-turbo-0301",
159
+ prefix_messages=prefix_messages,
160
+ )
161
+ )
162
+
163
+ response = None # Initialize response variable to avoid UnboundLocalError
164
+ qa_prompt = QuestionAnswerPrompt(prompt_tmpl)
165
+ rf_prompt = RefinePrompt(refine_tmpl)
166
+ response = index.query(
167
+ question,
168
+ llm_predictor=llm_predictor,
169
+ similarity_top_k=sim_k,
170
+ text_qa_template=qa_prompt,
171
+ refine_template=rf_prompt,
172
+ response_mode="compact",
173
+ )
174
+
175
+ if response is not None:
176
+ logging.info(f"Response: {response}")
177
+ ret_text = response.response
178
+ ret_text += "\n----------\n"
179
+ nodes = []
180
+ for index, node in enumerate(response.source_nodes):
181
+ brief = node.source_text[:25].replace("\n", "")
182
+ nodes.append(
183
+ f"<details><summary>[{index+1}]\t{brief}...</summary><p>{node.source_text}</p></details>"
184
+ )
185
+ ret_text += "\n\n".join(nodes)
186
+ logging.info(
187
+ f"Response: {colorama.Fore.BLUE}{ret_text}{colorama.Style.RESET_ALL}"
188
+ )
189
+ os.environ["OPENAI_API_KEY"] = ""
190
+ return ret_text, f"查询消耗了{llm_predictor.last_token_usage} tokens"
191
+ else:
192
+ logging.warning("No response found, returning None")
193
+ os.environ["OPENAI_API_KEY"] = ""
194
+ return None
195
+
196
+
197
+ def add_space(text):
198
+ punctuations = {",": ", ", "。": "。 ", "?": "? ", "!": "! ", ":": ": ", ";": "; "}
199
+ for cn_punc, en_punc in punctuations.items():
200
+ text = text.replace(cn_punc, en_punc)
201
+ return text
overwrites.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import os
3
+
4
+ import llama_index
5
+
6
+ from llama_index import (
7
+ LLMPredictor,
8
+ GPTTreeIndex,
9
+ Document,
10
+ GPTSimpleVectorIndex,
11
+ SimpleDirectoryReader,
12
+ RefinePrompt,
13
+ QuestionAnswerPrompt,
14
+ GPTListIndex,
15
+ PromptHelper,
16
+ )
17
+ from pathlib import Path
18
+ from docx import Document as DocxDocument
19
+ from tqdm import tqdm
20
+ import re
21
+ from langchain.llms import OpenAIChat, OpenAI
22
+ from llama_index.composability import ComposableGraph
23
+ from IPython.display import Markdown, display
24
+ import json
25
+ from llama_index import Prompt
26
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Type
27
+
28
+ import logging
29
+ import sys
30
+
31
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Type
32
+ import logging
33
+ import json
34
+ import gradio as gr
35
+
36
+ # import openai
37
+ import os
38
+ import traceback
39
+ import requests
40
+
41
+ # import markdown
42
+ import csv
43
+ import mdtex2html
44
+ from pypinyin import lazy_pinyin
45
+ from presets import *
46
+ from llama_func import *
47
+ import tiktoken
48
+ from tqdm import tqdm
49
+ import colorama
50
+ import os
51
+ from llama_index import (
52
+ GPTSimpleVectorIndex,
53
+ GPTTreeIndex,
54
+ GPTKeywordTableIndex,
55
+ GPTListIndex,
56
+ )
57
+ from llama_index import SimpleDirectoryReader, download_loader
58
+ from llama_index import (
59
+ Document,
60
+ LLMPredictor,
61
+ PromptHelper,
62
+ QuestionAnswerPrompt,
63
+ RefinePrompt,
64
+ )
65
+ from langchain.llms import OpenAIChat, OpenAI
66
+ from duckduckgo_search import ddg
67
+ import datetime
68
+
69
+ def compact_text_chunks(self, prompt: Prompt, text_chunks: List[str]) -> List[str]:
70
+ logging.debug("Compacting text chunks...🚀🚀🚀")
71
+ combined_str = [c.strip() for c in text_chunks if c.strip()]
72
+ combined_str = [f"[{index+1}] {c}" for index, c in enumerate(combined_str)]
73
+ combined_str = "\n\n".join(combined_str)
74
+ # resplit based on self.max_chunk_overlap
75
+ text_splitter = self.get_text_splitter_given_prompt(prompt, 1, padding=1)
76
+ return text_splitter.split_text(combined_str)
77
+
78
+
79
+ def postprocess(
80
+ self, y: List[Tuple[str | None, str | None]]
81
+ ) -> List[Tuple[str | None, str | None]]:
82
+ """
83
+ Parameters:
84
+ y: List of tuples representing the message and response pairs. Each message and response should be a string, which may be in Markdown format.
85
+ Returns:
86
+ List of tuples representing the message and response. Each message and response will be a string of HTML.
87
+ """
88
+ if y is None:
89
+ return []
90
+ for i, (message, response) in enumerate(y):
91
+ y[i] = (
92
+ # None if message is None else markdown.markdown(message),
93
+ # None if response is None else markdown.markdown(response),
94
+ None if message is None else message,
95
+ None if response is None else mdtex2html.convert(response, extensions=['fenced_code','codehilite','tables']),
96
+ )
97
+ return y
presets.py CHANGED
@@ -1,4 +1,23 @@
1
  # -*- coding:utf-8 -*-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  title = """<h1 align="left" style="min-width:200px; margin-top:0;">川虎ChatGPT 🚀</h1>"""
3
  description = """\
4
  <div align="center" style="margin:16px 0">
@@ -12,6 +31,7 @@ description = """\
12
  """
13
 
14
  summarize_prompt = "你是谁?我们刚才聊了什么?" # 总结对话时的 prompt
 
15
  MODELS = [
16
  "gpt-3.5-turbo",
17
  "gpt-3.5-turbo-0301",
@@ -21,7 +41,8 @@ MODELS = [
21
  "gpt-4-32k-0314",
22
  ] # 可选的模型
23
 
24
- websearch_prompt = """\
 
25
  Web search results:
26
 
27
  {web_results}
@@ -31,18 +52,29 @@ Instructions: Using the provided web search results, write a comprehensive reply
31
  Query: {query}
32
  Reply in 中文"""
33
 
34
- # 错误信息
35
- standard_error_msg = "☹️发生了错误:" # 错误信息的标准前缀
36
- error_retrieve_prompt = "请检查网络连接,或者API-Key是否有效。" # 获取对话时发生错误
37
- connection_timeout_prompt = "连接超时,无法获取对话。" # 连接超时
38
- read_timeout_prompt = "读取超时,无法获取对话。" # 读取超时
39
- proxy_error_prompt = "代理错误,无法获取对话。" # 代理错误
40
- ssl_error_prompt = "SSL错误,无法获取对话。" # SSL 错误
41
- no_apikey_msg = "API key长度不是51位,请检查是否输入正确。" # API key 长度不足 51
 
 
 
 
 
42
 
43
- max_token_streaming = 3500 # 流式对话时的最大 token 数
44
- timeout_streaming = 30 # 流式对话时的超时时间
45
- max_token_all = 3500 # 非流式对话时的最大 token
46
- timeout_all = 200 # 非流式对话时的超时时间
47
- enable_streaming_option = True # 是否启用选择选择是否实时显示回答的勾选框
48
- HIDE_MY_KEY = False # 如果你想在UI中隐藏你的 API 密钥,将此值设置为 True
 
 
 
 
 
 
 
1
  # -*- coding:utf-8 -*-
2
+ # 错误信息
3
+ standard_error_msg = "☹️发生了错误:" # 错误信息的标准前缀
4
+ error_retrieve_prompt = "请检查网络连接,或者API-Key是否有效。" # 获取对话时发生错误
5
+ connection_timeout_prompt = "连接超时,无法获取对话。" # 连接超时
6
+ read_timeout_prompt = "读取超时,无法获取对话。" # 读取超时
7
+ proxy_error_prompt = "代理错误,无法获取对话。" # 代理错误
8
+ ssl_error_prompt = "SSL错误,无法获取对话。" # SSL 错误
9
+ no_apikey_msg = "API key长度不是51位,请检查是否输入正确。" # API key 长度不足 51 位
10
+
11
+ max_token_streaming = 3500 # 流式对话时的最大 token 数
12
+ timeout_streaming = 30 # 流式对话时的超时时间
13
+ max_token_all = 3500 # 非流式对话时的最大 token 数
14
+ timeout_all = 200 # 非流式对话时的超时时间
15
+ enable_streaming_option = True # 是否启用选择选择是否实时显示回答的勾选框
16
+ HIDE_MY_KEY = False # 如果你想在UI中隐藏你的 API 密钥,将此值设置为 True
17
+
18
+ SIM_K = 5
19
+ INDEX_QUERY_TEMPRATURE = 1.0
20
+
21
  title = """<h1 align="left" style="min-width:200px; margin-top:0;">川虎ChatGPT 🚀</h1>"""
22
  description = """\
23
  <div align="center" style="margin:16px 0">
 
31
  """
32
 
33
  summarize_prompt = "你是谁?我们刚才聊了什么?" # 总结对话时的 prompt
34
+
35
  MODELS = [
36
  "gpt-3.5-turbo",
37
  "gpt-3.5-turbo-0301",
 
41
  "gpt-4-32k-0314",
42
  ] # 可选的模型
43
 
44
+
45
+ WEBSEARCH_PTOMPT_TEMPLATE = """\
46
  Web search results:
47
 
48
  {web_results}
 
52
  Query: {query}
53
  Reply in 中文"""
54
 
55
+ PROMPT_TEMPLATE = """\
56
+ Context information is below.
57
+ ---------------------
58
+ {context_str}
59
+ ---------------------
60
+ Using the provided context information, write a comprehensive reply to the given query.
61
+ Make sure to cite results using [number] notation after the reference.
62
+ If the provided context information refer to multiple subjects with the same name, write separate answers for each subject.
63
+ Use prior knowledge only if the given context didn't provide enough information.
64
+ Today is {current_date}.
65
+ Answer the question: {query_str}
66
+ Reply in 中文
67
+ """
68
 
69
+ REFINE_TEMPLATE = """\
70
+ The original question is as follows: {query_str}
71
+ We have provided an existing answer: {existing_answer}
72
+ We have the opportunity to refine the existing answer
73
+ (only if needed) with some more context below.
74
+ ------------
75
+ {context_msg}
76
+ ------------
77
+ Given the new context, refine the original answer to better
78
+ Answer in the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch.
79
+ If the context isn't useful, return the original answer.
80
+ """
requirements.txt CHANGED
@@ -6,4 +6,6 @@ socksio
6
  tqdm
7
  colorama
8
  duckduckgo_search
9
- Pygments
 
 
 
6
  tqdm
7
  colorama
8
  duckduckgo_search
9
+ Pygments
10
+ llama_index
11
+ langchain
utils.py CHANGED
@@ -18,8 +18,25 @@ from presets import *
18
  import tiktoken
19
  from tqdm import tqdm
20
  import colorama
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  from duckduckgo_search import ddg
22
  import datetime
 
23
 
24
  # logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
25
 
@@ -37,27 +54,6 @@ HISTORY_DIR = "history"
37
  TEMPLATES_DIR = "templates"
38
 
39
 
40
- def postprocess(
41
- self, y: List[Tuple[str | None, str | None]]
42
- ) -> List[Tuple[str | None, str | None]]:
43
- """
44
- Parameters:
45
- y: List of tuples representing the message and response pairs. Each message and response should be a string, which may be in Markdown format.
46
- Returns:
47
- List of tuples representing the message and response. Each message and response will be a string of HTML.
48
- """
49
- if y is None:
50
- return []
51
- for i, (message, response) in enumerate(y):
52
- y[i] = (
53
- # None if message is None else markdown.markdown(message),
54
- # None if response is None else markdown.markdown(response),
55
- None if message is None else message,
56
- None if response is None else mdtex2html.convert(response, extensions=['fenced_code','codehilite','tables']),
57
- )
58
- return y
59
-
60
-
61
  def count_token(message):
62
  encoding = tiktoken.get_encoding("cl100k_base")
63
  input_str = f"role: {message['role']}, content: {message['content']}"
@@ -102,389 +98,6 @@ def construct_token_message(token, stream=False):
102
  return f"Token 计数: {token}"
103
 
104
 
105
- def get_response(
106
- openai_api_key, system_prompt, history, temperature, top_p, stream, selected_model
107
- ):
108
- headers = {
109
- "Content-Type": "application/json",
110
- "Authorization": f"Bearer {openai_api_key}",
111
- }
112
-
113
- history = [construct_system(system_prompt), *history]
114
-
115
- payload = {
116
- "model": selected_model,
117
- "messages": history, # [{"role": "user", "content": f"{inputs}"}],
118
- "temperature": temperature, # 1.0,
119
- "top_p": top_p, # 1.0,
120
- "n": 1,
121
- "stream": stream,
122
- "presence_penalty": 0,
123
- "frequency_penalty": 0,
124
- }
125
- if stream:
126
- timeout = timeout_streaming
127
- else:
128
- timeout = timeout_all
129
-
130
- # 获取环境变量中的代理设置
131
- http_proxy = os.environ.get("HTTP_PROXY") or os.environ.get("http_proxy")
132
- https_proxy = os.environ.get("HTTPS_PROXY") or os.environ.get("https_proxy")
133
-
134
- # 如果存在代理设置,使用它们
135
- proxies = {}
136
- if http_proxy:
137
- logging.info(f"Using HTTP proxy: {http_proxy}")
138
- proxies["http"] = http_proxy
139
- if https_proxy:
140
- logging.info(f"Using HTTPS proxy: {https_proxy}")
141
- proxies["https"] = https_proxy
142
-
143
- # 如果有代理,使用代理发送请求,否则使用默认设置发送请求
144
- if proxies:
145
- response = requests.post(
146
- API_URL,
147
- headers=headers,
148
- json=payload,
149
- stream=True,
150
- timeout=timeout,
151
- proxies=proxies,
152
- )
153
- else:
154
- response = requests.post(
155
- API_URL,
156
- headers=headers,
157
- json=payload,
158
- stream=True,
159
- timeout=timeout,
160
- )
161
- return response
162
-
163
-
164
- def stream_predict(
165
- openai_api_key,
166
- system_prompt,
167
- history,
168
- inputs,
169
- chatbot,
170
- all_token_counts,
171
- top_p,
172
- temperature,
173
- selected_model,
174
- ):
175
- def get_return_value():
176
- return chatbot, history, status_text, all_token_counts
177
-
178
- logging.info("实时回答模式")
179
- partial_words = ""
180
- counter = 0
181
- status_text = "开始实时传输回答……"
182
- history.append(construct_user(inputs))
183
- history.append(construct_assistant(""))
184
- chatbot.append((parse_text(inputs), ""))
185
- user_token_count = 0
186
- if len(all_token_counts) == 0:
187
- system_prompt_token_count = count_token(construct_system(system_prompt))
188
- user_token_count = (
189
- count_token(construct_user(inputs)) + system_prompt_token_count
190
- )
191
- else:
192
- user_token_count = count_token(construct_user(inputs))
193
- all_token_counts.append(user_token_count)
194
- logging.info(f"输入token计数: {user_token_count}")
195
- yield get_return_value()
196
- try:
197
- response = get_response(
198
- openai_api_key,
199
- system_prompt,
200
- history,
201
- temperature,
202
- top_p,
203
- True,
204
- selected_model,
205
- )
206
- except requests.exceptions.ConnectTimeout:
207
- status_text = (
208
- standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
209
- )
210
- yield get_return_value()
211
- return
212
- except requests.exceptions.ReadTimeout:
213
- status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
214
- yield get_return_value()
215
- return
216
-
217
- yield get_return_value()
218
- error_json_str = ""
219
-
220
- for chunk in tqdm(response.iter_lines()):
221
- if counter == 0:
222
- counter += 1
223
- continue
224
- counter += 1
225
- # check whether each line is non-empty
226
- if chunk:
227
- chunk = chunk.decode()
228
- chunklength = len(chunk)
229
- try:
230
- chunk = json.loads(chunk[6:])
231
- except json.JSONDecodeError:
232
- logging.info(chunk)
233
- error_json_str += chunk
234
- status_text = f"JSON解析错误。请重置对话。收到的内容: {error_json_str}"
235
- yield get_return_value()
236
- continue
237
- # decode each line as response data is in bytes
238
- if chunklength > 6 and "delta" in chunk["choices"][0]:
239
- finish_reason = chunk["choices"][0]["finish_reason"]
240
- status_text = construct_token_message(
241
- sum(all_token_counts), stream=True
242
- )
243
- if finish_reason == "stop":
244
- yield get_return_value()
245
- break
246
- try:
247
- partial_words = (
248
- partial_words + chunk["choices"][0]["delta"]["content"]
249
- )
250
- except KeyError:
251
- status_text = (
252
- standard_error_msg
253
- + "API回复中找不到内容。很可能是Token计数达到上限了。请重置对话。当前Token计数: "
254
- + str(sum(all_token_counts))
255
- )
256
- yield get_return_value()
257
- break
258
- history[-1] = construct_assistant(partial_words)
259
- chatbot[-1] = (parse_text(inputs), parse_text(partial_words))
260
- all_token_counts[-1] += 1
261
- yield get_return_value()
262
-
263
-
264
- def predict_all(
265
- openai_api_key,
266
- system_prompt,
267
- history,
268
- inputs,
269
- chatbot,
270
- all_token_counts,
271
- top_p,
272
- temperature,
273
- selected_model,
274
- ):
275
- logging.info("一次性回答模式")
276
- history.append(construct_user(inputs))
277
- history.append(construct_assistant(""))
278
- chatbot.append((parse_text(inputs), ""))
279
- all_token_counts.append(count_token(construct_user(inputs)))
280
- try:
281
- response = get_response(
282
- openai_api_key,
283
- system_prompt,
284
- history,
285
- temperature,
286
- top_p,
287
- False,
288
- selected_model,
289
- )
290
- except requests.exceptions.ConnectTimeout:
291
- status_text = (
292
- standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
293
- )
294
- return chatbot, history, status_text, all_token_counts
295
- except requests.exceptions.ProxyError:
296
- status_text = standard_error_msg + proxy_error_prompt + error_retrieve_prompt
297
- return chatbot, history, status_text, all_token_counts
298
- except requests.exceptions.SSLError:
299
- status_text = standard_error_msg + ssl_error_prompt + error_retrieve_prompt
300
- return chatbot, history, status_text, all_token_counts
301
- response = json.loads(response.text)
302
- content = response["choices"][0]["message"]["content"]
303
- history[-1] = construct_assistant(content)
304
- chatbot[-1] = (parse_text(inputs), parse_text(content))
305
- total_token_count = response["usage"]["total_tokens"]
306
- all_token_counts[-1] = total_token_count - sum(all_token_counts)
307
- status_text = construct_token_message(total_token_count)
308
- return chatbot, history, status_text, all_token_counts
309
-
310
-
311
- def predict(
312
- openai_api_key,
313
- system_prompt,
314
- history,
315
- inputs,
316
- chatbot,
317
- all_token_counts,
318
- top_p,
319
- temperature,
320
- stream=False,
321
- selected_model=MODELS[0],
322
- use_websearch_checkbox=False,
323
- should_check_token_count=True,
324
- ): # repetition_penalty, top_k
325
- logging.info("输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
326
- if use_websearch_checkbox:
327
- results = ddg(inputs, max_results=3)
328
- web_results = []
329
- for idx, result in enumerate(results):
330
- logging.info(f"搜索结果{idx + 1}:{result}")
331
- web_results.append(f'[{idx+1}]"{result["body"]}"\nURL: {result["href"]}')
332
- web_results = "\n\n".join(web_results)
333
- today = datetime.datetime.today().strftime("%Y-%m-%d")
334
- inputs = (
335
- websearch_prompt.replace("{current_date}", today)
336
- .replace("{query}", inputs)
337
- .replace("{web_results}", web_results)
338
- )
339
- if len(openai_api_key) != 51:
340
- status_text = standard_error_msg + no_apikey_msg
341
- logging.info(status_text)
342
- chatbot.append((parse_text(inputs), ""))
343
- if len(history) == 0:
344
- history.append(construct_user(inputs))
345
- history.append("")
346
- all_token_counts.append(0)
347
- else:
348
- history[-2] = construct_user(inputs)
349
- yield chatbot, history, status_text, all_token_counts
350
- return
351
- if stream:
352
- yield chatbot, history, "开始生成回答……", all_token_counts
353
- if stream:
354
- logging.info("使用流式传输")
355
- iter = stream_predict(
356
- openai_api_key,
357
- system_prompt,
358
- history,
359
- inputs,
360
- chatbot,
361
- all_token_counts,
362
- top_p,
363
- temperature,
364
- selected_model,
365
- )
366
- for chatbot, history, status_text, all_token_counts in iter:
367
- yield chatbot, history, status_text, all_token_counts
368
- else:
369
- logging.info("不使用流式传输")
370
- chatbot, history, status_text, all_token_counts = predict_all(
371
- openai_api_key,
372
- system_prompt,
373
- history,
374
- inputs,
375
- chatbot,
376
- all_token_counts,
377
- top_p,
378
- temperature,
379
- selected_model,
380
- )
381
- yield chatbot, history, status_text, all_token_counts
382
- logging.info(f"传输完毕。当前token计数为{all_token_counts}")
383
- if len(history) > 1 and history[-1]["content"] != inputs:
384
- logging.info(
385
- "回答为:"
386
- + colorama.Fore.BLUE
387
- + f"{history[-1]['content']}"
388
- + colorama.Style.RESET_ALL
389
- )
390
- if stream:
391
- max_token = max_token_streaming
392
- else:
393
- max_token = max_token_all
394
- if sum(all_token_counts) > max_token and should_check_token_count:
395
- status_text = f"精简token中{all_token_counts}/{max_token}"
396
- logging.info(status_text)
397
- yield chatbot, history, status_text, all_token_counts
398
- iter = reduce_token_size(
399
- openai_api_key,
400
- system_prompt,
401
- history,
402
- chatbot,
403
- all_token_counts,
404
- top_p,
405
- temperature,
406
- stream=False,
407
- selected_model=selected_model,
408
- hidden=True,
409
- )
410
- for chatbot, history, status_text, all_token_counts in iter:
411
- status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}"
412
- yield chatbot, history, status_text, all_token_counts
413
-
414
-
415
- def retry(
416
- openai_api_key,
417
- system_prompt,
418
- history,
419
- chatbot,
420
- token_count,
421
- top_p,
422
- temperature,
423
- stream=False,
424
- selected_model=MODELS[0],
425
- ):
426
- logging.info("重试中……")
427
- if len(history) == 0:
428
- yield chatbot, history, f"{standard_error_msg}上下文是空的", token_count
429
- return
430
- history.pop()
431
- inputs = history.pop()["content"]
432
- token_count.pop()
433
- iter = predict(
434
- openai_api_key,
435
- system_prompt,
436
- history,
437
- inputs,
438
- chatbot,
439
- token_count,
440
- top_p,
441
- temperature,
442
- stream=stream,
443
- selected_model=selected_model,
444
- )
445
- logging.info("重试完毕")
446
- for x in iter:
447
- yield x
448
-
449
-
450
- def reduce_token_size(
451
- openai_api_key,
452
- system_prompt,
453
- history,
454
- chatbot,
455
- token_count,
456
- top_p,
457
- temperature,
458
- stream=False,
459
- selected_model=MODELS[0],
460
- hidden=False,
461
- ):
462
- logging.info("开始减少token数量……")
463
- iter = predict(
464
- openai_api_key,
465
- system_prompt,
466
- history,
467
- summarize_prompt,
468
- chatbot,
469
- token_count,
470
- top_p,
471
- temperature,
472
- stream=stream,
473
- selected_model=selected_model,
474
- should_check_token_count=False,
475
- )
476
- logging.info(f"chatbot: {chatbot}")
477
- for chatbot, history, status_text, previous_token_count in iter:
478
- history = history[-2:]
479
- token_count = previous_token_count[-1:]
480
- if hidden:
481
- chatbot.pop()
482
- yield chatbot, history, construct_token_message(
483
- sum(token_count), stream=stream
484
- ), token_count
485
- logging.info("减少token数量完毕")
486
-
487
-
488
  def delete_last_conversation(chatbot, history, previous_token_count):
489
  if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
490
  logging.info("由于包含报错信息,只删除chatbot记录")
@@ -643,6 +256,7 @@ def reset_state():
643
  def reset_textbox():
644
  return gr.update(value="")
645
 
 
646
  def reset_default():
647
  global API_URL
648
  API_URL = "https://api.openai.com/v1/chat/completions"
@@ -650,6 +264,7 @@ def reset_default():
650
  os.environ.pop("https_proxy", None)
651
  return gr.update(value=API_URL), gr.update(value=""), "API URL 和代理已重置"
652
 
 
653
  def change_api_url(url):
654
  global API_URL
655
  API_URL = url
@@ -657,22 +272,41 @@ def change_api_url(url):
657
  logging.info(msg)
658
  return msg
659
 
 
660
  def change_proxy(proxy):
661
  os.environ["HTTPS_PROXY"] = proxy
662
  msg = f"代理更改为了{proxy}"
663
  logging.info(msg)
664
  return msg
665
 
 
666
  def hide_middle_chars(s):
667
  if len(s) <= 8:
668
  return s
669
  else:
670
  head = s[:4]
671
  tail = s[-4:]
672
- hidden = '*' * (len(s) - 8)
673
  return head + hidden + tail
674
 
 
675
  def submit_key(key):
676
  msg = f"API密钥更改为了{hide_middle_chars(key)}"
677
  logging.info(msg)
678
  return key, msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  import tiktoken
19
  from tqdm import tqdm
20
  import colorama
21
+ import os
22
+ from llama_index import (
23
+ GPTSimpleVectorIndex,
24
+ GPTTreeIndex,
25
+ GPTKeywordTableIndex,
26
+ GPTListIndex,
27
+ )
28
+ from llama_index import SimpleDirectoryReader, download_loader
29
+ from llama_index import (
30
+ Document,
31
+ LLMPredictor,
32
+ PromptHelper,
33
+ QuestionAnswerPrompt,
34
+ RefinePrompt,
35
+ )
36
+ from langchain.llms import OpenAIChat, OpenAI
37
  from duckduckgo_search import ddg
38
  import datetime
39
+ import hashlib
40
 
41
  # logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
42
 
 
54
  TEMPLATES_DIR = "templates"
55
 
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def count_token(message):
58
  encoding = tiktoken.get_encoding("cl100k_base")
59
  input_str = f"role: {message['role']}, content: {message['content']}"
 
98
  return f"Token 计数: {token}"
99
 
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  def delete_last_conversation(chatbot, history, previous_token_count):
102
  if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
103
  logging.info("由于包含报错信息,只删除chatbot记录")
 
256
  def reset_textbox():
257
  return gr.update(value="")
258
 
259
+
260
  def reset_default():
261
  global API_URL
262
  API_URL = "https://api.openai.com/v1/chat/completions"
 
264
  os.environ.pop("https_proxy", None)
265
  return gr.update(value=API_URL), gr.update(value=""), "API URL 和代理已重置"
266
 
267
+
268
  def change_api_url(url):
269
  global API_URL
270
  API_URL = url
 
272
  logging.info(msg)
273
  return msg
274
 
275
+
276
  def change_proxy(proxy):
277
  os.environ["HTTPS_PROXY"] = proxy
278
  msg = f"代理更改为了{proxy}"
279
  logging.info(msg)
280
  return msg
281
 
282
+
283
  def hide_middle_chars(s):
284
  if len(s) <= 8:
285
  return s
286
  else:
287
  head = s[:4]
288
  tail = s[-4:]
289
+ hidden = "*" * (len(s) - 8)
290
  return head + hidden + tail
291
 
292
+
293
  def submit_key(key):
294
  msg = f"API密钥更改为了{hide_middle_chars(key)}"
295
  logging.info(msg)
296
  return key, msg
297
+
298
+
299
+ def sha1sum(filename):
300
+ sha1 = hashlib.sha1()
301
+ with open(filename, "rb") as f:
302
+ while True:
303
+ data = f.read(65536)
304
+ if not data:
305
+ break
306
+ sha1.update(data)
307
+ return sha1.hexdigest()
308
+
309
+
310
+ def replace_today(prompt):
311
+ today = datetime.datetime.today().strftime("%Y-%m-%d")
312
+ return prompt.replace("{current_date}", today)