Spaces:
Running
Running
import argparse | |
from itertools import islice | |
from pathlib import Path | |
from tqdm import tqdm | |
import torch | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import Qdrant | |
from loaders import get_loader, LOADER_NAMES | |
from config import DB_CONFIG | |
CHUNK_SIZE = 500 | |
def get_text_chunk(docs): | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=CHUNK_SIZE, chunk_overlap=0 | |
) | |
texts = text_splitter.split_documents(docs) | |
return texts | |
def batched(iterable, *, size=100): | |
"Batch data into tuples of length n. The last batch may be shorter." | |
# batched('ABCDEFG', 3) --> ABC DEF G | |
if size < 1: | |
raise ValueError('n must be at least one') | |
it = iter(iterable) | |
while batch := tuple(islice(it, size)): | |
yield batch | |
def store(texts): | |
model_name = "intfloat/multilingual-e5-large" | |
model_kwargs = {"device": "cuda:0" if torch.cuda.is_available() else "cpu"} | |
encode_kwargs = {"normalize_embeddings": False} | |
embeddings = HuggingFaceEmbeddings( | |
model_name=model_name, | |
model_kwargs=model_kwargs, | |
encode_kwargs=encode_kwargs, | |
) | |
db_url, db_api_key, db_collection_name = DB_CONFIG | |
for batch in tqdm(batched(texts, size=100)): | |
_ = Qdrant.from_documents( | |
batch, | |
embeddings, | |
url=db_url, | |
api_key=db_api_key, | |
collection_name=db_collection_name, | |
) | |
def get_parser(): | |
p = argparse.ArgumentParser() | |
p.add_argument("index", type=str) | |
p.add_argument("inputfile", metavar="INPUTFILE", type=str) | |
p.add_argument("-l", "--loader", type=str, choices=LOADER_NAMES, required=True) | |
return p | |
def index_annotated_docs(docs, index): | |
for doc in docs: | |
doc.metadata["index"] = index | |
yield doc | |
def main(): | |
""" | |
$ python store.py --loader wikipage "index" "FILE_PATH" | |
$ python store.py -l wikipage wiki data/wiki.json | |
$ python store.py -l rtdhtmlpage django ./docs.djangoproject.com/ | |
""" | |
p = get_parser() | |
args = p.parse_args() | |
loader = get_loader( | |
args.loader, | |
inputfile=Path(args.inputfile), | |
) | |
docs = loader.lazy_load() | |
texts = get_text_chunk(index_annotated_docs(docs, args.index)) | |
store(texts) | |
if __name__ == "__main__": | |
main() | |