captain-awesome's picture
Update app.py
38e6184 verified
raw
history blame
5.17 kB
from langchain_community.llms import CTransformers
from langchain.agents import Tool
from langchain.agents import AgentType, initialize_agent
from langchain.chains import RetrievalQA
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
import tempfile
import os
import streamlit as st
import timeit
def main():
FILE_LOADER_MAPPING = {
"pdf": (PyPDFLoader, {})
# Add more mappings for other file extensions and loaders as needed
}
st.title("Document Comparison with Q&A using Agents")
# Upload files
uploaded_files = st.file_uploader("Upload your documents", type=["pdf"], accept_multiple_files=True)
loaded_documents = []
if uploaded_files:
# Create a temporary directory
with tempfile.TemporaryDirectory() as td:
# Move the uploaded files to the temporary directory and process them
for uploaded_file in uploaded_files:
st.write(f"Uploaded: {uploaded_file.name}")
ext = os.path.splitext(uploaded_file.name)[-1][1:].lower()
st.write(f"Uploaded: {ext}")
# Check if the extension is in FILE_LOADER_MAPPING
if ext in FILE_LOADER_MAPPING:
loader_class, loader_args = FILE_LOADER_MAPPING[ext]
# st.write(f"loader_class: {loader_class}")
# Save the uploaded file to the temporary directory
file_path = os.path.join(td, uploaded_file.name)
with open(file_path, 'wb') as temp_file:
temp_file.write(uploaded_file.read())
# Use Langchain loader to process the file
loader = loader_class(file_path, **loader_args)
loaded_documents.extend(loader.load())
else:
st.warning(f"Unsupported file extension: {ext}, the app currently only supports pdf")
st.write("Ask question to get comparison from the documents:")
query = st.text_input("Ask a question:")
if st.button("Get Answer"):
if query:
# Load model, set prompts, create vector database, and retrieve answer
try:
start = timeit.default_timer()
# config = {
# 'max_new_tokens': 1024,
# 'repetition_penalty': 1.1,
# 'temperature': 0.1,
# 'top_k': 50,
# 'top_p': 0.9,
# 'stream': True,
# 'threads': int(os.cpu_count() / 2)
# }
llm = CTransformers(
model = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF",
model_type="mistral",
max_new_tokens = 1048,
temperature = 0.3
)
print("LLM Initialized...")
model_name = "BAAI/bge-large-en"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
embeddings = HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
chunked_documents = text_splitter.split_documents(loaded_documents)
retriever = FAISS.from_documents(docs, embeddings).as_retriever()
# Wrap retrievers in a Tool
tools.append(
Tool(
name="Comparison tool",
description="useful when you want to answer questions about the uploaded documents",
func=RetrievalQA.from_chain_type(llm=llm, retriever=retriever),
)
)
agent = initialize_agent(
tools=tools,
llm=llm,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True
)
response = agent.run(query)
end = timeit.default_timer()
st.write("Elapsed time:")
st.write(end - start)
st.write("Bot Response:")
st.write(response)
except Exception as e:
st.error(f"An error occurred: {str(e)}")
else:
st.warning("Please enter a question.")
if __name__ == "__main__":
main()