|
from langchain_core.prompts import ChatPromptTemplate |
|
|
|
from modules.chat.langchain.utils import * |
|
from langchain.memory import ChatMessageHistory |
|
from modules.chat.base import BaseRAG |
|
|
|
|
|
class Langchain_RAG(BaseRAG): |
|
def __init__(self, llm, memory, retriever, qa_prompt: str, rephrase_prompt: str): |
|
""" |
|
Initialize the Langchain_RAG class. |
|
|
|
Args: |
|
llm (LanguageModelLike): The language model instance. |
|
memory (BaseChatMessageHistory): The chat message history instance. |
|
retriever (BaseRetriever): The retriever instance. |
|
qa_prompt (str): The QA prompt string. |
|
rephrase_prompt (str): The rephrase prompt string. |
|
""" |
|
self.llm = llm |
|
self.memory = self.add_history_from_list(memory) |
|
self.retriever = retriever |
|
self.qa_prompt = qa_prompt |
|
self.rephrase_prompt = rephrase_prompt |
|
self.store = {} |
|
|
|
|
|
contextualize_q_system_prompt = rephrase_prompt or ( |
|
"Given a chat history and the latest user question " |
|
"which might reference context in the chat history, " |
|
"formulate a standalone question which can be understood " |
|
"without the chat history. Do NOT answer the question, just " |
|
"reformulate it if needed and otherwise return it as is." |
|
) |
|
self.contextualize_q_prompt = ChatPromptTemplate.from_template( |
|
contextualize_q_system_prompt |
|
) |
|
|
|
|
|
self.history_aware_retriever = create_history_aware_retriever( |
|
self.llm, self.retriever, self.contextualize_q_prompt |
|
) |
|
|
|
|
|
qa_system_prompt = qa_prompt or ( |
|
"You are an assistant for question-answering tasks. Use " |
|
"the following pieces of retrieved context to answer the " |
|
"question. If you don't know the answer, just say that you " |
|
"don't know. Use three sentences maximum and keep the answer " |
|
"concise." |
|
"\n\n" |
|
"{context}" |
|
) |
|
self.qa_prompt_template = ChatPromptTemplate.from_template(qa_system_prompt) |
|
|
|
|
|
self.question_answer_chain = create_stuff_documents_chain( |
|
self.llm, self.qa_prompt_template |
|
) |
|
|
|
|
|
self.rag_chain = create_retrieval_chain( |
|
self.history_aware_retriever, self.question_answer_chain |
|
) |
|
|
|
self.rag_chain = CustomRunnableWithHistory( |
|
self.rag_chain, |
|
get_session_history=self.get_session_history, |
|
input_messages_key="input", |
|
history_messages_key="chat_history", |
|
output_messages_key="answer", |
|
history_factory_config=[ |
|
ConfigurableFieldSpec( |
|
id="user_id", |
|
annotation=str, |
|
name="User ID", |
|
description="Unique identifier for the user.", |
|
default="", |
|
is_shared=True, |
|
), |
|
ConfigurableFieldSpec( |
|
id="conversation_id", |
|
annotation=str, |
|
name="Conversation ID", |
|
description="Unique identifier for the conversation.", |
|
default="", |
|
is_shared=True, |
|
), |
|
ConfigurableFieldSpec( |
|
id="memory_window", |
|
annotation=int, |
|
name="Number of Conversations", |
|
description="Number of conversations to consider for context.", |
|
default=1, |
|
is_shared=True, |
|
), |
|
], |
|
) |
|
|
|
def get_session_history( |
|
self, user_id: str, conversation_id: str, memory_window: int |
|
) -> BaseChatMessageHistory: |
|
""" |
|
Get the session history for a user and conversation. |
|
|
|
Args: |
|
user_id (str): The user identifier. |
|
conversation_id (str): The conversation identifier. |
|
memory_window (int): The number of conversations to consider for context. |
|
|
|
Returns: |
|
BaseChatMessageHistory: The chat message history. |
|
""" |
|
if (user_id, conversation_id) not in self.store: |
|
self.store[(user_id, conversation_id)] = InMemoryHistory() |
|
self.store[(user_id, conversation_id)].add_messages( |
|
self.memory.messages |
|
) |
|
return self.store[(user_id, conversation_id)] |
|
|
|
def invoke(self, user_query, config): |
|
""" |
|
Invoke the chain. |
|
|
|
Args: |
|
kwargs: The input variables. |
|
|
|
Returns: |
|
dict: The output variables. |
|
""" |
|
res = self.rag_chain.invoke(user_query, config) |
|
res["rephrase_prompt"] = self.rephrase_prompt |
|
res["qa_prompt"] = self.qa_prompt |
|
return res |
|
|
|
def stream(self, user_query, config): |
|
res = self.rag_chain.stream(user_query, config) |
|
return res |
|
|
|
def add_history_from_list(self, history_list): |
|
""" |
|
Add messages from a list to the chat history. |
|
|
|
Args: |
|
messages (list): The list of messages to add. |
|
""" |
|
history = ChatMessageHistory() |
|
|
|
for idx, message_pairs in enumerate(history_list): |
|
history.add_user_message(message_pairs[0]) |
|
history.add_ai_message(message_pairs[1]) |
|
|
|
return history |
|
|