GGYIMAH1031 commited on
Commit
2e8e265
·
verified ·
1 Parent(s): a59ab81

uploaded all small directories

Browse files
__pycache__/st_utils.cpython-311.pyc ADDED
Binary file (8.1 kB). View file
 
core/.DS_Store ADDED
Binary file (8.2 kB). View file
 
core/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Init file."""
core/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (216 Bytes). View file
 
core/__pycache__/builder_config.cpython-311.pyc ADDED
Binary file (541 Bytes). View file
 
core/__pycache__/callback_manager.cpython-311.pyc ADDED
Binary file (3.39 kB). View file
 
core/__pycache__/constants.cpython-311.pyc ADDED
Binary file (490 Bytes). View file
 
core/__pycache__/param_cache.cpython-311.pyc ADDED
Binary file (6.89 kB). View file
 
core/__pycache__/utils.cpython-311.pyc ADDED
Binary file (19 kB). View file
 
core/agent_builder/.DS_Store ADDED
Binary file (6.15 kB). View file
 
core/agent_builder/__init__.py ADDED
File without changes
core/agent_builder/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (206 Bytes). View file
 
core/agent_builder/__pycache__/base.cpython-311.pyc ADDED
Binary file (10.4 kB). View file
 
core/agent_builder/__pycache__/loader.cpython-311.pyc ADDED
Binary file (4.33 kB). View file
 
core/agent_builder/__pycache__/multimodal.cpython-311.pyc ADDED
Binary file (12.1 kB). View file
 
core/agent_builder/__pycache__/registry.cpython-311.pyc ADDED
Binary file (5.69 kB). View file
 
core/agent_builder/base.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Agent builder."""
2
+
3
+ from llama_index.llms import ChatMessage
4
+ from llama_index.prompts import ChatPromptTemplate
5
+ from typing import List, cast, Optional
6
+ from core.builder_config import BUILDER_LLM
7
+ from typing import Dict, Any
8
+ import uuid
9
+ from core.constants import AGENT_CACHE_DIR
10
+ from abc import ABC, abstractmethod
11
+
12
+ from core.param_cache import ParamCache, RAGParams
13
+ from core.utils import (
14
+ load_data,
15
+ get_tool_objects,
16
+ construct_agent,
17
+ )
18
+ from core.agent_builder.registry import AgentCacheRegistry
19
+
20
+
21
+ # System prompt tool
22
+ GEN_SYS_PROMPT_STR = """\
23
+ Task information is given below.
24
+
25
+ Given the task, please generate a system prompt for an OpenAI-powered bot \
26
+ to solve this task:
27
+ {task} \
28
+
29
+ Make sure the system prompt obeys the following requirements:
30
+ - Tells the bot to ALWAYS use tools given to solve the task. \
31
+ NEVER give an answer without using a tool.
32
+ - Does not reference a specific data source. \
33
+ The data source is implicit in any queries to the bot, \
34
+ and telling the bot to analyze a specific data source might confuse it given a \
35
+ user query.
36
+
37
+ """
38
+
39
+ gen_sys_prompt_messages = [
40
+ ChatMessage(
41
+ role="system",
42
+ content="You are helping to build a system prompt for another bot.",
43
+ ),
44
+ ChatMessage(role="user", content=GEN_SYS_PROMPT_STR),
45
+ ]
46
+
47
+ GEN_SYS_PROMPT_TMPL = ChatPromptTemplate(gen_sys_prompt_messages)
48
+
49
+
50
+ class BaseRAGAgentBuilder(ABC):
51
+ """Base RAG Agent builder class."""
52
+
53
+ @property
54
+ @abstractmethod
55
+ def cache(self) -> ParamCache:
56
+ """Cache."""
57
+
58
+ @property
59
+ @abstractmethod
60
+ def agent_registry(self) -> AgentCacheRegistry:
61
+ """Agent registry."""
62
+
63
+
64
+ class RAGAgentBuilder(BaseRAGAgentBuilder):
65
+ """RAG Agent builder.
66
+
67
+ Contains a set of functions to construct a RAG agent, including:
68
+ - setting system prompts
69
+ - loading data
70
+ - adding web search
71
+ - setting parameters (e.g. top-k)
72
+
73
+ Must pass in a cache. This cache will be modified as the agent is built.
74
+
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ cache: Optional[ParamCache] = None,
80
+ agent_registry: Optional[AgentCacheRegistry] = None,
81
+ ) -> None:
82
+ """Init params."""
83
+ self._cache = cache or ParamCache()
84
+ self._agent_registry = agent_registry or AgentCacheRegistry(
85
+ str(AGENT_CACHE_DIR)
86
+ )
87
+
88
+ @property
89
+ def cache(self) -> ParamCache:
90
+ """Cache."""
91
+ return self._cache
92
+
93
+ @property
94
+ def agent_registry(self) -> AgentCacheRegistry:
95
+ """Agent registry."""
96
+ return self._agent_registry
97
+
98
+ def create_system_prompt(self, task: str) -> str:
99
+ """Create system prompt for another agent given an input task."""
100
+ llm = BUILDER_LLM
101
+ fmt_messages = GEN_SYS_PROMPT_TMPL.format_messages(task=task)
102
+ response = llm.chat(fmt_messages)
103
+ self._cache.system_prompt = response.message.content
104
+
105
+ return f"System prompt created: {response.message.content}"
106
+
107
+ def load_data(
108
+ self,
109
+ file_names: Optional[List[str]] = None,
110
+ directory: Optional[str] = None,
111
+ urls: Optional[List[str]] = None,
112
+ ) -> str:
113
+ """Load data for a given task.
114
+
115
+ Only ONE of file_names or directory or urls should be specified.
116
+
117
+ Args:
118
+ file_names (Optional[List[str]]): List of file names to load.
119
+ Defaults to None.
120
+ directory (Optional[str]): Directory to load files from.
121
+ urls (Optional[List[str]]): List of urls to load.
122
+ Defaults to None.
123
+
124
+ """
125
+ file_names = file_names or []
126
+ urls = urls or []
127
+ directory = directory or ""
128
+ docs = load_data(file_names=file_names, directory=directory, urls=urls)
129
+ self._cache.docs = docs
130
+ self._cache.file_names = file_names
131
+ self._cache.urls = urls
132
+ self._cache.directory = directory
133
+ return "Data loaded successfully."
134
+
135
+ def add_web_tool(self) -> str:
136
+ """Add a web tool to enable agent to solve a task."""
137
+ # TODO: make this not hardcoded to a web tool
138
+ # Set up Metaphor tool
139
+ if "web_search" in self._cache.tools:
140
+ return "Web tool already added."
141
+ else:
142
+ self._cache.tools.append("web_search")
143
+ return "Web tool added successfully."
144
+
145
+ def get_rag_params(self) -> Dict:
146
+ """Get parameters used to configure the RAG pipeline.
147
+
148
+ Should be called before `set_rag_params` so that the agent is aware of the
149
+ schema.
150
+
151
+ """
152
+ rag_params = self._cache.rag_params
153
+ return rag_params.dict()
154
+
155
+ def set_rag_params(self, **rag_params: Dict) -> str:
156
+ """Set RAG parameters.
157
+
158
+ These parameters will then be used to actually initialize the agent.
159
+ Should call `get_rag_params` first to get the schema of the input dictionary.
160
+
161
+ Args:
162
+ **rag_params (Dict): dictionary of RAG parameters.
163
+
164
+ """
165
+ new_dict = self._cache.rag_params.dict()
166
+ new_dict.update(rag_params)
167
+ rag_params_obj = RAGParams(**new_dict)
168
+ self._cache.rag_params = rag_params_obj
169
+ return "RAG parameters set successfully."
170
+
171
+ def create_agent(self, agent_id: Optional[str] = None) -> str:
172
+ """Create an agent.
173
+
174
+ There are no parameters for this function because all the
175
+ functions should have already been called to set up the agent.
176
+
177
+ """
178
+ if self._cache.system_prompt is None:
179
+ raise ValueError("Must set system prompt before creating agent.")
180
+
181
+ # construct additional tools
182
+ additional_tools = get_tool_objects(self.cache.tools)
183
+ agent, extra_info = construct_agent(
184
+ cast(str, self._cache.system_prompt),
185
+ cast(RAGParams, self._cache.rag_params),
186
+ self._cache.docs,
187
+ additional_tools=additional_tools,
188
+ )
189
+
190
+ # if agent_id not specified, randomly generate one
191
+ agent_id = agent_id or self._cache.agent_id or f"Agent_{str(uuid.uuid4())}"
192
+ self._cache.vector_index = extra_info["vector_index"]
193
+ self._cache.agent_id = agent_id
194
+ self._cache.agent = agent
195
+
196
+ # save the cache to disk
197
+ self._agent_registry.add_new_agent_cache(agent_id, self._cache)
198
+ return "Agent created successfully."
199
+
200
+ def update_agent(
201
+ self,
202
+ agent_id: str,
203
+ system_prompt: Optional[str] = None,
204
+ include_summarization: Optional[bool] = None,
205
+ top_k: Optional[int] = None,
206
+ chunk_size: Optional[int] = None,
207
+ embed_model: Optional[str] = None,
208
+ llm: Optional[str] = None,
209
+ additional_tools: Optional[List] = None,
210
+ ) -> None:
211
+ """Update agent.
212
+
213
+ Delete old agent by ID and create a new one.
214
+ Optionally update the system prompt and RAG parameters.
215
+
216
+ NOTE: Currently is manually called, not meant for agent use.
217
+
218
+ """
219
+ self._agent_registry.delete_agent_cache(self.cache.agent_id)
220
+
221
+ # set agent id
222
+ self.cache.agent_id = agent_id
223
+
224
+ # set system prompt
225
+ if system_prompt is not None:
226
+ self.cache.system_prompt = system_prompt
227
+ # get agent_builder
228
+ # We call set_rag_params and create_agent, which will
229
+ # update the cache
230
+ # TODO: decouple functions from tool functions exposed to the agent
231
+ rag_params_dict: Dict[str, Any] = {}
232
+ if include_summarization is not None:
233
+ rag_params_dict["include_summarization"] = include_summarization
234
+ if top_k is not None:
235
+ rag_params_dict["top_k"] = top_k
236
+ if chunk_size is not None:
237
+ rag_params_dict["chunk_size"] = chunk_size
238
+ if embed_model is not None:
239
+ rag_params_dict["embed_model"] = embed_model
240
+ if llm is not None:
241
+ rag_params_dict["llm"] = llm
242
+
243
+ self.set_rag_params(**rag_params_dict)
244
+
245
+ # update tools
246
+ if additional_tools is not None:
247
+ self.cache.tools = additional_tools
248
+
249
+ # this will update the agent in the cache
250
+ self.create_agent()
core/agent_builder/loader.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Loader agent."""
2
+
3
+ from typing import List, cast, Optional
4
+ from llama_index.tools import FunctionTool
5
+ from llama_index.agent.types import BaseAgent
6
+ from core.builder_config import BUILDER_LLM
7
+ from typing import Tuple, Callable
8
+ import streamlit as st
9
+
10
+ from core.param_cache import ParamCache
11
+ from core.utils import (
12
+ load_meta_agent,
13
+ )
14
+ from core.agent_builder.registry import AgentCacheRegistry
15
+ from core.agent_builder.base import RAGAgentBuilder, BaseRAGAgentBuilder
16
+ from core.agent_builder.multimodal import MultimodalRAGAgentBuilder
17
+
18
+ ####################
19
+ #### META Agent ####
20
+ ####################
21
+
22
+ RAG_BUILDER_SYS_STR = """\
23
+ You are helping to construct an agent given a user-specified task.
24
+ You should generally use the tools in this rough order to build the agent.
25
+
26
+ 1) Create system prompt tool: to create the system prompt for the agent.
27
+ 2) Load in user-specified data (based on file paths they specify).
28
+ 3) Decide whether or not to add additional tools.
29
+ 4) Set parameters for the RAG pipeline.
30
+ 5) Build the agent
31
+
32
+ This will be a back and forth conversation with the user. You should
33
+ continue asking users if there's anything else they want to do until
34
+ they say they're done. To help guide them on the process,
35
+ you can give suggestions on parameters they can set based on the tools they
36
+ have available (e.g. "Do you want to set the number of documents to retrieve?")
37
+
38
+ """
39
+
40
+
41
+ ### DEFINE Agent ####
42
+ # NOTE: here we define a function that is dependent on the LLM,
43
+ # please make sure to update the LLM above if you change the function below
44
+
45
+
46
+ def _get_builder_agent_tools(agent_builder: RAGAgentBuilder) -> List[FunctionTool]:
47
+ """Get list of builder agent tools to pass to the builder agent."""
48
+ # see if metaphor api key is set, otherwise don't add web tool
49
+ # TODO: refactor this later
50
+
51
+ if "metaphor_key" in st.secrets:
52
+ fns: List[Callable] = [
53
+ agent_builder.create_system_prompt,
54
+ agent_builder.load_data,
55
+ agent_builder.add_web_tool,
56
+ agent_builder.get_rag_params,
57
+ agent_builder.set_rag_params,
58
+ agent_builder.create_agent,
59
+ ]
60
+ else:
61
+ fns = [
62
+ agent_builder.create_system_prompt,
63
+ agent_builder.load_data,
64
+ agent_builder.get_rag_params,
65
+ agent_builder.set_rag_params,
66
+ agent_builder.create_agent,
67
+ ]
68
+
69
+ fn_tools: List[FunctionTool] = [FunctionTool.from_defaults(fn=fn) for fn in fns]
70
+ return fn_tools
71
+
72
+
73
+ def _get_mm_builder_agent_tools(
74
+ agent_builder: MultimodalRAGAgentBuilder,
75
+ ) -> List[FunctionTool]:
76
+ """Get list of builder agent tools to pass to the builder agent."""
77
+ fns: List[Callable] = [
78
+ agent_builder.create_system_prompt,
79
+ agent_builder.load_data,
80
+ agent_builder.get_rag_params,
81
+ agent_builder.set_rag_params,
82
+ agent_builder.create_agent,
83
+ ]
84
+
85
+ fn_tools: List[FunctionTool] = [FunctionTool.from_defaults(fn=fn) for fn in fns]
86
+ return fn_tools
87
+
88
+
89
+ # define agent
90
+ def load_meta_agent_and_tools(
91
+ cache: Optional[ParamCache] = None,
92
+ agent_registry: Optional[AgentCacheRegistry] = None,
93
+ is_multimodal: bool = False,
94
+ ) -> Tuple[BaseAgent, BaseRAGAgentBuilder]:
95
+ """Load meta agent and tools."""
96
+
97
+ if is_multimodal:
98
+ agent_builder: BaseRAGAgentBuilder = MultimodalRAGAgentBuilder(
99
+ cache, agent_registry=agent_registry
100
+ )
101
+ fn_tools = _get_mm_builder_agent_tools(
102
+ cast(MultimodalRAGAgentBuilder, agent_builder)
103
+ )
104
+ builder_agent = load_meta_agent(
105
+ fn_tools, llm=BUILDER_LLM, system_prompt=RAG_BUILDER_SYS_STR, verbose=True
106
+ )
107
+ else:
108
+ # think of this as tools for the agent to use
109
+ agent_builder = RAGAgentBuilder(cache, agent_registry=agent_registry)
110
+ fn_tools = _get_builder_agent_tools(agent_builder)
111
+ builder_agent = load_meta_agent(
112
+ fn_tools, llm=BUILDER_LLM, system_prompt=RAG_BUILDER_SYS_STR, verbose=True
113
+ )
114
+
115
+ return builder_agent, agent_builder
core/agent_builder/multimodal.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Multimodal agent builder."""
2
+
3
+ from llama_index.llms import ChatMessage
4
+ from typing import List, cast, Optional
5
+ from core.builder_config import BUILDER_LLM
6
+ from typing import Dict, Any
7
+ import uuid
8
+ from core.constants import AGENT_CACHE_DIR
9
+
10
+ from core.param_cache import ParamCache, RAGParams
11
+ from core.utils import (
12
+ load_data,
13
+ construct_mm_agent,
14
+ )
15
+ from core.agent_builder.registry import AgentCacheRegistry
16
+ from core.agent_builder.base import GEN_SYS_PROMPT_TMPL, BaseRAGAgentBuilder
17
+
18
+ from llama_index.chat_engine.types import BaseChatEngine
19
+
20
+ from llama_index.callbacks import trace_method
21
+ from llama_index.query_engine.multi_modal import SimpleMultiModalQueryEngine
22
+ from llama_index.chat_engine.types import (
23
+ AGENT_CHAT_RESPONSE_TYPE,
24
+ StreamingAgentChatResponse,
25
+ AgentChatResponse,
26
+ )
27
+ from llama_index.llms.base import ChatResponse
28
+ from typing import Generator
29
+
30
+
31
+ class MultimodalChatEngine(BaseChatEngine):
32
+ """Multimodal chat engine.
33
+
34
+ This chat engine is a light wrapper around a query engine.
35
+ Offers no real 'chat' functionality, is a beta feature.
36
+
37
+ """
38
+
39
+ def __init__(self, mm_query_engine: SimpleMultiModalQueryEngine) -> None:
40
+ """Init params."""
41
+ self._mm_query_engine = mm_query_engine
42
+
43
+ def reset(self) -> None:
44
+ """Reset conversation state."""
45
+ pass
46
+
47
+ @trace_method("chat")
48
+ def chat(
49
+ self, message: str, chat_history: Optional[List[ChatMessage]] = None
50
+ ) -> AGENT_CHAT_RESPONSE_TYPE:
51
+ """Main chat interface."""
52
+ # just return the top-k results
53
+ response = self._mm_query_engine.query(message)
54
+ return AgentChatResponse(response=str(response))
55
+
56
+ @trace_method("chat")
57
+ def stream_chat(
58
+ self, message: str, chat_history: Optional[List[ChatMessage]] = None
59
+ ) -> StreamingAgentChatResponse:
60
+ """Stream chat interface."""
61
+ response = self._mm_query_engine.query(message)
62
+
63
+ def _chat_stream(response: str) -> Generator[ChatResponse, None, None]:
64
+ yield ChatResponse(message=ChatMessage(role="assistant", content=response))
65
+
66
+ chat_stream = _chat_stream(str(response))
67
+ return StreamingAgentChatResponse(chat_stream=chat_stream)
68
+
69
+ @trace_method("chat")
70
+ async def achat(
71
+ self, message: str, chat_history: Optional[List[ChatMessage]] = None
72
+ ) -> AGENT_CHAT_RESPONSE_TYPE:
73
+ """Async version of main chat interface."""
74
+ response = await self._mm_query_engine.aquery(message)
75
+ return AgentChatResponse(response=str(response))
76
+
77
+ @trace_method("chat")
78
+ async def astream_chat(
79
+ self, message: str, chat_history: Optional[List[ChatMessage]] = None
80
+ ) -> StreamingAgentChatResponse:
81
+ """Async version of main chat interface."""
82
+ return self.stream_chat(message, chat_history)
83
+
84
+
85
+ class MultimodalRAGAgentBuilder(BaseRAGAgentBuilder):
86
+ """Multimodal RAG Agent builder.
87
+
88
+ Contains a set of functions to construct a RAG agent, including:
89
+ - setting system prompts
90
+ - loading data
91
+ - adding web search
92
+ - setting parameters (e.g. top-k)
93
+
94
+ Must pass in a cache. This cache will be modified as the agent is built.
95
+
96
+ """
97
+
98
+ def __init__(
99
+ self,
100
+ cache: Optional[ParamCache] = None,
101
+ agent_registry: Optional[AgentCacheRegistry] = None,
102
+ ) -> None:
103
+ """Init params."""
104
+ self._cache = cache or ParamCache()
105
+ self._agent_registry = agent_registry or AgentCacheRegistry(
106
+ str(AGENT_CACHE_DIR)
107
+ )
108
+
109
+ @property
110
+ def cache(self) -> ParamCache:
111
+ """Cache."""
112
+ return self._cache
113
+
114
+ @property
115
+ def agent_registry(self) -> AgentCacheRegistry:
116
+ """Agent registry."""
117
+ return self._agent_registry
118
+
119
+ def create_system_prompt(self, task: str) -> str:
120
+ """Create system prompt for another agent given an input task."""
121
+ llm = BUILDER_LLM
122
+ fmt_messages = GEN_SYS_PROMPT_TMPL.format_messages(task=task)
123
+ response = llm.chat(fmt_messages)
124
+ self._cache.system_prompt = response.message.content
125
+
126
+ return f"System prompt created: {response.message.content}"
127
+
128
+ def load_data(
129
+ self,
130
+ file_names: Optional[List[str]] = None,
131
+ directory: Optional[str] = None,
132
+ ) -> str:
133
+ """Load data for a given task.
134
+
135
+ Only ONE of file_names or directory should be specified.
136
+ **NOTE**: urls not supported in multi-modal setting.
137
+
138
+ Args:
139
+ file_names (Optional[List[str]]): List of file names to load.
140
+ Defaults to None.
141
+ directory (Optional[str]): Directory to load files from.
142
+
143
+ """
144
+ file_names = file_names or []
145
+ directory = directory or ""
146
+ docs = load_data(file_names=file_names, directory=directory)
147
+ self._cache.docs = docs
148
+ self._cache.file_names = file_names
149
+ self._cache.directory = directory
150
+ return "Data loaded successfully."
151
+
152
+ def get_rag_params(self) -> Dict:
153
+ """Get parameters used to configure the RAG pipeline.
154
+
155
+ Should be called before `set_rag_params` so that the agent is aware of the
156
+ schema.
157
+
158
+ """
159
+ rag_params = self._cache.rag_params
160
+ return rag_params.dict()
161
+
162
+ def set_rag_params(self, **rag_params: Dict) -> str:
163
+ """Set RAG parameters.
164
+
165
+ These parameters will then be used to actually initialize the agent.
166
+ Should call `get_rag_params` first to get the schema of the input dictionary.
167
+
168
+ Args:
169
+ **rag_params (Dict): dictionary of RAG parameters.
170
+
171
+ """
172
+ new_dict = self._cache.rag_params.dict()
173
+ new_dict.update(rag_params)
174
+ rag_params_obj = RAGParams(**new_dict)
175
+ self._cache.rag_params = rag_params_obj
176
+ return "RAG parameters set successfully."
177
+
178
+ def create_agent(self, agent_id: Optional[str] = None) -> str:
179
+ """Create an agent.
180
+
181
+ There are no parameters for this function because all the
182
+ functions should have already been called to set up the agent.
183
+
184
+ """
185
+ if self._cache.system_prompt is None:
186
+ raise ValueError("Must set system prompt before creating agent.")
187
+
188
+ # construct additional tools
189
+ agent, extra_info = construct_mm_agent(
190
+ cast(str, self._cache.system_prompt),
191
+ cast(RAGParams, self._cache.rag_params),
192
+ self._cache.docs,
193
+ )
194
+
195
+ # if agent_id not specified, randomly generate one
196
+ agent_id = agent_id or self._cache.agent_id or f"Agent_{str(uuid.uuid4())}"
197
+ self._cache.builder_type = "multimodal"
198
+ self._cache.vector_index = extra_info["vector_index"]
199
+ self._cache.agent_id = agent_id
200
+ self._cache.agent = agent
201
+
202
+ # save the cache to disk
203
+ self._agent_registry.add_new_agent_cache(agent_id, self._cache)
204
+ return "Agent created successfully."
205
+
206
+ def update_agent(
207
+ self,
208
+ agent_id: str,
209
+ system_prompt: Optional[str] = None,
210
+ include_summarization: Optional[bool] = None,
211
+ top_k: Optional[int] = None,
212
+ chunk_size: Optional[int] = None,
213
+ embed_model: Optional[str] = None,
214
+ llm: Optional[str] = None,
215
+ additional_tools: Optional[List] = None,
216
+ ) -> None:
217
+ """Update agent.
218
+
219
+ Delete old agent by ID and create a new one.
220
+ Optionally update the system prompt and RAG parameters.
221
+
222
+ NOTE: Currently is manually called, not meant for agent use.
223
+
224
+ """
225
+ self._agent_registry.delete_agent_cache(self.cache.agent_id)
226
+
227
+ # set agent id
228
+ self.cache.agent_id = agent_id
229
+
230
+ # set system prompt
231
+ if system_prompt is not None:
232
+ self.cache.system_prompt = system_prompt
233
+ # get agent_builder
234
+ # We call set_rag_params and create_agent, which will
235
+ # update the cache
236
+ # TODO: decouple functions from tool functions exposed to the agent
237
+ rag_params_dict: Dict[str, Any] = {}
238
+ if include_summarization is not None:
239
+ rag_params_dict["include_summarization"] = include_summarization
240
+ if top_k is not None:
241
+ rag_params_dict["top_k"] = top_k
242
+ if chunk_size is not None:
243
+ rag_params_dict["chunk_size"] = chunk_size
244
+ if embed_model is not None:
245
+ rag_params_dict["embed_model"] = embed_model
246
+ if llm is not None:
247
+ rag_params_dict["llm"] = llm
248
+
249
+ self.set_rag_params(**rag_params_dict)
250
+
251
+ # update tools
252
+ if additional_tools is not None:
253
+ self.cache.tools = additional_tools
254
+
255
+ # this will update the agent in the cache
256
+ self.create_agent()
core/agent_builder/registry.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Agent builder registry."""
2
+
3
+ from typing import List
4
+ from typing import Union
5
+ from pathlib import Path
6
+ import json
7
+ import shutil
8
+
9
+ from core.param_cache import ParamCache
10
+
11
+
12
+ class AgentCacheRegistry:
13
+ """Registry for agent caches, in disk.
14
+
15
+ Can register new agent caches, load agent caches, delete agent caches, etc.
16
+
17
+ """
18
+
19
+ def __init__(self, dir: Union[str, Path]) -> None:
20
+ """Init params."""
21
+ self._dir = dir
22
+
23
+ def _add_agent_id_to_directory(self, agent_id: str) -> None:
24
+ """Save agent id to directory."""
25
+ full_path = Path(self._dir) / "agent_ids.json"
26
+ if not full_path.exists():
27
+ with open(full_path, "w") as f:
28
+ json.dump({"agent_ids": [agent_id]}, f)
29
+ else:
30
+ with open(full_path, "r") as f:
31
+ agent_ids = json.load(f)["agent_ids"]
32
+ if agent_id in agent_ids:
33
+ raise ValueError(f"Agent id {agent_id} already exists.")
34
+ agent_ids_set = set(agent_ids)
35
+ agent_ids_set.add(agent_id)
36
+ with open(full_path, "w") as f:
37
+ json.dump({"agent_ids": list(agent_ids_set)}, f)
38
+
39
+ def add_new_agent_cache(self, agent_id: str, cache: ParamCache) -> None:
40
+ """Register agent."""
41
+ # save the cache to disk
42
+ agent_cache_path = f"{self._dir}/{agent_id}"
43
+ cache.save_to_disk(agent_cache_path)
44
+ # save to agent ids
45
+ self._add_agent_id_to_directory(agent_id)
46
+
47
+ def get_agent_ids(self) -> List[str]:
48
+ """Get agent ids."""
49
+ full_path = Path(self._dir) / "agent_ids.json"
50
+ if not full_path.exists():
51
+ return []
52
+ with open(full_path, "r") as f:
53
+ agent_ids = json.load(f)["agent_ids"]
54
+
55
+ return agent_ids
56
+
57
+ def get_agent_cache(self, agent_id: str) -> ParamCache:
58
+ """Get agent cache."""
59
+ full_path = Path(self._dir) / f"{agent_id}"
60
+ if not full_path.exists():
61
+ raise ValueError(f"Cache for agent {agent_id} does not exist.")
62
+ cache = ParamCache.load_from_disk(str(full_path))
63
+ return cache
64
+
65
+ def delete_agent_cache(self, agent_id: str) -> None:
66
+ """Delete agent cache."""
67
+ # modify / resave agent_ids
68
+ agent_ids = self.get_agent_ids()
69
+ new_agent_ids = [id for id in agent_ids if id != agent_id]
70
+ full_path = Path(self._dir) / "agent_ids.json"
71
+ with open(full_path, "w") as f:
72
+ json.dump({"agent_ids": new_agent_ids}, f)
73
+
74
+ # remove agent cache
75
+ full_path = Path(self._dir) / f"{agent_id}"
76
+ if full_path.exists():
77
+ # recursive delete
78
+ shutil.rmtree(full_path)
core/builder_config.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration."""
2
+ import streamlit as st
3
+ import os
4
+
5
+ ### DEFINE BUILDER_LLM #####
6
+ ## Uncomment the LLM you want to use to construct the meta agent
7
+
8
+ ## OpenAI
9
+ from llama_index.llms import OpenAI
10
+
11
+ # set OpenAI Key - use Streamlit secrets
12
+ os.environ["OPENAI_API_KEY"] = st.secrets.openai_key
13
+ # load LLM
14
+ BUILDER_LLM = OpenAI(model="gpt-4-1106-preview")
15
+
16
+ # # Anthropic (make sure you `pip install anthropic`)
17
+ # from llama_index.llms import Anthropic
18
+ # # set Anthropic key
19
+ # os.environ["ANTHROPIC_API_KEY"] = st.secrets.anthropic_key
20
+ # BUILDER_LLM = Anthropic()
core/callback_manager.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Streaming callback manager."""
2
+ from llama_index.callbacks.base_handler import BaseCallbackHandler
3
+ from llama_index.callbacks.schema import CBEventType
4
+
5
+ from typing import Optional, Dict, Any, List, Callable
6
+
7
+ STORAGE_DIR = "./storage" # directory to cache the generated index
8
+ DATA_DIR = "./data" # directory containing the documents to index
9
+
10
+
11
+ class StreamlitFunctionsCallbackHandler(BaseCallbackHandler):
12
+ """Callback handler that outputs streamlit components given events."""
13
+
14
+ def __init__(self, msg_handler: Callable[[str], Any]) -> None:
15
+ """Initialize the base callback handler."""
16
+ self.msg_handler = msg_handler
17
+ super().__init__([], [])
18
+
19
+ def on_event_start(
20
+ self,
21
+ event_type: CBEventType,
22
+ payload: Optional[Dict[str, Any]] = None,
23
+ event_id: str = "",
24
+ parent_id: str = "",
25
+ **kwargs: Any,
26
+ ) -> str:
27
+ """Run when an event starts and return id of event."""
28
+ if event_type == CBEventType.FUNCTION_CALL:
29
+ if payload is None:
30
+ raise ValueError("Payload cannot be None")
31
+ arguments_str = payload["function_call"]
32
+ tool_str = payload["tool"].name
33
+ print_str = f"Calling function: {tool_str} with args: {arguments_str}\n\n"
34
+ self.msg_handler(print_str)
35
+ else:
36
+ pass
37
+ return event_id
38
+
39
+ def on_event_end(
40
+ self,
41
+ event_type: CBEventType,
42
+ payload: Optional[Dict[str, Any]] = None,
43
+ event_id: str = "",
44
+ **kwargs: Any,
45
+ ) -> None:
46
+ """Run when an event ends."""
47
+ pass
48
+ # TODO: currently we don't need to do anything here
49
+ # if event_type == CBEventType.FUNCTION_CALL:
50
+ # response = payload["function_call_response"]
51
+ # # Add this to queue
52
+ # print_str = (
53
+ # f"\n\nGot output: {response}\n"
54
+ # "========================\n\n"
55
+ # )
56
+ # elif event_type == CBEventType.AGENT_STEP:
57
+ # # put response into queue
58
+ # self._queue.put(payload["response"])
59
+
60
+ def start_trace(self, trace_id: Optional[str] = None) -> None:
61
+ """Run when an overall trace is launched."""
62
+ pass
63
+
64
+ def end_trace(
65
+ self,
66
+ trace_id: Optional[str] = None,
67
+ trace_map: Optional[Dict[str, List[str]]] = None,
68
+ ) -> None:
69
+ """Run when an overall trace is exited."""
70
+ pass
core/constants.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ AGENT_CACHE_DIR = Path(__file__).parent.parent / "cache" / "agents"
4
+ MESSAGES_CACHE_DIR = Path(__file__).parent.parent / "cache" / "messages"
core/param_cache.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Param cache."""
2
+
3
+ from pydantic import BaseModel, Field
4
+ from llama_index import (
5
+ VectorStoreIndex,
6
+ StorageContext,
7
+ load_index_from_storage,
8
+ )
9
+ from typing import List, cast, Optional
10
+ from llama_index.chat_engine.types import BaseChatEngine
11
+ from pathlib import Path
12
+ import json
13
+ import uuid
14
+ from core.utils import (
15
+ load_data,
16
+ get_tool_objects,
17
+ construct_agent,
18
+ RAGParams,
19
+ construct_mm_agent,
20
+ )
21
+
22
+
23
+ class ParamCache(BaseModel):
24
+ """Cache for RAG agent builder.
25
+
26
+ Created a wrapper class around a dict in case we wanted to more explicitly
27
+ type different items in the cache.
28
+
29
+ """
30
+
31
+ # arbitrary types
32
+ class Config:
33
+ arbitrary_types_allowed = True
34
+
35
+ # system prompt
36
+ system_prompt: Optional[str] = Field(
37
+ default=None, description="System prompt for RAG agent."
38
+ )
39
+ # data
40
+ file_names: List[str] = Field(
41
+ default_factory=list, description="File names as data source (if specified)"
42
+ )
43
+ urls: List[str] = Field(
44
+ default_factory=list, description="URLs as data source (if specified)"
45
+ )
46
+ directory: Optional[str] = Field(
47
+ default=None, description="Directory as data source (if specified)"
48
+ )
49
+
50
+ docs: List = Field(default_factory=list, description="Documents for RAG agent.")
51
+ # tools
52
+ tools: List = Field(
53
+ default_factory=list, description="Additional tools for RAG agent (e.g. web)"
54
+ )
55
+ # RAG params
56
+ rag_params: RAGParams = Field(
57
+ default_factory=RAGParams, description="RAG parameters for RAG agent."
58
+ )
59
+
60
+ # agent params
61
+ builder_type: str = Field(
62
+ default="default", description="Builder type (default, multimodal)."
63
+ )
64
+ vector_index: Optional[VectorStoreIndex] = Field(
65
+ default=None, description="Vector index for RAG agent."
66
+ )
67
+ agent_id: str = Field(
68
+ default_factory=lambda: f"Agent_{str(uuid.uuid4())}",
69
+ description="Agent ID for RAG agent.",
70
+ )
71
+ agent: Optional[BaseChatEngine] = Field(default=None, description="RAG agent.")
72
+
73
+ def save_to_disk(self, save_dir: str) -> None:
74
+ """Save cache to disk."""
75
+ # NOTE: more complex than just calling dict() because we want to
76
+ # only store serializable fields and be space-efficient
77
+
78
+ dict_to_serialize = {
79
+ "system_prompt": self.system_prompt,
80
+ "file_names": self.file_names,
81
+ "urls": self.urls,
82
+ "directory": self.directory,
83
+ # TODO: figure out tools
84
+ "tools": self.tools,
85
+ "rag_params": self.rag_params.dict(),
86
+ "builder_type": self.builder_type,
87
+ "agent_id": self.agent_id,
88
+ }
89
+ # store the vector store within the agent
90
+ if self.vector_index is None:
91
+ raise ValueError("Must specify vector index in order to save.")
92
+ self.vector_index.storage_context.persist(Path(save_dir) / "storage")
93
+
94
+ # if save_path directories don't exist, create it
95
+ if not Path(save_dir).exists():
96
+ Path(save_dir).mkdir(parents=True)
97
+ with open(Path(save_dir) / "cache.json", "w") as f:
98
+ json.dump(dict_to_serialize, f)
99
+
100
+ @classmethod
101
+ def load_from_disk(
102
+ cls,
103
+ save_dir: str,
104
+ ) -> "ParamCache":
105
+ """Load cache from disk."""
106
+ with open(Path(save_dir) / "cache.json", "r") as f:
107
+ cache_dict = json.load(f)
108
+
109
+ storage_context = StorageContext.from_defaults(
110
+ persist_dir=str(Path(save_dir) / "storage")
111
+ )
112
+ if cache_dict["builder_type"] == "multimodal":
113
+ from llama_index.indices.multi_modal.base import MultiModalVectorStoreIndex
114
+
115
+ vector_index: VectorStoreIndex = cast(
116
+ MultiModalVectorStoreIndex, load_index_from_storage(storage_context)
117
+ )
118
+ else:
119
+ vector_index = cast(
120
+ VectorStoreIndex, load_index_from_storage(storage_context)
121
+ )
122
+
123
+ # replace rag params with RAGParams object
124
+ cache_dict["rag_params"] = RAGParams(**cache_dict["rag_params"])
125
+
126
+ # add in the missing fields
127
+ # load docs
128
+ cache_dict["docs"] = load_data(
129
+ file_names=cache_dict["file_names"],
130
+ urls=cache_dict["urls"],
131
+ directory=cache_dict["directory"],
132
+ )
133
+ # load agent from index
134
+ additional_tools = get_tool_objects(cache_dict["tools"])
135
+
136
+ if cache_dict["builder_type"] == "multimodal":
137
+ vector_index = cast(MultiModalVectorStoreIndex, vector_index)
138
+ agent, _ = construct_mm_agent(
139
+ cache_dict["system_prompt"],
140
+ cache_dict["rag_params"],
141
+ cache_dict["docs"],
142
+ mm_vector_index=vector_index,
143
+ )
144
+ else:
145
+ agent, _ = construct_agent(
146
+ cache_dict["system_prompt"],
147
+ cache_dict["rag_params"],
148
+ cache_dict["docs"],
149
+ vector_index=vector_index,
150
+ additional_tools=additional_tools,
151
+ # TODO: figure out tools
152
+ )
153
+ cache_dict["vector_index"] = vector_index
154
+ cache_dict["agent"] = agent
155
+
156
+ return cls(**cache_dict)
core/utils.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utils."""
2
+
3
+ from llama_index.llms import OpenAI, Anthropic, Replicate
4
+ from llama_index.llms.base import LLM
5
+ from llama_index.llms.utils import resolve_llm
6
+ from pydantic import BaseModel, Field
7
+ import os
8
+ from llama_index.agent import OpenAIAgent, ReActAgent
9
+ from llama_index.agent.react.prompts import REACT_CHAT_SYSTEM_HEADER
10
+ from llama_index import (
11
+ VectorStoreIndex,
12
+ SummaryIndex,
13
+ ServiceContext,
14
+ Document,
15
+ )
16
+ from typing import List, cast, Optional
17
+ from llama_index import SimpleDirectoryReader
18
+ from llama_index.embeddings.utils import resolve_embed_model
19
+ from llama_index.tools import QueryEngineTool, ToolMetadata
20
+ from llama_index.agent.types import BaseAgent
21
+ from llama_index.chat_engine.types import BaseChatEngine
22
+ from llama_index.agent.react.formatter import ReActChatFormatter
23
+ from llama_index.llms.openai_utils import is_function_calling_model
24
+ from llama_index.chat_engine import CondensePlusContextChatEngine
25
+ from core.builder_config import BUILDER_LLM
26
+ from typing import Dict, Tuple, Any
27
+ import streamlit as st
28
+
29
+ from llama_index.callbacks import CallbackManager, trace_method
30
+ from core.callback_manager import StreamlitFunctionsCallbackHandler
31
+ from llama_index.schema import ImageNode, NodeWithScore
32
+
33
+ ### BETA: Multi-modal
34
+ from llama_index.indices.multi_modal.base import MultiModalVectorStoreIndex
35
+ from llama_index.multi_modal_llms.openai import OpenAIMultiModal
36
+ from llama_index.indices.multi_modal.retriever import (
37
+ MultiModalVectorIndexRetriever,
38
+ )
39
+ from llama_index.llms import ChatMessage
40
+ from llama_index.query_engine.multi_modal import SimpleMultiModalQueryEngine
41
+ from llama_index.chat_engine.types import (
42
+ AGENT_CHAT_RESPONSE_TYPE,
43
+ StreamingAgentChatResponse,
44
+ AgentChatResponse,
45
+ )
46
+ from llama_index.llms.base import ChatResponse
47
+ from typing import Generator
48
+
49
+
50
+ class RAGParams(BaseModel):
51
+ """RAG parameters.
52
+
53
+ Parameters used to configure a RAG pipeline.
54
+
55
+ """
56
+
57
+ include_summarization: bool = Field(
58
+ default=False,
59
+ description=(
60
+ "Whether to include summarization in the RAG pipeline. (only for GPT-4)"
61
+ ),
62
+ )
63
+ top_k: int = Field(
64
+ default=2, description="Number of documents to retrieve from vector store."
65
+ )
66
+ chunk_size: int = Field(default=1024, description="Chunk size for vector store.")
67
+ embed_model: str = Field(
68
+ default="default", description="Embedding model to use (default is OpenAI)"
69
+ )
70
+ llm: str = Field(
71
+ default="gpt-4-1106-preview", description="LLM to use for summarization."
72
+ )
73
+
74
+
75
+ def _resolve_llm(llm_str: str) -> LLM:
76
+ """Resolve LLM."""
77
+ # TODO: make this less hardcoded with if-else statements
78
+ # see if there's a prefix
79
+ # - if there isn't, assume it's an OpenAI model
80
+ # - if there is, resolve it
81
+ tokens = llm_str.split(":")
82
+ if len(tokens) == 1:
83
+ os.environ["OPENAI_API_KEY"] = st.secrets.openai_key
84
+ llm: LLM = OpenAI(model=llm_str)
85
+ elif tokens[0] == "local":
86
+ llm = resolve_llm(llm_str)
87
+ elif tokens[0] == "openai":
88
+ os.environ["OPENAI_API_KEY"] = st.secrets.openai_key
89
+ llm = OpenAI(model=tokens[1])
90
+ elif tokens[0] == "anthropic":
91
+ os.environ["ANTHROPIC_API_KEY"] = st.secrets.anthropic_key
92
+ llm = Anthropic(model=tokens[1])
93
+ elif tokens[0] == "replicate":
94
+ os.environ["REPLICATE_API_KEY"] = st.secrets.replicate_key
95
+ llm = Replicate(model=tokens[1])
96
+ else:
97
+ raise ValueError(f"LLM {llm_str} not recognized.")
98
+ return llm
99
+
100
+
101
+ def load_data(
102
+ file_names: Optional[List[str]] = None,
103
+ directory: Optional[str] = None,
104
+ urls: Optional[List[str]] = None,
105
+ ) -> List[Document]:
106
+ """Load data."""
107
+ file_names = file_names or []
108
+ directory = directory or ""
109
+ urls = urls or []
110
+
111
+ # get number depending on whether specified
112
+ num_specified = sum(1 for v in [file_names, urls, directory] if v)
113
+
114
+ if num_specified == 0:
115
+ raise ValueError("Must specify either file_names or urls or directory.")
116
+ elif num_specified > 1:
117
+ raise ValueError("Must specify only one of file_names or urls or directory.")
118
+ elif file_names:
119
+ reader = SimpleDirectoryReader(input_files=file_names)
120
+ docs = reader.load_data()
121
+ elif directory:
122
+ reader = SimpleDirectoryReader(input_dir=directory)
123
+ docs = reader.load_data()
124
+ elif urls:
125
+ from llama_hub.web.simple_web.base import SimpleWebPageReader
126
+
127
+ # use simple web page reader from llamahub
128
+ loader = SimpleWebPageReader()
129
+ docs = loader.load_data(urls=urls)
130
+ else:
131
+ raise ValueError("Must specify either file_names or urls or directory.")
132
+
133
+ return docs
134
+
135
+
136
+ def load_agent(
137
+ tools: List,
138
+ llm: LLM,
139
+ system_prompt: str,
140
+ extra_kwargs: Optional[Dict] = None,
141
+ **kwargs: Any,
142
+ ) -> BaseChatEngine:
143
+ """Load agent."""
144
+ extra_kwargs = extra_kwargs or {}
145
+ if isinstance(llm, OpenAI) and is_function_calling_model(llm.model):
146
+ # TODO: use default msg handler
147
+ # TODO: separate this from agent_utils.py...
148
+ def _msg_handler(msg: str) -> None:
149
+ """Message handler."""
150
+ st.info(msg)
151
+ st.session_state.agent_messages.append(
152
+ {"role": "assistant", "content": msg, "msg_type": "info"}
153
+ )
154
+
155
+ # add streamlit callbacks (to inject events)
156
+ handler = StreamlitFunctionsCallbackHandler(_msg_handler)
157
+ callback_manager = CallbackManager([handler])
158
+ # get OpenAI Agent
159
+ agent: BaseChatEngine = OpenAIAgent.from_tools(
160
+ tools=tools,
161
+ llm=llm,
162
+ system_prompt=system_prompt,
163
+ **kwargs,
164
+ callback_manager=callback_manager,
165
+ )
166
+ else:
167
+ if "vector_index" not in extra_kwargs:
168
+ raise ValueError(
169
+ "Must pass in vector index for CondensePlusContextChatEngine."
170
+ )
171
+ vector_index = cast(VectorStoreIndex, extra_kwargs["vector_index"])
172
+ rag_params = cast(RAGParams, extra_kwargs["rag_params"])
173
+ # use condense + context chat engine
174
+ agent = CondensePlusContextChatEngine.from_defaults(
175
+ vector_index.as_retriever(similarity_top_k=rag_params.top_k),
176
+ )
177
+
178
+ return agent
179
+
180
+
181
+ def load_meta_agent(
182
+ tools: List,
183
+ llm: LLM,
184
+ system_prompt: str,
185
+ extra_kwargs: Optional[Dict] = None,
186
+ **kwargs: Any,
187
+ ) -> BaseAgent:
188
+ """Load meta agent.
189
+
190
+ TODO: consolidate with load_agent.
191
+
192
+ The meta-agent *has* to perform tool-use.
193
+
194
+ """
195
+ extra_kwargs = extra_kwargs or {}
196
+ if isinstance(llm, OpenAI) and is_function_calling_model(llm.model):
197
+ # get OpenAI Agent
198
+
199
+ agent: BaseAgent = OpenAIAgent.from_tools(
200
+ tools=tools,
201
+ llm=llm,
202
+ system_prompt=system_prompt,
203
+ **kwargs,
204
+ )
205
+ else:
206
+ agent = ReActAgent.from_tools(
207
+ tools=tools,
208
+ llm=llm,
209
+ react_chat_formatter=ReActChatFormatter(
210
+ system_header=system_prompt + "\n" + REACT_CHAT_SYSTEM_HEADER,
211
+ ),
212
+ **kwargs,
213
+ )
214
+
215
+ return agent
216
+
217
+
218
+ def construct_agent(
219
+ system_prompt: str,
220
+ rag_params: RAGParams,
221
+ docs: List[Document],
222
+ vector_index: Optional[VectorStoreIndex] = None,
223
+ additional_tools: Optional[List] = None,
224
+ ) -> Tuple[BaseChatEngine, Dict]:
225
+ """Construct agent from docs / parameters / indices."""
226
+ extra_info = {}
227
+ additional_tools = additional_tools or []
228
+
229
+ # first resolve llm and embedding model
230
+ embed_model = resolve_embed_model(rag_params.embed_model)
231
+ # llm = resolve_llm(rag_params.llm)
232
+ # TODO: use OpenAI for now
233
+ # llm = OpenAI(model=rag_params.llm)
234
+ llm = _resolve_llm(rag_params.llm)
235
+
236
+ # first let's index the data with the right parameters
237
+ service_context = ServiceContext.from_defaults(
238
+ chunk_size=rag_params.chunk_size,
239
+ llm=llm,
240
+ embed_model=embed_model,
241
+ )
242
+
243
+ if vector_index is None:
244
+ vector_index = VectorStoreIndex.from_documents(
245
+ docs, service_context=service_context
246
+ )
247
+ else:
248
+ pass
249
+
250
+ extra_info["vector_index"] = vector_index
251
+
252
+ vector_query_engine = vector_index.as_query_engine(
253
+ similarity_top_k=rag_params.top_k
254
+ )
255
+ all_tools = []
256
+ vector_tool = QueryEngineTool(
257
+ query_engine=vector_query_engine,
258
+ metadata=ToolMetadata(
259
+ name="vector_tool",
260
+ description=("Use this tool to answer any user question over any data."),
261
+ ),
262
+ )
263
+ all_tools.append(vector_tool)
264
+ if rag_params.include_summarization:
265
+ summary_index = SummaryIndex.from_documents(
266
+ docs, service_context=service_context
267
+ )
268
+ summary_query_engine = summary_index.as_query_engine()
269
+ summary_tool = QueryEngineTool(
270
+ query_engine=summary_query_engine,
271
+ metadata=ToolMetadata(
272
+ name="summary_tool",
273
+ description=(
274
+ "Use this tool for any user questions that ask "
275
+ "for a summarization of content"
276
+ ),
277
+ ),
278
+ )
279
+ all_tools.append(summary_tool)
280
+
281
+ # then we add tools
282
+ all_tools.extend(additional_tools)
283
+
284
+ # build agent
285
+ if system_prompt is None:
286
+ return "System prompt not set yet. Please set system prompt first."
287
+
288
+ agent = load_agent(
289
+ all_tools,
290
+ llm=llm,
291
+ system_prompt=system_prompt,
292
+ verbose=True,
293
+ extra_kwargs={"vector_index": vector_index, "rag_params": rag_params},
294
+ )
295
+ return agent, extra_info
296
+
297
+
298
+ def get_web_agent_tool() -> QueryEngineTool:
299
+ """Get web agent tool.
300
+
301
+ Wrap with our load and search tool spec.
302
+
303
+ """
304
+ from llama_hub.tools.metaphor.base import MetaphorToolSpec
305
+
306
+ # TODO: set metaphor API key
307
+ metaphor_tool = MetaphorToolSpec(
308
+ api_key=st.secrets.metaphor_key,
309
+ )
310
+ metaphor_tool_list = metaphor_tool.to_tool_list()
311
+
312
+ # TODO: LoadAndSearch doesn't work yet
313
+ # The search_and_retrieve_documents tool is the third in the tool list,
314
+ # as seen above
315
+ # wrapped_retrieve = LoadAndSearchToolSpec.from_defaults(
316
+ # metaphor_tool_list[2],
317
+ # )
318
+
319
+ # NOTE: requires openai right now
320
+ # We don't give the Agent our unwrapped retrieve document tools
321
+ # instead passing the wrapped tools
322
+ web_agent = OpenAIAgent.from_tools(
323
+ # [*wrapped_retrieve.to_tool_list(), metaphor_tool_list[4]],
324
+ metaphor_tool_list,
325
+ llm=BUILDER_LLM,
326
+ verbose=True,
327
+ )
328
+
329
+ # return agent as a tool
330
+ # TODO: tune description
331
+ web_agent_tool = QueryEngineTool.from_defaults(
332
+ web_agent,
333
+ name="web_agent",
334
+ description="""
335
+ This agent can answer questions by searching the web. \
336
+ Use this tool if the answer is ONLY likely to be found by searching \
337
+ the internet, especially for queries about recent events.
338
+ """,
339
+ )
340
+
341
+ return web_agent_tool
342
+
343
+
344
+ def get_tool_objects(tool_names: List[str]) -> List:
345
+ """Get tool objects from tool names."""
346
+ # construct additional tools
347
+ tool_objs = []
348
+ for tool_name in tool_names:
349
+ if tool_name == "web_search":
350
+ # build web agent
351
+ tool_objs.append(get_web_agent_tool())
352
+ else:
353
+ raise ValueError(f"Tool {tool_name} not recognized.")
354
+
355
+ return tool_objs
356
+
357
+
358
+ class MultimodalChatEngine(BaseChatEngine):
359
+ """Multimodal chat engine.
360
+
361
+ This chat engine is a light wrapper around a query engine.
362
+ Offers no real 'chat' functionality, is a beta feature.
363
+
364
+ """
365
+
366
+ def __init__(self, mm_query_engine: SimpleMultiModalQueryEngine) -> None:
367
+ """Init params."""
368
+ self._mm_query_engine = mm_query_engine
369
+
370
+ def reset(self) -> None:
371
+ """Reset conversation state."""
372
+ pass
373
+
374
+ @property
375
+ def chat_history(self) -> List[ChatMessage]:
376
+ return []
377
+
378
+ @trace_method("chat")
379
+ def chat(
380
+ self, message: str, chat_history: Optional[List[ChatMessage]] = None
381
+ ) -> AGENT_CHAT_RESPONSE_TYPE:
382
+ """Main chat interface."""
383
+ # just return the top-k results
384
+ response = self._mm_query_engine.query(message)
385
+ return AgentChatResponse(
386
+ response=str(response), source_nodes=response.source_nodes
387
+ )
388
+
389
+ @trace_method("chat")
390
+ def stream_chat(
391
+ self, message: str, chat_history: Optional[List[ChatMessage]] = None
392
+ ) -> StreamingAgentChatResponse:
393
+ """Stream chat interface."""
394
+ response = self._mm_query_engine.query(message)
395
+
396
+ def _chat_stream(response: str) -> Generator[ChatResponse, None, None]:
397
+ yield ChatResponse(message=ChatMessage(role="assistant", content=response))
398
+
399
+ chat_stream = _chat_stream(str(response))
400
+ return StreamingAgentChatResponse(
401
+ chat_stream=chat_stream, source_nodes=response.source_nodes
402
+ )
403
+
404
+ @trace_method("chat")
405
+ async def achat(
406
+ self, message: str, chat_history: Optional[List[ChatMessage]] = None
407
+ ) -> AGENT_CHAT_RESPONSE_TYPE:
408
+ """Async version of main chat interface."""
409
+ response = await self._mm_query_engine.aquery(message)
410
+ return AgentChatResponse(
411
+ response=str(response), source_nodes=response.source_nodes
412
+ )
413
+
414
+ @trace_method("chat")
415
+ async def astream_chat(
416
+ self, message: str, chat_history: Optional[List[ChatMessage]] = None
417
+ ) -> StreamingAgentChatResponse:
418
+ """Async version of main chat interface."""
419
+ return self.stream_chat(message, chat_history)
420
+
421
+
422
+ def construct_mm_agent(
423
+ system_prompt: str,
424
+ rag_params: RAGParams,
425
+ docs: List[Document],
426
+ mm_vector_index: Optional[VectorStoreIndex] = None,
427
+ additional_tools: Optional[List] = None,
428
+ ) -> Tuple[BaseChatEngine, Dict]:
429
+ """Construct agent from docs / parameters / indices.
430
+
431
+ NOTE: system prompt isn't used right now
432
+
433
+ """
434
+ extra_info = {}
435
+ additional_tools = additional_tools or []
436
+
437
+ # first resolve llm and embedding model
438
+ embed_model = resolve_embed_model(rag_params.embed_model)
439
+ # TODO: use OpenAI for now
440
+ os.environ["OPENAI_API_KEY"] = st.secrets.openai_key
441
+ openai_mm_llm = OpenAIMultiModal(model="gpt-4-vision-preview", max_new_tokens=1500)
442
+
443
+ # first let's index the data with the right parameters
444
+ service_context = ServiceContext.from_defaults(
445
+ chunk_size=rag_params.chunk_size,
446
+ embed_model=embed_model,
447
+ )
448
+
449
+ if mm_vector_index is None:
450
+ mm_vector_index = MultiModalVectorStoreIndex.from_documents(
451
+ docs, service_context=service_context
452
+ )
453
+ else:
454
+ pass
455
+
456
+ mm_retriever = mm_vector_index.as_retriever(similarity_top_k=rag_params.top_k)
457
+ mm_query_engine = SimpleMultiModalQueryEngine(
458
+ cast(MultiModalVectorIndexRetriever, mm_retriever),
459
+ multi_modal_llm=openai_mm_llm,
460
+ )
461
+
462
+ extra_info["vector_index"] = mm_vector_index
463
+
464
+ # use condense + context chat engine
465
+ agent = MultimodalChatEngine(mm_query_engine)
466
+
467
+ return agent, extra_info
468
+
469
+
470
+ def get_image_and_text_nodes(
471
+ nodes: List[NodeWithScore],
472
+ ) -> Tuple[List[NodeWithScore], List[NodeWithScore]]:
473
+ image_nodes = []
474
+ text_nodes = []
475
+ for res_node in nodes:
476
+ if isinstance(res_node.node, ImageNode):
477
+ image_nodes.append(res_node)
478
+ else:
479
+ text_nodes.append(res_node)
480
+ return image_nodes, text_nodes
pages/.DS_Store ADDED
Binary file (6.15 kB). View file
 
pages/4_🤖_ChatDoctor.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Streamlit page showing builder config."""
2
+ import streamlit as st
3
+ from st_utils import add_sidebar, get_current_state
4
+ from core.utils import get_image_and_text_nodes
5
+ from llama_index.schema import MetadataMode
6
+ from llama_index.chat_engine.types import AGENT_CHAT_RESPONSE_TYPE
7
+ from typing import Dict, Optional
8
+ import pandas as pd
9
+
10
+
11
+ ####################
12
+ #### STREAMLIT #####
13
+ ####################
14
+
15
+
16
+ st.set_page_config(
17
+ page_title="ChatDoctor: your virtual primary care physician assistant",
18
+ page_icon="🤖💬",
19
+ layout="centered",
20
+ #initial_sidebar_state="auto", #ggyimah set this to off
21
+ menu_items=None,
22
+ )
23
+ st.title("ChatDoctor: your virtual primary care physician assistant")
24
+ #st.info(
25
+ # "Welcome!!! My name is ChatDoctor and I am trained to provide medical diagnoses and advice.",
26
+ # icon="ℹ️",
27
+ #)
28
+
29
+
30
+
31
+ current_state = get_current_state()
32
+ add_sidebar()
33
+
34
+ if (
35
+ "agent_messages" not in st.session_state.keys()
36
+ ): # Initialize the chat messages history
37
+ st.session_state.agent_messages = [
38
+ {"role": "assistant", "content": "I am trained to provide medical diagnoses and advice. How may I help you, today?"}
39
+ ]
40
+
41
+
42
+ def display_sources(response: AGENT_CHAT_RESPONSE_TYPE) -> None:
43
+ image_nodes, text_nodes = get_image_and_text_nodes(response.source_nodes)
44
+ if len(image_nodes) > 0 or len(text_nodes) > 0:
45
+ with st.expander("Sources"):
46
+ # get image nodes
47
+ if len(image_nodes) > 0:
48
+ st.subheader("Images")
49
+ for image_node in image_nodes:
50
+ st.image(image_node.metadata["file_path"])
51
+
52
+ if len(text_nodes) > 0:
53
+ st.subheader("Text")
54
+ sources_df_list = []
55
+ for text_node in text_nodes:
56
+ sources_df_list.append(
57
+ {
58
+ "ID": text_node.id_,
59
+ "Text": text_node.node.get_content(
60
+ metadata_mode=MetadataMode.ALL
61
+ ),
62
+ }
63
+ )
64
+ sources_df = pd.DataFrame(sources_df_list)
65
+ st.dataframe(sources_df)
66
+
67
+
68
+ def add_to_message_history(
69
+ role: str, content: str, extra: Optional[Dict] = None
70
+ ) -> None:
71
+ message = {"role": role, "content": str(content), "extra": extra}
72
+ st.session_state.agent_messages.append(message) # Add response to message history
73
+
74
+
75
+ def display_messages() -> None:
76
+ """Display messages."""
77
+ for message in st.session_state.agent_messages: # Display the prior chat messages
78
+ with st.chat_message(message["role"]):
79
+ msg_type = message["msg_type"] if "msg_type" in message.keys() else "text"
80
+ if msg_type == "text":
81
+ st.write(message["content"])
82
+ elif msg_type == "info":
83
+ st.info(message["content"], icon="ℹ️")
84
+ else:
85
+ raise ValueError(f"Unknown message type: {msg_type}")
86
+
87
+ # display sources
88
+ if "extra" in message and isinstance(message["extra"], dict):
89
+ if "response" in message["extra"].keys():
90
+ display_sources(message["extra"]["response"])
91
+
92
+
93
+ # if agent is created, then we can chat with it
94
+ if current_state.cache is not None and current_state.cache.agent is not None:
95
+ st.info(f"Viewing config for agent: {current_state.cache.agent_id}", icon="ℹ️")
96
+ agent = current_state.cache.agent
97
+
98
+ # display prior messages
99
+ display_messages()
100
+
101
+ # don't process selected for now
102
+ if prompt := st.chat_input(
103
+ "Your question"
104
+ ): # Prompt for user input and save to chat history
105
+ add_to_message_history("user", prompt)
106
+ with st.chat_message("user"):
107
+ st.write(prompt)
108
+
109
+ # If last message is not from assistant, generate a new response
110
+ if st.session_state.agent_messages[-1]["role"] != "assistant":
111
+ with st.chat_message("assistant"):
112
+ with st.spinner("Thinking..."):
113
+ response = agent.chat(str(prompt))
114
+ st.write(str(response))
115
+
116
+ # display sources
117
+ # Multi-modal: check if image nodes are present
118
+ display_sources(response)
119
+
120
+ add_to_message_history(
121
+ "assistant", str(response), extra={"response": response}
122
+ )
123
+ else:
124
+ st.info("In the side bar, select the ChatDoctor virtual agent (Agent_950acb55-056f-4324-957d-15e1c9b48695) to get started.\n")
125
+ st.info("Since this app is running on a free basic server, it could take from 2 to 10 minutes for the virtual agent to join you. \n Please be patient.")
126
+
tests/__init__.py ADDED
File without changes