|
from modules.chat.helpers import get_prompt |
|
from modules.chat.chat_model_loader import ChatModelLoader |
|
from modules.vectorstore.store_manager import VectorStoreManager |
|
from modules.retriever.retriever import Retriever |
|
from modules.chat.langchain.langchain_rag import CustomConversationalRetrievalChain |
|
|
|
|
|
class LLMTutor: |
|
def __init__(self, config, user, logger=None): |
|
""" |
|
Initialize the LLMTutor class. |
|
|
|
Args: |
|
config (dict): Configuration dictionary. |
|
user (str): User identifier. |
|
logger (Logger, optional): Logger instance. Defaults to None. |
|
""" |
|
self.config = config |
|
self.llm = self.load_llm() |
|
self.user = user |
|
self.logger = logger |
|
self.vector_db = VectorStoreManager(config, logger=self.logger) |
|
if self.config["vectorstore"]["embedd_files"]: |
|
self.vector_db.create_database() |
|
self.vector_db.save_database() |
|
|
|
def update_llm(self, new_config): |
|
""" |
|
Update the LLM and VectorStoreManager based on new configuration. |
|
|
|
Args: |
|
new_config (dict): New configuration dictionary. |
|
""" |
|
changes = self.get_config_changes(self.config, new_config) |
|
self.config = new_config |
|
|
|
if "chat_model" in changes: |
|
self.llm = self.load_llm() |
|
|
|
if "vectorstore" in changes: |
|
self.vector_db = VectorStoreManager( |
|
self.config, logger=self.logger |
|
) |
|
if self.config["vectorstore"]["embedd_files"]: |
|
self.vector_db.create_database() |
|
self.vector_db.save_database() |
|
|
|
def get_config_changes(self, old_config, new_config): |
|
""" |
|
Get the changes between the old and new configuration. |
|
|
|
Args: |
|
old_config (dict): Old configuration dictionary. |
|
new_config (dict): New configuration dictionary. |
|
|
|
Returns: |
|
dict: Dictionary containing the changes. |
|
""" |
|
changes = {} |
|
for key in new_config: |
|
if old_config.get(key) != new_config[key]: |
|
changes[key] = (old_config.get(key), new_config[key]) |
|
return changes |
|
|
|
def retrieval_qa_chain(self, llm, qa_prompt, rephrase_prompt, db, memory=None): |
|
""" |
|
Create a Retrieval QA Chain. |
|
|
|
Args: |
|
llm (LLM): The language model instance. |
|
qa_prompt (str): The QA prompt string. |
|
rephrase_prompt (str): The rephrase prompt string. |
|
db (VectorStore): The vector store instance. |
|
memory (Memory, optional): Memory instance. Defaults to None. |
|
|
|
Returns: |
|
Chain: The retrieval QA chain instance. |
|
""" |
|
retriever = Retriever(self.config)._return_retriever(db) |
|
|
|
if self.config["llm_params"]["use_history"]: |
|
qa_chain = CustomConversationalRetrievalChain( |
|
llm=llm, |
|
memory=memory, |
|
retriever=retriever, |
|
qa_prompt=qa_prompt, |
|
rephrase_prompt=rephrase_prompt, |
|
) |
|
return qa_chain |
|
|
|
def load_llm(self): |
|
""" |
|
Load the language model. |
|
|
|
Returns: |
|
LLM: The loaded language model instance. |
|
""" |
|
chat_model_loader = ChatModelLoader(self.config) |
|
llm = chat_model_loader.load_chat_model() |
|
return llm |
|
|
|
def qa_bot(self, memory=None, qa_prompt=None, rephrase_prompt=None): |
|
""" |
|
Create a QA bot instance. |
|
|
|
Args: |
|
memory (Memory, optional): Memory instance. Defaults to None. |
|
qa_prompt (str, optional): QA prompt string. Defaults to None. |
|
rephrase_prompt (str, optional): Rephrase prompt string. Defaults to None. |
|
|
|
Returns: |
|
Chain: The QA bot chain instance. |
|
""" |
|
if qa_prompt is None: |
|
qa_prompt = get_prompt(self.config, "qa") |
|
if rephrase_prompt is None: |
|
rephrase_prompt = get_prompt(self.config, "rephrase") |
|
db = self.vector_db.load_database() |
|
|
|
if len(db) == 0: |
|
raise ValueError( |
|
"No documents in the database. Populate the database first." |
|
) |
|
qa = self.retrieval_qa_chain(self.llm, qa_prompt, rephrase_prompt, db, memory) |
|
|
|
return qa |
|
|
|
def final_result(query): |
|
""" |
|
Get the final result for a given query. |
|
|
|
Args: |
|
query (str): The query string. |
|
|
|
Returns: |
|
str: The response string. |
|
""" |
|
qa_result = qa_bot() |
|
response = qa_result({"query": query}) |
|
return response |
|
|