Akhil2507 commited on
Commit
a8e9ac6
·
1 Parent(s): a1b9a25

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.chains import RetrievalQA
2
+ from langchain import HuggingFaceHub
3
+ from langchain.prompts import PromptTemplate
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
6
+ from langchain.document_loaders import PyPDFLoader
7
+ from langchain.vectorstores import FAISS
8
+ from glob import glob
9
+ from tqdm import tqdm
10
+ import yaml
11
+
12
+
13
+ def load_config():
14
+ with open('config.yaml', 'r') as file:
15
+ config = yaml.safe_load(file)
16
+ return config
17
+
18
+
19
+ config = load_config()
20
+
21
+
22
+ def load_embeddings(model_name=config["embeddings"]["name"],
23
+ model_kwargs={'device': config["embeddings"]["device"]}):
24
+ return HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
25
+
26
+
27
+ def load_documents(directory: str):
28
+ """Loads all documents from a directory and returns a list of Document objects
29
+ args: directory format = directory/
30
+ """
31
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=config["TextSplitter"]["chunk_size"],
32
+ chunk_overlap=config["TextSplitter"]["chunk_overlap"])
33
+ documents = []
34
+ for item_path in tqdm(glob(directory + "*.pdf")):
35
+ loader = PyPDFLoader(item_path)
36
+ documents.extend(loader.load_and_split(text_splitter=text_splitter))
37
+
38
+ return documents
39
+
40
+
41
+ template = """Use the following pieces of context to answer the question at the end.
42
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
43
+ Use three sentences maximum and keep the answer as concise as possible.
44
+ Always say "thanks for asking!" at the end of the answer.
45
+ {context}
46
+ Question: {question}
47
+ Helpful Answer:"""
48
+ QA_CHAIN_PROMPT = PromptTemplate.from_template(template)
49
+
50
+ repo_id = "google/flan-t5-xxl"
51
+
52
+
53
+ def get_llm():
54
+ llm = HuggingFaceHub(
55
+ repo_id=repo_id, model_kwargs={"temperature": 0.5, "max_length": 200}
56
+ )
57
+ return llm
58
+
59
+
60
+ def answer_question(question: str):
61
+ embedding_function = load_embeddings()
62
+ documents = load_documents("data/")
63
+
64
+ db = FAISS.from_documents(documents, embedding_function)
65
+
66
+ retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 4})
67
+
68
+ qa_chain = RetrievalQA.from_chain_type(
69
+ get_llm(),
70
+ retriever=retriever,
71
+ chain_type="stuff",
72
+ chain_type_kwargs={"prompt": QA_CHAIN_PROMPT},
73
+ return_source_documents=True
74
+ )
75
+
76
+ output = qa_chain({"query": question})
77
+ return output["result"]
78
+
79
+
80
+ # Gradio UI for PDFChat
81
+ with gr.Blocks() as demo:
82
+ with gr.Tab("PdfChat"):
83
+ with gr.Row():
84
+ ans = gr.Textbox(label="Answer", lines=10)
85
+
86
+ que = gr.Textbox(label="Ask a Question", lines=3)
87
+
88
+ bttn = gr.Button(label="Submit")
89
+
90
+ bttn.click(fn=answer_question, inputs=[que], outputs=[ans])
91
+
92
+ if __name__ == "__main__":
93
+ demo.launch()