Spaces:
Sleeping
Sleeping
# server.py remains the same as before | |
# Updated client.py | |
import asyncio | |
import websockets | |
import numpy as np | |
import base64 | |
import argparse | |
import requests | |
import time | |
import torch | |
import torchaudio | |
import av | |
import streamlit as st | |
from typing import List | |
from streamlit_webrtc import WebRtcMode, webrtc_streamer | |
class AudioClient: | |
def __init__(self, server_url="ws://localhost:8000", token_temp=None, categorical_temp=None, gaussian_temp=None): | |
# Convert ws:// to http:// for the base URL | |
self.base_url = server_url.replace("ws://", "http://") | |
self.server_url = f"{server_url}/audio" | |
self.sound_check = False | |
# Set temperatures if provided | |
if any(t is not None for t in [token_temp, categorical_temp, gaussian_temp]): | |
response_message = self.set_temperature_and_echo(token_temp, categorical_temp, gaussian_temp) | |
print(response_message) | |
self.downsampler = torchaudio.transforms.Resample(STREAMING_SAMPLE_RATE, SAMPLE_RATE) | |
self.upsampler = torchaudio.transforms.Resample(SAMPLE_RATE, STREAMING_SAMPLE_RATE) | |
self.ws = None | |
self.in_buffer = None | |
self.out_buffer = None | |
def set_temperature_and_echo(self, token_temp=None, categorical_temp=None, gaussian_temp=None, echo_testing = False): | |
"""Send temperature settings to server""" | |
params = {} | |
if token_temp is not None: | |
params['token_temp'] = token_temp | |
if categorical_temp is not None: | |
params['categorical_temp'] = categorical_temp | |
if gaussian_temp is not None: | |
params['gaussian_temp'] = gaussian_temp | |
response = requests.post(f"{self.base_url}/set_temperature", params=params) | |
response_message = response.json()['message'] | |
return response_message | |
def _resample(self, audio_data: np.ndarray, resampler: torchaudio.transforms.Resample) -> np.ndarray: | |
audio_data = audio_data.astype(np.float32) / 32767.0 | |
audio_data = resampler(torch.tensor(audio_data)).numpy() | |
audio_data = (audio_data * 32767.0).astype(np.int16) | |
return audio_data | |
def upsample(self, audio_data: np.ndarray) -> np.ndarray: | |
return self._resample(audio_data, self.upsampler) | |
def downsample(self, audio_data: np.ndarray) -> np.ndarray: | |
return self._resample(audio_data, self.downsampler) | |
def from_s16_format(self, audio_data: np.ndarray, channels: int) -> np.ndarray: | |
if channels == 2: | |
audio_data = audio_data.reshape(-1, 2).T | |
else: | |
audio_data = audio_data.reshape(-1) | |
return audio_data | |
def to_s16_format(self, audio_data: np.ndarray): | |
if len(audio_data.shape) == 2 and audio_data.shape[0] == 2: | |
audio_data = audio_data.T.reshape(1, -1) | |
elif len(audio_data.shape) == 1: | |
audio_data = audio_data.reshape(1, -1) | |
return audio_data | |
def to_channels(self, audio_data: np.ndarray, channels: int) -> np.ndarray: | |
current_channels = audio_data.shape[0] if len(audio_data.shape) == 2 else 1 | |
if current_channels == channels: | |
return audio_data | |
elif current_channels == 1 and channels == 2: | |
audio_data = np.tile(audio_data, 2).reshape(2, -1) | |
elif current_channels == 2 and channels == 1: | |
audio_data = audio_data.astype(np.float32) / 32767.0 | |
audio_data = audio_data.mean(axis=0) | |
audio_data = (audio_data * 32767.0).astype(np.int16) | |
return audio_data | |
async def process_audio(self, audio_data: np.ndarray) -> np.ndarray: | |
if self.ws is None: | |
self.ws = await websockets.connect(self.server_url) | |
audio_data = audio_data.reshape(-1, CHANNELS) | |
print(f'Data from microphone:{audio_data.shape, audio_data.dtype, audio_data.min(), audio_data.max()}') | |
# Convert to base64 | |
audio_b64 = base64.b64encode(audio_data.tobytes()).decode('utf-8') | |
# Send to server | |
time_sent = time.time() | |
await self.ws.send(f"data:audio/raw;base64,{audio_b64}") | |
# Receive processed audio | |
response = await self.ws.recv() | |
response = response.split(",")[1] | |
time_received = time.time() | |
print(f"Data sent: {audio_b64[:10]}. Data received: {response[:10]}. Received in {(time_received - time_sent) * 1000:.2f} ms") | |
processed_audio = np.frombuffer( | |
base64.b64decode(response), | |
dtype=np.int16 | |
).reshape(-1, CHANNELS) | |
print(f'Data from model:{processed_audio.shape, processed_audio.dtype, processed_audio.min(), processed_audio.max()}') | |
if CHANNELS == 1: | |
processed_audio = processed_audio.reshape(-1) | |
return processed_audio | |
async def queued_audio_frames_callback(self, frames: List[av.AudioFrame]) -> List[av.AudioFrame]: | |
out_frames = [] | |
for frame in frames: | |
# Read in audio | |
audio_data = frame.to_ndarray() | |
# Convert input audio from s16 format, convert to `CHANNELS` number of channels, and downsample | |
audio_data = self.from_s16_format(audio_data, len(frame.layout.channels)) | |
audio_data = self.to_channels(audio_data, CHANNELS) | |
audio_data = self.downsample(audio_data) | |
# Add audio to input buffer | |
if self.in_buffer is None: | |
self.in_buffer = audio_data | |
else: | |
self.in_buffer = np.concatenate((self.in_buffer, audio_data), axis=-1) | |
# Take BLOCK_SIZE samples from input buffer if available for processing | |
if self.in_buffer.shape[0] >= BLOCK_SIZE: | |
audio_data = self.in_buffer[:BLOCK_SIZE] | |
self.in_buffer = self.in_buffer[BLOCK_SIZE:] | |
else: | |
audio_data = None | |
# Process audio if available and add resulting audio to output buffer | |
if audio_data is not None: | |
if not self.sound_check: | |
audio_data = await self.process_audio(audio_data) | |
if self.out_buffer is None: | |
self.out_buffer = audio_data | |
else: | |
self.out_buffer = np.concatenate((self.out_buffer, audio_data), axis=-1) | |
# Take `out_samples` samples from output buffer if available for output | |
out_samples = int(frame.samples * SAMPLE_RATE / STREAMING_SAMPLE_RATE) | |
if self.out_buffer is not None and self.out_buffer.shape[0] >= out_samples: | |
audio_data = self.out_buffer[:out_samples] | |
self.out_buffer = self.out_buffer[out_samples:] | |
else: | |
audio_data = None | |
# Output silence if no audio data available | |
if audio_data is None: | |
# output silence | |
audio_data = np.zeros(out_samples, dtype=np.int16) | |
# Upsample output audio, convert to original number of channels, and convert to s16 format | |
audio_data = self.upsample(audio_data) | |
audio_data = self.to_channels(audio_data, len(frame.layout.channels)) | |
audio_data = self.to_s16_format(audio_data) | |
# return audio data as AudioFrame | |
new_frame = av.AudioFrame.from_ndarray(audio_data, format=frame.format.name, layout=frame.layout.name) | |
new_frame.sample_rate = frame.sample_rate | |
out_frames.append(new_frame) | |
return out_frames | |
def stop(self): | |
if self.ws is not None: | |
# TODO: this hangs. Figure out why. | |
#asyncio.get_event_loop().run_until_complete(self.ws.close()) | |
print("Websocket closed") | |
self.ws = None | |
self.in_buffer = None | |
self.out_buffer = None | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description='Audio Client with Temperature Control') | |
parser.add_argument('--token_temp', '-t1', type=float, help='Token (LM) temperature parameter') | |
parser.add_argument('--categorical_temp', '-t2', type=float, help='Categorical (VAE) temperature parameter') | |
parser.add_argument('--gaussian_temp', '-t3', type=float, help='Gaussian (VAE) temperature parameter') | |
parser.add_argument('--server', '-s', default="ws://localhost:8000", | |
help='Server URL (default: ws://localhost:8000)') | |
parser.add_argument("--use_ice_servers", action="store_true", help="Use public STUN servers") | |
args = parser.parse_args() | |
# Audio settings | |
STREAMING_SAMPLE_RATE = 48000 | |
SAMPLE_RATE = 16000 | |
BLOCK_SIZE = 2000 | |
CHANNELS = 1 | |
st.title("hertz-dev webrtc demo!") | |
st.markdown(""" | |
Welcome to the audio processing interface! Here you can talk live with hertz. | |
- Process audio in real-time through your microphone | |
- Adjust various temperature parameters for inference | |
- Test your microphone with sound check mode | |
- Enable/disable echo cancellation and noise suppression | |
To begin, click the START button below and allow microphone access. | |
""") | |
audio_client = st.session_state.get("audio_client") | |
if audio_client is None: | |
audio_client = AudioClient( | |
server_url=args.server, | |
token_temp=args.token_temp, | |
categorical_temp=args.categorical_temp, | |
gaussian_temp=args.gaussian_temp | |
) | |
st.session_state.audio_client = audio_client | |
with st.sidebar: | |
st.markdown("## Inference Settings") | |
token_temp_default = args.token_temp if args.token_temp is not None else 0.8 | |
token_temp = st.slider("Token Temperature", 0.05, 2.0, token_temp_default, step=0.05) | |
categorical_temp_default = args.categorical_temp if args.categorical_temp is not None else 0.4 | |
categorical_temp = st.slider("Categorical Temperature", 0.01, 1.0, categorical_temp_default, step=0.01) | |
gaussian_temp_default = args.gaussian_temp if args.gaussian_temp is not None else 0.1 | |
gaussian_temp = st.slider("Gaussian Temperature", 0.01, 1.0, gaussian_temp_default, step=0.01) | |
if st.button("Set Temperatures"): | |
response_message = audio_client.set_temperature_and_echo(token_temp, categorical_temp, gaussian_temp) | |
st.write(response_message) | |
st.markdown("## Microphone Settings") | |
audio_client.sound_check = st.toggle("Sound Check (Echo)", value=False) | |
echo_cancellation = st.toggle("Echo Cancellation*‡", value=False) | |
noise_suppression = st.toggle("Noise Suppression*", value=False) | |
st.markdown(r"\* *Restart stream to take effect*") | |
st.markdown("‡ *May cause audio to cut out*") | |
# Use a free STUN server from Google if --use_ice_servers is given | |
# (found in get_ice_servers() at https://github.com/whitphx/streamlit-webrtc/blob/main/sample_utils/turn.py) | |
rtc_configuration = {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]} if args.use_ice_servers else None | |
audio_config = {"echoCancellation": echo_cancellation, "noiseSuppression": noise_suppression} | |
webrtc_streamer( | |
key="streamer", | |
mode=WebRtcMode.SENDRECV, | |
rtc_configuration=rtc_configuration, | |
media_stream_constraints={"audio": audio_config, "video": False}, | |
queued_audio_frames_callback=audio_client.queued_audio_frames_callback, | |
on_audio_ended=audio_client.stop, | |
async_processing=True, | |
) | |