|
from modules.chat.helpers import get_prompt |
|
from modules.chat.chat_model_loader import ChatModelLoader |
|
from modules.vectorstore.store_manager import VectorStoreManager |
|
from modules.retriever.retriever import Retriever |
|
from modules.chat.langchain.langchain_rag import ( |
|
Langchain_RAG_V2, |
|
QuestionGenerator, |
|
) |
|
|
|
|
|
class LLMTutor: |
|
def __init__(self, config, user, logger=None): |
|
""" |
|
Initialize the LLMTutor class. |
|
|
|
Args: |
|
config (dict): Configuration dictionary. |
|
user (str): User identifier. |
|
logger (Logger, optional): Logger instance. Defaults to None. |
|
""" |
|
self.config = config |
|
self.llm = self.load_llm() |
|
self.user = user |
|
self.logger = logger |
|
self.vector_db = VectorStoreManager(config, logger=self.logger).load_database() |
|
self.qa_prompt = get_prompt(config, "qa") |
|
self.rephrase_prompt = get_prompt( |
|
config, "rephrase" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_llm(self, old_config, new_config): |
|
""" |
|
Update the LLM and VectorStoreManager based on new configuration. |
|
|
|
Args: |
|
new_config (dict): New configuration dictionary. |
|
""" |
|
changes = self.get_config_changes(old_config, new_config) |
|
|
|
if "llm_params.llm_loader" in changes: |
|
self.llm = self.load_llm() |
|
|
|
if "vectorstore.db_option" in changes: |
|
self.vector_db = VectorStoreManager( |
|
self.config, logger=self.logger |
|
).load_database() |
|
|
|
|
|
|
|
|
|
|
|
|
|
if "llm_params.llm_style" in changes: |
|
self.qa_prompt = get_prompt( |
|
self.config, "qa" |
|
) |
|
|
|
def get_config_changes(self, old_config, new_config): |
|
""" |
|
Get the changes between the old and new configuration. |
|
|
|
Args: |
|
old_config (dict): Old configuration dictionary. |
|
new_config (dict): New configuration dictionary. |
|
|
|
Returns: |
|
dict: Dictionary containing the changes. |
|
""" |
|
changes = {} |
|
|
|
def compare_dicts(old, new, parent_key=""): |
|
for key in new: |
|
full_key = f"{parent_key}.{key}" if parent_key else key |
|
if isinstance(new[key], dict) and isinstance(old.get(key), dict): |
|
compare_dicts(old.get(key, {}), new[key], full_key) |
|
elif old.get(key) != new[key]: |
|
changes[full_key] = (old.get(key), new[key]) |
|
|
|
for key in old: |
|
if key not in new: |
|
full_key = f"{parent_key}.{key}" if parent_key else key |
|
changes[full_key] = (old[key], None) |
|
|
|
compare_dicts(old_config, new_config) |
|
return changes |
|
|
|
def retrieval_qa_chain( |
|
self, llm, qa_prompt, rephrase_prompt, db, memory=None, callbacks=None |
|
): |
|
""" |
|
Create a Retrieval QA Chain. |
|
|
|
Args: |
|
llm (LLM): The language model instance. |
|
qa_prompt (str): The QA prompt string. |
|
rephrase_prompt (str): The rephrase prompt string. |
|
db (VectorStore): The vector store instance. |
|
memory (Memory, optional): Memory instance. Defaults to None. |
|
|
|
Returns: |
|
Chain: The retrieval QA chain instance. |
|
""" |
|
retriever = Retriever(self.config)._return_retriever(db) |
|
|
|
if self.config["llm_params"]["llm_arch"] == "langchain": |
|
self.qa_chain = Langchain_RAG_V2( |
|
llm=llm, |
|
memory=memory, |
|
retriever=retriever, |
|
qa_prompt=qa_prompt, |
|
rephrase_prompt=rephrase_prompt, |
|
config=self.config, |
|
callbacks=callbacks, |
|
) |
|
|
|
self.question_generator = QuestionGenerator() |
|
else: |
|
raise ValueError( |
|
f"Invalid LLM Architecture: {self.config['llm_params']['llm_arch']}" |
|
) |
|
return self.qa_chain |
|
|
|
def load_llm(self): |
|
""" |
|
Load the language model. |
|
|
|
Returns: |
|
LLM: The loaded language model instance. |
|
""" |
|
chat_model_loader = ChatModelLoader(self.config) |
|
llm = chat_model_loader.load_chat_model() |
|
return llm |
|
|
|
def qa_bot(self, memory=None, callbacks=None): |
|
""" |
|
Create a QA bot instance. |
|
|
|
Args: |
|
memory (Memory, optional): Memory instance. Defaults to None. |
|
qa_prompt (str, optional): QA prompt string. Defaults to None. |
|
rephrase_prompt (str, optional): Rephrase prompt string. Defaults to None. |
|
|
|
Returns: |
|
Chain: The QA bot chain instance. |
|
""" |
|
|
|
if len(self.vector_db) == 0: |
|
raise ValueError( |
|
"No documents in the database. Populate the database first." |
|
) |
|
|
|
qa = self.retrieval_qa_chain( |
|
self.llm, |
|
self.qa_prompt, |
|
self.rephrase_prompt, |
|
self.vector_db, |
|
memory, |
|
callbacks=callbacks, |
|
) |
|
|
|
return qa |
|
|