VanYsa commited on
Commit
0af41c6
·
1 Parent(s): 15a2b26

added tts to whole app.py

Browse files
Files changed (1) hide show
  1. app.py +251 -32
app.py CHANGED
@@ -1,49 +1,268 @@
 
 
 
 
 
 
 
 
 
1
  import torch
 
 
 
 
 
 
 
2
 
3
  from transformers import pipeline
4
 
5
- import numpy as np
6
- import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  pipe = pipeline("text-to-speech", model="kakao-enterprise/vits-ljs", device=device)
11
 
12
- # Inference
13
- def generate_audio(text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- output = pipe(text)
16
- output = gr.Audio(value = (output["sampling_rate"], output["audio"].squeeze()), type="numpy", autoplay=True, label="Response Voice Player", show_label=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  visible=True)
18
-
19
- ###############language = "english"
20
- return output
21
-
22
- css = """
23
- #container{
24
- margin: 0 auto;
25
- max-width: 80rem;
26
- }
27
- #intro{
28
- max-width: 100%;
29
- text-align: center;
30
- margin: 0 auto;
31
- }
32
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- # Gradio blocks demo
35
- with gr.Blocks(css=css) as demo_blocks:
36
 
37
- with gr.Row():
38
- with gr.Column():
39
- inp_text = gr.Textbox(label="Input Text", info="What sentence would you like to synthesise?")
40
- btn = gr.Button("Generate Audio!")
41
 
42
- #"Enter the text you would like to synthesise into speech. Amazing! One plus one is equal to two. \n The quick brown fox jumps over the lazy dog. \n 1. Mangoes \n 2. Fruits"
43
- with gr.Column():
44
- out_audio = gr.Audio(value = None, label="Response Voice Player", show_label=True, visible=False)
 
 
 
 
45
 
46
- btn.click(generate_audio, [inp_text], out_audio)
 
 
 
 
 
47
 
 
 
 
 
 
 
 
 
 
48
 
49
- demo_blocks.queue().launch()
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import json
4
+ import librosa
5
+ import os
6
+ import soundfile as sf
7
+ import tempfile
8
+ import uuid
9
+ import transformers
10
  import torch
11
+ import time
12
+ import spaces
13
+
14
+ from nemo.collections.asr.models import ASRModel
15
+
16
+ from transformers import AutoModelForCausalLM
17
+ from transformers import AutoModelForCausalLM, AutoTokenizer
18
 
19
  from transformers import pipeline
20
 
21
+ # Set an environment variable
22
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
23
+
24
+
25
+ SAMPLE_RATE = 16000 # Hz
26
+ MAX_AUDIO_SECONDS = 40 # wont try to transcribe if longer than this
27
+ DESCRIPTION = '''
28
+ <div>
29
+ <h1 style='text-align: center'>MyAlexa: Voice Chat Assistant</h1>
30
+ <p style='text-align: center'>MyAlexa is a demo of a voice chat assistant with chat logs that accepts audio input and outputs an AI response. </p>
31
+ <p>This space uses <a href="https://huggingface.co/nvidia/canary-1b"><b>NVIDIA Canary 1B</b></a> for Automatic Speech-to-text Recognition (ASR), <a href="https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct"><b>Meta Llama 3 8B Insruct</b></a> for the large language model (LLM) and <a href="https://huggingface.co/kakao-enterprise/vits-ljs"><b>VITS-ljs by Kakao Enterprise</b></a> for text to speech (TTS).</p>
32
+ <p>This demo accepts audio inputs not more than 40 seconds long.</p>
33
+ <p>Transcription and responses are limited to the English language.</p>
34
+ </div>
35
+ '''
36
+ PLACEHOLDER = """
37
+ <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
38
+ <img src="./MyAlexaLogo.png" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55; ">
39
+ <p style="font-size: 28px; margin-bottom: 2px; opacity: 0.65;">What's on your mind?</p>
40
+ </div>
41
+ """
42
+ # PLACEHOLDER = """
43
+ # <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
44
+ # <img src="https://i.ibb.co/S35q17Q/My-Alexa-Logo.png" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55; ">
45
+ # <p style="font-size: 28px; margin-bottom: 2px; opacity: 0.65;">What's on your mind?</p>
46
+ # </div>
47
+ # """
48
 
49
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
 
51
+ ### ASR model
52
+ canary_model = ASRModel.from_pretrained("nvidia/canary-1b").to(device)
53
+ canary_model.eval()
54
+ # make sure beam size always 1 for consistency
55
+ canary_model.change_decoding_strategy(None)
56
+ decoding_cfg = canary_model.cfg.decoding
57
+ decoding_cfg.beam.beam_size = 1
58
+ canary_model.change_decoding_strategy(decoding_cfg)
59
+
60
+ ### LLM model
61
+ # Load the tokenizer and model
62
+ llm_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
63
+ llama3_model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto") # to("cuda:0")
64
+
65
+ if llm_tokenizer.pad_token is None:
66
+ llm_tokenizer.pad_token = llm_tokenizer.eos_token
67
+
68
+ terminators = [
69
+ llm_tokenizer.eos_token_id,
70
+ llm_tokenizer.convert_tokens_to_ids("<|eot_id|>")
71
+ ]
72
+
73
+ ### TTS model
74
  pipe = pipeline("text-to-speech", model="kakao-enterprise/vits-ljs", device=device)
75
 
76
+ def convert_audio(audio_filepath, tmpdir, utt_id):
77
+ """
78
+ Convert all files to monochannel 16 kHz wav files.
79
+ Do not convert and raise error if audio is too long.
80
+ Returns output filename and duration.
81
+ """
82
+
83
+ data, sr = librosa.load(audio_filepath, sr=None, mono=True)
84
+
85
+ duration = librosa.get_duration(y=data, sr=sr)
86
+
87
+ if duration > MAX_AUDIO_SECONDS:
88
+ raise gr.Error(
89
+ f"This demo can transcribe up to {MAX_AUDIO_SECONDS} seconds of audio. "
90
+ "If you wish, you may trim the audio using the Audio viewer in Step 1 "
91
+ "(click on the scissors icon to start trimming audio)."
92
+ )
93
+
94
+ if sr != SAMPLE_RATE:
95
+ data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE)
96
+
97
+ out_filename = os.path.join(tmpdir, utt_id + '.wav')
98
+
99
+ # save output audio
100
+ sf.write(out_filename, data, SAMPLE_RATE)
101
+
102
+ return out_filename, duration
103
+
104
+ def transcribe(audio_filepath):
105
+ """
106
+ Transcribes a converted audio file.
107
+ Set to english language with punctuations.
108
+ Returns the transcribed text as a string.
109
+ """
110
 
111
+ if audio_filepath is None:
112
+ raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone")
113
+
114
+ utt_id = uuid.uuid4()
115
+ with tempfile.TemporaryDirectory() as tmpdir:
116
+ converted_audio_filepath, duration = convert_audio(audio_filepath, tmpdir, str(utt_id))
117
+
118
+ # make manifest file and save
119
+ manifest_data = {
120
+ "audio_filepath": converted_audio_filepath,
121
+ "source_lang": "en",
122
+ "target_lang": "en",
123
+ "taskname": "asr",
124
+ "pnc": "yes",
125
+ "answer": "predict",
126
+ "duration": str(duration),
127
+ }
128
+
129
+ manifest_filepath = os.path.join(tmpdir, f'{utt_id}.json')
130
+
131
+ with open(manifest_filepath, 'w') as fout:
132
+ line = json.dumps(manifest_data)
133
+ fout.write(line + '\n')
134
+
135
+ # call transcribe, passing in manifest filepath
136
+ output_text = canary_model.transcribe(manifest_filepath)[0]
137
+
138
+ return output_text
139
+
140
+ def add_message(history, message):
141
+ """
142
+ Adds the input message in the chatbot.
143
+ Returns the updated chatbot history.
144
+ """
145
+ history.append((message, None))
146
+ return history
147
+
148
+ def bot(history, message):
149
+ """
150
+ Gets the bot's response and places the user and bot messages in the chatbot
151
+ Returns the appended chatbot history.
152
+ """
153
+ response = bot_response(message, history)
154
+ lines = response.split("\n")
155
+ complete_lines = '\n'.join(lines[2:])
156
+ answer = ""
157
+ for character in complete_lines:
158
+ answer += character
159
+ new_tuple = list(history[-1])
160
+ new_tuple[1] = answer
161
+ history[-1] = tuple(new_tuple)
162
+ time.sleep(0.05)
163
+ yield history
164
+ #return history
165
+
166
+ @spaces.GPU()
167
+ def bot_response(message, history):
168
+ """
169
+ Generates a streaming response using the llama3-8b model.
170
+ Set max_new_tokens = 100, temperature=0.6, and top_p=0.9
171
+ Returns the generated response in string format.
172
+ """
173
+ conversation = []
174
+ for user, assistant in history:
175
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
176
+ conversation.append({"role": "user", "content": message})
177
+
178
+ input_ids = llm_tokenizer.apply_chat_template(conversation, return_tensors="pt").to(llama3_model.device)
179
+
180
+ outputs = llama3_model.generate(
181
+ input_ids,
182
+ max_new_tokens = 100,
183
+ eos_token_id = terminators,
184
+ do_sample=True,
185
+ temperature=0.6,
186
+ top_p=0.9,
187
+ pad_token_id=llm_tokenizer.pad_token_id,
188
+ )
189
+
190
+ out = outputs[0][input_ids.shape[-1]:]
191
+
192
+ return llm_tokenizer.decode(out, skip_special_tokens=True)
193
+
194
+
195
+ def voice_player(history):
196
+ """
197
+ Plays the generated response using the VITS-ljs model.
198
+ Returns the audio player with the generated response.
199
+ """
200
+ _, text = history[-1]
201
+ voice = pipe(text)
202
+ voice = gr.Audio(value = (voice["sampling_rate"], voice["audio"].squeeze()), type="numpy", autoplay=True, label="MyAlexa Response", show_label=True,
203
  visible=True)
204
+ return voice
205
+
206
+
207
+ with gr.Blocks(
208
+ title="MyAlexa",
209
+ css="""
210
+ textarea { font-size: 18px;}
211
+ """,
212
+ theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg) # make text slightly bigger (default is text_md )
213
+ ) as demo:
214
+
215
+ gr.HTML(DESCRIPTION)
216
+ chatbot = gr.Chatbot(
217
+ [],
218
+ elem_id="chatbot",
219
+ bubble_full_width=False,
220
+ placeholder=PLACEHOLDER,
221
+ label='MyAlexa'
222
+ )
223
+ with gr.Row():
224
+ with gr.Column():
225
+ gr.HTML(
226
+ "<p><b>Step 1:</b> Upload an audio file or record with your microphone.</p>"
227
+ )
228
+
229
+ audio_file = gr.Audio(sources=["microphone", "upload"], type="filepath")
230
+
231
+
232
+ with gr.Column():
233
 
234
+ gr.HTML("<p><b>Step 2:</b> Submit your recorded or uploaded audio as input and wait for MyAlexa's response.</p>")
 
235
 
236
+ submit_button = gr.Button(
237
+ value="Submit audio",
238
+ variant="primary"
239
+ )
240
 
241
+ chat_input = gr.Textbox(
242
+ label="Transcribed text:",
243
+ interactive=False,
244
+ placeholder="Transcribed text will appear here.",
245
+ elem_id="chat_input",
246
+ visible=True # set to True to see processing time of asr transcription
247
+ )
248
 
249
+ out_audio = gr.Audio(
250
+ value = None,
251
+ label="Response Voice Player",
252
+ show_label=True,
253
+ visible=True # set to True to see processing time of tts audio generation
254
+ )
255
 
256
+ chat_msg = chat_input.change(add_message, [chatbot, chat_input], [chatbot], api_name="add_message_in_chatbot")
257
+ bot_msg = chat_msg.then(bot, [chatbot, chat_input], chatbot, api_name="bot_response_in_chatbot")
258
+ voice_msg = bot_msg.then(voice_player, chatbot, out_audio, api_name="bot_response_voice_player")
259
+
260
+ submit_button.click(
261
+ fn=transcribe,
262
+ inputs = [audio_file],
263
+ outputs = [chat_input]
264
+ )
265
 
266
+ demo.queue()
267
+ if __name__ == "__main__":
268
+ demo.launch()