File size: 3,167 Bytes
d63c9ee
 
dfc4889
d63c9ee
 
dfc4889
d63c9ee
 
 
 
 
dfc4889
 
 
d63c9ee
 
 
 
 
 
 
 
 
 
 
dfc4889
 
 
 
 
 
d63c9ee
 
 
 
 
 
 
 
dfc4889
 
 
 
 
 
 
d63c9ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfc4889
 
 
 
 
 
 
 
 
 
 
d63c9ee
 
 
 
 
 
 
 
 
 
 
 
 
6b192cb
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from langchain.agents import Tool, AgentType, initialize_agent
from langchain.memory import ConversationBufferMemory
# from langchain.utilities import DuckDuckGoSearchAPIWrapper
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.agents import AgentExecutor
from langchain import hub
from langchain.agents.format_scratchpad import format_log_to_str
from langchain.agents.output_parsers import ReActSingleInputOutputParser
from langchain.tools.render import render_text_description
import os
from tools.kg_search import lookup_kg
from tools.tavily_search import tavily_search
from tools.tavily_search_v2 import tavily_search, tavily_qna_search

from dotenv import load_dotenv
from langchain.agents import Tool
from langchain_core.prompts import PromptTemplate

load_dotenv()
os.environ["GOOGLE_API_KEY"] = os.getenv("GEMINI_API_KEY")
llm = ChatGoogleGenerativeAI(
    model= "gemini-1.5-flash-latest",
    temperature = 0
)

# search = DuckDuckGoSearchAPIWrapper()
#
# search_tool = Tool(name="Current Search",
#                    func=search.run,
#                    description="Useful when you need to answer questions about detail jobs information or search a job."
#                    )

kg_query = Tool(
    name = 'Query Knowledge Graph',
    func = lookup_kg,
    description='Useful for when you need to answer questions about job posts.'
)


web_search = Tool(
    name = 'Web Search',
    func = tavily_qna_search,
    description = "Useful for when you need to search for external information."
)

tools = [kg_query, web_search]


with open("prompts/react_prompt_v2.txt", "r") as file:
    react_template = file.read()

react_prompt = PromptTemplate(
    input_variables = ["tools", "tool_names", "input", "agent_scratchpad", "chat_history"],
    template = react_template
)

prompt = react_prompt.partial(
    tools = render_text_description(tools),
    tool_names = ", ".join([t.name for t in tools]),
)

llm_with_stop = llm.bind(stop=["\nObservation"])

agent = (
    {
        "input": lambda x: x["input"],
        "agent_scratchpad": lambda x: format_log_to_str(x["intermediate_steps"]),
        "chat_history": lambda x: x["chat_history"],
    }
    | prompt
    | llm_with_stop
    | ReActSingleInputOutputParser()
)

memory = ConversationBufferMemory(memory_key="chat_history")

agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, memory=memory)


def get_react_agent(memory):
    agent_executor = AgentExecutor(
        agent=agent,
        tools=tools,
        verbose=True,
        memory=memory
    )

    return agent_executor

# result = agent_executor.invoke({"input": "Have any company recruit Machine Learning jobs?"})
# print(result)

# result = agent_chain.run(input = "Have any company recruit Machine Learning jobs?")
# print(result)

# question = {
#     "input": "What did I just ask?"
# }
#
# result = agent_executor.invoke(question)
# print(result)

# if __name__ == "__main__":
#     while True:
#         try:
#             question = input("> ")
#             result = agent_executor.invoke({
#                 "input": question
#             })
#         except:
#             break