Spaces:
Running
Running
import chainlit.data as cl_data | |
import asyncio | |
from typing import Any, Dict, no_type_check | |
import chainlit as cl | |
from edubotics_core.chat.llm_tutor import LLMTutor | |
from edubotics_core.chat.helpers import ( | |
get_sources, | |
get_history_setup_llm, | |
) | |
import copy | |
from langchain_community.callbacks import get_openai_callback | |
from config.config_manager import config_manager | |
USER_TIMEOUT = 60_000 | |
SYSTEM = "System" | |
LLM = "AI Tutor" | |
AGENT = "Agent" | |
YOU = "User" | |
ERROR = "Error" | |
config = config_manager.get_config().dict() | |
class Chatbot: | |
def __init__(self, config): | |
""" | |
Initialize the Chatbot class. | |
""" | |
self.config = config | |
async def setup_llm(self): | |
""" | |
Set up the LLM with the provided settings. Update the configuration and initialize the LLM tutor. | |
#TODO: Clean this up. | |
""" | |
llm_settings = cl.user_session.get("llm_settings", {}) | |
( | |
chat_profile, | |
retriever_method, | |
memory_window, | |
llm_style, | |
generate_follow_up, | |
chunking_mode, | |
) = ( | |
llm_settings.get("chat_model"), | |
llm_settings.get("retriever_method"), | |
llm_settings.get("memory_window"), | |
llm_settings.get("llm_style"), | |
llm_settings.get("follow_up_questions"), | |
llm_settings.get("chunking_mode"), | |
) | |
chain = cl.user_session.get("chain") | |
memory_list = cl.user_session.get( | |
"memory", | |
( | |
list(chain.store.values())[0].messages | |
if len(chain.store.values()) > 0 | |
else [] | |
), | |
) | |
conversation_list = get_history_setup_llm(memory_list) | |
old_config = copy.deepcopy(self.config) | |
self.config["vectorstore"]["db_option"] = retriever_method | |
self.config["llm_params"]["memory_window"] = memory_window | |
self.config["llm_params"]["llm_style"] = llm_style | |
self.config["llm_params"]["llm_loader"] = chat_profile | |
self.config["llm_params"]["generate_follow_up"] = generate_follow_up | |
self.config["splitter_options"]["chunking_mode"] = chunking_mode | |
self.llm_tutor.update_llm( | |
old_config, self.config | |
) # update only llm attributes that are changed | |
self.chain = self.llm_tutor.qa_bot( | |
memory=conversation_list, | |
) | |
cl.user_session.set("chain", self.chain) | |
cl.user_session.set("llm_tutor", self.llm_tutor) | |
async def update_llm(self, new_settings: Dict[str, Any]): | |
""" | |
Update the LLM settings and reinitialize the LLM with the new settings. | |
Args: | |
new_settings (Dict[str, Any]): The new settings to update. | |
""" | |
cl.user_session.set("llm_settings", new_settings) | |
await self.inform_llm_settings() | |
await self.setup_llm() | |
async def make_llm_settings_widgets(self, config=None): | |
""" | |
Create and send the widgets for LLM settings configuration. | |
Args: | |
config: The configuration to use for setting up the widgets. | |
""" | |
config = config or self.config | |
await cl.ChatSettings( | |
[ | |
cl.input_widget.Select( | |
id="chat_model", | |
label="Model Name (Default GPT-3)", | |
values=["local_llm", "gpt-3.5-turbo-1106", "gpt-4", "gpt-4o-mini"], | |
initial_index=[ | |
"local_llm", | |
"gpt-3.5-turbo-1106", | |
"gpt-4", | |
"gpt-4o-mini", | |
].index(config["llm_params"]["llm_loader"]), | |
), | |
cl.input_widget.Select( | |
id="retriever_method", | |
label="Retriever (Default FAISS)", | |
values=["FAISS", "Chroma", "RAGatouille", "RAPTOR"], | |
initial_index=["FAISS", "Chroma", "RAGatouille", "RAPTOR"].index( | |
config["vectorstore"]["db_option"] | |
), | |
), | |
cl.input_widget.Slider( | |
id="memory_window", | |
label="Memory Window (Default 3)", | |
initial=3, | |
min=0, | |
max=10, | |
step=1, | |
), | |
cl.input_widget.Switch( | |
id="view_sources", label="View Sources", initial=False | |
), | |
cl.input_widget.Switch( | |
id="stream_response", | |
label="Stream response", | |
initial=config["llm_params"]["stream"], | |
), | |
cl.input_widget.Select( | |
id="chunking_mode", | |
label="Chunking mode", | |
values=["fixed", "semantic"], | |
initial_index=1, | |
), | |
cl.input_widget.Switch( | |
id="follow_up_questions", | |
label="Generate follow up questions", | |
initial=False, | |
), | |
cl.input_widget.Select( | |
id="llm_style", | |
label="Type of Conversation (Default Normal)", | |
values=["Normal", "ELI5"], | |
initial_index=0, | |
), | |
] | |
).send() | |
async def inform_llm_settings(self): | |
""" | |
Inform the user about the updated LLM settings and display them as a message. | |
""" | |
await cl.Message( | |
author=SYSTEM, | |
content="LLM settings have been updated. You can continue with your Query!", | |
).send() | |
async def set_starters(self): | |
""" | |
Set starter messages for the chatbot. | |
""" | |
return [ | |
cl.Starter( | |
label="recording on Transformers?", | |
message="Where can I find the recording for the lecture on Transformers?", | |
icon="/public/assets/images/starter_icons/adv-screen-recorder-svgrepo-com.svg", | |
), | |
cl.Starter( | |
label="where's the slides?", | |
message="When are the lectures? I can't find the schedule.", | |
icon="/public/assets/images/starter_icons/alarmy-svgrepo-com.svg", | |
), | |
cl.Starter( | |
label="Due Date?", | |
message="When is the final project due?", | |
icon="/public/assets/images/starter_icons/calendar-samsung-17-svgrepo-com.svg", | |
), | |
cl.Starter( | |
label="Explain backprop.", | |
message="I didn't understand the math behind backprop, could you explain it?", | |
icon="/public/assets/images/starter_icons/acastusphoton-svgrepo-com.svg", | |
), | |
] | |
def rename(self, orig_author: str): | |
""" | |
Rename the original author to a more user-friendly name. | |
Args: | |
orig_author (str): The original author's name. | |
Returns: | |
str: The renamed author. | |
""" | |
rename_dict = {"Chatbot": LLM} | |
return rename_dict.get(orig_author, orig_author) | |
async def start(self): | |
""" | |
Start the chatbot, initialize settings widgets, | |
and display and load previous conversation if chat logging is enabled. | |
""" | |
await self.make_llm_settings_widgets(self.config) # Reload the settings widgets | |
# TODO: remove self.user with cl.user_session.get("user") | |
self.user = { | |
"user_id": "guest", | |
"session_id": cl.context.session.thread_id, | |
} | |
memory = cl.user_session.get("memory", []) | |
self.llm_tutor = LLMTutor(self.config, user=self.user) | |
self.chain = self.llm_tutor.qa_bot( | |
memory=memory, | |
) | |
self.question_generator = self.llm_tutor.question_generator | |
cl.user_session.set("llm_tutor", self.llm_tutor) | |
cl.user_session.set("chain", self.chain) | |
async def stream_response(self, response): | |
""" | |
Stream the response from the LLM. | |
Args: | |
response: The response from the LLM. | |
""" | |
msg = cl.Message(content="") | |
await msg.send() | |
output = {} | |
for chunk in response: | |
if "answer" in chunk: | |
await msg.stream_token(chunk["answer"]) | |
for key in chunk: | |
if key not in output: | |
output[key] = chunk[key] | |
else: | |
output[key] += chunk[key] | |
return output | |
async def main(self, message): | |
""" | |
Process and Display the Conversation. | |
Args: | |
message: The incoming chat message. | |
""" | |
chain = cl.user_session.get("chain") | |
token_count = 0 # initialize token count | |
if not chain: | |
await self.start() # start the chatbot if the chain is not present | |
chain = cl.user_session.get("chain") | |
# update user info with last message time | |
llm_settings = cl.user_session.get("llm_settings", {}) | |
view_sources = llm_settings.get("view_sources", False) | |
stream = llm_settings.get("stream_response", False) | |
stream = False # Fix streaming | |
user_query_dict = {"input": message.content} | |
# Define the base configuration | |
cb = cl.AsyncLangchainCallbackHandler() | |
chain_config = { | |
"configurable": { | |
"user_id": self.user["user_id"], | |
"conversation_id": self.user["session_id"], | |
"memory_window": self.config["llm_params"]["memory_window"], | |
}, | |
"callbacks": ( | |
[cb] | |
if cl_data._data_layer and self.config["chat_logging"]["callbacks"] | |
else None | |
), | |
} | |
with get_openai_callback() as token_count_cb: | |
if stream: | |
res = chain.stream(user_query=user_query_dict, config=chain_config) | |
res = await self.stream_response(res) | |
else: | |
res = await chain.invoke( | |
user_query=user_query_dict, | |
config=chain_config, | |
) | |
token_count += token_count_cb.total_tokens | |
answer = res.get("answer", res.get("result")) | |
answer_with_sources, source_elements, sources_dict = get_sources( | |
res, answer, stream=stream, view_sources=view_sources | |
) | |
answer_with_sources = answer_with_sources.replace("$$", "$") | |
actions = [] | |
if self.config["llm_params"]["generate_follow_up"]: | |
cb_follow_up = cl.AsyncLangchainCallbackHandler() | |
config = { | |
"callbacks": ( | |
[cb_follow_up] | |
if cl_data._data_layer and self.config["chat_logging"]["callbacks"] | |
else None | |
) | |
} | |
with get_openai_callback() as token_count_cb: | |
list_of_questions = await self.question_generator.generate_questions( | |
query=user_query_dict["input"], | |
response=answer, | |
chat_history=res.get("chat_history"), | |
context=res.get("context"), | |
config=config, | |
) | |
token_count += token_count_cb.total_tokens | |
for question in list_of_questions: | |
actions.append( | |
cl.Action( | |
name="follow up question", | |
value="example_value", | |
description=question, | |
label=question, | |
) | |
) | |
await cl.Message( | |
content=answer_with_sources, | |
elements=source_elements, | |
author=LLM, | |
actions=actions, | |
).send() | |
async def on_follow_up(self, action: cl.Action): | |
user = cl.user_session.get("user") | |
message = await cl.Message( | |
content=action.description, | |
type="user_message", | |
author=user.identifier, | |
).send() | |
async with cl.Step( | |
name="on_follow_up", type="run", parent_id=message.id | |
) as step: | |
await self.main(message) | |
step.output = message.content | |
chatbot = Chatbot(config=config) | |
async def start_app(): | |
cl.set_starters(chatbot.set_starters) | |
cl.author_rename(chatbot.rename) | |
cl.on_chat_start(chatbot.start) | |
cl.on_message(chatbot.main) | |
cl.on_settings_update(chatbot.update_llm) | |
cl.action_callback("follow up question")(chatbot.on_follow_up) | |
loop = asyncio.get_event_loop() | |
if loop.is_running(): | |
asyncio.ensure_future(start_app()) | |
else: | |
asyncio.run(start_app()) | |