Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -4,27 +4,26 @@ import streamlit as st
|
|
4 |
from langchain.embeddings import HuggingFaceInstructEmbeddings, HuggingFaceEmbeddings
|
5 |
from langchain.vectorstores.faiss import FAISS
|
6 |
from huggingface_hub import snapshot_download
|
7 |
-
|
8 |
-
from langchain.
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
)
|
14 |
-
from langchain.schema import (
|
15 |
-
AIMessage,
|
16 |
-
HumanMessage,
|
17 |
-
SystemMessage
|
18 |
)
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
|
21 |
-
from langchain.chains.llm import LLMChain
|
22 |
-
from langchain.callbacks.base import CallbackManager
|
23 |
-
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
24 |
-
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
|
25 |
-
from langchain.chains.question_answering import load_qa_chain
|
26 |
|
27 |
-
st.set_page_config(
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
#Load API Key
|
30 |
api_key = os.environ["OPENAI_API_KEY"]
|
@@ -34,9 +33,10 @@ with st.sidebar:
|
|
34 |
book = st.radio("Embedding Model: ",
|
35 |
["Sbert"]
|
36 |
)
|
37 |
-
|
|
|
38 |
#load embedding models
|
39 |
-
@st.
|
40 |
def load_embedding_models(model):
|
41 |
|
42 |
if model == 'Sbert':
|
@@ -53,14 +53,10 @@ def load_embedding_models(model):
|
|
53 |
|
54 |
return emb
|
55 |
|
56 |
-
st.title(f"Talk to CFA Level 1 Book")
|
57 |
-
st.markdown("#### Have a conversation with the CFA Curriculum by the CFA Institute π")
|
58 |
-
|
59 |
-
|
60 |
embeddings = load_embedding_models(book)
|
61 |
|
62 |
##### functionss ####
|
63 |
-
@st.
|
64 |
def load_vectorstore(_embeddings):
|
65 |
# download from hugging face
|
66 |
cache_dir="cfa_level_1_cache"
|
@@ -83,122 +79,75 @@ def load_vectorstore(_embeddings):
|
|
83 |
print(target_path)
|
84 |
|
85 |
# load faiss
|
86 |
-
|
87 |
-
|
88 |
-
return
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
def
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
unique_pages = ""
|
160 |
-
for item in unique_sources:
|
161 |
-
unique_pages += str(item) + ", "
|
162 |
-
|
163 |
-
# will look like 1, 2, 3,
|
164 |
-
pages = unique_pages[:-2] # removes the last comma and space
|
165 |
-
|
166 |
-
# source text
|
167 |
-
full_source = ""
|
168 |
-
for item in result['source_documents']:
|
169 |
-
full_source += f"- **{item.metadata['source']}**" + "\n" + item.page_content + "\n\n"
|
170 |
-
|
171 |
-
# will look like:
|
172 |
-
# - Page: {number}
|
173 |
-
# {extracted text from book}
|
174 |
-
extract = full_source
|
175 |
-
|
176 |
-
return answer, pages, extract
|
177 |
-
|
178 |
-
|
179 |
-
##### main ####
|
180 |
-
user_input = st.text_area("Your question", "What is an MBS and who are the main issuers and investors of the MBS market?", key="input")
|
181 |
-
|
182 |
-
col1, col2 = st.columns([10, 1])
|
183 |
-
|
184 |
-
# show question
|
185 |
-
col1.write(f"**You:** {user_input}")
|
186 |
-
|
187 |
-
# ask button to the right of the displayed question
|
188 |
-
ask = col2.button("Ask", type="primary")
|
189 |
-
|
190 |
-
|
191 |
-
if ask:
|
192 |
-
|
193 |
-
with st.spinner("this can take about a minute for your first question because some models have to be downloaded π₯Ίππ»ππ»"):
|
194 |
-
try:
|
195 |
-
answer, pages, extract = get_answer(question=user_input)
|
196 |
-
except Exception as e:
|
197 |
-
st.write(f"Error with Download: {e}")
|
198 |
-
st.stop()
|
199 |
-
|
200 |
-
st.write(f"{answer}")
|
201 |
-
|
202 |
-
# sources
|
203 |
-
with st.expander(label = f"From: {pages}", expanded = False):
|
204 |
-
st.markdown(extract)
|
|
|
4 |
from langchain.embeddings import HuggingFaceInstructEmbeddings, HuggingFaceEmbeddings
|
5 |
from langchain.vectorstores.faiss import FAISS
|
6 |
from huggingface_hub import snapshot_download
|
7 |
+
|
8 |
+
from langchain.callbacks import StreamlitCallbackHandler
|
9 |
+
from langchain.agents import OpenAIFunctionsAgent, AgentExecutor
|
10 |
+
from langchain.agents.agent_toolkits import create_retriever_tool
|
11 |
+
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
|
12 |
+
AgentTokenBufferMemory,
|
|
|
|
|
|
|
|
|
|
|
13 |
)
|
14 |
+
from langchain.chat_models import ChatOpenAI
|
15 |
+
from langchain.schema import SystemMessage, AIMessage, HumanMessage
|
16 |
+
from langchain.prompts import MessagesPlaceholder
|
17 |
+
from langsmith import Client
|
18 |
|
19 |
+
client = Client()
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
st.set_page_config(
|
22 |
+
page_title="Chat with CFA Level 1",
|
23 |
+
page_icon="π",
|
24 |
+
layout="wide",
|
25 |
+
initial_sidebar_state="collapsed",
|
26 |
+
)
|
27 |
|
28 |
#Load API Key
|
29 |
api_key = os.environ["OPENAI_API_KEY"]
|
|
|
33 |
book = st.radio("Embedding Model: ",
|
34 |
["Sbert"]
|
35 |
)
|
36 |
+
|
37 |
+
|
38 |
#load embedding models
|
39 |
+
@st.cache_resource(show_spinner=True)
|
40 |
def load_embedding_models(model):
|
41 |
|
42 |
if model == 'Sbert':
|
|
|
53 |
|
54 |
return emb
|
55 |
|
|
|
|
|
|
|
|
|
56 |
embeddings = load_embedding_models(book)
|
57 |
|
58 |
##### functionss ####
|
59 |
+
@st.cache_resource(show_spinner=False)
|
60 |
def load_vectorstore(_embeddings):
|
61 |
# download from hugging face
|
62 |
cache_dir="cfa_level_1_cache"
|
|
|
79 |
print(target_path)
|
80 |
|
81 |
# load faiss
|
82 |
+
vectorstore = FAISS.load_local(folder_path=target_path, embeddings=_embeddings)
|
83 |
+
|
84 |
+
return vectorstore.as_retriever(search_kwargs={"k": 4})
|
85 |
+
|
86 |
+
tool = create_retriever_tool(
|
87 |
+
load_vectorstore(),
|
88 |
+
"search_cfa_docs",
|
89 |
+
"Searches and returns documents regarding the CFA level 1 curriculum. CFA is a rigorous program for investment professionals which covers topics such as ethics, corporate finance, economics, fixed income, equities and derivatives markets. You do not know anything about the CFA program, so if you are ever asked about CFA material or curriculum you should use this tool.",
|
90 |
+
)
|
91 |
+
tools = [tool]
|
92 |
+
llm = ChatOpenAI(temperature=0, streaming=True, model="gpt-4")
|
93 |
+
message = SystemMessage(
|
94 |
+
content=(
|
95 |
+
"You are a helpful chatbot who is tasked with answering questions about the CFA level 1 program. "
|
96 |
+
"Unless otherwise explicitly stated, it is probably fair to assume that questions are about the CFA program and materials. "
|
97 |
+
"If there is any ambiguity, you probably assume they are about that."
|
98 |
+
)
|
99 |
+
)
|
100 |
+
|
101 |
+
prompt = OpenAIFunctionsAgent.create_prompt(
|
102 |
+
system_message=message,
|
103 |
+
extra_prompt_messages=[MessagesPlaceholder(variable_name="history")],
|
104 |
+
)
|
105 |
+
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
|
106 |
+
agent_executor = AgentExecutor(
|
107 |
+
agent=agent,
|
108 |
+
tools=tools,
|
109 |
+
verbose=True,
|
110 |
+
return_intermediate_steps=True,
|
111 |
+
)
|
112 |
+
memory = AgentTokenBufferMemory(llm=llm)
|
113 |
+
starter_message = "Ask me anything about the CFA Level 1 Curriculum!"
|
114 |
+
if "messages" not in st.session_state or st.sidebar.button("Clear message history"):
|
115 |
+
st.session_state["messages"] = [AIMessage(content=starter_message)]
|
116 |
+
|
117 |
+
|
118 |
+
def send_feedback(run_id, score):
|
119 |
+
client.create_feedback(run_id, "user_score", score=score)
|
120 |
+
|
121 |
+
|
122 |
+
for msg in st.session_state.messages:
|
123 |
+
if isinstance(msg, AIMessage):
|
124 |
+
st.chat_message("assistant").write(msg.content)
|
125 |
+
elif isinstance(msg, HumanMessage):
|
126 |
+
st.chat_message("user").write(msg.content)
|
127 |
+
memory.chat_memory.add_message(msg)
|
128 |
+
|
129 |
+
|
130 |
+
if prompt := st.chat_input(placeholder=starter_message):
|
131 |
+
st.chat_message("user").write(prompt)
|
132 |
+
with st.chat_message("assistant"):
|
133 |
+
st_callback = StreamlitCallbackHandler(st.container())
|
134 |
+
response = agent_executor(
|
135 |
+
{"input": prompt, "history": st.session_state.messages},
|
136 |
+
callbacks=[st_callback],
|
137 |
+
include_run_info=True,
|
138 |
+
)
|
139 |
+
st.session_state.messages.append(AIMessage(content=response["output"]))
|
140 |
+
st.write(response["output"])
|
141 |
+
memory.save_context({"input": prompt}, response)
|
142 |
+
st.session_state["messages"] = memory.buffer
|
143 |
+
run_id = response["__run"].run_id
|
144 |
+
|
145 |
+
col_blank, col_text, col1, col2 = st.columns([10, 2, 1, 1])
|
146 |
+
with col_text:
|
147 |
+
st.text("Feedback:")
|
148 |
+
|
149 |
+
with col1:
|
150 |
+
st.button("π", on_click=send_feedback, args=(run_id, 1))
|
151 |
+
|
152 |
+
with col2:
|
153 |
+
st.button("π", on_click=send_feedback, args=(run_id, 0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|