File size: 3,912 Bytes
f2daaee
e029e22
 
 
f2daaee
 
 
 
 
e029e22
 
 
 
f2daaee
e029e22
 
 
f2daaee
e029e22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2daaee
 
 
 
 
 
 
 
 
 
 
 
 
 
fc2cb23
f2daaee
e029e22
f2daaee
fc2cb23
e029e22
f2daaee
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from literalai import LiteralClient
from literalai.api import LiteralAPI
from literalai.filter import Filter as ThreadFilter

import os
from .base import ChatProcessorBase


class LiteralaiChatProcessor(ChatProcessorBase):
    def __init__(self, user=None, tags=None):
        super().__init__()
        self.user = user
        self.tags = tags
        self.literal_client = LiteralClient(api_key=os.getenv("LITERAL_API_KEY"))
        self.literal_api = LiteralAPI(
            api_key=os.getenv("LITERAL_API_KEY"), url=os.getenv("LITERAL_API_URL")
        )
        self.literal_client.reset_context()
        self.user_info = self._fetch_userinfo()
        self.user_thread = self._fetch_user_threads()
        if len(self.user_thread["data"]) == 0:
            self.thread = self._create_user_thread()
        else:
            self.thread = self._get_user_thread()
        self.thread_id = self.thread["id"]

        self.prev_conv = self._get_prev_k_conversations()

    def _get_user_thread(self):
        thread = self.literal_api.get_thread(id=self.user_thread["data"][0]["id"])
        return thread.to_dict()

    def _create_user_thread(self):
        thread = self.literal_api.create_thread(
            name=f"{self.user_info['identifier']}",
            participant_id=self.user_info["metadata"]["id"],
            environment="dev",
        )

        return thread.to_dict()

    def _get_prev_k_conversations(self, k=3):

        steps = self.thread["steps"]
        conversation_pairs = []
        count = 0
        for i in range(len(steps) - 1, 0, -1):
            if (
                steps[i - 1]["type"] == "user_message"
                and steps[i]["type"] == "assistant_message"
            ):
                user_message = steps[i - 1]["output"]["content"]
                assistant_message = steps[i]["output"]["content"]
                conversation_pairs.append((user_message, assistant_message))

                count += 1
                if count >= k:
                    break

        # Return the last k conversation pairs, reversed to maintain chronological order
        return conversation_pairs[::-1]

    def _fetch_user_threads(self):
        filters = filters = [
            {
                "operator": "eq",
                "field": "participantId",
                "value": self.user_info["metadata"]["id"],
            }
        ]
        user_threads = self.literal_api.get_threads(filters=filters)
        return user_threads.to_dict()

    def _fetch_userinfo(self):
        user_info = self.literal_api.get_or_create_user(
            identifier=self.user["user_id"]
        ).to_dict()
        # TODO: Have to do this more elegantly
        # update metadata with unique id for now
        # (literalai seems to not return the unique id as of now,
        # so have to explicitly update it in the metadata)
        user_info = self.literal_api.update_user(
            id=user_info["id"],
            metadata={
                "id": user_info["id"],
            },
        ).to_dict()
        return user_info

    def process(self, user_message, assistant_message, source_dict):
        with self.literal_client.thread(thread_id=self.thread_id) as thread:
            self.literal_client.message(
                content=user_message,
                type="user_message",
                name="User",
            )
            self.literal_client.message(
                content=assistant_message,
                type="assistant_message",
                name="AI_Tutor",
            )

    async def rag(self, user_query: dict, config: dict, chain):
        with self.literal_client.step(
            type="retrieval", name="RAG", thread_id=self.thread_id, tags=self.tags
        ) as step:
            step.input = {"question": user_query["input"]}
            res = chain.invoke(user_query, config)
            step.output = res
        return res