yogjoshi14's picture
making workflows
3b1a154
raw
history blame
4.87 kB
import os
import pinecone
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFaceHub
from PyPDF2 import PdfReader
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Pinecone
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
import streamlit as st
st.set_page_config(page_title="chatbot")
st.title("Chat with Documents")
num_of_top_selection = 3
CHUNK_SIZE = 500
CHUNK_OVERLAP = 50
embedding_dim = 768
# Initialize Pinecone
pc = pinecone.Pinecone(api_key=os.environ.getattribute("PINECONE_API_KEY"))
index_name = "qp-ai-assessment"
def recreate_index():
# Check if the index exists, and delete it if it does
existing_indexes = pc.list_indexes().names()
print(existing_indexes)
if index_name in existing_indexes:
pc.delete_index(index_name)
print(f"Deleted existing index: {index_name}")
# Create a new index
pc.create_index(
name=index_name,
metric='cosine',
dimension=embedding_dim,
spec=pinecone.PodSpec(os.environ.getattribute("PINECONE_ENV")) # 1536 dim of text-embedding-ada-002
)
print(f"Created new index: {index_name}")
def load_documents(pdf_docs):
text = ""
for pdf in pdf_docs:
pdf_reader = PdfReader(pdf)
for page in pdf_reader.pages:
text += page.extract_text()
return text
def split_documents(documents):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
texts = text_splitter.split_text(documents)
return text_splitter.create_documents(texts)
def embeddings_on_pinecone(texts):
# Use HuggingFace embeddings for transforming text into numerical vectors
embeddings = HuggingFaceEmbeddings()
vectordb = Pinecone.from_documents(texts, embeddings, index_name=st.session_state.pinecone_index)
retriever = vectordb.as_retriever(search_kwargs={'k': num_of_top_selection})
return retriever
def query_llm(retriever, query):
#llm = OpenAIChat(openai_api_key=st.session_state.openai_api_key)
llm = HuggingFaceHub(repo_id="google/flan-t5-xxl", model_kwargs={"temperature":0.5, "max_length":512})
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
return_source_documents=True,
)
result = qa_chain({'question': query, 'chat_history': st.session_state.messages})
result = result['answer']
st.session_state.messages.append((query, result))
return result
def input_fields():
#
with st.sidebar:
#
# if "openai_api_key" in st.secrets:
# st.session_state.openai_api_key = st.secrets.openai_api_key
# else:
# st.session_state.openai_api_key = st.text_input("OpenAI API key", type="password")
st.session_state.pinecone_api_key = os.environ.getattribute("PINECONE_API_KEY")
# st.text_input("Pinecone API key", type="password")
st.session_state.pinecone_env = os.environ.getattribute("PINECONE_ENV")
# st.text_input("Pinecone environment")
st.session_state.pinecone_index = index_name
# st.text_input("Pinecone index name")
st.session_state.source_docs = st.file_uploader(label="Upload Documents", type="pdf", accept_multiple_files=True)
#
def process_documents():
if not st.session_state.pinecone_api_key or not st.session_state.pinecone_env or not st.session_state.pinecone_index or not st.session_state.source_docs:
st.warning(f"Please upload the documents and provide the missing fields.")
else:
try:
# for source_doc in st.session_state.source_docs:
if st.session_state.source_docs:
#
# recreate_index()
documents = load_documents(st.session_state.source_docs)
#
texts = split_documents(documents)
#
st.session_state.retriever = embeddings_on_pinecone(texts)
except Exception as e:
st.error(f"An error occurred: {e}")
def boot():
#
input_fields()
#
st.button("Submit Documents", on_click=process_documents)
#
if "messages" not in st.session_state:
st.session_state.messages = []
#
for message in st.session_state.messages:
st.chat_message('human').write(message[0])
st.chat_message('ai').write(message[1])
#
if query := st.chat_input():
st.chat_message("human").write(query)
response = query_llm(st.session_state.retriever, query)
st.chat_message("ai").write(response)
if __name__ == '__main__':
#
boot()