|
import json |
|
import textwrap |
|
from typing import Any, Callable, Dict, List, Literal, Optional, no_type_check |
|
import chainlit as cl |
|
from chainlit import run_sync |
|
from chainlit.config import config |
|
import yaml |
|
import os |
|
|
|
from modules.chat.llm_tutor import LLMTutor |
|
from modules.chat_processor.chat_processor import ChatProcessor |
|
from modules.config.constants import LLAMA_PATH |
|
from modules.chat.helpers import get_sources |
|
|
|
from chainlit.input_widget import Select, Switch, Slider |
|
|
|
USER_TIMEOUT = 60_000 |
|
SYSTEM = "System 🖥️" |
|
LLM = "LLM 🧠" |
|
AGENT = "Agent <>" |
|
YOU = "You 😃" |
|
ERROR = "Error 🚫" |
|
|
|
|
|
class Chatbot: |
|
def __init__(self): |
|
self.llm_tutor = None |
|
self.chain = None |
|
self.chat_processor = None |
|
self.config = self._load_config() |
|
|
|
def _load_config(self): |
|
with open("modules/config/config.yml", "r") as f: |
|
config = yaml.safe_load(f) |
|
return config |
|
|
|
async def ask_helper(func, **kwargs): |
|
res = await func(**kwargs).send() |
|
while not res: |
|
res = await func(**kwargs).send() |
|
return res |
|
|
|
@no_type_check |
|
async def setup_llm(self) -> None: |
|
"""From the session `llm_settings`, create new LLMConfig and LLM objects, |
|
save them in session state.""" |
|
|
|
llm_settings = cl.user_session.get("llm_settings", {}) |
|
chat_profile = llm_settings.get("chat_model") |
|
retriever_method = llm_settings.get("retriever_method") |
|
memory_window = llm_settings.get("memory_window") |
|
|
|
self._configure_llm(chat_profile) |
|
|
|
chain = cl.user_session.get("chain") |
|
memory = chain.memory |
|
self.config["vectorstore"][ |
|
"db_option" |
|
] = retriever_method |
|
memory.k = memory_window |
|
|
|
self.llm_tutor = LLMTutor(self.config) |
|
self.chain = self.llm_tutor.qa_bot(memory=memory) |
|
|
|
tags = [chat_profile, self.config["vectorstore"]["db_option"]] |
|
self.chat_processor = ChatProcessor(self.config, tags=tags) |
|
|
|
cl.user_session.set("chain", self.chain) |
|
cl.user_session.set("llm_tutor", self.llm_tutor) |
|
cl.user_session.set("chat_processor", self.chat_processor) |
|
|
|
@no_type_check |
|
async def update_llm(self, new_settings: Dict[str, Any]) -> None: |
|
"""Update LLMConfig and LLM from settings, and save in session state.""" |
|
cl.user_session.set("llm_settings", new_settings) |
|
await self.inform_llm_settings() |
|
await self.setup_llm() |
|
|
|
async def make_llm_settings_widgets(self, config=None): |
|
config = config or self.config |
|
await cl.ChatSettings( |
|
[ |
|
cl.input_widget.Select( |
|
id="chat_model", |
|
label="Model Name (Default GPT-3)", |
|
values=["llama", "gpt-3.5-turbo-1106", "gpt-4"], |
|
initial_index=0, |
|
), |
|
cl.input_widget.Select( |
|
id="retriever_method", |
|
label="Retriever (Default FAISS)", |
|
values=["FAISS", "Chroma", "RAGatouille", "RAPTOR"], |
|
initial_index=0, |
|
), |
|
cl.input_widget.Slider( |
|
id="memory_window", |
|
label="Memory Window (Default 3)", |
|
initial=3, |
|
min=0, |
|
max=10, |
|
step=1, |
|
), |
|
cl.input_widget.Switch( |
|
id="view_sources", label="View Sources", initial=False |
|
), |
|
] |
|
).send() |
|
|
|
@no_type_check |
|
async def inform_llm_settings(self) -> None: |
|
llm_settings: Dict[str, Any] = cl.user_session.get("llm_settings", {}) |
|
llm_tutor = cl.user_session.get("llm_tutor") |
|
settings_dict = dict( |
|
model=llm_settings.get("chat_model"), |
|
retriever=llm_settings.get("retriever_method"), |
|
memory_window=llm_settings.get("memory_window"), |
|
num_docs_in_db=len(llm_tutor.vector_db), |
|
view_sources=llm_settings.get("view_sources"), |
|
) |
|
await cl.Message( |
|
author=SYSTEM, |
|
content="LLM settings have been updated. You can continue with your Query!", |
|
elements=[ |
|
cl.Text( |
|
name="settings", |
|
display="side", |
|
content=json.dumps(settings_dict, indent=4), |
|
language="json", |
|
) |
|
], |
|
).send() |
|
|
|
async def set_starters(self): |
|
return [ |
|
cl.Starter( |
|
label="recording on CNNs?", |
|
message="Where can I find the recording for the lecture on Transformers?", |
|
icon="/public/adv-screen-recorder-svgrepo-com.svg", |
|
), |
|
cl.Starter( |
|
label="where's the slides?", |
|
message="When are the lectures? I can't find the schedule.", |
|
icon="/public/alarmy-svgrepo-com.svg", |
|
), |
|
cl.Starter( |
|
label="Due Date?", |
|
message="When is the final project due?", |
|
icon="/public/calendar-samsung-17-svgrepo-com.svg", |
|
), |
|
cl.Starter( |
|
label="Explain backprop.", |
|
message="I didn't understand the math behind backprop, could you explain it?", |
|
icon="/public/acastusphoton-svgrepo-com.svg", |
|
), |
|
] |
|
|
|
async def chat_profile(self): |
|
return [ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cl.ChatProfile( |
|
name="Llama", |
|
markdown_description="Use the local LLM: **Tiny Llama**.", |
|
), |
|
] |
|
|
|
def rename(self, orig_author: str): |
|
rename_dict = {"Chatbot": "AI Tutor"} |
|
return rename_dict.get(orig_author, orig_author) |
|
|
|
async def start(self): |
|
await self.make_llm_settings_widgets(self.config) |
|
|
|
chat_profile = cl.user_session.get("chat_profile") |
|
if chat_profile: |
|
self._configure_llm(chat_profile) |
|
|
|
self.llm_tutor = LLMTutor(self.config) |
|
self.chain = self.llm_tutor.qa_bot() |
|
tags = [chat_profile, self.config["vectorstore"]["db_option"]] |
|
self.chat_processor = ChatProcessor(self.config, tags=tags) |
|
|
|
cl.user_session.set("llm_tutor", self.llm_tutor) |
|
cl.user_session.set("chain", self.chain) |
|
cl.user_session.set("counter", 0) |
|
cl.user_session.set("chat_processor", self.chat_processor) |
|
|
|
async def on_chat_end(self): |
|
await cl.Message(content="Sorry, I have to go now. Goodbye!").send() |
|
|
|
async def main(self, message): |
|
user = cl.user_session.get("user") |
|
chain = cl.user_session.get("chain") |
|
counter = cl.user_session.get("counter") |
|
llm_settings = cl.user_session.get("llm_settings") |
|
|
|
counter += 1 |
|
cl.user_session.set("counter", counter) |
|
|
|
cb = cl.AsyncLangchainCallbackHandler() |
|
cb.answer_reached = True |
|
|
|
processor = cl.user_session.get("chat_processor") |
|
res = await processor.rag(message.content, chain, cb) |
|
answer = res.get("answer", res.get("result")) |
|
|
|
answer_with_sources, source_elements, sources_dict = get_sources( |
|
res, answer, view_sources=llm_settings.get("view_sources") |
|
) |
|
processor._process(message.content, answer, sources_dict) |
|
|
|
await cl.Message(content=answer_with_sources, elements=source_elements).send() |
|
|
|
def _configure_llm(self, chat_profile): |
|
chat_profile = chat_profile.lower() |
|
if chat_profile in ["gpt-3.5-turbo-1106", "gpt-4"]: |
|
self.config["llm_params"]["llm_loader"] = "openai" |
|
self.config["llm_params"]["openai_params"]["model"] = chat_profile |
|
elif chat_profile == "llama": |
|
self.config["llm_params"]["llm_loader"] = "local_llm" |
|
self.config["llm_params"]["local_llm_params"]["model"] = LLAMA_PATH |
|
self.config["llm_params"]["local_llm_params"]["model_type"] = "llama" |
|
elif chat_profile == "mistral": |
|
self.config["llm_params"]["llm_loader"] = "local_llm" |
|
self.config["llm_params"]["local_llm_params"]["model"] = MISTRAL_PATH |
|
self.config["llm_params"]["local_llm_params"]["model_type"] = "mistral" |
|
|
|
|
|
chatbot = Chatbot() |
|
|
|
|
|
cl.set_starters(chatbot.set_starters) |
|
cl.set_chat_profiles(chatbot.chat_profile) |
|
cl.author_rename(chatbot.rename) |
|
cl.on_chat_start(chatbot.start) |
|
cl.on_chat_end(chatbot.on_chat_end) |
|
cl.on_message(chatbot.main) |
|
cl.on_settings_update(chatbot.update_llm) |
|
|