File size: 3,475 Bytes
e5cd1d3
 
8f6647c
 
 
 
e5cd1d3
8f6647c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5cd1d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f6647c
e5cd1d3
 
 
d697aa5
e5cd1d3
 
 
 
 
 
8f6647c
 
 
 
 
 
 
e5cd1d3
 
 
 
8f6647c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from ragatouille import RAGPretrainedModel
from modules.vectorstore.base import VectorStoreBase
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun, Callbacks
from langchain_core.documents import Document
from typing import Any, List, Optional, Sequence
import os
import json


class RAGatouilleLangChainRetrieverWithScore(BaseRetriever):
    model: Any
    kwargs: dict = {}

    def _get_relevant_documents(
        self,
        query: str,
        *,
        run_manager: CallbackManagerForRetrieverRun,  # noqa
    ) -> List[Document]:
        """Get documents relevant to a query."""
        docs = self.model.search(query, **self.kwargs)
        return [
            Document(
                page_content=doc["content"],
                metadata={**doc.get("document_metadata", {}), "score": doc["score"]},
            )
            for doc in docs
        ]

    async def _aget_relevant_documents(
        self,
        query: str,
        *,
        run_manager: CallbackManagerForRetrieverRun,  # noqa
    ) -> List[Document]:
        """Get documents relevant to a query."""
        docs = self.model.search(query, **self.kwargs)
        return [
            Document(
                page_content=doc["content"],
                metadata={**doc.get("document_metadata", {}), "score": doc["score"]},
            )
            for doc in docs
        ]


class RAGPretrainedModel(RAGPretrainedModel):
    """
    Adding len property to RAGPretrainedModel
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._document_count = 0

    def set_document_count(self, count):
        self._document_count = count

    def __len__(self):
        return self._document_count

    def as_langchain_retriever(self, **kwargs: Any) -> BaseRetriever:
        return RAGatouilleLangChainRetrieverWithScore(model=self, kwargs=kwargs)


class ColbertVectorStore(VectorStoreBase):
    def __init__(self, config):
        self.config = config
        self._init_vector_db()

    def _init_vector_db(self):
        self.colbert = RAGPretrainedModel.from_pretrained(
            "colbert-ir/colbertv2.0",
            index_root=os.path.join(
                self.config["vectorstore"]["db_path"],
                "db_" + self.config["vectorstore"]["db_option"],
            ),
        )

    def create_database(self, documents, document_names, document_metadata):
        index_path = self.colbert.index(
            index_name="new_idx",
            collection=documents,
            document_ids=document_names,
            document_metadatas=document_metadata,
        )
        self.colbert.set_document_count(len(document_names))

    def load_database(self):
        path = os.path.join(
            os.getcwd(),
            self.config["vectorstore"]["db_path"],
            "db_" + self.config["vectorstore"]["db_option"],
        )
        self.vectorstore = RAGPretrainedModel.from_index(
            f"{path}/colbert/indexes/new_idx"
        )

        index_metadata = json.load(
            open(f"{path}/colbert/indexes/new_idx/0.metadata.json")
        )
        num_documents = index_metadata["num_passages"]
        self.vectorstore.set_document_count(num_documents)

        return self.vectorstore

    def as_retriever(self):
        return self.vectorstore.as_retriever()

    def __len__(self):
        return len(self.vectorstore)