Spaces:
Configuration error
Configuration error
uploaded all small directories
Browse files- __pycache__/st_utils.cpython-311.pyc +0 -0
- core/.DS_Store +0 -0
- core/__init__.py +1 -0
- core/__pycache__/__init__.cpython-311.pyc +0 -0
- core/__pycache__/builder_config.cpython-311.pyc +0 -0
- core/__pycache__/callback_manager.cpython-311.pyc +0 -0
- core/__pycache__/constants.cpython-311.pyc +0 -0
- core/__pycache__/param_cache.cpython-311.pyc +0 -0
- core/__pycache__/utils.cpython-311.pyc +0 -0
- core/agent_builder/.DS_Store +0 -0
- core/agent_builder/__init__.py +0 -0
- core/agent_builder/__pycache__/__init__.cpython-311.pyc +0 -0
- core/agent_builder/__pycache__/base.cpython-311.pyc +0 -0
- core/agent_builder/__pycache__/loader.cpython-311.pyc +0 -0
- core/agent_builder/__pycache__/multimodal.cpython-311.pyc +0 -0
- core/agent_builder/__pycache__/registry.cpython-311.pyc +0 -0
- core/agent_builder/base.py +250 -0
- core/agent_builder/loader.py +115 -0
- core/agent_builder/multimodal.py +256 -0
- core/agent_builder/registry.py +78 -0
- core/builder_config.py +20 -0
- core/callback_manager.py +70 -0
- core/constants.py +4 -0
- core/param_cache.py +156 -0
- core/utils.py +480 -0
- pages/.DS_Store +0 -0
- pages/4_🤖_ChatDoctor.py +126 -0
- tests/__init__.py +0 -0
__pycache__/st_utils.cpython-311.pyc
ADDED
Binary file (8.1 kB). View file
|
|
core/.DS_Store
ADDED
Binary file (8.2 kB). View file
|
|
core/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""Init file."""
|
core/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (216 Bytes). View file
|
|
core/__pycache__/builder_config.cpython-311.pyc
ADDED
Binary file (541 Bytes). View file
|
|
core/__pycache__/callback_manager.cpython-311.pyc
ADDED
Binary file (3.39 kB). View file
|
|
core/__pycache__/constants.cpython-311.pyc
ADDED
Binary file (490 Bytes). View file
|
|
core/__pycache__/param_cache.cpython-311.pyc
ADDED
Binary file (6.89 kB). View file
|
|
core/__pycache__/utils.cpython-311.pyc
ADDED
Binary file (19 kB). View file
|
|
core/agent_builder/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
core/agent_builder/__init__.py
ADDED
File without changes
|
core/agent_builder/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (206 Bytes). View file
|
|
core/agent_builder/__pycache__/base.cpython-311.pyc
ADDED
Binary file (10.4 kB). View file
|
|
core/agent_builder/__pycache__/loader.cpython-311.pyc
ADDED
Binary file (4.33 kB). View file
|
|
core/agent_builder/__pycache__/multimodal.cpython-311.pyc
ADDED
Binary file (12.1 kB). View file
|
|
core/agent_builder/__pycache__/registry.cpython-311.pyc
ADDED
Binary file (5.69 kB). View file
|
|
core/agent_builder/base.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Agent builder."""
|
2 |
+
|
3 |
+
from llama_index.llms import ChatMessage
|
4 |
+
from llama_index.prompts import ChatPromptTemplate
|
5 |
+
from typing import List, cast, Optional
|
6 |
+
from core.builder_config import BUILDER_LLM
|
7 |
+
from typing import Dict, Any
|
8 |
+
import uuid
|
9 |
+
from core.constants import AGENT_CACHE_DIR
|
10 |
+
from abc import ABC, abstractmethod
|
11 |
+
|
12 |
+
from core.param_cache import ParamCache, RAGParams
|
13 |
+
from core.utils import (
|
14 |
+
load_data,
|
15 |
+
get_tool_objects,
|
16 |
+
construct_agent,
|
17 |
+
)
|
18 |
+
from core.agent_builder.registry import AgentCacheRegistry
|
19 |
+
|
20 |
+
|
21 |
+
# System prompt tool
|
22 |
+
GEN_SYS_PROMPT_STR = """\
|
23 |
+
Task information is given below.
|
24 |
+
|
25 |
+
Given the task, please generate a system prompt for an OpenAI-powered bot \
|
26 |
+
to solve this task:
|
27 |
+
{task} \
|
28 |
+
|
29 |
+
Make sure the system prompt obeys the following requirements:
|
30 |
+
- Tells the bot to ALWAYS use tools given to solve the task. \
|
31 |
+
NEVER give an answer without using a tool.
|
32 |
+
- Does not reference a specific data source. \
|
33 |
+
The data source is implicit in any queries to the bot, \
|
34 |
+
and telling the bot to analyze a specific data source might confuse it given a \
|
35 |
+
user query.
|
36 |
+
|
37 |
+
"""
|
38 |
+
|
39 |
+
gen_sys_prompt_messages = [
|
40 |
+
ChatMessage(
|
41 |
+
role="system",
|
42 |
+
content="You are helping to build a system prompt for another bot.",
|
43 |
+
),
|
44 |
+
ChatMessage(role="user", content=GEN_SYS_PROMPT_STR),
|
45 |
+
]
|
46 |
+
|
47 |
+
GEN_SYS_PROMPT_TMPL = ChatPromptTemplate(gen_sys_prompt_messages)
|
48 |
+
|
49 |
+
|
50 |
+
class BaseRAGAgentBuilder(ABC):
|
51 |
+
"""Base RAG Agent builder class."""
|
52 |
+
|
53 |
+
@property
|
54 |
+
@abstractmethod
|
55 |
+
def cache(self) -> ParamCache:
|
56 |
+
"""Cache."""
|
57 |
+
|
58 |
+
@property
|
59 |
+
@abstractmethod
|
60 |
+
def agent_registry(self) -> AgentCacheRegistry:
|
61 |
+
"""Agent registry."""
|
62 |
+
|
63 |
+
|
64 |
+
class RAGAgentBuilder(BaseRAGAgentBuilder):
|
65 |
+
"""RAG Agent builder.
|
66 |
+
|
67 |
+
Contains a set of functions to construct a RAG agent, including:
|
68 |
+
- setting system prompts
|
69 |
+
- loading data
|
70 |
+
- adding web search
|
71 |
+
- setting parameters (e.g. top-k)
|
72 |
+
|
73 |
+
Must pass in a cache. This cache will be modified as the agent is built.
|
74 |
+
|
75 |
+
"""
|
76 |
+
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
cache: Optional[ParamCache] = None,
|
80 |
+
agent_registry: Optional[AgentCacheRegistry] = None,
|
81 |
+
) -> None:
|
82 |
+
"""Init params."""
|
83 |
+
self._cache = cache or ParamCache()
|
84 |
+
self._agent_registry = agent_registry or AgentCacheRegistry(
|
85 |
+
str(AGENT_CACHE_DIR)
|
86 |
+
)
|
87 |
+
|
88 |
+
@property
|
89 |
+
def cache(self) -> ParamCache:
|
90 |
+
"""Cache."""
|
91 |
+
return self._cache
|
92 |
+
|
93 |
+
@property
|
94 |
+
def agent_registry(self) -> AgentCacheRegistry:
|
95 |
+
"""Agent registry."""
|
96 |
+
return self._agent_registry
|
97 |
+
|
98 |
+
def create_system_prompt(self, task: str) -> str:
|
99 |
+
"""Create system prompt for another agent given an input task."""
|
100 |
+
llm = BUILDER_LLM
|
101 |
+
fmt_messages = GEN_SYS_PROMPT_TMPL.format_messages(task=task)
|
102 |
+
response = llm.chat(fmt_messages)
|
103 |
+
self._cache.system_prompt = response.message.content
|
104 |
+
|
105 |
+
return f"System prompt created: {response.message.content}"
|
106 |
+
|
107 |
+
def load_data(
|
108 |
+
self,
|
109 |
+
file_names: Optional[List[str]] = None,
|
110 |
+
directory: Optional[str] = None,
|
111 |
+
urls: Optional[List[str]] = None,
|
112 |
+
) -> str:
|
113 |
+
"""Load data for a given task.
|
114 |
+
|
115 |
+
Only ONE of file_names or directory or urls should be specified.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
file_names (Optional[List[str]]): List of file names to load.
|
119 |
+
Defaults to None.
|
120 |
+
directory (Optional[str]): Directory to load files from.
|
121 |
+
urls (Optional[List[str]]): List of urls to load.
|
122 |
+
Defaults to None.
|
123 |
+
|
124 |
+
"""
|
125 |
+
file_names = file_names or []
|
126 |
+
urls = urls or []
|
127 |
+
directory = directory or ""
|
128 |
+
docs = load_data(file_names=file_names, directory=directory, urls=urls)
|
129 |
+
self._cache.docs = docs
|
130 |
+
self._cache.file_names = file_names
|
131 |
+
self._cache.urls = urls
|
132 |
+
self._cache.directory = directory
|
133 |
+
return "Data loaded successfully."
|
134 |
+
|
135 |
+
def add_web_tool(self) -> str:
|
136 |
+
"""Add a web tool to enable agent to solve a task."""
|
137 |
+
# TODO: make this not hardcoded to a web tool
|
138 |
+
# Set up Metaphor tool
|
139 |
+
if "web_search" in self._cache.tools:
|
140 |
+
return "Web tool already added."
|
141 |
+
else:
|
142 |
+
self._cache.tools.append("web_search")
|
143 |
+
return "Web tool added successfully."
|
144 |
+
|
145 |
+
def get_rag_params(self) -> Dict:
|
146 |
+
"""Get parameters used to configure the RAG pipeline.
|
147 |
+
|
148 |
+
Should be called before `set_rag_params` so that the agent is aware of the
|
149 |
+
schema.
|
150 |
+
|
151 |
+
"""
|
152 |
+
rag_params = self._cache.rag_params
|
153 |
+
return rag_params.dict()
|
154 |
+
|
155 |
+
def set_rag_params(self, **rag_params: Dict) -> str:
|
156 |
+
"""Set RAG parameters.
|
157 |
+
|
158 |
+
These parameters will then be used to actually initialize the agent.
|
159 |
+
Should call `get_rag_params` first to get the schema of the input dictionary.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
**rag_params (Dict): dictionary of RAG parameters.
|
163 |
+
|
164 |
+
"""
|
165 |
+
new_dict = self._cache.rag_params.dict()
|
166 |
+
new_dict.update(rag_params)
|
167 |
+
rag_params_obj = RAGParams(**new_dict)
|
168 |
+
self._cache.rag_params = rag_params_obj
|
169 |
+
return "RAG parameters set successfully."
|
170 |
+
|
171 |
+
def create_agent(self, agent_id: Optional[str] = None) -> str:
|
172 |
+
"""Create an agent.
|
173 |
+
|
174 |
+
There are no parameters for this function because all the
|
175 |
+
functions should have already been called to set up the agent.
|
176 |
+
|
177 |
+
"""
|
178 |
+
if self._cache.system_prompt is None:
|
179 |
+
raise ValueError("Must set system prompt before creating agent.")
|
180 |
+
|
181 |
+
# construct additional tools
|
182 |
+
additional_tools = get_tool_objects(self.cache.tools)
|
183 |
+
agent, extra_info = construct_agent(
|
184 |
+
cast(str, self._cache.system_prompt),
|
185 |
+
cast(RAGParams, self._cache.rag_params),
|
186 |
+
self._cache.docs,
|
187 |
+
additional_tools=additional_tools,
|
188 |
+
)
|
189 |
+
|
190 |
+
# if agent_id not specified, randomly generate one
|
191 |
+
agent_id = agent_id or self._cache.agent_id or f"Agent_{str(uuid.uuid4())}"
|
192 |
+
self._cache.vector_index = extra_info["vector_index"]
|
193 |
+
self._cache.agent_id = agent_id
|
194 |
+
self._cache.agent = agent
|
195 |
+
|
196 |
+
# save the cache to disk
|
197 |
+
self._agent_registry.add_new_agent_cache(agent_id, self._cache)
|
198 |
+
return "Agent created successfully."
|
199 |
+
|
200 |
+
def update_agent(
|
201 |
+
self,
|
202 |
+
agent_id: str,
|
203 |
+
system_prompt: Optional[str] = None,
|
204 |
+
include_summarization: Optional[bool] = None,
|
205 |
+
top_k: Optional[int] = None,
|
206 |
+
chunk_size: Optional[int] = None,
|
207 |
+
embed_model: Optional[str] = None,
|
208 |
+
llm: Optional[str] = None,
|
209 |
+
additional_tools: Optional[List] = None,
|
210 |
+
) -> None:
|
211 |
+
"""Update agent.
|
212 |
+
|
213 |
+
Delete old agent by ID and create a new one.
|
214 |
+
Optionally update the system prompt and RAG parameters.
|
215 |
+
|
216 |
+
NOTE: Currently is manually called, not meant for agent use.
|
217 |
+
|
218 |
+
"""
|
219 |
+
self._agent_registry.delete_agent_cache(self.cache.agent_id)
|
220 |
+
|
221 |
+
# set agent id
|
222 |
+
self.cache.agent_id = agent_id
|
223 |
+
|
224 |
+
# set system prompt
|
225 |
+
if system_prompt is not None:
|
226 |
+
self.cache.system_prompt = system_prompt
|
227 |
+
# get agent_builder
|
228 |
+
# We call set_rag_params and create_agent, which will
|
229 |
+
# update the cache
|
230 |
+
# TODO: decouple functions from tool functions exposed to the agent
|
231 |
+
rag_params_dict: Dict[str, Any] = {}
|
232 |
+
if include_summarization is not None:
|
233 |
+
rag_params_dict["include_summarization"] = include_summarization
|
234 |
+
if top_k is not None:
|
235 |
+
rag_params_dict["top_k"] = top_k
|
236 |
+
if chunk_size is not None:
|
237 |
+
rag_params_dict["chunk_size"] = chunk_size
|
238 |
+
if embed_model is not None:
|
239 |
+
rag_params_dict["embed_model"] = embed_model
|
240 |
+
if llm is not None:
|
241 |
+
rag_params_dict["llm"] = llm
|
242 |
+
|
243 |
+
self.set_rag_params(**rag_params_dict)
|
244 |
+
|
245 |
+
# update tools
|
246 |
+
if additional_tools is not None:
|
247 |
+
self.cache.tools = additional_tools
|
248 |
+
|
249 |
+
# this will update the agent in the cache
|
250 |
+
self.create_agent()
|
core/agent_builder/loader.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Loader agent."""
|
2 |
+
|
3 |
+
from typing import List, cast, Optional
|
4 |
+
from llama_index.tools import FunctionTool
|
5 |
+
from llama_index.agent.types import BaseAgent
|
6 |
+
from core.builder_config import BUILDER_LLM
|
7 |
+
from typing import Tuple, Callable
|
8 |
+
import streamlit as st
|
9 |
+
|
10 |
+
from core.param_cache import ParamCache
|
11 |
+
from core.utils import (
|
12 |
+
load_meta_agent,
|
13 |
+
)
|
14 |
+
from core.agent_builder.registry import AgentCacheRegistry
|
15 |
+
from core.agent_builder.base import RAGAgentBuilder, BaseRAGAgentBuilder
|
16 |
+
from core.agent_builder.multimodal import MultimodalRAGAgentBuilder
|
17 |
+
|
18 |
+
####################
|
19 |
+
#### META Agent ####
|
20 |
+
####################
|
21 |
+
|
22 |
+
RAG_BUILDER_SYS_STR = """\
|
23 |
+
You are helping to construct an agent given a user-specified task.
|
24 |
+
You should generally use the tools in this rough order to build the agent.
|
25 |
+
|
26 |
+
1) Create system prompt tool: to create the system prompt for the agent.
|
27 |
+
2) Load in user-specified data (based on file paths they specify).
|
28 |
+
3) Decide whether or not to add additional tools.
|
29 |
+
4) Set parameters for the RAG pipeline.
|
30 |
+
5) Build the agent
|
31 |
+
|
32 |
+
This will be a back and forth conversation with the user. You should
|
33 |
+
continue asking users if there's anything else they want to do until
|
34 |
+
they say they're done. To help guide them on the process,
|
35 |
+
you can give suggestions on parameters they can set based on the tools they
|
36 |
+
have available (e.g. "Do you want to set the number of documents to retrieve?")
|
37 |
+
|
38 |
+
"""
|
39 |
+
|
40 |
+
|
41 |
+
### DEFINE Agent ####
|
42 |
+
# NOTE: here we define a function that is dependent on the LLM,
|
43 |
+
# please make sure to update the LLM above if you change the function below
|
44 |
+
|
45 |
+
|
46 |
+
def _get_builder_agent_tools(agent_builder: RAGAgentBuilder) -> List[FunctionTool]:
|
47 |
+
"""Get list of builder agent tools to pass to the builder agent."""
|
48 |
+
# see if metaphor api key is set, otherwise don't add web tool
|
49 |
+
# TODO: refactor this later
|
50 |
+
|
51 |
+
if "metaphor_key" in st.secrets:
|
52 |
+
fns: List[Callable] = [
|
53 |
+
agent_builder.create_system_prompt,
|
54 |
+
agent_builder.load_data,
|
55 |
+
agent_builder.add_web_tool,
|
56 |
+
agent_builder.get_rag_params,
|
57 |
+
agent_builder.set_rag_params,
|
58 |
+
agent_builder.create_agent,
|
59 |
+
]
|
60 |
+
else:
|
61 |
+
fns = [
|
62 |
+
agent_builder.create_system_prompt,
|
63 |
+
agent_builder.load_data,
|
64 |
+
agent_builder.get_rag_params,
|
65 |
+
agent_builder.set_rag_params,
|
66 |
+
agent_builder.create_agent,
|
67 |
+
]
|
68 |
+
|
69 |
+
fn_tools: List[FunctionTool] = [FunctionTool.from_defaults(fn=fn) for fn in fns]
|
70 |
+
return fn_tools
|
71 |
+
|
72 |
+
|
73 |
+
def _get_mm_builder_agent_tools(
|
74 |
+
agent_builder: MultimodalRAGAgentBuilder,
|
75 |
+
) -> List[FunctionTool]:
|
76 |
+
"""Get list of builder agent tools to pass to the builder agent."""
|
77 |
+
fns: List[Callable] = [
|
78 |
+
agent_builder.create_system_prompt,
|
79 |
+
agent_builder.load_data,
|
80 |
+
agent_builder.get_rag_params,
|
81 |
+
agent_builder.set_rag_params,
|
82 |
+
agent_builder.create_agent,
|
83 |
+
]
|
84 |
+
|
85 |
+
fn_tools: List[FunctionTool] = [FunctionTool.from_defaults(fn=fn) for fn in fns]
|
86 |
+
return fn_tools
|
87 |
+
|
88 |
+
|
89 |
+
# define agent
|
90 |
+
def load_meta_agent_and_tools(
|
91 |
+
cache: Optional[ParamCache] = None,
|
92 |
+
agent_registry: Optional[AgentCacheRegistry] = None,
|
93 |
+
is_multimodal: bool = False,
|
94 |
+
) -> Tuple[BaseAgent, BaseRAGAgentBuilder]:
|
95 |
+
"""Load meta agent and tools."""
|
96 |
+
|
97 |
+
if is_multimodal:
|
98 |
+
agent_builder: BaseRAGAgentBuilder = MultimodalRAGAgentBuilder(
|
99 |
+
cache, agent_registry=agent_registry
|
100 |
+
)
|
101 |
+
fn_tools = _get_mm_builder_agent_tools(
|
102 |
+
cast(MultimodalRAGAgentBuilder, agent_builder)
|
103 |
+
)
|
104 |
+
builder_agent = load_meta_agent(
|
105 |
+
fn_tools, llm=BUILDER_LLM, system_prompt=RAG_BUILDER_SYS_STR, verbose=True
|
106 |
+
)
|
107 |
+
else:
|
108 |
+
# think of this as tools for the agent to use
|
109 |
+
agent_builder = RAGAgentBuilder(cache, agent_registry=agent_registry)
|
110 |
+
fn_tools = _get_builder_agent_tools(agent_builder)
|
111 |
+
builder_agent = load_meta_agent(
|
112 |
+
fn_tools, llm=BUILDER_LLM, system_prompt=RAG_BUILDER_SYS_STR, verbose=True
|
113 |
+
)
|
114 |
+
|
115 |
+
return builder_agent, agent_builder
|
core/agent_builder/multimodal.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Multimodal agent builder."""
|
2 |
+
|
3 |
+
from llama_index.llms import ChatMessage
|
4 |
+
from typing import List, cast, Optional
|
5 |
+
from core.builder_config import BUILDER_LLM
|
6 |
+
from typing import Dict, Any
|
7 |
+
import uuid
|
8 |
+
from core.constants import AGENT_CACHE_DIR
|
9 |
+
|
10 |
+
from core.param_cache import ParamCache, RAGParams
|
11 |
+
from core.utils import (
|
12 |
+
load_data,
|
13 |
+
construct_mm_agent,
|
14 |
+
)
|
15 |
+
from core.agent_builder.registry import AgentCacheRegistry
|
16 |
+
from core.agent_builder.base import GEN_SYS_PROMPT_TMPL, BaseRAGAgentBuilder
|
17 |
+
|
18 |
+
from llama_index.chat_engine.types import BaseChatEngine
|
19 |
+
|
20 |
+
from llama_index.callbacks import trace_method
|
21 |
+
from llama_index.query_engine.multi_modal import SimpleMultiModalQueryEngine
|
22 |
+
from llama_index.chat_engine.types import (
|
23 |
+
AGENT_CHAT_RESPONSE_TYPE,
|
24 |
+
StreamingAgentChatResponse,
|
25 |
+
AgentChatResponse,
|
26 |
+
)
|
27 |
+
from llama_index.llms.base import ChatResponse
|
28 |
+
from typing import Generator
|
29 |
+
|
30 |
+
|
31 |
+
class MultimodalChatEngine(BaseChatEngine):
|
32 |
+
"""Multimodal chat engine.
|
33 |
+
|
34 |
+
This chat engine is a light wrapper around a query engine.
|
35 |
+
Offers no real 'chat' functionality, is a beta feature.
|
36 |
+
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, mm_query_engine: SimpleMultiModalQueryEngine) -> None:
|
40 |
+
"""Init params."""
|
41 |
+
self._mm_query_engine = mm_query_engine
|
42 |
+
|
43 |
+
def reset(self) -> None:
|
44 |
+
"""Reset conversation state."""
|
45 |
+
pass
|
46 |
+
|
47 |
+
@trace_method("chat")
|
48 |
+
def chat(
|
49 |
+
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
50 |
+
) -> AGENT_CHAT_RESPONSE_TYPE:
|
51 |
+
"""Main chat interface."""
|
52 |
+
# just return the top-k results
|
53 |
+
response = self._mm_query_engine.query(message)
|
54 |
+
return AgentChatResponse(response=str(response))
|
55 |
+
|
56 |
+
@trace_method("chat")
|
57 |
+
def stream_chat(
|
58 |
+
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
59 |
+
) -> StreamingAgentChatResponse:
|
60 |
+
"""Stream chat interface."""
|
61 |
+
response = self._mm_query_engine.query(message)
|
62 |
+
|
63 |
+
def _chat_stream(response: str) -> Generator[ChatResponse, None, None]:
|
64 |
+
yield ChatResponse(message=ChatMessage(role="assistant", content=response))
|
65 |
+
|
66 |
+
chat_stream = _chat_stream(str(response))
|
67 |
+
return StreamingAgentChatResponse(chat_stream=chat_stream)
|
68 |
+
|
69 |
+
@trace_method("chat")
|
70 |
+
async def achat(
|
71 |
+
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
72 |
+
) -> AGENT_CHAT_RESPONSE_TYPE:
|
73 |
+
"""Async version of main chat interface."""
|
74 |
+
response = await self._mm_query_engine.aquery(message)
|
75 |
+
return AgentChatResponse(response=str(response))
|
76 |
+
|
77 |
+
@trace_method("chat")
|
78 |
+
async def astream_chat(
|
79 |
+
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
80 |
+
) -> StreamingAgentChatResponse:
|
81 |
+
"""Async version of main chat interface."""
|
82 |
+
return self.stream_chat(message, chat_history)
|
83 |
+
|
84 |
+
|
85 |
+
class MultimodalRAGAgentBuilder(BaseRAGAgentBuilder):
|
86 |
+
"""Multimodal RAG Agent builder.
|
87 |
+
|
88 |
+
Contains a set of functions to construct a RAG agent, including:
|
89 |
+
- setting system prompts
|
90 |
+
- loading data
|
91 |
+
- adding web search
|
92 |
+
- setting parameters (e.g. top-k)
|
93 |
+
|
94 |
+
Must pass in a cache. This cache will be modified as the agent is built.
|
95 |
+
|
96 |
+
"""
|
97 |
+
|
98 |
+
def __init__(
|
99 |
+
self,
|
100 |
+
cache: Optional[ParamCache] = None,
|
101 |
+
agent_registry: Optional[AgentCacheRegistry] = None,
|
102 |
+
) -> None:
|
103 |
+
"""Init params."""
|
104 |
+
self._cache = cache or ParamCache()
|
105 |
+
self._agent_registry = agent_registry or AgentCacheRegistry(
|
106 |
+
str(AGENT_CACHE_DIR)
|
107 |
+
)
|
108 |
+
|
109 |
+
@property
|
110 |
+
def cache(self) -> ParamCache:
|
111 |
+
"""Cache."""
|
112 |
+
return self._cache
|
113 |
+
|
114 |
+
@property
|
115 |
+
def agent_registry(self) -> AgentCacheRegistry:
|
116 |
+
"""Agent registry."""
|
117 |
+
return self._agent_registry
|
118 |
+
|
119 |
+
def create_system_prompt(self, task: str) -> str:
|
120 |
+
"""Create system prompt for another agent given an input task."""
|
121 |
+
llm = BUILDER_LLM
|
122 |
+
fmt_messages = GEN_SYS_PROMPT_TMPL.format_messages(task=task)
|
123 |
+
response = llm.chat(fmt_messages)
|
124 |
+
self._cache.system_prompt = response.message.content
|
125 |
+
|
126 |
+
return f"System prompt created: {response.message.content}"
|
127 |
+
|
128 |
+
def load_data(
|
129 |
+
self,
|
130 |
+
file_names: Optional[List[str]] = None,
|
131 |
+
directory: Optional[str] = None,
|
132 |
+
) -> str:
|
133 |
+
"""Load data for a given task.
|
134 |
+
|
135 |
+
Only ONE of file_names or directory should be specified.
|
136 |
+
**NOTE**: urls not supported in multi-modal setting.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
file_names (Optional[List[str]]): List of file names to load.
|
140 |
+
Defaults to None.
|
141 |
+
directory (Optional[str]): Directory to load files from.
|
142 |
+
|
143 |
+
"""
|
144 |
+
file_names = file_names or []
|
145 |
+
directory = directory or ""
|
146 |
+
docs = load_data(file_names=file_names, directory=directory)
|
147 |
+
self._cache.docs = docs
|
148 |
+
self._cache.file_names = file_names
|
149 |
+
self._cache.directory = directory
|
150 |
+
return "Data loaded successfully."
|
151 |
+
|
152 |
+
def get_rag_params(self) -> Dict:
|
153 |
+
"""Get parameters used to configure the RAG pipeline.
|
154 |
+
|
155 |
+
Should be called before `set_rag_params` so that the agent is aware of the
|
156 |
+
schema.
|
157 |
+
|
158 |
+
"""
|
159 |
+
rag_params = self._cache.rag_params
|
160 |
+
return rag_params.dict()
|
161 |
+
|
162 |
+
def set_rag_params(self, **rag_params: Dict) -> str:
|
163 |
+
"""Set RAG parameters.
|
164 |
+
|
165 |
+
These parameters will then be used to actually initialize the agent.
|
166 |
+
Should call `get_rag_params` first to get the schema of the input dictionary.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
**rag_params (Dict): dictionary of RAG parameters.
|
170 |
+
|
171 |
+
"""
|
172 |
+
new_dict = self._cache.rag_params.dict()
|
173 |
+
new_dict.update(rag_params)
|
174 |
+
rag_params_obj = RAGParams(**new_dict)
|
175 |
+
self._cache.rag_params = rag_params_obj
|
176 |
+
return "RAG parameters set successfully."
|
177 |
+
|
178 |
+
def create_agent(self, agent_id: Optional[str] = None) -> str:
|
179 |
+
"""Create an agent.
|
180 |
+
|
181 |
+
There are no parameters for this function because all the
|
182 |
+
functions should have already been called to set up the agent.
|
183 |
+
|
184 |
+
"""
|
185 |
+
if self._cache.system_prompt is None:
|
186 |
+
raise ValueError("Must set system prompt before creating agent.")
|
187 |
+
|
188 |
+
# construct additional tools
|
189 |
+
agent, extra_info = construct_mm_agent(
|
190 |
+
cast(str, self._cache.system_prompt),
|
191 |
+
cast(RAGParams, self._cache.rag_params),
|
192 |
+
self._cache.docs,
|
193 |
+
)
|
194 |
+
|
195 |
+
# if agent_id not specified, randomly generate one
|
196 |
+
agent_id = agent_id or self._cache.agent_id or f"Agent_{str(uuid.uuid4())}"
|
197 |
+
self._cache.builder_type = "multimodal"
|
198 |
+
self._cache.vector_index = extra_info["vector_index"]
|
199 |
+
self._cache.agent_id = agent_id
|
200 |
+
self._cache.agent = agent
|
201 |
+
|
202 |
+
# save the cache to disk
|
203 |
+
self._agent_registry.add_new_agent_cache(agent_id, self._cache)
|
204 |
+
return "Agent created successfully."
|
205 |
+
|
206 |
+
def update_agent(
|
207 |
+
self,
|
208 |
+
agent_id: str,
|
209 |
+
system_prompt: Optional[str] = None,
|
210 |
+
include_summarization: Optional[bool] = None,
|
211 |
+
top_k: Optional[int] = None,
|
212 |
+
chunk_size: Optional[int] = None,
|
213 |
+
embed_model: Optional[str] = None,
|
214 |
+
llm: Optional[str] = None,
|
215 |
+
additional_tools: Optional[List] = None,
|
216 |
+
) -> None:
|
217 |
+
"""Update agent.
|
218 |
+
|
219 |
+
Delete old agent by ID and create a new one.
|
220 |
+
Optionally update the system prompt and RAG parameters.
|
221 |
+
|
222 |
+
NOTE: Currently is manually called, not meant for agent use.
|
223 |
+
|
224 |
+
"""
|
225 |
+
self._agent_registry.delete_agent_cache(self.cache.agent_id)
|
226 |
+
|
227 |
+
# set agent id
|
228 |
+
self.cache.agent_id = agent_id
|
229 |
+
|
230 |
+
# set system prompt
|
231 |
+
if system_prompt is not None:
|
232 |
+
self.cache.system_prompt = system_prompt
|
233 |
+
# get agent_builder
|
234 |
+
# We call set_rag_params and create_agent, which will
|
235 |
+
# update the cache
|
236 |
+
# TODO: decouple functions from tool functions exposed to the agent
|
237 |
+
rag_params_dict: Dict[str, Any] = {}
|
238 |
+
if include_summarization is not None:
|
239 |
+
rag_params_dict["include_summarization"] = include_summarization
|
240 |
+
if top_k is not None:
|
241 |
+
rag_params_dict["top_k"] = top_k
|
242 |
+
if chunk_size is not None:
|
243 |
+
rag_params_dict["chunk_size"] = chunk_size
|
244 |
+
if embed_model is not None:
|
245 |
+
rag_params_dict["embed_model"] = embed_model
|
246 |
+
if llm is not None:
|
247 |
+
rag_params_dict["llm"] = llm
|
248 |
+
|
249 |
+
self.set_rag_params(**rag_params_dict)
|
250 |
+
|
251 |
+
# update tools
|
252 |
+
if additional_tools is not None:
|
253 |
+
self.cache.tools = additional_tools
|
254 |
+
|
255 |
+
# this will update the agent in the cache
|
256 |
+
self.create_agent()
|
core/agent_builder/registry.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Agent builder registry."""
|
2 |
+
|
3 |
+
from typing import List
|
4 |
+
from typing import Union
|
5 |
+
from pathlib import Path
|
6 |
+
import json
|
7 |
+
import shutil
|
8 |
+
|
9 |
+
from core.param_cache import ParamCache
|
10 |
+
|
11 |
+
|
12 |
+
class AgentCacheRegistry:
|
13 |
+
"""Registry for agent caches, in disk.
|
14 |
+
|
15 |
+
Can register new agent caches, load agent caches, delete agent caches, etc.
|
16 |
+
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, dir: Union[str, Path]) -> None:
|
20 |
+
"""Init params."""
|
21 |
+
self._dir = dir
|
22 |
+
|
23 |
+
def _add_agent_id_to_directory(self, agent_id: str) -> None:
|
24 |
+
"""Save agent id to directory."""
|
25 |
+
full_path = Path(self._dir) / "agent_ids.json"
|
26 |
+
if not full_path.exists():
|
27 |
+
with open(full_path, "w") as f:
|
28 |
+
json.dump({"agent_ids": [agent_id]}, f)
|
29 |
+
else:
|
30 |
+
with open(full_path, "r") as f:
|
31 |
+
agent_ids = json.load(f)["agent_ids"]
|
32 |
+
if agent_id in agent_ids:
|
33 |
+
raise ValueError(f"Agent id {agent_id} already exists.")
|
34 |
+
agent_ids_set = set(agent_ids)
|
35 |
+
agent_ids_set.add(agent_id)
|
36 |
+
with open(full_path, "w") as f:
|
37 |
+
json.dump({"agent_ids": list(agent_ids_set)}, f)
|
38 |
+
|
39 |
+
def add_new_agent_cache(self, agent_id: str, cache: ParamCache) -> None:
|
40 |
+
"""Register agent."""
|
41 |
+
# save the cache to disk
|
42 |
+
agent_cache_path = f"{self._dir}/{agent_id}"
|
43 |
+
cache.save_to_disk(agent_cache_path)
|
44 |
+
# save to agent ids
|
45 |
+
self._add_agent_id_to_directory(agent_id)
|
46 |
+
|
47 |
+
def get_agent_ids(self) -> List[str]:
|
48 |
+
"""Get agent ids."""
|
49 |
+
full_path = Path(self._dir) / "agent_ids.json"
|
50 |
+
if not full_path.exists():
|
51 |
+
return []
|
52 |
+
with open(full_path, "r") as f:
|
53 |
+
agent_ids = json.load(f)["agent_ids"]
|
54 |
+
|
55 |
+
return agent_ids
|
56 |
+
|
57 |
+
def get_agent_cache(self, agent_id: str) -> ParamCache:
|
58 |
+
"""Get agent cache."""
|
59 |
+
full_path = Path(self._dir) / f"{agent_id}"
|
60 |
+
if not full_path.exists():
|
61 |
+
raise ValueError(f"Cache for agent {agent_id} does not exist.")
|
62 |
+
cache = ParamCache.load_from_disk(str(full_path))
|
63 |
+
return cache
|
64 |
+
|
65 |
+
def delete_agent_cache(self, agent_id: str) -> None:
|
66 |
+
"""Delete agent cache."""
|
67 |
+
# modify / resave agent_ids
|
68 |
+
agent_ids = self.get_agent_ids()
|
69 |
+
new_agent_ids = [id for id in agent_ids if id != agent_id]
|
70 |
+
full_path = Path(self._dir) / "agent_ids.json"
|
71 |
+
with open(full_path, "w") as f:
|
72 |
+
json.dump({"agent_ids": new_agent_ids}, f)
|
73 |
+
|
74 |
+
# remove agent cache
|
75 |
+
full_path = Path(self._dir) / f"{agent_id}"
|
76 |
+
if full_path.exists():
|
77 |
+
# recursive delete
|
78 |
+
shutil.rmtree(full_path)
|
core/builder_config.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Configuration."""
|
2 |
+
import streamlit as st
|
3 |
+
import os
|
4 |
+
|
5 |
+
### DEFINE BUILDER_LLM #####
|
6 |
+
## Uncomment the LLM you want to use to construct the meta agent
|
7 |
+
|
8 |
+
## OpenAI
|
9 |
+
from llama_index.llms import OpenAI
|
10 |
+
|
11 |
+
# set OpenAI Key - use Streamlit secrets
|
12 |
+
os.environ["OPENAI_API_KEY"] = st.secrets.openai_key
|
13 |
+
# load LLM
|
14 |
+
BUILDER_LLM = OpenAI(model="gpt-4-1106-preview")
|
15 |
+
|
16 |
+
# # Anthropic (make sure you `pip install anthropic`)
|
17 |
+
# from llama_index.llms import Anthropic
|
18 |
+
# # set Anthropic key
|
19 |
+
# os.environ["ANTHROPIC_API_KEY"] = st.secrets.anthropic_key
|
20 |
+
# BUILDER_LLM = Anthropic()
|
core/callback_manager.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Streaming callback manager."""
|
2 |
+
from llama_index.callbacks.base_handler import BaseCallbackHandler
|
3 |
+
from llama_index.callbacks.schema import CBEventType
|
4 |
+
|
5 |
+
from typing import Optional, Dict, Any, List, Callable
|
6 |
+
|
7 |
+
STORAGE_DIR = "./storage" # directory to cache the generated index
|
8 |
+
DATA_DIR = "./data" # directory containing the documents to index
|
9 |
+
|
10 |
+
|
11 |
+
class StreamlitFunctionsCallbackHandler(BaseCallbackHandler):
|
12 |
+
"""Callback handler that outputs streamlit components given events."""
|
13 |
+
|
14 |
+
def __init__(self, msg_handler: Callable[[str], Any]) -> None:
|
15 |
+
"""Initialize the base callback handler."""
|
16 |
+
self.msg_handler = msg_handler
|
17 |
+
super().__init__([], [])
|
18 |
+
|
19 |
+
def on_event_start(
|
20 |
+
self,
|
21 |
+
event_type: CBEventType,
|
22 |
+
payload: Optional[Dict[str, Any]] = None,
|
23 |
+
event_id: str = "",
|
24 |
+
parent_id: str = "",
|
25 |
+
**kwargs: Any,
|
26 |
+
) -> str:
|
27 |
+
"""Run when an event starts and return id of event."""
|
28 |
+
if event_type == CBEventType.FUNCTION_CALL:
|
29 |
+
if payload is None:
|
30 |
+
raise ValueError("Payload cannot be None")
|
31 |
+
arguments_str = payload["function_call"]
|
32 |
+
tool_str = payload["tool"].name
|
33 |
+
print_str = f"Calling function: {tool_str} with args: {arguments_str}\n\n"
|
34 |
+
self.msg_handler(print_str)
|
35 |
+
else:
|
36 |
+
pass
|
37 |
+
return event_id
|
38 |
+
|
39 |
+
def on_event_end(
|
40 |
+
self,
|
41 |
+
event_type: CBEventType,
|
42 |
+
payload: Optional[Dict[str, Any]] = None,
|
43 |
+
event_id: str = "",
|
44 |
+
**kwargs: Any,
|
45 |
+
) -> None:
|
46 |
+
"""Run when an event ends."""
|
47 |
+
pass
|
48 |
+
# TODO: currently we don't need to do anything here
|
49 |
+
# if event_type == CBEventType.FUNCTION_CALL:
|
50 |
+
# response = payload["function_call_response"]
|
51 |
+
# # Add this to queue
|
52 |
+
# print_str = (
|
53 |
+
# f"\n\nGot output: {response}\n"
|
54 |
+
# "========================\n\n"
|
55 |
+
# )
|
56 |
+
# elif event_type == CBEventType.AGENT_STEP:
|
57 |
+
# # put response into queue
|
58 |
+
# self._queue.put(payload["response"])
|
59 |
+
|
60 |
+
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
61 |
+
"""Run when an overall trace is launched."""
|
62 |
+
pass
|
63 |
+
|
64 |
+
def end_trace(
|
65 |
+
self,
|
66 |
+
trace_id: Optional[str] = None,
|
67 |
+
trace_map: Optional[Dict[str, List[str]]] = None,
|
68 |
+
) -> None:
|
69 |
+
"""Run when an overall trace is exited."""
|
70 |
+
pass
|
core/constants.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
AGENT_CACHE_DIR = Path(__file__).parent.parent / "cache" / "agents"
|
4 |
+
MESSAGES_CACHE_DIR = Path(__file__).parent.parent / "cache" / "messages"
|
core/param_cache.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Param cache."""
|
2 |
+
|
3 |
+
from pydantic import BaseModel, Field
|
4 |
+
from llama_index import (
|
5 |
+
VectorStoreIndex,
|
6 |
+
StorageContext,
|
7 |
+
load_index_from_storage,
|
8 |
+
)
|
9 |
+
from typing import List, cast, Optional
|
10 |
+
from llama_index.chat_engine.types import BaseChatEngine
|
11 |
+
from pathlib import Path
|
12 |
+
import json
|
13 |
+
import uuid
|
14 |
+
from core.utils import (
|
15 |
+
load_data,
|
16 |
+
get_tool_objects,
|
17 |
+
construct_agent,
|
18 |
+
RAGParams,
|
19 |
+
construct_mm_agent,
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
class ParamCache(BaseModel):
|
24 |
+
"""Cache for RAG agent builder.
|
25 |
+
|
26 |
+
Created a wrapper class around a dict in case we wanted to more explicitly
|
27 |
+
type different items in the cache.
|
28 |
+
|
29 |
+
"""
|
30 |
+
|
31 |
+
# arbitrary types
|
32 |
+
class Config:
|
33 |
+
arbitrary_types_allowed = True
|
34 |
+
|
35 |
+
# system prompt
|
36 |
+
system_prompt: Optional[str] = Field(
|
37 |
+
default=None, description="System prompt for RAG agent."
|
38 |
+
)
|
39 |
+
# data
|
40 |
+
file_names: List[str] = Field(
|
41 |
+
default_factory=list, description="File names as data source (if specified)"
|
42 |
+
)
|
43 |
+
urls: List[str] = Field(
|
44 |
+
default_factory=list, description="URLs as data source (if specified)"
|
45 |
+
)
|
46 |
+
directory: Optional[str] = Field(
|
47 |
+
default=None, description="Directory as data source (if specified)"
|
48 |
+
)
|
49 |
+
|
50 |
+
docs: List = Field(default_factory=list, description="Documents for RAG agent.")
|
51 |
+
# tools
|
52 |
+
tools: List = Field(
|
53 |
+
default_factory=list, description="Additional tools for RAG agent (e.g. web)"
|
54 |
+
)
|
55 |
+
# RAG params
|
56 |
+
rag_params: RAGParams = Field(
|
57 |
+
default_factory=RAGParams, description="RAG parameters for RAG agent."
|
58 |
+
)
|
59 |
+
|
60 |
+
# agent params
|
61 |
+
builder_type: str = Field(
|
62 |
+
default="default", description="Builder type (default, multimodal)."
|
63 |
+
)
|
64 |
+
vector_index: Optional[VectorStoreIndex] = Field(
|
65 |
+
default=None, description="Vector index for RAG agent."
|
66 |
+
)
|
67 |
+
agent_id: str = Field(
|
68 |
+
default_factory=lambda: f"Agent_{str(uuid.uuid4())}",
|
69 |
+
description="Agent ID for RAG agent.",
|
70 |
+
)
|
71 |
+
agent: Optional[BaseChatEngine] = Field(default=None, description="RAG agent.")
|
72 |
+
|
73 |
+
def save_to_disk(self, save_dir: str) -> None:
|
74 |
+
"""Save cache to disk."""
|
75 |
+
# NOTE: more complex than just calling dict() because we want to
|
76 |
+
# only store serializable fields and be space-efficient
|
77 |
+
|
78 |
+
dict_to_serialize = {
|
79 |
+
"system_prompt": self.system_prompt,
|
80 |
+
"file_names": self.file_names,
|
81 |
+
"urls": self.urls,
|
82 |
+
"directory": self.directory,
|
83 |
+
# TODO: figure out tools
|
84 |
+
"tools": self.tools,
|
85 |
+
"rag_params": self.rag_params.dict(),
|
86 |
+
"builder_type": self.builder_type,
|
87 |
+
"agent_id": self.agent_id,
|
88 |
+
}
|
89 |
+
# store the vector store within the agent
|
90 |
+
if self.vector_index is None:
|
91 |
+
raise ValueError("Must specify vector index in order to save.")
|
92 |
+
self.vector_index.storage_context.persist(Path(save_dir) / "storage")
|
93 |
+
|
94 |
+
# if save_path directories don't exist, create it
|
95 |
+
if not Path(save_dir).exists():
|
96 |
+
Path(save_dir).mkdir(parents=True)
|
97 |
+
with open(Path(save_dir) / "cache.json", "w") as f:
|
98 |
+
json.dump(dict_to_serialize, f)
|
99 |
+
|
100 |
+
@classmethod
|
101 |
+
def load_from_disk(
|
102 |
+
cls,
|
103 |
+
save_dir: str,
|
104 |
+
) -> "ParamCache":
|
105 |
+
"""Load cache from disk."""
|
106 |
+
with open(Path(save_dir) / "cache.json", "r") as f:
|
107 |
+
cache_dict = json.load(f)
|
108 |
+
|
109 |
+
storage_context = StorageContext.from_defaults(
|
110 |
+
persist_dir=str(Path(save_dir) / "storage")
|
111 |
+
)
|
112 |
+
if cache_dict["builder_type"] == "multimodal":
|
113 |
+
from llama_index.indices.multi_modal.base import MultiModalVectorStoreIndex
|
114 |
+
|
115 |
+
vector_index: VectorStoreIndex = cast(
|
116 |
+
MultiModalVectorStoreIndex, load_index_from_storage(storage_context)
|
117 |
+
)
|
118 |
+
else:
|
119 |
+
vector_index = cast(
|
120 |
+
VectorStoreIndex, load_index_from_storage(storage_context)
|
121 |
+
)
|
122 |
+
|
123 |
+
# replace rag params with RAGParams object
|
124 |
+
cache_dict["rag_params"] = RAGParams(**cache_dict["rag_params"])
|
125 |
+
|
126 |
+
# add in the missing fields
|
127 |
+
# load docs
|
128 |
+
cache_dict["docs"] = load_data(
|
129 |
+
file_names=cache_dict["file_names"],
|
130 |
+
urls=cache_dict["urls"],
|
131 |
+
directory=cache_dict["directory"],
|
132 |
+
)
|
133 |
+
# load agent from index
|
134 |
+
additional_tools = get_tool_objects(cache_dict["tools"])
|
135 |
+
|
136 |
+
if cache_dict["builder_type"] == "multimodal":
|
137 |
+
vector_index = cast(MultiModalVectorStoreIndex, vector_index)
|
138 |
+
agent, _ = construct_mm_agent(
|
139 |
+
cache_dict["system_prompt"],
|
140 |
+
cache_dict["rag_params"],
|
141 |
+
cache_dict["docs"],
|
142 |
+
mm_vector_index=vector_index,
|
143 |
+
)
|
144 |
+
else:
|
145 |
+
agent, _ = construct_agent(
|
146 |
+
cache_dict["system_prompt"],
|
147 |
+
cache_dict["rag_params"],
|
148 |
+
cache_dict["docs"],
|
149 |
+
vector_index=vector_index,
|
150 |
+
additional_tools=additional_tools,
|
151 |
+
# TODO: figure out tools
|
152 |
+
)
|
153 |
+
cache_dict["vector_index"] = vector_index
|
154 |
+
cache_dict["agent"] = agent
|
155 |
+
|
156 |
+
return cls(**cache_dict)
|
core/utils.py
ADDED
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utils."""
|
2 |
+
|
3 |
+
from llama_index.llms import OpenAI, Anthropic, Replicate
|
4 |
+
from llama_index.llms.base import LLM
|
5 |
+
from llama_index.llms.utils import resolve_llm
|
6 |
+
from pydantic import BaseModel, Field
|
7 |
+
import os
|
8 |
+
from llama_index.agent import OpenAIAgent, ReActAgent
|
9 |
+
from llama_index.agent.react.prompts import REACT_CHAT_SYSTEM_HEADER
|
10 |
+
from llama_index import (
|
11 |
+
VectorStoreIndex,
|
12 |
+
SummaryIndex,
|
13 |
+
ServiceContext,
|
14 |
+
Document,
|
15 |
+
)
|
16 |
+
from typing import List, cast, Optional
|
17 |
+
from llama_index import SimpleDirectoryReader
|
18 |
+
from llama_index.embeddings.utils import resolve_embed_model
|
19 |
+
from llama_index.tools import QueryEngineTool, ToolMetadata
|
20 |
+
from llama_index.agent.types import BaseAgent
|
21 |
+
from llama_index.chat_engine.types import BaseChatEngine
|
22 |
+
from llama_index.agent.react.formatter import ReActChatFormatter
|
23 |
+
from llama_index.llms.openai_utils import is_function_calling_model
|
24 |
+
from llama_index.chat_engine import CondensePlusContextChatEngine
|
25 |
+
from core.builder_config import BUILDER_LLM
|
26 |
+
from typing import Dict, Tuple, Any
|
27 |
+
import streamlit as st
|
28 |
+
|
29 |
+
from llama_index.callbacks import CallbackManager, trace_method
|
30 |
+
from core.callback_manager import StreamlitFunctionsCallbackHandler
|
31 |
+
from llama_index.schema import ImageNode, NodeWithScore
|
32 |
+
|
33 |
+
### BETA: Multi-modal
|
34 |
+
from llama_index.indices.multi_modal.base import MultiModalVectorStoreIndex
|
35 |
+
from llama_index.multi_modal_llms.openai import OpenAIMultiModal
|
36 |
+
from llama_index.indices.multi_modal.retriever import (
|
37 |
+
MultiModalVectorIndexRetriever,
|
38 |
+
)
|
39 |
+
from llama_index.llms import ChatMessage
|
40 |
+
from llama_index.query_engine.multi_modal import SimpleMultiModalQueryEngine
|
41 |
+
from llama_index.chat_engine.types import (
|
42 |
+
AGENT_CHAT_RESPONSE_TYPE,
|
43 |
+
StreamingAgentChatResponse,
|
44 |
+
AgentChatResponse,
|
45 |
+
)
|
46 |
+
from llama_index.llms.base import ChatResponse
|
47 |
+
from typing import Generator
|
48 |
+
|
49 |
+
|
50 |
+
class RAGParams(BaseModel):
|
51 |
+
"""RAG parameters.
|
52 |
+
|
53 |
+
Parameters used to configure a RAG pipeline.
|
54 |
+
|
55 |
+
"""
|
56 |
+
|
57 |
+
include_summarization: bool = Field(
|
58 |
+
default=False,
|
59 |
+
description=(
|
60 |
+
"Whether to include summarization in the RAG pipeline. (only for GPT-4)"
|
61 |
+
),
|
62 |
+
)
|
63 |
+
top_k: int = Field(
|
64 |
+
default=2, description="Number of documents to retrieve from vector store."
|
65 |
+
)
|
66 |
+
chunk_size: int = Field(default=1024, description="Chunk size for vector store.")
|
67 |
+
embed_model: str = Field(
|
68 |
+
default="default", description="Embedding model to use (default is OpenAI)"
|
69 |
+
)
|
70 |
+
llm: str = Field(
|
71 |
+
default="gpt-4-1106-preview", description="LLM to use for summarization."
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
+
def _resolve_llm(llm_str: str) -> LLM:
|
76 |
+
"""Resolve LLM."""
|
77 |
+
# TODO: make this less hardcoded with if-else statements
|
78 |
+
# see if there's a prefix
|
79 |
+
# - if there isn't, assume it's an OpenAI model
|
80 |
+
# - if there is, resolve it
|
81 |
+
tokens = llm_str.split(":")
|
82 |
+
if len(tokens) == 1:
|
83 |
+
os.environ["OPENAI_API_KEY"] = st.secrets.openai_key
|
84 |
+
llm: LLM = OpenAI(model=llm_str)
|
85 |
+
elif tokens[0] == "local":
|
86 |
+
llm = resolve_llm(llm_str)
|
87 |
+
elif tokens[0] == "openai":
|
88 |
+
os.environ["OPENAI_API_KEY"] = st.secrets.openai_key
|
89 |
+
llm = OpenAI(model=tokens[1])
|
90 |
+
elif tokens[0] == "anthropic":
|
91 |
+
os.environ["ANTHROPIC_API_KEY"] = st.secrets.anthropic_key
|
92 |
+
llm = Anthropic(model=tokens[1])
|
93 |
+
elif tokens[0] == "replicate":
|
94 |
+
os.environ["REPLICATE_API_KEY"] = st.secrets.replicate_key
|
95 |
+
llm = Replicate(model=tokens[1])
|
96 |
+
else:
|
97 |
+
raise ValueError(f"LLM {llm_str} not recognized.")
|
98 |
+
return llm
|
99 |
+
|
100 |
+
|
101 |
+
def load_data(
|
102 |
+
file_names: Optional[List[str]] = None,
|
103 |
+
directory: Optional[str] = None,
|
104 |
+
urls: Optional[List[str]] = None,
|
105 |
+
) -> List[Document]:
|
106 |
+
"""Load data."""
|
107 |
+
file_names = file_names or []
|
108 |
+
directory = directory or ""
|
109 |
+
urls = urls or []
|
110 |
+
|
111 |
+
# get number depending on whether specified
|
112 |
+
num_specified = sum(1 for v in [file_names, urls, directory] if v)
|
113 |
+
|
114 |
+
if num_specified == 0:
|
115 |
+
raise ValueError("Must specify either file_names or urls or directory.")
|
116 |
+
elif num_specified > 1:
|
117 |
+
raise ValueError("Must specify only one of file_names or urls or directory.")
|
118 |
+
elif file_names:
|
119 |
+
reader = SimpleDirectoryReader(input_files=file_names)
|
120 |
+
docs = reader.load_data()
|
121 |
+
elif directory:
|
122 |
+
reader = SimpleDirectoryReader(input_dir=directory)
|
123 |
+
docs = reader.load_data()
|
124 |
+
elif urls:
|
125 |
+
from llama_hub.web.simple_web.base import SimpleWebPageReader
|
126 |
+
|
127 |
+
# use simple web page reader from llamahub
|
128 |
+
loader = SimpleWebPageReader()
|
129 |
+
docs = loader.load_data(urls=urls)
|
130 |
+
else:
|
131 |
+
raise ValueError("Must specify either file_names or urls or directory.")
|
132 |
+
|
133 |
+
return docs
|
134 |
+
|
135 |
+
|
136 |
+
def load_agent(
|
137 |
+
tools: List,
|
138 |
+
llm: LLM,
|
139 |
+
system_prompt: str,
|
140 |
+
extra_kwargs: Optional[Dict] = None,
|
141 |
+
**kwargs: Any,
|
142 |
+
) -> BaseChatEngine:
|
143 |
+
"""Load agent."""
|
144 |
+
extra_kwargs = extra_kwargs or {}
|
145 |
+
if isinstance(llm, OpenAI) and is_function_calling_model(llm.model):
|
146 |
+
# TODO: use default msg handler
|
147 |
+
# TODO: separate this from agent_utils.py...
|
148 |
+
def _msg_handler(msg: str) -> None:
|
149 |
+
"""Message handler."""
|
150 |
+
st.info(msg)
|
151 |
+
st.session_state.agent_messages.append(
|
152 |
+
{"role": "assistant", "content": msg, "msg_type": "info"}
|
153 |
+
)
|
154 |
+
|
155 |
+
# add streamlit callbacks (to inject events)
|
156 |
+
handler = StreamlitFunctionsCallbackHandler(_msg_handler)
|
157 |
+
callback_manager = CallbackManager([handler])
|
158 |
+
# get OpenAI Agent
|
159 |
+
agent: BaseChatEngine = OpenAIAgent.from_tools(
|
160 |
+
tools=tools,
|
161 |
+
llm=llm,
|
162 |
+
system_prompt=system_prompt,
|
163 |
+
**kwargs,
|
164 |
+
callback_manager=callback_manager,
|
165 |
+
)
|
166 |
+
else:
|
167 |
+
if "vector_index" not in extra_kwargs:
|
168 |
+
raise ValueError(
|
169 |
+
"Must pass in vector index for CondensePlusContextChatEngine."
|
170 |
+
)
|
171 |
+
vector_index = cast(VectorStoreIndex, extra_kwargs["vector_index"])
|
172 |
+
rag_params = cast(RAGParams, extra_kwargs["rag_params"])
|
173 |
+
# use condense + context chat engine
|
174 |
+
agent = CondensePlusContextChatEngine.from_defaults(
|
175 |
+
vector_index.as_retriever(similarity_top_k=rag_params.top_k),
|
176 |
+
)
|
177 |
+
|
178 |
+
return agent
|
179 |
+
|
180 |
+
|
181 |
+
def load_meta_agent(
|
182 |
+
tools: List,
|
183 |
+
llm: LLM,
|
184 |
+
system_prompt: str,
|
185 |
+
extra_kwargs: Optional[Dict] = None,
|
186 |
+
**kwargs: Any,
|
187 |
+
) -> BaseAgent:
|
188 |
+
"""Load meta agent.
|
189 |
+
|
190 |
+
TODO: consolidate with load_agent.
|
191 |
+
|
192 |
+
The meta-agent *has* to perform tool-use.
|
193 |
+
|
194 |
+
"""
|
195 |
+
extra_kwargs = extra_kwargs or {}
|
196 |
+
if isinstance(llm, OpenAI) and is_function_calling_model(llm.model):
|
197 |
+
# get OpenAI Agent
|
198 |
+
|
199 |
+
agent: BaseAgent = OpenAIAgent.from_tools(
|
200 |
+
tools=tools,
|
201 |
+
llm=llm,
|
202 |
+
system_prompt=system_prompt,
|
203 |
+
**kwargs,
|
204 |
+
)
|
205 |
+
else:
|
206 |
+
agent = ReActAgent.from_tools(
|
207 |
+
tools=tools,
|
208 |
+
llm=llm,
|
209 |
+
react_chat_formatter=ReActChatFormatter(
|
210 |
+
system_header=system_prompt + "\n" + REACT_CHAT_SYSTEM_HEADER,
|
211 |
+
),
|
212 |
+
**kwargs,
|
213 |
+
)
|
214 |
+
|
215 |
+
return agent
|
216 |
+
|
217 |
+
|
218 |
+
def construct_agent(
|
219 |
+
system_prompt: str,
|
220 |
+
rag_params: RAGParams,
|
221 |
+
docs: List[Document],
|
222 |
+
vector_index: Optional[VectorStoreIndex] = None,
|
223 |
+
additional_tools: Optional[List] = None,
|
224 |
+
) -> Tuple[BaseChatEngine, Dict]:
|
225 |
+
"""Construct agent from docs / parameters / indices."""
|
226 |
+
extra_info = {}
|
227 |
+
additional_tools = additional_tools or []
|
228 |
+
|
229 |
+
# first resolve llm and embedding model
|
230 |
+
embed_model = resolve_embed_model(rag_params.embed_model)
|
231 |
+
# llm = resolve_llm(rag_params.llm)
|
232 |
+
# TODO: use OpenAI for now
|
233 |
+
# llm = OpenAI(model=rag_params.llm)
|
234 |
+
llm = _resolve_llm(rag_params.llm)
|
235 |
+
|
236 |
+
# first let's index the data with the right parameters
|
237 |
+
service_context = ServiceContext.from_defaults(
|
238 |
+
chunk_size=rag_params.chunk_size,
|
239 |
+
llm=llm,
|
240 |
+
embed_model=embed_model,
|
241 |
+
)
|
242 |
+
|
243 |
+
if vector_index is None:
|
244 |
+
vector_index = VectorStoreIndex.from_documents(
|
245 |
+
docs, service_context=service_context
|
246 |
+
)
|
247 |
+
else:
|
248 |
+
pass
|
249 |
+
|
250 |
+
extra_info["vector_index"] = vector_index
|
251 |
+
|
252 |
+
vector_query_engine = vector_index.as_query_engine(
|
253 |
+
similarity_top_k=rag_params.top_k
|
254 |
+
)
|
255 |
+
all_tools = []
|
256 |
+
vector_tool = QueryEngineTool(
|
257 |
+
query_engine=vector_query_engine,
|
258 |
+
metadata=ToolMetadata(
|
259 |
+
name="vector_tool",
|
260 |
+
description=("Use this tool to answer any user question over any data."),
|
261 |
+
),
|
262 |
+
)
|
263 |
+
all_tools.append(vector_tool)
|
264 |
+
if rag_params.include_summarization:
|
265 |
+
summary_index = SummaryIndex.from_documents(
|
266 |
+
docs, service_context=service_context
|
267 |
+
)
|
268 |
+
summary_query_engine = summary_index.as_query_engine()
|
269 |
+
summary_tool = QueryEngineTool(
|
270 |
+
query_engine=summary_query_engine,
|
271 |
+
metadata=ToolMetadata(
|
272 |
+
name="summary_tool",
|
273 |
+
description=(
|
274 |
+
"Use this tool for any user questions that ask "
|
275 |
+
"for a summarization of content"
|
276 |
+
),
|
277 |
+
),
|
278 |
+
)
|
279 |
+
all_tools.append(summary_tool)
|
280 |
+
|
281 |
+
# then we add tools
|
282 |
+
all_tools.extend(additional_tools)
|
283 |
+
|
284 |
+
# build agent
|
285 |
+
if system_prompt is None:
|
286 |
+
return "System prompt not set yet. Please set system prompt first."
|
287 |
+
|
288 |
+
agent = load_agent(
|
289 |
+
all_tools,
|
290 |
+
llm=llm,
|
291 |
+
system_prompt=system_prompt,
|
292 |
+
verbose=True,
|
293 |
+
extra_kwargs={"vector_index": vector_index, "rag_params": rag_params},
|
294 |
+
)
|
295 |
+
return agent, extra_info
|
296 |
+
|
297 |
+
|
298 |
+
def get_web_agent_tool() -> QueryEngineTool:
|
299 |
+
"""Get web agent tool.
|
300 |
+
|
301 |
+
Wrap with our load and search tool spec.
|
302 |
+
|
303 |
+
"""
|
304 |
+
from llama_hub.tools.metaphor.base import MetaphorToolSpec
|
305 |
+
|
306 |
+
# TODO: set metaphor API key
|
307 |
+
metaphor_tool = MetaphorToolSpec(
|
308 |
+
api_key=st.secrets.metaphor_key,
|
309 |
+
)
|
310 |
+
metaphor_tool_list = metaphor_tool.to_tool_list()
|
311 |
+
|
312 |
+
# TODO: LoadAndSearch doesn't work yet
|
313 |
+
# The search_and_retrieve_documents tool is the third in the tool list,
|
314 |
+
# as seen above
|
315 |
+
# wrapped_retrieve = LoadAndSearchToolSpec.from_defaults(
|
316 |
+
# metaphor_tool_list[2],
|
317 |
+
# )
|
318 |
+
|
319 |
+
# NOTE: requires openai right now
|
320 |
+
# We don't give the Agent our unwrapped retrieve document tools
|
321 |
+
# instead passing the wrapped tools
|
322 |
+
web_agent = OpenAIAgent.from_tools(
|
323 |
+
# [*wrapped_retrieve.to_tool_list(), metaphor_tool_list[4]],
|
324 |
+
metaphor_tool_list,
|
325 |
+
llm=BUILDER_LLM,
|
326 |
+
verbose=True,
|
327 |
+
)
|
328 |
+
|
329 |
+
# return agent as a tool
|
330 |
+
# TODO: tune description
|
331 |
+
web_agent_tool = QueryEngineTool.from_defaults(
|
332 |
+
web_agent,
|
333 |
+
name="web_agent",
|
334 |
+
description="""
|
335 |
+
This agent can answer questions by searching the web. \
|
336 |
+
Use this tool if the answer is ONLY likely to be found by searching \
|
337 |
+
the internet, especially for queries about recent events.
|
338 |
+
""",
|
339 |
+
)
|
340 |
+
|
341 |
+
return web_agent_tool
|
342 |
+
|
343 |
+
|
344 |
+
def get_tool_objects(tool_names: List[str]) -> List:
|
345 |
+
"""Get tool objects from tool names."""
|
346 |
+
# construct additional tools
|
347 |
+
tool_objs = []
|
348 |
+
for tool_name in tool_names:
|
349 |
+
if tool_name == "web_search":
|
350 |
+
# build web agent
|
351 |
+
tool_objs.append(get_web_agent_tool())
|
352 |
+
else:
|
353 |
+
raise ValueError(f"Tool {tool_name} not recognized.")
|
354 |
+
|
355 |
+
return tool_objs
|
356 |
+
|
357 |
+
|
358 |
+
class MultimodalChatEngine(BaseChatEngine):
|
359 |
+
"""Multimodal chat engine.
|
360 |
+
|
361 |
+
This chat engine is a light wrapper around a query engine.
|
362 |
+
Offers no real 'chat' functionality, is a beta feature.
|
363 |
+
|
364 |
+
"""
|
365 |
+
|
366 |
+
def __init__(self, mm_query_engine: SimpleMultiModalQueryEngine) -> None:
|
367 |
+
"""Init params."""
|
368 |
+
self._mm_query_engine = mm_query_engine
|
369 |
+
|
370 |
+
def reset(self) -> None:
|
371 |
+
"""Reset conversation state."""
|
372 |
+
pass
|
373 |
+
|
374 |
+
@property
|
375 |
+
def chat_history(self) -> List[ChatMessage]:
|
376 |
+
return []
|
377 |
+
|
378 |
+
@trace_method("chat")
|
379 |
+
def chat(
|
380 |
+
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
381 |
+
) -> AGENT_CHAT_RESPONSE_TYPE:
|
382 |
+
"""Main chat interface."""
|
383 |
+
# just return the top-k results
|
384 |
+
response = self._mm_query_engine.query(message)
|
385 |
+
return AgentChatResponse(
|
386 |
+
response=str(response), source_nodes=response.source_nodes
|
387 |
+
)
|
388 |
+
|
389 |
+
@trace_method("chat")
|
390 |
+
def stream_chat(
|
391 |
+
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
392 |
+
) -> StreamingAgentChatResponse:
|
393 |
+
"""Stream chat interface."""
|
394 |
+
response = self._mm_query_engine.query(message)
|
395 |
+
|
396 |
+
def _chat_stream(response: str) -> Generator[ChatResponse, None, None]:
|
397 |
+
yield ChatResponse(message=ChatMessage(role="assistant", content=response))
|
398 |
+
|
399 |
+
chat_stream = _chat_stream(str(response))
|
400 |
+
return StreamingAgentChatResponse(
|
401 |
+
chat_stream=chat_stream, source_nodes=response.source_nodes
|
402 |
+
)
|
403 |
+
|
404 |
+
@trace_method("chat")
|
405 |
+
async def achat(
|
406 |
+
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
407 |
+
) -> AGENT_CHAT_RESPONSE_TYPE:
|
408 |
+
"""Async version of main chat interface."""
|
409 |
+
response = await self._mm_query_engine.aquery(message)
|
410 |
+
return AgentChatResponse(
|
411 |
+
response=str(response), source_nodes=response.source_nodes
|
412 |
+
)
|
413 |
+
|
414 |
+
@trace_method("chat")
|
415 |
+
async def astream_chat(
|
416 |
+
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
417 |
+
) -> StreamingAgentChatResponse:
|
418 |
+
"""Async version of main chat interface."""
|
419 |
+
return self.stream_chat(message, chat_history)
|
420 |
+
|
421 |
+
|
422 |
+
def construct_mm_agent(
|
423 |
+
system_prompt: str,
|
424 |
+
rag_params: RAGParams,
|
425 |
+
docs: List[Document],
|
426 |
+
mm_vector_index: Optional[VectorStoreIndex] = None,
|
427 |
+
additional_tools: Optional[List] = None,
|
428 |
+
) -> Tuple[BaseChatEngine, Dict]:
|
429 |
+
"""Construct agent from docs / parameters / indices.
|
430 |
+
|
431 |
+
NOTE: system prompt isn't used right now
|
432 |
+
|
433 |
+
"""
|
434 |
+
extra_info = {}
|
435 |
+
additional_tools = additional_tools or []
|
436 |
+
|
437 |
+
# first resolve llm and embedding model
|
438 |
+
embed_model = resolve_embed_model(rag_params.embed_model)
|
439 |
+
# TODO: use OpenAI for now
|
440 |
+
os.environ["OPENAI_API_KEY"] = st.secrets.openai_key
|
441 |
+
openai_mm_llm = OpenAIMultiModal(model="gpt-4-vision-preview", max_new_tokens=1500)
|
442 |
+
|
443 |
+
# first let's index the data with the right parameters
|
444 |
+
service_context = ServiceContext.from_defaults(
|
445 |
+
chunk_size=rag_params.chunk_size,
|
446 |
+
embed_model=embed_model,
|
447 |
+
)
|
448 |
+
|
449 |
+
if mm_vector_index is None:
|
450 |
+
mm_vector_index = MultiModalVectorStoreIndex.from_documents(
|
451 |
+
docs, service_context=service_context
|
452 |
+
)
|
453 |
+
else:
|
454 |
+
pass
|
455 |
+
|
456 |
+
mm_retriever = mm_vector_index.as_retriever(similarity_top_k=rag_params.top_k)
|
457 |
+
mm_query_engine = SimpleMultiModalQueryEngine(
|
458 |
+
cast(MultiModalVectorIndexRetriever, mm_retriever),
|
459 |
+
multi_modal_llm=openai_mm_llm,
|
460 |
+
)
|
461 |
+
|
462 |
+
extra_info["vector_index"] = mm_vector_index
|
463 |
+
|
464 |
+
# use condense + context chat engine
|
465 |
+
agent = MultimodalChatEngine(mm_query_engine)
|
466 |
+
|
467 |
+
return agent, extra_info
|
468 |
+
|
469 |
+
|
470 |
+
def get_image_and_text_nodes(
|
471 |
+
nodes: List[NodeWithScore],
|
472 |
+
) -> Tuple[List[NodeWithScore], List[NodeWithScore]]:
|
473 |
+
image_nodes = []
|
474 |
+
text_nodes = []
|
475 |
+
for res_node in nodes:
|
476 |
+
if isinstance(res_node.node, ImageNode):
|
477 |
+
image_nodes.append(res_node)
|
478 |
+
else:
|
479 |
+
text_nodes.append(res_node)
|
480 |
+
return image_nodes, text_nodes
|
pages/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
pages/4_🤖_ChatDoctor.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Streamlit page showing builder config."""
|
2 |
+
import streamlit as st
|
3 |
+
from st_utils import add_sidebar, get_current_state
|
4 |
+
from core.utils import get_image_and_text_nodes
|
5 |
+
from llama_index.schema import MetadataMode
|
6 |
+
from llama_index.chat_engine.types import AGENT_CHAT_RESPONSE_TYPE
|
7 |
+
from typing import Dict, Optional
|
8 |
+
import pandas as pd
|
9 |
+
|
10 |
+
|
11 |
+
####################
|
12 |
+
#### STREAMLIT #####
|
13 |
+
####################
|
14 |
+
|
15 |
+
|
16 |
+
st.set_page_config(
|
17 |
+
page_title="ChatDoctor: your virtual primary care physician assistant",
|
18 |
+
page_icon="🤖💬",
|
19 |
+
layout="centered",
|
20 |
+
#initial_sidebar_state="auto", #ggyimah set this to off
|
21 |
+
menu_items=None,
|
22 |
+
)
|
23 |
+
st.title("ChatDoctor: your virtual primary care physician assistant")
|
24 |
+
#st.info(
|
25 |
+
# "Welcome!!! My name is ChatDoctor and I am trained to provide medical diagnoses and advice.",
|
26 |
+
# icon="ℹ️",
|
27 |
+
#)
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
current_state = get_current_state()
|
32 |
+
add_sidebar()
|
33 |
+
|
34 |
+
if (
|
35 |
+
"agent_messages" not in st.session_state.keys()
|
36 |
+
): # Initialize the chat messages history
|
37 |
+
st.session_state.agent_messages = [
|
38 |
+
{"role": "assistant", "content": "I am trained to provide medical diagnoses and advice. How may I help you, today?"}
|
39 |
+
]
|
40 |
+
|
41 |
+
|
42 |
+
def display_sources(response: AGENT_CHAT_RESPONSE_TYPE) -> None:
|
43 |
+
image_nodes, text_nodes = get_image_and_text_nodes(response.source_nodes)
|
44 |
+
if len(image_nodes) > 0 or len(text_nodes) > 0:
|
45 |
+
with st.expander("Sources"):
|
46 |
+
# get image nodes
|
47 |
+
if len(image_nodes) > 0:
|
48 |
+
st.subheader("Images")
|
49 |
+
for image_node in image_nodes:
|
50 |
+
st.image(image_node.metadata["file_path"])
|
51 |
+
|
52 |
+
if len(text_nodes) > 0:
|
53 |
+
st.subheader("Text")
|
54 |
+
sources_df_list = []
|
55 |
+
for text_node in text_nodes:
|
56 |
+
sources_df_list.append(
|
57 |
+
{
|
58 |
+
"ID": text_node.id_,
|
59 |
+
"Text": text_node.node.get_content(
|
60 |
+
metadata_mode=MetadataMode.ALL
|
61 |
+
),
|
62 |
+
}
|
63 |
+
)
|
64 |
+
sources_df = pd.DataFrame(sources_df_list)
|
65 |
+
st.dataframe(sources_df)
|
66 |
+
|
67 |
+
|
68 |
+
def add_to_message_history(
|
69 |
+
role: str, content: str, extra: Optional[Dict] = None
|
70 |
+
) -> None:
|
71 |
+
message = {"role": role, "content": str(content), "extra": extra}
|
72 |
+
st.session_state.agent_messages.append(message) # Add response to message history
|
73 |
+
|
74 |
+
|
75 |
+
def display_messages() -> None:
|
76 |
+
"""Display messages."""
|
77 |
+
for message in st.session_state.agent_messages: # Display the prior chat messages
|
78 |
+
with st.chat_message(message["role"]):
|
79 |
+
msg_type = message["msg_type"] if "msg_type" in message.keys() else "text"
|
80 |
+
if msg_type == "text":
|
81 |
+
st.write(message["content"])
|
82 |
+
elif msg_type == "info":
|
83 |
+
st.info(message["content"], icon="ℹ️")
|
84 |
+
else:
|
85 |
+
raise ValueError(f"Unknown message type: {msg_type}")
|
86 |
+
|
87 |
+
# display sources
|
88 |
+
if "extra" in message and isinstance(message["extra"], dict):
|
89 |
+
if "response" in message["extra"].keys():
|
90 |
+
display_sources(message["extra"]["response"])
|
91 |
+
|
92 |
+
|
93 |
+
# if agent is created, then we can chat with it
|
94 |
+
if current_state.cache is not None and current_state.cache.agent is not None:
|
95 |
+
st.info(f"Viewing config for agent: {current_state.cache.agent_id}", icon="ℹ️")
|
96 |
+
agent = current_state.cache.agent
|
97 |
+
|
98 |
+
# display prior messages
|
99 |
+
display_messages()
|
100 |
+
|
101 |
+
# don't process selected for now
|
102 |
+
if prompt := st.chat_input(
|
103 |
+
"Your question"
|
104 |
+
): # Prompt for user input and save to chat history
|
105 |
+
add_to_message_history("user", prompt)
|
106 |
+
with st.chat_message("user"):
|
107 |
+
st.write(prompt)
|
108 |
+
|
109 |
+
# If last message is not from assistant, generate a new response
|
110 |
+
if st.session_state.agent_messages[-1]["role"] != "assistant":
|
111 |
+
with st.chat_message("assistant"):
|
112 |
+
with st.spinner("Thinking..."):
|
113 |
+
response = agent.chat(str(prompt))
|
114 |
+
st.write(str(response))
|
115 |
+
|
116 |
+
# display sources
|
117 |
+
# Multi-modal: check if image nodes are present
|
118 |
+
display_sources(response)
|
119 |
+
|
120 |
+
add_to_message_history(
|
121 |
+
"assistant", str(response), extra={"response": response}
|
122 |
+
)
|
123 |
+
else:
|
124 |
+
st.info("In the side bar, select the ChatDoctor virtual agent (Agent_950acb55-056f-4324-957d-15e1c9b48695) to get started.\n")
|
125 |
+
st.info("Since this app is running on a free basic server, it could take from 2 to 10 minutes for the virtual agent to join you. \n Please be patient.")
|
126 |
+
|
tests/__init__.py
ADDED
File without changes
|