Tuchuanhuhuhu commited on
Commit
8043b80
·
1 Parent(s): 7042605

添加PaLM支持(未完成)

Browse files
modules/models/PaLM.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_model import BaseLLMModel, CallbackToIterator, ChuanhuCallbackHandler
2
+ from langchain.chat_models import ChatGooglePalm
3
+
4
+ class PaLM_Client(BaseLLMModel):
5
+ def __init__(self, model_name, user="") -> None:
6
+ super().__init__(model_name, user)
7
+ self.llm = ChatGooglePalm(google_api_key="")
8
+
9
+ def get_answer_at_once(self):
10
+ self.llm.generate(self.history)
modules/models/base_model.py CHANGED
@@ -20,6 +20,7 @@ from enum import Enum
20
 
21
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
22
  from langchain.callbacks.manager import BaseCallbackManager
 
23
 
24
  from typing import Any, Dict, List, Optional, Union
25
 
@@ -108,6 +109,7 @@ class ModelType(Enum):
108
  MOSS = 5
109
  YuanAI = 6
110
  ChuanhuAgent = 7
 
111
 
112
  @classmethod
113
  def get_type(cls, model_name: str):
@@ -129,6 +131,8 @@ class ModelType(Enum):
129
  model_type = ModelType.YuanAI
130
  elif "川虎助理" in model_name_lower:
131
  model_type = ModelType.ChuanhuAgent
 
 
132
  else:
133
  model_type = ModelType.Unknown
134
  return model_type
@@ -262,18 +266,20 @@ class BaseLLMModel:
262
  status = i18n("索引构建完成")
263
  # Summarize the document
264
  logging.info(i18n("生成内容总结中……"))
265
- os.environ["OPENAI_API_KEY"] = self.api_key
266
- from langchain.chains.summarize import load_summarize_chain
267
- from langchain.prompts import PromptTemplate
268
- from langchain.chat_models import ChatOpenAI
269
- from langchain.callbacks import StdOutCallbackHandler
270
- prompt_template = "Write a concise summary of the following:\n\n{text}\n\nCONCISE SUMMARY IN " + language + ":"
271
- PROMPT = PromptTemplate(template=prompt_template, input_variables=["text"])
272
- llm = ChatOpenAI()
273
- chain = load_summarize_chain(llm, chain_type="map_reduce", return_intermediate_steps=True, map_prompt=PROMPT, combine_prompt=PROMPT)
274
- summary = chain({"input_documents": list(index.docstore.__dict__["_dict"].values())}, return_only_outputs=True)["output_text"]
275
- print(i18n("总结") + f": {summary}")
276
- chatbot.append([i18n("上传了")+len(files)+"个文件", summary])
 
 
277
  return gr.Files.update(), chatbot, status
278
 
279
  def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot):
 
20
 
21
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
22
  from langchain.callbacks.manager import BaseCallbackManager
23
+ from langchain.callbacks import get_openai_callback
24
 
25
  from typing import Any, Dict, List, Optional, Union
26
 
 
109
  MOSS = 5
110
  YuanAI = 6
111
  ChuanhuAgent = 7
112
+ PaLM = 8
113
 
114
  @classmethod
115
  def get_type(cls, model_name: str):
 
131
  model_type = ModelType.YuanAI
132
  elif "川虎助理" in model_name_lower:
133
  model_type = ModelType.ChuanhuAgent
134
+ elif "palm" in model_name_lower:
135
+ model_type = ModelType.PaLM
136
  else:
137
  model_type = ModelType.Unknown
138
  return model_type
 
266
  status = i18n("索引构建完成")
267
  # Summarize the document
268
  logging.info(i18n("生成内容总结中……"))
269
+ with get_openai_callback() as cb:
270
+ os.environ["OPENAI_API_KEY"] = self.api_key
271
+ from langchain.chains.summarize import load_summarize_chain
272
+ from langchain.prompts import PromptTemplate
273
+ from langchain.chat_models import ChatOpenAI
274
+ from langchain.callbacks import StdOutCallbackHandler
275
+ prompt_template = "Write a concise summary of the following:\n\n{text}\n\nCONCISE SUMMARY IN " + language + ":"
276
+ PROMPT = PromptTemplate(template=prompt_template, input_variables=["text"])
277
+ llm = ChatOpenAI()
278
+ chain = load_summarize_chain(llm, chain_type="map_reduce", return_intermediate_steps=True, map_prompt=PROMPT, combine_prompt=PROMPT)
279
+ summary = chain({"input_documents": list(index.docstore.__dict__["_dict"].values())}, return_only_outputs=True)["output_text"]
280
+ print(i18n("总结") + f": {summary}")
281
+ chatbot.append([i18n("上传了")+str(len(files))+"个文件", summary])
282
+ logging.info(cb)
283
  return gr.Files.update(), chatbot, status
284
 
285
  def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot):
modules/models/models.py CHANGED
@@ -606,6 +606,9 @@ def get_model(
606
  elif model_type == ModelType.ChuanhuAgent:
607
  from .ChuanhuAgent import ChuanhuAgent_Client
608
  model = ChuanhuAgent_Client(model_name, access_key, user_name=user_name)
 
 
 
609
  elif model_type == ModelType.Unknown:
610
  raise ValueError(f"未知模型: {model_name}")
611
  logging.info(msg)
 
606
  elif model_type == ModelType.ChuanhuAgent:
607
  from .ChuanhuAgent import ChuanhuAgent_Client
608
  model = ChuanhuAgent_Client(model_name, access_key, user_name=user_name)
609
+ elif model_type == ModelType.PaLM:
610
+ from .PaLM import PaLM_Client
611
+ model = PaLM_Client(model_name, user_name=user_name)
612
  elif model_type == ModelType.Unknown:
613
  raise ValueError(f"未知模型: {model_name}")
614
  logging.info(msg)
modules/presets.py CHANGED
@@ -68,6 +68,7 @@ ONLINE_MODELS = [
68
  "gpt-4-32k",
69
  "gpt-4-32k-0314",
70
  "xmchat",
 
71
  "yuanai-1.0-base_10B",
72
  "yuanai-1.0-translate",
73
  "yuanai-1.0-dialog",
 
68
  "gpt-4-32k",
69
  "gpt-4-32k-0314",
70
  "xmchat",
71
+ "Google PaLM",
72
  "yuanai-1.0-base_10B",
73
  "yuanai-1.0-translate",
74
  "yuanai-1.0-dialog",