File size: 1,933 Bytes
6c945f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# -*-coding:utf-8 -*-
"""
    Build Faiss Index for all document in ./data
"""

import argparse
from build_index.base import FileReader, Document
from build_index.process import split_documents
from build_index.pricing import check_price
from build_index.doc2vec import doc2vec


def process_directory(directory, folder_name, model, max_tokens):
    reader = FileReader(directory, folder_name)
    docs = reader.load_data()
    print(f'Before chunking {len(docs)} docs')
    new_docs = split_documents(documents=docs, max_tokens=max_tokens)
    print(f'After chunking {len(new_docs)} docs')
    docs = [Document.to_langchain_format(doc) for doc in new_docs]

    doc2vec(docs, model, folder_name)
    return docs


def process_files(files, model, max_tokens):
    reader = FileReader(input_files=files)
    docs = reader.load_data()
    print(f'Before chunking {len(docs)} docs')
    new_docs = split_documents(documents=docs, max_tokens=max_tokens)
    print(f'After chunking {len(new_docs)} docs')
    docs = [Document.to_langchain_format(doc) for doc in new_docs]

    doc2vec(docs, model)
    return docs


if __name__ =='__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("directory", type=str, deafult='./data')
    parser.add_argument("folder", type=str, default='半导体')
    parser.add_argument("model", type=str, default='mpnet')
    parser.add_argument('max_tokens', type=int, default=514)
    args = parser.parse_args()

    process_directory(args.directory, args.folder, args.model, args.max_token)

    process_directory('./data', '半导体', 'mpnet', 514)
    process_directory('./data', '数字经济', 'cohere', 514)
    process_directory('./data', '数字经济', 'cohere', 514)
    process_directory('./data', '两会', 'cohere', 514)

    from glob import glob
    files = glob('./data/两会/*.pdf')
    files = files[:3]
    process_files(files[:3], 'cohere', 514)