File size: 3,084 Bytes
a2ac5f7
 
 
f2daaee
f2beb6a
 
 
a2ac5f7
 
 
 
 
 
ea7b686
 
 
 
f2daaee
ea7b686
a2ac5f7
 
 
 
 
 
 
 
 
ea7b686
 
 
 
 
 
 
 
a2ac5f7
 
 
 
ea7b686
a2ac5f7
 
ea7b686
 
 
 
 
 
 
 
a2ac5f7
ea7b686
 
a2ac5f7
9b7a7cf
f2beb6a
 
 
9b7a7cf
33e5fa6
 
f2beb6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2ac5f7
 
 
 
 
8f6647c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from modules.vectorstore.faiss import FaissVectorStore
from modules.vectorstore.chroma import ChromaVectorStore
from modules.vectorstore.colbert import ColbertVectorStore
from modules.vectorstore.raptor import RAPTORVectoreStore
from huggingface_hub import snapshot_download
import os
import shutil


class VectorStore:
    def __init__(self, config):
        self.config = config
        self.vectorstore = None
        self.vectorstore_classes = {
            "FAISS": FaissVectorStore,
            "Chroma": ChromaVectorStore,
            "RAGatouille": ColbertVectorStore,
            "RAPTOR": RAPTORVectoreStore,
        }

    def _create_database(
        self,
        document_chunks,
        document_names,
        documents,
        document_metadata,
        embedding_model,
    ):
        db_option = self.config["vectorstore"]["db_option"]
        vectorstore_class = self.vectorstore_classes.get(db_option)
        if not vectorstore_class:
            raise ValueError(f"Invalid db_option: {db_option}")

        self.vectorstore = vectorstore_class(self.config)

        if db_option == "RAGatouille":
            self.vectorstore.create_database(
                documents, document_names, document_metadata
            )
        else:
            self.vectorstore.create_database(document_chunks, embedding_model)

    def _load_database(self, embedding_model):
        db_option = self.config["vectorstore"]["db_option"]
        vectorstore_class = self.vectorstore_classes.get(db_option)
        if not vectorstore_class:
            raise ValueError(f"Invalid db_option: {db_option}")

        self.vectorstore = vectorstore_class(self.config)

        if db_option == "RAGatouille":
            return self.vectorstore.load_database()
        else:
            return self.vectorstore.load_database(embedding_model)

    def _load_from_HF(self, HF_PATH):
        # Download the snapshot from Hugging Face Hub
        # Note: Download goes to the cache directory
        snapshot_path = snapshot_download(
            repo_id=HF_PATH,
            repo_type="dataset",
            force_download=True,
        )

        # Move the downloaded files to the desired directory
        target_path = os.path.join(
            self.config["vectorstore"]["db_path"],
            "db_" + self.config["vectorstore"]["db_option"],
        )

        # Create target path if it doesn't exist
        os.makedirs(target_path, exist_ok=True)

        # move all files and directories from snapshot_path to target_path
        # target path is used while loading the database
        for item in os.listdir(snapshot_path):
            s = os.path.join(snapshot_path, item)
            d = os.path.join(target_path, item)
            if os.path.isdir(s):
                shutil.copytree(s, d, dirs_exist_ok=True)
            else:
                shutil.copy2(s, d)

    def _as_retriever(self):
        return self.vectorstore.as_retriever()

    def _get_vectorstore(self):
        return self.vectorstore

    def __len__(self):
        return self.vectorstore.__len__()