from modules.vectorstore.faiss import FaissVectorStore from modules.vectorstore.chroma import ChromaVectorStore from modules.vectorstore.colbert import ColbertVectorStore from modules.vectorstore.raptor import RAPTORVectoreStore from huggingface_hub import snapshot_download import os import shutil class VectorStore: def __init__(self, config): self.config = config self.vectorstore = None self.vectorstore_classes = { "FAISS": FaissVectorStore, "Chroma": ChromaVectorStore, "RAGatouille": ColbertVectorStore, "RAPTOR": RAPTORVectoreStore, } def _create_database( self, document_chunks, document_names, documents, document_metadata, embedding_model, ): db_option = self.config["vectorstore"]["db_option"] vectorstore_class = self.vectorstore_classes.get(db_option) if not vectorstore_class: raise ValueError(f"Invalid db_option: {db_option}") self.vectorstore = vectorstore_class(self.config) if db_option == "RAGatouille": self.vectorstore.create_database( documents, document_names, document_metadata ) else: self.vectorstore.create_database(document_chunks, embedding_model) def _load_database(self, embedding_model): db_option = self.config["vectorstore"]["db_option"] vectorstore_class = self.vectorstore_classes.get(db_option) if not vectorstore_class: raise ValueError(f"Invalid db_option: {db_option}") self.vectorstore = vectorstore_class(self.config) if db_option == "RAGatouille": return self.vectorstore.load_database() else: return self.vectorstore.load_database(embedding_model) def _load_from_HF(self, HF_PATH): # Download the snapshot from Hugging Face Hub # Note: Download goes to the cache directory snapshot_path = snapshot_download( repo_id=HF_PATH, repo_type="dataset", force_download=True, ) # Move the downloaded files to the desired directory target_path = os.path.join( self.config["vectorstore"]["db_path"], "db_" + self.config["vectorstore"]["db_option"], ) # Create target path if it doesn't exist os.makedirs(target_path, exist_ok=True) # move all files and directories from snapshot_path to target_path # target path is used while loading the database for item in os.listdir(snapshot_path): s = os.path.join(snapshot_path, item) d = os.path.join(target_path, item) if os.path.isdir(s): shutil.copytree(s, d, dirs_exist_ok=True) else: shutil.copy2(s, d) def _as_retriever(self): return self.vectorstore.as_retriever() def _get_vectorstore(self): return self.vectorstore def __len__(self): return self.vectorstore.__len__()