PoliticsToYou / src /vectordatabase.py
TomData's picture
added flexible vectorstore
85df319
raw
history blame
5.46 kB
from langchain_community.document_loaders import DataFrameLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.vectorstores import FAISS
from langchain_community.llms import HuggingFaceHub
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
from faiss import IndexFlatL2
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain.embeddings import SentenceTransformerEmbeddings
import functools
import pandas as pd
import os
#from dotenv import load_dotenv
#Load environmental variables from .env-file
#load_dotenv()
# Load documents to create a vectorstore later
def load_documents(df):
# To Do: Create one initial vectore store loading all the documents with this function
#loader = CSVLoader(index_name, source_column="speech_content") #unprocessed csv file
loader = DataFrameLoader(data_frame=df, page_content_column='speech_content') #df
data = loader.load()
splitter = RecursiveCharacterTextSplitter(
chunk_size=1024,
chunk_overlap=32,
length_function=len,
is_separator_regex=False,
)
documents = splitter.split_documents(documents=data)
return documents
#@functools.lru_cache()
def get_vectorstore(inputs, embeddings):
"""
Combine multiple FAISS vector stores into a single vector store based on the specified inputs.
Parameters:
----------
inputs : list of str
A list of strings specifying which vector stores to combine. Each string represents a specific
index or a special keyword "All". If "All" is included in the list, it will load a pre-defined
comprehensive vector store and return immediately.
embeddings : Embeddings
An instance of embeddings that will be used to load the vector stores. The specific type and
structure of `embeddings` depend on the implementation of the `get_vectorstore` function.
Returns:
-------
FAISS
A FAISS vector store that combines the specified indices into a single vector store.
Notes:
-----
- The `folder_path` variable is set to the default path "./src/FAISS", where the FAISS index files are stored.
- The function initializes an empty FAISS vector store with a dimensionality of 128.
- If "All" is specified in the `inputs`, it directly loads and returns the comprehensive vector store named "speeches_1949_09_12".
- For each specific index in `inputs`, it retrieves the corresponding vector store and merges it with the initialized FAISS vector store.
- The `FAISS.load_local` method is used to load vector stores from the local file system.
The `allow_dangerous_deserialization` parameter is set to True to allow loading of potentially unsafe serialized objects.
"""
# Default folder path
folder_path = "./src/FAISS"
if inputs[0] == "All":
index_name = "speeches_1949_09_12"
db = FAISS.load_local(folder_path=folder_path, index_name=index_name,
embeddings=embeddings, allow_dangerous_deserialization=True)
return db
# Initialize empty db
embedding_function = embeddings #SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
dimensions: int = len(embedding_function.embed_query("dummy"))
db = FAISS(
embedding_function=embedding_function,
index=IndexFlatL2(dimensions),
docstore=InMemoryDocstore(),
index_to_docstore_id={},
normalize_L2=False
)
# Retrieve inputs: 20. Legislaturperiode, 19. Legislaturperiode, ...
for input in inputs:
# Retrieve selected index and merge vector stores
index = input.split(".")[0]
index_name = f'{index}_legislature'
local_db = FAISS.load_local(folder_path=folder_path, index_name=index_name,
embeddings=embeddings, allow_dangerous_deserialization=True)
db.merge_from(local_db)
return db
# Apply RAG by providing the context and the question to the LLM using the predefined template
def RAG(llm, prompt, db, question):
document_chain = create_stuff_documents_chain(llm=llm, prompt=prompt)
retriever = db.as_retriever()
retrieval_chain = create_retrieval_chain(retriever, document_chain)
response = retrieval_chain.invoke({"input": question})
return response
#########
# Dynamically loading vector_db
##########
def get_similar_vectorstore(start_date, end_date, party, base_path='src\FAISS'):
# Get all file names
vector_stores = [store for store in os.listdir(base_path) if store.split(".")[1] == "faiss"]
df = pd.DataFrame(culumns=["file_name", "start_date", "end_date", "date_diff"])
# Extract metadata of file from its name
for file_name in vector_stores:
file_name = file_name.split(".")[0]
file_elements = file_name.split("_")
file_start_date, file_end_date, file_party = file_elements[1], file_elements[2], file_elements[3]
if file_party == party and file_start_date <= start_date:
None