File size: 3,084 Bytes
a2ac5f7 f2daaee f2beb6a a2ac5f7 ea7b686 f2daaee ea7b686 a2ac5f7 ea7b686 a2ac5f7 ea7b686 a2ac5f7 ea7b686 a2ac5f7 ea7b686 a2ac5f7 9b7a7cf f2beb6a 9b7a7cf 33e5fa6 f2beb6a a2ac5f7 8f6647c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
from modules.vectorstore.faiss import FaissVectorStore
from modules.vectorstore.chroma import ChromaVectorStore
from modules.vectorstore.colbert import ColbertVectorStore
from modules.vectorstore.raptor import RAPTORVectoreStore
from huggingface_hub import snapshot_download
import os
import shutil
class VectorStore:
def __init__(self, config):
self.config = config
self.vectorstore = None
self.vectorstore_classes = {
"FAISS": FaissVectorStore,
"Chroma": ChromaVectorStore,
"RAGatouille": ColbertVectorStore,
"RAPTOR": RAPTORVectoreStore,
}
def _create_database(
self,
document_chunks,
document_names,
documents,
document_metadata,
embedding_model,
):
db_option = self.config["vectorstore"]["db_option"]
vectorstore_class = self.vectorstore_classes.get(db_option)
if not vectorstore_class:
raise ValueError(f"Invalid db_option: {db_option}")
self.vectorstore = vectorstore_class(self.config)
if db_option == "RAGatouille":
self.vectorstore.create_database(
documents, document_names, document_metadata
)
else:
self.vectorstore.create_database(document_chunks, embedding_model)
def _load_database(self, embedding_model):
db_option = self.config["vectorstore"]["db_option"]
vectorstore_class = self.vectorstore_classes.get(db_option)
if not vectorstore_class:
raise ValueError(f"Invalid db_option: {db_option}")
self.vectorstore = vectorstore_class(self.config)
if db_option == "RAGatouille":
return self.vectorstore.load_database()
else:
return self.vectorstore.load_database(embedding_model)
def _load_from_HF(self, HF_PATH):
# Download the snapshot from Hugging Face Hub
# Note: Download goes to the cache directory
snapshot_path = snapshot_download(
repo_id=HF_PATH,
repo_type="dataset",
force_download=True,
)
# Move the downloaded files to the desired directory
target_path = os.path.join(
self.config["vectorstore"]["db_path"],
"db_" + self.config["vectorstore"]["db_option"],
)
# Create target path if it doesn't exist
os.makedirs(target_path, exist_ok=True)
# move all files and directories from snapshot_path to target_path
# target path is used while loading the database
for item in os.listdir(snapshot_path):
s = os.path.join(snapshot_path, item)
d = os.path.join(target_path, item)
if os.path.isdir(s):
shutil.copytree(s, d, dirs_exist_ok=True)
else:
shutil.copy2(s, d)
def _as_retriever(self):
return self.vectorstore.as_retriever()
def _get_vectorstore(self):
return self.vectorstore
def __len__(self):
return self.vectorstore.__len__()
|