Merge pull request #35 from DL4DS/hf_index_load
Browse files
code/modules/config/config.yml
CHANGED
@@ -3,11 +3,13 @@ log_chunk_dir: '../storage/logs/chunks' # str
|
|
3 |
device: 'cpu' # str [cuda, cpu]
|
4 |
|
5 |
vectorstore:
|
|
|
|
|
6 |
embedd_files: False # bool
|
7 |
data_path: '../storage/data' # str
|
8 |
url_file_path: '../storage/data/urls.txt' # str
|
9 |
expand_urls: True # bool
|
10 |
-
db_option : '
|
11 |
db_path : '../vectorstores' # str
|
12 |
model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
|
13 |
search_top_k : 3 # int
|
|
|
3 |
device: 'cpu' # str [cuda, cpu]
|
4 |
|
5 |
vectorstore:
|
6 |
+
load_from_HF: True # bool
|
7 |
+
HF_path: "XThomasBU/Colbert_Index" # str
|
8 |
embedd_files: False # bool
|
9 |
data_path: '../storage/data' # str
|
10 |
url_file_path: '../storage/data/urls.txt' # str
|
11 |
expand_urls: True # bool
|
12 |
+
db_option : 'RAGatouille' # str [FAISS, Chroma, RAGatouille, RAPTOR]
|
13 |
db_path : '../vectorstores' # str
|
14 |
model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
|
15 |
search_top_k : 3 # int
|
code/modules/vectorstore/store_manager.py
CHANGED
@@ -143,6 +143,14 @@ class VectorStoreManager:
|
|
143 |
self.logger.info("Loaded database")
|
144 |
return self.loaded_vector_db
|
145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
if __name__ == "__main__":
|
148 |
import yaml
|
@@ -152,7 +160,10 @@ if __name__ == "__main__":
|
|
152 |
print(config)
|
153 |
print(f"Trying to create database with config: {config}")
|
154 |
vector_db = VectorStoreManager(config)
|
155 |
-
|
|
|
|
|
|
|
156 |
print("Created database")
|
157 |
|
158 |
print(f"Trying to load the database")
|
|
|
143 |
self.logger.info("Loaded database")
|
144 |
return self.loaded_vector_db
|
145 |
|
146 |
+
def load_from_HF(self):
|
147 |
+
start_time = time.time() # Start time for loading database
|
148 |
+
self.vector_db._load_from_HF()
|
149 |
+
end_time = time.time()
|
150 |
+
self.logger.info(
|
151 |
+
f"Time taken to load database from Hugging Face: {end_time - start_time} seconds"
|
152 |
+
)
|
153 |
+
|
154 |
|
155 |
if __name__ == "__main__":
|
156 |
import yaml
|
|
|
160 |
print(config)
|
161 |
print(f"Trying to create database with config: {config}")
|
162 |
vector_db = VectorStoreManager(config)
|
163 |
+
if config["vectorstore"]["load_from_HF"] and "HF_path" in config["vectorstore"]:
|
164 |
+
vector_db.load_from_HF()
|
165 |
+
else:
|
166 |
+
vector_db.create_database()
|
167 |
print("Created database")
|
168 |
|
169 |
print(f"Trying to load the database")
|
code/modules/vectorstore/vectorstore.py
CHANGED
@@ -2,6 +2,9 @@ from modules.vectorstore.faiss import FaissVectorStore
|
|
2 |
from modules.vectorstore.chroma import ChromaVectorStore
|
3 |
from modules.vectorstore.colbert import ColbertVectorStore
|
4 |
from modules.vectorstore.raptor import RAPTORVectoreStore
|
|
|
|
|
|
|
5 |
|
6 |
|
7 |
class VectorStore:
|
@@ -50,6 +53,34 @@ class VectorStore:
|
|
50 |
else:
|
51 |
return self.vectorstore.load_database(embedding_model)
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
def _as_retriever(self):
|
54 |
return self.vectorstore.as_retriever()
|
55 |
|
|
|
2 |
from modules.vectorstore.chroma import ChromaVectorStore
|
3 |
from modules.vectorstore.colbert import ColbertVectorStore
|
4 |
from modules.vectorstore.raptor import RAPTORVectoreStore
|
5 |
+
from huggingface_hub import snapshot_download
|
6 |
+
import os
|
7 |
+
import shutil
|
8 |
|
9 |
|
10 |
class VectorStore:
|
|
|
53 |
else:
|
54 |
return self.vectorstore.load_database(embedding_model)
|
55 |
|
56 |
+
def _load_from_HF(self):
|
57 |
+
# Download the snapshot from Hugging Face Hub
|
58 |
+
# Note: Download goes to the cache directory
|
59 |
+
snapshot_path = snapshot_download(
|
60 |
+
repo_id=self.config["vectorstore"]["HF_path"],
|
61 |
+
repo_type="dataset",
|
62 |
+
force_download=True,
|
63 |
+
)
|
64 |
+
|
65 |
+
# Move the downloaded files to the desired directory
|
66 |
+
target_path = os.path.join(
|
67 |
+
self.config["vectorstore"]["db_path"],
|
68 |
+
"db_" + self.config["vectorstore"]["db_option"],
|
69 |
+
)
|
70 |
+
|
71 |
+
# Create target path if it doesn't exist
|
72 |
+
os.makedirs(target_path, exist_ok=True)
|
73 |
+
|
74 |
+
# move all files and directories from snapshot_path to target_path
|
75 |
+
# target path is used while loading the database
|
76 |
+
for item in os.listdir(snapshot_path):
|
77 |
+
s = os.path.join(snapshot_path, item)
|
78 |
+
d = os.path.join(target_path, item)
|
79 |
+
if os.path.isdir(s):
|
80 |
+
shutil.copytree(s, d, dirs_exist_ok=True)
|
81 |
+
else:
|
82 |
+
shutil.copy2(s, d)
|
83 |
+
|
84 |
def _as_retriever(self):
|
85 |
return self.vectorstore.as_retriever()
|
86 |
|