import json import textwrap from typing import Any, Callable, Dict, List, Literal, Optional, no_type_check import chainlit as cl from chainlit import run_sync from chainlit.config import config import yaml import os from modules.chat.llm_tutor import LLMTutor from modules.chat_processor.chat_processor import ChatProcessor from modules.config.constants import LLAMA_PATH from modules.chat.helpers import get_sources from chainlit.input_widget import Select, Switch, Slider USER_TIMEOUT = 60_000 SYSTEM = "System 🖥️" LLM = "LLM 🧠" AGENT = "Agent <>" YOU = "You 😃" ERROR = "Error 🚫" class Chatbot: def __init__(self): self.llm_tutor = None self.chain = None self.chat_processor = None self.config = self._load_config() def _load_config(self): with open("modules/config/config.yml", "r") as f: config = yaml.safe_load(f) return config async def ask_helper(func, **kwargs): res = await func(**kwargs).send() while not res: res = await func(**kwargs).send() return res @no_type_check async def setup_llm(self) -> None: """From the session `llm_settings`, create new LLMConfig and LLM objects, save them in session state.""" old_config = self.config.copy() # create a copy of the previous config new_config = ( self.config.copy() ) # create the new config as a copy of the previous config llm_settings = cl.user_session.get("llm_settings", {}) chat_profile = llm_settings.get("chat_model") retriever_method = llm_settings.get("retriever_method") memory_window = llm_settings.get("memory_window") self._configure_llm(chat_profile) chain = cl.user_session.get("chain") memory = chain.memory new_config["vectorstore"][ "db_option" ] = retriever_method # update the retriever method in the config new_config["llm_params"][ "memory_window" ] = memory_window # update the memory window in the config self.llm_tutor.update_llm(new_config) self.chain = self.llm_tutor.qa_bot(memory=memory) tags = [chat_profile, self.config["vectorstore"]["db_option"]] self.chat_processor = ChatProcessor(self.llm_tutor, tags=tags) cl.user_session.set("chain", self.chain) cl.user_session.set("llm_tutor", self.llm_tutor) cl.user_session.set("chat_processor", self.chat_processor) @no_type_check async def update_llm(self, new_settings: Dict[str, Any]) -> None: """Update LLMConfig and LLM from settings, and save in session state.""" cl.user_session.set("llm_settings", new_settings) await self.inform_llm_settings() await self.setup_llm() async def make_llm_settings_widgets(self, config=None): config = config or self.config await cl.ChatSettings( [ cl.input_widget.Select( id="chat_model", label="Model Name (Default GPT-3)", values=["llama", "gpt-3.5-turbo-1106", "gpt-4"], initial_index=0, ), cl.input_widget.Select( id="retriever_method", label="Retriever (Default FAISS)", values=["FAISS", "Chroma", "RAGatouille", "RAPTOR"], initial_index=0, ), cl.input_widget.Slider( id="memory_window", label="Memory Window (Default 3)", initial=3, min=0, max=10, step=1, ), cl.input_widget.Switch( id="view_sources", label="View Sources", initial=False ), # cl.input_widget.TextInput( # id="vectorstore", # label="temp", # initial="None", # ), ] ).send() # type: ignore @no_type_check async def inform_llm_settings(self) -> None: llm_settings: Dict[str, Any] = cl.user_session.get("llm_settings", {}) llm_tutor = cl.user_session.get("llm_tutor") settings_dict = dict( model=llm_settings.get("chat_model"), retriever=llm_settings.get("retriever_method"), memory_window=llm_settings.get("memory_window"), num_docs_in_db=len(llm_tutor.vector_db), view_sources=llm_settings.get("view_sources"), ) await cl.Message( author=SYSTEM, content="LLM settings have been updated. You can continue with your Query!", elements=[ cl.Text( name="settings", display="side", content=json.dumps(settings_dict, indent=4), language="json", ) ], ).send() async def set_starters(self): return [ cl.Starter( label="recording on CNNs?", message="Where can I find the recording for the lecture on Transformers?", icon="/public/adv-screen-recorder-svgrepo-com.svg", ), cl.Starter( label="where's the slides?", message="When are the lectures? I can't find the schedule.", icon="/public/alarmy-svgrepo-com.svg", ), cl.Starter( label="Due Date?", message="When is the final project due?", icon="/public/calendar-samsung-17-svgrepo-com.svg", ), cl.Starter( label="Explain backprop.", message="I didn't understand the math behind backprop, could you explain it?", icon="/public/acastusphoton-svgrepo-com.svg", ), ] async def chat_profile(self): return [ cl.ChatProfile( name="gpt-3.5-turbo-1106", markdown_description="Use OpenAI API for **gpt-3.5-turbo-1106**.", ), cl.ChatProfile( name="gpt-4", markdown_description="Use OpenAI API for **gpt-4**.", ), cl.ChatProfile( name="Llama", markdown_description="Use the local LLM: **Tiny Llama**.", ), ] def rename(self, orig_author: str): rename_dict = {"Chatbot": "AI Tutor"} return rename_dict.get(orig_author, orig_author) async def start(self): await self.make_llm_settings_widgets(self.config) chat_profile = cl.user_session.get("chat_profile") if chat_profile: self._configure_llm(chat_profile) self.llm_tutor = LLMTutor( self.config, user={"user_id": "abc123", "session_id": "789"} ) self.chain = self.llm_tutor.qa_bot() tags = [chat_profile, self.config["vectorstore"]["db_option"]] self.chat_processor = ChatProcessor(self.llm_tutor, tags=tags) cl.user_session.set("llm_tutor", self.llm_tutor) cl.user_session.set("chain", self.chain) cl.user_session.set("counter", 20) cl.user_session.set("chat_processor", self.chat_processor) async def on_chat_end(self): await cl.Message(content="Sorry, I have to go now. Goodbye!").send() async def main(self, message): chain = cl.user_session.get("chain") counter = cl.user_session.get("counter") llm_settings = cl.user_session.get("llm_settings", {}) view_sources = llm_settings.get("view_sources", False) print("HERE") print(llm_settings) print(view_sources) print("\n\n") counter += 1 cl.user_session.set("counter", counter) processor = cl.user_session.get("chat_processor") res = await processor.rag(message.content, chain) print(res) answer = res.get("answer", res.get("result")) answer_with_sources, source_elements, sources_dict = get_sources( res, answer, view_sources=view_sources ) processor._process(message.content, answer, sources_dict) await cl.Message(content=answer_with_sources, elements=source_elements).send() def _configure_llm(self, chat_profile): chat_profile = chat_profile.lower() if chat_profile in ["gpt-3.5-turbo-1106", "gpt-4"]: self.config["llm_params"]["llm_loader"] = "openai" self.config["llm_params"]["openai_params"]["model"] = chat_profile elif chat_profile == "llama": self.config["llm_params"]["llm_loader"] = "local_llm" self.config["llm_params"]["local_llm_params"]["model"] = LLAMA_PATH self.config["llm_params"]["local_llm_params"]["model_type"] = "llama" elif chat_profile == "mistral": self.config["llm_params"]["llm_loader"] = "local_llm" self.config["llm_params"]["local_llm_params"]["model"] = MISTRAL_PATH self.config["llm_params"]["local_llm_params"]["model_type"] = "mistral" chatbot = Chatbot() # Register functions to Chainlit events cl.set_starters(chatbot.set_starters) cl.set_chat_profiles(chatbot.chat_profile) cl.author_rename(chatbot.rename) cl.on_chat_start(chatbot.start) cl.on_chat_end(chatbot.on_chat_end) cl.on_message(chatbot.main) cl.on_settings_update(chatbot.update_llm)