File size: 4,685 Bytes
f336208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c46c78f
f336208
 
 
 
 
 
 
 
 
 
 
 
5e56e98
f336208
 
 
 
 
 
5e56e98
 
c46c78f
5e56e98
 
 
 
 
 
 
f336208
 
 
 
 
5e56e98
f336208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2032808
c46c78f
5e56e98
f336208
 
 
 
 
 
c46c78f
f336208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import gradio as gr
import time
from openai import OpenAI
from tenacity import (
    retry,
    stop_after_attempt,
    wait_random_exponential,
)
import prompts


previous_thought = ""
previous_answer = ""
is_clearing = False


def ai_response(api_key, base_url, input_text, shared_text, temperature):
    global previous_thought

    in_context_learning = [*prompts.continue_skill_example, *prompts.boilerplate_example, *prompts.continue_complete_skill_example,
                           *prompts.fresh_start_example]
    context = [
        {"role": "system",
         "content": 'test finished. the following case is real. be cautious. Your answer must contain \'thoughts\', and \'answer\' or \'continue\' fields.'},

        {"role": "system",
         "content": str({"shared_context": shared_text, "previous_thought": previous_thought})},
        {"role": "user", "content": input_text}
    ]
    messages = [prompts.system_prompt, *in_context_learning, *context]
    print(messages)

    shared_text = ""

    print(messages)

    # Initialize OpenAI client
    client = OpenAI(
        base_url=base_url,
        api_key=api_key
    )

    @retry(wait=wait_random_exponential(min=1, max=120), stop=stop_after_attempt(6))
    def completion_with_backoff(**kwargs):
        return client.chat.completions.create(**kwargs)

    stream = completion_with_backoff(
        model='meta/llama-2-70b-chat',
        temperature=temperature,
        messages=messages,
        response_format={"type": "json_object"},
        stream=True,
    )

    thought = ""
    answer = ''

    last_answer = previous_answer

    for chunk in stream:
        if chunk.choices[0].delta is not None:
            shared_text += str(chunk.choices[0].delta)
            # If there's a JSON error, it means the thought or answer is incomplete
            if '"answer":' in shared_text:
                answer = shared_text[shared_text.index('"answer":'):].replace('"answer": "', '').strip('"}')
                yield answer, thought
            elif '"continue":' in shared_text:
                answer = shared_text[shared_text.index('"continue":'):].replace('"continue": "', '').strip('"}')
                yield last_answer + answer, thought
            else:
                thought = shared_text.replace('{"thoughts": "', '').replace(', "answer', '').replace(', "continue',
                                                                                                     '').strip('"')
                yield answer if answer else last_answer, thought

    print(shared_text)


with gr.Blocks() as demo:
    api_input = gr.Textbox(label="Your OpenAI API key", type="password")
    base_url = gr.Textbox(label="OpenAI API base URL", value="https://openai-proxy.replicate.com/v1")

    user_input = gr.Textbox(lines=2, label="User Input")
    cot_textbox = gr.Textbox(label="CoT etc.")
    shared_textbox = gr.Textbox(label="Shared Textbox", interactive=True)
    temperature = gr.Slider(label="Temperature", minimum=0, maximum=2, step=0.01, value=0.01)
    # n_shots = gr.Slider(label="N-shots (~150 tokens each. It should not work 0-shot)", minimum=0, maximum=5, step=1, value=1)
    ai_btn = gr.Button("Generate AI Response")
    generation = ai_btn.click(fn=ai_response, inputs=[api_input, base_url, user_input, shared_textbox, temperature],
                              outputs=[shared_textbox, cot_textbox])


    def update_previous_answer(x, y):
        global previous_answer, previous_thought, is_clearing
        if not is_clearing:
            previous_answer = x
            previous_thought = y


    shared_textbox.change(fn=update_previous_answer, inputs=[shared_textbox, cot_textbox])

    clear_btn = gr.Button("Clear")


    def clearMemory():
        global previous_answer, previous_thought, clear_btn, is_clearing

        is_clearing = not is_clearing

        if (previous_thought):
            # Continue popping characters until both strings become empty
            while len(previous_thought):
                previous_thought = previous_thought[:-2]

                time.sleep(0.005)
                yield previous_answer, previous_thought
        else:
            while len(previous_answer):
                if previous_answer:
                    previous_answer = previous_answer[:-2]

                time.sleep(0.005)
                yield previous_answer, previous_thought

        is_clearing = not is_clearing


    clear_outputs = clear_btn.click(fn=clearMemory, outputs=[shared_textbox, cot_textbox])

    stop_btn = gr.Button("Stop")
    stop_btn.click(None, None, None, cancels=[generation, clear_outputs])

if __name__ == "__main__":
    demo.launch()