from langchain.chains import RetrievalQA from langchain import HuggingFaceHub from langchain.prompts import PromptTemplate from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.embeddings.huggingface import HuggingFaceEmbeddings from langchain.document_loaders import PyPDFLoader from langchain.vectorstores import FAISS from dotenv import load_dotenv from glob import glob from tqdm import tqdm import gradio as gr import yaml load_dotenv() def load_config(): with open('config.yaml', 'r') as file: config = yaml.safe_load(file) return config config = load_config() def load_embeddings(model_name=config["embeddings"]["name"], model_kwargs={'device': config["embeddings"]["device"]}): return HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs) def load_documents(directory: str): """Loads all documents from a directory and returns a list of Document objects args: directory format = directory/ """ text_splitter = RecursiveCharacterTextSplitter(chunk_size=config["TextSplitter"]["chunk_size"], chunk_overlap=config["TextSplitter"]["chunk_overlap"]) documents = [] for item_path in tqdm(glob(directory + "*.pdf")): loader = PyPDFLoader(item_path) documents.extend(loader.load_and_split(text_splitter=text_splitter)) return documents template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. Use three sentences maximum and keep the answer as concise as possible. Always say "thanks for asking!" at the end of the answer. {context} Question: {question} Helpful Answer:""" QA_CHAIN_PROMPT = PromptTemplate.from_template(template) repo_id = "google/flan-t5-xxl" def get_llm(): llm = HuggingFaceHub( repo_id=repo_id, model_kwargs={"temperature": 0.5, "max_length": 200} ) return llm def answer_question(question: str): embedding_function = load_embeddings() documents = load_documents("data/") db = FAISS.from_documents(documents, embedding_function) retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 4}) qa_chain = RetrievalQA.from_chain_type( get_llm(), retriever=retriever, chain_type="stuff", chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}, return_source_documents=True ) output = qa_chain({"query": question}) return output["result"] # Gradio UI for PDFChat with gr.Blocks() as demo: with gr.Tab("PdfChat"): with gr.Column(): ans = gr.Textbox(label="Answer", lines=10) que = gr.Textbox(label="Ask a Question", lines=2) bttn = gr.Button(value="Submit") bttn.click(fn=answer_question, inputs=[que], outputs=[ans]) if __name__ == "__main__": demo.launch()