SorboBot / app.py
leo-bourrel's picture
feat: replace postgres with sqlite
5c20978
raw
history blame
5.6 kB
import json
import os
import sqlalchemy
import sqlite_vss
import streamlit as st
import streamlit.components.v1 as components
from langchain import OpenAI
from langchain.callbacks import get_openai_callback
from langchain.chains import ConversationalRetrievalChain
from langchain.chains.conversation.memory import ConversationBufferMemory
from langchain.embeddings import GPT4AllEmbeddings
from sqlalchemy import event
from chat_history import insert_chat_history, insert_chat_history_articles
from css import load_css
from custom_pgvector import CustomPGVector
from message import Message
CONNECTION_STRING = "sqlite:///data/sorbobot.db"
st.set_page_config(layout="wide")
st.title("Sorbobot - Le futur de la recherche scientifique interactive")
chat_column, doc_column = st.columns([2, 1])
def connect() -> sqlalchemy.engine.Connection:
engine = sqlalchemy.create_engine(CONNECTION_STRING)
@event.listens_for(engine, "connect")
def receive_connect(connection, _):
connection.enable_load_extension(True)
sqlite_vss.load(connection)
connection.enable_load_extension(False)
conn = engine.connect()
return conn
conn = connect()
def initialize_session_state():
if "history" not in st.session_state:
st.session_state.history = []
if "token_count" not in st.session_state:
st.session_state.token_count = 0
if "conversation" not in st.session_state:
embeddings = GPT4AllEmbeddings()
db = CustomPGVector(
embedding_function=embeddings,
table_name="article",
column_name="abstract_embedding",
connection=conn,
)
retriever = db.as_retriever()
llm = OpenAI(
temperature=0,
openai_api_key=os.environ["OPENAI_API_KEY"],
model="text-davinci-003",
)
memory = ConversationBufferMemory(
output_key="answer", memory_key="chat_history", return_messages=True
)
st.session_state.conversation = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
verbose=True,
memory=memory,
return_source_documents=True,
)
def on_click_callback():
with get_openai_callback() as cb:
human_prompt = st.session_state.human_prompt
llm_response = st.session_state.conversation(human_prompt)
st.session_state.history.append(Message("human", human_prompt))
st.session_state.history.append(
Message(
"ai", llm_response["answer"], documents=llm_response["source_documents"]
)
)
st.session_state.token_count += cb.total_tokens
history_id = insert_chat_history(conn, human_prompt, llm_response["answer"])
insert_chat_history_articles(conn, history_id, llm_response["source_documents"])
load_css()
initialize_session_state()
with chat_column:
chat_placeholder = st.container()
prompt_placeholder = st.form("chat-form")
information_placeholder = st.empty()
with chat_placeholder:
for chat in st.session_state.history:
div = f"""
<div class="chat-row
{'' if chat.origin == 'ai' else 'row-reverse'}">
<img class="chat-icon" src="./app/static/{
'ai_icon.png' if chat.origin == 'ai'
else 'user_icon.png'}"
width=32 height=32>
<div class="chat-bubble
{'ai-bubble' if chat.origin == 'ai' else 'human-bubble'}">
&#8203;{chat.message}
</div>
</div>
"""
st.markdown(div, unsafe_allow_html=True)
for _ in range(3):
st.markdown("")
with prompt_placeholder:
st.markdown("**Chat**")
cols = st.columns((6, 1))
cols[0].text_input(
"Chat",
value="Hello bot",
label_visibility="collapsed",
key="human_prompt",
)
cols[1].form_submit_button(
"Submit",
type="primary",
on_click=on_click_callback,
)
information_placeholder.caption(
f"""
Used {st.session_state.token_count} tokens \n
Debug Langchain conversation:
{st.session_state.conversation.memory.buffer}
"""
)
components.html(
"""
<script>
const streamlitDoc = window.parent.document;
const buttons = Array.from(
streamlitDoc.querySelectorAll('.stButton > button')
);
const submitButton = buttons.find(
el => el.innerText === 'Submit'
);
streamlitDoc.addEventListener('keydown', function(e) {
switch (e.key) {
case 'Enter':
submitButton.click();
break;
}
});
</script>
""",
height=0,
width=0,
)
with doc_column:
if len(st.session_state.history) > 0:
st.markdown("**Source documents**")
for doc in st.session_state.history[-1].documents:
doc_content = json.loads(doc.page_content)
expander = st.expander(doc_content["title"])
expander.markdown("**" + doc_content["doi"] + "**")
expander.markdown(doc_content["abstract"])
expander.markdown("**Authors** : " + doc_content["authors"])
expander.markdown("**Keywords** : " + doc_content["keywords"])
expander.markdown("**Distance** : " + str(doc_content["distance"]))