DrishtiSharma commited on
Commit
739d571
·
verified ·
1 Parent(s): 106a3e5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -0
app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from dotenv import load_dotenv
4
+ from langchain_openai import ChatOpenAI
5
+ from langchain.agents import AgentExecutor, create_openai_tools_agent
6
+ from langchain_core.messages import BaseMessage, HumanMessage
7
+ from langchain_community.tools.tavily_search import TavilySearchResults
8
+ from langchain_experimental.tools import PythonREPLTool
9
+ from langchain_community.document_loaders import DirectoryLoader, TextLoader
10
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
11
+ from langchain_community.vectorstores import Chroma
12
+ from langchain.embeddings import HuggingFaceBgeEmbeddings
13
+ from langchain_core.output_parsers import StrOutputParser
14
+ from langchain_core.runnables import RunnablePassthrough
15
+ from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
16
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
17
+ from langgraph.graph import StateGraph, END
18
+ from typing import Annotated, Sequence, TypedDict
19
+ import functools
20
+ import operator
21
+
22
+ # Load environment variables
23
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
24
+ TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
25
+
26
+ if not OPENAI_API_KEY or not TAVILY_API_KEY:
27
+ st.error("Please set OPENAI_API_KEY and TAVILY_API_KEY in your environment variables.")
28
+ st.stop()
29
+
30
+ # Initialize API keys and LLM
31
+ llm = ChatOpenAI(model="gpt-4-1106-preview", openai_api_key=OPENAI_API_KEY)
32
+
33
+ # Utility Functions
34
+ def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str):
35
+ prompt = ChatPromptTemplate.from_messages([
36
+ ("system", system_prompt),
37
+ MessagesPlaceholder(variable_name="messages"),
38
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
39
+ ])
40
+ agent = create_openai_tools_agent(llm, tools, prompt)
41
+ return AgentExecutor(agent=agent, tools=tools)
42
+
43
+ def agent_node(state, agent, name):
44
+ result = agent.invoke(state)
45
+ return {"messages": [HumanMessage(content=result["output"], name=name)]}
46
+
47
+ @tool
48
+ def RAG(state):
49
+ st.session_state.outputs.append('-> Calling RAG ->')
50
+ question = state
51
+ template = """Answer the question based only on the following context:\n{context}\nQuestion: {question}"""
52
+ prompt = ChatPromptTemplate.from_template(template)
53
+ retrieval_chain = (
54
+ {"context": retriever, "question": RunnablePassthrough()} |
55
+ prompt |
56
+ llm |
57
+ StrOutputParser()
58
+ )
59
+ result = retrieval_chain.invoke(question)
60
+ return result
61
+
62
+ # Load Tools and Retriever
63
+ tavily_tool = TavilySearchResults(max_results=5, tavily_api_key=TAVILY_API_KEY)
64
+ python_repl_tool = PythonREPLTool()
65
+
66
+ # File Upload Section
67
+ st.title("Multi-Agent Workflow Demonstration")
68
+ uploaded_files = st.file_uploader("Upload your source files (TXT)", accept_multiple_files=True, type=['txt'])
69
+
70
+ if uploaded_files:
71
+ docs = []
72
+ for uploaded_file in uploaded_files:
73
+ content = uploaded_file.read().decode("utf-8")
74
+ docs.append(TextLoader(file_path=None, content=content).load()[0])
75
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=10, length_function=len)
76
+ new_docs = text_splitter.split_documents(documents=docs)
77
+ embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-base-en-v1.5", model_kwargs={'device': 'cpu'}, encode_kwargs={'normalize_embeddings': True})
78
+ db = Chroma.from_documents(new_docs, embeddings)
79
+ retriever = db.as_retriever(search_kwargs={"k": 4})
80
+ else:
81
+ retriever = None
82
+ st.warning("Please upload at least one text file to proceed.")
83
+ st.stop()
84
+
85
+ # Create Agents
86
+ research_agent = create_agent(llm, [tavily_tool], "You are a web researcher.")
87
+ code_agent = create_agent(llm, [python_repl_tool], "You may generate safe python code to analyze data and generate charts using matplotlib.")
88
+ RAG_agent = create_agent(llm, [RAG], "Use this tool when questions are related to Japan or Sports category.")
89
+
90
+ research_node = functools.partial(agent_node, agent=research_agent, name="Researcher")
91
+ code_node = functools.partial(agent_node, agent=code_agent, name="Coder")
92
+ rag_node = functools.partial(agent_node, agent=RAG_agent, name="RAG")
93
+
94
+ members = ["RAG", "Researcher", "Coder"]
95
+ system_prompt = (
96
+ "You are a supervisor managing these workers: {members}. Respond with the next worker or FINISH. "
97
+ "Use RAG tool for Japan or Sports questions."
98
+ )
99
+ options = ["FINISH"] + members
100
+ function_def = {
101
+ "name": "route", "description": "Select the next role.",
102
+ "parameters": {
103
+ "title": "routeSchema", "type": "object",
104
+ "properties": {"next": {"anyOf": [{"enum": options}]}}, "required": ["next"]
105
+ }
106
+ }
107
+ prompt = ChatPromptTemplate.from_messages([
108
+ ("system", system_prompt),
109
+ MessagesPlaceholder(variable_name="messages"),
110
+ ("system", "Given the conversation above, who should act next? Select one of: {options}"),
111
+ ]).partial(options=str(options), members=", ".join(members))
112
+
113
+ supervisor_chain = (prompt | llm.bind_functions(functions=[function_def], function_call="route") | JsonOutputFunctionsParser())
114
+
115
+ # Build Workflow
116
+ class AgentState(TypedDict):
117
+ messages: Annotated[Sequence[BaseMessage], operator.add]
118
+ next: str
119
+
120
+ workflow = StateGraph(AgentState)
121
+ workflow.add_node("Researcher", research_node)
122
+ workflow.add_node("Coder", code_node)
123
+ workflow.add_node("RAG", rag_node)
124
+ workflow.add_node("supervisor", supervisor_chain)
125
+
126
+ for member in members:
127
+ workflow.add_edge(member, "supervisor")
128
+ conditional_map = {k: k for k in members}
129
+ conditional_map["FINISH"] = END
130
+ workflow.add_conditional_edges("supervisor", lambda x: x["next"], conditional_map)
131
+ workflow.set_entry_point("supervisor")
132
+ graph = workflow.compile()
133
+
134
+ # Streamlit UI
135
+ if 'outputs' not in st.session_state:
136
+ st.session_state.outputs = []
137
+
138
+ user_input = st.text_area("Enter your task or question:")
139
+
140
+ def run_workflow(task):
141
+ st.session_state.outputs.clear()
142
+ st.session_state.outputs.append(f"User Input: {task}")
143
+ for state in graph.stream({"messages": [HumanMessage(content=task)]}):
144
+ if "__end__" not in state:
145
+ st.session_state.outputs.append(str(state))
146
+ st.session_state.outputs.append("----")
147
+
148
+ if st.button("Run Workflow"):
149
+ if user_input:
150
+ run_workflow(user_input)
151
+ else:
152
+ st.warning("Please enter a task or question.")
153
+
154
+ st.subheader("Workflow Output:")
155
+ for output in st.session_state.outputs:
156
+ st.text(output)