from ragatouille import RAGPretrainedModel from modules.vectorstore.base import VectorStoreBase from langchain_core.retrievers import BaseRetriever from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun, Callbacks from langchain_core.documents import Document from typing import Any, List, Optional, Sequence import os import json class RAGatouilleLangChainRetrieverWithScore(BaseRetriever): model: Any kwargs: dict = {} def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun, # noqa ) -> List[Document]: """Get documents relevant to a query.""" docs = self.model.search(query, **self.kwargs) return [ Document( page_content=doc["content"], metadata={**doc.get("document_metadata", {}), "score": doc["score"]}, ) for doc in docs ] async def _aget_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun, # noqa ) -> List[Document]: """Get documents relevant to a query.""" docs = self.model.search(query, **self.kwargs) return [ Document( page_content=doc["content"], metadata={**doc.get("document_metadata", {}), "score": doc["score"]}, ) for doc in docs ] class RAGPretrainedModel(RAGPretrainedModel): """ Adding len property to RAGPretrainedModel """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._document_count = 0 def set_document_count(self, count): self._document_count = count def __len__(self): return self._document_count def as_langchain_retriever(self, **kwargs: Any) -> BaseRetriever: return RAGatouilleLangChainRetrieverWithScore(model=self, kwargs=kwargs) class ColbertVectorStore(VectorStoreBase): def __init__(self, config): self.config = config self._init_vector_db() def _init_vector_db(self): self.colbert = RAGPretrainedModel.from_pretrained( "colbert-ir/colbertv2.0", index_root=os.path.join( self.config["vectorstore"]["db_path"], "db_" + self.config["vectorstore"]["db_option"], ), ) def create_database(self, documents, document_names, document_metadata): index_path = self.colbert.index( index_name="new_idx", collection=documents, document_ids=document_names, document_metadatas=document_metadata, ) self.colbert.set_document_count(len(document_names)) def load_database(self): path = os.path.join( os.getcwd(), self.config["vectorstore"]["db_path"], "db_" + self.config["vectorstore"]["db_option"], ) self.vectorstore = RAGPretrainedModel.from_index( f"{path}/colbert/indexes/new_idx" ) index_metadata = json.load( open(f"{path}/colbert/indexes/new_idx/0.metadata.json") ) num_documents = index_metadata["num_passages"] self.vectorstore.set_document_count(num_documents) return self.vectorstore def as_retriever(self): return self.vectorstore.as_retriever() def __len__(self): return len(self.vectorstore)