File size: 3,912 Bytes
f2daaee e029e22 f2daaee e029e22 f2daaee e029e22 f2daaee e029e22 f2daaee fc2cb23 f2daaee e029e22 f2daaee fc2cb23 e029e22 f2daaee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
from literalai import LiteralClient
from literalai.api import LiteralAPI
from literalai.filter import Filter as ThreadFilter
import os
from .base import ChatProcessorBase
class LiteralaiChatProcessor(ChatProcessorBase):
def __init__(self, user=None, tags=None):
super().__init__()
self.user = user
self.tags = tags
self.literal_client = LiteralClient(api_key=os.getenv("LITERAL_API_KEY"))
self.literal_api = LiteralAPI(
api_key=os.getenv("LITERAL_API_KEY"), url=os.getenv("LITERAL_API_URL")
)
self.literal_client.reset_context()
self.user_info = self._fetch_userinfo()
self.user_thread = self._fetch_user_threads()
if len(self.user_thread["data"]) == 0:
self.thread = self._create_user_thread()
else:
self.thread = self._get_user_thread()
self.thread_id = self.thread["id"]
self.prev_conv = self._get_prev_k_conversations()
def _get_user_thread(self):
thread = self.literal_api.get_thread(id=self.user_thread["data"][0]["id"])
return thread.to_dict()
def _create_user_thread(self):
thread = self.literal_api.create_thread(
name=f"{self.user_info['identifier']}",
participant_id=self.user_info["metadata"]["id"],
environment="dev",
)
return thread.to_dict()
def _get_prev_k_conversations(self, k=3):
steps = self.thread["steps"]
conversation_pairs = []
count = 0
for i in range(len(steps) - 1, 0, -1):
if (
steps[i - 1]["type"] == "user_message"
and steps[i]["type"] == "assistant_message"
):
user_message = steps[i - 1]["output"]["content"]
assistant_message = steps[i]["output"]["content"]
conversation_pairs.append((user_message, assistant_message))
count += 1
if count >= k:
break
# Return the last k conversation pairs, reversed to maintain chronological order
return conversation_pairs[::-1]
def _fetch_user_threads(self):
filters = filters = [
{
"operator": "eq",
"field": "participantId",
"value": self.user_info["metadata"]["id"],
}
]
user_threads = self.literal_api.get_threads(filters=filters)
return user_threads.to_dict()
def _fetch_userinfo(self):
user_info = self.literal_api.get_or_create_user(
identifier=self.user["user_id"]
).to_dict()
# TODO: Have to do this more elegantly
# update metadata with unique id for now
# (literalai seems to not return the unique id as of now,
# so have to explicitly update it in the metadata)
user_info = self.literal_api.update_user(
id=user_info["id"],
metadata={
"id": user_info["id"],
},
).to_dict()
return user_info
def process(self, user_message, assistant_message, source_dict):
with self.literal_client.thread(thread_id=self.thread_id) as thread:
self.literal_client.message(
content=user_message,
type="user_message",
name="User",
)
self.literal_client.message(
content=assistant_message,
type="assistant_message",
name="AI_Tutor",
)
async def rag(self, user_query: dict, config: dict, chain):
with self.literal_client.step(
type="retrieval", name="RAG", thread_id=self.thread_id, tags=self.tags
) as step:
step.input = {"question": user_query["input"]}
res = chain.invoke(user_query, config)
step.output = res
return res
|