XThomasBU commited on
Commit
a8421b2
·
unverified ·
2 Parent(s): e934b90 eb62139

Merge pull request #87 from DL4DS/chainlit_base_code

Browse files
Files changed (1) hide show
  1. code/chainlit_base.py +484 -0
code/chainlit_base.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chainlit.data as cl_data
2
+ import asyncio
3
+ import yaml
4
+ from typing import Any, Dict, no_type_check
5
+ import chainlit as cl
6
+ from modules.chat.llm_tutor import LLMTutor
7
+ from modules.chat.helpers import (
8
+ get_sources,
9
+ get_history_chat_resume,
10
+ get_history_setup_llm,
11
+ get_last_config,
12
+ )
13
+ import copy
14
+ from chainlit.types import ThreadDict
15
+ import time
16
+ from langchain_community.callbacks import get_openai_callback
17
+
18
+ USER_TIMEOUT = 60_000
19
+ SYSTEM = "System"
20
+ LLM = "AI Tutor"
21
+ AGENT = "Agent"
22
+ YOU = "User"
23
+ ERROR = "Error"
24
+
25
+ with open("modules/config/config.yml", "r") as f:
26
+ config = yaml.safe_load(f)
27
+
28
+
29
+ # async def setup_data_layer():
30
+ # """
31
+ # Set up the data layer for chat logging.
32
+ # """
33
+ # if config["chat_logging"]["log_chat"]:
34
+ # data_layer = CustomLiteralDataLayer(
35
+ # api_key=LITERAL_API_KEY_LOGGING, server=LITERAL_API_URL
36
+ # )
37
+ # else:
38
+ # data_layer = None
39
+
40
+ # return data_layer
41
+
42
+
43
+ class Chatbot:
44
+ def __init__(self, config):
45
+ """
46
+ Initialize the Chatbot class.
47
+ """
48
+ self.config = config
49
+
50
+ async def _load_config(self):
51
+ """
52
+ Load the configuration from a YAML file.
53
+ """
54
+ with open("modules/config/config.yml", "r") as f:
55
+ return yaml.safe_load(f)
56
+
57
+ @no_type_check
58
+ async def setup_llm(self):
59
+ """
60
+ Set up the LLM with the provided settings. Update the configuration and initialize the LLM tutor.
61
+
62
+ #TODO: Clean this up.
63
+ """
64
+ start_time = time.time()
65
+
66
+ llm_settings = cl.user_session.get("llm_settings", {})
67
+ (
68
+ chat_profile,
69
+ retriever_method,
70
+ memory_window,
71
+ llm_style,
72
+ generate_follow_up,
73
+ chunking_mode,
74
+ ) = (
75
+ llm_settings.get("chat_model"),
76
+ llm_settings.get("retriever_method"),
77
+ llm_settings.get("memory_window"),
78
+ llm_settings.get("llm_style"),
79
+ llm_settings.get("follow_up_questions"),
80
+ llm_settings.get("chunking_mode"),
81
+ )
82
+
83
+ chain = cl.user_session.get("chain")
84
+ memory_list = cl.user_session.get(
85
+ "memory",
86
+ (
87
+ list(chain.store.values())[0].messages
88
+ if len(chain.store.values()) > 0
89
+ else []
90
+ ),
91
+ )
92
+ conversation_list = get_history_setup_llm(memory_list)
93
+
94
+ old_config = copy.deepcopy(self.config)
95
+ self.config["vectorstore"]["db_option"] = retriever_method
96
+ self.config["llm_params"]["memory_window"] = memory_window
97
+ self.config["llm_params"]["llm_style"] = llm_style
98
+ self.config["llm_params"]["llm_loader"] = chat_profile
99
+ self.config["llm_params"]["generate_follow_up"] = generate_follow_up
100
+ self.config["splitter_options"]["chunking_mode"] = chunking_mode
101
+
102
+ self.llm_tutor.update_llm(
103
+ old_config, self.config
104
+ ) # update only llm attributes that are changed
105
+ self.chain = self.llm_tutor.qa_bot(
106
+ memory=conversation_list,
107
+ )
108
+
109
+ cl.user_session.set("chain", self.chain)
110
+ cl.user_session.set("llm_tutor", self.llm_tutor)
111
+
112
+ print("Time taken to setup LLM: ", time.time() - start_time)
113
+
114
+ @no_type_check
115
+ async def update_llm(self, new_settings: Dict[str, Any]):
116
+ """
117
+ Update the LLM settings and reinitialize the LLM with the new settings.
118
+
119
+ Args:
120
+ new_settings (Dict[str, Any]): The new settings to update.
121
+ """
122
+ cl.user_session.set("llm_settings", new_settings)
123
+ await self.inform_llm_settings()
124
+ await self.setup_llm()
125
+
126
+ async def make_llm_settings_widgets(self, config=None):
127
+ """
128
+ Create and send the widgets for LLM settings configuration.
129
+
130
+ Args:
131
+ config: The configuration to use for setting up the widgets.
132
+ """
133
+ config = config or self.config
134
+ await cl.ChatSettings(
135
+ [
136
+ cl.input_widget.Select(
137
+ id="chat_model",
138
+ label="Model Name (Default GPT-3)",
139
+ values=["local_llm", "gpt-3.5-turbo-1106", "gpt-4", "gpt-4o-mini"],
140
+ initial_index=[
141
+ "local_llm",
142
+ "gpt-3.5-turbo-1106",
143
+ "gpt-4",
144
+ "gpt-4o-mini",
145
+ ].index(config["llm_params"]["llm_loader"]),
146
+ ),
147
+ cl.input_widget.Select(
148
+ id="retriever_method",
149
+ label="Retriever (Default FAISS)",
150
+ values=["FAISS", "Chroma", "RAGatouille", "RAPTOR"],
151
+ initial_index=["FAISS", "Chroma", "RAGatouille", "RAPTOR"].index(
152
+ config["vectorstore"]["db_option"]
153
+ ),
154
+ ),
155
+ cl.input_widget.Slider(
156
+ id="memory_window",
157
+ label="Memory Window (Default 3)",
158
+ initial=3,
159
+ min=0,
160
+ max=10,
161
+ step=1,
162
+ ),
163
+ cl.input_widget.Switch(
164
+ id="view_sources", label="View Sources", initial=False
165
+ ),
166
+ cl.input_widget.Switch(
167
+ id="stream_response",
168
+ label="Stream response",
169
+ initial=config["llm_params"]["stream"],
170
+ ),
171
+ cl.input_widget.Select(
172
+ id="chunking_mode",
173
+ label="Chunking mode",
174
+ values=["fixed", "semantic"],
175
+ initial_index=1,
176
+ ),
177
+ cl.input_widget.Switch(
178
+ id="follow_up_questions",
179
+ label="Generate follow up questions",
180
+ initial=False,
181
+ ),
182
+ cl.input_widget.Select(
183
+ id="llm_style",
184
+ label="Type of Conversation (Default Normal)",
185
+ values=["Normal", "ELI5"],
186
+ initial_index=0,
187
+ ),
188
+ ]
189
+ ).send()
190
+
191
+ @no_type_check
192
+ async def inform_llm_settings(self):
193
+ """
194
+ Inform the user about the updated LLM settings and display them as a message.
195
+ """
196
+ llm_settings: Dict[str, Any] = cl.user_session.get("llm_settings", {})
197
+ llm_tutor = cl.user_session.get("llm_tutor")
198
+ settings_dict = {
199
+ "model": llm_settings.get("chat_model"),
200
+ "retriever": llm_settings.get("retriever_method"),
201
+ "memory_window": llm_settings.get("memory_window"),
202
+ "num_docs_in_db": (
203
+ len(llm_tutor.vector_db)
204
+ if llm_tutor and hasattr(llm_tutor, "vector_db")
205
+ else 0
206
+ ),
207
+ "view_sources": llm_settings.get("view_sources"),
208
+ "follow_up_questions": llm_settings.get("follow_up_questions"),
209
+ }
210
+ print("Settings Dict: ", settings_dict)
211
+ await cl.Message(
212
+ author=SYSTEM,
213
+ content="LLM settings have been updated. You can continue with your Query!",
214
+ # elements=[
215
+ # cl.Text(
216
+ # name="settings",
217
+ # display="side",
218
+ # content=json.dumps(settings_dict, indent=4),
219
+ # language="json",
220
+ # ),
221
+ # ],
222
+ ).send()
223
+
224
+ async def set_starters(self):
225
+ """
226
+ Set starter messages for the chatbot.
227
+ """
228
+ # Return Starters only if the chat is new
229
+
230
+ try:
231
+ thread = cl_data._data_layer.get_thread(
232
+ cl.context.session.thread_id
233
+ ) # see if the thread has any steps
234
+ if thread.steps or len(thread.steps) > 0:
235
+ return None
236
+ except Exception as e:
237
+ print(e)
238
+ return [
239
+ cl.Starter(
240
+ label="recording on CNNs?",
241
+ message="Where can I find the recording for the lecture on Transformers?",
242
+ icon="/public/adv-screen-recorder-svgrepo-com.svg",
243
+ ),
244
+ cl.Starter(
245
+ label="where's the slides?",
246
+ message="When are the lectures? I can't find the schedule.",
247
+ icon="/public/alarmy-svgrepo-com.svg",
248
+ ),
249
+ cl.Starter(
250
+ label="Due Date?",
251
+ message="When is the final project due?",
252
+ icon="/public/calendar-samsung-17-svgrepo-com.svg",
253
+ ),
254
+ cl.Starter(
255
+ label="Explain backprop.",
256
+ message="I didn't understand the math behind backprop, could you explain it?",
257
+ icon="/public/acastusphoton-svgrepo-com.svg",
258
+ ),
259
+ ]
260
+
261
+ def rename(self, orig_author: str):
262
+ """
263
+ Rename the original author to a more user-friendly name.
264
+
265
+ Args:
266
+ orig_author (str): The original author's name.
267
+
268
+ Returns:
269
+ str: The renamed author.
270
+ """
271
+ rename_dict = {"Chatbot": LLM}
272
+ return rename_dict.get(orig_author, orig_author)
273
+
274
+ async def start(self, config=None):
275
+ """
276
+ Start the chatbot, initialize settings widgets,
277
+ and display and load previous conversation if chat logging is enabled.
278
+ """
279
+
280
+ start_time = time.time()
281
+
282
+ self.config = (
283
+ await self._load_config() if config is None else config
284
+ ) # Reload the configuration on chat resume
285
+
286
+ await self.make_llm_settings_widgets(self.config) # Reload the settings widgets
287
+
288
+ user = cl.user_session.get("user")
289
+
290
+ # TODO: remove self.user with cl.user_session.get("user")
291
+ try:
292
+ self.user = {
293
+ "user_id": user.identifier,
294
+ "session_id": cl.context.session.thread_id,
295
+ }
296
+ except Exception as e:
297
+ print(e)
298
+ self.user = {
299
+ "user_id": "guest",
300
+ "session_id": cl.context.session.thread_id,
301
+ }
302
+
303
+ memory = cl.user_session.get("memory", [])
304
+ self.llm_tutor = LLMTutor(self.config, user=self.user)
305
+
306
+ self.chain = self.llm_tutor.qa_bot(
307
+ memory=memory,
308
+ )
309
+ self.question_generator = self.llm_tutor.question_generator
310
+ cl.user_session.set("llm_tutor", self.llm_tutor)
311
+ cl.user_session.set("chain", self.chain)
312
+
313
+ print("Time taken to start LLM: ", time.time() - start_time)
314
+
315
+ async def stream_response(self, response):
316
+ """
317
+ Stream the response from the LLM.
318
+
319
+ Args:
320
+ response: The response from the LLM.
321
+ """
322
+ msg = cl.Message(content="")
323
+ await msg.send()
324
+
325
+ output = {}
326
+ for chunk in response:
327
+ if "answer" in chunk:
328
+ await msg.stream_token(chunk["answer"])
329
+
330
+ for key in chunk:
331
+ if key not in output:
332
+ output[key] = chunk[key]
333
+ else:
334
+ output[key] += chunk[key]
335
+ return output
336
+
337
+ async def main(self, message):
338
+ """
339
+ Process and Display the Conversation.
340
+
341
+ Args:
342
+ message: The incoming chat message.
343
+ """
344
+
345
+ start_time = time.time()
346
+
347
+ chain = cl.user_session.get("chain")
348
+ token_count = 0 # initialize token count
349
+ if not chain:
350
+ await self.start() # start the chatbot if the chain is not present
351
+ chain = cl.user_session.get("chain")
352
+
353
+ # update user info with last message time
354
+ llm_settings = cl.user_session.get("llm_settings", {})
355
+ view_sources = llm_settings.get("view_sources", False)
356
+ stream = llm_settings.get("stream_response", False)
357
+ stream = False # Fix streaming
358
+ user_query_dict = {"input": message.content}
359
+ # Define the base configuration
360
+ cb = cl.AsyncLangchainCallbackHandler()
361
+ chain_config = {
362
+ "configurable": {
363
+ "user_id": self.user["user_id"],
364
+ "conversation_id": self.user["session_id"],
365
+ "memory_window": self.config["llm_params"]["memory_window"],
366
+ },
367
+ "callbacks": (
368
+ [cb]
369
+ if cl_data._data_layer and self.config["chat_logging"]["callbacks"]
370
+ else None
371
+ ),
372
+ }
373
+
374
+ with get_openai_callback() as token_count_cb:
375
+ if stream:
376
+ res = chain.stream(user_query=user_query_dict, config=chain_config)
377
+ res = await self.stream_response(res)
378
+ else:
379
+ res = await chain.invoke(
380
+ user_query=user_query_dict,
381
+ config=chain_config,
382
+ )
383
+ token_count += token_count_cb.total_tokens
384
+
385
+ answer = res.get("answer", res.get("result"))
386
+
387
+ answer_with_sources, source_elements, sources_dict = get_sources(
388
+ res, answer, stream=stream, view_sources=view_sources
389
+ )
390
+ answer_with_sources = answer_with_sources.replace("$$", "$")
391
+
392
+ print("Time taken to process the message: ", time.time() - start_time)
393
+
394
+ actions = []
395
+
396
+ if self.config["llm_params"]["generate_follow_up"]:
397
+ start_time = time.time()
398
+ cb_follow_up = cl.AsyncLangchainCallbackHandler()
399
+ config = {
400
+ "callbacks": (
401
+ [cb_follow_up]
402
+ if cl_data._data_layer and self.config["chat_logging"]["callbacks"]
403
+ else None
404
+ )
405
+ }
406
+ with get_openai_callback() as token_count_cb:
407
+ list_of_questions = await self.question_generator.generate_questions(
408
+ query=user_query_dict["input"],
409
+ response=answer,
410
+ chat_history=res.get("chat_history"),
411
+ context=res.get("context"),
412
+ config=config,
413
+ )
414
+
415
+ token_count += token_count_cb.total_tokens
416
+
417
+ for question in list_of_questions:
418
+ actions.append(
419
+ cl.Action(
420
+ name="follow up question",
421
+ value="example_value",
422
+ description=question,
423
+ label=question,
424
+ )
425
+ )
426
+
427
+ print("Time taken to generate questions: ", time.time() - start_time)
428
+ print("Total Tokens Used: ", token_count)
429
+
430
+ await cl.Message(
431
+ content=answer_with_sources,
432
+ elements=source_elements,
433
+ author=LLM,
434
+ actions=actions,
435
+ metadata=self.config,
436
+ ).send()
437
+
438
+ async def on_chat_resume(self, thread: ThreadDict):
439
+ thread_config = None
440
+ steps = thread["steps"]
441
+ k = self.config["llm_params"][
442
+ "memory_window"
443
+ ] # on resume, alwyas use the default memory window
444
+ conversation_list = get_history_chat_resume(steps, k, SYSTEM, LLM)
445
+ thread_config = get_last_config(
446
+ steps
447
+ ) # TODO: Returns None for now - which causes config to be reloaded with default values
448
+ cl.user_session.set("memory", conversation_list)
449
+ await self.start(config=thread_config)
450
+
451
+ async def on_follow_up(self, action: cl.Action):
452
+ user = cl.user_session.get("user")
453
+ message = await cl.Message(
454
+ content=action.description,
455
+ type="user_message",
456
+ author=user.identifier,
457
+ ).send()
458
+ async with cl.Step(
459
+ name="on_follow_up", type="run", parent_id=message.id
460
+ ) as step:
461
+ await self.main(message)
462
+ step.output = message.content
463
+
464
+
465
+ chatbot = Chatbot(config=config)
466
+
467
+
468
+ async def start_app():
469
+ # cl_data._data_layer = await setup_data_layer()
470
+ # chatbot.literal_client = cl_data._data_layer.client if cl_data._data_layer else None
471
+ cl.set_starters(chatbot.set_starters)
472
+ cl.author_rename(chatbot.rename)
473
+ cl.on_chat_start(chatbot.start)
474
+ cl.on_chat_resume(chatbot.on_chat_resume)
475
+ cl.on_message(chatbot.main)
476
+ cl.on_settings_update(chatbot.update_llm)
477
+ cl.action_callback("follow up question")(chatbot.on_follow_up)
478
+
479
+
480
+ loop = asyncio.get_event_loop()
481
+ if loop.is_running():
482
+ asyncio.ensure_future(start_app())
483
+ else:
484
+ asyncio.run(start_app())