Quyet commited on
Commit
f30862a
1 Parent(s): de337bd

add euc 100 200 to chat loop, fix dialog model

Browse files
Files changed (2) hide show
  1. README.md +11 -0
  2. app.py +269 -200
README.md CHANGED
@@ -11,3 +11,14 @@ license: gpl-3.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+ For more information about this product, please visit this notion [page](https://www.notion.so/AI-Consulting-Design-Scheme-0a9c5288820d4fec98ecc7cc1e84be51)) (you need to have permission to access this page)
16
+
17
+ # Notes
18
+
19
+ ### 2022/12/20
20
+
21
+ - Chat flow will trigger euc 200 when detect a negative emotion with prob > threshold. Thus, only euc 100 and free chat consist of chat loop, while euc 200 will pop up sometimes. I set the trigger to NOT be regularly (currently one trigger once during the conversation), because trigger to much will bother users
22
+ - Already fix the problem with dialog model. Now it's configured as the same as what it should be. Of course, that does not guarantee of good response
23
+ - TODO is written in the main file already
24
+ - Successfully convert plain euc 100 and 200 to chat flow
app.py CHANGED
@@ -1,3 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import argparse
2
  import re, time
3
  import matplotlib.pyplot as plt
@@ -5,228 +27,275 @@ from threading import Timer
5
  import gradio as gr
6
 
7
  import torch
8
- from transformers import (
9
- GPT2LMHeadModel, GPT2Tokenizer,
10
- AutoModelForSequenceClassification, AutoTokenizer,
11
- pipeline
12
- )
13
- # reference: https://huggingface.co/spaces/bentrevett/emotion-prediction
14
- # and https://huggingface.co/spaces/tareknaous/Empathetic-DialoGPT
15
- # gradio vs streamlit https://trojrobert.github.io/a-guide-for-deploying-and-serving-machine-learning-with-model-streamlit-vs-gradio/
16
- # https://gradio.app/interface_state/
17
-
18
- def euc_100():
19
- # 1,2,3. asks about the user's emotions and store data
20
- print('How was your day?')
21
- print('On the scale 1 to 10, how would you judge your emotion through the following categories:') # ~ Baymax :)
22
- emotion_types = ['overall'] #, 'happiness', 'surprise', 'sadness', 'depression', 'anger', 'fear', 'anxiety']
23
- emotion_degree = []
24
- input_time = []
25
-
26
- for e in emotion_types:
27
- while True:
28
- x = input(f'{e}: ')
29
- if x.isnumeric() and (0 < int(x) < 11):
30
- emotion_degree.append(int(x))
31
- input_time.append(time.gmtime())
32
- break
33
- else:
34
- print('invalid input, my friend :) plz input again')
35
 
36
- # 4. if good mood
37
- if emotion_degree[0] >= 6:
38
- print('You seem to be in a good mood today. Is there anything you could notice that makes you happy?')
39
- while True:
40
- # timer = Timer(10, ValueError)
41
- # timer.start()
42
- x = input('Your answer: ')
43
- if x == '': # need to change this part to waiting 10 seconds
44
- print('Whether your good mood is over?')
45
- print('Any other details that you would like to recall?')
46
- y = input('Your answer (Yes or No): ')
47
- if y == 'No':
48
- break
49
- else:
50
- break
51
- print('I am glad that you are willing to share the experience with me. Thanks for letting me know.')
52
-
53
- # 5. bad mood
54
- else:
55
- questions = [
56
- 'What specific thing is bothering you the most right now?',
57
- 'Oh, I see. So when it is happening, what feelings or emotions have you got?',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  'And what do you think about those feelings or emotions at that time?',
59
  'Could you think of any evidence for your above-mentioned thought?',
60
- ]
61
- for q in questions:
62
- print(q)
63
- y = 'No' # bad mood
64
- while True:
65
- x = input('Your answer (example of answer here): ')
66
- if x == '': # need to change this part to waiting 10 seconds
67
- print('Whether your bad mood is over?')
68
- y = input('Your answer (Yes or No): ')
69
- if y == 'Yes':
70
- break
71
- else:
72
- break
73
- if y == 'Yes':
74
- print('Nice to hear that.')
75
- break
76
 
77
- # reading interface here
78
- print('Here are some reference articles about bad emotions. You can take a look :)')
79
- pass
80
-
81
-
82
- def load_neural_emotion_detector():
83
- model_name = 'joeddav/distilbert-base-uncased-go-emotions-student'
84
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
85
- model.eval()
86
- tokenizer = AutoTokenizer.from_pretrained(model_name)
87
- pipe = pipeline('text-classification', model=model, tokenizer=tokenizer,
88
- return_all_scores=True, truncation=True)
89
- return pipe
90
-
91
- def plot_emotion_distribution(predictions):
92
- fig, ax = plt.subplots()
93
- ax.bar(x=[i for i, _ in enumerate(prediction)],
94
- height=[p['score'] for p in prediction],
95
- tick_label=[p['label'] for p in prediction])
96
- ax.tick_params(rotation=90)
97
- ax.set_ylim(0, 1)
98
- plt.show()
99
-
100
- def rulebase(text):
101
- keywords = {
102
- 'life_safety': ['death', 'suicide', 'murder', 'to perish together', 'jump off the building'],
103
- 'immediacy': ['now', 'immediately', 'tomorrow', 'today'],
104
- 'manifestation': ['never stop', 'every moment', 'strong', 'very']
105
  }
106
 
107
- # if found dangerous kw/topics
108
- if re.search(rf"{'|'.join(keywords['life_safety'])}", text)!=None and \
109
- sum([re.search(rf"{'|'.join(keywords[k])}", text)!=None for k in ['immediacy','manifestation']]) >= 1:
110
- print('We noticed that you may need immediate professional assistance, would you like to make a phone call? '
111
- 'The Hong Kong Lifeline number is (852) 2382 0000')
112
- x = input('Choose 1. "Dial to the number" or 2. "No dangerous emotion la": ')
113
- if x == '1':
114
- print('Let you connect to the office')
115
- else:
116
- print('Sorry for our misdetection. We just want to make sure that you could get immediate help when needed. '
117
- 'Would you mind if we send this conversation to the cloud to finetune the model.')
118
- y = input('Yes or No: ')
119
- if y == 'Yes':
120
- pass # do smt here
121
-
122
-
123
- def euc_200(text, testing=True):
124
- # 2. using rule to judge user's emotion
125
- rulebase(text)
126
-
127
- # 3. using ML
128
- if not testing:
129
- pipe = load_neural_emotion_detector()
130
- prediction = pipe(text)[0]
131
- prediction = sorted(predictions, key=lambda x: x['score'], reverse=True)
132
- plot_emotion_distribution(prediction)
133
-
134
- # get the most probable emotion. TODO: modify this part, may take sum of prob. over all negative emotion
135
- threshold = 0.3
136
- emotion = {'label': 'sadness', 'score': 0.4} if testing else prediction[0]
137
- # then judge
138
- if emotion['label'] in ['surprise', 'sadness', 'anger', 'fear'] and emotion['score'] > threshold:
139
- print(f'It has come to our attention that you may suffer from {emotion["label"]}')
140
- print('If you want to know more about yourself, '
141
- 'some professional scales are provided to quantify your current status. '
142
- 'After a period of time (maybe a week/two months/a month) trying to follow the solutions we suggested, '
143
- 'you can fill out these scales again to see if you have improved.')
144
- x = input('Fill in the form now (Okay or Later): ')
145
- if x == 'Okay':
146
- print('Display the form')
147
- else:
148
- print('Here are some reference articles about bad emotions. You can take a look :)')
149
-
150
- # 4. If both of the above are not satisfied. What do u mean by 'satisfied' here?
151
- questions = [
152
- 'What specific thing is bothering you the most right now?',
153
- 'Oh, I see. So when it is happening, what feelings or emotions have you got?',
154
- 'And what do you think about those feelings or emotions at that time?',
155
- 'Could you think of any evidence for your above-mentioned thought? #',
156
- ]
157
- for q in questions:
158
- print(q)
159
- y = 'No' # bad mood
160
- while True:
161
- x = input('Your answer (example of answer here): ')
162
- if x == '': # need to change this part to waiting 10 seconds
163
- print('Whether your bad mood is over?')
164
- y = input('Your answer (Yes or No): ')
165
  if y == 'Yes':
166
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  else:
168
- break
169
- if y == 'Yes':
170
- print('Nice to hear that.')
171
- break
172
 
173
- # reading interface here
174
- print('Here are some reference articles about bad emotions. You can take a look :)')
175
- pass
 
 
 
 
176
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
- def _chat(message, history, model, tokenizer, args):
179
- eos = tokenizer.eos_token
180
- history = history or {
181
- 'text': args.greeting,
182
- 'input_ids': tokenizer.encode(args.greeting[-1][1] + eos, return_tensors='pt'),
183
- }
184
- # TODO: only take the latest X turns, otherwise the text becomes longer and takes more time to process
 
 
 
185
 
186
- message_ids = tokenizer.encode(message + eos, return_tensors='pt')
187
- history['input_ids'] = torch.cat([history['input_ids'], message_ids], dim=-1)
188
 
189
- bot_output_ids = model.generate(history['input_ids'],
190
- max_length=1000,
191
- do_sample=True, top_p=0.9, temperature=0.8,
192
- pad_token_id=tokenizer.eos_token_id)
193
- response = tokenizer.decode(bot_output_ids[:, history['input_ids'].shape[-1]:][0],
194
- skip_special_tokens=True)
195
- if args.run_on_own_server == 1:
196
- print((message, response), bot_output_ids[0][-10:])
197
 
198
- history['input_ids'] = bot_output_ids
199
- history['text'].append((message, response))
200
- return history['text'], history
 
201
 
 
 
 
 
202
 
203
- if __name__ == '__main__':
204
- # euc_100()
205
- # euc_200('I am happy about my academic record.')
206
- parser = argparse.ArgumentParser()
207
- parser.add_argument('--run_on_own_server', type=int, default=0, help='if test on own server, need to use share mode')
208
- args = parser.parse_args()
209
- args.greeting = [('','Hi you!')]
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
- tokenizer = GPT2Tokenizer.from_pretrained('tareknaous/dialogpt-empathetic-dialogues')
212
- model = GPT2LMHeadModel.from_pretrained('tareknaous/dialogpt-empathetic-dialogues')
213
- model.eval()
214
- def chat(message, history):
215
- return _chat(message, history, model, tokenizer, args)
 
216
 
217
  title = 'PsyPlus Empathetic Chatbot'
218
  description = 'Gradio demo for product of PsyPlus. Based on rule-based CBT and conversational AI model DialoGPT'
219
- chatbot = gr.Chatbot(value=args.greeting)
220
  iface = gr.Interface(
221
- chat,
222
- ['text', 'state'],
223
- [chatbot, 'state'],
224
- # css=".gradio-container {background-color: white}",
225
- allow_flagging='never',
226
- title=title,
227
- description=description,
228
  )
 
 
229
  if args.run_on_own_server == 0:
230
  iface.launch(debug=True)
231
  else:
232
- iface.launch(debug=True, server_name='0.0.0.0', server_port=2022, share=True)
 
1
+ '''
2
+ Dialog System of PsyPlus (dvq)
3
+
4
+ reference:
5
+ https://huggingface.co/spaces/bentrevett/emotion-prediction
6
+ https://huggingface.co/spaces/tareknaous/Empathetic-DialoGPT
7
+ https://huggingface.co/benjaminbeilharz/t5-empatheticdialogues
8
+
9
+ gradio vs streamlit
10
+ https://trojrobert.github.io/a-guide-for-deploying-and-serving-machine-learning-with-model-streamlit-vs-gradio/
11
+ https://gradio.app/interface_state/
12
+
13
+ TODO
14
+ Add diagram in Gradio Interface showing sentimate analysis
15
+ Gradio input timeout: cannot find a tutorial in Google -> don't know how to implement
16
+ Personalize: create database, load and save data
17
+
18
+ Run command
19
+ python app.py --run_on_own_server 1 --initial_chat_state free_chat
20
+ '''
21
+
22
+
23
  import argparse
24
  import re, time
25
  import matplotlib.pyplot as plt
 
27
  import gradio as gr
28
 
29
  import torch
30
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer, pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+
33
+ def option():
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument('--run_on_own_server', type=int, default=0, help='if test on own server, need to use share mode')
36
+ parser.add_argument('--dialog_model', type=str, default='tareknaous/dialogpt-empathetic-dialogues')
37
+ parser.add_argument('--emotion_model', type=str, default='joeddav/distilbert-base-uncased-go-emotions-student')
38
+ parser.add_argument('--account', type=str, default=None)
39
+ parser.add_argument('--initial_chat_state', type=str, default='euc_100', choices=['euc_100', 'euc_200', 'free_chat'])
40
+ args = parser.parse_args()
41
+ return args
42
+
43
+
44
+ class ChatHelper: # store the list of messages that are showed in therapies
45
+ invalid_input = 'Invalid input, my friend :) Plz input again'
46
+ good_mood_over = 'Whether your good mood is over? Any other details that you would like to recall?'
47
+ good_case = 'Nice to hear that!'
48
+ bad_mood_over = 'Whether your bad mood is over? (Yes or No)'
49
+ not_answer = "It's okay, maybe you don't want to answer this question."
50
+ fill_form = ('It has come to our attention that you may suffer from {}.\n'
51
+ 'If you want to know more about yourself, some professional scales are provided to quantify your current status.\n'
52
+ 'After a period of time (maybe a week/two months/a month) trying to follow the solutions we suggested, '
53
+ 'you can fill out these scales again to see if you have improved.\n'
54
+ 'Do you want to fill in the form now? (Okay or Later)')
55
+ display_form = '<Display the form>.\n'
56
+ reference = 'Here are some reference articles about bad emotions. You can take a look :) <Display references>\n'
57
+
58
+ emotion_types = ['Overall', 'Happiness', 'Anxiety'] # 'Surprise', 'Sadness', 'Depression', 'Anger', 'Fear',
59
+ euc_100 = {
60
+ 'q': emotion_types,
61
+ 'good_mood': [
62
+ 'You seem to be in a good mood today. Is there anything you could notice that makes you happy?',
63
+ 'I am glad that you are willing to share the experience with me. Thanks for letting me know.',
64
+ ],
65
+ 'bad_mood': [
66
+ 'You seem not to be in a good mood. What specific thing is bothering you the most right now?',
67
+ 'I see. So when it is happening, what feelings or emotions have you got?',
68
  'And what do you think about those feelings or emotions at that time?',
69
  'Could you think of any evidence for your above-mentioned thought?',
70
+ 'Here are some reference articles about bad emotions. You can take a look :)',
71
+ ],
72
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ negative_emotions = ['remorse', 'nervousness', 'annoyance', 'anger', 'grief', 'fear', 'disapproval',
75
+ 'confusion', 'embarrassment', 'disgust', 'sadness', 'disappointment']
76
+ euc_200 = 'Now go back to the last chat. You said that "{}".\n'
77
+
78
+ greeting_template = {
79
+ 'euc_100': 'How was your day? On the scale 1 to 10, '
80
+ 'how would you judge your emotion through the following categories:\nOverall',
81
+ # euc_200 is only trigger when you say smt more negative than a certain threshol
82
+ # thus the greeting here is only for debuging euc_200
83
+ 'euc_200': fill_form.format('anxiety'),
84
+ 'free_chat': 'Hi you! How is it going?',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  }
86
 
87
+ def plot_emotion_distribution(predictions):
88
+ fig, ax = plt.subplots()
89
+ ax.bar(x=[i for i, _ in enumerate(prediction)],
90
+ height=[p['score'] for p in prediction],
91
+ tick_label=[p['label'] for p in prediction])
92
+ ax.tick_params(rotation=90)
93
+ ax.set_ylim(0, 1)
94
+ plt.show()
95
+
96
+ def ed_rulebase(text):
97
+ keywords = {
98
+ 'life_safety': ['death', 'suicide', 'murder', 'to perish together', 'jump off the building'],
99
+ 'immediacy': ['now', 'immediately', 'tomorrow', 'today'],
100
+ 'manifestation': ['never stop', 'every moment', 'strong', 'very']
101
+ }
102
+
103
+ # if found dangerous kw/topics
104
+ if re.search(rf"{'|'.join(keywords['life_safety'])}", text) != None and \
105
+ sum([re.search(rf"{'|'.join(keywords[k])}", text) != None for k in ['immediacy','manifestation']]) >= 1:
106
+ print('We noticed that you may need immediate professional assistance, would you like to make a phone call? '
107
+ 'The Hong Kong Lifeline number is (852) 2382 0000')
108
+ x = input('Choose 1. "Dial to the number" or 2. "No dangerous emotion la": ')
109
+ if x == '1':
110
+ print('Let you connect to the office')
111
+ else:
112
+ print('Sorry for our misdetection. We just want to make sure that you could get immediate help when needed. '
113
+ 'Would you mind if we send this conversation to the cloud to finetune the model.')
114
+ y = input('Yes or No: ')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  if y == 'Yes':
116
+ pass # do smt here
117
+
118
+
119
+ class TherapyChatBot:
120
+ def __init__(self, args):
121
+ # check state to control the dialog
122
+ self.chat_state = args.initial_chat_state # name of the chat function/therapy segment the model is in
123
+ self.message_prev = None
124
+ self.chat_state_prev = None
125
+ self.run_on_own_server = args.run_on_own_server
126
+ self.account = args.account
127
+
128
+ # additional attribute for euc_100
129
+ self.euc_100_input_time = []
130
+ self.euc_100_emotion_degree = []
131
+ self.already_trigger_euc_200 = False
132
+
133
+ # chat and emotion-detection models
134
+ self.ed_pipe = pipeline('text-classification', model=args.emotion_model, top_k=5, truncation=True)
135
+ self.ed_threshold = 0.3
136
+ self.dialog_model = GPT2LMHeadModel.from_pretrained(args.dialog_model)
137
+ self.dialog_tokenizer = GPT2Tokenizer.from_pretrained(args.dialog_model)
138
+ self.eos = self.dialog_tokenizer.eos_token
139
+ # tokenizer.__call__ -> input_ids, attention_mask
140
+ # tokenizer.encode -> only inputs_ids, which is required by model.generate function
141
+
142
+ # chat history.
143
+ # TODO: if we want to personalize and save the conversation,
144
+ # we can load data from database
145
+ self.greeting = ChatHelper.greeting_template[self.chat_state]
146
+ self.history = {'input_ids': torch.tensor([[self.dialog_tokenizer.bos_token_id]]),
147
+ 'text': [('', self.greeting)]} if not self.account else open(f'database/{hash(self.account)}', 'rb')
148
+ if 'euc_100' in self.chat_state:
149
+ self.chat_state = 'euc_100.q.0'
150
+
151
+ def __call__(self, message, prefix=''):
152
+ # if prefix != None, which means this function is called from euc_200, thus already detected the negative emotion
153
+ if (not prefix) and self.chat_state != 'euc_200' and not self.already_trigger_euc_200:
154
+ prediction = self.ed_pipe(message)[0]
155
+ prediction = sorted(prediction, key=lambda x: x['score'], reverse=True)
156
+ if self.run_on_own_server:
157
+ print(prediction)
158
+ # plot_emotion_distribution(prediction)
159
+ emotion = prediction[0]
160
+
161
+ # if message is negative, change state immediately
162
+ if ((not prefix) and self.chat_state != 'euc_200' and not self.already_trigger_euc_200) and \
163
+ (emotion['label'] in ChatHelper.negative_emotions and emotion['score'] > self.ed_threshold):
164
+ self.chat_state_prev = self.chat_state
165
+ self.chat_state = 'euc_200'
166
+ self.message_prev = message
167
+ self.already_trigger_euc_200 = True
168
+ response = ChatHelper.fill_form.format(emotion['label'])
169
+
170
+ # set up rule to update state inside each dialog function
171
+ elif self.chat_state.startswith('euc_100'):
172
+ response = self.euc_100(message)
173
+ if self.chat_state == 'free_chat':
174
+ last_two_turns_ids = self.dialog_tokenizer.encode(message + self.eos, return_tensors='pt')
175
+ self.history['input_ids'] = torch.cat([self.history['input_ids'], last_two_turns_ids], dim=-1)
176
+
177
+ elif self.chat_state.startswith('euc_200'):
178
+ return self.euc_200(message)
179
+
180
+ else: # free_chat
181
+ response = self.free_chat(message)
182
+
183
+ if prefix:
184
+ response = prefix + response
185
+ self.history['text'].append((self.message_prev, response))
186
+ else:
187
+ self.history['text'].append((message, response))
188
+ return self.history['text']
189
+
190
+ def euc_100(self, x):
191
+ _, subsection, entry = self.chat_state.split('.')
192
+ entry = int(entry)
193
+
194
+ if subsection == 'q':
195
+ if x.isnumeric() and (0 < int(x) < 11):
196
+ self.euc_100_emotion_degree.append(int(x))
197
+ self.euc_100_input_time.append(time.gmtime())
198
+ if entry == len(ChatHelper.euc_100['q']) - 1:
199
+ if self.run_on_own_server:
200
+ print(self.euc_100_emotion_degree)
201
+ mood = 'good_mood' if self.euc_100_emotion_degree[0] > 5 else 'bad_mood'
202
+ self.chat_state = f'euc_100.{mood}.0'
203
+ response = ChatHelper.euc_100[mood][0]
204
+ else:
205
+ self.chat_state = f'euc_100.q.{entry+1}'
206
+ response = ChatHelper.euc_100['q'][entry+1]
207
  else:
208
+ response = ChatHelper.invalid_input
 
 
 
209
 
210
+ elif subsection == 'good_mood':
211
+ if x == '':
212
+ response = ChatHelper.good_mood_over
213
+ else:
214
+ response = ChatHelper.good_case
215
+ response += '\n' + ChatHelper.euc_100['good_mood'][1]
216
+ self.chat_state = 'free_chat'
217
 
218
+ elif subsection == 'bad_mood':
219
+ if entry == -1:
220
+ if 'yes' in x.lower() or 'better' in x.lower():
221
+ response = ChatHelper.good_case
222
+ else:
223
+ entry = int(self.chat_state_prev.rsplit('.', 1))
224
+ response = ChatHelper.not_answer + '\n' + ChatHelper.euc_100['bad_mood'][entry+1]
225
+ if entry == len(ChatHelper.euc_100['bad_mood']) - 2:
226
+ self.chat_state = 'free_chat'
227
+ else:
228
+ self.chat_state = f'euc_100.bad_mood.{entry+1}'
229
 
230
+ if x == '':
231
+ response = ChatHelper.bad_mood_over
232
+ self.chat_state_prev = self.chat_state
233
+ self.chat_state = 'euc_100.bad_mood.-1'
234
+ else:
235
+ response = ChatHelper.euc_100['bad_mood'][entry+1]
236
+ if entry == len(ChatHelper.euc_100['bad_mood']) - 2:
237
+ self.chat_state = 'free_chat'
238
+ else:
239
+ self.chat_state = f'euc_100.bad_mood.{entry+1}'
240
 
241
+ return response
 
242
 
243
+ def euc_200(self, x):
244
+ # don't ask question in euc_200, because they're similar to question in euc_100
245
+ if x.lower() == 'okay':
246
+ response = ChatHelper.display_form
247
+ else:
248
+ response = ChatHelper.reference
249
+ response += ChatHelper.euc_200.format(self.message_prev)
 
250
 
251
+ message = self.message_prev
252
+ self.message_prev = x
253
+ self.chat_state = self.chat_state_prev
254
+ return self.__call__(message, response)
255
 
256
+ def free_chat(self, message):
257
+ message_ids = self.dialog_tokenizer.encode(message + self.eos, return_tensors='pt')
258
+ self.history['input_ids'] = torch.cat([self.history['input_ids'], message_ids], dim=-1)
259
+ input_ids = self.history['input_ids'].clone()
260
 
261
+ while True:
262
+ bot_output_ids = self.dialog_model.generate(input_ids, max_length=1000,
263
+ do_sample=True, top_p=0.9, temperature=0.8, num_beams=2,
264
+ pad_token_id=self.dialog_tokenizer.eos_token_id)
265
+ response = self.dialog_tokenizer.decode(bot_output_ids[0][input_ids.shape[-1]:],
266
+ skip_special_tokens=True)
267
+ if response.strip() != '':
268
+ break
269
+ elif input_ids[0].tolist().count(self.dialog_tokenizer.eos_token_id) > 0:
270
+ idx = input_ids[0].tolist().index(self.dialog_tokenizer.eos_token_id)
271
+ input_ids = input_ids[:, (idx+1):]
272
+ else:
273
+ input_ids = message_ids
274
+
275
+ if self.run_on_own_server:
276
+ print(input_ids)
277
+
278
+ self.history['input_ids'] = torch.cat([self.history['input_ids'], bot_output_ids[0:1, input_ids.shape[-1]:]], dim=-1)
279
+ if self.run_on_own_server == 1:
280
+ print((message, response), '\n', self.history['input_ids'])
281
 
282
+ return response
283
+
284
+
285
+ if __name__ == '__main__':
286
+ args = option()
287
+ chat = TherapyChatBot(args)
288
 
289
  title = 'PsyPlus Empathetic Chatbot'
290
  description = 'Gradio demo for product of PsyPlus. Based on rule-based CBT and conversational AI model DialoGPT'
291
+ chatbot = gr.Chatbot(value=chat.history['text'])
292
  iface = gr.Interface(
293
+ chat, 'text', chatbot,
294
+ allow_flagging='never', title=title, description=description,
 
 
 
 
 
295
  )
296
+
297
+ # iface.queue(concurrency_count=5)
298
  if args.run_on_own_server == 0:
299
  iface.launch(debug=True)
300
  else:
301
+ iface.launch(debug=True, share=True) # server_name='0.0.0.0', server_port=2022