tutor_dev / apps /chainlit_base /config /config_manager.py
XThomasBU
updates
ae33464
raw
history blame
5.61 kB
from pydantic import BaseModel, conint, confloat, HttpUrl
from typing import Optional, List
import yaml
class FaissParams(BaseModel):
index_path: str = "vectorstores/faiss.index"
index_type: str = "Flat" # Options: [Flat, HNSW, IVF]
index_dimension: conint(gt=0) = 384
index_nlist: conint(gt=0) = 100
index_nprobe: conint(gt=0) = 10
class ColbertParams(BaseModel):
index_name: str = "new_idx"
class VectorStoreConfig(BaseModel):
load_from_HF: bool = True
reparse_files: bool = True
data_path: str = "storage/data"
url_file_path: str = "storage/data/urls.txt"
expand_urls: bool = True
db_option: str = "RAGatouille" # Options: [FAISS, Chroma, RAGatouille, RAPTOR]
db_path: str = "vectorstores"
model: str = (
"sentence-transformers/all-MiniLM-L6-v2" # Options: [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002]
)
search_top_k: conint(gt=0) = 3
score_threshold: confloat(ge=0.0, le=1.0) = 0.2
faiss_params: Optional[FaissParams] = None
colbert_params: Optional[ColbertParams] = None
class OpenAIParams(BaseModel):
temperature: confloat(ge=0.0, le=1.0) = 0.7
class LocalLLMParams(BaseModel):
temperature: confloat(ge=0.0, le=1.0) = 0.7
repo_id: str = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF" # HuggingFace repo id
filename: str = (
"tinyllama-1.1b-chat-v1.0.Q5_0.gguf" # Specific name of gguf file in the repo
)
model_path: str = (
"storage/models/tinyllama-1.1b-chat-v1.0.Q5_0.gguf" # Path to the model file
)
class LLMParams(BaseModel):
llm_arch: str = "langchain" # Options: [langchain]
use_history: bool = True
generate_follow_up: bool = False
memory_window: conint(ge=1) = 3
llm_style: str = "Normal" # Options: [Normal, ELI5]
llm_loader: str = (
"gpt-4o-mini" # Options: [local_llm, gpt-3.5-turbo-1106, gpt-4, gpt-4o-mini]
)
openai_params: Optional[OpenAIParams] = None
local_llm_params: Optional[LocalLLMParams] = None
stream: bool = False
pdf_reader: str = "gpt" # Options: [llama, pymupdf, gpt]
class ChatLoggingConfig(BaseModel):
log_chat: bool = True
platform: str = "literalai"
callbacks: bool = True
class SplitterOptions(BaseModel):
use_splitter: bool = True
split_by_token: bool = True
remove_leftover_delimiters: bool = True
remove_chunks: bool = False
chunking_mode: str = "semantic" # Options: [fixed, semantic]
chunk_size: conint(gt=0) = 300
chunk_overlap: conint(ge=0) = 30
chunk_separators: List[str] = ["\n\n", "\n", " ", ""]
front_chunks_to_remove: Optional[conint(ge=0)] = None
last_chunks_to_remove: Optional[conint(ge=0)] = None
delimiters_to_remove: List[str] = ["\t", "\n", " ", " "]
class RetrieverConfig(BaseModel):
retriever_hf_paths: dict[str, str] = {"RAGatouille": "XThomasBU/Colbert_Index"}
class MetadataConfig(BaseModel):
metadata_links: List[HttpUrl] = [
"https://dl4ds.github.io/sp2024/lectures/",
"https://dl4ds.github.io/sp2024/schedule/",
]
slide_base_link: HttpUrl = "https://dl4ds.github.io"
class APIConfig(BaseModel):
timeout: conint(gt=0) = 60
class Config(BaseModel):
log_dir: str = "storage/logs"
log_chunk_dir: str = "storage/logs/chunks"
device: str = "cpu" # Options: ['cuda', 'cpu']
vectorstore: VectorStoreConfig
llm_params: LLMParams
chat_logging: ChatLoggingConfig
splitter_options: SplitterOptions
retriever: RetrieverConfig
metadata: MetadataConfig
api_config: APIConfig
class ConfigManager:
def __init__(self, config_path: str, project_config_path: str):
self.config_path = config_path
self.project_config_path = project_config_path
self.config = self.load_config()
self.validate_config()
def load_config(self) -> Config:
with open(self.config_path, "r") as f:
config_data = yaml.safe_load(f)
with open(self.project_config_path, "r") as f:
project_config_data = yaml.safe_load(f)
# Merge the two configurations
merged_config = {**config_data, **project_config_data}
return Config(**merged_config)
def get_config(self) -> Config:
return ConfigWrapper(self.config)
def validate_config(self):
# If any required fields are missing, raise an error
# required_fields = [
# "vectorstore", "llm_params", "chat_logging", "splitter_options",
# "retriever", "metadata", "token_config", "misc", "api_config"
# ]
# for field in required_fields:
# if not hasattr(self.config, field):
# raise ValueError(f"Missing required configuration field: {field}")
# # Validate types of specific fields
# if not isinstance(self.config.vectorstore, VectorStoreConfig):
# raise TypeError("vectorstore must be an instance of VectorStoreConfig")
# if not isinstance(self.config.llm_params, LLMParams):
# raise TypeError("llm_params must be an instance of LLMParams")
pass
class ConfigWrapper:
def __init__(self, config: Config):
self._config = config
def __getitem__(self, key):
return getattr(self._config, key)
def __getattr__(self, name):
return getattr(self._config, name)
def dict(self):
return self._config.dict()
# Usage
config_manager = ConfigManager(
config_path="config/config.yml", project_config_path="config/project_config.yml"
)
# config = config_manager.get_config().dict()