|
from langchain_community.vectorstores import Chroma |
|
from modules.vectorstore.base import VectorStoreBase |
|
import os |
|
|
|
|
|
class ChromaVectorStore(VectorStoreBase): |
|
def __init__(self, config): |
|
self.config = config |
|
self._init_vector_db() |
|
|
|
def _init_vector_db(self): |
|
self.chroma = Chroma() |
|
|
|
def create_database(self, document_chunks, embedding_model): |
|
self.vectorstore = self.chroma.from_documents( |
|
documents=document_chunks, |
|
embedding=embedding_model, |
|
persist_directory=os.path.join( |
|
self.config["vectorstore"]["db_path"], |
|
"db_" |
|
+ self.config["vectorstore"]["db_option"] |
|
+ "_" |
|
+ self.config["vectorstore"]["model"], |
|
), |
|
) |
|
|
|
def load_database(self, embedding_model): |
|
self.vectorstore = Chroma( |
|
persist_directory=os.path.join( |
|
self.config["vectorstore"]["db_path"], |
|
"db_" |
|
+ self.config["vectorstore"]["db_option"] |
|
+ "_" |
|
+ self.config["vectorstore"]["model"], |
|
), |
|
embedding_function=embedding_model, |
|
) |
|
return self.vectorstore |
|
|
|
def as_retriever(self): |
|
return self.vectorstore.as_retriever() |
|
|
|
def __len__(self): |
|
return len(self.vectorstore) |
|
|