Chatting_with_Aira / chat_app_remote.py
Respair's picture
Update chat_app_remote.py
b42c71b verified
raw
history blame
7.01 kB
import gradio as gr
from gradio_client import Client
import uuid
import warnings
import numpy as np
import json
import os
from gradio_client import Client, FileData, handle_file
import tempfile
import scipy.io.wavfile as wavfile
# Suppress warnings
warnings.filterwarnings("ignore")
# Initialize client
client = Client(os.environ['src'])
# Custom CSS for container alignment
custom_css = """
.gradio-container {
justify-content: flex-start !important;
}
"""
def chat_function(message, history, session_id):
"""
Handle chat interactions with the backend.
"""
result = client.predict(
message,
history,
session_id,
fn_index=0
)
_, new_history, audio_path, display_text = result
return "", new_history, audio_path, session_id, display_text
def set_session(user_id):
"""
Set or generate new session ID.
"""
result = client.predict(
user_id,
fn_index=1
)
new_id, display_text = result
display_text = f"Current Session ID: {new_id}"
return new_id, "", display_text
def handle_audio(audio_data, history, session_id):
"""
Process audio input and send to backend.
"""
if audio_data is None:
return None, history, session_id, f"Current Session ID: {session_id}"
try:
sample_rate, audio_array = audio_data
with tempfile.NamedTemporaryFile(suffix='.wav', delete=True) as temp:
wavfile.write(temp.name, sample_rate, audio_array)
audio = {"path": temp.name, "meta": {"_type": "gradio.FileData"}}
result = client.predict(
audio,
history,
session_id,
api_name="/handle_audio"
)
audio_path, new_history, new_session_id = result
display_text = f"Current Session ID: {new_session_id}"
return audio_path, new_history, new_session_id, display_text
except Exception as e:
print(f"Error processing audio: {str(e)}")
import traceback
traceback.print_exc()
return None, history, session_id, f"Error processing audio. Session ID: {session_id}"
def respond(message, chat_history, session_id):
"""
Handle chat responses with session validation.
"""
if not session_id:
return "", chat_history, None, session_id, "Please set a session ID first"
return chat_function(message, chat_history, session_id)
def create_frontend_demo():
"""
Create and configure the Gradio interface.
"""
with gr.Blocks(css=custom_css, theme="Respair/[email protected]") as demo:
# Initialize session_id_state with a new random session ID
initial_session_id, _, initial_display_text = set_session("")
session_id_state = gr.State(value=initial_session_id)
with gr.Tabs() as tabs:
with gr.Tab("Chat"):
# Update initial session display with the auto-generated ID
session_display = gr.Markdown(initial_display_text, label="Session ID")
chatbot = gr.Chatbot(
label="Conversation History",
height=400,
avatar_images=["photo_2024-03-01_22-30-42.jpg", "colored_blured.png"],
placeholder="Start chatting with Aira..."
)
# Update the instruction text since session ID is now automatic
gr.Markdown("""Start chatting with Aira! You can use text or voice input.
<br> アイラとチャットを始めましょう!テキストまたは音声入力が使えます。""")
with gr.Column():
msg = gr.Textbox(
show_label=False,
placeholder="Enter text and press enter",
container=True
)
audio_output = gr.Audio(
label="Aira's Response",
type="filepath",
streaming=False,
autoplay=True
)
with gr.Row():
audio_input = gr.Audio(
sources=["microphone"],
type="numpy",
label="Audio Input",
streaming=False
)
with gr.Tab("Options"):
with gr.Column():
session_input = gr.Textbox(
value="",
label="Session ID (leave blank for new session)"
)
gen_id_btn = gr.Button("Set Session ID")
session_msg = gr.Markdown("")
clear_btn = gr.Button("Clear Conversation")
gr.Markdown("""
This is a personal project I wanted to do for a while.
Aira's voice was designed to be unique; it doesn't belong to any real person out there.
Her design is also based on a vtuber project I did a few years ago, though I didn't put
a lot of effort into it this time (you can see the lazy brush strokes).
You can talk to her in English or Japanese, but she will only respond in Japanese
(Subs over dubs, bros) ask her to give you a Subtitle if you can't talk in Japanese.
The majority of the latency depends on the HF's inference api.
The language modelling part is not fine-tuned, it's an off-the-shelf one, please beware of that.
1. Enter your Session ID above or leave blank for a new one
2. Click 'Set Session ID' to confirm
3. Use 'Clear Conversation' to reset the chat
4. Your conversation history is saved based on your Session ID
I'll try to keep this demo up for as long as I can afford.
""")
# Event handlers
msg.submit(
respond,
inputs=[msg, chatbot, session_id_state],
outputs=[msg, chatbot, audio_output, session_id_state, session_display]
)
gen_id_btn.click(
set_session,
inputs=[session_input],
outputs=[session_id_state, session_msg, session_display]
)
audio_input.stop_recording(
handle_audio,
inputs=[audio_input, chatbot, session_id_state],
outputs=[audio_output, chatbot, session_id_state, session_display]
)
clear_btn.click(
lambda: [],
None,
[chatbot]
)
return demo
if __name__ == "__main__":
demo = create_frontend_demo()
demo.launch(show_error=True)