File size: 5,094 Bytes
8029d8e
121a1b0
 
 
 
8029d8e
 
121a1b0
8029d8e
 
 
121a1b0
8029d8e
121a1b0
 
 
8029d8e
 
 
 
 
 
121a1b0
 
8029d8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121a1b0
 
 
 
 
 
 
 
 
8029d8e
 
 
121a1b0
8029d8e
121a1b0
 
8029d8e
 
121a1b0
8029d8e
 
 
121a1b0
 
8029d8e
 
 
0a5800f
 
121a1b0
 
 
 
8029d8e
 
121a1b0
 
 
8029d8e
 
121a1b0
 
8029d8e
 
121a1b0
 
 
 
 
 
 
8029d8e
 
121a1b0
 
 
 
 
 
 
8029d8e
 
 
121a1b0
8029d8e
121a1b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""OpenAI GPT-3 Chatbot with Streamlit"""
import openai
import streamlit as st
from streamlit_chat import message
from transformers import pipeline


summarizer = pipeline("summarization", model="philschmid/bart-large-cnn-samsum")
sentiment_task = pipeline("sentiment-analysis",
                            model='cardiffnlp/twitter-roberta-base-sentiment-latest',
                            tokenizer='cardiffnlp/twitter-roberta-base-sentiment-latest')

openai.api_key = st.secrets["openai_api_key"]

completion = openai.Completion()

START_PROMPT = '[Instruction] Act as a friendly, compasionate, insightful, and empathetic AI \
                therapist named Joy. Joy listens and offers advices. \
                End the conversation when the patient wishes to.'
START_MESSAGE = 'I am Joy, your AI therapist. How are you feeling today?'
START_SEQUENCE = "\nJoy:"
RESTART_SEQUENCE = "\n\nPatient:"

def ask(question: str, chat_log: str, model='text-davinci-003', temp=0.9) -> (str, str):
    ''' funtion takes a input string and the preview chat_log, 
        returns the model's repsonse/answer and the dialog
        by using the preview chat_log, gold fish memory effect can be prevented,
        the chat_log is used to summarize and analyze the sentiment of the user's input '''

    prompt = f'{chat_log}{RESTART_SEQUENCE} {question}{START_SEQUENCE}'
    response = completion.create(
        prompt = prompt,
        model = model,
        stop = ["Patient:",'Joy:'],
        temperature = temp, #the higher the more creative
        frequency_penalty = 0.9, #prevents word repetition, larger -> higher penalty
        presence_penalty = 1, #prevents topic repetition, larger -> higher penalty
        top_p =1,
        best_of=1,
        max_tokens=170,
        )

    answer = response.choices[0].text.strip()
    log = f'{RESTART_SEQUENCE}{question}{START_SEQUENCE}{answer}'
    return str(answer), str(log)

def clean_chat_log(chat_log: list) -> str:
    ''' cleans the chat log by joining list items, 
    removing everything before the first \n and replace all other 
    \n with empty space.'''

    chat_log = ' '.join(chat_log)
    # find the first /n
    first_newline = chat_log.find('\n')
    chat_log = chat_log[first_newline:]
    # remove all \n
    chat_log = chat_log.replace('\n', ' ')
    return chat_log

def summarize(chat_log: list) -> str:
    ''' returns a summary of the chat log '''

    chat_log = clean_chat_log(chat_log)
    summary = summarizer(chat_log, do_sample=False)[0]['summary_text']
    return summary

def analyze_sentiment(user_input: list) -> str:
    ''' returns user sentiment based on the users input'''

    user_input = clean_chat_log(user_input)
    summary = summarizer(user_input, do_sample=False)[0]['summary_text']
    sentiment = sentiment_task(summary)
    return sentiment

def remove_backslash(chat_log: list) -> list:
    ''' removes the backslash from the chat log '''

    chat_log = [i.replace('\n', ' ') for i in chat_log]
    return chat_log



def main():
    ''' main function '''

    st.title("Chat with Joy - the AI therapist!")
    col1, col2 = st.columns(2)
    temp = col1.slider("Bot-Creativeness", 0.0, 1.0, 0.9, 0.1)
    model = col2.selectbox("Model", ["text-davinci-003",
    "text-curie-001", "curie:ft-personal-2023-02-03-17-06-53"])

    if 'generated' not in st.session_state:
        st.session_state['generated'] = [START_MESSAGE]

    if 'past' not in st.session_state:
        st.session_state['past'] = []

    if 'summary' not in st.session_state:
        st.session_state['summary'] = []

    if 'chat_log' not in st.session_state:
        st.session_state['chat_log'] = [START_PROMPT+START_SEQUENCE+START_MESSAGE]


    if len(st.session_state['generated']) > 2:
        if st.button("Clear and summerize", key='clear'):
            chat_log = clean_chat_log(st.session_state['chat_log'])
            summary = summarizer(chat_log, max_length=100, min_length=30, do_sample=False)
            st.write(summary)
            user_sentiment = st.session_state['past']
            user_sentiment = remove_backslash(user_sentiment)
            st.write(analyze_sentiment(user_sentiment))
            st.session_state['generated'] = [START_MESSAGE]
            st.session_state['past'] = []
            st.session_state['chat_log'] = [START_PROMPT+START_SEQUENCE+START_MESSAGE]
            st.session_state['summary'] = []

    user_input=st.text_input("You:",key='input')

    if user_input:
        output, chat_log = ask(user_input, st.session_state['chat_log'], model=model, temp=temp)
        st.session_state['chat_log'].append(chat_log)
        st.session_state['past'].append(user_input)
        st.session_state['generated'].append(output)
    if st.session_state['generated']:
        for i in range(len(st.session_state['generated'])-1, -1, -1):
            if i < len(st.session_state['past']):
                message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
            message(st.session_state["generated"][i], key=str(i))



if __name__ == "__main__":
    main()