File size: 3,832 Bytes
6158da4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import logging
import os
import yaml

from modules.embedding_model_loader import EmbeddingModelLoader
from langchain.vectorstores import FAISS
from modules.data_loader import DataLoader
from modules.constants import *


class VectorDB:
    def __init__(self, config, logger=None):
        self.config = config
        self.db_option = config["embedding_options"]["db_option"]
        self.document_names = None

        # Set up logging to both console and a file
        if logger is None:
            self.logger = logging.getLogger(__name__)
            self.logger.setLevel(logging.INFO)

            # Console Handler
            console_handler = logging.StreamHandler()
            console_handler.setLevel(logging.INFO)
            formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
            console_handler.setFormatter(formatter)
            self.logger.addHandler(console_handler)

            # File Handler
            log_file_path = "vector_db.log"  # Change this to your desired log file path
            file_handler = logging.FileHandler(log_file_path, mode="w")
            file_handler.setLevel(logging.INFO)
            file_handler.setFormatter(formatter)
            self.logger.addHandler(file_handler)
        else:
            self.logger = logger

        self.logger.info("VectorDB instance instantiated")

    def load_files(self):
        files = os.listdir(self.config["embedding_options"]["data_path"])
        files = [
            os.path.join(self.config["embedding_options"]["data_path"], file)
            for file in files
        ]
        return files

    def create_embedding_model(self):
        self.logger.info("Creating embedding function")
        self.embedding_model_loader = EmbeddingModelLoader(self.config)
        self.embedding_model = self.embedding_model_loader.load_embedding_model()

    def initialize_database(self, document_chunks: list, document_names: list):
        # Track token usage
        self.logger.info("Initializing vector_db")
        self.logger.info("\tUsing {} as db_option".format(self.db_option))
        if self.db_option == "FAISS":
            self.vector_db = FAISS.from_documents(
                documents=document_chunks, embedding=self.embedding_model
            )
        self.logger.info("Completed initializing vector_db")

    def create_database(self):
        data_loader = DataLoader(self.config)
        self.logger.info("Loading data")
        files = self.load_files()
        document_chunks, document_names = data_loader.get_chunks(files, [""])
        self.logger.info("Completed loading data")

        self.create_embedding_model()
        self.initialize_database(document_chunks, document_names)

    def save_database(self):
        self.vector_db.save_local(
            os.path.join(
                self.config["embedding_options"]["db_path"],
                "db_"
                + self.config["embedding_options"]["db_option"]
                + "_"
                + self.config["embedding_options"]["model"],
            )
        )
        self.logger.info("Saved database")

    def load_database(self):
        self.create_embedding_model()
        self.vector_db = FAISS.load_local(
            os.path.join(
                self.config["embedding_options"]["db_path"],
                "db_"
                + self.config["embedding_options"]["db_option"]
                + "_"
                + self.config["embedding_options"]["model"],
            ),
            self.embedding_model,
        )
        self.logger.info("Loaded database")
        return self.vector_db


if __name__ == "__main__":
    with open("config.yml", "r") as f:
        config = yaml.safe_load(f)
    print(config)
    vector_db = VectorDB(config)
    vector_db.create_database()
    vector_db.save_database()