Spaces:
Build error
Build error
''' | |
Dialog System of PsyPlus (dvq) | |
reference: | |
https://huggingface.co/spaces/bentrevett/emotion-prediction | |
https://huggingface.co/spaces/tareknaous/Empathetic-DialoGPT | |
https://huggingface.co/benjaminbeilharz/t5-empatheticdialogues | |
gradio vs streamlit | |
https://trojrobert.github.io/a-guide-for-deploying-and-serving-machine-learning-with-model-streamlit-vs-gradio/ | |
https://gradio.app/interface_state/ -> global and local varible affect the separation of sessions | |
TODO | |
Add command to reset/jump to a function, e.g >reset, >euc_100 | |
Add diagram in Gradio Interface showing sentimate analysis | |
Gradio input timeout: cannot find a tutorial in Google -> don't know how to implement | |
Personalize: create database, load and save data | |
Run command | |
python app.py --run_on_own_server 1 --initial_chat_state free_chat | |
''' | |
import argparse | |
import re, time | |
import matplotlib.pyplot as plt | |
from threading import Timer | |
import gradio as gr | |
import torch | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer, pipeline | |
def option(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--run_on_own_server', type=int, default=0, help='if test on own server, need to use share mode') | |
parser.add_argument('--dialog_model', type=str, default='tareknaous/dialogpt-empathetic-dialogues') | |
parser.add_argument('--emotion_model', type=str, default='joeddav/distilbert-base-uncased-go-emotions-student') | |
parser.add_argument('--account', type=str, default=None) | |
parser.add_argument('--initial_chat_state', type=str, default='euc_100', choices=['euc_100', 'euc_200', 'free_chat']) | |
args = parser.parse_args() | |
return args | |
args = option() | |
# store the list of messages that are showed in therapies and models as global variables | |
# let all chat-session-wise variables placed in TherapyChatBot | |
class ChatHelper: | |
# chat and emotion-detection models | |
ed_pipe = pipeline('text-classification', model=args.emotion_model, top_k=5, truncation=True) | |
ed_threshold = 0.3 | |
dialog_model = GPT2LMHeadModel.from_pretrained(args.dialog_model) | |
dialog_tokenizer = GPT2Tokenizer.from_pretrained(args.dialog_model) | |
eos = dialog_tokenizer.eos_token | |
# tokenizer.__call__ -> input_ids, attention_mask | |
# tokenizer.encode -> only inputs_ids, which is required by model.generate function | |
invalid_input = 'Invalid input, my friend :) Plz input again' | |
good_mood_over = 'Whether your good mood is over? Any other details that you would like to recall?' | |
good_case = 'Nice to hear that!' | |
bad_mood_over = 'Whether your bad mood is over? (Yes or No)' | |
not_answer = "It's okay, maybe you don't want to answer this question." | |
fill_form = ('It has come to our attention that you may suffer from {}.\n' | |
'If you want to know more about yourself, some professional scales are provided to quantify your current status.\n' | |
'After a period of time (maybe a week/two months/a month) trying to follow the solutions we suggested, ' | |
'you can fill out these scales again to see if you have improved.\n' | |
'Do you want to fill in the form now? (Okay or Later)') | |
display_form = '<Display the form>.\n' | |
reference = 'Here are some reference articles about bad emotions. You can take a look :) <Display references>\n' | |
emotion_types = ['Overall', 'Happiness', 'Anxiety'] # 'Surprise', 'Sadness', 'Depression', 'Anger', 'Fear', | |
euc_100 = { | |
'q': emotion_types, | |
'good_mood': [ | |
'You seem to be in a good mood today. Is there anything you could notice that makes you happy?', | |
'I am glad that you are willing to share the experience with me. Thanks for letting me know.', | |
], | |
'bad_mood': [ | |
'You seem not to be in a good mood. What specific thing is bothering you the most right now?', | |
'I see. So when it is happening, what feelings or emotions have you got?', | |
'And what do you think about those feelings or emotions at that time?', | |
'Could you think of any evidence for your above-mentioned thought?', | |
'Here are some reference articles about bad emotions. You can take a look :)', | |
], | |
} | |
negative_emotions = ['remorse', 'nervousness', 'annoyance', 'anger', 'grief', 'fear', 'disapproval', | |
'confusion', 'embarrassment', 'disgust', 'sadness', 'disappointment'] | |
euc_200 = 'Now go back to the last chat. You said that "{}".\n' | |
greeting_template = { | |
'euc_100': 'How was your day? On the scale 1 to 10, ' | |
'how would you judge your emotion through the following categories:\nOverall', | |
# euc_200 is only trigger when you say smt more negative than a certain threshol | |
# thus the greeting here is only for debuging euc_200 | |
'euc_200': fill_form.format('anxiety'), | |
'free_chat': 'Hi you! How is it going?', | |
} | |
def plot_emotion_distribution(predictions): | |
fig, ax = plt.subplots() | |
ax.bar(x=[i for i, _ in enumerate(prediction)], | |
height=[p['score'] for p in prediction], | |
tick_label=[p['label'] for p in prediction]) | |
ax.tick_params(rotation=90) | |
ax.set_ylim(0, 1) | |
plt.show() | |
def ed_rulebase(text): | |
keywords = { | |
'life_safety': ['death', 'suicide', 'murder', 'to perish together', 'jump off the building'], | |
'immediacy': ['now', 'immediately', 'tomorrow', 'today'], | |
'manifestation': ['never stop', 'every moment', 'strong', 'very'] | |
} | |
# if found dangerous kw/topics | |
if re.search(rf"{'|'.join(keywords['life_safety'])}", text) != None and \ | |
sum([re.search(rf"{'|'.join(keywords[k])}", text) != None for k in ['immediacy','manifestation']]) >= 1: | |
print('We noticed that you may need immediate professional assistance, would you like to make a phone call? ' | |
'The Hong Kong Lifeline number is (852) 2382 0000') | |
x = input('Choose 1. "Dial to the number" or 2. "No dangerous emotion la": ') | |
if x == '1': | |
print('Let you connect to the office') | |
else: | |
print('Sorry for our misdetection. We just want to make sure that you could get immediate help when needed. ' | |
'Would you mind if we send this conversation to the cloud to finetune the model.') | |
y = input('Yes or No: ') | |
if y == 'Yes': | |
pass # do smt here | |
class TherapyChatBot: | |
def __init__(self, args): | |
# check state to control the dialog | |
self.chat_state = args.initial_chat_state # name of the chat function/therapy segment the model is in | |
self.message_prev = None | |
self.chat_state_prev = None | |
self.run_on_own_server = args.run_on_own_server | |
self.account = args.account | |
# additional attribute for euc_100 | |
self.euc_100_input_time = [] | |
self.euc_100_emotion_degree = [] | |
self.already_trigger_euc_200 = False | |
# chat history. | |
# TODO: if we want to personalize and save the conversation, | |
# we can load data from database | |
self.greeting = [('', ChatHelper.greeting_template[self.chat_state])] | |
self.history = {'input_ids': torch.tensor([[ChatHelper.dialog_tokenizer.bos_token_id]]), | |
'text': self.greeting} if not self.account else open(f'database/{hash(self.account)}', 'rb') | |
if 'euc_100' in self.chat_state: | |
self.chat_state = 'euc_100.q.0' | |
def __call__(self, message, prefix=''): | |
# if prefix != None, which means this function is called from euc_200, thus already detected the negative emotion | |
if (not prefix) and self.chat_state != 'euc_200' and not self.already_trigger_euc_200: | |
prediction = ChatHelper.ed_pipe(message)[0] | |
prediction = sorted(prediction, key=lambda x: x['score'], reverse=True) | |
if self.run_on_own_server: | |
print(prediction) | |
# plot_emotion_distribution(prediction) | |
emotion = prediction[0] | |
# if message is negative, change state immediately | |
if ((not prefix) and self.chat_state != 'euc_200' and not self.already_trigger_euc_200) and \ | |
(emotion['label'] in ChatHelper.negative_emotions and emotion['score'] > ChatHelper.ed_threshold): | |
self.chat_state_prev = self.chat_state | |
self.chat_state = 'euc_200' | |
self.message_prev = message | |
self.already_trigger_euc_200 = True | |
response = ChatHelper.fill_form.format(emotion['label']) | |
# set up rule to update state inside each dialog function | |
elif self.chat_state.startswith('euc_100'): | |
response = self.euc_100(message) | |
if self.chat_state == 'free_chat': | |
last_two_turns_ids = ChatHelper.dialog_tokenizer.encode(message + ChatHelper.eos, return_tensors='pt') | |
self.history['input_ids'] = torch.cat([self.history['input_ids'], last_two_turns_ids], dim=-1) | |
elif self.chat_state.startswith('euc_200'): | |
return self.euc_200(message) | |
else: # free_chat | |
response = self.free_chat(message) | |
if prefix: | |
response = prefix + response | |
self.history['text'].append((self.message_prev, response)) | |
else: | |
self.history['text'].append((message, response)) | |
def euc_100(self, x): | |
_, subsection, entry = self.chat_state.split('.') | |
entry = int(entry) | |
if subsection == 'q': | |
if x.isnumeric() and (0 < int(x) < 11): | |
self.euc_100_emotion_degree.append(int(x)) | |
self.euc_100_input_time.append(time.gmtime()) | |
if entry == len(ChatHelper.euc_100['q']) - 1: | |
if self.run_on_own_server: | |
print(self.euc_100_emotion_degree) | |
mood = 'good_mood' if self.euc_100_emotion_degree[0] > 5 else 'bad_mood' | |
self.chat_state = f'euc_100.{mood}.0' | |
response = ChatHelper.euc_100[mood][0] | |
else: | |
self.chat_state = f'euc_100.q.{entry+1}' | |
response = ChatHelper.euc_100['q'][entry+1] | |
else: | |
response = ChatHelper.invalid_input | |
elif subsection == 'good_mood': | |
if x == '': | |
response = ChatHelper.good_mood_over | |
else: | |
response = ChatHelper.good_case | |
response += '\n' + ChatHelper.euc_100['good_mood'][1] | |
self.chat_state = 'free_chat' | |
elif subsection == 'bad_mood': | |
if entry == -1: | |
if 'yes' in x.lower() or 'better' in x.lower(): | |
response = ChatHelper.good_case | |
else: | |
entry = int(self.chat_state_prev.rsplit('.', 1)) | |
response = ChatHelper.not_answer + '\n' + ChatHelper.euc_100['bad_mood'][entry+1] | |
if entry == len(ChatHelper.euc_100['bad_mood']) - 2: | |
self.chat_state = 'free_chat' | |
else: | |
self.chat_state = f'euc_100.bad_mood.{entry+1}' | |
if x == '': | |
response = ChatHelper.bad_mood_over | |
self.chat_state_prev = self.chat_state | |
self.chat_state = 'euc_100.bad_mood.-1' | |
else: | |
response = ChatHelper.euc_100['bad_mood'][entry+1] | |
if entry == len(ChatHelper.euc_100['bad_mood']) - 2: | |
self.chat_state = 'free_chat' | |
else: | |
self.chat_state = f'euc_100.bad_mood.{entry+1}' | |
return response | |
def euc_200(self, x): | |
# don't ask question in euc_200, because they're similar to question in euc_100 | |
if x.lower() == 'okay': | |
response = ChatHelper.display_form | |
else: | |
response = ChatHelper.reference | |
response += ChatHelper.euc_200.format(self.message_prev) | |
message = self.message_prev | |
self.message_prev = x | |
self.chat_state = self.chat_state_prev | |
return self.__call__(message, prefix=response) | |
def free_chat(self, message): | |
message_ids = ChatHelper.dialog_tokenizer.encode(message + ChatHelper.eos, return_tensors='pt') | |
self.history['input_ids'] = torch.cat([self.history['input_ids'], message_ids], dim=-1) | |
input_ids = self.history['input_ids'].clone() | |
while True: | |
bot_output_ids = ChatHelper.dialog_model.generate(input_ids, max_length=1000, | |
do_sample=True, top_p=0.9, temperature=0.8, num_beams=2, | |
pad_token_id=ChatHelper.dialog_tokenizer.eos_token_id) | |
response = ChatHelper.dialog_tokenizer.decode(bot_output_ids[0][input_ids.shape[-1]:], | |
skip_special_tokens=True) | |
if response.strip() != '': | |
break | |
elif input_ids[0].tolist().count(ChatHelper.dialog_tokenizer.eos_token_id) > 0: | |
idx = input_ids[0].tolist().index(ChatHelper.dialog_tokenizer.eos_token_id) | |
input_ids = input_ids[:, (idx+1):] | |
else: | |
input_ids = message_ids | |
if self.run_on_own_server: | |
print(input_ids) | |
self.history['input_ids'] = torch.cat([self.history['input_ids'], bot_output_ids[0:1, input_ids.shape[-1]:]], dim=-1) | |
if self.run_on_own_server == 1: | |
print((message, response), '\n', self.history['input_ids']) | |
return response | |
if __name__ == '__main__': | |
def chat(message, bot): | |
bot = bot or TherapyChatBot(args) | |
bot(message) | |
return bot.history['text'], bot | |
title = 'PsyPlus Empathetic Chatbot' | |
description = 'Gradio demo for product of PsyPlus. Based on rule-based CBT and conversational AI model DialoGPT' | |
greeting = [('', ChatHelper.greeting_template[args.initial_chat_state])] | |
chatbot = gr.Chatbot(value=greeting) | |
iface = gr.Interface( | |
chat, ['text', 'state'], [chatbot, 'state'], | |
allow_flagging='never', title=title, description=description, | |
) | |
if args.run_on_own_server == 0: | |
iface.launch(debug=True) | |
else: | |
iface.launch(debug=True, share=True) |