Pedrampedram commited on
Commit
37e6628
1 Parent(s): 8bb4085

Upload 4 files

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. app.py +18 -0
  3. dataset.tsv +3 -0
  4. question_processing.py +91 -0
  5. requirements.txt +5 -0
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ dataset.tsv filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from question_processing import process_question
3
+
4
+ st.title("Question Answering System")
5
+ st.write("Enter your question and get an answer from the pre-trained model.")
6
+
7
+ # Input field for the user's question
8
+ question = st.text_input("Please enter your question:")
9
+
10
+ # Process the question and display the answer(s) when the user clicks the "Submit" button
11
+ if st.button("Submit"):
12
+ if question:
13
+ answers = process_question(question)
14
+ for answer in answers:
15
+ st.write("Answer:", answer)
16
+ st.write("---")
17
+ else:
18
+ st.write("Please enter a question.")
dataset.tsv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e88b7f47f3494171367face846d1dcaf2710854870b076d6d419b8bae720bf1
3
+ size 28877451
question_processing.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import necessary libraries
2
+ import os
3
+ import textwrap
4
+ import pandas as pd
5
+ from langchain import HuggingFaceHub
6
+ from langchain.document_loaders import TextLoader
7
+ from langchain.embeddings import HuggingFaceEmbeddings
8
+ from langchain.text_splitter import CharacterTextSplitter
9
+ from langchain.vectorstores import FAISS
10
+ from langchain.chains.question_answering import load_qa_chain
11
+ from transformers import AutoTokenizer
12
+
13
+ def wrap_text_preserve_newlines(text, width=110):
14
+ lines = text.split('\n')
15
+ wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
16
+ wrapped_text = '\n'.join(wrapped_lines)
17
+ return wrapped_text
18
+
19
+ def split_into_chunks(text, tokenizer, max_tokens=500):
20
+ tokens = tokenizer.encode(text, return_tensors="pt").squeeze()
21
+ token_chunks = []
22
+
23
+ current_chunk = []
24
+ current_chunk_len = 0
25
+ for token in tokens:
26
+ token_len = len(tokenizer.decode(token.item()))
27
+ if current_chunk_len + token_len + 1 > max_tokens:
28
+ token_chunks.append(tokenizer.decode(current_chunk))
29
+ current_chunk = []
30
+ current_chunk_len = 0
31
+ current_chunk.append(token.item())
32
+ current_chunk_len += token_len + 1
33
+
34
+ if current_chunk:
35
+ token_chunks.append(tokenizer.decode(current_chunk))
36
+
37
+ return token_chunks
38
+
39
+ tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xxl")
40
+
41
+ class TextDocument:
42
+ def __init__(self, content, id, metadata=None):
43
+ self.page_content = content
44
+ self.metadata = metadata if metadata is not None else {}
45
+ self.metadata['id'] = id
46
+
47
+ os.environ["HUGGINGFACEHUB_API_TOKEN"] = "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"
48
+
49
+ data_frame = pd.read_csv("dataset.tsv", sep="\t", nrows=1000)
50
+ data = data_frame.to_dict(orient="records")
51
+ documents = [TextDocument(content=str(item["answer"]), id=item["id"]) for item in data]
52
+ text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=0)
53
+ docs = text_splitter.split_documents(documents)
54
+ embeddings = HuggingFaceEmbeddings()
55
+ db = FAISS.from_documents(docs, embeddings)
56
+
57
+ llm = HuggingFaceHub(repo_id="google/flan-t5-xxl", model_kwargs={"temperature": 0.75, "max_length": 2048})
58
+ chain = load_qa_chain(llm, chain_type="refine")
59
+
60
+ def truncate_answer(answer, question, tokenizer, max_total_tokens=1000):
61
+ special_tokens = 2
62
+ question_tokens = len(tokenizer.encode(question, return_tensors="pt").squeeze())
63
+ max_answer_tokens = max_total_tokens - question_tokens - special_tokens
64
+ answer_tokens = tokenizer.encode(answer, return_tensors="pt").squeeze()
65
+ truncated_answer = tokenizer.decode(answer_tokens[:max_answer_tokens])
66
+ return truncated_answer
67
+
68
+ def combined_length_exceeds_limit(question, answer, tokenizer, model_token_limit=1024):
69
+ special_tokens = 2
70
+ question_tokens = len(tokenizer.encode(question, return_tensors="pt").squeeze())
71
+ answer_tokens = len(tokenizer.encode(answer, return_tensors="pt").squeeze())
72
+ return question_tokens + answer_tokens > (model_token_limit - special_tokens)
73
+
74
+ def process_question(query):
75
+ answers = []
76
+
77
+ docs = db.similarity_search(query)
78
+ most_similar_doc = docs[0]
79
+ print(f"Most similar answer: \n{wrap_text_preserve_newlines(str(most_similar_doc.page_content))}\n")
80
+
81
+ query_chunks = split_into_chunks(query, tokenizer, max_tokens=500)
82
+
83
+ for query_chunk in query_chunks:
84
+ if combined_length_exceeds_limit(query_chunk, str(docs[0].page_content), tokenizer):
85
+ print("The combined length of the question and answer exceeds the model's token limit.")
86
+ else:
87
+ truncated_answer = truncate_answer(str(docs[0].page_content), query_chunk, tokenizer, max_total_tokens=500)
88
+ result = chain.run(input_documents=[TextDocument(content=truncated_answer, id=docs[0].metadata['id'])], question=query_chunk)
89
+ answers.append(result)
90
+
91
+ return answers
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers==4.27.1
2
+ torch>=1.13.1
3
+ datasets==2.10.1
4
+
5
+ tqdm==4.65.0