tutor_dev / modules /vectorstore /vectorstore.py
XThomasBU
initial commit
d92c997
raw
history blame
3.08 kB
from modules.vectorstore.faiss import FaissVectorStore
from modules.vectorstore.chroma import ChromaVectorStore
from modules.vectorstore.colbert import ColbertVectorStore
from modules.vectorstore.raptor import RAPTORVectoreStore
from huggingface_hub import snapshot_download
import os
import shutil
class VectorStore:
def __init__(self, config):
self.config = config
self.vectorstore = None
self.vectorstore_classes = {
"FAISS": FaissVectorStore,
"Chroma": ChromaVectorStore,
"RAGatouille": ColbertVectorStore,
"RAPTOR": RAPTORVectoreStore,
}
def _create_database(
self,
document_chunks,
document_names,
documents,
document_metadata,
embedding_model,
):
db_option = self.config["vectorstore"]["db_option"]
vectorstore_class = self.vectorstore_classes.get(db_option)
if not vectorstore_class:
raise ValueError(f"Invalid db_option: {db_option}")
self.vectorstore = vectorstore_class(self.config)
if db_option == "RAGatouille":
self.vectorstore.create_database(
documents, document_names, document_metadata
)
else:
self.vectorstore.create_database(document_chunks, embedding_model)
def _load_database(self, embedding_model):
db_option = self.config["vectorstore"]["db_option"]
vectorstore_class = self.vectorstore_classes.get(db_option)
if not vectorstore_class:
raise ValueError(f"Invalid db_option: {db_option}")
self.vectorstore = vectorstore_class(self.config)
if db_option == "RAGatouille":
return self.vectorstore.load_database()
else:
return self.vectorstore.load_database(embedding_model)
def _load_from_HF(self, HF_PATH):
# Download the snapshot from Hugging Face Hub
# Note: Download goes to the cache directory
snapshot_path = snapshot_download(
repo_id=HF_PATH,
repo_type="dataset",
force_download=True,
)
# Move the downloaded files to the desired directory
target_path = os.path.join(
self.config["vectorstore"]["db_path"],
"db_" + self.config["vectorstore"]["db_option"],
)
# Create target path if it doesn't exist
os.makedirs(target_path, exist_ok=True)
# move all files and directories from snapshot_path to target_path
# target path is used while loading the database
for item in os.listdir(snapshot_path):
s = os.path.join(snapshot_path, item)
d = os.path.join(target_path, item)
if os.path.isdir(s):
shutil.copytree(s, d, dirs_exist_ok=True)
else:
shutil.copy2(s, d)
def _as_retriever(self):
return self.vectorstore.as_retriever()
def _get_vectorstore(self):
return self.vectorstore
def __len__(self):
return self.vectorstore.__len__()