xyfcc commited on
Commit
1f0ed9b
Β·
verified Β·
1 Parent(s): f631117

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +293 -0
app.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os.path
3
+ import tempfile
4
+ import sys
5
+ import re
6
+ import uuid
7
+ import requests
8
+ from argparse import ArgumentParser
9
+
10
+ import torchaudio
11
+ from transformers import WhisperFeatureExtractor, AutoTokenizer
12
+ from speech_tokenizer.modeling_whisper import WhisperVQEncoder
13
+
14
+
15
+ sys.path.insert(0, "./cosyvoice")
16
+ sys.path.insert(0, "./third_party/Matcha-TTS")
17
+
18
+ from speech_tokenizer.utils import extract_speech_token
19
+
20
+ import gradio as gr
21
+ import torch
22
+
23
+
24
+ audio_token_pattern = re.compile(r"<\|audio_(\d+)\|>")
25
+
26
+ from flow_inference import AudioDecoder
27
+
28
+ use_local_interface = True
29
+ if use_local_interface :
30
+ from model_server import ModelWorker
31
+
32
+ if __name__ == "__main__":
33
+ parser = ArgumentParser()
34
+ parser.add_argument("--host", type=str, default="0.0.0.0")
35
+ parser.add_argument("--port", type=int, default="8888")
36
+ parser.add_argument("--flow-path", type=str, default="./glm-4-voice-decoder")
37
+ parser.add_argument("--model-path", type=str, default="THUDM/glm-4-voice-9b")
38
+ parser.add_argument("--tokenizer-path", type= str, default="THUDM/glm-4-voice-tokenizer")
39
+ args = parser.parse_args()
40
+ # --tokenizer-path /home/hanrf/llm/voice/model/ZhipuAI/glm-4-voice-tokenizer --model-path /home/hanrf/llm/voice/model/ZhipuAI/glm-4-voice-9b --flow-path /home/hanrf/llm/voice/model/ZhipuAI/glm-4-voice-decoder
41
+ # args.tokenizer_path = '/home/hanrf/llm/voice/model/ZhipuAI/glm-4-voice-tokenizer'
42
+ # args.model_path = '/home/hanrf/llm/voice/model/ZhipuAI/glm-4-voice-9b'
43
+ # args.flow_path = '/home/hanrf/llm/voice/model/ZhipuAI/glm-4-voice-decoder'
44
+
45
+ flow_config = os.path.join(args.flow_path, "config.yaml")
46
+ flow_checkpoint = os.path.join(args.flow_path, 'flow.pt')
47
+ hift_checkpoint = os.path.join(args.flow_path, 'hift.pt')
48
+ glm_tokenizer = None
49
+ device = "cuda"
50
+ audio_decoder: AudioDecoder = None
51
+ whisper_model, feature_extractor = None, None
52
+ worker = None
53
+
54
+ def initialize_fn():
55
+ global audio_decoder, feature_extractor, whisper_model, glm_model, glm_tokenizer
56
+ if audio_decoder is not None:
57
+ return
58
+
59
+ # GLM
60
+ glm_tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
61
+
62
+ # Flow & Hift
63
+ audio_decoder = AudioDecoder(config_path=flow_config, flow_ckpt_path=flow_checkpoint,
64
+ hift_ckpt_path=hift_checkpoint,
65
+ device=device)
66
+
67
+ # Speech tokenizer
68
+ whisper_model = WhisperVQEncoder.from_pretrained(args.tokenizer_path).eval().to(device)
69
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(args.tokenizer_path)
70
+
71
+ global use_local_interface, worker
72
+ if use_local_interface :
73
+ model_path0 = 'THUDM/glm-4-voice-9b '
74
+ # dtype = 'bfloat16'
75
+ device0 = 'cuda:0'
76
+ worker = ModelWorker(model_path0,device0)
77
+
78
+ def clear_fn():
79
+ return [], [], '', '', '', None, None
80
+
81
+
82
+ def inference_fn(
83
+ temperature: float,
84
+ top_p: float,
85
+ max_new_token: int,
86
+ input_mode,
87
+ audio_path: str | None,
88
+ input_text: str | None,
89
+ history: list[dict],
90
+ previous_input_tokens: str,
91
+ previous_completion_tokens: str,
92
+ ):
93
+
94
+ if input_mode == "audio":
95
+ assert audio_path is not None
96
+ history.append({"role": "user", "content": {"path": audio_path}})
97
+ audio_tokens = extract_speech_token(
98
+ whisper_model, feature_extractor, [audio_path]
99
+ )[0]
100
+ if len(audio_tokens) == 0:
101
+ raise gr.Error("No audio tokens extracted")
102
+ audio_tokens = "".join([f"<|audio_{x}|>" for x in audio_tokens])
103
+ audio_tokens = "<|begin_of_audio|>" + audio_tokens + "<|end_of_audio|>"
104
+ user_input = audio_tokens
105
+ system_prompt = "User will provide you with a speech instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens. "
106
+
107
+ else:
108
+ assert input_text is not None
109
+ history.append({"role": "user", "content": input_text})
110
+ user_input = input_text
111
+ system_prompt = "User will provide you with a text instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens."
112
+
113
+
114
+ # Gather history
115
+ inputs = previous_input_tokens + previous_completion_tokens
116
+ inputs = inputs.strip()
117
+ if "<|system|>" not in inputs:
118
+ inputs += f"<|system|>\n{system_prompt}"
119
+ inputs += f"<|user|>\n{user_input}<|assistant|>streaming_transcription\n"
120
+
121
+ global use_local_interface , worker
122
+ with torch.no_grad():
123
+ if use_local_interface :
124
+ params = { "prompt": inputs,
125
+ "temperature": temperature,
126
+ "top_p": top_p,
127
+ "max_new_tokens": max_new_token, }
128
+ response = worker.generate_stream( params )
129
+
130
+ else :
131
+ response = requests.post(
132
+ "http://localhost:10000/generate_stream",
133
+ data=json.dumps({
134
+ "prompt": inputs,
135
+ "temperature": temperature,
136
+ "top_p": top_p,
137
+ "max_new_tokens": max_new_token,
138
+ }),
139
+ stream=True
140
+ )
141
+ text_tokens, audio_tokens = [], []
142
+ audio_offset = glm_tokenizer.convert_tokens_to_ids('<|audio_0|>')
143
+ end_token_id = glm_tokenizer.convert_tokens_to_ids('<|user|>')
144
+ complete_tokens = []
145
+ prompt_speech_feat = torch.zeros(1, 0, 80).to(device)
146
+ flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int64).to(device)
147
+ this_uuid = str(uuid.uuid4())
148
+ tts_speechs = []
149
+ tts_mels = []
150
+ prev_mel = None
151
+ is_finalize = False
152
+ block_size = 10
153
+ # for chunk in response.iter_lines():
154
+ for chunk in response :
155
+ token_id = json.loads(chunk)["token_id"]
156
+ if token_id == end_token_id:
157
+ is_finalize = True
158
+ if len(audio_tokens) >= block_size or (is_finalize and audio_tokens):
159
+ block_size = 20
160
+ tts_token = torch.tensor(audio_tokens, device=device).unsqueeze(0)
161
+
162
+ if prev_mel is not None:
163
+ prompt_speech_feat = torch.cat(tts_mels, dim=-1).transpose(1, 2)
164
+
165
+ tts_speech, tts_mel = audio_decoder.token2wav(tts_token, uuid=this_uuid,
166
+ prompt_token=flow_prompt_speech_token.to(device),
167
+ prompt_feat=prompt_speech_feat.to(device),
168
+ finalize=is_finalize)
169
+ prev_mel = tts_mel
170
+
171
+ tts_speechs.append(tts_speech.squeeze())
172
+ tts_mels.append(tts_mel)
173
+ yield history, inputs, '', '', (22050, tts_speech.squeeze().cpu().numpy()), None
174
+ flow_prompt_speech_token = torch.cat((flow_prompt_speech_token, tts_token), dim=-1)
175
+ audio_tokens = []
176
+ if not is_finalize:
177
+ complete_tokens.append(token_id)
178
+ if token_id >= audio_offset:
179
+ audio_tokens.append(token_id - audio_offset)
180
+ else:
181
+ text_tokens.append(token_id)
182
+ tts_speech = torch.cat(tts_speechs, dim=-1).cpu()
183
+ complete_text = glm_tokenizer.decode(complete_tokens, spaces_between_special_tokens=False)
184
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
185
+ torchaudio.save(f, tts_speech.unsqueeze(0), 22050, format="wav")
186
+ history.append({"role": "assistant", "content": {"path": f.name, "type": "audio/wav"}})
187
+ history.append({"role": "assistant", "content": glm_tokenizer.decode(text_tokens, ignore_special_tokens=False)})
188
+ yield history, inputs, complete_text, '', None, (22050, tts_speech.numpy())
189
+
190
+
191
+ def update_input_interface(input_mode):
192
+ if input_mode == "audio":
193
+ return [gr.update(visible=True), gr.update(visible=False)]
194
+ else:
195
+ return [gr.update(visible=False), gr.update(visible=True)]
196
+
197
+
198
+ # Create the Gradio interface
199
+ with gr.Blocks(title="GLM-4-Voice Demo", fill_height=True) as demo:
200
+ with gr.Row():
201
+ temperature = gr.Number(
202
+ label="Temperature",
203
+ value=0.2
204
+ )
205
+
206
+ top_p = gr.Number(
207
+ label="Top p",
208
+ value=0.8
209
+ )
210
+
211
+ max_new_token = gr.Number(
212
+ label="Max new tokens",
213
+ value=2000,
214
+ )
215
+
216
+ chatbot = gr.Chatbot(
217
+ elem_id="chatbot",
218
+ bubble_full_width=False,
219
+ type="messages",
220
+ scale=1,
221
+ )
222
+
223
+ with gr.Row():
224
+ with gr.Column():
225
+ input_mode = gr.Radio(["audio", "text"], label="Input Mode", value="audio")
226
+ # audio = gr.Audio(label="Input audio", type='filepath', show_download_button=True, visible=True)
227
+ audio = gr.Audio(sources=["upload","microphone"], label="Input audio", type='filepath', show_download_button=True, visible=True)
228
+ # audio = gr.Audio(source="microphone", label="Input audio", type='filepath', show_download_button=True, visible=True)
229
+ text_input = gr.Textbox(label="Input text", placeholder="Enter your text here...", lines=2, visible=False)
230
+
231
+ with gr.Column():
232
+ submit_btn = gr.Button("Submit")
233
+ reset_btn = gr.Button("Clear")
234
+ output_audio = gr.Audio(label="Play", streaming=True,
235
+ autoplay=True, show_download_button=False)
236
+ complete_audio = gr.Audio(label="Last Output Audio (If Any)", show_download_button=True)
237
+
238
+
239
+
240
+ gr.Markdown("""## Debug Info""")
241
+ with gr.Row():
242
+ input_tokens = gr.Textbox(
243
+ label=f"Input Tokens",
244
+ interactive=False,
245
+ )
246
+
247
+ completion_tokens = gr.Textbox(
248
+ label=f"Completion Tokens",
249
+ interactive=False,
250
+ )
251
+
252
+ detailed_error = gr.Textbox(
253
+ label=f"Detailed Error",
254
+ interactive=False,
255
+ )
256
+
257
+ history_state = gr.State([])
258
+
259
+ respond = submit_btn.click(
260
+ inference_fn,
261
+ inputs=[
262
+ temperature,
263
+ top_p,
264
+ max_new_token,
265
+ input_mode,
266
+ audio,
267
+ text_input,
268
+ history_state,
269
+ input_tokens,
270
+ completion_tokens,
271
+ ],
272
+ outputs=[history_state, input_tokens, completion_tokens, detailed_error, output_audio, complete_audio]
273
+ )
274
+
275
+ respond.then(lambda s: s, [history_state], chatbot)
276
+
277
+ reset_btn.click(clear_fn, outputs=[chatbot, history_state, input_tokens, completion_tokens, detailed_error, output_audio, complete_audio])
278
+ input_mode.input(clear_fn, outputs=[chatbot, history_state, input_tokens, completion_tokens, detailed_error, output_audio, complete_audio]).then(update_input_interface, inputs=[input_mode], outputs=[audio, text_input])
279
+
280
+ initialize_fn()
281
+ # Launch the interface
282
+ demo.launch(
283
+ server_port=args.port,
284
+ server_name=args.host,
285
+ ssl_verify=False,
286
+ share=True
287
+ )
288
+
289
+ '''
290
+ server.launch(share=True)
291
+ https://1a9b77cb89ac33f546.gradio.live
292
+
293
+ '''