|
import gradio as gr |
|
import os |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_community.document_loaders import PyPDFLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain_community.llms import HuggingFaceEndpoint |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.memory import ConversationBufferMemory |
|
|
|
api_token = os.getenv("HF_TOKEN") |
|
DEFAULT_LLM = "meta-llama/Meta-Llama-3-8B-Instruct" |
|
|
|
def load_and_create_db(files): |
|
if not files: |
|
return None, None |
|
|
|
try: |
|
|
|
list_file_paths = [] |
|
for file in files: |
|
|
|
file_name = file.name |
|
if file_name.lower().endswith('.pdf'): |
|
list_file_paths.append(file_name) |
|
else: |
|
raise ValueError(f"Unsupported file format: {file_name}. Please upload PDF files only.") |
|
|
|
if not list_file_paths: |
|
return None, None |
|
|
|
|
|
loaders = [PyPDFLoader(path) for path in list_file_paths] |
|
pages = [] |
|
for loader in loaders: |
|
pages.extend(loader.load()) |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=1024, |
|
chunk_overlap=64 |
|
) |
|
doc_splits = text_splitter.split_documents(pages) |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings() |
|
vectordb = FAISS.from_documents(doc_splits, embeddings) |
|
|
|
|
|
qa_chain = initialize_llmchain(vectordb) |
|
|
|
return vectordb, qa_chain |
|
except Exception as e: |
|
print(f"Error processing files: {str(e)}") |
|
return None, None |
|
|
|
def initialize_llmchain(vector_db, temperature=0.5, max_tokens=4096, top_k=3): |
|
llm = HuggingFaceEndpoint( |
|
repo_id=DEFAULT_LLM, |
|
huggingfacehub_api_token=api_token, |
|
temperature=temperature, |
|
max_new_tokens=max_tokens, |
|
top_k=top_k, |
|
) |
|
|
|
memory = ConversationBufferMemory( |
|
memory_key="chat_history", |
|
output_key='answer', |
|
return_messages=True |
|
) |
|
|
|
qa_chain = ConversationalRetrievalChain.from_llm( |
|
llm, |
|
retriever=vector_db.as_retriever(), |
|
chain_type="stuff", |
|
memory=memory, |
|
return_source_documents=True, |
|
verbose=False, |
|
) |
|
return qa_chain |
|
|
|
def format_citation(source_doc): |
|
content = source_doc.page_content.strip() |
|
page = source_doc.metadata["page"] + 1 |
|
return content, page |
|
|
|
def conversation(qa_chain, message, history): |
|
if not qa_chain: |
|
return (None, gr.update(value=""), history, "", 0, "", 0, "", 0, |
|
"Please upload a document first.") |
|
|
|
formatted_history = [] |
|
for user_msg, bot_msg in history: |
|
formatted_history.append(f"User: {user_msg}") |
|
formatted_history.append(f"Assistant: {bot_msg}") |
|
|
|
response = qa_chain.invoke({ |
|
"question": message, |
|
"chat_history": formatted_history |
|
}) |
|
|
|
answer = response["answer"] |
|
if "Helpful Answer:" in answer: |
|
answer = answer.split("Helpful Answer:")[-1] |
|
|
|
|
|
sources = response["source_documents"][:3] |
|
modified_answer = answer |
|
for i in range(len(sources)): |
|
modified_answer = modified_answer + f" [{i+1}]" |
|
|
|
|
|
citations = [format_citation(source) for source in sources] |
|
source1_content, page1 = citations[0] if len(citations) > 0 else ("", 0) |
|
source2_content, page2 = citations[1] if len(citations) > 1 else ("", 0) |
|
source3_content, page3 = citations[2] if len(citations) > 2 else ("", 0) |
|
|
|
new_history = history + [(message, modified_answer)] |
|
return (qa_chain, gr.update(value=""), new_history, |
|
source1_content, page1, source2_content, page2, source3_content, page3, "") |
|
|
|
def demo(): |
|
with gr.Blocks(theme=gr.themes.Default(primary_hue="red", secondary_hue="pink", neutral_hue="sky")) as demo: |
|
vector_db = gr.State() |
|
qa_chain = gr.State() |
|
|
|
gr.HTML("<center><h1>RAG PDF Chatbot</h1></center>") |
|
gr.Markdown(""" |
|
<b>Query your PDF documents!</b> This AI agent performs retrieval augmented generation (RAG) |
|
on PDF documents. <b>Please do not upload confidential documents.</b> |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
document = gr.Files( |
|
height=300, |
|
file_count="multiple", |
|
file_types=[".pdf"], |
|
label="Upload PDF documents" |
|
) |
|
upload_status = gr.Textbox(label="Upload Status", interactive=False) |
|
|
|
with gr.Column(scale=2): |
|
chatbot = gr.Chatbot(height=500) |
|
with gr.Accordion("Citations", open=False): |
|
with gr.Row(): |
|
doc_source1 = gr.Textbox(label="[1]", lines=2, container=True, scale=20) |
|
source1_page = gr.Number(label="Page", scale=1) |
|
with gr.Row(): |
|
doc_source2 = gr.Textbox(label="[2]", lines=2, container=True, scale=20) |
|
source2_page = gr.Number(label="Page", scale=1) |
|
with gr.Row(): |
|
doc_source3 = gr.Textbox(label="[3]", lines=2, container=True, scale=20) |
|
source3_page = gr.Number(label="Page", scale=1) |
|
with gr.Row(): |
|
msg = gr.Textbox( |
|
placeholder="Ask a question about your documents...", |
|
container=True |
|
) |
|
with gr.Row(): |
|
submit_btn = gr.Button("Submit") |
|
clear_btn = gr.ClearButton([msg, chatbot], value="Clear") |
|
|
|
def handle_file_upload(files): |
|
if not files: |
|
return None, None, "No files uploaded" |
|
try: |
|
vectordb, qa = load_and_create_db(files) |
|
if vectordb and qa: |
|
return vectordb, qa, "Files successfully processed" |
|
return None, None, "Error processing files" |
|
except Exception as e: |
|
return None, None, f"Error: {str(e)}" |
|
|
|
|
|
document.upload( |
|
fn=handle_file_upload, |
|
inputs=[document], |
|
outputs=[vector_db, qa_chain, upload_status] |
|
) |
|
|
|
|
|
def clear_all(): |
|
return ["", 0, "", 0, "", 0] |
|
|
|
|
|
submit_btn.click( |
|
conversation, |
|
inputs=[qa_chain, msg, chatbot], |
|
outputs=[qa_chain, msg, chatbot, |
|
doc_source1, source1_page, |
|
doc_source2, source2_page, |
|
doc_source3, source3_page, |
|
upload_status] |
|
) |
|
msg.submit( |
|
conversation, |
|
inputs=[qa_chain, msg, chatbot], |
|
outputs=[qa_chain, msg, chatbot, |
|
doc_source1, source1_page, |
|
doc_source2, source2_page, |
|
doc_source3, source3_page, |
|
upload_status] |
|
) |
|
clear_btn.click( |
|
clear_all, |
|
outputs=[doc_source1, source1_page, |
|
doc_source2, source2_page, |
|
doc_source3, source3_page] |
|
) |
|
|
|
demo.queue().launch(debug=True) |
|
|
|
if __name__ == "__main__": |
|
demo() |