File size: 6,035 Bytes
ae33464
60929fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b69151
 
60929fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5aac56
60929fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae33464
60929fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
from pydantic import BaseModel, conint, confloat, HttpUrl
from typing import Optional, List
import yaml


class FaissParams(BaseModel):
    index_path: str = "vectorstores/faiss.index"
    index_type: str = "Flat"  # Options: [Flat, HNSW, IVF]
    index_dimension: conint(gt=0) = 384
    index_nlist: conint(gt=0) = 100
    index_nprobe: conint(gt=0) = 10


class ColbertParams(BaseModel):
    index_name: str = "new_idx"


class VectorStoreConfig(BaseModel):
    load_from_HF: bool = True
    reparse_files: bool = True
    data_path: str = "storage/data"
    url_file_path: str = "storage/data/urls.txt"
    expand_urls: bool = True
    db_option: str = "RAGatouille"  # Options: [FAISS, Chroma, RAGatouille, RAPTOR]
    db_path: str = "vectorstores"
    model: str = (
        # Options: [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002]
        "sentence-transformers/all-MiniLM-L6-v2"
    )
    search_top_k: conint(gt=0) = 3
    score_threshold: confloat(ge=0.0, le=1.0) = 0.2

    faiss_params: Optional[FaissParams] = None
    colbert_params: Optional[ColbertParams] = None


class OpenAIParams(BaseModel):
    temperature: confloat(ge=0.0, le=1.0) = 0.7


class LocalLLMParams(BaseModel):
    temperature: confloat(ge=0.0, le=1.0) = 0.7
    repo_id: str = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"  # HuggingFace repo id
    filename: str = (
        "tinyllama-1.1b-chat-v1.0.Q5_0.gguf"  # Specific name of gguf file in the repo
    )
    model_path: str = (
        "storage/models/tinyllama-1.1b-chat-v1.0.Q5_0.gguf"  # Path to the model file
    )


class LLMParams(BaseModel):
    llm_arch: str = "langchain"  # Options: [langchain]
    use_history: bool = True
    generate_follow_up: bool = False
    memory_window: conint(ge=1) = 3
    llm_style: str = "Normal"  # Options: [Normal, ELI5]
    llm_loader: str = (
        "gpt-4o-mini"  # Options: [local_llm, gpt-3.5-turbo-1106, gpt-4, gpt-4o-mini]
    )
    openai_params: Optional[OpenAIParams] = None
    local_llm_params: Optional[LocalLLMParams] = None
    stream: bool = False
    pdf_reader: str = "gpt"  # Options: [llama, pymupdf, gpt]


class ChatLoggingConfig(BaseModel):
    log_chat: bool = True
    platform: str = "literalai"
    callbacks: bool = True


class SplitterOptions(BaseModel):
    use_splitter: bool = True
    split_by_token: bool = True
    remove_leftover_delimiters: bool = True
    remove_chunks: bool = False
    chunking_mode: str = "semantic"  # Options: [fixed, semantic]
    chunk_size: conint(gt=0) = 300
    chunk_overlap: conint(ge=0) = 30
    chunk_separators: List[str] = ["\n\n", "\n", " ", ""]
    front_chunks_to_remove: Optional[conint(ge=0)] = None
    last_chunks_to_remove: Optional[conint(ge=0)] = None
    delimiters_to_remove: List[str] = ["\t", "\n", "   ", "  "]


class RetrieverConfig(BaseModel):
    retriever_hf_paths: dict[str, str] = {"RAGatouille": "XThomasBU/Colbert_Index"}


class MetadataConfig(BaseModel):
    metadata_links: List[HttpUrl] = [
        "https://dl4ds.github.io/sp2024/lectures/",
        "https://dl4ds.github.io/sp2024/schedule/",
    ]
    slide_base_link: HttpUrl = "https://dl4ds.github.io"


class TokenConfig(BaseModel):
    cooldown_time: conint(gt=0) = 60
    regen_time: conint(gt=0) = 180
    tokens_left: conint(gt=0) = 2000
    all_time_tokens_allocated: conint(gt=0) = 1000000


class MiscConfig(BaseModel):
    github_repo: HttpUrl = "https://github.com/edubotics-ai/edubot-core"
    docs_website: HttpUrl = "https://dl4ds.github.io/dl4ds_tutor/"


class APIConfig(BaseModel):
    timeout: conint(gt=0) = 60


class Config(BaseModel):
    log_dir: str = "storage/logs"
    log_chunk_dir: str = "storage/logs/chunks"
    device: str = "cpu"  # Options: ['cuda', 'cpu']

    vectorstore: VectorStoreConfig
    llm_params: LLMParams
    chat_logging: ChatLoggingConfig
    splitter_options: SplitterOptions
    retriever: RetrieverConfig
    metadata: MetadataConfig
    token_config: TokenConfig
    misc: MiscConfig
    api_config: APIConfig


class ConfigManager:
    def __init__(self, config_path: str, project_config_path: str):
        self.config_path = config_path
        self.project_config_path = project_config_path
        self.config = self.load_config()
        self.validate_config()

    def load_config(self) -> Config:
        with open(self.config_path, "r") as f:
            config_data = yaml.safe_load(f)

        with open(self.project_config_path, "r") as f:
            project_config_data = yaml.safe_load(f)

        # Merge the two configurations
        merged_config = {**config_data, **project_config_data}

        return Config(**merged_config)

    def get_config(self) -> Config:
        return ConfigWrapper(self.config)

    def validate_config(self):
        # If any required fields are missing, raise an error
        # required_fields = [
        #     "vectorstore", "llm_params", "chat_logging", "splitter_options",
        #     "retriever", "metadata", "token_config", "misc", "api_config"
        # ]
        # for field in required_fields:
        #     if not hasattr(self.config, field):
        #         raise ValueError(f"Missing required configuration field: {field}")

        # # Validate types of specific fields
        # if not isinstance(self.config.vectorstore, VectorStoreConfig):
        #     raise TypeError("vectorstore must be an instance of VectorStoreConfig")
        # if not isinstance(self.config.llm_params, LLMParams):
        #     raise TypeError("llm_params must be an instance of LLMParams")
        pass


class ConfigWrapper:
    def __init__(self, config: Config):
        self._config = config

    def __getitem__(self, key):
        return getattr(self._config, key)

    def __getattr__(self, name):
        return getattr(self._config, name)

    def dict(self):
        return self._config.dict()


# Usage
config_manager = ConfigManager(
    config_path="config/config.yml", project_config_path="config/project_config.yml"
)
# config = config_manager.get_config().dict()