Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu
commited on
Commit
·
8043b80
1
Parent(s):
7042605
添加PaLM支持(未完成)
Browse files- modules/models/PaLM.py +10 -0
- modules/models/base_model.py +18 -12
- modules/models/models.py +3 -0
- modules/presets.py +1 -0
modules/models/PaLM.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_model import BaseLLMModel, CallbackToIterator, ChuanhuCallbackHandler
|
2 |
+
from langchain.chat_models import ChatGooglePalm
|
3 |
+
|
4 |
+
class PaLM_Client(BaseLLMModel):
|
5 |
+
def __init__(self, model_name, user="") -> None:
|
6 |
+
super().__init__(model_name, user)
|
7 |
+
self.llm = ChatGooglePalm(google_api_key="")
|
8 |
+
|
9 |
+
def get_answer_at_once(self):
|
10 |
+
self.llm.generate(self.history)
|
modules/models/base_model.py
CHANGED
@@ -20,6 +20,7 @@ from enum import Enum
|
|
20 |
|
21 |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
22 |
from langchain.callbacks.manager import BaseCallbackManager
|
|
|
23 |
|
24 |
from typing import Any, Dict, List, Optional, Union
|
25 |
|
@@ -108,6 +109,7 @@ class ModelType(Enum):
|
|
108 |
MOSS = 5
|
109 |
YuanAI = 6
|
110 |
ChuanhuAgent = 7
|
|
|
111 |
|
112 |
@classmethod
|
113 |
def get_type(cls, model_name: str):
|
@@ -129,6 +131,8 @@ class ModelType(Enum):
|
|
129 |
model_type = ModelType.YuanAI
|
130 |
elif "川虎助理" in model_name_lower:
|
131 |
model_type = ModelType.ChuanhuAgent
|
|
|
|
|
132 |
else:
|
133 |
model_type = ModelType.Unknown
|
134 |
return model_type
|
@@ -262,18 +266,20 @@ class BaseLLMModel:
|
|
262 |
status = i18n("索引构建完成")
|
263 |
# Summarize the document
|
264 |
logging.info(i18n("生成内容总结中……"))
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
|
|
|
|
277 |
return gr.Files.update(), chatbot, status
|
278 |
|
279 |
def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot):
|
|
|
20 |
|
21 |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
22 |
from langchain.callbacks.manager import BaseCallbackManager
|
23 |
+
from langchain.callbacks import get_openai_callback
|
24 |
|
25 |
from typing import Any, Dict, List, Optional, Union
|
26 |
|
|
|
109 |
MOSS = 5
|
110 |
YuanAI = 6
|
111 |
ChuanhuAgent = 7
|
112 |
+
PaLM = 8
|
113 |
|
114 |
@classmethod
|
115 |
def get_type(cls, model_name: str):
|
|
|
131 |
model_type = ModelType.YuanAI
|
132 |
elif "川虎助理" in model_name_lower:
|
133 |
model_type = ModelType.ChuanhuAgent
|
134 |
+
elif "palm" in model_name_lower:
|
135 |
+
model_type = ModelType.PaLM
|
136 |
else:
|
137 |
model_type = ModelType.Unknown
|
138 |
return model_type
|
|
|
266 |
status = i18n("索引构建完成")
|
267 |
# Summarize the document
|
268 |
logging.info(i18n("生成内容总结中……"))
|
269 |
+
with get_openai_callback() as cb:
|
270 |
+
os.environ["OPENAI_API_KEY"] = self.api_key
|
271 |
+
from langchain.chains.summarize import load_summarize_chain
|
272 |
+
from langchain.prompts import PromptTemplate
|
273 |
+
from langchain.chat_models import ChatOpenAI
|
274 |
+
from langchain.callbacks import StdOutCallbackHandler
|
275 |
+
prompt_template = "Write a concise summary of the following:\n\n{text}\n\nCONCISE SUMMARY IN " + language + ":"
|
276 |
+
PROMPT = PromptTemplate(template=prompt_template, input_variables=["text"])
|
277 |
+
llm = ChatOpenAI()
|
278 |
+
chain = load_summarize_chain(llm, chain_type="map_reduce", return_intermediate_steps=True, map_prompt=PROMPT, combine_prompt=PROMPT)
|
279 |
+
summary = chain({"input_documents": list(index.docstore.__dict__["_dict"].values())}, return_only_outputs=True)["output_text"]
|
280 |
+
print(i18n("总结") + f": {summary}")
|
281 |
+
chatbot.append([i18n("上传了")+str(len(files))+"个文件", summary])
|
282 |
+
logging.info(cb)
|
283 |
return gr.Files.update(), chatbot, status
|
284 |
|
285 |
def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot):
|
modules/models/models.py
CHANGED
@@ -606,6 +606,9 @@ def get_model(
|
|
606 |
elif model_type == ModelType.ChuanhuAgent:
|
607 |
from .ChuanhuAgent import ChuanhuAgent_Client
|
608 |
model = ChuanhuAgent_Client(model_name, access_key, user_name=user_name)
|
|
|
|
|
|
|
609 |
elif model_type == ModelType.Unknown:
|
610 |
raise ValueError(f"未知模型: {model_name}")
|
611 |
logging.info(msg)
|
|
|
606 |
elif model_type == ModelType.ChuanhuAgent:
|
607 |
from .ChuanhuAgent import ChuanhuAgent_Client
|
608 |
model = ChuanhuAgent_Client(model_name, access_key, user_name=user_name)
|
609 |
+
elif model_type == ModelType.PaLM:
|
610 |
+
from .PaLM import PaLM_Client
|
611 |
+
model = PaLM_Client(model_name, user_name=user_name)
|
612 |
elif model_type == ModelType.Unknown:
|
613 |
raise ValueError(f"未知模型: {model_name}")
|
614 |
logging.info(msg)
|
modules/presets.py
CHANGED
@@ -68,6 +68,7 @@ ONLINE_MODELS = [
|
|
68 |
"gpt-4-32k",
|
69 |
"gpt-4-32k-0314",
|
70 |
"xmchat",
|
|
|
71 |
"yuanai-1.0-base_10B",
|
72 |
"yuanai-1.0-translate",
|
73 |
"yuanai-1.0-dialog",
|
|
|
68 |
"gpt-4-32k",
|
69 |
"gpt-4-32k-0314",
|
70 |
"xmchat",
|
71 |
+
"Google PaLM",
|
72 |
"yuanai-1.0-base_10B",
|
73 |
"yuanai-1.0-translate",
|
74 |
"yuanai-1.0-dialog",
|