Spaces:
Sleeping
Sleeping
import json | |
import os | |
import sqlalchemy | |
import sqlite_vss | |
import streamlit as st | |
import streamlit.components.v1 as components | |
from langchain import OpenAI | |
from langchain.callbacks import get_openai_callback | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.chains.conversation.memory import ConversationBufferMemory | |
from langchain.embeddings import GPT4AllEmbeddings | |
from sqlalchemy import event | |
from chat_history import insert_chat_history, insert_chat_history_articles | |
from css import load_css | |
from custom_pgvector import CustomPGVector | |
from message import Message | |
CONNECTION_STRING = "sqlite:///data/sorbobot.db" | |
st.set_page_config(layout="wide") | |
st.title("Sorbobot - Le futur de la recherche scientifique interactive") | |
chat_column, doc_column = st.columns([2, 1]) | |
def connect() -> sqlalchemy.engine.Connection: | |
engine = sqlalchemy.create_engine(CONNECTION_STRING) | |
def receive_connect(connection, _): | |
connection.enable_load_extension(True) | |
sqlite_vss.load(connection) | |
connection.enable_load_extension(False) | |
conn = engine.connect() | |
return conn | |
conn = connect() | |
def initialize_session_state(): | |
if "history" not in st.session_state: | |
st.session_state.history = [] | |
if "token_count" not in st.session_state: | |
st.session_state.token_count = 0 | |
if "conversation" not in st.session_state: | |
embeddings = GPT4AllEmbeddings() | |
db = CustomPGVector( | |
embedding_function=embeddings, | |
table_name="article", | |
column_name="abstract_embedding", | |
connection=conn, | |
) | |
retriever = db.as_retriever() | |
llm = OpenAI( | |
temperature=0, | |
openai_api_key=os.environ["OPENAI_API_KEY"], | |
model="text-davinci-003", | |
) | |
memory = ConversationBufferMemory( | |
output_key="answer", memory_key="chat_history", return_messages=True | |
) | |
st.session_state.conversation = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
retriever=retriever, | |
verbose=True, | |
memory=memory, | |
return_source_documents=True, | |
) | |
def on_click_callback(): | |
with get_openai_callback() as cb: | |
human_prompt = st.session_state.human_prompt | |
llm_response = st.session_state.conversation(human_prompt) | |
st.session_state.history.append(Message("human", human_prompt)) | |
st.session_state.history.append( | |
Message( | |
"ai", llm_response["answer"], documents=llm_response["source_documents"] | |
) | |
) | |
st.session_state.token_count += cb.total_tokens | |
history_id = insert_chat_history(conn, human_prompt, llm_response["answer"]) | |
insert_chat_history_articles(conn, history_id, llm_response["source_documents"]) | |
load_css() | |
initialize_session_state() | |
with chat_column: | |
chat_placeholder = st.container() | |
prompt_placeholder = st.form("chat-form") | |
information_placeholder = st.empty() | |
with chat_placeholder: | |
for chat in st.session_state.history: | |
div = f""" | |
<div class="chat-row | |
{'' if chat.origin == 'ai' else 'row-reverse'}"> | |
<img class="chat-icon" src="./app/static/{ | |
'ai_icon.png' if chat.origin == 'ai' | |
else 'user_icon.png'}" | |
width=32 height=32> | |
<div class="chat-bubble | |
{'ai-bubble' if chat.origin == 'ai' else 'human-bubble'}"> | |
​{chat.message} | |
</div> | |
</div> | |
""" | |
st.markdown(div, unsafe_allow_html=True) | |
for _ in range(3): | |
st.markdown("") | |
with prompt_placeholder: | |
st.markdown("**Chat**") | |
cols = st.columns((6, 1)) | |
cols[0].text_input( | |
"Chat", | |
value="Hello bot", | |
label_visibility="collapsed", | |
key="human_prompt", | |
) | |
cols[1].form_submit_button( | |
"Submit", | |
type="primary", | |
on_click=on_click_callback, | |
) | |
information_placeholder.caption( | |
f""" | |
Used {st.session_state.token_count} tokens \n | |
Debug Langchain conversation: | |
{st.session_state.conversation.memory.buffer} | |
""" | |
) | |
components.html( | |
""" | |
<script> | |
const streamlitDoc = window.parent.document; | |
const buttons = Array.from( | |
streamlitDoc.querySelectorAll('.stButton > button') | |
); | |
const submitButton = buttons.find( | |
el => el.innerText === 'Submit' | |
); | |
streamlitDoc.addEventListener('keydown', function(e) { | |
switch (e.key) { | |
case 'Enter': | |
submitButton.click(); | |
break; | |
} | |
}); | |
</script> | |
""", | |
height=0, | |
width=0, | |
) | |
with doc_column: | |
if len(st.session_state.history) > 0: | |
st.markdown("**Source documents**") | |
for doc in st.session_state.history[-1].documents: | |
doc_content = json.loads(doc.page_content) | |
expander = st.expander(doc_content["title"]) | |
expander.markdown("**" + doc_content["doi"] + "**") | |
expander.markdown(doc_content["abstract"]) | |
expander.markdown("**Authors** : " + doc_content["authors"]) | |
expander.markdown("**Keywords** : " + doc_content["keywords"]) | |
expander.markdown("**Distance** : " + str(doc_content["distance"])) | |