Luciferalive commited on
Commit
806e7bf
·
verified ·
1 Parent(s): 437db54

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -0
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain.chains import LLMChain
3
+ from langchain.prompts import PromptTemplate
4
+ from langchain_community.llms import HuggingFaceEndpoint
5
+ from pdfminer.high_level import extract_text
6
+ import docx2txt
7
+ import io
8
+ import re
9
+ from typing import List
10
+ from langchain.vectorstores import Chroma
11
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
12
+ from langchain.embeddings import SentenceTransformerEmbeddings
13
+ from sentence_transformers import SentenceTransformer
14
+ from sklearn.metrics.pairwise import cosine_similarity
15
+ import numpy as np
16
+ import os
17
+ import boto3
18
+
19
+ # AWS access credentials
20
+ access_key = 'AKIAUI7N373AFR74QX5H'
21
+ secret_key = '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'
22
+
23
+ # S3 bucket details
24
+ bucket_name = 'sentinelx-prod'
25
+ prefix = 'LOTO/Documents/LOTOFormDocuments/'
26
+
27
+ HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
28
+
29
+ def extract_text_from_pdf(pdf_content):
30
+ return extract_text(io.BytesIO(pdf_content))
31
+
32
+ def extract_text_from_doc(doc_content):
33
+ return docx2txt.process(io.BytesIO(doc_content))
34
+
35
+ def preprocess_text(text):
36
+ text = text.replace('\n', ' ').replace('\r', ' ')
37
+ text = re.sub(r'[^\x00-\x7F]+', ' ', text)
38
+ text = text.lower()
39
+ text = re.sub(r'[^\w\s]', '', text)
40
+ text = re.sub(r'\s+', ' ', text).strip()
41
+ return text
42
+
43
+ def process_files(file_contents: List[bytes]):
44
+ all_text = ""
45
+ for file_content in file_contents:
46
+ if file_content.startswith(b'%PDF'):
47
+ extracted_text = extract_text_from_pdf(file_content)
48
+ else:
49
+ extracted_text = extract_text_from_doc(file_content)
50
+ preprocessed_text = preprocess_text(extracted_text)
51
+ all_text += preprocessed_text + " "
52
+ return all_text
53
+
54
+ def compute_cosine_similarity_scores(query, retrieved_docs):
55
+ model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
56
+ query_embedding = model.encode(query, convert_to_tensor=True)
57
+ doc_embeddings = model.encode(retrieved_docs, convert_to_tensor=True)
58
+ cosine_scores = np.dot(doc_embeddings, query_embedding.T)
59
+ readable_scores = [{"doc": doc, "score": float(score)} for doc, score in zip(retrieved_docs, cosine_scores.flatten())]
60
+ return readable_scores
61
+
62
+ def answer_query_with_similarity(query):
63
+ try:
64
+ # Fetch files from S3
65
+ s3 = boto3.client('s3', aws_access_key_id=access_key, aws_secret_access_key=secret_key)
66
+ objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
67
+
68
+ file_contents = []
69
+ for obj in objects.get('Contents', []):
70
+ if not obj['Key'].endswith('/'): # Skip directories
71
+ response = s3.get_object(Bucket=bucket_name, Key=obj['Key'])
72
+ file_content = response['Body'].read()
73
+ file_contents.append(file_content)
74
+
75
+ all_text = process_files(file_contents)
76
+
77
+ embeddings = SentenceTransformerEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
78
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
79
+ texts = text_splitter.split_text(all_text)
80
+
81
+ vector_store = Chroma.from_texts(texts, embeddings, collection_metadata={"hnsw:space": "cosine"}, persist_directory="stores/insurance_cosine")
82
+ load_vector_store = Chroma(persist_directory="stores/insurance_cosine", embedding_function=embeddings)
83
+ print("Vector DB Successfully Created!")
84
+
85
+ db3 = Chroma(persist_directory=f"stores/insurance_cosine", embedding_function=embeddings)
86
+ docs = db3.similarity_search(query)
87
+ print(f"\n\nDocuments retrieved: {len(docs)}")
88
+
89
+ if not docs:
90
+ print("No documents match the query.")
91
+ return None
92
+
93
+ docs_content = [doc.page_content for doc in docs]
94
+ for i, content in enumerate(docs_content, start=1):
95
+ print(f"\nDocument {i}: {content}...")
96
+
97
+ cosine_similarity_scores = compute_cosine_similarity_scores(query, docs_content)
98
+ for score in cosine_similarity_scores:
99
+ print(f"\nDocument Score: {score['score']}")
100
+
101
+ all_docs_content = " ".join(docs_content)
102
+
103
+ template = """
104
+ ### [INST] Instruction:You are an AI assistant named Goose. Your purpose is to provide accurate, relevant, and helpful information to users in a friendly, warm, and supportive manner, similar to ChatGPT. When responding to queries, please keep the following guidelines in mind:
105
+ When someone say hi, or small talk, o only response in a sentence.
106
+ Retrieve relevant information from your knowledge base to formulate accurate and informative responses.
107
+ Always maintain a positive, friendly, and encouraging tone in your interactions with users.
108
+ Strictly write the crisp and clear answers, dont write unnecesary stuff.
109
+ Only answer to the asked question, don't hellucinate of print any pre information.
110
+ After providing the answer, always ask a for any other help needed in the next paragraph
111
+ Writing in the bullet format is our top preference
112
+ Remember, your goal is to be a reliable, friendly, and supportive AI assistant that provides accurate information while creating a positive user experience, just like ChatGPT. Adapt your communication style to best suit each user's needs and preferences.
113
+ ### Docs : {docs}
114
+ ### Question : {question}
115
+ """
116
+ prompt = PromptTemplate.from_template(template.format(docs=all_docs_content, question=query))
117
+
118
+ repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"
119
+ llm = HuggingFaceEndpoint(repo_id=repo_id, temperature=0.1, token=HUGGINGFACEHUB_API_TOKEN,
120
+ top_p=0.15,
121
+ max_new_tokens=256,
122
+ repetition_penalty=1.1
123
+ )
124
+ llm_chain = LLMChain(prompt=prompt, llm=llm)
125
+
126
+ answer = llm_chain.run(question=query)
127
+ cleaned_answer = answer.split("Answer:")[-1].strip()
128
+ print(f"\n\nAnswer: {cleaned_answer}")
129
+
130
+ return cleaned_answer
131
+ except Exception as e:
132
+ print("An error occurred while getting the answer: ", str(e))
133
+ return None
134
+
135
+ def main():
136
+ st.title("Document Query App")
137
+
138
+ query = st.text_input("Enter your query:")
139
+
140
+ if st.button("Get Answer"):
141
+ if query:
142
+ response = answer_query_with_similarity(query)
143
+ if response:
144
+ st.write("Answer:", response)
145
+ else:
146
+ st.write("No answer found.")
147
+ else:
148
+ st.write("Please provide a query.")
149
+
150
+ if __name__ == "__main__":
151
+ main()