|
from langchain.chains import RetrievalQA |
|
from langchain import HuggingFaceHub |
|
from langchain.prompts import PromptTemplate |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings |
|
from langchain.document_loaders import PyPDFLoader |
|
from langchain.vectorstores import FAISS |
|
from dotenv import load_dotenv |
|
from glob import glob |
|
from tqdm import tqdm |
|
import gradio as gr |
|
import yaml |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
def load_config(): |
|
with open('config.yaml', 'r') as file: |
|
config = yaml.safe_load(file) |
|
return config |
|
|
|
|
|
config = load_config() |
|
|
|
|
|
def load_embeddings(model_name=config["embeddings"]["name"], |
|
model_kwargs={'device': config["embeddings"]["device"]}): |
|
return HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs) |
|
|
|
|
|
def load_documents(directory: str): |
|
"""Loads all documents from a directory and returns a list of Document objects |
|
args: directory format = directory/ |
|
""" |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=config["TextSplitter"]["chunk_size"], |
|
chunk_overlap=config["TextSplitter"]["chunk_overlap"]) |
|
documents = [] |
|
for item_path in tqdm(glob(directory + "*.pdf")): |
|
loader = PyPDFLoader(item_path) |
|
documents.extend(loader.load_and_split(text_splitter=text_splitter)) |
|
|
|
return documents |
|
|
|
|
|
template = """Use the following pieces of context to answer the question at the end. |
|
If you don't know the answer, just say that you don't know, don't try to make up an answer. |
|
Use three sentences maximum and keep the answer as concise as possible. |
|
Always say "thanks for asking!" at the end of the answer. |
|
{context} |
|
Question: {question} |
|
Helpful Answer:""" |
|
QA_CHAIN_PROMPT = PromptTemplate.from_template(template) |
|
|
|
repo_id = "google/flan-t5-xxl" |
|
|
|
|
|
def get_llm(): |
|
llm = HuggingFaceHub( |
|
repo_id=repo_id, model_kwargs={"temperature": 0.5, "max_length": 200} |
|
) |
|
return llm |
|
|
|
|
|
def answer_question(question: str): |
|
embedding_function = load_embeddings() |
|
documents = load_documents("data/") |
|
|
|
db = FAISS.from_documents(documents, embedding_function) |
|
|
|
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 4}) |
|
|
|
qa_chain = RetrievalQA.from_chain_type( |
|
get_llm(), |
|
retriever=retriever, |
|
chain_type="stuff", |
|
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}, |
|
return_source_documents=True |
|
) |
|
|
|
output = qa_chain({"query": question}) |
|
return output["result"] |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Tab("PdfChat"): |
|
with gr.Column(): |
|
ans = gr.Textbox(label="Answer", lines=10) |
|
|
|
que = gr.Textbox(label="Ask a Question", lines=2) |
|
|
|
bttn = gr.Button(value="Submit") |
|
|
|
bttn.click(fn=answer_question, inputs=[que], outputs=[ans]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|