Spaces:
Sleeping
Sleeping
import spaces | |
import os | |
import jieba | |
import pandas as pd | |
import gradio as gr | |
import torch | |
import numpy as np | |
import cv2 | |
import base64 | |
import time | |
from PIL import Image | |
from openai import OpenAI | |
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, EulerAncestralDiscreteScheduler | |
from peft import PeftModel | |
from transformers import AutoModel | |
import asyncio | |
# OpenAI API | |
api_key = os.getenv('OPENAI_API_KEY') | |
client = OpenAI(api_key=api_key) | |
# text to image 設定 | |
access_token = os.getenv('HF_TOKEN') | |
model_id = "stabilityai/stable-diffusion-xl-base-1.0" | |
pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16, use_safetensors=True, token=access_token).to("cuda") | |
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) | |
# 載入 LoRA | |
lora_path = "./tbh368-sdxl.safetensors" | |
pipe.load_lora_weights(lora_path, adapter_name="milton-glaser") | |
pipe.load_lora_weights("e-n-v-y/envyimpressionismxl01", weight_name="EnvyImpressionismXL01.safetensors", adapter_name="impressionism") | |
pipe.set_adapters(["milton-glaser", "impressionism"], adapter_weights=[1.0, 0.5]) | |
# 載入 CVAW Corpus 資料 | |
cvaw_data = pd.read_csv('./CVAW_all_SD.csv', delimiter='\t') | |
cvaw_dict = dict(zip(cvaw_data['Word'], zip(cvaw_data['Valence_Mean'], cvaw_data['Arousal_Mean']))) | |
def analyze_sentiment_corpus(text, conversation_times, valence_scores, arousal_scores): | |
words = jieba.cut(text) | |
conversation_times += 1 | |
for word in words: | |
if word in cvaw_dict: | |
valence, arousal = cvaw_dict[word] | |
valence_scores.append(valence) | |
arousal_scores.append(arousal) | |
# 只保留最後三句使用者輸入的內容 | |
valence_scores = valence_scores[-3:] | |
arousal_scores = arousal_scores[-3:] | |
if conversation_times < 6: # 當對話次數<6的時候,返回10,代表不會進入放鬆模式 | |
return 10, 10, conversation_times, valence_scores, arousal_scores | |
else: | |
avg_valence = np.mean(valence_scores) | |
avg_arousal = np.mean(arousal_scores) | |
return avg_valence, avg_arousal, conversation_times, valence_scores, arousal_scores | |
def call_gpt(input_text, history): | |
messages = [{"role":"system", "content":"對話請以繁體中文進行:你是一位熟悉現象學的諮商實習生,擅長引導使用者描述他當下的所知覺到的事物。回答問題的時候必須有同理心,請同理使用者說的內容,再繼續回答,且不要超過20個字。"}] | |
for h in history: | |
messages.append({"role": "user", "content": h[0]}) | |
messages.append({"role": "assistant", "content": h[1]}) | |
messages.append({"role": "user", "content": input_text}) | |
chat_reply = client.chat.completions.create( | |
model="chatgpt-4o-latest", | |
messages=messages, | |
temperature=0.8 | |
) | |
return chat_reply.choices[0].message.content | |
def chat_with_bot(input_text, history, conversation_times, valence_scores, arousal_scores, meditation_flag): | |
response = "" | |
med_confirm_layout = False # 是否顯示放鬆選項 | |
jump2med_btn = True # 是否允許跳轉到放鬆介面 | |
# 進行情感分析 | |
valence, arousal, conversation_times, valence_scores, arousal_scores = analyze_sentiment_corpus(input_text, conversation_times, valence_scores, arousal_scores) | |
# 判斷是否建議放鬆練習 | |
if 4.7 <= arousal <= 5.4 and meditation_flag is True: # 詢問是否進行放鬆練習 | |
time.sleep(1.5) | |
response = "我知道你的狀況了\n我有一個建議,我們來進行一個可以讓自己放鬆的呼吸練習好嗎?" | |
history.append((input_text, response)) | |
med_confirm_layout = True | |
return history, med_confirm_layout, jump2med_btn, conversation_times, valence_scores, arousal_scores, meditation_flag | |
elif meditation_flag is False: # 已經放鬆過,不顯示跳轉按鈕 | |
response = call_gpt(input_text, history) | |
history.append((input_text, response)) | |
jump2med_btn = False | |
return history, med_confirm_layout, jump2med_btn, conversation_times, valence_scores, arousal_scores, meditation_flag | |
else: # 繼續對話 | |
response = call_gpt(input_text, history) | |
history.append((input_text, response)) | |
return history, med_confirm_layout, jump2med_btn, conversation_times, valence_scores, arousal_scores, meditation_flag | |
def translate_to_english(text): | |
character = "You are a professional text-to-image prompt generator, please use the following text to generate prompt in English. Make sure it has only 60 tokens. Details are not necessary." | |
messages = [{"role":"system", "content":character}, | |
{"role": "user", "content": text}] | |
chat_reply = client.chat.completions.create( | |
messages=messages, | |
model="chatgpt-4o-latest", | |
temperature=0.3, | |
max_tokens=60 | |
) | |
return chat_reply.choices[0].message.content | |
def generate_images(history, conversation_times, last_genimg_times, generated_images): | |
if generated_images is not None and last_genimg_times == conversation_times: | |
return conversation_times, last_genimg_times, *generated_images # 如果圖片已生成,直接返回 | |
user_story = " ".join([h[0] for h in history]) | |
prompt = translate_to_english(user_story) | |
neg_prompt = "dark, realistic, words, sentence, text, extra, nude, duplicate, ugly" | |
seeds = np.random.randint(0, 100000, 4) | |
generator = [torch.Generator().manual_seed(int(i)) for i in seeds] | |
images = [] | |
last_genimg_times = conversation_times | |
for i in range(4): | |
img = pipe("style of Milton Glaser, modern digital impressionism, hard to understood, abstract, "+prompt, | |
negative_prompt=neg_prompt, | |
height=720, width=512, | |
generator=generator[i], | |
num_inference_steps=40, | |
guidance_scale=10, | |
).images[0] | |
images.append(img) | |
generated_images = images # 儲存生成的圖片 | |
return conversation_times, last_genimg_times, *images | |
def select_image(choice, img1, img2, img3, img4): | |
index = int(choice.split()[-1]) - 1 | |
images = [img1, img2, img3, img4] | |
return images[index] | |
def chat_about_image(input_text, history, selected_image): | |
# 讀取並編碼圖像 | |
_, buffer = cv2.imencode('.png', selected_image) | |
img_str = base64.b64encode(buffer).decode() | |
messages = [ | |
{"role": "system", "content": "對話請以繁體中文進行:你是一位熟悉現象學的諮商實習生,請根據使用者對他所選出的圖像描述進行引導,指出這張圖像與先前對話的關聯,幫助使用者探索他們的分享與該圖像間的連結,並繼續對話"}, | |
{"role": "user", "content": [ | |
{"type": "text", "text": f"看到這張圖像,讓我想到 {input_text}"}, | |
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_str}"}} | |
]} | |
] | |
chat_reply = client.chat.completions.create( | |
model="chatgpt-4o-latest", | |
messages=messages, | |
max_tokens=300 | |
) | |
reply = chat_reply.choices[0].message.content | |
history.append((input_text, reply)) | |
return ( | |
history, | |
history, | |
gr.update(visible=True) # chatbot_interface | |
) | |
audio_file = "meditation_v2.m4a" | |
# UI handle functions | |
def handle_chat(input_text, history, conversation_times, valence_scores, arousal_scores, meditation_flag): | |
updated_history, meditation, jump2med_btn, conversation_times, valence_scores, arousal_scores, meditation_flag = chat_with_bot(input_text, history, conversation_times, valence_scores, arousal_scores, meditation_flag) | |
if meditation: | |
return ( | |
updated_history, | |
gr.update(value="", placeholder="現在開始引導放鬆吧", interactive=False), # msg | |
gr.update(visible=False), # submit | |
gr.update(visible=False), # jump_to_med | |
gr.update(visible=True), # meditation_buttons | |
conversation_times, valence_scores, arousal_scores, meditation_flag | |
) | |
elif jump2med_btn is False: | |
return ( | |
updated_history, | |
gr.update(value="", placeholder="今天想跟我分享什麼呢?", interactive=True), # msg | |
gr.update(visible=True), # submit | |
gr.update(visible=False), # jump_to_med | |
gr.update(visible=False), # meditation_buttons | |
conversation_times, valence_scores, arousal_scores, meditation_flag | |
) | |
else: | |
return ( | |
updated_history, | |
gr.update(value="", placeholder="今天想跟我分享什麼呢?", interactive=True), # msg | |
gr.update(visible=True), # submit | |
gr.update(visible=True), # jump_to_med | |
gr.update(visible=False), # meditation_buttons | |
conversation_times, valence_scores, arousal_scores, meditation_flag | |
) | |
def start_meditation(meditation_flag): | |
meditation_flag = False | |
audio = audio_file | |
return ( | |
audio, meditation_flag | |
) | |
def continue_chat(): | |
return ( | |
gr.update(value="", placeholder="今天想跟我分享什麼呢?", interactive=True), # msg | |
gr.update(visible=True), # submit | |
gr.update(visible=True), # jump_to_med | |
gr.update(visible=False), # meditation_buttons | |
) | |
def return_to_chat(): | |
return ( | |
None, # set audio_player to None will stop the audio | |
gr.update(visible=True), # main_interface | |
gr.update(visible=True), # chatbot_interface | |
gr.update(visible=True), # selected_image_interface | |
gr.update(visible=False), # audio_interface | |
gr.update(interactive=True, placeholder="今天想跟我分享什麼呢?"), # msg | |
gr.update(visible=True), # submit | |
gr.update(visible=False), # meditation_buttons | |
gr.update(visible=True), # gen_other_img | |
gr.update(visible=False) # jump_to_med | |
) | |
async def show_loading(): | |
# 顯示載入訊息 | |
yield (gr.update(visible=True), # loading_message | |
"載入時間約需十秒,建議戴上耳機體驗。\n在接下來的畫面,你會看到播放介面,按下播放鈕後就會開始播放放鬆指導語", | |
gr.update(visible=False), # main_interface | |
gr.update(visible=False)) # audio_interface | |
# 等待 12 秒 | |
await asyncio.sleep(12) | |
# 清除訊息 | |
yield (gr.update(visible=False), # loading_message | |
"", | |
gr.update(visible=False), # main_interface | |
gr.update(visible=True)) # audio_interface | |
theme = gr.themes.Base( | |
primary_hue="amber", | |
secondary_hue="sky", | |
font=[gr.themes.GoogleFont('Noto Sans TC'), 'ui-sans-serif', 'system-ui', 'sans-serif'], | |
).set( | |
checkbox_background_color_selected_dark='*secondary_400', | |
button_border_width='*checkbox_border_width', | |
button_primary_background_fill_hover='*primary_400', | |
button_primary_background_fill_hover_dark='*primary_900', | |
button_secondary_background_fill='*secondary_100', | |
button_secondary_background_fill_dark='*secondary_700', | |
button_secondary_background_fill_hover='*secondary_300', | |
button_secondary_background_fill_hover_dark='*secondary_900', | |
button_secondary_border_color='*secondary_100', | |
button_secondary_border_color_dark='*secondary_600', | |
button_secondary_text_color='*secondary_700' | |
) | |
css = """ | |
.graphic_parent { | |
display: flex; | |
flex-direction: row; | |
flex-wrap: wrap; | |
gap: 8px; | |
max-height: 100vh; | |
max-width: 100vw; | |
} | |
.graphic { | |
width: 48%; | |
} | |
""" | |
with gr.Blocks(theme=theme, css=css, title='療癒對話機器人') as demo: | |
generated_images = gr.State(value=None) | |
meditation_flag = gr.State(value=True) | |
last_genimg_times = gr.State(value=0) | |
conversation_times = gr.State(value=0) | |
valence_scores = gr.State(value=[]) | |
arousal_scores = gr.State(value=[]) | |
gr.Markdown("# 療癒對話機器人") | |
login = gr.LoginButton(value="登入 Hugging Face", logout_value="登出 Hugging Face", variant="primary", size="sm") | |
loading_message = gr.Textbox(visible=False, show_label=False) | |
with gr.Column(visible=False) as audio_interface: | |
audio_player = gr.Audio(label="放鬆引導指導語", show_download_button=False, show_share_button=False, interactive=False) | |
back_to_chat = gr.Button("返回聊天") | |
with gr.Row() as main_interface: | |
with gr.Column() as chatbot_interface: | |
chatbot = gr.Chatbot(label="聊天機器人", show_share_button=False, bubble_full_width=False, layout='bubble', scale=6) | |
msg = gr.Textbox(show_label=False, placeholder="今天想要跟我分享什麼呢?", autofocus=True, scale=2) | |
with gr.Row(): | |
submit = gr.Button("送出", variant="primary", scale=2) | |
jump_to_med = gr.Button("跳過對話進行放鬆引導", variant="secondary", scale=2) | |
gen_other_img = gr.Button("結合聯想生成更多圖像", variant="secondary", scale=2, visible=False) | |
with gr.Row(visible=False) as meditation_buttons: | |
relax_yes = gr.Button("好", variant="primary") | |
relax_no = gr.Button("我想再多分享一點") | |
with gr.Column(elem_classes="graphic_parent") as image_selector_interface: | |
image_selector = gr.Radio(choices=["圖像 1", "圖像 2", "圖像 3", "圖像 4"], label="選擇一張圖像") | |
image_outputs = [gr.Image(label=f"圖像 {i+1}", interactive=False, show_share_button=False, elem_classes="graphic") for i in range(4)] | |
with gr.Column(visible=False) as selected_image_interface: | |
selected_image = gr.Image(interactive=False, show_share_button=False, label="你選擇的圖像") | |
image_chat_input = gr.Textbox(label="這張圖像讓你產生了什麼樣的聯想?") | |
image_chat_button = gr.Button("與聊天機器人分享", variant="primary") | |
login.activate() | |
# chatbot events handle | |
submit.click(handle_chat, [msg, chatbot, conversation_times, valence_scores, arousal_scores, meditation_flag], [chatbot, msg, submit, jump_to_med, meditation_buttons, conversation_times, valence_scores, arousal_scores, meditation_flag]) | |
msg.submit(handle_chat, [msg, chatbot, conversation_times, valence_scores, arousal_scores, meditation_flag], [chatbot, msg, submit, jump_to_med, meditation_buttons, conversation_times, valence_scores, arousal_scores, meditation_flag]) | |
# going to meditation events handle | |
jump_to_med.click(start_meditation, meditation_flag, [audio_player, meditation_flag]) | |
relax_yes.click(start_meditation, meditation_flag, [audio_player, meditation_flag]) | |
relax_no.click(continue_chat, None, [msg, submit, jump_to_med, meditation_buttons]) | |
# meditation events handle | |
jump_to_med.click(show_loading, None, [loading_message, loading_message, main_interface, audio_interface]) | |
relax_yes.click(show_loading, None, [loading_message, loading_message, main_interface, audio_interface]) | |
audio_player.play(generate_images, [chatbot, conversation_times, last_genimg_times, generated_images], [conversation_times, last_genimg_times] + image_outputs) | |
audio_player.stop(return_to_chat, None, [audio_player, main_interface, chatbot_interface, selected_image_interface, audio_interface, msg, submit, meditation_buttons, gen_other_img, jump_to_med]) | |
back_to_chat.click(return_to_chat, None, [audio_player, main_interface, chatbot_interface, selected_image_interface, audio_interface, msg, submit, meditation_buttons, gen_other_img, jump_to_med]) | |
# images select events handle | |
image_selector.change(select_image, [image_selector] + image_outputs, selected_image) | |
image_chat_input.submit(chat_about_image, [image_chat_input, chatbot, selected_image], [chatbot, chatbot, chatbot_interface]).then(lambda: None, None, image_chat_input , queue=False) | |
image_chat_button.click(chat_about_image, [image_chat_input, chatbot, selected_image], [chatbot, chatbot, chatbot_interface]).then(lambda: None, None, image_chat_input , queue=False) | |
# generate other images event handle | |
gen_other_img.click(generate_images, [chatbot, conversation_times, last_genimg_times, generated_images], [conversation_times, last_genimg_times] + image_outputs) | |
if __name__ == "__main__": | |
demo.queue(max_size=30, default_concurrency_limit=20) | |
demo.launch(show_api=False, max_threads=40) |