|
from modules.retriever.faiss_retriever import FaissRetriever |
|
from modules.retriever.chroma_retriever import ChromaRetriever |
|
from modules.retriever.colbert_retriever import ColbertRetriever |
|
from modules.retriever.raptor_retriever import RaptorRetriever |
|
|
|
|
|
class Retriever: |
|
def __init__(self, config): |
|
self.config = config |
|
self.retriever_classes = { |
|
"FAISS": FaissRetriever, |
|
"Chroma": ChromaRetriever, |
|
"RAGatouille": ColbertRetriever, |
|
"RAPTOR": RaptorRetriever, |
|
} |
|
self._create_retriever() |
|
|
|
def _create_retriever(self): |
|
db_option = self.config["vectorstore"]["db_option"] |
|
retriever_class = self.retriever_classes.get(db_option) |
|
if not retriever_class: |
|
raise ValueError(f"Invalid db_option: {db_option}") |
|
self.retriever = retriever_class() |
|
|
|
def _return_retriever(self, db): |
|
return self.retriever.return_retriever(db, self.config) |
|
|