XThomasBU
commited on
Commit
·
fc2cb23
1
Parent(s):
8f6647c
improvements, refactored chat
Browse files- code/main.py +41 -22
- code/modules/chat/helpers.py +68 -54
- code/modules/chat/langchain/langchain_rag.py +137 -0
- code/modules/chat/langchain/utils.py +197 -0
- code/modules/chat/llm_tutor.py +93 -169
- code/modules/chat_processor/chat_processor.py +32 -7
- code/modules/chat_processor/literal_ai.py +3 -3
- code/modules/config/constants.py +63 -64
code/main.py
CHANGED
@@ -45,6 +45,11 @@ class Chatbot:
|
|
45 |
"""From the session `llm_settings`, create new LLMConfig and LLM objects,
|
46 |
save them in session state."""
|
47 |
|
|
|
|
|
|
|
|
|
|
|
48 |
llm_settings = cl.user_session.get("llm_settings", {})
|
49 |
chat_profile = llm_settings.get("chat_model")
|
50 |
retriever_method = llm_settings.get("retriever_method")
|
@@ -54,16 +59,18 @@ class Chatbot:
|
|
54 |
|
55 |
chain = cl.user_session.get("chain")
|
56 |
memory = chain.memory
|
57 |
-
|
58 |
"db_option"
|
59 |
] = retriever_method # update the retriever method in the config
|
60 |
-
|
|
|
|
|
61 |
|
62 |
-
self.llm_tutor
|
63 |
self.chain = self.llm_tutor.qa_bot(memory=memory)
|
64 |
|
65 |
tags = [chat_profile, self.config["vectorstore"]["db_option"]]
|
66 |
-
self.chat_processor = ChatProcessor(self.
|
67 |
|
68 |
cl.user_session.set("chain", self.chain)
|
69 |
cl.user_session.set("llm_tutor", self.llm_tutor)
|
@@ -103,6 +110,11 @@ class Chatbot:
|
|
103 |
cl.input_widget.Switch(
|
104 |
id="view_sources", label="View Sources", initial=False
|
105 |
),
|
|
|
|
|
|
|
|
|
|
|
106 |
]
|
107 |
).send() # type: ignore
|
108 |
|
@@ -156,14 +168,14 @@ class Chatbot:
|
|
156 |
|
157 |
async def chat_profile(self):
|
158 |
return [
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
cl.ChatProfile(
|
168 |
name="Llama",
|
169 |
markdown_description="Use the local LLM: **Tiny Llama**.",
|
@@ -181,37 +193,44 @@ class Chatbot:
|
|
181 |
if chat_profile:
|
182 |
self._configure_llm(chat_profile)
|
183 |
|
184 |
-
self.llm_tutor = LLMTutor(
|
|
|
|
|
185 |
self.chain = self.llm_tutor.qa_bot()
|
186 |
tags = [chat_profile, self.config["vectorstore"]["db_option"]]
|
187 |
-
self.chat_processor = ChatProcessor(self.
|
188 |
|
189 |
cl.user_session.set("llm_tutor", self.llm_tutor)
|
190 |
cl.user_session.set("chain", self.chain)
|
191 |
-
cl.user_session.set("counter",
|
192 |
cl.user_session.set("chat_processor", self.chat_processor)
|
193 |
|
194 |
async def on_chat_end(self):
|
195 |
await cl.Message(content="Sorry, I have to go now. Goodbye!").send()
|
196 |
|
197 |
async def main(self, message):
|
198 |
-
user = cl.user_session.get("user")
|
199 |
chain = cl.user_session.get("chain")
|
200 |
counter = cl.user_session.get("counter")
|
201 |
-
llm_settings = cl.user_session.get("llm_settings")
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
|
203 |
counter += 1
|
204 |
cl.user_session.set("counter", counter)
|
205 |
|
206 |
-
cb = cl.AsyncLangchainCallbackHandler() # TODO: fix streaming here
|
207 |
-
cb.answer_reached = True
|
208 |
-
|
209 |
processor = cl.user_session.get("chat_processor")
|
210 |
-
res = await processor.rag(message.content, chain
|
|
|
|
|
|
|
211 |
answer = res.get("answer", res.get("result"))
|
212 |
|
213 |
answer_with_sources, source_elements, sources_dict = get_sources(
|
214 |
-
res, answer, view_sources=
|
215 |
)
|
216 |
processor._process(message.content, answer, sources_dict)
|
217 |
|
|
|
45 |
"""From the session `llm_settings`, create new LLMConfig and LLM objects,
|
46 |
save them in session state."""
|
47 |
|
48 |
+
old_config = self.config.copy() # create a copy of the previous config
|
49 |
+
new_config = (
|
50 |
+
self.config.copy()
|
51 |
+
) # create the new config as a copy of the previous config
|
52 |
+
|
53 |
llm_settings = cl.user_session.get("llm_settings", {})
|
54 |
chat_profile = llm_settings.get("chat_model")
|
55 |
retriever_method = llm_settings.get("retriever_method")
|
|
|
59 |
|
60 |
chain = cl.user_session.get("chain")
|
61 |
memory = chain.memory
|
62 |
+
new_config["vectorstore"][
|
63 |
"db_option"
|
64 |
] = retriever_method # update the retriever method in the config
|
65 |
+
new_config["llm_params"][
|
66 |
+
"memory_window"
|
67 |
+
] = memory_window # update the memory window in the config
|
68 |
|
69 |
+
self.llm_tutor.update_llm(new_config)
|
70 |
self.chain = self.llm_tutor.qa_bot(memory=memory)
|
71 |
|
72 |
tags = [chat_profile, self.config["vectorstore"]["db_option"]]
|
73 |
+
self.chat_processor = ChatProcessor(self.llm_tutor, tags=tags)
|
74 |
|
75 |
cl.user_session.set("chain", self.chain)
|
76 |
cl.user_session.set("llm_tutor", self.llm_tutor)
|
|
|
110 |
cl.input_widget.Switch(
|
111 |
id="view_sources", label="View Sources", initial=False
|
112 |
),
|
113 |
+
# cl.input_widget.TextInput(
|
114 |
+
# id="vectorstore",
|
115 |
+
# label="temp",
|
116 |
+
# initial="None",
|
117 |
+
# ),
|
118 |
]
|
119 |
).send() # type: ignore
|
120 |
|
|
|
168 |
|
169 |
async def chat_profile(self):
|
170 |
return [
|
171 |
+
cl.ChatProfile(
|
172 |
+
name="gpt-3.5-turbo-1106",
|
173 |
+
markdown_description="Use OpenAI API for **gpt-3.5-turbo-1106**.",
|
174 |
+
),
|
175 |
+
cl.ChatProfile(
|
176 |
+
name="gpt-4",
|
177 |
+
markdown_description="Use OpenAI API for **gpt-4**.",
|
178 |
+
),
|
179 |
cl.ChatProfile(
|
180 |
name="Llama",
|
181 |
markdown_description="Use the local LLM: **Tiny Llama**.",
|
|
|
193 |
if chat_profile:
|
194 |
self._configure_llm(chat_profile)
|
195 |
|
196 |
+
self.llm_tutor = LLMTutor(
|
197 |
+
self.config, user={"user_id": "abc123", "session_id": "789"}
|
198 |
+
)
|
199 |
self.chain = self.llm_tutor.qa_bot()
|
200 |
tags = [chat_profile, self.config["vectorstore"]["db_option"]]
|
201 |
+
self.chat_processor = ChatProcessor(self.llm_tutor, tags=tags)
|
202 |
|
203 |
cl.user_session.set("llm_tutor", self.llm_tutor)
|
204 |
cl.user_session.set("chain", self.chain)
|
205 |
+
cl.user_session.set("counter", 20)
|
206 |
cl.user_session.set("chat_processor", self.chat_processor)
|
207 |
|
208 |
async def on_chat_end(self):
|
209 |
await cl.Message(content="Sorry, I have to go now. Goodbye!").send()
|
210 |
|
211 |
async def main(self, message):
|
|
|
212 |
chain = cl.user_session.get("chain")
|
213 |
counter = cl.user_session.get("counter")
|
214 |
+
llm_settings = cl.user_session.get("llm_settings", {})
|
215 |
+
view_sources = llm_settings.get("view_sources", False)
|
216 |
+
|
217 |
+
print("HERE")
|
218 |
+
print(llm_settings)
|
219 |
+
print(view_sources)
|
220 |
+
print("\n\n")
|
221 |
|
222 |
counter += 1
|
223 |
cl.user_session.set("counter", counter)
|
224 |
|
|
|
|
|
|
|
225 |
processor = cl.user_session.get("chat_processor")
|
226 |
+
res = await processor.rag(message.content, chain)
|
227 |
+
|
228 |
+
print(res)
|
229 |
+
|
230 |
answer = res.get("answer", res.get("result"))
|
231 |
|
232 |
answer_with_sources, source_elements, sources_dict = get_sources(
|
233 |
+
res, answer, view_sources=view_sources
|
234 |
)
|
235 |
processor._process(message.content, answer, sources_dict)
|
236 |
|
code/modules/chat/helpers.py
CHANGED
@@ -1,13 +1,14 @@
|
|
1 |
from modules.config.constants import *
|
2 |
import chainlit as cl
|
3 |
from langchain_core.prompts import PromptTemplate
|
|
|
4 |
|
5 |
|
6 |
def get_sources(res, answer, view_sources=False):
|
7 |
source_elements = []
|
8 |
source_dict = {} # Dictionary to store URL elements
|
9 |
|
10 |
-
for idx, source in enumerate(res["
|
11 |
source_metadata = source.metadata
|
12 |
url = source_metadata.get("source", "N/A")
|
13 |
score = source_metadata.get("score", "N/A")
|
@@ -43,64 +44,77 @@ def get_sources(res, answer, view_sources=False):
|
|
43 |
if view_sources:
|
44 |
|
45 |
# Then, display the sources
|
46 |
-
|
47 |
-
|
48 |
-
full_answer +=
|
49 |
-
|
50 |
-
|
51 |
-
full_answer +=
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
|
56 |
-
|
57 |
-
if source_data["url"].lower().endswith(".pdf"):
|
58 |
-
name = f"Source {idx + 1} PDF\n"
|
59 |
full_answer += name
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
)
|
78 |
-
)
|
79 |
|
80 |
return full_answer, source_elements, source_dict
|
81 |
|
82 |
|
83 |
-
def get_prompt(config):
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
105 |
)
|
106 |
-
|
|
|
|
|
|
1 |
from modules.config.constants import *
|
2 |
import chainlit as cl
|
3 |
from langchain_core.prompts import PromptTemplate
|
4 |
+
from langchain_core.prompts import ChatPromptTemplate
|
5 |
|
6 |
|
7 |
def get_sources(res, answer, view_sources=False):
|
8 |
source_elements = []
|
9 |
source_dict = {} # Dictionary to store URL elements
|
10 |
|
11 |
+
for idx, source in enumerate(res["context"]):
|
12 |
source_metadata = source.metadata
|
13 |
url = source_metadata.get("source", "N/A")
|
14 |
score = source_metadata.get("score", "N/A")
|
|
|
44 |
if view_sources:
|
45 |
|
46 |
# Then, display the sources
|
47 |
+
# check if the answer has sources
|
48 |
+
if len(source_dict) == 0:
|
49 |
+
full_answer += "\n\n**No sources found.**"
|
50 |
+
return full_answer, source_elements, source_dict
|
51 |
+
else:
|
52 |
+
full_answer += "\n\n**Sources:**\n"
|
53 |
+
for idx, (url_name, source_data) in enumerate(source_dict.items()):
|
54 |
+
|
55 |
+
full_answer += f"\nSource {idx + 1} (Score: {source_data['score']}): {source_data['url']}\n"
|
56 |
|
57 |
+
name = f"Source {idx + 1} Text\n"
|
|
|
|
|
58 |
full_answer += name
|
59 |
+
source_elements.append(
|
60 |
+
cl.Text(name=name, content=source_data["text"], display="side")
|
61 |
+
)
|
62 |
+
|
63 |
+
# Add a PDF element if the source is a PDF file
|
64 |
+
if source_data["url"].lower().endswith(".pdf"):
|
65 |
+
name = f"Source {idx + 1} PDF\n"
|
66 |
+
full_answer += name
|
67 |
+
pdf_url = f"{source_data['url']}#page={source_data['page']+1}"
|
68 |
+
source_elements.append(
|
69 |
+
cl.Pdf(name=name, url=pdf_url, display="side")
|
70 |
+
)
|
71 |
+
|
72 |
+
full_answer += "\n**Metadata:**\n"
|
73 |
+
for idx, (url_name, source_data) in enumerate(source_dict.items()):
|
74 |
+
full_answer += f"\nSource {idx + 1} Metadata:\n"
|
75 |
+
source_elements.append(
|
76 |
+
cl.Text(
|
77 |
+
name=f"Source {idx + 1} Metadata",
|
78 |
+
content=f"Source: {source_data['url']}\n"
|
79 |
+
f"Page: {source_data['page']}\n"
|
80 |
+
f"Type: {source_data['source_type']}\n"
|
81 |
+
f"Date: {source_data['date']}\n"
|
82 |
+
f"TL;DR: {source_data['lecture_tldr']}\n"
|
83 |
+
f"Lecture Recording: {source_data['lecture_recording']}\n"
|
84 |
+
f"Suggested Readings: {source_data['suggested_readings']}\n",
|
85 |
+
display="side",
|
86 |
+
)
|
87 |
)
|
|
|
88 |
|
89 |
return full_answer, source_elements, source_dict
|
90 |
|
91 |
|
92 |
+
def get_prompt(config, prompt_type):
|
93 |
+
llm_params = config["llm_params"]
|
94 |
+
llm_loader = llm_params["llm_loader"]
|
95 |
+
use_history = llm_params["use_history"]
|
96 |
+
|
97 |
+
if prompt_type == "qa":
|
98 |
+
if llm_loader == "openai":
|
99 |
+
return (
|
100 |
+
OPENAI_PROMPT_WITH_HISTORY if use_history else OPENAI_PROMPT_NO_HISTORY
|
101 |
+
)
|
102 |
+
elif (
|
103 |
+
llm_loader == "local_llm"
|
104 |
+
and llm_params.get("local_llm_params") == "tiny-llama"
|
105 |
+
):
|
106 |
+
return (
|
107 |
+
TINYLLAMA_PROMPT_TEMPLATE_WITH_HISTORY
|
108 |
+
if use_history
|
109 |
+
else TINYLLAMA_PROMPT_TEMPLATE_NO_HISTORY
|
110 |
+
)
|
111 |
+
elif prompt_type == "rephrase":
|
112 |
+
prompt = ChatPromptTemplate.from_messages(
|
113 |
+
[
|
114 |
+
("system", OPENAI_REPHRASE_PROMPT),
|
115 |
+
("human", "{question}, {chat_history}"),
|
116 |
+
]
|
117 |
)
|
118 |
+
return OPENAI_REPHRASE_PROMPT
|
119 |
+
|
120 |
+
return None
|
code/modules/chat/langchain/langchain_rag.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.prompts import ChatPromptTemplate
|
2 |
+
|
3 |
+
from modules.chat.langchain.utils import *
|
4 |
+
|
5 |
+
|
6 |
+
class CustomConversationalRetrievalChain:
|
7 |
+
def __init__(self, llm, memory, retriever, qa_prompt: str, rephrase_prompt: str):
|
8 |
+
"""
|
9 |
+
Initialize the CustomConversationalRetrievalChain class.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
llm (LanguageModelLike): The language model instance.
|
13 |
+
memory (BaseChatMessageHistory): The chat message history instance.
|
14 |
+
retriever (BaseRetriever): The retriever instance.
|
15 |
+
qa_prompt (str): The QA prompt string.
|
16 |
+
rephrase_prompt (str): The rephrase prompt string.
|
17 |
+
"""
|
18 |
+
self.llm = llm
|
19 |
+
self.memory = memory
|
20 |
+
self.retriever = retriever
|
21 |
+
self.qa_prompt = qa_prompt
|
22 |
+
self.rephrase_prompt = rephrase_prompt
|
23 |
+
self.store = {}
|
24 |
+
|
25 |
+
# Contextualize question prompt
|
26 |
+
contextualize_q_system_prompt = rephrase_prompt or (
|
27 |
+
"Given a chat history and the latest user question "
|
28 |
+
"which might reference context in the chat history, "
|
29 |
+
"formulate a standalone question which can be understood "
|
30 |
+
"without the chat history. Do NOT answer the question, just "
|
31 |
+
"reformulate it if needed and otherwise return it as is."
|
32 |
+
)
|
33 |
+
self.contextualize_q_prompt = ChatPromptTemplate.from_messages(
|
34 |
+
[
|
35 |
+
("system", contextualize_q_system_prompt),
|
36 |
+
MessagesPlaceholder("chat_history"),
|
37 |
+
("human", "{input}"),
|
38 |
+
]
|
39 |
+
)
|
40 |
+
|
41 |
+
# History-aware retriever
|
42 |
+
self.history_aware_retriever = create_history_aware_retriever(
|
43 |
+
self.llm, self.retriever, self.contextualize_q_prompt
|
44 |
+
)
|
45 |
+
|
46 |
+
# Answer question prompt
|
47 |
+
qa_system_prompt = qa_prompt or (
|
48 |
+
"You are an assistant for question-answering tasks. Use "
|
49 |
+
"the following pieces of retrieved context to answer the "
|
50 |
+
"question. If you don't know the answer, just say that you "
|
51 |
+
"don't know. Use three sentences maximum and keep the answer "
|
52 |
+
"concise."
|
53 |
+
"\n\n"
|
54 |
+
"{context}"
|
55 |
+
)
|
56 |
+
self.qa_prompt_template = ChatPromptTemplate.from_messages(
|
57 |
+
[
|
58 |
+
("system", qa_system_prompt),
|
59 |
+
MessagesPlaceholder("chat_history"),
|
60 |
+
("human", "{input}"),
|
61 |
+
]
|
62 |
+
)
|
63 |
+
|
64 |
+
# Question-answer chain
|
65 |
+
self.question_answer_chain = create_stuff_documents_chain(
|
66 |
+
self.llm, self.qa_prompt_template
|
67 |
+
)
|
68 |
+
|
69 |
+
# Final retrieval chain
|
70 |
+
self.rag_chain = create_retrieval_chain(
|
71 |
+
self.history_aware_retriever, self.question_answer_chain
|
72 |
+
)
|
73 |
+
|
74 |
+
self.rag_chain = CustomRunnableWithHistory(
|
75 |
+
self.rag_chain,
|
76 |
+
get_session_history=self.get_session_history,
|
77 |
+
input_messages_key="input",
|
78 |
+
history_messages_key="chat_history",
|
79 |
+
output_messages_key="answer",
|
80 |
+
history_factory_config=[
|
81 |
+
ConfigurableFieldSpec(
|
82 |
+
id="user_id",
|
83 |
+
annotation=str,
|
84 |
+
name="User ID",
|
85 |
+
description="Unique identifier for the user.",
|
86 |
+
default="",
|
87 |
+
is_shared=True,
|
88 |
+
),
|
89 |
+
ConfigurableFieldSpec(
|
90 |
+
id="conversation_id",
|
91 |
+
annotation=str,
|
92 |
+
name="Conversation ID",
|
93 |
+
description="Unique identifier for the conversation.",
|
94 |
+
default="",
|
95 |
+
is_shared=True,
|
96 |
+
),
|
97 |
+
ConfigurableFieldSpec(
|
98 |
+
id="memory_window",
|
99 |
+
annotation=int,
|
100 |
+
name="Number of Conversations",
|
101 |
+
description="Number of conversations to consider for context.",
|
102 |
+
default=1,
|
103 |
+
is_shared=True,
|
104 |
+
),
|
105 |
+
],
|
106 |
+
)
|
107 |
+
|
108 |
+
def get_session_history(
|
109 |
+
self, user_id: str, conversation_id: str, memory_window: int
|
110 |
+
) -> BaseChatMessageHistory:
|
111 |
+
"""
|
112 |
+
Get the session history for a user and conversation.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
user_id (str): The user identifier.
|
116 |
+
conversation_id (str): The conversation identifier.
|
117 |
+
memory_window (int): The number of conversations to consider for context.
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
BaseChatMessageHistory: The chat message history.
|
121 |
+
"""
|
122 |
+
if (user_id, conversation_id) not in self.store:
|
123 |
+
self.store[(user_id, conversation_id)] = InMemoryHistory()
|
124 |
+
return self.store[(user_id, conversation_id)]
|
125 |
+
|
126 |
+
def invoke(self, user_query, config):
|
127 |
+
"""
|
128 |
+
Invoke the chain.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
kwargs: The input variables.
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
dict: The output variables.
|
135 |
+
"""
|
136 |
+
print(user_query, config)
|
137 |
+
return self.rag_chain.invoke(user_query, config)
|
code/modules/chat/langchain/utils.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Union, Tuple, Optional
|
2 |
+
from langchain_core.messages import (
|
3 |
+
BaseMessage,
|
4 |
+
AIMessage,
|
5 |
+
FunctionMessage,
|
6 |
+
HumanMessage,
|
7 |
+
)
|
8 |
+
|
9 |
+
from langchain_core.prompts.base import BasePromptTemplate, format_document
|
10 |
+
from langchain_core.prompts.chat import MessagesPlaceholder
|
11 |
+
from langchain_core.output_parsers import StrOutputParser
|
12 |
+
from langchain_core.output_parsers.base import BaseOutputParser
|
13 |
+
from langchain_core.retrievers import BaseRetriever, RetrieverOutput
|
14 |
+
from langchain_core.language_models import LanguageModelLike
|
15 |
+
from langchain_core.runnables import Runnable, RunnableBranch, RunnablePassthrough
|
16 |
+
from langchain_core.runnables.history import RunnableWithMessageHistory
|
17 |
+
from langchain_core.runnables.utils import ConfigurableFieldSpec
|
18 |
+
from langchain_core.chat_history import BaseChatMessageHistory
|
19 |
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
20 |
+
from langchain.chains.combine_documents.base import (
|
21 |
+
DEFAULT_DOCUMENT_PROMPT,
|
22 |
+
DEFAULT_DOCUMENT_SEPARATOR,
|
23 |
+
DOCUMENTS_KEY,
|
24 |
+
BaseCombineDocumentsChain,
|
25 |
+
_validate_prompt,
|
26 |
+
)
|
27 |
+
from langchain.chains.llm import LLMChain
|
28 |
+
from langchain_core.callbacks import Callbacks
|
29 |
+
from langchain_core.documents import Document
|
30 |
+
|
31 |
+
|
32 |
+
CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]
|
33 |
+
|
34 |
+
from langchain_core.runnables.config import RunnableConfig
|
35 |
+
from langchain_core.messages import BaseMessage
|
36 |
+
|
37 |
+
|
38 |
+
class CustomRunnableWithHistory(RunnableWithMessageHistory):
|
39 |
+
def _enter_history(self, input: Any, config: RunnableConfig) -> List[BaseMessage]:
|
40 |
+
"""
|
41 |
+
Get the last k conversations from the message history.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
input (Any): The input data.
|
45 |
+
config (RunnableConfig): The runnable configuration.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
List[BaseMessage]: The last k conversations.
|
49 |
+
"""
|
50 |
+
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
|
51 |
+
messages = hist.messages.copy()
|
52 |
+
|
53 |
+
if not self.history_messages_key:
|
54 |
+
# return all messages
|
55 |
+
messages += self._get_input_messages(input)
|
56 |
+
|
57 |
+
# return last k conversations
|
58 |
+
if config["configurable"]["memory_window"] == 0: # if k is 0, return empty list
|
59 |
+
messages = []
|
60 |
+
else:
|
61 |
+
messages = messages[-2 * config["configurable"]["memory_window"] :]
|
62 |
+
return messages
|
63 |
+
|
64 |
+
|
65 |
+
def _get_chat_history(chat_history: List[CHAT_TURN_TYPE], n: int = None) -> str:
|
66 |
+
"""
|
67 |
+
Convert chat history to a formatted string.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
chat_history (List[CHAT_TURN_TYPE]): The chat history.
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
str: The formatted chat history.
|
74 |
+
"""
|
75 |
+
_ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
|
76 |
+
buffer = ""
|
77 |
+
if n is not None:
|
78 |
+
# Calculate the number of turns to take (2 turns per pair)
|
79 |
+
turns_to_take = n * 2
|
80 |
+
chat_history = chat_history[-turns_to_take:]
|
81 |
+
for dialogue_turn in chat_history:
|
82 |
+
if isinstance(dialogue_turn, BaseMessage):
|
83 |
+
role_prefix = _ROLE_MAP.get(dialogue_turn.type, f"{dialogue_turn.type}: ")
|
84 |
+
buffer += f"\n{role_prefix}{dialogue_turn.content}"
|
85 |
+
elif isinstance(dialogue_turn, tuple):
|
86 |
+
human = "Student: " + dialogue_turn[0]
|
87 |
+
ai = "AI Tutor: " + dialogue_turn[1]
|
88 |
+
buffer += "\n" + "\n".join([human, ai])
|
89 |
+
else:
|
90 |
+
raise ValueError(
|
91 |
+
f"Unsupported chat history format: {type(dialogue_turn)}."
|
92 |
+
f" Full chat history: {chat_history} "
|
93 |
+
)
|
94 |
+
return buffer
|
95 |
+
|
96 |
+
|
97 |
+
class InMemoryHistory(BaseChatMessageHistory, BaseModel):
|
98 |
+
"""In-memory implementation of chat message history."""
|
99 |
+
|
100 |
+
messages: List[BaseMessage] = Field(default_factory=list)
|
101 |
+
|
102 |
+
def add_messages(self, messages: List[BaseMessage]) -> None:
|
103 |
+
"""Add a list of messages to the store."""
|
104 |
+
self.messages.extend(messages)
|
105 |
+
|
106 |
+
def clear(self) -> None:
|
107 |
+
"""Clear the message history."""
|
108 |
+
self.messages = []
|
109 |
+
|
110 |
+
def __len__(self) -> int:
|
111 |
+
"""Return the number of messages."""
|
112 |
+
return len(self.messages)
|
113 |
+
|
114 |
+
def get_last_n_conversations(self, n: int) -> "InMemoryHistory":
|
115 |
+
"""Return a new InMemoryHistory object with the last n conversations from the message history.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
n (int): The number of last conversations to return. If 0, return an empty history.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
InMemoryHistory: A new InMemoryHistory object containing the last n conversations.
|
122 |
+
"""
|
123 |
+
if n == 0:
|
124 |
+
return InMemoryHistory()
|
125 |
+
# Each conversation consists of a pair of messages (human + AI)
|
126 |
+
num_messages = n * 2
|
127 |
+
last_messages = self.messages[-num_messages:]
|
128 |
+
return InMemoryHistory(messages=last_messages)
|
129 |
+
|
130 |
+
|
131 |
+
def create_history_aware_retriever(
|
132 |
+
llm: LanguageModelLike,
|
133 |
+
retriever: BaseRetriever,
|
134 |
+
prompt: BasePromptTemplate,
|
135 |
+
) -> Runnable[Dict[str, Any], RetrieverOutput]:
|
136 |
+
"""Create a chain that takes conversation history and returns documents."""
|
137 |
+
if "input" not in prompt.input_variables:
|
138 |
+
raise ValueError(
|
139 |
+
"Expected `input` to be a prompt variable, "
|
140 |
+
f"but got {prompt.input_variables}"
|
141 |
+
)
|
142 |
+
|
143 |
+
retrieve_documents = RunnableBranch(
|
144 |
+
(
|
145 |
+
lambda x: not x["chat_history"],
|
146 |
+
(lambda x: x["input"]) | retriever,
|
147 |
+
),
|
148 |
+
prompt | llm | StrOutputParser() | retriever,
|
149 |
+
).with_config(run_name="chat_retriever_chain")
|
150 |
+
|
151 |
+
return retrieve_documents
|
152 |
+
|
153 |
+
|
154 |
+
def create_stuff_documents_chain(
|
155 |
+
llm: LanguageModelLike,
|
156 |
+
prompt: BasePromptTemplate,
|
157 |
+
output_parser: Optional[BaseOutputParser] = None,
|
158 |
+
document_prompt: Optional[BasePromptTemplate] = None,
|
159 |
+
document_separator: str = DEFAULT_DOCUMENT_SEPARATOR,
|
160 |
+
) -> Runnable[Dict[str, Any], Any]:
|
161 |
+
"""Create a chain for passing a list of Documents to a model."""
|
162 |
+
_validate_prompt(prompt)
|
163 |
+
_document_prompt = document_prompt or DEFAULT_DOCUMENT_PROMPT
|
164 |
+
_output_parser = output_parser or StrOutputParser()
|
165 |
+
|
166 |
+
def format_docs(inputs: dict) -> str:
|
167 |
+
return document_separator.join(
|
168 |
+
format_document(doc, _document_prompt) for doc in inputs[DOCUMENTS_KEY]
|
169 |
+
)
|
170 |
+
|
171 |
+
return (
|
172 |
+
RunnablePassthrough.assign(**{DOCUMENTS_KEY: format_docs}).with_config(
|
173 |
+
run_name="format_inputs"
|
174 |
+
)
|
175 |
+
| prompt
|
176 |
+
| llm
|
177 |
+
| _output_parser
|
178 |
+
).with_config(run_name="stuff_documents_chain")
|
179 |
+
|
180 |
+
|
181 |
+
def create_retrieval_chain(
|
182 |
+
retriever: Union[BaseRetriever, Runnable[dict, RetrieverOutput]],
|
183 |
+
combine_docs_chain: Runnable[Dict[str, Any], str],
|
184 |
+
) -> Runnable:
|
185 |
+
"""Create retrieval chain that retrieves documents and then passes them on."""
|
186 |
+
if not isinstance(retriever, BaseRetriever):
|
187 |
+
retrieval_docs = retriever
|
188 |
+
else:
|
189 |
+
retrieval_docs = (lambda x: x["input"]) | retriever
|
190 |
+
|
191 |
+
retrieval_chain = (
|
192 |
+
RunnablePassthrough.assign(
|
193 |
+
context=retrieval_docs.with_config(run_name="retrieve_documents"),
|
194 |
+
).assign(answer=combine_docs_chain)
|
195 |
+
).with_config(run_name="retrieval_chain")
|
196 |
+
|
197 |
+
return retrieval_chain
|
code/modules/chat/llm_tutor.py
CHANGED
@@ -1,216 +1,140 @@
|
|
1 |
-
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
|
2 |
-
from langchain.memory import (
|
3 |
-
ConversationBufferWindowMemory,
|
4 |
-
ConversationSummaryBufferMemory,
|
5 |
-
)
|
6 |
-
from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
|
7 |
-
import os
|
8 |
-
from modules.config.constants import *
|
9 |
from modules.chat.helpers import get_prompt
|
10 |
from modules.chat.chat_model_loader import ChatModelLoader
|
11 |
from modules.vectorstore.store_manager import VectorStoreManager
|
12 |
-
|
13 |
from modules.retriever.retriever import Retriever
|
14 |
-
|
15 |
-
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
16 |
-
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
|
17 |
-
import inspect
|
18 |
-
from langchain.chains.conversational_retrieval.base import _get_chat_history
|
19 |
-
from langchain_core.messages import BaseMessage
|
20 |
-
|
21 |
-
CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]
|
22 |
-
|
23 |
-
from langchain_core.output_parsers import StrOutputParser
|
24 |
-
from langchain_core.prompts import ChatPromptTemplate
|
25 |
-
from langchain_community.chat_models import ChatOpenAI
|
26 |
-
|
27 |
-
|
28 |
-
class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
29 |
-
|
30 |
-
def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
|
31 |
-
_ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
|
32 |
-
buffer = ""
|
33 |
-
for dialogue_turn in chat_history:
|
34 |
-
if isinstance(dialogue_turn, BaseMessage):
|
35 |
-
role_prefix = _ROLE_MAP.get(
|
36 |
-
dialogue_turn.type, f"{dialogue_turn.type}: "
|
37 |
-
)
|
38 |
-
buffer += f"\n{role_prefix}{dialogue_turn.content}"
|
39 |
-
elif isinstance(dialogue_turn, tuple):
|
40 |
-
human = "Student: " + dialogue_turn[0]
|
41 |
-
ai = "AI Tutor: " + dialogue_turn[1]
|
42 |
-
buffer += "\n" + "\n".join([human, ai])
|
43 |
-
else:
|
44 |
-
raise ValueError(
|
45 |
-
f"Unsupported chat history format: {type(dialogue_turn)}."
|
46 |
-
f" Full chat history: {chat_history} "
|
47 |
-
)
|
48 |
-
return buffer
|
49 |
-
|
50 |
-
async def _acall(
|
51 |
-
self,
|
52 |
-
inputs: Dict[str, Any],
|
53 |
-
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
54 |
-
) -> Dict[str, Any]:
|
55 |
-
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
56 |
-
question = inputs["question"]
|
57 |
-
get_chat_history = self._get_chat_history
|
58 |
-
chat_history_str = get_chat_history(inputs["chat_history"])
|
59 |
-
if chat_history_str:
|
60 |
-
# callbacks = _run_manager.get_child()
|
61 |
-
# new_question = await self.question_generator.arun(
|
62 |
-
# question=question, chat_history=chat_history_str, callbacks=callbacks
|
63 |
-
# )
|
64 |
-
system = (
|
65 |
-
"You are someone that rephrases statements. Rephrase the student's question to add context from their chat history if relevant, ensuring it remains from the student's point of view. "
|
66 |
-
"Incorporate relevant details from the chat history to make the question clearer and more specific."
|
67 |
-
"Do not change the meaning of the original statement, and maintain the student's tone and perspective. "
|
68 |
-
"If the question is conversational and doesn't require context, do not rephrase it. "
|
69 |
-
"Example: If the student previously asked about backpropagation in the context of deep learning and now asks 'what is it', rephrase to 'What is backprogatation.'. "
|
70 |
-
"Example: Do not rephrase if the user is asking something specific like 'cool, suggest a project with transformers to use as my final project'"
|
71 |
-
"Chat history: \n{chat_history_str}\n"
|
72 |
-
"Rephrase the following question only if necessary: '{question}'"
|
73 |
-
)
|
74 |
-
|
75 |
-
prompt = ChatPromptTemplate.from_messages(
|
76 |
-
[
|
77 |
-
("system", system),
|
78 |
-
("human", "{question}, {chat_history_str}"),
|
79 |
-
]
|
80 |
-
)
|
81 |
-
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
82 |
-
step_back = prompt | llm | StrOutputParser()
|
83 |
-
new_question = step_back.invoke(
|
84 |
-
{"question": question, "chat_history_str": chat_history_str}
|
85 |
-
)
|
86 |
-
else:
|
87 |
-
new_question = question
|
88 |
-
accepts_run_manager = (
|
89 |
-
"run_manager" in inspect.signature(self._aget_docs).parameters
|
90 |
-
)
|
91 |
-
if accepts_run_manager:
|
92 |
-
docs = await self._aget_docs(new_question, inputs, run_manager=_run_manager)
|
93 |
-
else:
|
94 |
-
docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg]
|
95 |
-
|
96 |
-
output: Dict[str, Any] = {}
|
97 |
-
output["original_question"] = question
|
98 |
-
if self.response_if_no_docs_found is not None and len(docs) == 0:
|
99 |
-
output[self.output_key] = self.response_if_no_docs_found
|
100 |
-
else:
|
101 |
-
new_inputs = inputs.copy()
|
102 |
-
if self.rephrase_question:
|
103 |
-
new_inputs["question"] = new_question
|
104 |
-
new_inputs["chat_history"] = chat_history_str
|
105 |
-
|
106 |
-
# Prepare the final prompt with metadata
|
107 |
-
context = "\n\n".join(
|
108 |
-
[
|
109 |
-
f"Context {idx+1}: \n(Document content: {doc.page_content}\nMetadata: (source_file: {doc.metadata['source'] if 'source' in doc.metadata else 'unknown'}))"
|
110 |
-
for idx, doc in enumerate(docs)
|
111 |
-
]
|
112 |
-
)
|
113 |
-
final_prompt = (
|
114 |
-
"You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance."
|
115 |
-
"If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. "
|
116 |
-
"Use chat history and context as guides but avoid repeating past responses. Provide links from the source_file metadata. Use the source context that is most relevent."
|
117 |
-
"Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n\n"
|
118 |
-
f"Chat History:\n{chat_history_str}\n\n"
|
119 |
-
f"Context:\n{context}\n\n"
|
120 |
-
"Answer the student's question below in a friendly, concise, and engaging manner. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation.\n"
|
121 |
-
f"Student: {question}\n"
|
122 |
-
"AI Tutor:"
|
123 |
-
)
|
124 |
-
|
125 |
-
# new_inputs["input"] = final_prompt
|
126 |
-
new_inputs["question"] = final_prompt
|
127 |
-
# output["final_prompt"] = final_prompt
|
128 |
-
|
129 |
-
answer = await self.combine_docs_chain.arun(
|
130 |
-
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
|
131 |
-
)
|
132 |
-
output[self.output_key] = answer
|
133 |
-
|
134 |
-
if self.return_source_documents:
|
135 |
-
output["source_documents"] = docs
|
136 |
-
output["rephrased_question"] = new_question
|
137 |
-
return output
|
138 |
|
139 |
|
140 |
class LLMTutor:
|
141 |
-
def __init__(self, config, logger=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
self.config = config
|
143 |
self.llm = self.load_llm()
|
|
|
144 |
self.logger = logger
|
145 |
self.vector_db = VectorStoreManager(config, logger=self.logger)
|
146 |
if self.config["vectorstore"]["embedd_files"]:
|
147 |
self.vector_db.create_database()
|
148 |
self.vector_db.save_database()
|
149 |
|
150 |
-
def
|
151 |
"""
|
152 |
-
|
|
|
|
|
|
|
153 |
"""
|
154 |
-
|
155 |
-
|
|
|
|
|
|
|
156 |
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
-
|
160 |
-
|
|
|
161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
retriever = Retriever(self.config)._return_retriever(db)
|
163 |
|
164 |
if self.config["llm_params"]["use_history"]:
|
165 |
-
|
166 |
-
memory = ConversationBufferWindowMemory(
|
167 |
-
k=self.config["llm_params"]["memory_window"],
|
168 |
-
memory_key="chat_history",
|
169 |
-
return_messages=True,
|
170 |
-
output_key="answer",
|
171 |
-
)
|
172 |
-
qa_chain = CustomConversationalRetrievalChain.from_llm(
|
173 |
llm=llm,
|
174 |
-
chain_type="stuff",
|
175 |
-
retriever=retriever,
|
176 |
-
return_source_documents=True,
|
177 |
memory=memory,
|
178 |
-
combine_docs_chain_kwargs={"prompt": prompt},
|
179 |
-
response_if_no_docs_found="No context found",
|
180 |
-
)
|
181 |
-
else:
|
182 |
-
qa_chain = RetrievalQA.from_chain_type(
|
183 |
-
llm=llm,
|
184 |
-
chain_type="stuff",
|
185 |
retriever=retriever,
|
186 |
-
|
187 |
-
|
188 |
)
|
189 |
return qa_chain
|
190 |
|
191 |
-
# Loading the model
|
192 |
def load_llm(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
chat_model_loader = ChatModelLoader(self.config)
|
194 |
llm = chat_model_loader.load_chat_model()
|
195 |
return llm
|
196 |
|
197 |
-
|
198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
db = self.vector_db.load_database()
|
200 |
# sanity check to see if there are any documents in the database
|
201 |
if len(db) == 0:
|
202 |
raise ValueError(
|
203 |
"No documents in the database. Populate the database first."
|
204 |
)
|
205 |
-
|
206 |
-
qa = self.retrieval_qa_chain(
|
207 |
-
self.llm, qa_prompt, db, memory
|
208 |
-
) # TODO: PROMPT is overwritten in CustomConversationalRetrievalChain
|
209 |
|
210 |
return qa
|
211 |
|
212 |
-
# output function
|
213 |
def final_result(query):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
qa_result = qa_bot()
|
215 |
response = qa_result({"query": query})
|
216 |
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from modules.chat.helpers import get_prompt
|
2 |
from modules.chat.chat_model_loader import ChatModelLoader
|
3 |
from modules.vectorstore.store_manager import VectorStoreManager
|
|
|
4 |
from modules.retriever.retriever import Retriever
|
5 |
+
from modules.chat.langchain.langchain_rag import CustomConversationalRetrievalChain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
class LLMTutor:
|
9 |
+
def __init__(self, config, user, logger=None):
|
10 |
+
"""
|
11 |
+
Initialize the LLMTutor class.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
config (dict): Configuration dictionary.
|
15 |
+
user (str): User identifier.
|
16 |
+
logger (Logger, optional): Logger instance. Defaults to None.
|
17 |
+
"""
|
18 |
self.config = config
|
19 |
self.llm = self.load_llm()
|
20 |
+
self.user = user
|
21 |
self.logger = logger
|
22 |
self.vector_db = VectorStoreManager(config, logger=self.logger)
|
23 |
if self.config["vectorstore"]["embedd_files"]:
|
24 |
self.vector_db.create_database()
|
25 |
self.vector_db.save_database()
|
26 |
|
27 |
+
def update_llm(self, new_config):
|
28 |
"""
|
29 |
+
Update the LLM and VectorStoreManager based on new configuration.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
new_config (dict): New configuration dictionary.
|
33 |
"""
|
34 |
+
changes = self.get_config_changes(self.config, new_config)
|
35 |
+
self.config = new_config
|
36 |
+
|
37 |
+
if "chat_model" in changes:
|
38 |
+
self.llm = self.load_llm() # Reinitialize LLM if chat_model changes
|
39 |
|
40 |
+
if "vectorstore" in changes:
|
41 |
+
self.vector_db = VectorStoreManager(
|
42 |
+
self.config, logger=self.logger
|
43 |
+
) # Reinitialize VectorStoreManager if vectorstore changes
|
44 |
+
if self.config["vectorstore"]["embedd_files"]:
|
45 |
+
self.vector_db.create_database()
|
46 |
+
self.vector_db.save_database()
|
47 |
|
48 |
+
def get_config_changes(self, old_config, new_config):
|
49 |
+
"""
|
50 |
+
Get the changes between the old and new configuration.
|
51 |
|
52 |
+
Args:
|
53 |
+
old_config (dict): Old configuration dictionary.
|
54 |
+
new_config (dict): New configuration dictionary.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
dict: Dictionary containing the changes.
|
58 |
+
"""
|
59 |
+
changes = {}
|
60 |
+
for key in new_config:
|
61 |
+
if old_config.get(key) != new_config[key]:
|
62 |
+
changes[key] = (old_config.get(key), new_config[key])
|
63 |
+
return changes
|
64 |
+
|
65 |
+
def retrieval_qa_chain(self, llm, qa_prompt, rephrase_prompt, db, memory=None):
|
66 |
+
"""
|
67 |
+
Create a Retrieval QA Chain.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
llm (LLM): The language model instance.
|
71 |
+
qa_prompt (str): The QA prompt string.
|
72 |
+
rephrase_prompt (str): The rephrase prompt string.
|
73 |
+
db (VectorStore): The vector store instance.
|
74 |
+
memory (Memory, optional): Memory instance. Defaults to None.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
Chain: The retrieval QA chain instance.
|
78 |
+
"""
|
79 |
retriever = Retriever(self.config)._return_retriever(db)
|
80 |
|
81 |
if self.config["llm_params"]["use_history"]:
|
82 |
+
qa_chain = CustomConversationalRetrievalChain(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
llm=llm,
|
|
|
|
|
|
|
84 |
memory=memory,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
retriever=retriever,
|
86 |
+
qa_prompt=qa_prompt,
|
87 |
+
rephrase_prompt=rephrase_prompt,
|
88 |
)
|
89 |
return qa_chain
|
90 |
|
|
|
91 |
def load_llm(self):
|
92 |
+
"""
|
93 |
+
Load the language model.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
LLM: The loaded language model instance.
|
97 |
+
"""
|
98 |
chat_model_loader = ChatModelLoader(self.config)
|
99 |
llm = chat_model_loader.load_chat_model()
|
100 |
return llm
|
101 |
|
102 |
+
def qa_bot(self, memory=None, qa_prompt=None, rephrase_prompt=None):
|
103 |
+
"""
|
104 |
+
Create a QA bot instance.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
memory (Memory, optional): Memory instance. Defaults to None.
|
108 |
+
qa_prompt (str, optional): QA prompt string. Defaults to None.
|
109 |
+
rephrase_prompt (str, optional): Rephrase prompt string. Defaults to None.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
Chain: The QA bot chain instance.
|
113 |
+
"""
|
114 |
+
if qa_prompt is None:
|
115 |
+
qa_prompt = get_prompt(self.config, "qa")
|
116 |
+
if rephrase_prompt is None:
|
117 |
+
rephrase_prompt = get_prompt(self.config, "rephrase")
|
118 |
db = self.vector_db.load_database()
|
119 |
# sanity check to see if there are any documents in the database
|
120 |
if len(db) == 0:
|
121 |
raise ValueError(
|
122 |
"No documents in the database. Populate the database first."
|
123 |
)
|
124 |
+
qa = self.retrieval_qa_chain(self.llm, qa_prompt, rephrase_prompt, db, memory)
|
|
|
|
|
|
|
125 |
|
126 |
return qa
|
127 |
|
|
|
128 |
def final_result(query):
|
129 |
+
"""
|
130 |
+
Get the final result for a given query.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
query (str): The query string.
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
str: The response string.
|
137 |
+
"""
|
138 |
qa_result = qa_bot()
|
139 |
response = qa_result({"query": query})
|
140 |
return response
|
code/modules/chat_processor/chat_processor.py
CHANGED
@@ -2,13 +2,25 @@ from modules.chat_processor.literal_ai import LiteralaiChatProcessor
|
|
2 |
|
3 |
|
4 |
class ChatProcessor:
|
5 |
-
def __init__(self,
|
6 |
-
self.
|
7 |
-
self.
|
8 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
if self.logging:
|
10 |
self._init_processor()
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
def _init_processor(self):
|
13 |
if self.chat_processor_type == "literalai":
|
14 |
self.processor = LiteralaiChatProcessor(self.tags)
|
@@ -23,8 +35,21 @@ class ChatProcessor:
|
|
23 |
else:
|
24 |
pass
|
25 |
|
26 |
-
async def rag(self, user_query: str, chain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
if self.logging:
|
28 |
-
return await self.processor.rag(
|
|
|
|
|
29 |
else:
|
30 |
-
return
|
|
|
2 |
|
3 |
|
4 |
class ChatProcessor:
|
5 |
+
def __init__(self, llm_tutor, tags=None):
|
6 |
+
self.llm_tutor = llm_tutor
|
7 |
+
self.config = self.llm_tutor.config
|
8 |
+
self.chat_processor_type = self.config["chat_logging"]["platform"]
|
9 |
+
self.logging = self.config["chat_logging"]["log_chat"]
|
10 |
+
self.user = self.llm_tutor.user
|
11 |
+
if tags is None:
|
12 |
+
self.tags = self._create_tags()
|
13 |
+
else:
|
14 |
+
self.tags = tags
|
15 |
if self.logging:
|
16 |
self._init_processor()
|
17 |
|
18 |
+
def _create_tags(self):
|
19 |
+
tags = []
|
20 |
+
tags.append(self.config["vectorstore"]["db_option"])
|
21 |
+
tags.append(self.config["llm_params"]["chat_profile"])
|
22 |
+
return tags
|
23 |
+
|
24 |
def _init_processor(self):
|
25 |
if self.chat_processor_type == "literalai":
|
26 |
self.processor = LiteralaiChatProcessor(self.tags)
|
|
|
35 |
else:
|
36 |
pass
|
37 |
|
38 |
+
async def rag(self, user_query: str, chain):
|
39 |
+
user_query_dict = {"input": user_query}
|
40 |
+
# Define the base configuration
|
41 |
+
config = {
|
42 |
+
"configurable": {
|
43 |
+
"user_id": self.user["user_id"],
|
44 |
+
"conversation_id": self.user["session_id"],
|
45 |
+
"memory_window": self.llm_tutor.config["llm_params"]["memory_window"],
|
46 |
+
}
|
47 |
+
}
|
48 |
+
|
49 |
+
# Process the user query using the appropriate method
|
50 |
if self.logging:
|
51 |
+
return await self.processor.rag(
|
52 |
+
user_query=user_query_dict, config=config, chain=chain
|
53 |
+
)
|
54 |
else:
|
55 |
+
return chain.invoke(user_query=user_query_dict, config=config)
|
code/modules/chat_processor/literal_ai.py
CHANGED
@@ -27,11 +27,11 @@ class LiteralaiChatProcessor(ChatProcessorBase):
|
|
27 |
name="AI_Tutor",
|
28 |
)
|
29 |
|
30 |
-
async def rag(self, user_query:
|
31 |
with self.literal_client.step(
|
32 |
type="retrieval", name="RAG", thread_id=self.thread_id
|
33 |
) as step:
|
34 |
-
step.input = {"question": user_query}
|
35 |
-
res = await chain.
|
36 |
step.output = res
|
37 |
return res
|
|
|
27 |
name="AI_Tutor",
|
28 |
)
|
29 |
|
30 |
+
async def rag(self, user_query: dict, config: dict, chain):
|
31 |
with self.literal_client.step(
|
32 |
type="retrieval", name="RAG", thread_id=self.thread_id
|
33 |
) as step:
|
34 |
+
step.input = {"question": user_query["input"]}
|
35 |
+
res = await chain.invoke(user_query, config)
|
36 |
step.output = res
|
37 |
return res
|
code/modules/config/constants.py
CHANGED
@@ -13,70 +13,69 @@ opening_message = f"Hey, What Can I Help You With?\n\nYou can me ask me question
|
|
13 |
|
14 |
# Prompt Templates
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
""
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
""
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
""
|
56 |
-
|
57 |
-
|
58 |
-
<|im_start|>
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
<|im_start|>
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
<|im_end
|
72 |
-
<|im_start|>user
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
""
|
77 |
-
|
78 |
-
|
79 |
# Model Paths
|
80 |
|
81 |
LLAMA_PATH = "../storage/models/tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf"
|
82 |
-
MISTRAL_PATH = "storage/models/mistral-7b-v0.1.Q4_K_M.gguf"
|
|
|
13 |
|
14 |
# Prompt Templates
|
15 |
|
16 |
+
OPENAI_REPHRASE_PROMPT = (
|
17 |
+
"You are someone that rephrases statements. Rephrase the student's question to add context from their chat history if relevant, ensuring it remains from the student's point of view. "
|
18 |
+
"Incorporate relevant details from the chat history to make the question clearer and more specific. "
|
19 |
+
"Do not change the meaning of the original statement, and maintain the student's tone and perspective. "
|
20 |
+
"If the question is conversational and doesn't require context, do not rephrase it. "
|
21 |
+
"Example: If the student previously asked about backpropagation in the context of deep learning and now asks 'what is it', rephrase to 'What is backpropagation.'. "
|
22 |
+
"Example: Do not rephrase if the user is asking something specific like 'cool, suggest a project with transformers to use as my final project' "
|
23 |
+
"Chat history: \n{chat_history}\n"
|
24 |
+
"Rephrase the following question only if necessary: '{input}'"
|
25 |
+
)
|
26 |
+
|
27 |
+
OPENAI_PROMPT_WITH_HISTORY = (
|
28 |
+
"You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance. "
|
29 |
+
"If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. "
|
30 |
+
"Use chat history and context as guides but avoid repeating past responses. Provide links from the source_file metadata. Use the source context that is most relevant. "
|
31 |
+
"Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n\n"
|
32 |
+
"Chat History:\n{chat_history}\n\n"
|
33 |
+
"Context:\n{context}\n\n"
|
34 |
+
"Answer the student's question below in a friendly, concise, and engaging manner. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation.\n"
|
35 |
+
"Student: {input}\n"
|
36 |
+
"AI Tutor:"
|
37 |
+
)
|
38 |
+
|
39 |
+
OPENAAI_PROMPT_NO_HISTORY = (
|
40 |
+
"You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance. "
|
41 |
+
"If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. "
|
42 |
+
"Provide links from the source_file metadata. Use the source context that is most relevant. "
|
43 |
+
"Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n\n"
|
44 |
+
"Context:\n{context}\n\n"
|
45 |
+
"Answer the student's question below in a friendly, concise, and engaging manner. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation.\n"
|
46 |
+
"Student: {input}\n"
|
47 |
+
"AI Tutor:"
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
TINYLLAMA_PROMPT_TEMPLATE_NO_HISTORY = (
|
52 |
+
"<|im_start|>system\n"
|
53 |
+
"Assistant is an intelligent chatbot designed to help students with questions regarding the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance.\n"
|
54 |
+
"If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally.\n"
|
55 |
+
"Provide links from the source_file metadata. Use the source context that is most relevant.\n"
|
56 |
+
"Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n"
|
57 |
+
"<|im_end|>\n\n"
|
58 |
+
"<|im_start|>user\n"
|
59 |
+
"Context:\n{context}\n\n"
|
60 |
+
"Question: {input}\n"
|
61 |
+
"<|im_end|>\n\n"
|
62 |
+
"<|im_start|>assistant"
|
63 |
+
)
|
64 |
+
|
65 |
+
TINYLLAMA_PROMPT_TEMPLATE_WITH_HISTORY = (
|
66 |
+
"<|im_start|>system\n"
|
67 |
+
"You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance. "
|
68 |
+
"If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. "
|
69 |
+
"Use chat history and context as guides but avoid repeating past responses. Provide links from the source_file metadata. Use the source context that is most relevant. "
|
70 |
+
"Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n"
|
71 |
+
"<|im_end|>\n\n"
|
72 |
+
"<|im_start|>user\n"
|
73 |
+
"Chat History:\n{chat_history}\n\n"
|
74 |
+
"Context:\n{context}\n\n"
|
75 |
+
"Question: {input}\n"
|
76 |
+
"<|im_end|>\n\n"
|
77 |
+
"<|im_start|>assistant"
|
78 |
+
)
|
79 |
# Model Paths
|
80 |
|
81 |
LLAMA_PATH = "../storage/models/tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf"
|
|