diff --git "a/playground/refs/test.ipynb" "b/playground/refs/test.ipynb" new file mode 100644--- /dev/null +++ "b/playground/refs/test.ipynb" @@ -0,0 +1,812 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/barrel/aai/.venv/lib/python3.10/site-packages/pyannote/audio/core/io.py:43: UserWarning: torchaudio._backend.set_audio_backend has been deprecated. With dispatcher enabled, this function is no-op. You can remove the function call.\n", + " torchaudio.set_audio_backend(\"soundfile\")\n" + ] + } + ], + "source": [ + "import gradio as gr\n", + "import numpy as np\n", + "import torch\n", + "import torchaudio\n", + "from silero_vad import get_speech_timestamps, load_silero_vad\n", + "import whisperx\n", + "import openai\n", + "import asyncio\n", + "import edge_tts\n", + "import gc\n", + "import logging\n", + "import time" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-09-23 13:50:24,408 - INFO - Using device: cuda\n", + "2024-09-23 13:50:24,660 - INFO - Loaded Silero VAD model\n", + "Lightning automatically upgraded your loaded checkpoint from v1.5.4 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../.cache/torch/whisperx-vad-segmentation.bin`\n", + "2024-09-23 13:50:24,994 - INFO - Loaded WhisperX model\n", + "2024-09-23 13:50:24,994 - INFO - Set OpenAI API key\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No language specified, language will be first be detected for each audio file (increases inference time).\n", + "Model was trained with pyannote.audio 0.0.1, yours is 3.1.1. Bad things might happen unless you revert pyannote.audio to 0.x.\n", + "Model was trained with torch 1.10.0+cu102, yours is 2.3.1+cu121. Bad things might happen unless you revert torch to 1.x.\n" + ] + } + ], + "source": [ + "# Configure logging\n", + "logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')\n", + "\n", + "# Load Silero VAD model\n", + "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "logging.info(f'Using device: {device}')\n", + "vad_model = load_silero_vad().to(device) # Ensure the model is on the correct device\n", + "logging.info('Loaded Silero VAD model')\n", + "\n", + "# Load WhisperX model\n", + "whisper_model = whisperx.load_model(\"tiny\", device, compute_type=\"float16\")\n", + "logging.info('Loaded WhisperX model')\n", + "\n", + "openai.api_key = \"sk-proj-gcrtuxd5qzaRYT82Ii3eT3BlbkFJpVQHBc9ZJrmSksLbQc3C\"\n", + "logging.info('Set OpenAI API key')\n", + "\n", + "# TTS Voice\n", + "TTS_VOICE = \"en-GB-SoniaNeural\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchaudio\n", + "import logging\n", + "\n", + "def check_vad(audio_data, sample_rate):\n", + " logging.info('Checking voice activity')\n", + " # Resample to 16000 Hz if necessary\n", + " target_sample_rate = 16000\n", + " if sample_rate != target_sample_rate:\n", + " resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)\n", + " audio_tensor = resampler(torch.from_numpy(audio_data))\n", + " else:\n", + " audio_tensor = torch.from_numpy(audio_data)\n", + " audio_tensor = audio_tensor.to(device)\n", + "\n", + " # Log audio data details\n", + " logging.info(f'Audio tensor shape: {audio_tensor.shape}, dtype: {audio_tensor.dtype}, device: {audio_tensor.device}')\n", + "\n", + " # Get speech timestamps with optimized parameters\n", + " speech_timestamps = get_speech_timestamps(\n", + " audio=audio_tensor,\n", + " model=vad_model,\n", + " sampling_rate=target_sample_rate,\n", + " min_speech_duration_ms=250,\n", + " min_silence_duration_ms=80,\n", + " speech_pad_ms=30\n", + " )\n", + " logging.info(f'Found {len(speech_timestamps)} speech timestamps')\n", + " return len(speech_timestamps) > 0" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "def transcript(audio_data, sample_rate):\n", + " logging.info('Transcribing audio')\n", + " # Resample to 16000 Hz if necessary\n", + " target_sample_rate = 16000\n", + " if sample_rate != target_sample_rate:\n", + " resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)\n", + " audio_data = resampler(torch.from_numpy(audio_data)).numpy()\n", + " else:\n", + " audio_data = audio_data\n", + "\n", + " # Transcribe\n", + " batch_size = 16 # Adjust as needed\n", + " result = whisper_model.transcribe(audio_data, batch_size=batch_size)\n", + " text = result['segments'][0]['text']\n", + " logging.info(f'Transcription result: {text}')\n", + " # Clear GPU memory\n", + " del result\n", + " gc.collect()\n", + " if device == 'cuda':\n", + " torch.cuda.empty_cache()\n", + " return text" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from openai import OpenAI\n", + "\n", + "openai_client = OpenAI(api_key='sk-proj-gcrtuxd5qzaRYT82Ii3eT3BlbkFJpVQHBc9ZJrmSksLbQc3C')\n", + "\n", + "def llm(text):\n", + " logging.info('Getting response from OpenAI API')\n", + " response = openai_client.chat.completions.create(\n", + " model=\"gpt-4o\", # Updated to a more recent model\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": \"You respond to the following transcript from the conversation that you are having with the user.\"},\n", + " {\"role\": \"user\", \"content\": text} \n", + " ],\n", + " stream=True,\n", + " temperature=0.7, # Optional: Adjust as needed\n", + " top_p=0.9, # Optional: Adjust as needed\n", + " )\n", + " for chunk in response:\n", + " yield chunk.choices[0].delta.content" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "def tts_streaming(text_stream):\n", + " logging.info('Performing TTS')\n", + " buffer = \"\"\n", + " punctuation = {'.', '!', '?'}\n", + " for text_chunk in text_stream:\n", + " if text_chunk is not None:\n", + " buffer += text_chunk\n", + " # Check for sentence completion\n", + " sentences = []\n", + " start = 0\n", + " for i, char in enumerate(buffer):\n", + " if (char in punctuation):\n", + " sentences.append(buffer[start:i+1].strip())\n", + " start = i+1\n", + " buffer = buffer[start:]\n", + "\n", + " for sentence in sentences:\n", + " if sentence:\n", + " communicate = edge_tts.Communicate(sentence, TTS_VOICE)\n", + " for chunk in communicate.stream_sync():\n", + " if chunk[\"type\"] == \"audio\":\n", + " yield chunk[\"data\"]\n", + " # Process any remaining text\n", + " if buffer.strip():\n", + " communicate = edge_tts.Communicate(buffer.strip(), TTS_VOICE)\n", + " for chunk in communicate.stream_sync():\n", + " if chunk[\"type\"] == \"audio\":\n", + " yield chunk[\"data\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# load audio to numpy array\n", + "def load_audio(audio_path):\n", + " audio_data, sample_rate = torchaudio.load(audio_path)\n", + " audio_data = audio_data[0].numpy()\n", + " if audio_data.ndim > 1:\n", + " audio_data = np.mean(audio_data, axis=1)\n", + " return audio_data, sample_rate" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# Testing the pipeline\n", + "\n", + "# 1. Load audio\n", + "audio_path = 'audio.mp3'\n", + "audio_data, sample_rate = load_audio(audio_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-09-23 13:50:49,248 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,253 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,494 - INFO - Found 1 speech timestamps\n", + "2024-09-23 13:50:49,495 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,498 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,506 - INFO - Found 1 speech timestamps\n", + "2024-09-23 13:50:49,507 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,511 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,518 - INFO - Found 1 speech timestamps\n", + "2024-09-23 13:50:49,519 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,523 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,531 - INFO - Found 1 speech timestamps\n", + "2024-09-23 13:50:49,532 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,535 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,543 - INFO - Found 1 speech timestamps\n", + "2024-09-23 13:50:49,543 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,546 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,557 - INFO - Found 1 speech timestamps\n", + "2024-09-23 13:50:49,558 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,561 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,569 - INFO - Found 1 speech timestamps\n", + "2024-09-23 13:50:49,570 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,573 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,581 - INFO - Found 1 speech timestamps\n", + "2024-09-23 13:50:49,582 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,585 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,593 - INFO - Found 1 speech timestamps\n", + "2024-09-23 13:50:49,593 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,595 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,604 - INFO - Found 1 speech timestamps\n", + "2024-09-23 13:50:49,605 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,607 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,616 - INFO - Found 1 speech timestamps\n", + "2024-09-23 13:50:49,617 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,619 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,628 - INFO - Found 1 speech timestamps\n", + "2024-09-23 13:50:49,629 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,632 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,640 - INFO - Found 1 speech timestamps\n", + "2024-09-23 13:50:49,641 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,644 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,651 - INFO - Found 0 speech timestamps\n", + "2024-09-23 13:50:49,652 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,654 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,665 - INFO - Found 0 speech timestamps\n", + "2024-09-23 13:50:49,665 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,669 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,678 - INFO - Found 0 speech timestamps\n", + "2024-09-23 13:50:49,678 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,681 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,690 - INFO - Found 0 speech timestamps\n", + "2024-09-23 13:50:49,691 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,693 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,703 - INFO - Found 1 speech timestamps\n", + "2024-09-23 13:50:49,704 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,707 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,718 - INFO - Found 1 speech timestamps\n", + "2024-09-23 13:50:49,719 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,722 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,731 - INFO - Found 0 speech timestamps\n", + "2024-09-23 13:50:49,732 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,734 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,743 - INFO - Found 1 speech timestamps\n", + "2024-09-23 13:50:49,744 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,746 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,759 - INFO - Found 1 speech timestamps\n", + "2024-09-23 13:50:49,760 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,762 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,773 - INFO - Found 1 speech timestamps\n", + "2024-09-23 13:50:49,773 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,776 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,784 - INFO - Found 1 speech timestamps\n", + "2024-09-23 13:50:49,785 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,789 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,798 - INFO - Found 1 speech timestamps\n", + "2024-09-23 13:50:49,799 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,801 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,810 - INFO - Found 1 speech timestamps\n", + "2024-09-23 13:50:49,810 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,813 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,821 - INFO - Found 1 speech timestamps\n", + "2024-09-23 13:50:49,822 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,824 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,833 - INFO - Found 1 speech timestamps\n", + "2024-09-23 13:50:49,834 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,836 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,844 - INFO - Found 1 speech timestamps\n", + "2024-09-23 13:50:49,845 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,847 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,856 - INFO - Found 1 speech timestamps\n", + "2024-09-23 13:50:49,857 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,860 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,871 - INFO - Found 1 speech timestamps\n", + "2024-09-23 13:50:49,872 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,875 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,883 - INFO - Found 1 speech timestamps\n", + "2024-09-23 13:50:49,884 - INFO - Checking voice activity\n", + "2024-09-23 13:50:49,887 - INFO - Audio tensor shape: torch.Size([644]), dtype: torch.float32, device: cuda:0\n", + "2024-09-23 13:50:49,889 - INFO - Found 0 speech timestamps\n" + ] + } + ], + "source": [ + "chunk_size = 500 # ms\n", + "chunk_size_samples = int(sample_rate * chunk_size / 1000)\n", + "chunks = [audio_data[i:i + chunk_size_samples] for i in range(0, len(audio_data), chunk_size_samples)]\n", + "\n", + "# 2. Check voice activity\n", + "voice_activity = [check_vad(chunk, sample_rate) for chunk in chunks]" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-09-23 13:50:50,691 - INFO - Transcribing audio\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Warning: audio is shorter than 30s, language detection may be inaccurate.\n", + "Detected language: en (0.99) in first 30s of audio...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-09-23 13:50:51,041 - INFO - Transcription result: What's this the reporter tried to make a hit piece about Wu Kong is not happy. I wonder why? What a shock. Well wait a second. Should we get to the bottom of this?\n" + ] + } + ], + "source": [ + "text = transcript(audio_data, sample_rate)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "text = llm(text)\n", + "tts_audio = tts_streaming(text)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-09-23 13:50:53,979 - INFO - Performing TTS\n", + "2024-09-23 13:50:53,980 - INFO - Getting response from OpenAI API\n", + "2024-09-23 13:50:54,236 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from IPython.display import Audio\n", + "from pydub import AudioSegment\n", + "from io import BytesIO\n", + "import base64\n", + "\n", + "# Combine audio chunk bytes\n", + "audio_bytes = b''.join(tts_audio)\n", + "\n", + "# Play audio\n", + "audio_segment = AudioSegment.from_file(BytesIO(audio_bytes), format=\"raw\", frame_rate=16000, channels=1, sample_width=2)\n", + "\n", + "Audio(audio_bytes, rate=16000)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "np_audio = np.frombuffer(audio_bytes, dtype=np.int16)\n", + "\n", + "# export audio with numpy\n", + "np_audio.tofile(\"output.wav\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# function to process audio input\n", + "def process_audio_old(audio, state):\n", + " \"\"\"\n", + " Flow:\n", + " 1. Sleep for 0.5 seconds to allow the audio buffer to accumulate\n", + " 2. Check for voice activity\n", + " 3. If voice activity is detected and mode is \"idle\":\n", + " - Set mode to \"listening\"\n", + " 4. If voice activity is detected and mode is \"speaking\":\n", + " - Stop the llm and tts tasks\n", + " - Set mode to \"listening\"\n", + " 5. If voice activity is detected and mode is \"listening\":\n", + " - If there's previous_no_vad_audio, add it to chunk_queue\n", + " - Start accumulating audio chunks in chunk_queue\n", + " - If the length of chunk_queue is greater than 3 seconds\n", + " - Get the first 2 seconds of audio from chunk_queue\n", + " - Run transcription on the first 2 seconds\n", + " - Store the transcription in the state\n", + " - Remove the first 2 seconds of audio from chunk_queue\n", + " 6. If voice activity is not detected:\n", + " - If mode is \"listening\" and there's audio in chunk_queue\n", + " - Add the chunk to chunk_queue\n", + " - Set mode to \"processing\"\n", + " - Run transcription on the leftover audio in chunk_queue\n", + " - Store the transcription in the state\n", + " - Set the mode to \"processing\"\n", + " - If mode is \"processing\"\n", + " - Check if there's any leftover audio in chunk_queue\n", + " - If there is, run transcription on the leftover audio\n", + " - Store the transcription in the state\n", + " - Start LLM and TTS in the background\n", + " - Set mode to \"responding\"\n", + " - If mode is \"responding\"\n", + " - Get the audio byte chunks from TTS\n", + " - Output the full audio\n", + " - Set mode to \"idle\"\n", + " - If mode is \"idle\"\n", + " - do nothing\n", + " \n", + " Ex: Gradio Streaming Audio Example:\n", + " import gradio as gr\n", + " import numpy as np\n", + " import time\n", + "\n", + " def add_to_stream(audio, instream):\n", + " time.sleep(1)\n", + " if audio is None:\n", + " return gr.update(), instream\n", + " if instream is None:\n", + " ret = audio\n", + " else:\n", + " ret = (audio[0], np.concatenate((instream[1], audio[1])))\n", + " return ret, ret\n", + "\n", + "\n", + " with gr.Blocks() as demo:\n", + " inp = gr.Audio(source=\"microphone\")\n", + " out = gr.Audio()\n", + " stream = gr.State()\n", + " clear = gr.Button(\"Clear\")\n", + "\n", + " inp.stream(add_to_stream, [inp, stream], [out, stream])\n", + " clear.click(lambda: [None, None, None], None, [inp, out, stream])\n", + "\n", + "\n", + " if __name__ == \"__main__\":\n", + " demo.launch()\n", + " \"\"\"\n", + " \"\"\"old code:\n", + " time.sleep(0.5)\n", + " if audio is None:\n", + " return None, state\n", + "\n", + " sample_rate, audio_data = audio\n", + " audio_data = np.array(audio_data, dtype=np.float32)\n", + "\n", + " # Convert to mono if stereo\n", + " if audio_data.ndim > 1:\n", + " audio_data = np.mean(audio_data, axis=1)\n", + "\n", + " # Check for voice activity\n", + " vad_result = check_vad(audio_data, sample_rate)\n", + " if vad_result:\n", + " logging.info('Voice activity detected')\n", + " # Voice activity detected\n", + " if state.get(\"previous_audio_chunk\") is not None:\n", + " state[\"audio_buffer\"].append(state[\"previous_audio_chunk\"])\n", + " state[\"audio_buffer\"].append(audio_data)\n", + " state[\"is_speaking\"] = True\n", + " state[\"previous_audio_chunk\"] = audio_data\n", + "\n", + " # Update total speaking time\n", + " chunk_duration = len(audio_data) / sample_rate\n", + " state[\"total_speaking_time\"] += chunk_duration\n", + "\n", + " # Start transcription after 3 seconds\n", + " if state[\"total_speaking_time\"] >= 3.0 and not state[\"transcription_started\"]:\n", + " logging.info('Starting transcription')\n", + " # Start transcribing the first 2 seconds\n", + " accumulated_audio = np.concatenate(state[\"audio_buffer\"])\n", + " first_two_seconds_samples = int(2.0 * sample_rate)\n", + " first_two_seconds_audio = accumulated_audio[:first_two_seconds_samples]\n", + "\n", + " # Transcribe asynchronously\n", + " transcribed_text = transcript(first_two_seconds_audio, sample_rate)\n", + " state[\"transcription\"] += transcribed_text\n", + " state[\"transcription_started\"] = True\n", + "\n", + " # Start LLM and TTS in the background\n", + " state[\"llm_task\"] = llm_and_tts(state[\"transcription\"], state)\n", + " else:\n", + " if state[\"is_speaking\"]:\n", + " logging.info('Voice activity ended')\n", + " # Voice activity just ended\n", + " # Process the accumulated audio\n", + " full_audio = np.concatenate(state[\"audio_buffer\"])\n", + " # Reset the state\n", + " state[\"audio_buffer\"] = []\n", + " state[\"is_speaking\"] = False\n", + " state[\"total_speaking_time\"] = 0.0\n", + " state[\"transcription_started\"] = False\n", + "\n", + " # Transcribe the remaining audio\n", + " transcribed_text = transcript(full_audio, sample_rate)\n", + " state[\"transcription\"] += transcribed_text\n", + "\n", + " # Start LLM and TTS if not already started\n", + " if not state.get(\"llm_task\"):\n", + " state[\"llm_task\"] = llm_and_tts(state[\"transcription\"], state)\n", + "\n", + " # Check if there's audio to output\n", + " if state.get(\"tts_audio_chunks\"):\n", + " logging.info('Outputting audio')\n", + " # Collect audio chunks\n", + " audio_chunks = state[\"tts_audio_chunks\"]\n", + " state[\"tts_audio_chunks\"] = []\n", + " response_audio = b\"\".join(audio_chunks)\n", + " np_response_audio = np.frombuffer(response_audio, dtype=np.int16)\n", + " return (sample_rate, np_response_audio), state\n", + "\n", + " # Collect the last chunk if it exists\n", + " if state.get(\"previous_audio_chunk\") is not None:\n", + " state[\"audio_buffer\"].append(state[\"previous_audio_chunk\"])\n", + "\n", + " return None, state\n", + " \"\"\"\n", + " ...\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Function to process audio input\n", + "def process_audio_chunk(audio, state):\n", + " if audio is None:\n", + " return None, state\n", + " if state is None:\n", + " state = {\n", + " 'mode': 'idle',\n", + " 'chunk_queue': [],\n", + " 'transcription': '',\n", + " 'previous_no_vad_audio': None,\n", + " 'tts_audio_chunks': [],\n", + " 'llm_task': None,\n", + " 'instream': None,\n", + " }\n", + "\n", + " sample_rate, audio_data = audio\n", + " audio_data = np.array(audio_data, dtype=np.float32)\n", + "\n", + " # Convert to mono if stereo\n", + " if audio_data.ndim > 1:\n", + " audio_data = np.mean(audio_data, axis=1)\n", + "\n", + " mode = state['mode']\n", + " chunk_queue = state['chunk_queue']\n", + " transcription = state['transcription']\n", + " previous_no_vad_audio = state['previous_no_vad_audio']\n", + " tts_audio_chunks = state['tts_audio_chunks']\n", + " llm_task = state['llm_task']\n", + " instream = state['instream']\n", + "\n", + " # Check for voice activity\n", + " vad_result = check_vad(audio_data, sample_rate)\n", + "\n", + " if vad_result:\n", + " logging.info(f'Voice activity detected in mode: {mode}')\n", + " if mode == 'idle':\n", + " mode = 'listening'\n", + " elif mode == 'speaking':\n", + " # Stop llm and tts tasks\n", + " if llm_task and llm_task.is_alive():\n", + " # Implement task cancellation logic if possible\n", + " logging.info('Stopping LLM and TTS tasks')\n", + " # Since we cannot kill threads directly, we need to handle this in the tasks\n", + " state['stop_signal'] = True\n", + " llm_task.join()\n", + " mode = 'listening'\n", + " \n", + " if vad_result:\n", + " if mode == 'listening':\n", + " if previous_no_vad_audio is not None:\n", + " chunk_queue.append(previous_no_vad_audio)\n", + " previous_no_vad_audio = None\n", + " # Accumulate audio chunks\n", + " chunk_queue.append(audio_data)\n", + " # Calculate the length of chunk_queue in seconds\n", + " total_samples = sum(len(chunk) for chunk in chunk_queue)\n", + " total_duration = total_samples / sample_rate\n", + " if total_duration > 3.0:\n", + " # Get the first 2 seconds of audio\n", + " first_two_seconds_samples = int(2.0 * sample_rate)\n", + " accumulated_audio = np.concatenate(chunk_queue)\n", + " first_two_seconds_audio = accumulated_audio[:first_two_seconds_samples]\n", + " # Run transcription on the first 2 seconds\n", + " transcribed_text = transcript(first_two_seconds_audio, sample_rate)\n", + " transcription += transcribed_text\n", + " # Remove the first 2 seconds from chunk_queue\n", + " remaining_audio = accumulated_audio[first_two_seconds_samples:]\n", + " chunk_queue = [remaining_audio] if len(remaining_audio) > 0 else []\n", + " elif mode == 'speaking':\n", + " # Continue accumulating audio chunks\n", + " chunk_queue.append(audio_data)\n", + " else:\n", + " logging.info(f'No voice activity detected in mode: {mode}')\n", + " if mode == 'listening' and chunk_queue:\n", + " # Add the chunk to chunk_queue\n", + " chunk_queue.append(audio_data)\n", + " # Run transcription on leftover audio in chunk_queue\n", + " accumulated_audio = np.concatenate(chunk_queue)\n", + " transcribed_text = transcript(accumulated_audio, sample_rate)\n", + " transcription += transcribed_text\n", + " # Clear chunk_queue\n", + " chunk_queue = []\n", + " mode = 'processing'\n", + " # Start LLM and TTS in the background\n", + " if not llm_task or not llm_task.is_alive():\n", + " state['stop_signal'] = False\n", + " llm_task = threading.Thread(target=llm_and_tts, args=(transcription, state))\n", + " llm_task.start()\n", + " elif mode == 'processing':\n", + " # Wait for LLM and TTS to finish\n", + " if llm_task and not llm_task.is_alive():\n", + " mode = 'responding'\n", + " elif mode == 'responding':\n", + " # Get the audio byte chunks from TTS\n", + " if tts_audio_chunks:\n", + " logging.info('Outputting audio response')\n", + " # Collect audio chunks\n", + " response_audio = b\"\".join(tts_audio_chunks)\n", + " np_response_audio = np.frombuffer(response_audio, dtype=np.int16)\n", + " \n", + " if instream is None:\n", + " instream = np_response_audio\n", + " else:\n", + " instream = np.concatenate((instream, np_response_audio))\n", + " \n", + " # Clear tts_audio_chunks\n", + " tts_audio_chunks.clear()\n", + " # Reset transcription for next interaction\n", + " transcription = ''\n", + " # Set mode to \"idle\"\n", + " mode = 'idle'\n", + " \n", + " # Update state\n", + " state.update({\n", + " 'mode': mode,\n", + " 'chunk_queue': chunk_queue,\n", + " 'transcription': transcription,\n", + " 'previous_no_vad_audio': previous_no_vad_audio,\n", + " 'tts_audio_chunks': tts_audio_chunks,\n", + " 'llm_task': None,\n", + " 'instream': instream\n", + " })\n", + " return (sample_rate, instream), state\n", + " elif mode == 'idle':\n", + " # Do nothing\n", + " pass\n", + " else:\n", + " # Store the audio when no VAD is detected\n", + " previous_no_vad_audio = audio_data\n", + "\n", + " # Update state\n", + " state.update({\n", + " 'mode': mode,\n", + " 'chunk_queue': chunk_queue,\n", + " 'transcription': transcription,\n", + " 'previous_no_vad_audio': previous_no_vad_audio,\n", + " 'tts_audio_chunks': tts_audio_chunks,\n", + " 'llm_task': llm_task,\n", + " 'instream': instream\n", + " })\n", + "\n", + " return None, state\n", + "\n", + "# Initialize the state\n", + "initial_state = {\n", + " 'mode': 'idle',\n", + " 'chunk_queue': [],\n", + " 'transcription': '',\n", + " 'previous_no_vad_audio': None,\n", + " 'tts_audio_chunks': [],\n", + " 'llm_task': None,\n", + " 'instream': None,\n", + "}\n", + "\n", + "# Create Gradio interface\n", + "with gr.Blocks() as demo:\n", + " gr.Markdown(\"## Voice-Activated Transcription and Response System\")\n", + " audio_input = gr.Audio(sources=\"microphone\", type=\"numpy\", streaming=True)\n", + " state = gr.State(initial_state)\n", + " audio_output = gr.Audio(label=\"Response Audio\", autoplay=True)\n", + " audio_input.stream(process_audio, [audio_input, state], [audio_output, state])\n", + "\n", + "if __name__ == \"__main__\":\n", + " logging.info('Launching Gradio interface')\n", + " demo.launch()\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}