|
from modules.vectorstore.faiss import FaissVectorStore |
|
from modules.vectorstore.chroma import ChromaVectorStore |
|
from modules.vectorstore.colbert import ColbertVectorStore |
|
from modules.vectorstore.raptor import RAPTORVectoreStore |
|
from huggingface_hub import snapshot_download |
|
import os |
|
import shutil |
|
|
|
|
|
class VectorStore: |
|
def __init__(self, config): |
|
self.config = config |
|
self.vectorstore = None |
|
self.vectorstore_classes = { |
|
"FAISS": FaissVectorStore, |
|
"Chroma": ChromaVectorStore, |
|
"RAGatouille": ColbertVectorStore, |
|
"RAPTOR": RAPTORVectoreStore, |
|
} |
|
|
|
def _create_database( |
|
self, |
|
document_chunks, |
|
document_names, |
|
documents, |
|
document_metadata, |
|
embedding_model, |
|
): |
|
db_option = self.config["vectorstore"]["db_option"] |
|
vectorstore_class = self.vectorstore_classes.get(db_option) |
|
if not vectorstore_class: |
|
raise ValueError(f"Invalid db_option: {db_option}") |
|
|
|
self.vectorstore = vectorstore_class(self.config) |
|
|
|
if db_option == "RAGatouille": |
|
self.vectorstore.create_database( |
|
documents, document_names, document_metadata |
|
) |
|
else: |
|
self.vectorstore.create_database(document_chunks, embedding_model) |
|
|
|
def _load_database(self, embedding_model): |
|
db_option = self.config["vectorstore"]["db_option"] |
|
vectorstore_class = self.vectorstore_classes.get(db_option) |
|
if not vectorstore_class: |
|
raise ValueError(f"Invalid db_option: {db_option}") |
|
|
|
self.vectorstore = vectorstore_class(self.config) |
|
|
|
if db_option == "RAGatouille": |
|
return self.vectorstore.load_database() |
|
else: |
|
return self.vectorstore.load_database(embedding_model) |
|
|
|
def _load_from_HF(self, HF_PATH): |
|
|
|
|
|
snapshot_path = snapshot_download( |
|
repo_id=HF_PATH, |
|
repo_type="dataset", |
|
force_download=True, |
|
) |
|
|
|
|
|
target_path = os.path.join( |
|
self.config["vectorstore"]["db_path"], |
|
"db_" + self.config["vectorstore"]["db_option"], |
|
) |
|
|
|
|
|
os.makedirs(target_path, exist_ok=True) |
|
|
|
|
|
|
|
for item in os.listdir(snapshot_path): |
|
s = os.path.join(snapshot_path, item) |
|
d = os.path.join(target_path, item) |
|
if os.path.isdir(s): |
|
shutil.copytree(s, d, dirs_exist_ok=True) |
|
else: |
|
shutil.copy2(s, d) |
|
|
|
def _as_retriever(self): |
|
return self.vectorstore.as_retriever() |
|
|
|
def _get_vectorstore(self): |
|
return self.vectorstore |
|
|
|
def __len__(self): |
|
return self.vectorstore.__len__() |
|
|