import json import os import sqlalchemy import sqlite_vss import streamlit as st import streamlit.components.v1 as components from langchain import OpenAI from langchain.callbacks import get_openai_callback from langchain.chains import ConversationalRetrievalChain from langchain.chains.conversation.memory import ConversationBufferMemory from langchain.embeddings import GPT4AllEmbeddings from sqlalchemy import event from chat_history import insert_chat_history, insert_chat_history_articles from css import load_css from custom_pgvector import CustomPGVector from message import Message CONNECTION_STRING = "sqlite:///data/sorbobot.db" st.set_page_config(layout="wide") st.title("Sorbobot - Le futur de la recherche scientifique interactive") chat_column, doc_column = st.columns([2, 1]) def connect() -> sqlalchemy.engine.Connection: engine = sqlalchemy.create_engine(CONNECTION_STRING) @event.listens_for(engine, "connect") def receive_connect(connection, _): connection.enable_load_extension(True) sqlite_vss.load(connection) connection.enable_load_extension(False) conn = engine.connect() return conn conn = connect() def initialize_session_state(): if "history" not in st.session_state: st.session_state.history = [] if "token_count" not in st.session_state: st.session_state.token_count = 0 if "conversation" not in st.session_state: embeddings = GPT4AllEmbeddings() db = CustomPGVector( embedding_function=embeddings, table_name="article", column_name="abstract_embedding", connection=conn, ) retriever = db.as_retriever() llm = OpenAI( temperature=0, openai_api_key=os.environ["OPENAI_API_KEY"], model="text-davinci-003", ) memory = ConversationBufferMemory( output_key="answer", memory_key="chat_history", return_messages=True ) st.session_state.conversation = ConversationalRetrievalChain.from_llm( llm=llm, retriever=retriever, verbose=True, memory=memory, return_source_documents=True, ) def on_click_callback(): with get_openai_callback() as cb: human_prompt = st.session_state.human_prompt llm_response = st.session_state.conversation(human_prompt) st.session_state.history.append(Message("human", human_prompt)) st.session_state.history.append( Message( "ai", llm_response["answer"], documents=llm_response["source_documents"] ) ) st.session_state.token_count += cb.total_tokens history_id = insert_chat_history(conn, human_prompt, llm_response["answer"]) insert_chat_history_articles(conn, history_id, llm_response["source_documents"]) load_css() initialize_session_state() with chat_column: chat_placeholder = st.container() prompt_placeholder = st.form("chat-form") information_placeholder = st.empty() with chat_placeholder: for chat in st.session_state.history: div = f"""