tutor_dev / modules /chat /langchain /langchain_rag.py
XThomasBU
initial commit
d92c997
raw
history blame
9.38 kB
from langchain_core.prompts import ChatPromptTemplate
# from modules.chat.langchain.utils import
from langchain_community.chat_message_histories import ChatMessageHistory
from modules.chat.base import BaseRAG
from langchain_core.prompts import PromptTemplate
from langchain.memory import ConversationBufferWindowMemory
from langchain_core.runnables.utils import ConfigurableFieldSpec
from .utils import (
CustomConversationalRetrievalChain,
create_history_aware_retriever,
create_stuff_documents_chain,
create_retrieval_chain,
return_questions,
CustomRunnableWithHistory,
BaseChatMessageHistory,
InMemoryHistory,
)
class Langchain_RAG_V1(BaseRAG):
def __init__(
self,
llm,
memory,
retriever,
qa_prompt: str,
rephrase_prompt: str,
config: dict,
callbacks=None,
):
"""
Initialize the Langchain_RAG class.
Args:
llm (LanguageModelLike): The language model instance.
memory (BaseChatMessageHistory): The chat message history instance.
retriever (BaseRetriever): The retriever instance.
qa_prompt (str): The QA prompt string.
rephrase_prompt (str): The rephrase prompt string.
"""
self.llm = llm
self.config = config
# self.memory = self.add_history_from_list(memory)
self.memory = ConversationBufferWindowMemory(
k=self.config["llm_params"]["memory_window"],
memory_key="chat_history",
return_messages=True,
output_key="answer",
max_token_limit=128,
)
self.retriever = retriever
self.qa_prompt = qa_prompt
self.rephrase_prompt = rephrase_prompt
self.store = {}
self.qa_prompt = PromptTemplate(
template=self.qa_prompt,
input_variables=["context", "chat_history", "input"],
)
self.rag_chain = CustomConversationalRetrievalChain.from_llm(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True,
memory=self.memory,
combine_docs_chain_kwargs={"prompt": self.qa_prompt},
response_if_no_docs_found="No context found",
)
def add_history_from_list(self, history_list):
"""
TODO: Add messages from a list to the chat history.
"""
history = []
return history
async def invoke(self, user_query, config):
"""
Invoke the chain.
Args:
kwargs: The input variables.
Returns:
dict: The output variables.
"""
res = await self.rag_chain.acall(user_query["input"])
return res
class QuestionGenerator:
"""
Generate a question from the LLMs response and users input and past conversations.
"""
def __init__(self):
pass
def generate_questions(self, query, response, chat_history, context, config):
questions = return_questions(query, response, chat_history, context, config)
return questions
class Langchain_RAG_V2(BaseRAG):
def __init__(
self,
llm,
memory,
retriever,
qa_prompt: str,
rephrase_prompt: str,
config: dict,
callbacks=None,
):
"""
Initialize the Langchain_RAG class.
Args:
llm (LanguageModelLike): The language model instance.
memory (BaseChatMessageHistory): The chat message history instance.
retriever (BaseRetriever): The retriever instance.
qa_prompt (str): The QA prompt string.
rephrase_prompt (str): The rephrase prompt string.
"""
self.llm = llm
self.memory = self.add_history_from_list(memory)
self.retriever = retriever
self.qa_prompt = qa_prompt
self.rephrase_prompt = rephrase_prompt
self.store = {}
# Contextualize question prompt
contextualize_q_system_prompt = rephrase_prompt or (
"Given a chat history and the latest user question "
"which might reference context in the chat history, "
"formulate a standalone question which can be understood "
"without the chat history. Do NOT answer the question, just "
"reformulate it if needed and otherwise return it as is."
)
self.contextualize_q_prompt = ChatPromptTemplate.from_template(
contextualize_q_system_prompt
)
# History-aware retriever
self.history_aware_retriever = create_history_aware_retriever(
self.llm, self.retriever, self.contextualize_q_prompt
)
# Answer question prompt
qa_system_prompt = qa_prompt or (
"You are an assistant for question-answering tasks. Use "
"the following pieces of retrieved context to answer the "
"question. If you don't know the answer, just say that you "
"don't know. Use three sentences maximum and keep the answer "
"concise."
"\n\n"
"{context}"
)
self.qa_prompt_template = ChatPromptTemplate.from_template(qa_system_prompt)
# Question-answer chain
self.question_answer_chain = create_stuff_documents_chain(
self.llm, self.qa_prompt_template
)
# Final retrieval chain
self.rag_chain = create_retrieval_chain(
self.history_aware_retriever, self.question_answer_chain
)
self.rag_chain = CustomRunnableWithHistory(
self.rag_chain,
get_session_history=self.get_session_history,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
history_factory_config=[
ConfigurableFieldSpec(
id="user_id",
annotation=str,
name="User ID",
description="Unique identifier for the user.",
default="",
is_shared=True,
),
ConfigurableFieldSpec(
id="conversation_id",
annotation=str,
name="Conversation ID",
description="Unique identifier for the conversation.",
default="",
is_shared=True,
),
ConfigurableFieldSpec(
id="memory_window",
annotation=int,
name="Number of Conversations",
description="Number of conversations to consider for context.",
default=1,
is_shared=True,
),
],
).with_config(run_name="Langchain_RAG_V2")
if callbacks is not None:
self.rag_chain = self.rag_chain.with_config(callbacks=callbacks)
def get_session_history(
self, user_id: str, conversation_id: str, memory_window: int
) -> BaseChatMessageHistory:
"""
Get the session history for a user and conversation.
Args:
user_id (str): The user identifier.
conversation_id (str): The conversation identifier.
memory_window (int): The number of conversations to consider for context.
Returns:
BaseChatMessageHistory: The chat message history.
"""
if (user_id, conversation_id) not in self.store:
self.store[(user_id, conversation_id)] = InMemoryHistory()
self.store[(user_id, conversation_id)].add_messages(
self.memory.messages
) # add previous messages to the store. Note: the store is in-memory.
return self.store[(user_id, conversation_id)]
async def invoke(self, user_query, config, **kwargs):
"""
Invoke the chain.
Args:
kwargs: The input variables.
Returns:
dict: The output variables.
"""
res = await self.rag_chain.ainvoke(user_query, config, **kwargs)
res["rephrase_prompt"] = self.rephrase_prompt
res["qa_prompt"] = self.qa_prompt
return res
def stream(self, user_query, config):
res = self.rag_chain.stream(user_query, config)
return res
def add_history_from_list(self, conversation_list):
"""
Add messages from a list to the chat history.
Args:
messages (list): The list of messages to add.
"""
history = ChatMessageHistory()
for idx, message in enumerate(conversation_list):
message_type = (
message.get("type", None)
if isinstance(message, dict)
else getattr(message, "type", None)
)
message_content = (
message.get("content", None)
if isinstance(message, dict)
else getattr(message, "content", None)
)
if message_type in ["human", "user_message"]:
history.add_user_message(message_content)
elif message_type in ["ai", "ai_message"]:
history.add_ai_message(message_content)
return history