|
from typing import Any, Dict, List, Union, Tuple, Optional |
|
from langchain_core.messages import ( |
|
BaseMessage, |
|
AIMessage, |
|
FunctionMessage, |
|
HumanMessage, |
|
) |
|
|
|
from langchain_core.prompts.base import BasePromptTemplate, format_document |
|
from langchain_core.prompts.chat import MessagesPlaceholder |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.output_parsers.base import BaseOutputParser |
|
from langchain_core.retrievers import BaseRetriever, RetrieverOutput |
|
from langchain_core.language_models import LanguageModelLike |
|
from langchain_core.runnables import Runnable, RunnableBranch, RunnablePassthrough |
|
from langchain_core.runnables.history import RunnableWithMessageHistory |
|
from langchain_core.runnables.utils import ConfigurableFieldSpec |
|
from langchain_core.chat_history import BaseChatMessageHistory |
|
from langchain_core.pydantic_v1 import BaseModel, Field |
|
from langchain.chains.combine_documents.base import ( |
|
DEFAULT_DOCUMENT_PROMPT, |
|
DEFAULT_DOCUMENT_SEPARATOR, |
|
DOCUMENTS_KEY, |
|
BaseCombineDocumentsChain, |
|
_validate_prompt, |
|
) |
|
from langchain.chains.llm import LLMChain |
|
from langchain_core.callbacks import Callbacks |
|
from langchain_core.documents import Document |
|
|
|
|
|
CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage] |
|
|
|
from langchain_core.runnables.config import RunnableConfig |
|
from langchain_core.messages import BaseMessage |
|
|
|
|
|
class CustomRunnableWithHistory(RunnableWithMessageHistory): |
|
def _enter_history(self, input: Any, config: RunnableConfig) -> List[BaseMessage]: |
|
""" |
|
Get the last k conversations from the message history. |
|
|
|
Args: |
|
input (Any): The input data. |
|
config (RunnableConfig): The runnable configuration. |
|
|
|
Returns: |
|
List[BaseMessage]: The last k conversations. |
|
""" |
|
hist: BaseChatMessageHistory = config["configurable"]["message_history"] |
|
messages = hist.messages.copy() |
|
|
|
if not self.history_messages_key: |
|
|
|
messages += self._get_input_messages(input) |
|
|
|
|
|
if config["configurable"]["memory_window"] == 0: |
|
messages = [] |
|
else: |
|
messages = messages[-2 * config["configurable"]["memory_window"] :] |
|
return messages |
|
|
|
|
|
def _get_chat_history(chat_history: List[CHAT_TURN_TYPE], n: int = None) -> str: |
|
""" |
|
Convert chat history to a formatted string. |
|
|
|
Args: |
|
chat_history (List[CHAT_TURN_TYPE]): The chat history. |
|
|
|
Returns: |
|
str: The formatted chat history. |
|
""" |
|
_ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "} |
|
buffer = "" |
|
if n is not None: |
|
|
|
turns_to_take = n * 2 |
|
chat_history = chat_history[-turns_to_take:] |
|
for dialogue_turn in chat_history: |
|
if isinstance(dialogue_turn, BaseMessage): |
|
role_prefix = _ROLE_MAP.get(dialogue_turn.type, f"{dialogue_turn.type}: ") |
|
buffer += f"\n{role_prefix}{dialogue_turn.content}" |
|
elif isinstance(dialogue_turn, tuple): |
|
human = "Student: " + dialogue_turn[0] |
|
ai = "AI Tutor: " + dialogue_turn[1] |
|
buffer += "\n" + "\n".join([human, ai]) |
|
else: |
|
raise ValueError( |
|
f"Unsupported chat history format: {type(dialogue_turn)}." |
|
f" Full chat history: {chat_history} " |
|
) |
|
return buffer |
|
|
|
|
|
class InMemoryHistory(BaseChatMessageHistory, BaseModel): |
|
"""In-memory implementation of chat message history.""" |
|
|
|
messages: List[BaseMessage] = Field(default_factory=list) |
|
|
|
def add_messages(self, messages: List[BaseMessage]) -> None: |
|
"""Add a list of messages to the store.""" |
|
self.messages.extend(messages) |
|
|
|
def clear(self) -> None: |
|
"""Clear the message history.""" |
|
self.messages = [] |
|
|
|
def __len__(self) -> int: |
|
"""Return the number of messages.""" |
|
return len(self.messages) |
|
|
|
def get_last_n_conversations(self, n: int) -> "InMemoryHistory": |
|
"""Return a new InMemoryHistory object with the last n conversations from the message history. |
|
|
|
Args: |
|
n (int): The number of last conversations to return. If 0, return an empty history. |
|
|
|
Returns: |
|
InMemoryHistory: A new InMemoryHistory object containing the last n conversations. |
|
""" |
|
if n == 0: |
|
return InMemoryHistory() |
|
|
|
num_messages = n * 2 |
|
last_messages = self.messages[-num_messages:] |
|
return InMemoryHistory(messages=last_messages) |
|
|
|
|
|
def create_history_aware_retriever( |
|
llm: LanguageModelLike, |
|
retriever: BaseRetriever, |
|
prompt: BasePromptTemplate, |
|
) -> Runnable[Dict[str, Any], RetrieverOutput]: |
|
"""Create a chain that takes conversation history and returns documents.""" |
|
if "input" not in prompt.input_variables: |
|
raise ValueError( |
|
"Expected `input` to be a prompt variable, " |
|
f"but got {prompt.input_variables}" |
|
) |
|
|
|
retrieve_documents = RunnableBranch( |
|
( |
|
lambda x: not x["chat_history"], |
|
(lambda x: x["input"]) | retriever, |
|
), |
|
prompt | llm | StrOutputParser() | retriever, |
|
).with_config(run_name="chat_retriever_chain") |
|
|
|
return retrieve_documents |
|
|
|
|
|
def create_stuff_documents_chain( |
|
llm: LanguageModelLike, |
|
prompt: BasePromptTemplate, |
|
output_parser: Optional[BaseOutputParser] = None, |
|
document_prompt: Optional[BasePromptTemplate] = None, |
|
document_separator: str = DEFAULT_DOCUMENT_SEPARATOR, |
|
) -> Runnable[Dict[str, Any], Any]: |
|
"""Create a chain for passing a list of Documents to a model.""" |
|
_validate_prompt(prompt) |
|
_document_prompt = document_prompt or DEFAULT_DOCUMENT_PROMPT |
|
_output_parser = output_parser or StrOutputParser() |
|
|
|
def format_docs(inputs: dict) -> str: |
|
return document_separator.join( |
|
format_document(doc, _document_prompt) for doc in inputs[DOCUMENTS_KEY] |
|
) |
|
|
|
return ( |
|
RunnablePassthrough.assign(**{DOCUMENTS_KEY: format_docs}).with_config( |
|
run_name="format_inputs" |
|
) |
|
| prompt |
|
| llm |
|
| _output_parser |
|
).with_config(run_name="stuff_documents_chain") |
|
|
|
|
|
def create_retrieval_chain( |
|
retriever: Union[BaseRetriever, Runnable[dict, RetrieverOutput]], |
|
combine_docs_chain: Runnable[Dict[str, Any], str], |
|
) -> Runnable: |
|
"""Create retrieval chain that retrieves documents and then passes them on.""" |
|
if not isinstance(retriever, BaseRetriever): |
|
retrieval_docs = retriever |
|
else: |
|
retrieval_docs = (lambda x: x["input"]) | retriever |
|
|
|
retrieval_chain = ( |
|
RunnablePassthrough.assign( |
|
context=retrieval_docs.with_config(run_name="retrieve_documents"), |
|
).assign(answer=combine_docs_chain) |
|
).with_config(run_name="retrieval_chain") |
|
|
|
return retrieval_chain |
|
|