|
import chainlit.data as cl_data |
|
import asyncio |
|
from config.constants import ( |
|
LITERAL_API_KEY_LOGGING, |
|
LITERAL_API_URL, |
|
) |
|
from edubotics_core.chat_processor.literal_ai import CustomLiteralDataLayer |
|
import json |
|
from typing import Any, Dict, no_type_check |
|
import chainlit as cl |
|
from edubotics_core.chat.llm_tutor import LLMTutor |
|
from edubotics_core.chat.helpers import ( |
|
get_sources, |
|
get_history_chat_resume, |
|
get_history_setup_llm, |
|
|
|
) |
|
from edubotics_core.chat_processor.helpers import ( |
|
update_user_info, |
|
get_user_details, |
|
) |
|
from helpers import ( |
|
check_user_cooldown, |
|
reset_tokens_for_user, |
|
) |
|
from helpers import get_time |
|
import copy |
|
from typing import Optional |
|
from chainlit.types import ThreadDict |
|
import base64 |
|
from langchain_community.callbacks import get_openai_callback |
|
from datetime import datetime, timezone |
|
from config.config_manager import config_manager |
|
|
|
USER_TIMEOUT = 60_000 |
|
SYSTEM = "System" |
|
LLM = "AI Tutor" |
|
AGENT = "Agent" |
|
YOU = "User" |
|
ERROR = "Error" |
|
|
|
|
|
config = config_manager.get_config().dict() |
|
|
|
|
|
async def setup_data_layer(): |
|
""" |
|
Set up the data layer for chat logging. |
|
""" |
|
if config["chat_logging"]["log_chat"]: |
|
data_layer = CustomLiteralDataLayer( |
|
api_key=LITERAL_API_KEY_LOGGING, server=LITERAL_API_URL |
|
) |
|
else: |
|
data_layer = None |
|
|
|
return data_layer |
|
|
|
|
|
async def update_user_from_chainlit(user, token_count=0): |
|
if "admin" not in user.metadata["role"]: |
|
user.metadata["tokens_left"] = user.metadata["tokens_left"] - token_count |
|
user.metadata["all_time_tokens_allocated"] = ( |
|
user.metadata["all_time_tokens_allocated"] - token_count |
|
) |
|
user.metadata["tokens_left_at_last_message"] = user.metadata[ |
|
"tokens_left" |
|
] |
|
user.metadata["last_message_time"] = get_time() |
|
await update_user_info(user) |
|
|
|
tokens_left = user.metadata["tokens_left"] |
|
if tokens_left < 0: |
|
tokens_left = 0 |
|
return tokens_left |
|
|
|
|
|
class Chatbot: |
|
def __init__(self, config): |
|
""" |
|
Initialize the Chatbot class. |
|
""" |
|
self.config = config |
|
|
|
@no_type_check |
|
async def setup_llm(self): |
|
""" |
|
Set up the LLM with the provided settings. Update the configuration and initialize the LLM tutor. |
|
|
|
#TODO: Clean this up. |
|
""" |
|
|
|
llm_settings = cl.user_session.get("llm_settings", {}) |
|
( |
|
chat_profile, |
|
retriever_method, |
|
memory_window, |
|
llm_style, |
|
generate_follow_up, |
|
chunking_mode, |
|
) = ( |
|
llm_settings.get("chat_model"), |
|
llm_settings.get("retriever_method"), |
|
llm_settings.get("memory_window"), |
|
llm_settings.get("llm_style"), |
|
llm_settings.get("follow_up_questions"), |
|
llm_settings.get("chunking_mode"), |
|
) |
|
|
|
chain = cl.user_session.get("chain") |
|
memory_list = cl.user_session.get( |
|
"memory", |
|
( |
|
list(chain.store.values())[0].messages |
|
if len(chain.store.values()) > 0 |
|
else [] |
|
), |
|
) |
|
conversation_list = get_history_setup_llm(memory_list) |
|
|
|
old_config = copy.deepcopy(self.config) |
|
self.config["vectorstore"]["db_option"] = retriever_method |
|
self.config["llm_params"]["memory_window"] = memory_window |
|
self.config["llm_params"]["llm_style"] = llm_style |
|
self.config["llm_params"]["llm_loader"] = chat_profile |
|
self.config["llm_params"]["generate_follow_up"] = generate_follow_up |
|
self.config["splitter_options"]["chunking_mode"] = chunking_mode |
|
|
|
self.llm_tutor.update_llm( |
|
old_config, self.config |
|
) |
|
self.chain = self.llm_tutor.qa_bot( |
|
memory=conversation_list, |
|
) |
|
|
|
cl.user_session.set("chain", self.chain) |
|
cl.user_session.set("llm_tutor", self.llm_tutor) |
|
|
|
@no_type_check |
|
async def update_llm(self, new_settings: Dict[str, Any]): |
|
""" |
|
Update the LLM settings and reinitialize the LLM with the new settings. |
|
|
|
Args: |
|
new_settings (Dict[str, Any]): The new settings to update. |
|
""" |
|
cl.user_session.set("llm_settings", new_settings) |
|
await self.inform_llm_settings() |
|
await self.setup_llm() |
|
|
|
async def make_llm_settings_widgets(self, config=None): |
|
""" |
|
Create and send the widgets for LLM settings configuration. |
|
|
|
Args: |
|
config: The configuration to use for setting up the widgets. |
|
""" |
|
config = config or self.config |
|
await cl.ChatSettings( |
|
[ |
|
cl.input_widget.Select( |
|
id="chat_model", |
|
label="Model Name (Default GPT-3)", |
|
values=["local_llm", "gpt-3.5-turbo-1106", "gpt-4", "gpt-4o-mini"], |
|
initial_index=[ |
|
"local_llm", |
|
"gpt-3.5-turbo-1106", |
|
"gpt-4", |
|
"gpt-4o-mini", |
|
].index(config["llm_params"]["llm_loader"]), |
|
), |
|
cl.input_widget.Select( |
|
id="retriever_method", |
|
label="Retriever (Default FAISS)", |
|
values=["FAISS", "Chroma", "RAGatouille", "RAPTOR"], |
|
initial_index=["FAISS", "Chroma", "RAGatouille", "RAPTOR"].index( |
|
config["vectorstore"]["db_option"] |
|
), |
|
), |
|
cl.input_widget.Slider( |
|
id="memory_window", |
|
label="Memory Window (Default 3)", |
|
initial=3, |
|
min=0, |
|
max=10, |
|
step=1, |
|
), |
|
cl.input_widget.Switch( |
|
id="view_sources", label="View Sources", initial=False |
|
), |
|
cl.input_widget.Switch( |
|
id="stream_response", |
|
label="Stream response", |
|
initial=config["llm_params"]["stream"], |
|
), |
|
cl.input_widget.Select( |
|
id="chunking_mode", |
|
label="Chunking mode", |
|
values=["fixed", "semantic"], |
|
initial_index=1, |
|
), |
|
cl.input_widget.Switch( |
|
id="follow_up_questions", |
|
label="Generate follow up questions", |
|
initial=False, |
|
), |
|
cl.input_widget.Select( |
|
id="llm_style", |
|
label="Type of Conversation (Default Normal)", |
|
values=["Normal", "ELI5"], |
|
initial_index=0, |
|
), |
|
] |
|
).send() |
|
|
|
@no_type_check |
|
async def inform_llm_settings(self): |
|
""" |
|
Inform the user about the updated LLM settings and display them as a message. |
|
""" |
|
await cl.Message( |
|
author=SYSTEM, |
|
content="LLM settings have been updated. You can continue with your Query!", |
|
).send() |
|
|
|
async def set_starters(self): |
|
""" |
|
Set starter messages for the chatbot. |
|
""" |
|
|
|
|
|
try: |
|
thread = cl_data._data_layer.get_thread( |
|
cl.context.session.thread_id |
|
) |
|
if thread.steps or len(thread.steps) > 0: |
|
return None |
|
except Exception as e: |
|
print(e) |
|
return [ |
|
cl.Starter( |
|
label="recording on Transformers?", |
|
message="Where can I find the recording for the lecture on Transformers?", |
|
icon="/public/assets/images/starter_icons/adv-screen-recorder-svgrepo-com.svg", |
|
), |
|
cl.Starter( |
|
label="where's the schedule?", |
|
message="When are the lectures? I can't find the schedule.", |
|
icon="/public/assets/images/starter_icons/alarmy-svgrepo-com.svg", |
|
), |
|
cl.Starter( |
|
label="Due Date?", |
|
message="When is the final project due?", |
|
icon="/public/assets/images/starter_icons/calendar-samsung-17-svgrepo-com.svg", |
|
), |
|
cl.Starter( |
|
label="Explain backprop.", |
|
message="I didn't understand the math behind backprop, could you explain it?", |
|
icon="/public/assets/images/starter_icons/acastusphoton-svgrepo-com.svg", |
|
), |
|
] |
|
|
|
def rename(self, orig_author: str): |
|
""" |
|
Rename the original author to a more user-friendly name. |
|
|
|
Args: |
|
orig_author (str): The original author's name. |
|
|
|
Returns: |
|
str: The renamed author. |
|
""" |
|
rename_dict = {"Chatbot": LLM} |
|
return rename_dict.get(orig_author, orig_author) |
|
|
|
async def start(self): |
|
""" |
|
Start the chatbot, initialize settings widgets, |
|
and display and load previous conversation if chat logging is enabled. |
|
""" |
|
|
|
await self.make_llm_settings_widgets(self.config) |
|
|
|
user = cl.user_session.get("user") |
|
|
|
|
|
try: |
|
self.user = { |
|
"user_id": user.identifier, |
|
"session_id": cl.context.session.thread_id, |
|
} |
|
except Exception as e: |
|
print(e) |
|
self.user = { |
|
"user_id": "guest", |
|
"session_id": cl.context.session.thread_id, |
|
} |
|
|
|
memory = cl.user_session.get("memory", []) |
|
self.llm_tutor = LLMTutor(self.config, user=self.user) |
|
|
|
self.chain = self.llm_tutor.qa_bot( |
|
memory=memory, |
|
) |
|
self.question_generator = self.llm_tutor.question_generator |
|
cl.user_session.set("llm_tutor", self.llm_tutor) |
|
cl.user_session.set("chain", self.chain) |
|
|
|
async def stream_response(self, response): |
|
""" |
|
Stream the response from the LLM. |
|
|
|
Args: |
|
response: The response from the LLM. |
|
""" |
|
msg = cl.Message(content="") |
|
await msg.send() |
|
|
|
output = {} |
|
for chunk in response: |
|
if "answer" in chunk: |
|
await msg.stream_token(chunk["answer"]) |
|
|
|
for key in chunk: |
|
if key not in output: |
|
output[key] = chunk[key] |
|
else: |
|
output[key] += chunk[key] |
|
return output |
|
|
|
async def main(self, message): |
|
""" |
|
Process and Display the Conversation. |
|
|
|
Args: |
|
message: The incoming chat message. |
|
""" |
|
|
|
chain = cl.user_session.get("chain") |
|
token_count = 0 |
|
if not chain: |
|
await self.start() |
|
chain = cl.user_session.get("chain") |
|
|
|
|
|
user = cl.user_session.get("user") |
|
await reset_tokens_for_user( |
|
user, |
|
self.config["token_config"]["tokens_left"], |
|
self.config["token_config"]["regen_time"], |
|
) |
|
updated_user = await get_user_details(user.identifier) |
|
user.metadata = updated_user.metadata |
|
cl.user_session.set("user", user) |
|
|
|
|
|
|
|
if user.metadata["tokens_left"] <= 0 and "admin" not in user.metadata["role"]: |
|
current_datetime = get_time() |
|
cooldown, cooldown_end_time = await check_user_cooldown( |
|
user, |
|
current_datetime, |
|
self.config["token_config"]["cooldown_time"], |
|
self.config["token_config"]["tokens_left"], |
|
self.config["token_config"]["regen_time"], |
|
) |
|
if cooldown: |
|
|
|
|
|
cooldown_end_time = datetime.fromisoformat(cooldown_end_time).replace( |
|
tzinfo=timezone.utc |
|
) |
|
current_datetime = datetime.fromisoformat(current_datetime).replace( |
|
tzinfo=timezone.utc |
|
) |
|
cooldown_time_left = cooldown_end_time - current_datetime |
|
|
|
total_seconds = int(cooldown_time_left.total_seconds()) |
|
|
|
hours, remainder = divmod(total_seconds, 3600) |
|
minutes, seconds = divmod(remainder, 60) |
|
|
|
formatted_time = f"{hours:02} hrs {minutes:02} mins {seconds:02} secs" |
|
await cl.Message( |
|
content=( |
|
"Ah, seems like you have run out of tokens...Click " |
|
'<a href="/cooldown" style="color: #0000CD; text-decoration: none;" target="_self">here</a> for more info. Please come back after {}'.format( |
|
formatted_time |
|
) |
|
), |
|
author=SYSTEM, |
|
).send() |
|
user.metadata["in_cooldown"] = True |
|
await update_user_info(user) |
|
return |
|
else: |
|
await cl.Message( |
|
content=( |
|
"Ah, seems like you don't have any tokens left...Please wait while we regenerate your tokens. Click " |
|
'<a href="/cooldown" style="color: #0000CD; text-decoration: none;" target="_self">here</a> to view your token credits.' |
|
), |
|
author=SYSTEM, |
|
).send() |
|
return |
|
|
|
user.metadata["in_cooldown"] = False |
|
|
|
llm_settings = cl.user_session.get("llm_settings", {}) |
|
view_sources = llm_settings.get("view_sources", False) |
|
stream = llm_settings.get("stream_response", False) |
|
stream = False |
|
user_query_dict = {"input": message.content} |
|
|
|
cb = cl.AsyncLangchainCallbackHandler() |
|
chain_config = { |
|
"configurable": { |
|
"user_id": self.user["user_id"], |
|
"conversation_id": self.user["session_id"], |
|
"memory_window": self.config["llm_params"]["memory_window"], |
|
}, |
|
"callbacks": ( |
|
[cb] |
|
if cl_data._data_layer and self.config["chat_logging"]["callbacks"] |
|
else None |
|
), |
|
} |
|
|
|
with get_openai_callback() as token_count_cb: |
|
if stream: |
|
res = chain.stream(user_query=user_query_dict, config=chain_config) |
|
res = await self.stream_response(res) |
|
else: |
|
res = await chain.invoke( |
|
user_query=user_query_dict, |
|
config=chain_config, |
|
) |
|
token_count += token_count_cb.total_tokens |
|
|
|
answer = res.get("answer", res.get("result")) |
|
|
|
answer_with_sources, source_elements, sources_dict = get_sources( |
|
res, answer, stream=stream, view_sources=view_sources |
|
) |
|
answer_with_sources = answer_with_sources.replace("$$", "$") |
|
|
|
actions = [] |
|
|
|
if self.config["llm_params"]["generate_follow_up"]: |
|
cb_follow_up = cl.AsyncLangchainCallbackHandler() |
|
config = { |
|
"callbacks": ( |
|
[cb_follow_up] |
|
if cl_data._data_layer and self.config["chat_logging"]["callbacks"] |
|
else None |
|
) |
|
} |
|
with get_openai_callback() as token_count_cb: |
|
list_of_questions = await self.question_generator.generate_questions( |
|
query=user_query_dict["input"], |
|
response=answer, |
|
chat_history=res.get("chat_history"), |
|
context=res.get("context"), |
|
config=config, |
|
) |
|
|
|
token_count += token_count_cb.total_tokens |
|
|
|
for question in list_of_questions: |
|
actions.append( |
|
cl.Action( |
|
name="follow up question", |
|
value="example_value", |
|
description=question, |
|
label=question, |
|
) |
|
) |
|
|
|
|
|
tokens_left = await update_user_from_chainlit(user, token_count) |
|
|
|
answer_with_sources += ( |
|
'\n\n<footer><span style="font-size: 0.8em; text-align: right; display: block;">Tokens Left: ' |
|
+ str(tokens_left) |
|
+ "</span></footer>\n" |
|
) |
|
|
|
await cl.Message( |
|
content=answer_with_sources, |
|
elements=source_elements, |
|
author=LLM, |
|
actions=actions, |
|
).send() |
|
|
|
async def on_chat_resume(self, thread: ThreadDict): |
|
|
|
steps = thread["steps"] |
|
k = self.config["llm_params"][ |
|
"memory_window" |
|
] |
|
conversation_list = get_history_chat_resume(steps, k, SYSTEM, LLM) |
|
|
|
|
|
|
|
cl.user_session.set("memory", conversation_list) |
|
await self.start() |
|
|
|
@cl.header_auth_callback |
|
def header_auth_callback(headers: dict) -> Optional[cl.User]: |
|
|
|
|
|
cookie = headers.get("cookie") |
|
|
|
cookie_dict = {} |
|
for pair in cookie.split("; "): |
|
key, value = pair.split("=", 1) |
|
|
|
cookie_dict[key] = value.strip('"') |
|
|
|
decoded_user_info = base64.b64decode( |
|
cookie_dict.get("X-User-Info", "") |
|
).decode() |
|
decoded_user_info = json.loads(decoded_user_info) |
|
|
|
return cl.User( |
|
id=decoded_user_info["literalai_info"]["id"], |
|
identifier=decoded_user_info["literalai_info"]["identifier"], |
|
metadata=decoded_user_info["literalai_info"]["metadata"], |
|
) |
|
|
|
async def on_follow_up(self, action: cl.Action): |
|
user = cl.user_session.get("user") |
|
message = await cl.Message( |
|
content=action.description, |
|
type="user_message", |
|
author=user.identifier, |
|
).send() |
|
async with cl.Step( |
|
name="on_follow_up", type="run", parent_id=message.id |
|
) as step: |
|
await self.main(message) |
|
step.output = message.content |
|
|
|
|
|
chatbot = Chatbot(config=config) |
|
|
|
|
|
async def start_app(): |
|
cl_data._data_layer = await setup_data_layer() |
|
chatbot.literal_client = cl_data._data_layer.client if cl_data._data_layer else None |
|
cl.set_starters(chatbot.set_starters) |
|
cl.author_rename(chatbot.rename) |
|
cl.on_chat_start(chatbot.start) |
|
cl.on_chat_resume(chatbot.on_chat_resume) |
|
cl.on_message(chatbot.main) |
|
cl.on_settings_update(chatbot.update_llm) |
|
cl.action_callback("follow up question")(chatbot.on_follow_up) |
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
if loop.is_running(): |
|
asyncio.ensure_future(start_app()) |
|
else: |
|
asyncio.run(start_app()) |
|
|