File size: 2,705 Bytes
6c945f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# -*-coding:utf-8 -*-
"""
    Base Reader and Document
"""
import os
from dataclasses import dataclass
from dataclasses_json import dataclass_json
from typing import Any, Dict, List, Optional
from glob import glob
from build_index.parser import ParserFactory
from langchain.docstore.document import Document as LCDocument


@dataclass_json
@dataclass
class Document:
    text: str = None
    doc_id: Optional[str] = None
    embedding: Optional[List[float]] = None
    extra_info: Optional[Dict[str, Any]] = None

    def get_text(self):
        return self.text

    def get_doc_id(self):
        return self.doc_id

    def get_embedding(self):
        return self.embedding

    @property
    def extra_info_str(self) -> Optional[str]:
        """Extra info string."""
        if self.extra_info is None:
            return None

        return "\n".join([f"{k}: {str(v)}" for k, v in self.extra_info.items()])

    def __post_init__(self):
        #字段检查
        assert self.text is not None, 'Text Field can not be None'

    def to_langchain_format(self):
        """Convert struct to LangChain document format."""
        metadata = self.extra_info or {}
        return LCDocument(page_content=self.text, metadata=metadata)


class FileReader(object):
    """
    Load file from ./data_dir
    """
    def __init__(self, data_dir=None, folder_name=None, input_files=None, has_meta=True):
        self.data_dir = data_dir
        self.has_meta = has_meta

        if input_files:
            self.input_files = input_files
        else:
            # get all file in data_dir
            ##TODO: 暂不支持data下recursive dir
            dir = os.path.join(data_dir, folder_name, '*')
            self.input_files = glob(dir)
            print(f'{len(self.input_files)} files in {dir}')
            print(self.input_files)

    def load_data(self, concatenate=False) -> List[Document]:
        data_list = []
        metadata_list = []
        for file in self.input_files:
            parser = ParserFactory['pdf']
            if parser is None:
                raise ValueError(f"{file} format doesn't match any sufix supported")
            try:
                data, meta = parser.parse_file(file)
            except Exception as e:
                print(f'{file} parse failed. error = {e}')
                continue
            data_list.append(data)
            if self.has_meta:
                metadata_list.append(meta)

        if concatenate:
            return [Document("\n".join(data_list))]
        elif self.has_meta:
            return [Document(d, extra_info=m) for d, m in zip(data_list, metadata_list)]
        else:
            return [Document(d) for d in data_list]