File size: 13,304 Bytes
f30862a
 
 
 
 
 
 
 
 
 
93152cc
f30862a
 
93152cc
f30862a
 
 
 
 
 
 
 
 
617fa8c
342c3ab
 
 
 
 
de337bd
f30862a
342c3ab
f30862a
 
 
 
 
 
 
 
 
 
 
93152cc
 
 
 
 
 
 
 
 
 
 
 
 
 
f30862a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342c3ab
 
f30862a
 
 
342c3ab
f30862a
 
 
 
 
 
 
 
 
 
 
342c3ab
 
f30862a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342c3ab
f30862a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93152cc
 
 
f30862a
 
 
 
 
 
93152cc
f30862a
 
 
 
 
 
 
 
93152cc
f30862a
 
 
 
 
 
 
 
 
 
93152cc
f30862a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342c3ab
f30862a
342c3ab
f30862a
 
 
 
 
 
 
342c3ab
f30862a
 
 
 
 
 
 
 
 
 
 
342c3ab
f30862a
 
 
 
 
 
 
 
 
 
342c3ab
f30862a
de337bd
f30862a
 
 
 
 
 
 
342c3ab
f30862a
 
 
93152cc
342c3ab
f30862a
93152cc
f30862a
 
342c3ab
f30862a
93152cc
f30862a
93152cc
 
f30862a
 
 
93152cc
 
f30862a
 
 
 
 
 
 
 
 
 
de337bd
f30862a
 
 
93152cc
 
 
 
 
617fa8c
de337bd
 
93152cc
 
342c3ab
93152cc
f30862a
342c3ab
f30862a
de337bd
617fa8c
 
93152cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
'''
Dialog System of PsyPlus (dvq)

reference:
  https://huggingface.co/spaces/bentrevett/emotion-prediction
  https://huggingface.co/spaces/tareknaous/Empathetic-DialoGPT
  https://huggingface.co/benjaminbeilharz/t5-empatheticdialogues

gradio vs streamlit 
  https://trojrobert.github.io/a-guide-for-deploying-and-serving-machine-learning-with-model-streamlit-vs-gradio/
  https://gradio.app/interface_state/ -> global and local varible affect the separation of sessions

TODO
  Add command to reset/jump to a function, e.g >reset, >euc_100
  Add diagram in Gradio Interface showing sentimate analysis
  Gradio input timeout: cannot find a tutorial in Google -> don't know how to implement
  Personalize: create database, load and save data

Run command
  python app.py --run_on_own_server 1 --initial_chat_state free_chat
'''


import argparse
import re, time
import matplotlib.pyplot as plt
from threading import Timer
import gradio as gr

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, pipeline


def option():
  parser = argparse.ArgumentParser()
  parser.add_argument('--run_on_own_server', type=int, default=0, help='if test on own server, need to use share mode')
  parser.add_argument('--dialog_model', type=str, default='tareknaous/dialogpt-empathetic-dialogues')
  parser.add_argument('--emotion_model', type=str, default='joeddav/distilbert-base-uncased-go-emotions-student')
  parser.add_argument('--account', type=str, default=None)
  parser.add_argument('--initial_chat_state', type=str, default='euc_100', choices=['euc_100', 'euc_200', 'free_chat'])
  args = parser.parse_args()
  return args

args = option()


# store the list of messages that are showed in therapies and models as global variables
# let all chat-session-wise variables placed in TherapyChatBot
class ChatHelper:
  # chat and emotion-detection models
  ed_pipe = pipeline('text-classification', model=args.emotion_model, top_k=5, truncation=True)
  ed_threshold = 0.3
  dialog_model = GPT2LMHeadModel.from_pretrained(args.dialog_model)
  dialog_tokenizer = GPT2Tokenizer.from_pretrained(args.dialog_model)
  eos = dialog_tokenizer.eos_token
  # tokenizer.__call__ -> input_ids, attention_mask
  # tokenizer.encode -> only inputs_ids, which is required by model.generate function

  invalid_input = 'Invalid input, my friend :) Plz input again'
  good_mood_over = 'Whether your good mood is over? Any other details that you would like to recall?'
  good_case = 'Nice to hear that!'
  bad_mood_over = 'Whether your bad mood is over? (Yes or No)'
  not_answer = "It's okay, maybe you don't want to answer this question."
  fill_form = ('It has come to our attention that you may suffer from {}.\n'
              'If you want to know more about yourself, some professional scales are provided to quantify your current status.\n'
              'After a period of time (maybe a week/two months/a month) trying to follow the solutions we suggested, '
              'you can fill out these scales again to see if you have improved.\n'
              'Do you want to fill in the form now? (Okay or Later)')
  display_form = '<Display the form>.\n'
  reference = 'Here are some reference articles about bad emotions. You can take a look :) <Display references>\n'

  emotion_types = ['Overall', 'Happiness', 'Anxiety'] # 'Surprise', 'Sadness', 'Depression', 'Anger', 'Fear',
  euc_100 = {
    'q': emotion_types,
    'good_mood': [
      'You seem to be in a good mood today. Is there anything you could notice that makes you happy?',
      'I am glad that you are willing to share the experience with me. Thanks for letting me know.',
    ],
    'bad_mood': [
      'You seem not to be in a good mood. What specific thing is bothering you the most right now?',
      'I see. So when it is happening, what feelings or emotions have you got?',
      'And what do you think about those feelings or emotions at that time?',
      'Could you think of any evidence for your above-mentioned thought?',
      'Here are some reference articles about bad emotions. You can take a look :)',
    ],
  }

  negative_emotions = ['remorse', 'nervousness', 'annoyance', 'anger', 'grief', 'fear', 'disapproval',
                       'confusion', 'embarrassment', 'disgust', 'sadness', 'disappointment']
  euc_200 = 'Now go back to the last chat. You said that "{}".\n'

  greeting_template = {
    'euc_100': 'How was your day? On the scale 1 to 10, '
      'how would you judge your emotion through the following categories:\nOverall',
    # euc_200 is only trigger when you say smt more negative than a certain threshol
    # thus the greeting here is only for debuging euc_200
    'euc_200': fill_form.format('anxiety'),
    'free_chat': 'Hi you! How is it going?',
  }

  def plot_emotion_distribution(predictions):
    fig, ax = plt.subplots()
    ax.bar(x=[i for i, _ in enumerate(prediction)],
           height=[p['score'] for p in prediction],
           tick_label=[p['label'] for p in prediction])
    ax.tick_params(rotation=90)
    ax.set_ylim(0, 1)
    plt.show()

  def ed_rulebase(text):
    keywords = {
      'life_safety': ['death', 'suicide', 'murder', 'to perish together', 'jump off the building'],
      'immediacy': ['now', 'immediately', 'tomorrow', 'today'],
      'manifestation': ['never stop', 'every moment', 'strong', 'very']
    }

    # if found dangerous kw/topics
    if re.search(rf"{'|'.join(keywords['life_safety'])}", text) != None and \
        sum([re.search(rf"{'|'.join(keywords[k])}", text) != None for k in ['immediacy','manifestation']]) >= 1:
      print('We noticed that you may need immediate professional assistance, would you like to make a phone call? '
        'The Hong Kong Lifeline number is (852) 2382 0000')
      x = input('Choose 1. "Dial to the number" or 2. "No dangerous emotion la": ')
      if x == '1':
        print('Let you connect to the office')
      else:
        print('Sorry for our misdetection. We just want to make sure that you could get immediate help when needed. '
          'Would you mind if we send this conversation to the cloud to finetune the model.')
        y = input('Yes or No: ')
        if y == 'Yes':
          pass # do smt here


class TherapyChatBot:
  def __init__(self, args):
    # check state to control the dialog
    self.chat_state = args.initial_chat_state # name of the chat function/therapy segment the model is in
    self.message_prev = None
    self.chat_state_prev = None
    self.run_on_own_server = args.run_on_own_server
    self.account = args.account

    # additional attribute for euc_100
    self.euc_100_input_time = []
    self.euc_100_emotion_degree = []
    self.already_trigger_euc_200 = False

    # chat history.
    # TODO: if we want to personalize and save the conversation,
    # we can load data from database
    self.greeting = [('', ChatHelper.greeting_template[self.chat_state])]
    self.history = {'input_ids': torch.tensor([[ChatHelper.dialog_tokenizer.bos_token_id]]),
                    'text': self.greeting} if not self.account else open(f'database/{hash(self.account)}', 'rb')
    if 'euc_100' in self.chat_state:
      self.chat_state = 'euc_100.q.0'

  def __call__(self, message, prefix=''):
    # if prefix != None, which means this function is called from euc_200, thus already detected the negative emotion
    if (not prefix) and self.chat_state != 'euc_200' and not self.already_trigger_euc_200:
      prediction = ChatHelper.ed_pipe(message)[0]
      prediction = sorted(prediction, key=lambda x: x['score'], reverse=True)
      if self.run_on_own_server:
        print(prediction)
      # plot_emotion_distribution(prediction)
      emotion = prediction[0]

    # if message is negative, change state immediately
    if ((not prefix) and self.chat_state != 'euc_200' and not self.already_trigger_euc_200) and \
      (emotion['label'] in ChatHelper.negative_emotions and emotion['score'] > ChatHelper.ed_threshold):
      self.chat_state_prev = self.chat_state
      self.chat_state = 'euc_200'
      self.message_prev = message
      self.already_trigger_euc_200 = True
      response = ChatHelper.fill_form.format(emotion['label'])

    # set up rule to update state inside each dialog function
    elif self.chat_state.startswith('euc_100'):
      response = self.euc_100(message)
      if self.chat_state == 'free_chat':
        last_two_turns_ids = ChatHelper.dialog_tokenizer.encode(message + ChatHelper.eos, return_tensors='pt')
        self.history['input_ids'] = torch.cat([self.history['input_ids'], last_two_turns_ids], dim=-1)

    elif self.chat_state.startswith('euc_200'):
      return self.euc_200(message)

    else: # free_chat
      response = self.free_chat(message)

    if prefix:
      response = prefix + response
      self.history['text'].append((self.message_prev, response))
    else:
      self.history['text'].append((message, response))

  def euc_100(self, x):
    _, subsection, entry = self.chat_state.split('.')
    entry = int(entry)

    if subsection == 'q':
      if x.isnumeric() and (0 < int(x) < 11):
        self.euc_100_emotion_degree.append(int(x))
        self.euc_100_input_time.append(time.gmtime())
        if entry == len(ChatHelper.euc_100['q']) - 1:
          if self.run_on_own_server:
            print(self.euc_100_emotion_degree)
          mood = 'good_mood' if self.euc_100_emotion_degree[0] > 5 else 'bad_mood'
          self.chat_state = f'euc_100.{mood}.0'
          response = ChatHelper.euc_100[mood][0]
        else:
          self.chat_state = f'euc_100.q.{entry+1}'
          response = ChatHelper.euc_100['q'][entry+1]
      else:
        response = ChatHelper.invalid_input

    elif subsection == 'good_mood':
      if x == '':
        response = ChatHelper.good_mood_over
      else:
        response = ChatHelper.good_case
      response += '\n' + ChatHelper.euc_100['good_mood'][1]
      self.chat_state = 'free_chat'

    elif subsection == 'bad_mood':
      if entry == -1:
        if 'yes' in x.lower() or 'better' in x.lower():
          response = ChatHelper.good_case
        else:
          entry = int(self.chat_state_prev.rsplit('.', 1))
          response = ChatHelper.not_answer + '\n' + ChatHelper.euc_100['bad_mood'][entry+1]
          if entry == len(ChatHelper.euc_100['bad_mood']) - 2:
            self.chat_state = 'free_chat'
          else:
            self.chat_state = f'euc_100.bad_mood.{entry+1}'

      if x == '':
        response = ChatHelper.bad_mood_over
        self.chat_state_prev = self.chat_state
        self.chat_state = 'euc_100.bad_mood.-1'
      else:
        response = ChatHelper.euc_100['bad_mood'][entry+1]
        if entry == len(ChatHelper.euc_100['bad_mood']) - 2:
          self.chat_state = 'free_chat'
        else:
          self.chat_state = f'euc_100.bad_mood.{entry+1}'

    return response

  def euc_200(self, x):
    # don't ask question in euc_200, because they're similar to question in euc_100
    if x.lower() == 'okay':
      response = ChatHelper.display_form
    else:
      response = ChatHelper.reference
    response += ChatHelper.euc_200.format(self.message_prev)

    message = self.message_prev
    self.message_prev = x
    self.chat_state = self.chat_state_prev
    return self.__call__(message, prefix=response)

  def free_chat(self, message):
    message_ids = ChatHelper.dialog_tokenizer.encode(message + ChatHelper.eos, return_tensors='pt')
    self.history['input_ids'] = torch.cat([self.history['input_ids'], message_ids], dim=-1)
    input_ids = self.history['input_ids'].clone()

    while True:
      bot_output_ids = ChatHelper.dialog_model.generate(input_ids, max_length=1000,
                                          do_sample=True, top_p=0.9, temperature=0.8, num_beams=2,
                                          pad_token_id=ChatHelper.dialog_tokenizer.eos_token_id)
      response = ChatHelper.dialog_tokenizer.decode(bot_output_ids[0][input_ids.shape[-1]:],
                                       skip_special_tokens=True)
      if response.strip() != '':
        break
      elif input_ids[0].tolist().count(ChatHelper.dialog_tokenizer.eos_token_id) > 0:
        idx = input_ids[0].tolist().index(ChatHelper.dialog_tokenizer.eos_token_id)
        input_ids = input_ids[:, (idx+1):]
      else:
        input_ids = message_ids

      if self.run_on_own_server:
        print(input_ids)

    self.history['input_ids'] = torch.cat([self.history['input_ids'], bot_output_ids[0:1, input_ids.shape[-1]:]], dim=-1)
    if self.run_on_own_server == 1:
      print((message, response), '\n', self.history['input_ids'])

    return response


if __name__ == '__main__':  
  def chat(message, bot):
    bot = bot or TherapyChatBot(args)
    bot(message)
    return bot.history['text'], bot

  title = 'PsyPlus Empathetic Chatbot'
  description = 'Gradio demo for product of PsyPlus. Based on rule-based CBT and conversational AI model DialoGPT'
  greeting = [('', ChatHelper.greeting_template[args.initial_chat_state])]
  chatbot = gr.Chatbot(value=greeting)
  iface = gr.Interface(
    chat, ['text', 'state'], [chatbot, 'state'],
    allow_flagging='never', title=title, description=description,
  )

  if args.run_on_own_server == 0:
    iface.launch(debug=True)
  else:
    iface.launch(debug=True, share=True)