Kameshr commited on
Commit
fff0339
·
verified ·
1 Parent(s): 6bf9f75

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +144 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from langchain_community.vectorstores import FAISS
4
+ from langchain_community.document_loaders import PyPDFLoader
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain_community.embeddings import HuggingFaceEmbeddings
7
+ from langchain.chains import ConversationalRetrievalChain
8
+ from langchain_community.llms import HuggingFaceEndpoint
9
+ from langchain.memory import ConversationBufferMemory
10
+
11
+ api_token = os.getenv("HF_TOKEN")
12
+
13
+ list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
14
+ list_llm_simple = [os.path.basename(llm) for llm in list_llm]
15
+
16
+ def load_and_process_docs(list_file_path):
17
+ loaders = [PyPDFLoader(x) for x in list_file_path]
18
+ pages = []
19
+ for loader in loaders:
20
+ pages.extend(loader.load())
21
+ text_splitter = RecursiveCharacterTextSplitter(
22
+ chunk_size=1024,
23
+ chunk_overlap=64
24
+ )
25
+ return text_splitter.split_documents(pages)
26
+
27
+ def create_vector_db(splits):
28
+ embeddings = HuggingFaceEmbeddings()
29
+ return FAISS.from_documents(splits, embeddings)
30
+
31
+ def initialize_qa_chain(llm_model, vector_db, temperature=0.5, max_tokens=4096, top_k=3):
32
+ llm = HuggingFaceEndpoint(
33
+ repo_id=llm_model,
34
+ huggingfacehub_api_token=api_token,
35
+ temperature=temperature,
36
+ max_new_tokens=max_tokens,
37
+ top_k=top_k,
38
+ )
39
+
40
+ memory = ConversationBufferMemory(
41
+ memory_key="chat_history",
42
+ output_key='answer',
43
+ return_messages=True
44
+ )
45
+
46
+ return ConversationalRetrievalChain.from_llm(
47
+ llm,
48
+ retriever=vector_db.as_retriever(),
49
+ chain_type="stuff",
50
+ memory=memory,
51
+ return_source_documents=True,
52
+ verbose=False,
53
+ )
54
+
55
+ def format_response_with_citations(response_text, sources):
56
+ formatted_response = response_text
57
+ for idx, source in enumerate(sources, 1):
58
+ citation_marker = f"[{idx}]"
59
+ formatted_response += f"\n\n{citation_marker} (Page {source.metadata['page'] + 1}): {source.page_content.strip()}"
60
+ return formatted_response
61
+
62
+ def chat(qa_chain, message, history):
63
+ formatted_history = []
64
+ for user_msg, bot_msg in history:
65
+ formatted_history.append(f"User: {user_msg}")
66
+ formatted_history.append(f"Assistant: {bot_msg}")
67
+
68
+ response = qa_chain.invoke({
69
+ "question": message,
70
+ "chat_history": formatted_history
71
+ })
72
+
73
+ answer = response["answer"]
74
+ if "Helpful Answer:" in answer:
75
+ answer = answer.split("Helpful Answer:")[-1]
76
+
77
+ formatted_response = format_response_with_citations(
78
+ answer,
79
+ response["source_documents"][:3]
80
+ )
81
+
82
+ return qa_chain, "", history + [(message, formatted_response)]
83
+
84
+ def demo():
85
+ with gr.Blocks(theme=gr.themes.Default(primary_hue="red")) as demo:
86
+ qa_chain = gr.State()
87
+
88
+ gr.HTML("<center><h1>RAG PDF Chatbot</h1></center>")
89
+ gr.Markdown("""Query your PDF documents with citation support.
90
+ **Please do not upload confidential documents.**""")
91
+
92
+ with gr.Row():
93
+ with gr.Column():
94
+ document = gr.Files(
95
+ height=100,
96
+ file_count="multiple",
97
+ file_types=["pdf"],
98
+ label="Upload PDF Documents"
99
+ )
100
+ llm_choice = gr.Radio(
101
+ list_llm_simple,
102
+ label="Select Language Model",
103
+ value=list_llm_simple[0],
104
+ type="index"
105
+ )
106
+
107
+ with gr.Column():
108
+ chatbot = gr.Chatbot(height=500)
109
+ msg = gr.Textbox(
110
+ placeholder="Ask a question about your documents",
111
+ container=True
112
+ )
113
+ with gr.Row():
114
+ submit_btn = gr.Button("Submit")
115
+ clear_btn = gr.ClearButton([msg, chatbot])
116
+
117
+ def initialize_system(files, llm_idx):
118
+ if not files:
119
+ return None
120
+ file_paths = [f.name for f in files]
121
+ splits = load_and_process_docs(file_paths)
122
+ vector_db = create_vector_db(splits)
123
+ return initialize_qa_chain(list_llm[llm_idx], vector_db)
124
+
125
+ # Auto-initialize when files are uploaded and model is selected
126
+ document.change(
127
+ initialize_system,
128
+ inputs=[document, llm_choice],
129
+ outputs=[qa_chain]
130
+ )
131
+ llm_choice.change(
132
+ initialize_system,
133
+ inputs=[document, llm_choice],
134
+ outputs=[qa_chain]
135
+ )
136
+
137
+ # Chat interactions
138
+ msg.submit(chat, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot])
139
+ submit_btn.click(chat, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot])
140
+
141
+ return demo.queue()
142
+
143
+ if __name__ == "__main__":
144
+ demo().launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ sentence-transformers
4
+ langchain
5
+ langchain-community
6
+ tqdm
7
+ accelerate
8
+ pypdf
9
+ faiss-gpu