hanAlex commited on
Commit
5dfaf9b
Β·
verified Β·
1 Parent(s): 543f710

Upload 2 files

Browse files
Files changed (2) hide show
  1. audio_process.py +93 -0
  2. web_demo.py +267 -0
audio_process.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import librosa
3
+ import soundfile as sf
4
+ import numpy as np
5
+ from pathlib import Path
6
+ import io
7
+
8
+ # Split audio stream at silence points to prevent playback stuttering issues
9
+ # caused by AAC encoder frame padding when streaming audio through Gradio audio components.
10
+ class AudioStreamProcessor:
11
+ def __init__(self, sr=22050, min_silence_duration=0.1, threshold_db=-40):
12
+ self.sr = sr
13
+ self.min_silence_duration = min_silence_duration
14
+ self.threshold_db = threshold_db
15
+ self.buffer = np.array([])
16
+
17
+
18
+ def process(self, audio_data, last=False):
19
+ """
20
+ Add audio data and process it
21
+ params:
22
+ audio_data: audio data in numpy array
23
+ last: whether this is the last chunk of data
24
+ returns:
25
+ Processed audio data, returns None if no split point is found
26
+ """
27
+
28
+ # Add new data to buffer
29
+ self.buffer = np.concatenate([self.buffer, audio_data]) if len(self.buffer) > 0 else audio_data
30
+
31
+ if last:
32
+ result = self.buffer
33
+ self.buffer = np.array([])
34
+ return self._to_wav_bytes(result)
35
+
36
+ # Find silence boundary
37
+ split_point = self._find_silence_boundary(self.buffer)
38
+
39
+ if split_point is not None:
40
+ # Modified: Extend split point to the end of silence
41
+ silence_end = self._find_silence_end(split_point)
42
+ result = self.buffer[:silence_end]
43
+ self.buffer = self.buffer[silence_end:]
44
+ return self._to_wav_bytes(result)
45
+
46
+ return None
47
+
48
+ def _find_silence_boundary(self, audio):
49
+ """
50
+ Find the starting point of silence boundary in audio
51
+ """
52
+ # Convert audio to decibels
53
+ db = librosa.amplitude_to_db(np.abs(audio), ref=np.max)
54
+
55
+ # Find points below threshold
56
+ silence_points = np.where(db < self.threshold_db)[0]
57
+
58
+ if len(silence_points) == 0:
59
+ return None
60
+
61
+ # Calculate minimum silence samples
62
+ min_silence_samples = int(self.min_silence_duration * self.sr)
63
+
64
+ # Search backwards for continuous silence segment starting point
65
+ for i in range(len(silence_points) - min_silence_samples, -1, -1):
66
+ if i < 0:
67
+ break
68
+ if np.all(np.diff(silence_points[i:i+min_silence_samples]) == 1):
69
+ return silence_points[i]
70
+
71
+ return None
72
+
73
+ def _find_silence_end(self, start_point):
74
+ """
75
+ Find the end point of silence segment
76
+ """
77
+ db = librosa.amplitude_to_db(np.abs(self.buffer[start_point:]), ref=np.max)
78
+ silence_points = np.where(db >= self.threshold_db)[0]
79
+
80
+ if len(silence_points) == 0:
81
+ return len(self.buffer)
82
+
83
+ return start_point + silence_points[0]
84
+
85
+ def _to_wav_bytes(self, audio_data):
86
+ """
87
+ trans_to_wav_bytes
88
+ """
89
+ wav_buffer = io.BytesIO()
90
+ sf.write(wav_buffer, audio_data, self.sr, format='WAV')
91
+ return wav_buffer.getvalue()
92
+
93
+
web_demo.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os.path
3
+ import tempfile
4
+ import sys
5
+ import re
6
+ import uuid
7
+ import requests
8
+ from argparse import ArgumentParser
9
+
10
+ import torchaudio
11
+ from transformers import WhisperFeatureExtractor, AutoTokenizer
12
+ from speech_tokenizer.modeling_whisper import WhisperVQEncoder
13
+
14
+
15
+ sys.path.insert(0, "./cosyvoice")
16
+ sys.path.insert(0, "./third_party/Matcha-TTS")
17
+
18
+ from speech_tokenizer.utils import extract_speech_token
19
+
20
+ import gradio as gr
21
+ import torch
22
+
23
+ audio_token_pattern = re.compile(r"<\|audio_(\d+)\|>")
24
+
25
+ from flow_inference import AudioDecoder
26
+ from audio_process import AudioStreamProcessor
27
+
28
+ if __name__ == "__main__":
29
+ parser = ArgumentParser()
30
+ parser.add_argument("--host", type=str, default="0.0.0.0")
31
+ parser.add_argument("--port", type=int, default="8888")
32
+ parser.add_argument("--flow-path", type=str, default="./glm-4-voice-decoder")
33
+ parser.add_argument("--model-path", type=str, default="THUDM/glm-4-voice-9b")
34
+ parser.add_argument("--tokenizer-path", type= str, default="THUDM/glm-4-voice-tokenizer")
35
+ args = parser.parse_args()
36
+
37
+ flow_config = os.path.join(args.flow_path, "config.yaml")
38
+ flow_checkpoint = os.path.join(args.flow_path, 'flow.pt')
39
+ hift_checkpoint = os.path.join(args.flow_path, 'hift.pt')
40
+ glm_tokenizer = None
41
+ device = "cuda"
42
+ audio_decoder: AudioDecoder = None
43
+ whisper_model, feature_extractor = None, None
44
+
45
+
46
+ def initialize_fn():
47
+ global audio_decoder, feature_extractor, whisper_model, glm_model, glm_tokenizer
48
+ if audio_decoder is not None:
49
+ return
50
+
51
+ # GLM
52
+ glm_tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
53
+
54
+ # Flow & Hift
55
+ audio_decoder = AudioDecoder(config_path=flow_config, flow_ckpt_path=flow_checkpoint,
56
+ hift_ckpt_path=hift_checkpoint,
57
+ device=device)
58
+
59
+ # Speech tokenizer
60
+ whisper_model = WhisperVQEncoder.from_pretrained(args.tokenizer_path).eval().to(device)
61
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(args.tokenizer_path)
62
+
63
+
64
+ def clear_fn():
65
+ return [], [], '', '', '', None, None
66
+
67
+
68
+ def inference_fn(
69
+ temperature: float,
70
+ top_p: float,
71
+ max_new_token: int,
72
+ input_mode,
73
+ audio_path: str | None,
74
+ input_text: str | None,
75
+ history: list[dict],
76
+ previous_input_tokens: str,
77
+ previous_completion_tokens: str,
78
+ ):
79
+
80
+ if input_mode == "audio":
81
+ assert audio_path is not None
82
+ history.append({"role": "user", "content": {"path": audio_path}})
83
+ audio_tokens = extract_speech_token(
84
+ whisper_model, feature_extractor, [audio_path]
85
+ )[0]
86
+ if len(audio_tokens) == 0:
87
+ raise gr.Error("No audio tokens extracted")
88
+ audio_tokens = "".join([f"<|audio_{x}|>" for x in audio_tokens])
89
+ audio_tokens = "<|begin_of_audio|>" + audio_tokens + "<|end_of_audio|>"
90
+ user_input = audio_tokens
91
+ system_prompt = "User will provide you with a speech instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens. "
92
+
93
+ else:
94
+ assert input_text is not None
95
+ history.append({"role": "user", "content": input_text})
96
+ user_input = input_text
97
+ system_prompt = "User will provide you with a text instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens."
98
+
99
+
100
+ # Gather history
101
+ inputs = previous_input_tokens + previous_completion_tokens
102
+ inputs = inputs.strip()
103
+ if "<|system|>" not in inputs:
104
+ inputs += f"<|system|>\n{system_prompt}"
105
+ inputs += f"<|user|>\n{user_input}<|assistant|>streaming_transcription\n"
106
+
107
+ with torch.no_grad():
108
+ response = requests.post(
109
+ "http://localhost:10000/generate_stream",
110
+ data=json.dumps({
111
+ "prompt": inputs,
112
+ "temperature": temperature,
113
+ "top_p": top_p,
114
+ "max_new_tokens": max_new_token,
115
+ }),
116
+ stream=True
117
+ )
118
+ text_tokens, audio_tokens = [], []
119
+ audio_offset = glm_tokenizer.convert_tokens_to_ids('<|audio_0|>')
120
+ end_token_id = glm_tokenizer.convert_tokens_to_ids('<|user|>')
121
+ complete_tokens = []
122
+ prompt_speech_feat = torch.zeros(1, 0, 80).to(device)
123
+ flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int64).to(device)
124
+ this_uuid = str(uuid.uuid4())
125
+ tts_speechs = []
126
+ tts_mels = []
127
+ prev_mel = None
128
+ is_finalize = False
129
+ block_size_list = [25,50,100,150,200]
130
+ block_size_idx = 0
131
+ block_size = block_size_list[block_size_idx]
132
+ audio_processor = AudioStreamProcessor()
133
+ for chunk in response.iter_lines():
134
+ token_id = json.loads(chunk)["token_id"]
135
+ if token_id == end_token_id:
136
+ is_finalize = True
137
+ if len(audio_tokens) >= block_size or (is_finalize and audio_tokens):
138
+ if block_size_idx < len(block_size_list) - 1:
139
+ block_size_idx += 1
140
+ block_size = block_size_list[block_size_idx]
141
+ tts_token = torch.tensor(audio_tokens, device=device).unsqueeze(0)
142
+
143
+ if prev_mel is not None:
144
+ prompt_speech_feat = torch.cat(tts_mels, dim=-1).transpose(1, 2)
145
+
146
+ tts_speech, tts_mel = audio_decoder.token2wav(tts_token, uuid=this_uuid,
147
+ prompt_token=flow_prompt_speech_token.to(device),
148
+ prompt_feat=prompt_speech_feat.to(device),
149
+ finalize=is_finalize)
150
+ prev_mel = tts_mel
151
+
152
+ audio_bytes = audio_processor.process(tts_speech.clone().cpu().numpy()[0], last=is_finalize)
153
+
154
+ tts_speechs.append(tts_speech.squeeze())
155
+ tts_mels.append(tts_mel)
156
+ if audio_bytes:
157
+ yield history, inputs, '', '', audio_bytes, None
158
+ flow_prompt_speech_token = torch.cat((flow_prompt_speech_token, tts_token), dim=-1)
159
+ audio_tokens = []
160
+ if not is_finalize:
161
+ complete_tokens.append(token_id)
162
+ if token_id >= audio_offset:
163
+ audio_tokens.append(token_id - audio_offset)
164
+ else:
165
+ text_tokens.append(token_id)
166
+ tts_speech = torch.cat(tts_speechs, dim=-1).cpu()
167
+ complete_text = glm_tokenizer.decode(complete_tokens, spaces_between_special_tokens=False)
168
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
169
+ torchaudio.save(f, tts_speech.unsqueeze(0), 22050, format="wav")
170
+ history.append({"role": "assistant", "content": {"path": f.name, "type": "audio/wav"}})
171
+ history.append({"role": "assistant", "content": glm_tokenizer.decode(text_tokens, ignore_special_tokens=False)})
172
+ yield history, inputs, complete_text, '', None, (22050, tts_speech.numpy())
173
+
174
+
175
+ def update_input_interface(input_mode):
176
+ if input_mode == "audio":
177
+ return [gr.update(visible=True), gr.update(visible=False)]
178
+ else:
179
+ return [gr.update(visible=False), gr.update(visible=True)]
180
+
181
+
182
+ # Create the Gradio interface
183
+ with gr.Blocks(title="GLM-4-Voice Demo", fill_height=True) as demo:
184
+ with gr.Row():
185
+ temperature = gr.Number(
186
+ label="Temperature",
187
+ value=0.2
188
+ )
189
+
190
+ top_p = gr.Number(
191
+ label="Top p",
192
+ value=0.8
193
+ )
194
+
195
+ max_new_token = gr.Number(
196
+ label="Max new tokens",
197
+ value=2000,
198
+ )
199
+
200
+ chatbot = gr.Chatbot(
201
+ elem_id="chatbot",
202
+ bubble_full_width=False,
203
+ type="messages",
204
+ scale=1,
205
+ )
206
+
207
+ with gr.Row():
208
+ with gr.Column():
209
+ input_mode = gr.Radio(["audio", "text"], label="Input Mode", value="audio")
210
+ audio = gr.Audio(label="Input audio", type='filepath', show_download_button=True, visible=True)
211
+ text_input = gr.Textbox(label="Input text", placeholder="Enter your text here...", lines=2, visible=False)
212
+
213
+ with gr.Column():
214
+ submit_btn = gr.Button("Submit")
215
+ reset_btn = gr.Button("Clear")
216
+ output_audio = gr.Audio(label="Play", streaming=True,
217
+ autoplay=True, show_download_button=False)
218
+ complete_audio = gr.Audio(label="Last Output Audio (If Any)", show_download_button=True)
219
+
220
+
221
+
222
+ gr.Markdown("""## Debug Info""")
223
+ with gr.Row():
224
+ input_tokens = gr.Textbox(
225
+ label=f"Input Tokens",
226
+ interactive=False,
227
+ )
228
+
229
+ completion_tokens = gr.Textbox(
230
+ label=f"Completion Tokens",
231
+ interactive=False,
232
+ )
233
+
234
+ detailed_error = gr.Textbox(
235
+ label=f"Detailed Error",
236
+ interactive=False,
237
+ )
238
+
239
+ history_state = gr.State([])
240
+
241
+ respond = submit_btn.click(
242
+ inference_fn,
243
+ inputs=[
244
+ temperature,
245
+ top_p,
246
+ max_new_token,
247
+ input_mode,
248
+ audio,
249
+ text_input,
250
+ history_state,
251
+ input_tokens,
252
+ completion_tokens,
253
+ ],
254
+ outputs=[history_state, input_tokens, completion_tokens, detailed_error, output_audio, complete_audio]
255
+ )
256
+
257
+ respond.then(lambda s: s, [history_state], chatbot)
258
+
259
+ reset_btn.click(clear_fn, outputs=[chatbot, history_state, input_tokens, completion_tokens, detailed_error, output_audio, complete_audio])
260
+ input_mode.input(clear_fn, outputs=[chatbot, history_state, input_tokens, completion_tokens, detailed_error, output_audio, complete_audio]).then(update_input_interface, inputs=[input_mode], outputs=[audio, text_input])
261
+
262
+ initialize_fn()
263
+ # Launch the interface
264
+ demo.launch(
265
+ server_port=args.port,
266
+ server_name=args.host
267
+ )