|
import json |
|
from langchain.llms.base import LLM |
|
from transformers import AutoTokenizer, AutoModel, AutoConfig |
|
from typing import List, Optional |
|
from utils import tool_config_from_file |
|
|
|
|
|
class ChatGLM3(LLM): |
|
max_token: int = 8192 |
|
do_sample: bool = False |
|
temperature: float = 0.8 |
|
top_p = 0.8 |
|
tokenizer: object = None |
|
model: object = None |
|
history: List = [] |
|
tool_names: List = [] |
|
has_search: bool = False |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
@property |
|
def _llm_type(self) -> str: |
|
return "ChatGLM3" |
|
|
|
def load_model(self, model_name_or_path=None): |
|
model_config = AutoConfig.from_pretrained( |
|
model_name_or_path, |
|
trust_remote_code=True |
|
) |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
model_name_or_path, |
|
trust_remote_code=True |
|
) |
|
self.model = AutoModel.from_pretrained( |
|
model_name_or_path, config=model_config, trust_remote_code=True |
|
).half().cuda() |
|
|
|
def _tool_history(self, prompt: str): |
|
ans = [] |
|
tool_prompts = prompt.split( |
|
"You have access to the following tools:\n\n")[1].split("\n\nUse a json blob")[0].split("\n") |
|
|
|
tool_names = [tool.split(":")[0] for tool in tool_prompts] |
|
self.tool_names = tool_names |
|
tools_json = [] |
|
for i, tool in enumerate(tool_names): |
|
tool_config = tool_config_from_file(tool) |
|
if tool_config: |
|
tools_json.append(tool_config) |
|
else: |
|
ValueError( |
|
f"Tool {tool} config not found! It's description is {tool_prompts[i]}" |
|
) |
|
|
|
ans.append({ |
|
"role": "system", |
|
"content": "Answer the following questions as best as you can. You have access to the following tools:", |
|
"tools": tools_json |
|
}) |
|
query = f"""{prompt.split("Human: ")[-1].strip()}""" |
|
return ans, query |
|
|
|
def _extract_observation(self, prompt: str): |
|
return_json = prompt.split("Observation: ")[-1].split("\nThought:")[0] |
|
self.history.append({ |
|
"role": "observation", |
|
"content": return_json |
|
}) |
|
return |
|
|
|
def _extract_tool(self): |
|
if len(self.history[-1]["metadata"]) > 0: |
|
metadata = self.history[-1]["metadata"] |
|
content = self.history[-1]["content"] |
|
if "tool_call" in content: |
|
for tool in self.tool_names: |
|
if tool in metadata: |
|
input_para = content.split("='")[-1].split("'")[0] |
|
action_json = { |
|
"action": tool, |
|
"action_input": input_para |
|
} |
|
self.has_search = True |
|
return f""" |
|
Action: |
|
``` |
|
{json.dumps(action_json, ensure_ascii=False)} |
|
```""" |
|
final_answer_json = { |
|
"action": "Final Answer", |
|
"action_input": self.history[-1]["content"] |
|
} |
|
self.has_search = False |
|
return f""" |
|
Action: |
|
``` |
|
{json.dumps(final_answer_json, ensure_ascii=False)} |
|
```""" |
|
|
|
def _call(self, prompt: str, history: List = [], stop: Optional[List[str]] = ["<|user|>"]): |
|
print("======") |
|
print(self.prompt) |
|
print("======") |
|
if not self.has_search: |
|
self.history, query = self._tool_history(prompt) |
|
else: |
|
self._extract_observation(prompt) |
|
query = "" |
|
|
|
|
|
|
|
_, self.history = self.model.chat( |
|
self.tokenizer, |
|
query, |
|
history=self.history, |
|
do_sample=self.do_sample, |
|
max_length=self.max_token, |
|
temperature=self.temperature, |
|
) |
|
response = self._extract_tool() |
|
history.append((prompt, response)) |
|
return response |
|
|