Tuchuanhuhuhu commited on
Commit
5879508
·
1 Parent(s): a8cb0a3

川虎助理加入文件索引功能

Browse files
modules/models/ChuanhuAgent.py CHANGED
@@ -30,7 +30,12 @@ from collections import deque
30
 
31
  from .base_model import BaseLLMModel, CallbackToIterator, ChuanhuCallbackHandler
32
  from ..config import default_chuanhu_assistant_model
33
- from ..presets import SUMMARIZE_PROMPT
 
 
 
 
 
34
  import logging
35
 
36
  class WebBrowsingInput(BaseModel):
@@ -50,6 +55,8 @@ class ChuanhuAgent_Client(BaseLLMModel):
50
  self.cheap_llm = ChatOpenAI(openai_api_key=openai_api_key, temperature=0, model_name="gpt-3.5-turbo")
51
  PROMPT = PromptTemplate(template=SUMMARIZE_PROMPT, input_variables=["text"])
52
  self.summarize_chain = load_summarize_chain(self.cheap_llm, chain_type="map_reduce", return_intermediate_steps=True, map_prompt=PROMPT, combine_prompt=PROMPT)
 
 
53
  if "Pro" in self.model_name:
54
  self.tools = load_tools(["google-search-results-json", "llm-math", "arxiv", "wikipedia", "wolfram-alpha"], llm=self.llm)
55
  else:
@@ -73,6 +80,39 @@ class ChuanhuAgent_Client(BaseLLMModel):
73
  )
74
  )
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  def summary(self, text):
77
  texts = Document(page_content=text)
78
  texts = self.text_splitter.split_documents([texts])
@@ -119,6 +159,16 @@ class ChuanhuAgent_Client(BaseLLMModel):
119
  it = CallbackToIterator()
120
  manager = BaseCallbackManager(handlers=[ChuanhuCallbackHandler(it.callback)])
121
  def thread_func():
 
 
 
 
 
 
 
 
 
 
122
  agent = initialize_agent(self.tools, self.llm, agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True, callback_manager=manager)
123
  reply = agent.run(input=f"{question} Reply in 简体中文")
124
  it.callback(reply)
 
30
 
31
  from .base_model import BaseLLMModel, CallbackToIterator, ChuanhuCallbackHandler
32
  from ..config import default_chuanhu_assistant_model
33
+ from ..presets import SUMMARIZE_PROMPT, i18n
34
+ from ..index_func import construct_index
35
+
36
+ from langchain.callbacks import get_openai_callback
37
+ import os
38
+ import gradio as gr
39
  import logging
40
 
41
  class WebBrowsingInput(BaseModel):
 
55
  self.cheap_llm = ChatOpenAI(openai_api_key=openai_api_key, temperature=0, model_name="gpt-3.5-turbo")
56
  PROMPT = PromptTemplate(template=SUMMARIZE_PROMPT, input_variables=["text"])
57
  self.summarize_chain = load_summarize_chain(self.cheap_llm, chain_type="map_reduce", return_intermediate_steps=True, map_prompt=PROMPT, combine_prompt=PROMPT)
58
+ self.index_summary = None
59
+ self.index = None
60
  if "Pro" in self.model_name:
61
  self.tools = load_tools(["google-search-results-json", "llm-math", "arxiv", "wikipedia", "wolfram-alpha"], llm=self.llm)
62
  else:
 
80
  )
81
  )
82
 
83
+ def handle_file_upload(self, files, chatbot, language):
84
+ """if the model accepts multi modal input, implement this function"""
85
+ status = gr.Markdown.update()
86
+ if files:
87
+ index = construct_index(self.api_key, file_src=files)
88
+ assert index is not None, "获取索引失败"
89
+ self.index = index
90
+ status = i18n("索引构建完成")
91
+ # Summarize the document
92
+ logging.info(i18n("生成内容总结中……"))
93
+ with get_openai_callback() as cb:
94
+ os.environ["OPENAI_API_KEY"] = self.api_key
95
+ from langchain.chains.summarize import load_summarize_chain
96
+ from langchain.prompts import PromptTemplate
97
+ from langchain.chat_models import ChatOpenAI
98
+ prompt_template = "Write a concise summary of the following:\n\n{text}\n\nCONCISE SUMMARY IN " + language + ":"
99
+ PROMPT = PromptTemplate(template=prompt_template, input_variables=["text"])
100
+ llm = ChatOpenAI()
101
+ chain = load_summarize_chain(llm, chain_type="map_reduce", return_intermediate_steps=True, map_prompt=PROMPT, combine_prompt=PROMPT)
102
+ summary = chain({"input_documents": list(index.docstore.__dict__["_dict"].values())}, return_only_outputs=True)["output_text"]
103
+ logging.info(f"Summary: {summary}")
104
+ self.index_summary = summary
105
+ logging.info(cb)
106
+ return gr.Files.update(), chatbot, status
107
+
108
+ def query_index(self, query):
109
+ if self.index is not None:
110
+ retriever = self.index.as_retriever()
111
+ qa = RetrievalQA.from_chain_type(llm=self.llm, chain_type="stuff", retriever=retriever)
112
+ return qa.run(query)
113
+ else:
114
+ "Error during query."
115
+
116
  def summary(self, text):
117
  texts = Document(page_content=text)
118
  texts = self.text_splitter.split_documents([texts])
 
159
  it = CallbackToIterator()
160
  manager = BaseCallbackManager(handlers=[ChuanhuCallbackHandler(it.callback)])
161
  def thread_func():
162
+ tools = self.tools
163
+ if self.index is not None:
164
+ tools.append(
165
+ Tool.from_function(
166
+ func=self.query_index,
167
+ name="Query Knowledge Base",
168
+ description=f"useful when you need to know about: {self.index_summary}",
169
+ args_schema=WebBrowsingInput
170
+ )
171
+ )
172
  agent = initialize_agent(self.tools, self.llm, agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True, callback_manager=manager)
173
  reply = agent.run(input=f"{question} Reply in 简体中文")
174
  it.callback(reply)
modules/models/base_model.py CHANGED
@@ -20,7 +20,6 @@ from enum import Enum
20
 
21
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
22
  from langchain.callbacks.manager import BaseCallbackManager
23
- from langchain.callbacks import get_openai_callback
24
 
25
  from typing import Any, Dict, List, Optional, Union
26
 
@@ -264,22 +263,6 @@ class BaseLLMModel:
264
  if files:
265
  index = construct_index(self.api_key, file_src=files)
266
  status = i18n("索引构建完成")
267
- # Summarize the document
268
- logging.info(i18n("生成内容总结中……"))
269
- with get_openai_callback() as cb:
270
- os.environ["OPENAI_API_KEY"] = self.api_key
271
- from langchain.chains.summarize import load_summarize_chain
272
- from langchain.prompts import PromptTemplate
273
- from langchain.chat_models import ChatOpenAI
274
- from langchain.callbacks import StdOutCallbackHandler
275
- prompt_template = "Write a concise summary of the following:\n\n{text}\n\nCONCISE SUMMARY IN " + language + ":"
276
- PROMPT = PromptTemplate(template=prompt_template, input_variables=["text"])
277
- llm = ChatOpenAI()
278
- chain = load_summarize_chain(llm, chain_type="map_reduce", return_intermediate_steps=True, map_prompt=PROMPT, combine_prompt=PROMPT)
279
- summary = chain({"input_documents": list(index.docstore.__dict__["_dict"].values())}, return_only_outputs=True)["output_text"]
280
- print(i18n("总结") + f": {summary}")
281
- chatbot.append([i18n("上传了")+str(len(files))+"个文件", summary])
282
- logging.info(cb)
283
  return gr.Files.update(), chatbot, status
284
 
285
  def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot):
 
20
 
21
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
22
  from langchain.callbacks.manager import BaseCallbackManager
 
23
 
24
  from typing import Any, Dict, List, Optional, Union
25
 
 
263
  if files:
264
  index = construct_index(self.api_key, file_src=files)
265
  status = i18n("索引构建完成")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  return gr.Files.update(), chatbot, status
267
 
268
  def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot):