leo-bourrel commited on
Commit
c3b9b9a
·
1 Parent(s): 389b0ad

feat: first draft of conversational retrieval

Browse files
Files changed (1) hide show
  1. app.py +29 -8
app.py CHANGED
@@ -5,10 +5,15 @@ import streamlit.components.v1 as components
5
  from css import load_css
6
  from langchain import OpenAI
7
  from langchain.callbacks import get_openai_callback
8
- from langchain.chains import ConversationChain
9
- from langchain.chains.conversation.memory import ConversationSummaryMemory
 
 
10
  from message import Message
11
 
 
 
 
12
 
13
  def initialize_session_state():
14
  if "history" not in st.session_state:
@@ -16,21 +21,37 @@ def initialize_session_state():
16
  if "token_count" not in st.session_state:
17
  st.session_state.token_count = 0
18
  if "conversation" not in st.session_state:
 
 
 
 
 
 
 
 
 
 
19
  llm = OpenAI(
20
  temperature=0,
21
  openai_api_key=os.environ["OPENAI_API_KEY"],
22
- model_name="text-davinci-003",
23
  )
24
- st.session_state.conversation = ConversationChain(
25
- llm=llm,
26
- memory=ConversationSummaryMemory(llm=llm),
 
27
  )
28
 
29
 
30
  def on_click_callback():
31
  with get_openai_callback() as cb:
32
  human_prompt = st.session_state.human_prompt
33
- llm_response = st.session_state.conversation.run(human_prompt)
 
 
 
 
 
34
  st.session_state.history.append(Message("human", human_prompt))
35
  st.session_state.history.append(Message("ai", llm_response))
36
  st.session_state.token_count += cb.total_tokens
@@ -84,7 +105,7 @@ information_placeholder.caption(
84
  f"""
85
  Used {st.session_state.token_count} tokens \n
86
  Debug Langchain conversation:
87
- {st.session_state.conversation.memory.buffer}
88
  """
89
  )
90
 
 
5
  from css import load_css
6
  from langchain import OpenAI
7
  from langchain.callbacks import get_openai_callback
8
+ from langchain.chains import ConversationalRetrievalChain
9
+ from langchain.chains.conversation.memory import ConversationBufferMemory
10
+ from langchain.embeddings.openai import OpenAIEmbeddings
11
+ from langchain.vectorstores.pgvector import PGVector
12
  from message import Message
13
 
14
+ CONNECTION_STRING = "postgresql+psycopg2://localhost/sorbobot"
15
+ COLLECTION_NAME = ""
16
+
17
 
18
  def initialize_session_state():
19
  if "history" not in st.session_state:
 
21
  if "token_count" not in st.session_state:
22
  st.session_state.token_count = 0
23
  if "conversation" not in st.session_state:
24
+ embeddings = OpenAIEmbeddings()
25
+
26
+ store = PGVector(
27
+ collection_name=COLLECTION_NAME,
28
+ connection_string=CONNECTION_STRING,
29
+ embedding_function=embeddings,
30
+ )
31
+
32
+ retriever = store.as_retriever()
33
+
34
  llm = OpenAI(
35
  temperature=0,
36
  openai_api_key=os.environ["OPENAI_API_KEY"],
37
+ model="text-davinci-003",
38
  )
39
+
40
+ st.session_state.memory = ConversationBufferMemory()
41
+ st.session_state.conversation = ConversationalRetrievalChain.from_llm(
42
+ llm=llm, retriever=retriever
43
  )
44
 
45
 
46
  def on_click_callback():
47
  with get_openai_callback() as cb:
48
  human_prompt = st.session_state.human_prompt
49
+ llm_response = st.session_state.conversation.run(
50
+ {
51
+ "question": human_prompt,
52
+ "chat_history": st.session_state.memory.buffer,
53
+ }
54
+ )
55
  st.session_state.history.append(Message("human", human_prompt))
56
  st.session_state.history.append(Message("ai", llm_response))
57
  st.session_state.token_count += cb.total_tokens
 
105
  f"""
106
  Used {st.session_state.token_count} tokens \n
107
  Debug Langchain conversation:
108
+ {st.session_state.memory.buffer}
109
  """
110
  )
111