File size: 5,815 Bytes
be1f39f
d837f95
7b26205
fc86a8c
7f721d2
fc86a8c
 
 
28ccd64
d837f95
681a5c2
0712e72
84ad3fa
2e34be4
649c1e6
28ccd64
5c2d16e
be1f39f
 
 
 
 
 
 
 
 
 
 
 
7f721d2
be1f39f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d55e9be
7f721d2
 
 
 
be1f39f
 
 
3999e7c
 
 
 
 
b371097
 
 
 
 
 
 
 
 
3999e7c
bf68d2b
 
b600572
e8e8cf3
 
 
bf68d2b
 
 
 
3999e7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73ee177
7f721d2
3999e7c
3ad696e
3999e7c
 
 
 
 
 
b26141b
3999e7c
 
 
 
 
 
 
 
649c1e6
3999e7c
 
 
 
 
 
0f7c92d
e8b8b3e
649c1e6
3999e7c
 
 
 
 
 
 
7f721d2
 
be1f39f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146

from langchain_community.llms import CTransformers
from ctransformers import AutoModelForCausalLM
from langchain.agents import Tool
from langchain.agents import AgentType, initialize_agent
from langchain.chains import RetrievalQA
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
import tempfile
import os
import streamlit as st
import timeit
from langchain.callbacks.tracers import ConsoleCallbackHandler

# tt
def main():

    FILE_LOADER_MAPPING = {
    "pdf": (PyPDFLoader, {})
    # Add more mappings for other file extensions and loaders as needed
    }

    st.title("Document Comparison with Q&A using Agents")
    
    

    # Upload files
    uploaded_files = st.file_uploader("Upload your documents", type=["pdf"], accept_multiple_files=True)
    loaded_documents = []

    if uploaded_files:
        # Create a temporary directory
        with tempfile.TemporaryDirectory() as td:
            # Move the uploaded files to the temporary directory and process them
            for uploaded_file in uploaded_files:
                st.write(f"Uploaded: {uploaded_file.name}")
                ext = os.path.splitext(uploaded_file.name)[-1][1:].lower()
                st.write(f"Uploaded: {ext}")

                # Check if the extension is in FILE_LOADER_MAPPING
                if ext in FILE_LOADER_MAPPING:
                    loader_class, loader_args = FILE_LOADER_MAPPING[ext]
                    # st.write(f"loader_class: {loader_class}")

                    # Save the uploaded file to the temporary directory
                    file_path = os.path.join(td, uploaded_file.name)
                    with open(file_path, 'wb') as temp_file:
                        temp_file.write(uploaded_file.read())

                    # Use Langchain loader to process the file
                    loader = loader_class(file_path, **loader_args)
                    loaded_documents.extend(loader.load())
                else:
                    st.warning(f"Unsupported file extension: {ext}, the app currently only supports pdf")

        st.write("Ask question to get comparison from the documents:")
        query = st.text_input("Ask a question:")


    

        if st.button("Get Answer"):
            if query:
                # Load model, set prompts, create vector database, and retrieve answer
                try:
                    start = timeit.default_timer()
                    # config = {
                    # 'max_new_tokens': 1024,
                    # 'repetition_penalty': 1.1,
                    # 'temperature': 0.1,
                    # 'top_k': 50,
                    # 'top_p': 0.9,
                    # 'stream': True,
                    # 'threads': int(os.cpu_count() / 2)
                    # }
                        
                    llm = CTransformers(
                    # model = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF",
                    model= "TheBloke/Llama-2-7B-Chat-GGUF",
                    model_file = "llama-2-7b-chat.Q3_K_S.gguf",
                    model_type="llama",
                    max_new_tokens = 300,
                    temperature = 0.3,
                    lib="avx2", # for CPU
                    )
                    # llm = AutoModelForCausalLM.from_pretrained("second-state/stablelm-2-zephyr-1.6b-GGUF", model_type="stablelm-2-zephyr-1_6b-Q4_0.gguf")
                        
                    print("LLM Initialized...")
                        
                        
                        
                    model_name = "BAAI/bge-large-en"
                    model_kwargs = {'device': 'cpu'}
                    encode_kwargs = {'normalize_embeddings': False}
                    embeddings = HuggingFaceBgeEmbeddings(
                        model_name=model_name,
                        model_kwargs=model_kwargs,
                        encode_kwargs=encode_kwargs
                    )
    
                    text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
                    chunked_documents = text_splitter.split_documents(loaded_documents)
                    retriever = FAISS.from_documents(chunked_documents, embeddings).as_retriever()
                    
                    # Wrap retrievers in a Tool
                    tools = []
                    tools.append(
                        Tool(
                            name="Comparison tool",
                            description="useful when you want to answer questions about the uploaded documents",
                            func=RetrievalQA.from_chain_type(llm=llm, retriever=retriever),
                        )
                    )
    
                    agent = initialize_agent(
                    tools=tools,
                    llm=llm,
                    agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
                    verbose=True
                    )
    
                    # response = agent.run(query)
    
                    end = timeit.default_timer()
                    st.write("Elapsed time:")
                    st.write(end - start)
    
                    st.write("Bot Response:")
                    # st.write(agent.invoke(query, config={"callbacks":[ConsoleCallbackHandler()]}))
                    st.write(agent.run({"input": query}))
                    # st.write(response)
    
    
                                             
                except Exception as e:
                    st.error(f"An error occurred: {str(e)}")
            else:
                st.warning("Please enter a question.")
                    

if __name__ == "__main__":
    main()