PDFChat / app.py
Akhil2507's picture
Update app.py
e2ca2ad
from langchain.chains import RetrievalQA
from langchain import HuggingFaceHub
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.document_loaders import PyPDFLoader
from langchain.vectorstores import FAISS
from dotenv import load_dotenv
from glob import glob
from tqdm import tqdm
import gradio as gr
import yaml
load_dotenv()
def load_config():
with open('config.yaml', 'r') as file:
config = yaml.safe_load(file)
return config
config = load_config()
def load_embeddings(model_name=config["embeddings"]["name"],
model_kwargs={'device': config["embeddings"]["device"]}):
return HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
def load_documents(directory: str):
"""Loads all documents from a directory and returns a list of Document objects
args: directory format = directory/
"""
text_splitter = RecursiveCharacterTextSplitter(chunk_size=config["TextSplitter"]["chunk_size"],
chunk_overlap=config["TextSplitter"]["chunk_overlap"])
documents = []
for item_path in tqdm(glob(directory + "*.pdf")):
loader = PyPDFLoader(item_path)
documents.extend(loader.load_and_split(text_splitter=text_splitter))
return documents
template = """Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Use three sentences maximum and keep the answer as concise as possible.
Always say "thanks for asking!" at the end of the answer.
{context}
Question: {question}
Helpful Answer:"""
QA_CHAIN_PROMPT = PromptTemplate.from_template(template)
repo_id = "google/flan-t5-xxl"
def get_llm():
llm = HuggingFaceHub(
repo_id=repo_id, model_kwargs={"temperature": 0.5, "max_length": 200}
)
return llm
def answer_question(question: str):
embedding_function = load_embeddings()
documents = load_documents("data/")
db = FAISS.from_documents(documents, embedding_function)
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 4})
qa_chain = RetrievalQA.from_chain_type(
get_llm(),
retriever=retriever,
chain_type="stuff",
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT},
return_source_documents=True
)
output = qa_chain({"query": question})
return output["result"]
# Gradio UI for PDFChat
with gr.Blocks() as demo:
with gr.Tab("PdfChat"):
with gr.Column():
ans = gr.Textbox(label="Answer", lines=10)
que = gr.Textbox(label="Ask a Question", lines=2)
bttn = gr.Button(value="Submit")
bttn.click(fn=answer_question, inputs=[que], outputs=[ans])
if __name__ == "__main__":
demo.launch()