File size: 2,964 Bytes
a8e9ac6
 
 
 
 
 
 
e2ca2ad
a8e9ac6
 
5da5dff
a8e9ac6
 
 
e2ca2ad
 
 
a8e9ac6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2ca2ad
a8e9ac6
 
e2ca2ad
a8e9ac6
e2ca2ad
a8e9ac6
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from langchain.chains import RetrievalQA
from langchain import HuggingFaceHub
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.document_loaders import PyPDFLoader
from langchain.vectorstores import FAISS
from dotenv import load_dotenv
from glob import glob
from tqdm import tqdm
import gradio as gr
import yaml


load_dotenv()


def load_config():
    with open('config.yaml', 'r') as file:
        config = yaml.safe_load(file)
    return config


config = load_config()


def load_embeddings(model_name=config["embeddings"]["name"],
                    model_kwargs={'device': config["embeddings"]["device"]}):
    return HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)


def load_documents(directory: str):
    """Loads all documents from a directory and returns a list of Document objects
    args: directory format = directory/
    """
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=config["TextSplitter"]["chunk_size"],
                                                   chunk_overlap=config["TextSplitter"]["chunk_overlap"])
    documents = []
    for item_path in tqdm(glob(directory + "*.pdf")):
        loader = PyPDFLoader(item_path)
        documents.extend(loader.load_and_split(text_splitter=text_splitter))

    return documents


template = """Use the following pieces of context to answer the question at the end. 
If you don't know the answer, just say that you don't know, don't try to make up an answer. 
Use three sentences maximum and keep the answer as concise as possible. 
Always say "thanks for asking!" at the end of the answer. 
{context}
Question: {question}
Helpful Answer:"""
QA_CHAIN_PROMPT = PromptTemplate.from_template(template)

repo_id = "google/flan-t5-xxl"


def get_llm():
    llm = HuggingFaceHub(
        repo_id=repo_id, model_kwargs={"temperature": 0.5, "max_length": 200}
    )
    return llm


def answer_question(question: str):
    embedding_function = load_embeddings()
    documents = load_documents("data/")

    db = FAISS.from_documents(documents, embedding_function)

    retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 4})

    qa_chain = RetrievalQA.from_chain_type(
        get_llm(),
        retriever=retriever,
        chain_type="stuff",
        chain_type_kwargs={"prompt": QA_CHAIN_PROMPT},
        return_source_documents=True
    )

    output = qa_chain({"query": question})
    return output["result"]


# Gradio UI for PDFChat
with gr.Blocks() as demo:
    with gr.Tab("PdfChat"):
        with gr.Column():
            ans = gr.Textbox(label="Answer", lines=10)

            que = gr.Textbox(label="Ask a Question", lines=2)

            bttn = gr.Button(value="Submit")

        bttn.click(fn=answer_question, inputs=[que], outputs=[ans])

if __name__ == "__main__":
    demo.launch()