File size: 6,870 Bytes
291bc70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80f2e42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291bc70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80f2e42
 
 
 
 
 
291bc70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c15e390
291bc70
 
305aec3
 
 
291bc70
 
80f2e42
291bc70
 
305aec3
291bc70
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import os
from src.extractor import create_extractor
from src.sql_chain import create_agent
from dotenv import load_dotenv
import chainlit as cl
import json
# Loading the environment variables
load_dotenv(".env")
# Create the extractor and agent

model = os.getenv('OPENAI_MODEL')
# Check if model exists, if not, set it to default
# if not model:
#     model = "gpt-3.5-turbo-0125"



interactive_key_done= False if os.getenv('INTERACTIVE_OPENAI_KEY', None) else True

if interactive_key_done:
    ex = create_extractor()
    ag = create_agent(llm_model=model)
else:
    ex= None
    ag = None

@cl.on_chat_start
async def on_chat_start():
    global ex, ag, interactive_key_done
    if not interactive_key_done:
        res =  await cl.AskUserMessage(content=" 🔑 Input your OPENAI_API_KEY from https://platform.openai.com/account/api-keys", timeout=10).send()
        if res:
            await cl.Message(
                content=f"⌛ Checking if provided OpenAI API key works. Please wait...",
            ).send()
            cl.user_session.set("openai_api_key", res.get("output"))
            try:
                os.environ["OPENAI_API_KEY"] = res.get("output")
                ex = create_extractor()
                ag = create_agent(llm_model=model)
                interactive_key_done= True
                await cl.Message(author="Socccer-RAG", content="✅ Voila! ⚽ Socccer-RAG warmed up and ready to go! You can start a fresh chat session from New Chat").send()
            except Exception as e:
                await cl.Message(
                    content=f"Error: {e}. Start new chat with correct key.",
                ).send()
            
            # ag = create_agent(llm_model = "gpt-4-0125-preview")



def extract_func(user_prompt: str):
    """

    Parameters
    ----------
    user_prompt: str

    Returns
    -------
    A dictionary of extracted properties
    """
    extracted = ex.extract_chainlit(user_prompt)
    return extracted
def validate_func(properties:dict):  # Auto validate as much as possible
    """
    Parameters
    ----------
    extracted properties: dict

    Returns
    -------
    Two dictionaries:
    1. validated: The validated properties
    2. need_input: Properties that need human validation
    """
    validated, need_input = ex.validate_chainlit(properties)
    return validated, need_input

def human_validate_func(human, validated, user_prompt):
    """

    Parameters
    ----------
    human - Human validated properties in the form of a list of dictionaries
    validated - Validated properties in the form of a dictionary
    user_prompt - The user prompt

    Returns
    -------
    The cleaned prompt with updated values
    """
    for item in human:
        # Iterate through key-value pairs in the current dictionary
        for key, value in item.items():
            if value == "":
                continue
            # Check if the key exists in the validated dictionary
            if key in validated:
                # Append the value to the existing list
                validated[key].append(value)
            else:
                # Create a new key with the value as a new list
                validated[key] = [value]
    val_list = [validated]

    return ex.build_prompt_chainlit(val_list, user_prompt)

def no_human(validated, user_prompt):
    """
    In case there is no need for human validation, this function will be called
    Parameters
    ----------
    validated
    user_prompt

    Returns
    -------
    Updated prompt
    """
    return ex.build_prompt_chainlit([validated], user_prompt)


def ask(text):
    """
    Calls the SQL Agent to get the final answer
    Parameters
    ----------
    text

    Returns
    -------
    The final answer
    """
    ans, const = ag.ask(text)
    return {"output": ans["output"]}, 12


@cl.step
async def Cleaner(text):  # just for printing
    return text


@cl.step
async def LLM(cleaned_prompt):  # just for printing
    ans, const = ask(cleaned_prompt)
    return ans, const


@cl.step
async def Choice(text):
    return text

@cl.step
async def Extractor(user_prompt):
    extracted_values = extract_func(user_prompt)
    return extracted_values


@cl.on_message  # this function will be called every time a user inputs a message in the UI
async def main(message: cl.Message):
    global interactive_key_done
    if not interactive_key_done:
        await cl.Message(
            content=f"Please set the OpenAI API key first by starting a new chat.",
        ).send()
        return
    user_prompt = message.content # Get the user prompt
    # extracted_values = extract_func(user_prompt)
    #
    # json_formatted = json.dumps(extracted_values, indent=4)
    extracted_values = await Extractor(user_prompt)
    json_formatted = json.dumps(extracted_values, indent=4)
    # Print the extracted values in json format
    await cl.Message(author="Extractor", content=f"Extracted properties:\n```json\n{json_formatted}\n```").send()
    # Try to validate everything
    validated, need_input = validate_func(extracted_values)
    await cl.Message(author="Validator", content=f"Extracted properties will now be validated against the database.").send()
    if need_input:
        # If we need validation, we will ask the user to select the correct value
        for element in need_input:
            key = next(iter(element))  # Get the first key in the dictionary
            # Present user with options to choose from
            actions = [
                cl.Action(name="option", value=value, label=value)
                for value in element['top_matches']
            ]
            await cl.Message(author="Resolver", content=f"Need to identify the correct value for {key}: ").send()
            res = await cl.AskActionMessage(author="Resolver",
                content=f"Which one do you mean for {key}?", 
                actions=actions
            ).send()
            selected_value = res.get("value") if res else ""
            element[key] = selected_value
            element.pop("top_matches")
            await Choice("Options were "+ ", ".join([action.label for action in actions]))
        # Get the cleaned prompt
        cleaned_prompt = human_validate_func(need_input, validated, user_prompt)
    else:
        cleaned_prompt = no_human(validated, user_prompt)
    # Print the cleaned prompt
    cleaner_message = cl.Message(author="Cleaner", content=f"New prompt is as follows:\n{cleaned_prompt}")
    await cleaner_message.send()

    # Call the SQL agent to get the final answer
    # ans, const = ask(cleaned_prompt)  # Get the final answer from some function
    await cl.Message(content=f"I will now query the database for information.").send()
    ans, const = await LLM(cleaned_prompt)
    await cl.Message(content=f"This is the final answer: \n\n{ans['output']}").send()