Kameshr commited on
Commit
7afda13
·
verified ·
1 Parent(s): 8e39574

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -133
app.py CHANGED
@@ -1,170 +1,152 @@
1
  import gradio as gr
2
  import os
3
-
4
- # Retrieve API token from environment variable
5
- api_token = os.getenv("HF_TOKEN")
6
-
7
- # Import required libraries from LangChain
8
- from langchain.llms import HuggingFaceHub
9
- from langchain.vectorstores import FAISS
10
- from langchain.document_loaders import PyPDFLoader
11
  from langchain.text_splitter import RecursiveCharacterTextSplitter
12
- from langchain.embeddings import HuggingFaceEmbeddings
 
13
  from langchain.chains import ConversationalRetrievalChain
14
- from langchain.memory import ConversationSummaryBufferMemory
15
 
16
- # Define the default LLM model to use
17
- llm_model = "meta-llama/Meta-Llama-3-8B-Instruct"
18
 
19
- # Initialize the LLM-based retrieval chain
20
- # The retriever is set directly within this function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- def initialize_llmchain(llm_model, temperature, max_tokens, top_k, retriever):
23
- # Configure the selected LLM with parameters using HuggingFaceHub
24
- llm = HuggingFaceHub(
25
- repo_id=llm_model,
26
  huggingfacehub_api_token=api_token,
27
- model_kwargs={
28
- "temperature": temperature,
29
- "max_new_tokens": max_tokens,
30
- "top_k": top_k
31
- }
32
  )
33
-
34
- # Use summary-based conversation memory for better performance
35
- memory = ConversationSummaryBufferMemory(
36
- llm=llm,
37
  memory_key="chat_history",
38
- output_key="answer"
 
39
  )
40
 
41
- # Combine LLM, retriever, and memory into a Conversational Retrieval Chain
42
  qa_chain = ConversationalRetrievalChain.from_llm(
43
  llm,
44
- retriever=retriever,
 
45
  memory=memory,
46
  return_source_documents=True,
47
  verbose=False,
48
  )
49
  return qa_chain
50
 
51
- # Load and split PDF documents into manageable chunks
52
- def load_doc(list_file_path):
53
- # Load each file using PyPDFLoader
54
- loaders = [PyPDFLoader(x) for x in list_file_path]
55
- pages = []
56
- for loader in loaders:
57
- pages.extend(loader.load())
58
-
59
- # Split loaded pages into smaller chunks with overlap for better context
60
- text_splitter = RecursiveCharacterTextSplitter(
61
- chunk_size=1024,
62
- chunk_overlap=64
63
- )
64
- doc_splits = text_splitter.split_documents(pages)
65
- return doc_splits
66
-
67
- # Create a vector database from document splits
68
- def create_db(splits):
69
- embeddings = HuggingFaceEmbeddings()
70
- vectordb = FAISS.from_documents(splits, embeddings)
71
- return vectordb
72
 
73
- # Format chat history for display in chatbot UI
74
- def format_chat_history(message, chat_history):
75
- formatted_chat_history = []
76
- for user_message, bot_message in chat_history:
77
- formatted_chat_history.append(f"User: {user_message}")
78
- formatted_chat_history.append(f"Assistant: {bot_message}")
79
- return formatted_chat_history
80
 
81
- # Handle user queries and generate responses using the chatbot
82
  def conversation(qa_chain, message, history):
83
- # Format chat history to include previous interactions
84
- formatted_chat_history = format_chat_history(message, history)
85
-
86
- # Invoke the QA chain with the user message and chat history
87
- response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history})
88
-
89
- # Extract the response answer, removing unnecessary labels if present
90
- response_answer = response["answer"].split("Helpful Answer:")[-1] if "Helpful Answer:" in response["answer"] else response["answer"]
91
-
92
- # Extract top 3 source documents for relevance display
93
- source_documents = response["source_documents"][:3]
94
- sources = [
95
- {
96
- "content": doc.page_content.strip(),
97
- "page": doc.metadata.get("page", 0) + 1
98
- } for doc in source_documents
99
- ]
100
-
101
- # Update chat history with the latest interaction
102
- new_history = history + [(message, response_answer)]
103
- return qa_chain, gr.update(value=""), new_history, sources
104
-
105
- # Set up the Gradio interface for the chatbot application
106
- def demo():
107
- # Load and initialize a retriever with a placeholder database
108
- placeholder_docs = load_doc(["placeholder.pdf"])
109
- vector_db = create_db(placeholder_docs)
110
- retriever = vector_db.as_retriever()
111
-
112
- # Initialize the QA chain with default LLM parameters
113
- qa_chain = initialize_llmchain(
114
- llm_model=llm_model,
115
- temperature=0.5,
116
- max_tokens=1024,
117
- top_k=3,
118
- retriever=retriever
119
  )
 
 
 
120
 
121
- with gr.Blocks() as demo:
122
- # Persistent states for the vector database and QA chain
123
- gr.State()
124
-
125
- # Display app header and description
126
  gr.HTML("<center><h1>RAG PDF Chatbot</h1></center>")
127
  gr.Markdown("""
128
- <b>Query your PDF documents!</b> Upload files to create a vector database and chat with the content. <b>Do not upload confidential documents.</b>
 
129
  """)
130
-
131
  with gr.Row():
132
- # Column for uploading files and configuring the pipeline
133
- with gr.Column(scale=85):
134
- document = gr.Files(height=200, file_count="multiple", file_types=[".pdf"], label="Upload PDF documents")
135
- upload_btn = gr.Button("Upload and Process Documents")
136
- pipeline_status = gr.Textbox(value="Initialized", interactive=False, label="Status")
137
-
138
- # Column for chatbot interaction
139
- with gr.Column(scale=200):
140
- chatbot = gr.Chatbot(height=400, label="Chatbot")
141
- message = gr.Textbox(placeholder="Type your question here")
142
- submit_btn = gr.Button("Submit")
143
- clear_btn = gr.ClearButton([message, chatbot], value="Clear")
144
- relevant_context = gr.Textbox(label="Relevant Context", lines=3, interactive=False)
145
-
146
- # Define action to process documents and update the retriever
147
- upload_btn.click(
148
- lambda file_obj: initialize_llmchain(
149
- llm_model=llm_model,
150
- temperature=0.5,
151
- max_tokens=1024,
152
- top_k=3,
153
- retriever=create_db(load_doc([file.name for file in file_obj if file is not None])).as_retriever()
154
- ),
155
  inputs=[document],
156
- outputs=[qa_chain]
157
  )
158
-
159
- # Define action to handle user queries
160
  submit_btn.click(
161
  conversation,
162
- inputs=[qa_chain, message, chatbot],
163
- outputs=[qa_chain, message, chatbot, relevant_context]
164
  )
 
 
 
 
 
 
 
165
 
166
- demo.launch()
167
-
168
- # Launch the application
169
  if __name__ == "__main__":
170
  demo()
 
1
  import gradio as gr
2
  import os
3
+ from langchain_community.vectorstores import FAISS
4
+ from langchain_community.document_loaders import PyPDFLoader
 
 
 
 
 
 
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain_community.embeddings import HuggingFaceEmbeddings
7
+ from langchain_community.llms import HuggingFaceEndpoint
8
  from langchain.chains import ConversationalRetrievalChain
9
+ from langchain.memory import ConversationBufferMemory
10
 
11
+ api_token = os.getenv("HF_TOKEN")
12
+ DEFAULT_LLM = "meta-llama/Meta-Llama-3-8B-Instruct"
13
 
14
+ def load_and_create_db(list_file_obj):
15
+ # Create a list of documents
16
+ list_file_path = [x.name for x in list_file_obj if x is not None]
17
+
18
+ # Load documents
19
+ loaders = [PyPDFLoader(x) for x in list_file_path]
20
+ pages = []
21
+ for loader in loaders:
22
+ pages.extend(loader.load())
23
+
24
+ # Split documents
25
+ text_splitter = RecursiveCharacterTextSplitter(
26
+ chunk_size=1024,
27
+ chunk_overlap=64
28
+ )
29
+ doc_splits = text_splitter.split_documents(pages)
30
+
31
+ # Create vector database
32
+ embeddings = HuggingFaceEmbeddings()
33
+ vectordb = FAISS.from_documents(doc_splits, embeddings)
34
+ return vectordb
35
 
36
+ def initialize_llmchain(vector_db, temperature=0.5, max_tokens=4096, top_k=3):
37
+ llm = HuggingFaceEndpoint(
38
+ repo_id=DEFAULT_LLM,
 
39
  huggingfacehub_api_token=api_token,
40
+ temperature=temperature,
41
+ max_new_tokens=max_tokens,
42
+ top_k=top_k,
 
 
43
  )
44
+
45
+ memory = ConversationBufferMemory(
 
 
46
  memory_key="chat_history",
47
+ output_key='answer',
48
+ return_messages=True
49
  )
50
 
 
51
  qa_chain = ConversationalRetrievalChain.from_llm(
52
  llm,
53
+ retriever=vector_db.as_retriever(),
54
+ chain_type="stuff",
55
  memory=memory,
56
  return_source_documents=True,
57
  verbose=False,
58
  )
59
  return qa_chain
60
 
61
+ def format_citation(source_doc):
62
+ content = source_doc.page_content.strip()
63
+ page = source_doc.metadata["page"] + 1
64
+ return f"[Page {page}] {content}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ def format_response_with_citations(answer, sources):
67
+ citations = [format_citation(source) for source in sources[:3]]
68
+ formatted_response = f"{answer}\n\nReferences:\n"
69
+ for idx, citation in enumerate(citations, 1):
70
+ formatted_response += f"^{idx}^ {citation}\n"
71
+ return formatted_response
 
72
 
 
73
  def conversation(qa_chain, message, history):
74
+ if not qa_chain:
75
+ return None, gr.update(value=""), history, "Please upload a document first.", None
76
+
77
+ formatted_history = []
78
+ for user_msg, bot_msg in history:
79
+ formatted_history.append(f"User: {user_msg}")
80
+ formatted_history.append(f"Assistant: {bot_msg}")
81
+
82
+ response = qa_chain.invoke({
83
+ "question": message,
84
+ "chat_history": formatted_history
85
+ })
86
+
87
+ answer = response["answer"]
88
+ if "Helpful Answer:" in answer:
89
+ answer = answer.split("Helpful Answer:")[-1]
90
+
91
+ formatted_response = format_response_with_citations(
92
+ answer,
93
+ response["source_documents"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  )
95
+
96
+ new_history = history + [(message, formatted_response)]
97
+ return qa_chain, gr.update(value=""), new_history
98
 
99
+ def demo():
100
+ with gr.Blocks(theme=gr.themes.Default(primary_hue="red", secondary_hue="pink", neutral_hue="sky")) as demo:
101
+ vector_db = gr.State()
102
+ qa_chain = gr.State()
103
+
104
  gr.HTML("<center><h1>RAG PDF Chatbot</h1></center>")
105
  gr.Markdown("""
106
+ <b>Query your PDF documents!</b> This AI agent performs retrieval augmented generation (RAG)
107
+ on PDF documents. <b>Please do not upload confidential documents.</b>
108
  """)
109
+
110
  with gr.Row():
111
+ with gr.Column(scale=1):
112
+ document = gr.Files(
113
+ height=300,
114
+ file_count="multiple",
115
+ file_types=["pdf"],
116
+ label="Upload PDF documents"
117
+ )
118
+
119
+ with gr.Column(scale=2):
120
+ chatbot = gr.Chatbot(height=600)
121
+ with gr.Row():
122
+ msg = gr.Textbox(
123
+ placeholder="Ask a question about your documents...",
124
+ container=True
125
+ )
126
+ with gr.Row():
127
+ submit_btn = gr.Button("Submit")
128
+ clear_btn = gr.ClearButton([msg, chatbot], value="Clear")
129
+
130
+ # Automatically create vector DB and initialize chain on file upload
131
+ document.upload(
132
+ fn=lambda files: [load_and_create_db(files), initialize_llmchain(load_and_create_db(files))],
 
133
  inputs=[document],
134
+ outputs=[vector_db, qa_chain]
135
  )
136
+
137
+ # Chatbot events
138
  submit_btn.click(
139
  conversation,
140
+ inputs=[qa_chain, msg, chatbot],
141
+ outputs=[qa_chain, msg, chatbot]
142
  )
143
+ msg.submit(
144
+ conversation,
145
+ inputs=[qa_chain, msg, chatbot],
146
+ outputs=[qa_chain, msg, chatbot]
147
+ )
148
+
149
+ demo.queue().launch(debug=True)
150
 
 
 
 
151
  if __name__ == "__main__":
152
  demo()