File size: 8,623 Bytes
0af41c6 5817424 0af41c6 7771452 e48d3ab 7771452 14b3eb9 0af41c6 14b3eb9 0af41c6 548c4d8 0af41c6 7771452 14b3eb9 0af41c6 14b3eb9 0af41c6 14b3eb9 e4337eb e48d3ab 548c4d8 14b3eb9 0af41c6 14b3eb9 0af41c6 e48d3ab 0af41c6 14b3eb9 0af41c6 14b3eb9 0af41c6 11efe1e 0af41c6 14b3eb9 0af41c6 14b3eb9 0af41c6 21d3f6f 474452f 0af41c6 14b3eb9 0af41c6 548c4d8 11efe1e 474452f 548c4d8 0af41c6 14b3eb9 0af41c6 14b3eb9 0af41c6 21d3f6f 0af41c6 548c4d8 0af41c6 7771452 0af41c6 7771452 0af41c6 7771452 548c4d8 0af41c6 c3f6601 0af41c6 f72f872 14b3eb9 3011ada 548c4d8 0af41c6 14b3eb9 0af41c6 14b3eb9 548c4d8 3243f5c 0af41c6 7771452 14b3eb9 0af41c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 |
import gradio as gr
import numpy as np
import json
import librosa
import os
import soundfile as sf
import tempfile
import uuid
import transformers
import torch
import time
import spaces
from nemo.collections.asr.models import ASRModel
from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import pipeline
#### Variables ###
# Set an environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)
SAMPLE_RATE = 16000 # Hz
MAX_AUDIO_SECONDS = 40 # wont try to transcribe if longer than this
DESCRIPTION = '''
<div>
<h1 style='text-align: center'>MyAlexa: Voice Chat Assistant</h1>
<p style='text-align: center'>MyAlexa is a demo of a voice chat assistant with chat logs that accepts audio input and outputs an AI response. </p>
<p>This space uses <a href="https://huggingface.co/nvidia/canary-1b"><b>NVIDIA Canary 1B</b></a> for Automatic Speech-to-text Recognition (ASR), <a href="https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct"><b>Meta Llama 3 8B Insruct</b></a> for the large language model (LLM) and <a href="https://huggingface.co/kakao-enterprise/vits-ljs"><b>VITS-ljs by Kakao Enterprise</b></a> for text to speech (TTS).</p>
<p>This demo accepts audio inputs not more than 40 seconds long. Transcription and responses are limited to the English language.</p>
<p>The LLM max_new_tokens, temperature and top_p are set to 512, 0.6 and 0.9 respectively</p>
</div>
'''
PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
<img src="https://i.ibb.co/S35q17Q/My-Alexa-Logo.png" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55; ">
<p style="font-size: 28px; margin-bottom: 2px; opacity: 0.65;">What's on your mind?</p>
</div>
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
### ASR model ###
canary_model = ASRModel.from_pretrained("nvidia/canary-1b").to(device)
canary_model.eval()
# make sure beam size always 1 for consistency
canary_model.change_decoding_strategy(None)
decoding_cfg = canary_model.cfg.decoding
decoding_cfg.beam.beam_size = 1
canary_model.change_decoding_strategy(decoding_cfg)
### LLM model ###
# Load the tokenizer and model
llm_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
llama3_model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto") # to("cuda:0")
if llm_tokenizer.pad_token is None:
llm_tokenizer.pad_token = llm_tokenizer.eos_token
terminators = [
llm_tokenizer.eos_token_id,
llm_tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
### TTS model ###
pipe = pipeline("text-to-speech", model="kakao-enterprise/vits-ljs", device=device)
### Start of functions ###
def convert_audio(audio_filepath, tmpdir, utt_id):
"""
Convert all files to monochannel 16 kHz wav files.
Do not convert and raise error if audio is too long.
Returns output filename and duration.
"""
data, sr = librosa.load(audio_filepath, sr=None, mono=True)
duration = librosa.get_duration(y=data, sr=sr)
if duration > MAX_AUDIO_SECONDS:
raise gr.Error(
f"This demo can transcribe up to {MAX_AUDIO_SECONDS} seconds of audio. "
"If you wish, you may trim the audio using the Audio viewer in Step 1 "
"(click on the scissors icon to start trimming audio)."
)
if sr != SAMPLE_RATE:
data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE)
out_filename = os.path.join(tmpdir, utt_id + '.wav')
# save output audio
sf.write(out_filename, data, SAMPLE_RATE)
return out_filename, duration
def transcribe(audio_filepath):
"""
Transcribes a converted audio file using the asr model.
Set to the english language with punctuations.
Returns the transcribed text as a string.
"""
if audio_filepath is None:
raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone")
utt_id = uuid.uuid4()
with tempfile.TemporaryDirectory() as tmpdir:
converted_audio_filepath, duration = convert_audio(audio_filepath, tmpdir, str(utt_id))
# make manifest file and save
manifest_data = {
"audio_filepath": converted_audio_filepath,
"source_lang": "en",
"target_lang": "en",
"taskname": "asr",
"pnc": "yes",
"answer": "predict",
"duration": str(duration),
}
manifest_filepath = os.path.join(tmpdir, f'{utt_id}.json')
with open(manifest_filepath, 'w') as fout:
line = json.dumps(manifest_data)
fout.write(line + '\n')
# call transcribe, passing in manifest filepath
output_text = canary_model.transcribe(manifest_filepath)[0]
return output_text
def add_message(history, message):
"""
Adds the input message in the chatbot.
Returns the updated chatbot.
"""
history.append((message, None))
return history
def bot(history, message):
"""
Gets the bot's response and adds it in the chatbot.
Returns the appended chatbot.
"""
response = bot_response(message, history)
lines = response.split("\n")
complete_lines = '\n'.join(lines[2:])
answer = ""
for character in complete_lines:
answer += character
new_tuple = list(history[-1])
new_tuple[1] = answer
history[-1] = tuple(new_tuple)
time.sleep(0.01)
yield history
@spaces.GPU()
def bot_response(message, history):
"""
Generates a streaming response using the llm model.
Set max_new_tokens = 512, temperature=0.6, and top_p=0.9
Returns the generated response in string format.
"""
conversation = []
for user, assistant in history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = llm_tokenizer.apply_chat_template(conversation, return_tensors="pt").to(llama3_model.device)
outputs = llama3_model.generate(
input_ids,
max_new_tokens = 512,
eos_token_id = terminators,
do_sample=True,
temperature=0.6,
top_p=0.9,
pad_token_id=llm_tokenizer.pad_token_id,
)
out = outputs[0][input_ids.shape[-1]:]
return llm_tokenizer.decode(out, skip_special_tokens=True)
@spaces.GPU()
def voice_player(history):
"""
Plays the generated response using the tts model.
Returns the audio player with the generated response.
"""
_, text = history[-1]
text = text.replace("*", "") # Temp. fix: For the tts to not read the asterisk of bold text
voice = pipe(text)
voice = gr.Audio(value = (
voice["sampling_rate"],
voice["audio"].squeeze()),
type="numpy", autoplay=True,
label="MyAlexa Response",
show_label=True,
visible=True)
return voice
### End of functions ###
### Interface using Blocks###
with gr.Blocks(
title="MyAlexa",
css="""
textarea { font-size: 18px;}
""",
theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg) # make text slightly bigger (default is text_md )
) as demo:
gr.HTML(DESCRIPTION)
chatbot = gr.Chatbot(
[],
elem_id="chatbot",
bubble_full_width=False,
placeholder=PLACEHOLDER,
label='MyAlexa'
)
with gr.Row():
with gr.Column():
gr.HTML(
"<p><b>Step 1:</b> Upload an audio file or record with your microphone.</p>"
)
audio_file = gr.Audio(
sources=["microphone", "upload"],
type="filepath"
)
with gr.Column():
gr.HTML("<p><b>Step 2:</b> Submit your recorded or uploaded audio as input and wait for MyAlexa's response.</p>")
submit_button = gr.Button(
value="Submit audio",
variant="primary"
)
chat_input = gr.Textbox( # Shows the transcribed text
label="Transcribed text:",
interactive=False,
placeholder="Transcribed text will appear here.",
elem_id="chat_input",
visible=False # set to True to see processing time of asr transcription
)
gr.HTML("<p><b>[Optional]:</b> Replay MyAlexa's voice response.</p>")
out_audio = gr.Audio( # Shows an audio player for the generated response
value = None,
label="Response Audio Player",
show_label=True,
visible=False # set to True to see processing time of the first tts audio generation
)
chat_msg = chat_input.change(add_message, [chatbot, chat_input], [chatbot], api_name="add_message_in_chatbot")
bot_msg = chat_msg.then(bot, [chatbot, chat_input], chatbot, api_name="bot_response_in_chatbot")
voice_msg = bot_msg.then(voice_player, chatbot, out_audio, api_name="bot_response_voice_player")
submit_button.click(
fn=transcribe,
inputs = [audio_file],
outputs = [chat_input]
)
### Queue and launch the demo ###
demo.queue()
if __name__ == "__main__":
demo.launch()
|