XThomasBU
improvements, refactored chat
fc2cb23
raw
history blame
7.08 kB
from typing import Any, Dict, List, Union, Tuple, Optional
from langchain_core.messages import (
BaseMessage,
AIMessage,
FunctionMessage,
HumanMessage,
)
from langchain_core.prompts.base import BasePromptTemplate, format_document
from langchain_core.prompts.chat import MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.output_parsers.base import BaseOutputParser
from langchain_core.retrievers import BaseRetriever, RetrieverOutput
from langchain_core.language_models import LanguageModelLike
from langchain_core.runnables import Runnable, RunnableBranch, RunnablePassthrough
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables.utils import ConfigurableFieldSpec
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain.chains.combine_documents.base import (
DEFAULT_DOCUMENT_PROMPT,
DEFAULT_DOCUMENT_SEPARATOR,
DOCUMENTS_KEY,
BaseCombineDocumentsChain,
_validate_prompt,
)
from langchain.chains.llm import LLMChain
from langchain_core.callbacks import Callbacks
from langchain_core.documents import Document
CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]
from langchain_core.runnables.config import RunnableConfig
from langchain_core.messages import BaseMessage
class CustomRunnableWithHistory(RunnableWithMessageHistory):
def _enter_history(self, input: Any, config: RunnableConfig) -> List[BaseMessage]:
"""
Get the last k conversations from the message history.
Args:
input (Any): The input data.
config (RunnableConfig): The runnable configuration.
Returns:
List[BaseMessage]: The last k conversations.
"""
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
messages = hist.messages.copy()
if not self.history_messages_key:
# return all messages
messages += self._get_input_messages(input)
# return last k conversations
if config["configurable"]["memory_window"] == 0: # if k is 0, return empty list
messages = []
else:
messages = messages[-2 * config["configurable"]["memory_window"] :]
return messages
def _get_chat_history(chat_history: List[CHAT_TURN_TYPE], n: int = None) -> str:
"""
Convert chat history to a formatted string.
Args:
chat_history (List[CHAT_TURN_TYPE]): The chat history.
Returns:
str: The formatted chat history.
"""
_ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
buffer = ""
if n is not None:
# Calculate the number of turns to take (2 turns per pair)
turns_to_take = n * 2
chat_history = chat_history[-turns_to_take:]
for dialogue_turn in chat_history:
if isinstance(dialogue_turn, BaseMessage):
role_prefix = _ROLE_MAP.get(dialogue_turn.type, f"{dialogue_turn.type}: ")
buffer += f"\n{role_prefix}{dialogue_turn.content}"
elif isinstance(dialogue_turn, tuple):
human = "Student: " + dialogue_turn[0]
ai = "AI Tutor: " + dialogue_turn[1]
buffer += "\n" + "\n".join([human, ai])
else:
raise ValueError(
f"Unsupported chat history format: {type(dialogue_turn)}."
f" Full chat history: {chat_history} "
)
return buffer
class InMemoryHistory(BaseChatMessageHistory, BaseModel):
"""In-memory implementation of chat message history."""
messages: List[BaseMessage] = Field(default_factory=list)
def add_messages(self, messages: List[BaseMessage]) -> None:
"""Add a list of messages to the store."""
self.messages.extend(messages)
def clear(self) -> None:
"""Clear the message history."""
self.messages = []
def __len__(self) -> int:
"""Return the number of messages."""
return len(self.messages)
def get_last_n_conversations(self, n: int) -> "InMemoryHistory":
"""Return a new InMemoryHistory object with the last n conversations from the message history.
Args:
n (int): The number of last conversations to return. If 0, return an empty history.
Returns:
InMemoryHistory: A new InMemoryHistory object containing the last n conversations.
"""
if n == 0:
return InMemoryHistory()
# Each conversation consists of a pair of messages (human + AI)
num_messages = n * 2
last_messages = self.messages[-num_messages:]
return InMemoryHistory(messages=last_messages)
def create_history_aware_retriever(
llm: LanguageModelLike,
retriever: BaseRetriever,
prompt: BasePromptTemplate,
) -> Runnable[Dict[str, Any], RetrieverOutput]:
"""Create a chain that takes conversation history and returns documents."""
if "input" not in prompt.input_variables:
raise ValueError(
"Expected `input` to be a prompt variable, "
f"but got {prompt.input_variables}"
)
retrieve_documents = RunnableBranch(
(
lambda x: not x["chat_history"],
(lambda x: x["input"]) | retriever,
),
prompt | llm | StrOutputParser() | retriever,
).with_config(run_name="chat_retriever_chain")
return retrieve_documents
def create_stuff_documents_chain(
llm: LanguageModelLike,
prompt: BasePromptTemplate,
output_parser: Optional[BaseOutputParser] = None,
document_prompt: Optional[BasePromptTemplate] = None,
document_separator: str = DEFAULT_DOCUMENT_SEPARATOR,
) -> Runnable[Dict[str, Any], Any]:
"""Create a chain for passing a list of Documents to a model."""
_validate_prompt(prompt)
_document_prompt = document_prompt or DEFAULT_DOCUMENT_PROMPT
_output_parser = output_parser or StrOutputParser()
def format_docs(inputs: dict) -> str:
return document_separator.join(
format_document(doc, _document_prompt) for doc in inputs[DOCUMENTS_KEY]
)
return (
RunnablePassthrough.assign(**{DOCUMENTS_KEY: format_docs}).with_config(
run_name="format_inputs"
)
| prompt
| llm
| _output_parser
).with_config(run_name="stuff_documents_chain")
def create_retrieval_chain(
retriever: Union[BaseRetriever, Runnable[dict, RetrieverOutput]],
combine_docs_chain: Runnable[Dict[str, Any], str],
) -> Runnable:
"""Create retrieval chain that retrieves documents and then passes them on."""
if not isinstance(retriever, BaseRetriever):
retrieval_docs = retriever
else:
retrieval_docs = (lambda x: x["input"]) | retriever
retrieval_chain = (
RunnablePassthrough.assign(
context=retrieval_docs.with_config(run_name="retrieve_documents"),
).assign(answer=combine_docs_chain)
).with_config(run_name="retrieval_chain")
return retrieval_chain