xl2533's picture
initial
6c945f2
raw
history blame
1.93 kB
# -*-coding:utf-8 -*-
"""
Build Faiss Index for all document in ./data
"""
import argparse
from build_index.base import FileReader, Document
from build_index.process import split_documents
from build_index.pricing import check_price
from build_index.doc2vec import doc2vec
def process_directory(directory, folder_name, model, max_tokens):
reader = FileReader(directory, folder_name)
docs = reader.load_data()
print(f'Before chunking {len(docs)} docs')
new_docs = split_documents(documents=docs, max_tokens=max_tokens)
print(f'After chunking {len(new_docs)} docs')
docs = [Document.to_langchain_format(doc) for doc in new_docs]
doc2vec(docs, model, folder_name)
return docs
def process_files(files, model, max_tokens):
reader = FileReader(input_files=files)
docs = reader.load_data()
print(f'Before chunking {len(docs)} docs')
new_docs = split_documents(documents=docs, max_tokens=max_tokens)
print(f'After chunking {len(new_docs)} docs')
docs = [Document.to_langchain_format(doc) for doc in new_docs]
doc2vec(docs, model)
return docs
if __name__ =='__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("directory", type=str, deafult='./data')
parser.add_argument("folder", type=str, default='半导体')
parser.add_argument("model", type=str, default='mpnet')
parser.add_argument('max_tokens', type=int, default=514)
args = parser.parse_args()
process_directory(args.directory, args.folder, args.model, args.max_token)
process_directory('./data', '半导体', 'mpnet', 514)
process_directory('./data', '数字经济', 'cohere', 514)
process_directory('./data', '数字经济', 'cohere', 514)
process_directory('./data', '两会', 'cohere', 514)
from glob import glob
files = glob('./data/两会/*.pdf')
files = files[:3]
process_files(files[:3], 'cohere', 514)