RAG / app.py
Kameshr's picture
Update app.py
4e6f4e8 verified
import gradio as gr
import os
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFaceEndpoint
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
api_token = os.getenv("HF_TOKEN")
DEFAULT_LLM = "meta-llama/Meta-Llama-3-8B-Instruct"
def load_and_create_db(files):
if not files:
return None, None
try:
# Create a list of documents
list_file_paths = []
for file in files:
# Save uploaded file temporarily
file_name = file.name
if file_name.lower().endswith('.pdf'):
list_file_paths.append(file_name)
else:
raise ValueError(f"Unsupported file format: {file_name}. Please upload PDF files only.")
if not list_file_paths:
return None, None
# Load documents
loaders = [PyPDFLoader(path) for path in list_file_paths]
pages = []
for loader in loaders:
pages.extend(loader.load())
# Split documents
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1024,
chunk_overlap=64
)
doc_splits = text_splitter.split_documents(pages)
# Create vector database
embeddings = HuggingFaceEmbeddings()
vectordb = FAISS.from_documents(doc_splits, embeddings)
# Initialize QA chain
qa_chain = initialize_llmchain(vectordb)
return vectordb, qa_chain
except Exception as e:
print(f"Error processing files: {str(e)}")
return None, None
def initialize_llmchain(vector_db, temperature=0.5, max_tokens=4096, top_k=3):
llm = HuggingFaceEndpoint(
repo_id=DEFAULT_LLM,
huggingfacehub_api_token=api_token,
temperature=temperature,
max_new_tokens=max_tokens,
top_k=top_k,
)
memory = ConversationBufferMemory(
memory_key="chat_history",
output_key='answer',
return_messages=True
)
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
retriever=vector_db.as_retriever(),
chain_type="stuff",
memory=memory,
return_source_documents=True,
verbose=False,
)
return qa_chain
def format_citation(source_doc):
content = source_doc.page_content.strip()
page = source_doc.metadata["page"] + 1
return content, page
def conversation(qa_chain, message, history):
if not qa_chain:
return (None, gr.update(value=""), history, "", 0, "", 0, "", 0,
"Please upload a document first.")
formatted_history = []
for user_msg, bot_msg in history:
formatted_history.append(f"User: {user_msg}")
formatted_history.append(f"Assistant: {bot_msg}")
response = qa_chain.invoke({
"question": message,
"chat_history": formatted_history
})
answer = response["answer"]
if "Helpful Answer:" in answer:
answer = answer.split("Helpful Answer:")[-1]
# Format answer with citation numbers
sources = response["source_documents"][:3]
modified_answer = answer
for i in range(len(sources)):
modified_answer = modified_answer + f" [{i+1}]"
# Get citation contents and page numbers
citations = [format_citation(source) for source in sources]
source1_content, page1 = citations[0] if len(citations) > 0 else ("", 0)
source2_content, page2 = citations[1] if len(citations) > 1 else ("", 0)
source3_content, page3 = citations[2] if len(citations) > 2 else ("", 0)
new_history = history + [(message, modified_answer)]
return (qa_chain, gr.update(value=""), new_history,
source1_content, page1, source2_content, page2, source3_content, page3, "")
def demo():
with gr.Blocks(theme=gr.themes.Default(primary_hue="red", secondary_hue="pink", neutral_hue="sky")) as demo:
vector_db = gr.State()
qa_chain = gr.State()
gr.HTML("<center><h1>RAG PDF Chatbot</h1></center>")
gr.Markdown("""
<b>Query your PDF documents!</b> This AI agent performs retrieval augmented generation (RAG)
on PDF documents. <b>Please do not upload confidential documents.</b>
""")
with gr.Row():
with gr.Column(scale=1):
document = gr.Files(
height=300,
file_count="multiple",
file_types=[".pdf"],
label="Upload PDF documents"
)
upload_status = gr.Textbox(label="Upload Status", interactive=False)
with gr.Column(scale=2):
chatbot = gr.Chatbot(height=500)
with gr.Accordion("Citations", open=False):
with gr.Row():
doc_source1 = gr.Textbox(label="[1]", lines=2, container=True, scale=20)
source1_page = gr.Number(label="Page", scale=1)
with gr.Row():
doc_source2 = gr.Textbox(label="[2]", lines=2, container=True, scale=20)
source2_page = gr.Number(label="Page", scale=1)
with gr.Row():
doc_source3 = gr.Textbox(label="[3]", lines=2, container=True, scale=20)
source3_page = gr.Number(label="Page", scale=1)
with gr.Row():
msg = gr.Textbox(
placeholder="Ask a question about your documents...",
container=True
)
with gr.Row():
submit_btn = gr.Button("Submit")
clear_btn = gr.ClearButton([msg, chatbot], value="Clear")
def handle_file_upload(files):
if not files:
return None, None, "No files uploaded"
try:
vectordb, qa = load_and_create_db(files)
if vectordb and qa:
return vectordb, qa, "Files successfully processed"
return None, None, "Error processing files"
except Exception as e:
return None, None, f"Error: {str(e)}"
# Automatically create vector DB and initialize chain on file upload
document.upload(
fn=handle_file_upload,
inputs=[document],
outputs=[vector_db, qa_chain, upload_status]
)
# Clear citations when chat is cleared
def clear_all():
return ["", 0, "", 0, "", 0]
# Chatbot events
submit_btn.click(
conversation,
inputs=[qa_chain, msg, chatbot],
outputs=[qa_chain, msg, chatbot,
doc_source1, source1_page,
doc_source2, source2_page,
doc_source3, source3_page,
upload_status]
)
msg.submit(
conversation,
inputs=[qa_chain, msg, chatbot],
outputs=[qa_chain, msg, chatbot,
doc_source1, source1_page,
doc_source2, source2_page,
doc_source3, source3_page,
upload_status]
)
clear_btn.click(
clear_all,
outputs=[doc_source1, source1_page,
doc_source2, source2_page,
doc_source3, source3_page]
)
demo.queue().launch(debug=True)
if __name__ == "__main__":
demo()