File size: 5,914 Bytes
fc2cb23 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
from typing import Any, Dict, List, Union, Tuple, Optional
from langchain_core.messages import (
BaseMessage,
AIMessage,
FunctionMessage,
HumanMessage,
)
from langchain_core.prompts.base import BasePromptTemplate, format_document
from langchain_core.prompts.chat import MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.output_parsers.base import BaseOutputParser
from langchain_core.retrievers import BaseRetriever, RetrieverOutput
from langchain_core.language_models import LanguageModelLike
from langchain_core.runnables import Runnable, RunnableBranch, RunnablePassthrough
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables.utils import ConfigurableFieldSpec
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain.chains.combine_documents.base import (
DEFAULT_DOCUMENT_PROMPT,
DEFAULT_DOCUMENT_SEPARATOR,
DOCUMENTS_KEY,
BaseCombineDocumentsChain,
_validate_prompt,
)
from langchain.chains.llm import LLMChain
from langchain_core.callbacks import Callbacks
from langchain_core.documents import Document
CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]
from langchain_core.runnables.config import RunnableConfig
from langchain_core.messages import BaseMessage
class CustomRunnableWithHistory(RunnableWithMessageHistory):
def _enter_history(self, input: Any, config: RunnableConfig) -> List[BaseMessage]:
"""
Get the last k conversations from the message history.
Args:
input (Any): The input data.
config (RunnableConfig): The runnable configuration.
Returns:
List[BaseMessage]: The last k conversations.
"""
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
messages = hist.messages.copy()
if not self.history_messages_key:
# return all messages
messages += self._get_input_messages(input)
# return last k conversations
if config["configurable"]["memory_window"] == 0: # if k is 0, return empty list
messages = []
else:
messages = messages[-2 * config["configurable"]["memory_window"] :]
return messages
class InMemoryHistory(BaseChatMessageHistory, BaseModel):
"""In-memory implementation of chat message history."""
messages: List[BaseMessage] = Field(default_factory=list)
def add_messages(self, messages: List[BaseMessage]) -> None:
"""Add a list of messages to the store."""
self.messages.extend(messages)
def clear(self) -> None:
"""Clear the message history."""
self.messages = []
def __len__(self) -> int:
"""Return the number of messages."""
return len(self.messages)
def get_last_n_conversations(self, n: int) -> "InMemoryHistory":
"""Return a new InMemoryHistory object with the last n conversations from the message history.
Args:
n (int): The number of last conversations to return. If 0, return an empty history.
Returns:
InMemoryHistory: A new InMemoryHistory object containing the last n conversations.
"""
if n == 0:
return InMemoryHistory()
# Each conversation consists of a pair of messages (human + AI)
num_messages = n * 2
last_messages = self.messages[-num_messages:]
return InMemoryHistory(messages=last_messages)
def create_history_aware_retriever(
llm: LanguageModelLike,
retriever: BaseRetriever,
prompt: BasePromptTemplate,
) -> Runnable[Dict[str, Any], RetrieverOutput]:
"""Create a chain that takes conversation history and returns documents."""
if "input" not in prompt.input_variables:
raise ValueError(
"Expected `input` to be a prompt variable, "
f"but got {prompt.input_variables}"
)
retrieve_documents = RunnableBranch(
(
lambda x: not x["chat_history"],
(lambda x: x["input"]) | retriever,
),
prompt | llm | StrOutputParser() | retriever,
).with_config(run_name="chat_retriever_chain")
return retrieve_documents
def create_stuff_documents_chain(
llm: LanguageModelLike,
prompt: BasePromptTemplate,
output_parser: Optional[BaseOutputParser] = None,
document_prompt: Optional[BasePromptTemplate] = None,
document_separator: str = DEFAULT_DOCUMENT_SEPARATOR,
) -> Runnable[Dict[str, Any], Any]:
"""Create a chain for passing a list of Documents to a model."""
_validate_prompt(prompt)
_document_prompt = document_prompt or DEFAULT_DOCUMENT_PROMPT
_output_parser = output_parser or StrOutputParser()
def format_docs(inputs: dict) -> str:
return document_separator.join(
format_document(doc, _document_prompt) for doc in inputs[DOCUMENTS_KEY]
)
return (
RunnablePassthrough.assign(**{DOCUMENTS_KEY: format_docs}).with_config(
run_name="format_inputs"
)
| prompt
| llm
| _output_parser
).with_config(run_name="stuff_documents_chain")
def create_retrieval_chain(
retriever: Union[BaseRetriever, Runnable[dict, RetrieverOutput]],
combine_docs_chain: Runnable[Dict[str, Any], str],
) -> Runnable:
"""Create retrieval chain that retrieves documents and then passes them on."""
if not isinstance(retriever, BaseRetriever):
retrieval_docs = retriever
else:
retrieval_docs = (lambda x: x["input"]) | retriever
retrieval_chain = (
RunnablePassthrough.assign(
context=retrieval_docs.with_config(run_name="retrieve_documents"),
).assign(answer=combine_docs_chain)
).with_config(run_name="retrieval_chain")
return retrieval_chain
|