nickmuchi commited on
Commit
8cbab56
Β·
1 Parent(s): 92a6a0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -144
app.py CHANGED
@@ -4,27 +4,26 @@ import streamlit as st
4
  from langchain.embeddings import HuggingFaceInstructEmbeddings, HuggingFaceEmbeddings
5
  from langchain.vectorstores.faiss import FAISS
6
  from huggingface_hub import snapshot_download
7
- from langchain.chat_models import ChatOpenAI
8
- from langchain.prompts.chat import (
9
- ChatPromptTemplate,
10
- SystemMessagePromptTemplate,
11
- AIMessagePromptTemplate,
12
- HumanMessagePromptTemplate,
13
- )
14
- from langchain.schema import (
15
- AIMessage,
16
- HumanMessage,
17
- SystemMessage
18
  )
 
 
 
 
19
 
20
- from langchain.chains import ConversationalRetrievalChain
21
- from langchain.chains.llm import LLMChain
22
- from langchain.callbacks.base import CallbackManager
23
- from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
24
- from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
25
- from langchain.chains.question_answering import load_qa_chain
26
 
27
- st.set_page_config(page_title="CFA Level 1", page_icon="πŸ“–")
 
 
 
 
 
28
 
29
  #Load API Key
30
  api_key = os.environ["OPENAI_API_KEY"]
@@ -34,9 +33,10 @@ with st.sidebar:
34
  book = st.radio("Embedding Model: ",
35
  ["Sbert"]
36
  )
37
-
 
38
  #load embedding models
39
- @st.experimental_singleton(show_spinner=True)
40
  def load_embedding_models(model):
41
 
42
  if model == 'Sbert':
@@ -53,14 +53,10 @@ def load_embedding_models(model):
53
 
54
  return emb
55
 
56
- st.title(f"Talk to CFA Level 1 Book")
57
- st.markdown("#### Have a conversation with the CFA Curriculum by the CFA Institute πŸ™Š")
58
-
59
-
60
  embeddings = load_embedding_models(book)
61
 
62
  ##### functionss ####
63
- @st.experimental_singleton(show_spinner=False)
64
  def load_vectorstore(_embeddings):
65
  # download from hugging face
66
  cache_dir="cfa_level_1_cache"
@@ -83,122 +79,75 @@ def load_vectorstore(_embeddings):
83
  print(target_path)
84
 
85
  # load faiss
86
- docsearch = FAISS.load_local(folder_path=target_path, embeddings=_embeddings)
87
-
88
- return docsearch
89
-
90
-
91
- @st.experimental_memo(show_spinner=False)
92
- def load_prompt():
93
- system_template="""You are an expert in finance, economics, investing, ethics, derivatives and markets.
94
- Use the following pieces of context to answer the users question. If you don't know the answer,
95
- just say that you don't know, don't try to make up an answer.
96
- ALWAYS return a "sources" part in your answer.
97
- The "sources" part should be a reference to the source of the context from which you got your answer. List all sources used
98
-
99
- You can use other sources to answer the question if and only if the given context does not have sufficient and relevant information to answer
100
- the question.
101
-
102
- The output should be a markdown code snippet formatted in the following schema:
103
- ```json
104
- {{
105
- answer: is foo
106
- sources: xyz
107
- }}
108
- ```
109
- Begin!
110
- ----------------
111
- {context}"""
112
- messages = [
113
- SystemMessagePromptTemplate.from_template(system_template),
114
- HumanMessagePromptTemplate.from_template("{question}")
115
- ]
116
- prompt = ChatPromptTemplate.from_messages(messages)
117
-
118
- return prompt
119
-
120
-
121
- @st.experimental_singleton(show_spinner=False)
122
- def load_chain():
123
- '''Load langchain Conversational Retrieval Chain'''
124
-
125
- vectorstore = load_vectorstore(embeddings)
126
- llm = ChatOpenAI(temperature=0, model_name='gpt-4-32k-0613')
127
- streaming_llm = ChatOpenAI(model_name='gpt-4-32k-0613',
128
- streaming=True,
129
- callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),
130
- verbose=True,
131
- temperature=0)
132
-
133
- question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)
134
- doc_chain = load_qa_chain(streaming_llm, chain_type="stuff", prompt=load_prompt())
135
-
136
- qa = ConversationalRetrievalChain(
137
- retriever=vectorstore.as_retriever(),
138
- combine_docs_chain=doc_chain,
139
- question_generator=question_generator,
140
- return_source_documents=True)
141
-
142
- return qa
143
-
144
- chat_history = []
145
-
146
- def get_answer(question):
147
- '''Generate an answer from the chain'''
148
-
149
- chain = load_chain()
150
- result = chain({"question": question, "chat_history": chat_history})
151
-
152
- answer = result["answer"]
153
-
154
- # pages
155
- unique_sources = set()
156
- for item in result['source_documents']:
157
- unique_sources.add(item.metadata['source'].split(',')[1])
158
-
159
- unique_pages = ""
160
- for item in unique_sources:
161
- unique_pages += str(item) + ", "
162
-
163
- # will look like 1, 2, 3,
164
- pages = unique_pages[:-2] # removes the last comma and space
165
-
166
- # source text
167
- full_source = ""
168
- for item in result['source_documents']:
169
- full_source += f"- **{item.metadata['source']}**" + "\n" + item.page_content + "\n\n"
170
-
171
- # will look like:
172
- # - Page: {number}
173
- # {extracted text from book}
174
- extract = full_source
175
-
176
- return answer, pages, extract
177
-
178
-
179
- ##### main ####
180
- user_input = st.text_area("Your question", "What is an MBS and who are the main issuers and investors of the MBS market?", key="input")
181
-
182
- col1, col2 = st.columns([10, 1])
183
-
184
- # show question
185
- col1.write(f"**You:** {user_input}")
186
-
187
- # ask button to the right of the displayed question
188
- ask = col2.button("Ask", type="primary")
189
-
190
-
191
- if ask:
192
-
193
- with st.spinner("this can take about a minute for your first question because some models have to be downloaded πŸ₯ΊπŸ‘‰πŸ»πŸ‘ˆπŸ»"):
194
- try:
195
- answer, pages, extract = get_answer(question=user_input)
196
- except Exception as e:
197
- st.write(f"Error with Download: {e}")
198
- st.stop()
199
-
200
- st.write(f"{answer}")
201
-
202
- # sources
203
- with st.expander(label = f"From: {pages}", expanded = False):
204
- st.markdown(extract)
 
4
  from langchain.embeddings import HuggingFaceInstructEmbeddings, HuggingFaceEmbeddings
5
  from langchain.vectorstores.faiss import FAISS
6
  from huggingface_hub import snapshot_download
7
+
8
+ from langchain.callbacks import StreamlitCallbackHandler
9
+ from langchain.agents import OpenAIFunctionsAgent, AgentExecutor
10
+ from langchain.agents.agent_toolkits import create_retriever_tool
11
+ from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
12
+ AgentTokenBufferMemory,
 
 
 
 
 
13
  )
14
+ from langchain.chat_models import ChatOpenAI
15
+ from langchain.schema import SystemMessage, AIMessage, HumanMessage
16
+ from langchain.prompts import MessagesPlaceholder
17
+ from langsmith import Client
18
 
19
+ client = Client()
 
 
 
 
 
20
 
21
+ st.set_page_config(
22
+ page_title="Chat with CFA Level 1",
23
+ page_icon="πŸ“–",
24
+ layout="wide",
25
+ initial_sidebar_state="collapsed",
26
+ )
27
 
28
  #Load API Key
29
  api_key = os.environ["OPENAI_API_KEY"]
 
33
  book = st.radio("Embedding Model: ",
34
  ["Sbert"]
35
  )
36
+
37
+
38
  #load embedding models
39
+ @st.cache_resource(show_spinner=True)
40
  def load_embedding_models(model):
41
 
42
  if model == 'Sbert':
 
53
 
54
  return emb
55
 
 
 
 
 
56
  embeddings = load_embedding_models(book)
57
 
58
  ##### functionss ####
59
+ @st.cache_resource(show_spinner=False)
60
  def load_vectorstore(_embeddings):
61
  # download from hugging face
62
  cache_dir="cfa_level_1_cache"
 
79
  print(target_path)
80
 
81
  # load faiss
82
+ vectorstore = FAISS.load_local(folder_path=target_path, embeddings=_embeddings)
83
+
84
+ return vectorstore.as_retriever(search_kwargs={"k": 4})
85
+
86
+ tool = create_retriever_tool(
87
+ load_vectorstore(),
88
+ "search_cfa_docs",
89
+ "Searches and returns documents regarding the CFA level 1 curriculum. CFA is a rigorous program for investment professionals which covers topics such as ethics, corporate finance, economics, fixed income, equities and derivatives markets. You do not know anything about the CFA program, so if you are ever asked about CFA material or curriculum you should use this tool.",
90
+ )
91
+ tools = [tool]
92
+ llm = ChatOpenAI(temperature=0, streaming=True, model="gpt-4")
93
+ message = SystemMessage(
94
+ content=(
95
+ "You are a helpful chatbot who is tasked with answering questions about the CFA level 1 program. "
96
+ "Unless otherwise explicitly stated, it is probably fair to assume that questions are about the CFA program and materials. "
97
+ "If there is any ambiguity, you probably assume they are about that."
98
+ )
99
+ )
100
+
101
+ prompt = OpenAIFunctionsAgent.create_prompt(
102
+ system_message=message,
103
+ extra_prompt_messages=[MessagesPlaceholder(variable_name="history")],
104
+ )
105
+ agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
106
+ agent_executor = AgentExecutor(
107
+ agent=agent,
108
+ tools=tools,
109
+ verbose=True,
110
+ return_intermediate_steps=True,
111
+ )
112
+ memory = AgentTokenBufferMemory(llm=llm)
113
+ starter_message = "Ask me anything about the CFA Level 1 Curriculum!"
114
+ if "messages" not in st.session_state or st.sidebar.button("Clear message history"):
115
+ st.session_state["messages"] = [AIMessage(content=starter_message)]
116
+
117
+
118
+ def send_feedback(run_id, score):
119
+ client.create_feedback(run_id, "user_score", score=score)
120
+
121
+
122
+ for msg in st.session_state.messages:
123
+ if isinstance(msg, AIMessage):
124
+ st.chat_message("assistant").write(msg.content)
125
+ elif isinstance(msg, HumanMessage):
126
+ st.chat_message("user").write(msg.content)
127
+ memory.chat_memory.add_message(msg)
128
+
129
+
130
+ if prompt := st.chat_input(placeholder=starter_message):
131
+ st.chat_message("user").write(prompt)
132
+ with st.chat_message("assistant"):
133
+ st_callback = StreamlitCallbackHandler(st.container())
134
+ response = agent_executor(
135
+ {"input": prompt, "history": st.session_state.messages},
136
+ callbacks=[st_callback],
137
+ include_run_info=True,
138
+ )
139
+ st.session_state.messages.append(AIMessage(content=response["output"]))
140
+ st.write(response["output"])
141
+ memory.save_context({"input": prompt}, response)
142
+ st.session_state["messages"] = memory.buffer
143
+ run_id = response["__run"].run_id
144
+
145
+ col_blank, col_text, col1, col2 = st.columns([10, 2, 1, 1])
146
+ with col_text:
147
+ st.text("Feedback:")
148
+
149
+ with col1:
150
+ st.button("πŸ‘", on_click=send_feedback, args=(run_id, 1))
151
+
152
+ with col2:
153
+ st.button("πŸ‘Ž", on_click=send_feedback, args=(run_id, 0))