ziyingsk commited on
Commit
b4853bc
·
verified ·
1 Parent(s): e16699e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -138
app.py CHANGED
@@ -1,138 +1,142 @@
1
- import os
2
- import streamlit as st
3
- from dotenv import load_dotenv
4
- import itertools
5
- from pinecone import Pinecone
6
- from langchain_community.llms import HuggingFaceHub
7
- from langchain.chains import LLMChain
8
- from langchain_community.document_loaders import PyPDFDirectoryLoader
9
- from langchain.text_splitter import RecursiveCharacterTextSplitter
10
- from langchain.prompts import PromptTemplate
11
- from sentence_transformers import SentenceTransformer
12
- import torch
13
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
14
- import logging
15
-
16
- # Set up environment, Pinecone is a database
17
- load_dotenv() # Load document .env
18
- cache_dir = os.getenv("CACHE_DIR") # Directory for cache
19
- Huggingface_token = os.getenv("API_TOKEN") # Huggingface API key
20
- pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY")) # Database API key
21
- index = pc.Index(os.getenv("Index_Name")) # Database index name
22
-
23
- # Initialize embedding model (LLM will be saved to cache_dir if assigned)
24
- embedding_model = "all-mpnet-base-v2" # See link https://www.sbert.net/docs/pretrained_models.html
25
-
26
- if cache_dir:
27
- embedding = SentenceTransformer(embedding_model, cache_folder=cache_dir)
28
- else:
29
- embedding = SentenceTransformer(embedding_model)
30
-
31
- # Read the PDF files, divide them into chunks, and Embedding
32
- def read_doc(file_path):
33
- file_loader = PyPDFDirectoryLoader(file_path)
34
- documents = file_loader.load()
35
- return documents
36
-
37
- def chunk_data(docs, chunk_size=300, chunk_overlap=50):
38
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
39
- doc = text_splitter.split_documents(docs)
40
- return doc
41
-
42
- # Save embeddings to database
43
- def chunks(iterable, batch_size=100):
44
- """A helper function to break an iterable into chunks of size batch_size."""
45
- it = iter(iterable)
46
- chunk = tuple(itertools.islice(it, batch_size))
47
- while chunk:
48
- yield chunk
49
- chunk = tuple(itertools.islice(it, batch_size))
50
-
51
- # Streamlit interface start, uploading file
52
- st.title("RAG-Anwendung (RAG Application)")
53
- st.caption("Diese Anwendung kann Ihnen helfen, kostenlos Fragen zu PDF-Dateien zu stellen. (This application can help you ask questions about PDF files for free.)")
54
-
55
- uploaded_file = st.file_uploader("Wählen Sie eine PDF-Datei, das Laden kann eine Weile dauern. (Choose a PDF file, loading might take a while.)", type="pdf")
56
- if uploaded_file is not None:
57
- # Ensure the temp directory exists and is empty
58
- temp_dir = "tempDir"
59
- if os.path.exists(temp_dir):
60
- for file in os.listdir(temp_dir):
61
- file_path = os.path.join(temp_dir, file)
62
- if os.path.isfile(file_path):
63
- os.remove(file_path)
64
- elif os.path.isdir(file_path):
65
- os.rmdir(file_path) # Only removes empty directories
66
-
67
- os.makedirs(temp_dir, exist_ok=True)
68
-
69
- # Save the uploaded file temporarily
70
- temp_file_path = os.path.join(temp_dir, uploaded_file.name)
71
- with open(temp_file_path, "wb") as f:
72
- f.write(uploaded_file.getbuffer())
73
- doc = read_doc(temp_dir+"/")
74
- documents = chunk_data(docs=doc)
75
- texts = [document.page_content for document in documents]
76
- pdf_vectors = embedding.encode(texts)
77
- vector_count = len(documents)
78
- example_data_generator = map(lambda i: (f'id-{i}', pdf_vectors[i], {"text": texts[i]}), range(vector_count))
79
- if 'ns1' in index.describe_index_stats()['namespaces']:
80
- index.delete(delete_all=True,namespace='ns1')
81
- for ids_vectors_chunk in chunks(example_data_generator, batch_size=100):
82
- index.upsert(vectors=ids_vectors_chunk,namespace='ns1')
83
-
84
- # Search query related context
85
- sample_query = st.text_input("Stellen Sie eine Frage zu dem PDF: (Ask a question related to the PDF:)")
86
- if st.button("Abschicken (Submit)"):
87
- if uploaded_file is not None and sample_query:
88
- query_vector = embedding.encode(sample_query).tolist()
89
- query_search = index.query(vector=query_vector, top_k=5, include_metadata=True)
90
-
91
- matched_contents = [match["metadata"]["text"] for match in query_search["matches"]]
92
-
93
- # Rerank
94
- rerank_model = "BAAI/bge-reranker-v2-m3"
95
- if cache_dir:
96
- tokenizer = AutoTokenizer.from_pretrained(rerank_model, cache_dir=cache_dir)
97
- model = AutoModelForSequenceClassification.from_pretrained(rerank_model, cache_dir=cache_dir)
98
- else:
99
- tokenizer = AutoTokenizer.from_pretrained(rerank_model)
100
- model = AutoModelForSequenceClassification.from_pretrained(rerank_model)
101
- model.eval()
102
-
103
- pairs = [[sample_query, content] for content in matched_contents]
104
- with torch.no_grad():
105
- inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=300)
106
- scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
107
- matched_contents = [content for _, content in sorted(zip(scores, matched_contents), key=lambda x: x[0], reverse=True)]
108
- matched_contents = matched_contents[0]
109
- del model
110
- torch.cuda.empty_cache()
111
-
112
- # Display matched contents after reranking
113
- st.markdown("### Möglicherweise relevante Abschnitte aus dem PDF (Potentially relevant sections from the PDF):")
114
- st.write(matched_contents)
115
-
116
- # Get answer
117
- query_model = "meta-llama/Meta-Llama-3-8B-Instruct"
118
- llm_huggingface = HuggingFaceHub(repo_id=query_model, model_kwargs={"temperature": 0.7, "max_length": 500})
119
-
120
- prompt_template = PromptTemplate(input_variables=['query', 'context'], template="{query}, Beim Beantworten der Frage bitte mit dem Wort 'Antwort:' beginnen,unter Berücksichtigung des folgenden Kontexts: \n\n{context}")
121
-
122
- prompt = prompt_template.format(query=sample_query, context=matched_contents)
123
- chain = LLMChain(llm=llm_huggingface, prompt=prompt_template)
124
- result = chain.run(query=sample_query, context=matched_contents)
125
-
126
- # Polish answer
127
- result = result.replace(prompt, "")
128
- special_start = "Antwort:"
129
- start_index = result.find(special_start)
130
- if (start_index != -1):
131
- result = result[start_index + len(special_start):].lstrip()
132
- else:
133
- result = result.lstrip()
134
-
135
- # Display the final answer with a note about limitations
136
- st.markdown("### Antwort (Answer):")
137
- st.write(result)
138
- st.markdown("**Hinweis:** Aufgrund begrenzter Rechenleistung kann das große Sprachmodell möglicherweise keine vollständige Antwort liefern. (Note: Due to limited computational power, the large language model might not be able to provide a complete response.)")
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from dotenv import load_dotenv
4
+ import itertools
5
+ from pinecone import Pinecone
6
+ from langchain_community.llms import HuggingFaceHub
7
+ from langchain.chains import LLMChain
8
+ from langchain_community.document_loaders import PyPDFDirectoryLoader
9
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
10
+ from langchain.prompts import PromptTemplate
11
+ from sentence_transformers import SentenceTransformer
12
+ import torch
13
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
14
+ import logging
15
+
16
+ # Set up environment, Pinecone is a database
17
+ st.write(
18
+ "Has environment variables been set:",
19
+ os.environ["API_TOKEN"] == st.secrets["HUGGINGFACEHUB_API_TOKEN"]
20
+ os.environ["PINECONE_API_KEY"] == st.secrets["PINECONE_API_KEY"]
21
+ os.environ["Index_Name"] == st.secrets["Index_Name"])
22
+ cache_dir = os.getenv("CACHE_DIR") # Directory for cache
23
+ Huggingface_token = os.getenv("API_TOKEN") # Huggingface API key
24
+ pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY")) # Database API key
25
+ index = pc.Index(os.getenv("Index_Name")) # Database index name
26
+
27
+ # Initialize embedding model (LLM will be saved to cache_dir if assigned)
28
+ embedding_model = "all-mpnet-base-v2" # See link https://www.sbert.net/docs/pretrained_models.html
29
+
30
+ if cache_dir:
31
+ embedding = SentenceTransformer(embedding_model, cache_folder=cache_dir)
32
+ else:
33
+ embedding = SentenceTransformer(embedding_model)
34
+
35
+ # Read the PDF files, divide them into chunks, and Embedding
36
+ def read_doc(file_path):
37
+ file_loader = PyPDFDirectoryLoader(file_path)
38
+ documents = file_loader.load()
39
+ return documents
40
+
41
+ def chunk_data(docs, chunk_size=300, chunk_overlap=50):
42
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
43
+ doc = text_splitter.split_documents(docs)
44
+ return doc
45
+
46
+ # Save embeddings to database
47
+ def chunks(iterable, batch_size=100):
48
+ """A helper function to break an iterable into chunks of size batch_size."""
49
+ it = iter(iterable)
50
+ chunk = tuple(itertools.islice(it, batch_size))
51
+ while chunk:
52
+ yield chunk
53
+ chunk = tuple(itertools.islice(it, batch_size))
54
+
55
+ # Streamlit interface start, uploading file
56
+ st.title("RAG-Anwendung (RAG Application)")
57
+ st.caption("Diese Anwendung kann Ihnen helfen, kostenlos Fragen zu PDF-Dateien zu stellen. (This application can help you ask questions about PDF files for free.)")
58
+
59
+ uploaded_file = st.file_uploader("Wählen Sie eine PDF-Datei, das Laden kann eine Weile dauern. (Choose a PDF file, loading might take a while.)", type="pdf")
60
+ if uploaded_file is not None:
61
+ # Ensure the temp directory exists and is empty
62
+ temp_dir = "tempDir"
63
+ if os.path.exists(temp_dir):
64
+ for file in os.listdir(temp_dir):
65
+ file_path = os.path.join(temp_dir, file)
66
+ if os.path.isfile(file_path):
67
+ os.remove(file_path)
68
+ elif os.path.isdir(file_path):
69
+ os.rmdir(file_path) # Only removes empty directories
70
+
71
+ os.makedirs(temp_dir, exist_ok=True)
72
+
73
+ # Save the uploaded file temporarily
74
+ temp_file_path = os.path.join(temp_dir, uploaded_file.name)
75
+ with open(temp_file_path, "wb") as f:
76
+ f.write(uploaded_file.getbuffer())
77
+ doc = read_doc(temp_dir+"/")
78
+ documents = chunk_data(docs=doc)
79
+ texts = [document.page_content for document in documents]
80
+ pdf_vectors = embedding.encode(texts)
81
+ vector_count = len(documents)
82
+ example_data_generator = map(lambda i: (f'id-{i}', pdf_vectors[i], {"text": texts[i]}), range(vector_count))
83
+ if 'ns1' in index.describe_index_stats()['namespaces']:
84
+ index.delete(delete_all=True,namespace='ns1')
85
+ for ids_vectors_chunk in chunks(example_data_generator, batch_size=100):
86
+ index.upsert(vectors=ids_vectors_chunk,namespace='ns1')
87
+
88
+ # Search query related context
89
+ sample_query = st.text_input("Stellen Sie eine Frage zu dem PDF: (Ask a question related to the PDF:)")
90
+ if st.button("Abschicken (Submit)"):
91
+ if uploaded_file is not None and sample_query:
92
+ query_vector = embedding.encode(sample_query).tolist()
93
+ query_search = index.query(vector=query_vector, top_k=5, include_metadata=True)
94
+
95
+ matched_contents = [match["metadata"]["text"] for match in query_search["matches"]]
96
+
97
+ # Rerank
98
+ rerank_model = "BAAI/bge-reranker-v2-m3"
99
+ if cache_dir:
100
+ tokenizer = AutoTokenizer.from_pretrained(rerank_model, cache_dir=cache_dir)
101
+ model = AutoModelForSequenceClassification.from_pretrained(rerank_model, cache_dir=cache_dir)
102
+ else:
103
+ tokenizer = AutoTokenizer.from_pretrained(rerank_model)
104
+ model = AutoModelForSequenceClassification.from_pretrained(rerank_model)
105
+ model.eval()
106
+
107
+ pairs = [[sample_query, content] for content in matched_contents]
108
+ with torch.no_grad():
109
+ inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=300)
110
+ scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
111
+ matched_contents = [content for _, content in sorted(zip(scores, matched_contents), key=lambda x: x[0], reverse=True)]
112
+ matched_contents = matched_contents[0]
113
+ del model
114
+ torch.cuda.empty_cache()
115
+
116
+ # Display matched contents after reranking
117
+ st.markdown("### Möglicherweise relevante Abschnitte aus dem PDF (Potentially relevant sections from the PDF):")
118
+ st.write(matched_contents)
119
+
120
+ # Get answer
121
+ query_model = "meta-llama/Meta-Llama-3-8B-Instruct"
122
+ llm_huggingface = HuggingFaceHub(repo_id=query_model, model_kwargs={"temperature": 0.7, "max_length": 500})
123
+
124
+ prompt_template = PromptTemplate(input_variables=['query', 'context'], template="{query}, Beim Beantworten der Frage bitte mit dem Wort 'Antwort:' beginnen,unter Berücksichtigung des folgenden Kontexts: \n\n{context}")
125
+
126
+ prompt = prompt_template.format(query=sample_query, context=matched_contents)
127
+ chain = LLMChain(llm=llm_huggingface, prompt=prompt_template)
128
+ result = chain.run(query=sample_query, context=matched_contents)
129
+
130
+ # Polish answer
131
+ result = result.replace(prompt, "")
132
+ special_start = "Antwort:"
133
+ start_index = result.find(special_start)
134
+ if (start_index != -1):
135
+ result = result[start_index + len(special_start):].lstrip()
136
+ else:
137
+ result = result.lstrip()
138
+
139
+ # Display the final answer with a note about limitations
140
+ st.markdown("### Antwort (Answer):")
141
+ st.write(result)
142
+ st.markdown("**Hinweis:** Aufgrund begrenzter Rechenleistung kann das große Sprachmodell möglicherweise keine vollständige Antwort liefern. (Note: Due to limited computational power, the large language model might not be able to provide a complete response.)")