uto1125 commited on
Commit
28c311d
·
verified ·
1 Parent(s): 26911ee

Delete tools

Browse files
tools/__pycache__/api.cpython-310.pyc DELETED
Binary file (10.1 kB)
 
tools/__pycache__/commons.cpython-310.pyc DELETED
Binary file (1.49 kB)
 
tools/__pycache__/file.cpython-310.pyc DELETED
Binary file (2.99 kB)
 
tools/__pycache__/webui.cpython-310.pyc DELETED
Binary file (9.5 kB)
 
tools/api.py DELETED
@@ -1,440 +0,0 @@
1
- import base64
2
- import io
3
- import json
4
- import queue
5
- import random
6
- import sys
7
- import traceback
8
- import wave
9
- from argparse import ArgumentParser
10
- from http import HTTPStatus
11
- from pathlib import Path
12
- from typing import Annotated, Any, Literal, Optional
13
-
14
- import numpy as np
15
- import ormsgpack
16
- import pyrootutils
17
- import soundfile as sf
18
- import torch
19
- import torchaudio
20
- from baize.datastructures import ContentType
21
- from kui.asgi import (
22
- Body,
23
- FactoryClass,
24
- HTTPException,
25
- HttpRequest,
26
- HttpView,
27
- JSONResponse,
28
- Kui,
29
- OpenAPI,
30
- StreamResponse,
31
- )
32
- from kui.asgi.routing import MultimethodRoutes
33
- from loguru import logger
34
- from pydantic import BaseModel, Field, conint
35
-
36
- pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
37
-
38
- # from fish_speech.models.vqgan.lit_module import VQGAN
39
- from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
40
- from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
41
- from fish_speech.utils import autocast_exclude_mps
42
- from tools.commons import ServeReferenceAudio, ServeTTSRequest
43
- from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
44
- from tools.llama.generate import (
45
- GenerateRequest,
46
- GenerateResponse,
47
- WrappedGenerateResponse,
48
- launch_thread_safe_queue,
49
- )
50
- from tools.vqgan.inference import load_model as load_decoder_model
51
-
52
-
53
- def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
54
- buffer = io.BytesIO()
55
-
56
- with wave.open(buffer, "wb") as wav_file:
57
- wav_file.setnchannels(channels)
58
- wav_file.setsampwidth(bit_depth // 8)
59
- wav_file.setframerate(sample_rate)
60
-
61
- wav_header_bytes = buffer.getvalue()
62
- buffer.close()
63
- return wav_header_bytes
64
-
65
-
66
- # Define utils for web server
67
- async def http_execption_handler(exc: HTTPException):
68
- return JSONResponse(
69
- dict(
70
- statusCode=exc.status_code,
71
- message=exc.content,
72
- error=HTTPStatus(exc.status_code).phrase,
73
- ),
74
- exc.status_code,
75
- exc.headers,
76
- )
77
-
78
-
79
- async def other_exception_handler(exc: "Exception"):
80
- traceback.print_exc()
81
-
82
- status = HTTPStatus.INTERNAL_SERVER_ERROR
83
- return JSONResponse(
84
- dict(statusCode=status, message=str(exc), error=status.phrase),
85
- status,
86
- )
87
-
88
-
89
- def load_audio(reference_audio, sr):
90
- if len(reference_audio) > 255 or not Path(reference_audio).exists():
91
- audio_data = reference_audio
92
- reference_audio = io.BytesIO(audio_data)
93
-
94
- waveform, original_sr = torchaudio.load(
95
- reference_audio, backend="sox" if sys.platform == "linux" else "soundfile"
96
- )
97
-
98
- if waveform.shape[0] > 1:
99
- waveform = torch.mean(waveform, dim=0, keepdim=True)
100
-
101
- if original_sr != sr:
102
- resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr)
103
- waveform = resampler(waveform)
104
-
105
- audio = waveform.squeeze().numpy()
106
- return audio
107
-
108
-
109
- def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
110
- if enable_reference_audio and reference_audio is not None:
111
- # Load audios, and prepare basic info here
112
- reference_audio_content = load_audio(
113
- reference_audio, decoder_model.spec_transform.sample_rate
114
- )
115
-
116
- audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
117
- None, None, :
118
- ]
119
- audio_lengths = torch.tensor(
120
- [audios.shape[2]], device=decoder_model.device, dtype=torch.long
121
- )
122
- logger.info(
123
- f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds"
124
- )
125
-
126
- # VQ Encoder
127
- if isinstance(decoder_model, FireflyArchitecture):
128
- prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
129
-
130
- logger.info(f"Encoded prompt: {prompt_tokens.shape}")
131
- else:
132
- prompt_tokens = None
133
- logger.info("No reference audio provided")
134
-
135
- return prompt_tokens
136
-
137
-
138
- def decode_vq_tokens(
139
- *,
140
- decoder_model,
141
- codes,
142
- ):
143
- feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
144
- logger.info(f"VQ features: {codes.shape}")
145
-
146
- if isinstance(decoder_model, FireflyArchitecture):
147
- # VQGAN Inference
148
- return decoder_model.decode(
149
- indices=codes[None],
150
- feature_lengths=feature_lengths,
151
- )[0].squeeze()
152
-
153
- raise ValueError(f"Unknown model type: {type(decoder_model)}")
154
-
155
-
156
- routes = MultimethodRoutes(base_class=HttpView)
157
-
158
-
159
- def get_content_type(audio_format):
160
- if audio_format == "wav":
161
- return "audio/wav"
162
- elif audio_format == "flac":
163
- return "audio/flac"
164
- elif audio_format == "mp3":
165
- return "audio/mpeg"
166
- else:
167
- return "application/octet-stream"
168
-
169
-
170
- @torch.inference_mode()
171
- def inference(req: ServeTTSRequest):
172
-
173
- idstr: str | None = req.reference_id
174
- if idstr is not None:
175
- ref_folder = Path("references") / idstr
176
- ref_folder.mkdir(parents=True, exist_ok=True)
177
- ref_audios = list_files(
178
- ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
179
- )
180
- prompt_tokens = [
181
- encode_reference(
182
- decoder_model=decoder_model,
183
- reference_audio=audio_to_bytes(str(ref_audio)),
184
- enable_reference_audio=True,
185
- )
186
- for ref_audio in ref_audios
187
- ]
188
- prompt_texts = [
189
- read_ref_text(str(ref_audio.with_suffix(".lab")))
190
- for ref_audio in ref_audios
191
- ]
192
-
193
- else:
194
- # Parse reference audio aka prompt
195
- refs = req.references
196
- if refs is None:
197
- refs = []
198
- prompt_tokens = [
199
- encode_reference(
200
- decoder_model=decoder_model,
201
- reference_audio=ref.audio,
202
- enable_reference_audio=True,
203
- )
204
- for ref in refs
205
- ]
206
- prompt_texts = [ref.text for ref in refs]
207
-
208
- # LLAMA Inference
209
- request = dict(
210
- device=decoder_model.device,
211
- max_new_tokens=req.max_new_tokens,
212
- text=(
213
- req.text
214
- if not req.normalize
215
- else ChnNormedText(raw_text=req.text).normalize()
216
- ),
217
- top_p=req.top_p,
218
- repetition_penalty=req.repetition_penalty,
219
- temperature=req.temperature,
220
- compile=args.compile,
221
- iterative_prompt=req.chunk_length > 0,
222
- chunk_length=req.chunk_length,
223
- max_length=2048,
224
- prompt_tokens=prompt_tokens,
225
- prompt_text=prompt_texts,
226
- )
227
-
228
- response_queue = queue.Queue()
229
- llama_queue.put(
230
- GenerateRequest(
231
- request=request,
232
- response_queue=response_queue,
233
- )
234
- )
235
-
236
- if req.streaming:
237
- yield wav_chunk_header()
238
-
239
- segments = []
240
- while True:
241
- result: WrappedGenerateResponse = response_queue.get()
242
- if result.status == "error":
243
- raise result.response
244
- break
245
-
246
- result: GenerateResponse = result.response
247
- if result.action == "next":
248
- break
249
-
250
- with autocast_exclude_mps(
251
- device_type=decoder_model.device.type, dtype=args.precision
252
- ):
253
- fake_audios = decode_vq_tokens(
254
- decoder_model=decoder_model,
255
- codes=result.codes,
256
- )
257
-
258
- fake_audios = fake_audios.float().cpu().numpy()
259
-
260
- if req.streaming:
261
- yield (fake_audios * 32768).astype(np.int16).tobytes()
262
- else:
263
- segments.append(fake_audios)
264
-
265
- if req.streaming:
266
- return
267
-
268
- if len(segments) == 0:
269
- raise HTTPException(
270
- HTTPStatus.INTERNAL_SERVER_ERROR,
271
- content="No audio generated, please check the input text.",
272
- )
273
-
274
- fake_audios = np.concatenate(segments, axis=0)
275
- yield fake_audios
276
-
277
-
278
- async def inference_async(req: ServeTTSRequest):
279
- for chunk in inference(req):
280
- yield chunk
281
-
282
-
283
- async def buffer_to_async_generator(buffer):
284
- yield buffer
285
-
286
-
287
- @routes.http.post("/v1/tts")
288
- async def api_invoke_model(
289
- req: Annotated[ServeTTSRequest, Body(exclusive=True)],
290
- ):
291
- """
292
- Invoke model and generate audio
293
- """
294
-
295
- if args.max_text_length > 0 and len(req.text) > args.max_text_length:
296
- raise HTTPException(
297
- HTTPStatus.BAD_REQUEST,
298
- content=f"Text is too long, max length is {args.max_text_length}",
299
- )
300
-
301
- if req.streaming and req.format != "wav":
302
- raise HTTPException(
303
- HTTPStatus.BAD_REQUEST,
304
- content="Streaming only supports WAV format",
305
- )
306
-
307
- if req.streaming:
308
- return StreamResponse(
309
- iterable=inference_async(req),
310
- headers={
311
- "Content-Disposition": f"attachment; filename=audio.{req.format}",
312
- },
313
- content_type=get_content_type(req.format),
314
- )
315
- else:
316
- fake_audios = next(inference(req))
317
- buffer = io.BytesIO()
318
- sf.write(
319
- buffer,
320
- fake_audios,
321
- decoder_model.spec_transform.sample_rate,
322
- format=req.format,
323
- )
324
-
325
- return StreamResponse(
326
- iterable=buffer_to_async_generator(buffer.getvalue()),
327
- headers={
328
- "Content-Disposition": f"attachment; filename=audio.{req.format}",
329
- },
330
- content_type=get_content_type(req.format),
331
- )
332
-
333
-
334
- @routes.http.post("/v1/health")
335
- async def api_health():
336
- """
337
- Health check
338
- """
339
-
340
- return JSONResponse({"status": "ok"})
341
-
342
-
343
- def parse_args():
344
- parser = ArgumentParser()
345
- parser.add_argument(
346
- "--llama-checkpoint-path",
347
- type=str,
348
- default="checkpoints/fish-speech-1.4",
349
- )
350
- parser.add_argument(
351
- "--decoder-checkpoint-path",
352
- type=str,
353
- default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
354
- )
355
- parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
356
- parser.add_argument("--device", type=str, default="cuda")
357
- parser.add_argument("--half", action="store_true")
358
- parser.add_argument("--compile", action="store_true")
359
- parser.add_argument("--max-text-length", type=int, default=0)
360
- parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
361
- parser.add_argument("--workers", type=int, default=1)
362
-
363
- return parser.parse_args()
364
-
365
-
366
- # Define Kui app
367
- openapi = OpenAPI(
368
- {
369
- "title": "Fish Speech API",
370
- },
371
- ).routes
372
-
373
-
374
- class MsgPackRequest(HttpRequest):
375
- async def data(self) -> Annotated[Any, ContentType("application/msgpack")]:
376
- if self.content_type == "application/msgpack":
377
- return ormsgpack.unpackb(await self.body)
378
-
379
- raise HTTPException(
380
- HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
381
- headers={"Accept": "application/msgpack"},
382
- )
383
-
384
-
385
- app = Kui(
386
- routes=routes + openapi[1:], # Remove the default route
387
- exception_handlers={
388
- HTTPException: http_execption_handler,
389
- Exception: other_exception_handler,
390
- },
391
- factory_class=FactoryClass(http=MsgPackRequest),
392
- cors_config={},
393
- )
394
-
395
-
396
- if __name__ == "__main__":
397
-
398
- import uvicorn
399
-
400
- args = parse_args()
401
- args.precision = torch.half if args.half else torch.bfloat16
402
-
403
- logger.info("Loading Llama model...")
404
- llama_queue = launch_thread_safe_queue(
405
- checkpoint_path=args.llama_checkpoint_path,
406
- device=args.device,
407
- precision=args.precision,
408
- compile=args.compile,
409
- )
410
- logger.info("Llama model loaded, loading VQ-GAN model...")
411
-
412
- decoder_model = load_decoder_model(
413
- config_name=args.decoder_config_name,
414
- checkpoint_path=args.decoder_checkpoint_path,
415
- device=args.device,
416
- )
417
-
418
- logger.info("VQ-GAN model loaded, warming up...")
419
-
420
- # Dry run to check if the model is loaded correctly and avoid the first-time latency
421
- list(
422
- inference(
423
- ServeTTSRequest(
424
- text="Hello world.",
425
- references=[],
426
- reference_id=None,
427
- max_new_tokens=1024,
428
- chunk_length=200,
429
- top_p=0.7,
430
- repetition_penalty=1.2,
431
- temperature=0.7,
432
- emotion=None,
433
- format="wav",
434
- )
435
- )
436
- )
437
-
438
- logger.info(f"Warming up done, starting server at http://{args.listen}")
439
- host, port = args.listen.split(":")
440
- uvicorn.run(app, host=host, port=int(port), workers=args.workers, log_level="info")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/auto_rerank.py DELETED
@@ -1,159 +0,0 @@
1
- import os
2
-
3
- os.environ["MODELSCOPE_CACHE"] = ".cache/"
4
-
5
- import string
6
- import time
7
- from threading import Lock
8
-
9
- import librosa
10
- import numpy as np
11
- import opencc
12
- import torch
13
- from faster_whisper import WhisperModel
14
-
15
- t2s_converter = opencc.OpenCC("t2s")
16
-
17
-
18
- def load_model(*, device="cuda"):
19
- model = WhisperModel(
20
- "medium",
21
- device=device,
22
- compute_type="float16",
23
- download_root="faster_whisper",
24
- )
25
- print("faster_whisper loaded!")
26
- return model
27
-
28
-
29
- @torch.no_grad()
30
- def batch_asr_internal(model: WhisperModel, audios, sr):
31
- resampled_audios = []
32
- for audio in audios:
33
-
34
- if isinstance(audio, np.ndarray):
35
- audio = torch.from_numpy(audio).float()
36
-
37
- if audio.dim() > 1:
38
- audio = audio.squeeze()
39
-
40
- assert audio.dim() == 1
41
- audio_np = audio.numpy()
42
- resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000)
43
- resampled_audios.append(resampled_audio)
44
-
45
- trans_results = []
46
-
47
- for resampled_audio in resampled_audios:
48
- segments, info = model.transcribe(
49
- resampled_audio,
50
- language=None,
51
- beam_size=5,
52
- initial_prompt="Punctuation is needed in any language.",
53
- )
54
- trans_results.append(list(segments))
55
-
56
- results = []
57
- for trans_res, audio in zip(trans_results, audios):
58
-
59
- duration = len(audio) / sr * 1000
60
- huge_gap = False
61
- max_gap = 0.0
62
-
63
- text = None
64
- last_tr = None
65
-
66
- for tr in trans_res:
67
- delta = tr.text.strip()
68
- if tr.id > 1:
69
- max_gap = max(tr.start - last_tr.end, max_gap)
70
- text += delta
71
- else:
72
- text = delta
73
-
74
- last_tr = tr
75
- if max_gap > 3.0:
76
- huge_gap = True
77
- break
78
-
79
- sim_text = t2s_converter.convert(text)
80
- results.append(
81
- {
82
- "text": sim_text,
83
- "duration": duration,
84
- "huge_gap": huge_gap,
85
- }
86
- )
87
-
88
- return results
89
-
90
-
91
- global_lock = Lock()
92
-
93
-
94
- def batch_asr(model, audios, sr):
95
- return batch_asr_internal(model, audios, sr)
96
-
97
-
98
- def is_chinese(text):
99
- return True
100
-
101
-
102
- def calculate_wer(text1, text2, debug=False):
103
- chars1 = remove_punctuation(text1)
104
- chars2 = remove_punctuation(text2)
105
-
106
- m, n = len(chars1), len(chars2)
107
-
108
- if m > n:
109
- chars1, chars2 = chars2, chars1
110
- m, n = n, m
111
-
112
- prev = list(range(m + 1)) # row 0 distance: [0, 1, 2, ...]
113
- curr = [0] * (m + 1)
114
-
115
- for j in range(1, n + 1):
116
- curr[0] = j
117
- for i in range(1, m + 1):
118
- if chars1[i - 1] == chars2[j - 1]:
119
- curr[i] = prev[i - 1]
120
- else:
121
- curr[i] = min(prev[i], curr[i - 1], prev[i - 1]) + 1
122
- prev, curr = curr, prev
123
-
124
- edits = prev[m]
125
- tot = max(len(chars1), len(chars2))
126
- wer = edits / tot
127
-
128
- if debug:
129
- print(" gt: ", chars1)
130
- print(" pred: ", chars2)
131
- print(" edits/tot = wer: ", edits, "/", tot, "=", wer)
132
-
133
- return wer
134
-
135
-
136
- def remove_punctuation(text):
137
- chinese_punctuation = (
138
- " \n\t”“!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—"
139
- '‛""„‟…‧﹏'
140
- )
141
- all_punctuation = string.punctuation + chinese_punctuation
142
- translator = str.maketrans("", "", all_punctuation)
143
- text_without_punctuation = text.translate(translator)
144
- return text_without_punctuation
145
-
146
-
147
- if __name__ == "__main__":
148
- model = load_model()
149
- audios = [
150
- librosa.load("44100.wav", sr=44100)[0],
151
- librosa.load("lengyue.wav", sr=44100)[0],
152
- ]
153
- print(np.array(audios[0]))
154
- print(batch_asr(model, audios, 44100))
155
-
156
- start_time = time.time()
157
- for _ in range(10):
158
- print(batch_asr(model, audios, 44100))
159
- print("Time taken:", time.time() - start_time)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/commons.py DELETED
@@ -1,35 +0,0 @@
1
- from typing import Annotated, Literal, Optional
2
-
3
- from pydantic import BaseModel, Field, conint
4
-
5
-
6
- class ServeReferenceAudio(BaseModel):
7
- audio: bytes
8
- text: str
9
-
10
-
11
- class ServeTTSRequest(BaseModel):
12
- text: str
13
- chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
14
- # Audio format
15
- format: Literal["wav", "pcm", "mp3"] = "wav"
16
- mp3_bitrate: Literal[64, 128, 192] = 128
17
- # References audios for in-context learning
18
- references: list[ServeReferenceAudio] = []
19
- # Reference id
20
- # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
21
- # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
22
- reference_id: str | None = None
23
- # Normalize text for en & zh, this increase stability for numbers
24
- normalize: bool = True
25
- mp3_bitrate: Optional[int] = 64
26
- opus_bitrate: Optional[int] = -1000
27
- # Balance mode will reduce latency to 300ms, but may decrease stability
28
- latency: Literal["normal", "balanced"] = "normal"
29
- # not usually used below
30
- streaming: bool = False
31
- emotion: Optional[str] = None
32
- max_new_tokens: int = 1024
33
- top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
34
- repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
35
- temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/download_models.py DELETED
@@ -1,55 +0,0 @@
1
- import os
2
-
3
- from huggingface_hub import hf_hub_download
4
-
5
-
6
- # Download
7
- def check_and_download_files(repo_id, file_list, local_dir):
8
- os.makedirs(local_dir, exist_ok=True)
9
- for file in file_list:
10
- file_path = os.path.join(local_dir, file)
11
- if not os.path.exists(file_path):
12
- print(f"{file} 不存在,从 Hugging Face 仓库下载...")
13
- hf_hub_download(
14
- repo_id=repo_id,
15
- filename=file,
16
- resume_download=True,
17
- local_dir=local_dir,
18
- local_dir_use_symlinks=False,
19
- )
20
- else:
21
- print(f"{file} 已存在,跳过下载。")
22
-
23
-
24
- # 1st
25
- repo_id_1 = "fishaudio/fish-speech-1.4"
26
- local_dir_1 = "./checkpoints/fish-speech-1.4"
27
- files_1 = [
28
- "model.pth",
29
- "README.md",
30
- "special_tokens_map.json",
31
- "tokenizer_config.json",
32
- "tokenizer.json",
33
- "config.json",
34
- "firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
35
- ]
36
-
37
- # 3rd
38
- repo_id_3 = "fishaudio/fish-speech-1"
39
- local_dir_3 = "./"
40
- files_3 = [
41
- "ffmpeg.exe",
42
- "ffprobe.exe",
43
- ]
44
-
45
- # 4th
46
- repo_id_4 = "SpicyqSama007/fish-speech-packed"
47
- local_dir_4 = "./"
48
- files_4 = [
49
- "asr-label-win-x64.exe",
50
- ]
51
-
52
- check_and_download_files(repo_id_1, files_1, local_dir_1)
53
-
54
- check_and_download_files(repo_id_3, files_3, local_dir_3)
55
- check_and_download_files(repo_id_4, files_4, local_dir_4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/extract_model.py DELETED
@@ -1,21 +0,0 @@
1
- import click
2
- import torch
3
- from loguru import logger
4
-
5
-
6
- @click.command()
7
- @click.argument("model_path")
8
- @click.argument("output_path")
9
- def main(model_path, output_path):
10
- if model_path == output_path:
11
- logger.error("Model path and output path are the same")
12
- return
13
-
14
- logger.info(f"Loading model from {model_path}")
15
- state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
16
- torch.save(state_dict, output_path)
17
- logger.info(f"Model saved to {output_path}")
18
-
19
-
20
- if __name__ == "__main__":
21
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/file.py DELETED
@@ -1,125 +0,0 @@
1
- import base64
2
- from pathlib import Path
3
- from typing import Union
4
-
5
- from loguru import logger
6
- from natsort import natsorted
7
-
8
- AUDIO_EXTENSIONS = {
9
- ".mp3",
10
- ".wav",
11
- ".flac",
12
- ".ogg",
13
- ".m4a",
14
- ".wma",
15
- ".aac",
16
- ".aiff",
17
- ".aif",
18
- ".aifc",
19
- }
20
-
21
- VIDEO_EXTENSIONS = {
22
- ".mp4",
23
- ".avi",
24
- }
25
-
26
-
27
- def audio_to_bytes(file_path):
28
- if not file_path or not Path(file_path).exists():
29
- return None
30
- with open(file_path, "rb") as wav_file:
31
- wav = wav_file.read()
32
- return wav
33
-
34
-
35
- def read_ref_text(ref_text):
36
- path = Path(ref_text)
37
- if path.exists() and path.is_file():
38
- with path.open("r", encoding="utf-8") as file:
39
- return file.read()
40
- return ref_text
41
-
42
-
43
- def list_files(
44
- path: Union[Path, str],
45
- extensions: set[str] = None,
46
- recursive: bool = False,
47
- sort: bool = True,
48
- ) -> list[Path]:
49
- """List files in a directory.
50
-
51
- Args:
52
- path (Path): Path to the directory.
53
- extensions (set, optional): Extensions to filter. Defaults to None.
54
- recursive (bool, optional): Whether to search recursively. Defaults to False.
55
- sort (bool, optional): Whether to sort the files. Defaults to True.
56
-
57
- Returns:
58
- list: List of files.
59
- """
60
-
61
- if isinstance(path, str):
62
- path = Path(path)
63
-
64
- if not path.exists():
65
- raise FileNotFoundError(f"Directory {path} does not exist.")
66
-
67
- files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
68
-
69
- if sort:
70
- files = natsorted(files)
71
-
72
- return files
73
-
74
-
75
- def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
76
- """
77
- Load a Bert-VITS2 style filelist.
78
- """
79
-
80
- files = set()
81
- results = []
82
- count_duplicated, count_not_found = 0, 0
83
-
84
- LANGUAGE_TO_LANGUAGES = {
85
- "zh": ["zh", "en"],
86
- "jp": ["jp", "en"],
87
- "en": ["en"],
88
- }
89
-
90
- with open(path, "r", encoding="utf-8") as f:
91
- for line in f.readlines():
92
- splits = line.strip().split("|", maxsplit=3)
93
- if len(splits) != 4:
94
- logger.warning(f"Invalid line: {line}")
95
- continue
96
-
97
- filename, speaker, language, text = splits
98
- file = Path(filename)
99
- language = language.strip().lower()
100
-
101
- if language == "ja":
102
- language = "jp"
103
-
104
- assert language in ["zh", "jp", "en"], f"Invalid language {language}"
105
- languages = LANGUAGE_TO_LANGUAGES[language]
106
-
107
- if file in files:
108
- logger.warning(f"Duplicated file: {file}")
109
- count_duplicated += 1
110
- continue
111
-
112
- if not file.exists():
113
- logger.warning(f"File not found: {file}")
114
- count_not_found += 1
115
- continue
116
-
117
- results.append((file, speaker, languages, text))
118
-
119
- if count_duplicated > 0:
120
- logger.warning(f"Total duplicated files: {count_duplicated}")
121
-
122
- if count_not_found > 0:
123
- logger.warning(f"Total files not found: {count_not_found}")
124
-
125
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/llama/__pycache__/generate.cpython-310.pyc DELETED
Binary file (15.1 kB)
 
tools/llama/build_dataset.py DELETED
@@ -1,169 +0,0 @@
1
- import itertools
2
- import os
3
- import re
4
- from collections import defaultdict
5
- from functools import partial
6
- from multiprocessing import Pool
7
- from pathlib import Path
8
-
9
- import click
10
- import numpy as np
11
- from loguru import logger
12
- from tqdm import tqdm
13
-
14
- from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
15
- from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
16
- from tools.file import load_filelist
17
-
18
- # To avoid CPU overload
19
- os.environ["MKL_NUM_THREADS"] = "1"
20
- os.environ["OMP_NUM_THREADS"] = "1"
21
-
22
-
23
- def task_generator_folder(root: Path, text_extension: str):
24
- files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
25
- files = sorted(files)
26
-
27
- grouped_files = defaultdict(list)
28
- for file in tqdm(files, desc=f"Grouping {root}"):
29
- p = str(file.parent)
30
- speaker = file.parent.name
31
-
32
- try:
33
- if isinstance(text_extension, str):
34
- texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")]
35
- else:
36
- texts = [
37
- file.with_suffix(ext).read_text(encoding="utf-8")
38
- for ext in text_extension
39
- ]
40
- except Exception as e:
41
- logger.error(f"Failed to read text {file}: {e}")
42
- continue
43
-
44
- grouped_files[p].append((speaker, file, texts))
45
-
46
- logger.info(
47
- f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
48
- )
49
-
50
- for i in grouped_files.values():
51
- subset = [(f, t) for _, f, t in i]
52
- yield i[0][0], subset, "folder"
53
-
54
-
55
- def task_generator_filelist(filelist):
56
- grouped_files = defaultdict(list)
57
- for filename, speaker, _, text in load_filelist(filelist):
58
- grouped_files[speaker].append((Path(filename), [text]))
59
-
60
- logger.info(f"Found {len(grouped_files)} groups in {filelist}")
61
- for speaker, values in grouped_files.items():
62
- yield speaker, values, "filelist"
63
-
64
-
65
- def run_task(task):
66
- name, subset, source = task
67
-
68
- # Parse the files
69
- sentences = []
70
- for file, texts in subset:
71
- np_file = file.with_suffix(".npy")
72
- if np_file.exists() is False:
73
- logger.warning(f"Can't find {np_file}")
74
- continue
75
-
76
- new_texts = []
77
-
78
- for text in texts:
79
- # Simple cleaning: replace { xxx } and < xxx > with space
80
- text = re.sub(r"\{.*?\}", " ", text)
81
- text = re.sub(r"<.*?>", " ", text)
82
- text = re.sub(r"\s+", " ", text)
83
- new_texts.append(text)
84
-
85
- try:
86
- semantics = np.load(np_file)
87
- except Exception as e:
88
- logger.error(f"Failed to parse {file}: {e}")
89
- continue
90
-
91
- if isinstance(semantics, np.ndarray):
92
- semantics = semantics.tolist()
93
-
94
- sentences.append(
95
- Sentence(
96
- texts=new_texts,
97
- semantics=[Semantics(values=s) for s in semantics],
98
- )
99
- )
100
-
101
- # Pack the sentences
102
- return pack_pb_stream(
103
- TextData(
104
- source=source,
105
- name=name,
106
- sentences=sentences,
107
- )
108
- )
109
-
110
-
111
- @click.command()
112
- @click.option(
113
- "--input",
114
- type=click.Path(path_type=Path),
115
- required=True,
116
- help="A folder containing the dataset or a filelist",
117
- multiple=True,
118
- )
119
- @click.option(
120
- "--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
121
- )
122
- @click.option("--num-workers", type=int, default=16)
123
- @click.option("--text-extension", type=str, default=[".txt"], multiple=True)
124
- @click.option(
125
- "--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
126
- )
127
- def main(input, output, num_workers, text_extension, shard_size):
128
- generator_fns = []
129
-
130
- for f in input:
131
- assert f.exists(), f"{f} not found"
132
-
133
- if f.is_dir():
134
- generator_fn = task_generator_folder(f, text_extension)
135
- else:
136
- generator_fn = task_generator_filelist(f)
137
-
138
- generator_fns.append(generator_fn)
139
-
140
- generator_fn = itertools.chain(*generator_fns)
141
- output.mkdir(parents=True, exist_ok=True)
142
-
143
- dataset_fp = None
144
- tar_idx = 0
145
- written_size = 0
146
-
147
- with Pool(num_workers) as p:
148
- for result in tqdm(p.imap_unordered(run_task, generator_fn)):
149
- if dataset_fp is None:
150
- dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
151
-
152
- dataset_fp.write(result)
153
- written_size += len(result)
154
-
155
- if written_size > shard_size * 1024 * 1024:
156
- logger.info(f"Finished writing {tar_idx} shards to {output}")
157
- dataset_fp.close()
158
- dataset_fp = None
159
- written_size = 0
160
- tar_idx += 1
161
-
162
- if dataset_fp is not None:
163
- dataset_fp.close()
164
-
165
- logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
166
-
167
-
168
- if __name__ == "__main__":
169
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/llama/eval_in_context.py DELETED
@@ -1,171 +0,0 @@
1
- import pyrootutils
2
- import torch
3
- import torch.nn.functional as F
4
- from matplotlib import pyplot as plt
5
- from transformers import AutoTokenizer
6
-
7
- # register eval resolver and root
8
- pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
9
-
10
- from torch.utils.data import DataLoader
11
-
12
- from fish_speech.datasets.semantic import AutoAugTextDataset, TextDataCollator
13
- from tools.llama.generate import load_model
14
-
15
-
16
- def smooth(
17
- scalars: list[float], weight: float
18
- ) -> list[float]: # Weight between 0 and 1
19
- last = scalars[0] # First value in the plot (first timestep)
20
- smoothed = list()
21
- for point in scalars:
22
- smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value
23
- smoothed.append(smoothed_val) # Save it
24
- last = smoothed_val # Anchor the last smoothed value
25
-
26
- return smoothed
27
-
28
-
29
- @torch.inference_mode()
30
- def analyze_one_model(loader, config, weight, max_length):
31
- device = "cuda" if torch.cuda.is_available() else "cpu"
32
- model = load_model(
33
- config,
34
- weight,
35
- device,
36
- torch.bfloat16,
37
- max_length,
38
- compile=False,
39
- )[0]
40
-
41
- current_step = 0
42
- model.eval()
43
-
44
- semantic_loss_sum = torch.zeros(
45
- max_length,
46
- dtype=torch.float32,
47
- device=device,
48
- )
49
- counter = torch.zeros(
50
- max_length,
51
- dtype=torch.long,
52
- device=device,
53
- )
54
-
55
- for batch in loader:
56
- batch = {k: v.to(device) for k, v in batch.items()}
57
-
58
- labels = batch["labels"]
59
- outputs = model(
60
- inp=batch["inputs"],
61
- key_padding_mask=batch["attention_masks"],
62
- )
63
-
64
- token_logits = outputs.token_logits
65
- codebook_logits = outputs.codebook_logits
66
-
67
- # Generate labels
68
- base_loss = F.cross_entropy(
69
- token_logits.reshape(-1, token_logits.size(-1)),
70
- labels[:, 0].reshape(-1),
71
- ignore_index=-100,
72
- reduction="none",
73
- )
74
-
75
- codebook_labels = labels[:, 1 : 1 + model.config.num_codebooks].mT
76
- semantic_loss = F.cross_entropy(
77
- codebook_logits.reshape(-1, codebook_logits.size(-1)),
78
- codebook_labels.reshape(-1),
79
- ignore_index=-100,
80
- reduction="none",
81
- )
82
-
83
- base_loss = base_loss.reshape(labels[:, 0].shape)
84
- semantic_loss = semantic_loss.reshape(codebook_labels.shape)
85
-
86
- semantic_loss_frame = semantic_loss.mean(-1)
87
- pad_pos = codebook_labels.sum(-1) == -100 * model.config.num_codebooks
88
-
89
- for loss_sample, pad in zip(semantic_loss_frame, pad_pos):
90
- semantic_loss_sum[~pad] += loss_sample[~pad]
91
- counter[~pad] += 1
92
-
93
- current_step += 1
94
- if current_step == 10:
95
- break
96
-
97
- semantic_loss = semantic_loss.cpu()
98
- counter = counter.cpu()
99
- xs, ys = [], []
100
-
101
- for i, (loss, count) in enumerate(zip(semantic_loss_sum, counter)):
102
- if count > 0:
103
- xs.append(i)
104
- ys.append((loss / count).item()) # for better loss visualization
105
-
106
- smoothed_ys = smooth(ys, 0.95)
107
-
108
- # Unload model
109
- del model
110
- torch.cuda.empty_cache()
111
-
112
- return xs, ys, smoothed_ys
113
-
114
-
115
- def main():
116
- tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
117
- max_length = 4096
118
-
119
- ds = AutoAugTextDataset(
120
- ["data/protos/sft/云天河"],
121
- tokenizer=tokenizer,
122
- use_speaker=False,
123
- interactive_prob=1.0,
124
- max_length=max_length,
125
- )
126
-
127
- loader = DataLoader(
128
- ds,
129
- batch_size=8,
130
- collate_fn=TextDataCollator(tokenizer, max_length=max_length),
131
- num_workers=0,
132
- shuffle=False,
133
- )
134
-
135
- plt.figure(figsize=(10, 5), dpi=200)
136
-
137
- plt.xlabel("Frame")
138
- plt.ylabel("Loss")
139
- plt.yscale("log")
140
- plt.title("Semantic Loss")
141
- plt.grid(which="both", axis="both")
142
- plt.xlim(0, max_length)
143
-
144
- tests = [
145
- (
146
- "pertrain-medium",
147
- "dual_ar_2_codebook_medium",
148
- "checkpoints/text2semantic-pretrain-medium-2k-v1.pth",
149
- ),
150
- (
151
- "sft-medium",
152
- "dual_ar_2_codebook_medium",
153
- "checkpoints/text2semantic-sft-medium-v1.1-4k.pth",
154
- ),
155
- (
156
- "sft-large",
157
- "dual_ar_2_codebook_large",
158
- "checkpoints/text2semantic-sft-large-v1.1-4k.pth",
159
- ),
160
- ]
161
-
162
- for name, config, weight in tests:
163
- xs, _, smoothed_ys = analyze_one_model(loader, config, weight, max_length)
164
- plt.plot(xs, smoothed_ys, label=name)
165
-
166
- plt.legend()
167
- plt.savefig("semantic_loss.png")
168
-
169
-
170
- if __name__ == "__main__":
171
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/llama/generate.py DELETED
@@ -1,714 +0,0 @@
1
- import os
2
- import queue
3
- import threading
4
- import time
5
- from contextlib import nullcontext
6
- from dataclasses import dataclass
7
- from pathlib import Path
8
- from typing import Literal, Optional, Tuple, Union
9
-
10
- import click
11
- import hydra
12
- import numpy as np
13
- import torch
14
- import torch._dynamo.config
15
- import torch._inductor.config
16
- from loguru import logger
17
- from tqdm import tqdm
18
-
19
- from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
20
- from fish_speech.text import clean_text, split_text
21
- torch.cuda.is_available = lambda: False
22
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
23
- torch._inductor.config.coordinate_descent_tuning = True
24
- torch._inductor.config.triton.unique_kernel_names = True
25
-
26
-
27
- if hasattr(torch._inductor.config, "fx_graph_cache"):
28
- # Experimental feature to reduce compilation times, will be on by default in future
29
- torch._inductor.config.fx_graph_cache = True
30
-
31
-
32
- from fish_speech.models.text2semantic.llama import (
33
- BaseTransformer,
34
- DualARTransformer,
35
- NaiveTransformer,
36
- )
37
-
38
-
39
- def multinomial_sample_one_no_sync(
40
- probs_sort,
41
- ): # Does multinomial sampling without a cuda synchronization
42
- q = torch.empty_like(probs_sort).exponential_(1)
43
- return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
44
-
45
-
46
- def logits_to_probs(
47
- logits,
48
- previous_tokens: Optional[torch.Tensor] = None,
49
- temperature: torch.Tensor = 1.0,
50
- top_p: torch.Tensor = 1.0,
51
- repetition_penalty: torch.Tensor = 1.0,
52
- ) -> torch.Tensor:
53
- # Apply repetition penalty
54
- if previous_tokens is not None:
55
- previous_tokens = previous_tokens.long()
56
- score = torch.gather(logits, dim=0, index=previous_tokens)
57
- score = torch.where(
58
- score < 0, score * repetition_penalty, score / repetition_penalty
59
- )
60
- logits.scatter_(dim=0, index=previous_tokens, src=score)
61
-
62
- # Apply top-p sampling
63
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
64
- cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
65
- sorted_indices_to_remove = cum_probs > top_p
66
- sorted_indices_to_remove[0] = False # keep at least one option
67
- indices_to_remove = sorted_indices_to_remove.scatter(
68
- dim=0, index=sorted_indices, src=sorted_indices_to_remove
69
- )
70
- logits = logits.masked_fill(indices_to_remove, -float("Inf"))
71
-
72
- logits = logits / max(temperature, 1e-5)
73
-
74
- probs = torch.nn.functional.softmax(logits, dim=-1)
75
- return probs
76
-
77
-
78
- def sample(
79
- logits,
80
- previous_tokens: Optional[torch.Tensor] = None,
81
- **sampling_kwargs,
82
- ) -> Tuple[torch.Tensor, torch.Tensor]:
83
- probs = logits_to_probs(
84
- logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
85
- )
86
- idx_next = multinomial_sample_one_no_sync(probs)
87
- return idx_next, probs
88
-
89
-
90
- def decode_one_token_ar(
91
- model: DualARTransformer,
92
- x: torch.Tensor,
93
- input_pos: torch.Tensor,
94
- previous_tokens: torch.Tensor = None,
95
- **sampling_kwargs,
96
- ) -> torch.Tensor:
97
- x = model.forward_generate(x, input_pos)
98
-
99
- sampling_kwargs_main = sampling_kwargs.copy()
100
- sampling_kwargs_main["temperature"] = 0.1
101
- sampling_kwargs_main["top_p"] = 0.1
102
- sampling_kwargs_main["repetition_penalty"] = 1.0
103
-
104
- codebooks = [
105
- sample(
106
- x.logits,
107
- previous_tokens=None, # Disable repetition penalty for the token codebook
108
- **sampling_kwargs_main,
109
- )[0]
110
- ]
111
-
112
- x = x.hidden_states
113
-
114
- # Cleanup the cache
115
- for layer in model.fast_layers:
116
- layer.attention.kv_cache.k_cache.fill_(0)
117
- layer.attention.kv_cache.v_cache.fill_(0)
118
-
119
- for codebook_idx in range(model.config.num_codebooks):
120
- input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
121
- logits = model.forward_generate_fast(x, input_pos)
122
- a = sample(
123
- logits,
124
- previous_tokens=(
125
- previous_tokens[codebook_idx + 1]
126
- if previous_tokens is not None
127
- else None
128
- ),
129
- **sampling_kwargs,
130
- )[0]
131
- x = model.fast_embeddings(a)
132
- codebooks.append(a)
133
-
134
- return torch.stack(codebooks, dim=0)
135
-
136
-
137
- def decode_one_token_naive(
138
- model: NaiveTransformer,
139
- x: torch.Tensor,
140
- input_pos: torch.Tensor,
141
- previous_tokens: torch.Tensor = None,
142
- **sampling_kwargs,
143
- ) -> torch.Tensor:
144
- x = model.forward_generate(x, input_pos)
145
-
146
- sampling_kwargs_main = sampling_kwargs.copy()
147
- sampling_kwargs_main["temperature"] = 0.1
148
- sampling_kwargs_main["top_p"] = 0.1
149
- sampling_kwargs_main["repetition_penalty"] = 1.0
150
-
151
- codebooks = [
152
- sample(
153
- x.logits,
154
- previous_tokens=None, # Disable repetition penalty for the token codebook
155
- **sampling_kwargs_main,
156
- )[0]
157
- ]
158
-
159
- for i in range(model.config.num_codebooks):
160
- codebooks.append(
161
- sample(
162
- x.codebook_logits[:, :, i],
163
- previous_tokens=(
164
- previous_tokens[i + 1] if previous_tokens is not None else None
165
- ),
166
- **sampling_kwargs,
167
- )[0]
168
- )
169
-
170
- return torch.stack(codebooks, dim=0)
171
-
172
-
173
- def decode_n_tokens(
174
- model: NaiveTransformer,
175
- cur_token: torch.Tensor,
176
- input_pos: torch.Tensor,
177
- num_new_tokens: int,
178
- im_end_id: int = 4,
179
- decode_one_token=decode_one_token_naive,
180
- **sampling_kwargs,
181
- ):
182
- previous_tokens = torch.zeros(
183
- (model.config.num_codebooks + 1, model.config.max_seq_len),
184
- dtype=torch.int,
185
- device=cur_token.device,
186
- )
187
-
188
- for i in tqdm(range(num_new_tokens)):
189
- # We need to get windowed repeat penalty
190
- win_size = 16
191
- if i < win_size:
192
- window = previous_tokens[:, :win_size]
193
- else:
194
- window = previous_tokens[:, i - win_size : i]
195
-
196
- with (
197
- torch.backends.cuda.sdp_kernel(
198
- enable_flash=False, enable_mem_efficient=False, enable_math=True
199
- )
200
- if torch.cuda.is_available()
201
- else nullcontext()
202
- ): # Actually better for Inductor to codegen attention here
203
- next_token = decode_one_token(
204
- model=model,
205
- x=cur_token,
206
- input_pos=input_pos,
207
- previous_tokens=window,
208
- **sampling_kwargs,
209
- )
210
-
211
- input_pos += 1
212
- cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
213
- previous_tokens[:, i : i + 1] = next_token.view(
214
- model.config.num_codebooks + 1, -1
215
- )
216
-
217
- if cur_token[0, 0, -1] == im_end_id:
218
- break
219
-
220
- return previous_tokens[:, : i + 1]
221
-
222
-
223
- @torch.no_grad()
224
- @torch.inference_mode()
225
- def generate(
226
- *,
227
- model: NaiveTransformer,
228
- prompt: torch.Tensor,
229
- max_new_tokens: int,
230
- im_end_id: int = 4,
231
- decode_one_token=decode_one_token_naive,
232
- **sampling_kwargs,
233
- ) -> torch.Tensor:
234
- """
235
- Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
236
- """
237
-
238
- # create an empty tensor of the expected final shape and fill in the current tokens
239
- T = prompt.size(1)
240
-
241
- if max_new_tokens:
242
- if T + max_new_tokens > model.config.max_seq_len:
243
- max_new_tokens = model.config.max_seq_len - T
244
- logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
245
-
246
- T_new = T + max_new_tokens
247
- else:
248
- T_new = model.config.max_seq_len
249
- max_new_tokens = T_new - T
250
-
251
- device, dtype = prompt.device, prompt.dtype
252
- with torch.device(device):
253
- model.setup_caches(
254
- max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype
255
- )
256
-
257
- codebook_dim = 1 + model.config.num_codebooks
258
- # create an empty tensor of the expected final shape and fill in the current tokens
259
- empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
260
- empty[:, :T] = prompt
261
- seq = empty
262
- input_pos = torch.arange(0, T, device=device)
263
-
264
- # Use non-accelerated version for now, to avoid compilation overhead
265
- prefill_decode = (
266
- decode_one_token_naive
267
- if isinstance(model, NaiveTransformer)
268
- else decode_one_token_ar
269
- )
270
-
271
- next_token = prefill_decode(
272
- model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
273
- )
274
- seq[:, T : T + 1] = next_token
275
-
276
- input_pos = torch.tensor([T], device=device, dtype=torch.int)
277
- x = decode_n_tokens(
278
- model,
279
- next_token.view(1, codebook_dim, -1),
280
- input_pos,
281
- max_new_tokens - 1,
282
- im_end_id=im_end_id,
283
- decode_one_token=decode_one_token,
284
- **sampling_kwargs,
285
- )
286
- # x = torch.cat(generated_tokens, dim=1)
287
- seq = seq[:, : T + 1 + x.size(1)]
288
- seq[:, T + 1 :] = x
289
-
290
- return seq
291
-
292
-
293
- def encode_tokens(
294
- tokenizer,
295
- string,
296
- device="cuda",
297
- prompt_tokens=None,
298
- num_codebooks=4,
299
- ):
300
- string = clean_text(string)
301
- string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
302
-
303
- new_tokens = tokenizer.encode(
304
- string,
305
- add_special_tokens=False,
306
- max_length=10**6,
307
- truncation=False,
308
- )
309
- tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
310
-
311
- # Codebooks
312
- zeros = (
313
- torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
314
- * CODEBOOK_PAD_TOKEN_ID
315
- )
316
- prompt = torch.cat((tokens, zeros), dim=0)
317
-
318
- if prompt_tokens is None:
319
- return prompt
320
-
321
- # Get prompt tokens
322
- if prompt_tokens.ndim == 3:
323
- assert (
324
- prompt_tokens.shape[0] == 1
325
- ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
326
- prompt_tokens = prompt_tokens[0]
327
-
328
- assert prompt_tokens.ndim == 2
329
- data = prompt_tokens + 1
330
-
331
- if prompt_tokens.shape[0] > num_codebooks:
332
- logger.warning(
333
- f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
334
- )
335
- data = data[:num_codebooks]
336
-
337
- # Add pad token for each codebook
338
- data = torch.cat(
339
- (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
340
- dim=1,
341
- )
342
-
343
- # Since 1.0, we use <|semantic|>
344
- s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
345
- end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
346
- main_token_ids = (
347
- torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
348
- )
349
- main_token_ids[0, -1] = end_token_id
350
-
351
- data = torch.cat((main_token_ids, data), dim=0)
352
- prompt = torch.cat((prompt, data), dim=1)
353
-
354
- return prompt
355
-
356
-
357
- def load_model(checkpoint_path, device, precision, compile=False):
358
- model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
359
- checkpoint_path, load_weights=True
360
- )
361
-
362
- model = model.to(device=device, dtype=precision)
363
- logger.info(f"Restored model from checkpoint")
364
-
365
- if isinstance(model, DualARTransformer):
366
- decode_one_token = decode_one_token_ar
367
- logger.info("Using DualARTransformer")
368
- else:
369
- decode_one_token = decode_one_token_naive
370
- logger.info("Using NaiveTransformer")
371
-
372
- if compile:
373
- logger.info("Compiling function...")
374
- decode_one_token = torch.compile(
375
- decode_one_token,
376
- fullgraph=True,
377
- backend="inductor" if torch.cuda.is_available() else "aot_eager",
378
- mode="reduce-overhead" if torch.cuda.is_available() else None,
379
- )
380
-
381
- return model.eval(), decode_one_token
382
-
383
-
384
- @dataclass
385
- class GenerateResponse:
386
- action: Literal["sample", "next"]
387
- codes: Optional[torch.Tensor] = None
388
- text: Optional[str] = None
389
-
390
-
391
- def generate_long(
392
- *,
393
- model,
394
- device: str | torch.device,
395
- decode_one_token: callable,
396
- text: str,
397
- num_samples: int = 1,
398
- max_new_tokens: int = 0,
399
- top_p: int = 0.7,
400
- repetition_penalty: float = 1.5,
401
- temperature: float = 0.7,
402
- compile: bool = False,
403
- iterative_prompt: bool = True,
404
- max_length: int = 2048,
405
- chunk_length: int = 150,
406
- prompt_text: Optional[str | list[str]] = None,
407
- prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
408
- ):
409
- assert 0 < top_p <= 1, "top_p must be in (0, 1]"
410
- assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
411
- assert 0 < temperature < 2, "temperature must be in (0, 2)"
412
-
413
- use_prompt = prompt_text is not None and prompt_tokens is not None
414
- if use_prompt and isinstance(prompt_text, str):
415
- prompt_text = [prompt_text]
416
- prompt_tokens = [prompt_tokens]
417
-
418
- assert use_prompt is False or len(prompt_text) == len(
419
- prompt_tokens
420
- ), "Prompt text and tokens must have the same length"
421
-
422
- model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
423
- tokenizer = model.tokenizer
424
- im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
425
-
426
- encoded = []
427
- texts = split_text(text, chunk_length) if iterative_prompt else [text]
428
- encoded_prompts = []
429
-
430
- if use_prompt:
431
- for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
432
- encoded_prompts.append(
433
- encode_tokens(
434
- tokenizer,
435
- string=t,
436
- device=device,
437
- prompt_tokens=c,
438
- num_codebooks=model.config.num_codebooks,
439
- )
440
- )
441
-
442
- for idx, text in enumerate(texts):
443
- encoded.append(
444
- encode_tokens(
445
- tokenizer,
446
- string=text,
447
- device=device,
448
- num_codebooks=model.config.num_codebooks,
449
- )
450
- )
451
- logger.info(f"Encoded text: {text}")
452
-
453
- # Move temperature, top_p, repetition_penalty to device
454
- # This is important so that changing params doesn't trigger recompile
455
- temperature = torch.tensor(temperature, device=device, dtype=torch.float)
456
- top_p = torch.tensor(top_p, device=device, dtype=torch.float)
457
- repetition_penalty = torch.tensor(
458
- repetition_penalty, device=device, dtype=torch.float
459
- )
460
-
461
- for sample_idx in range(num_samples):
462
- if torch.cuda.is_available():
463
- torch.cuda.synchronize()
464
-
465
- global_encoded = []
466
- seg_idx = 0
467
-
468
- while seg_idx < len(encoded):
469
- logger.info(
470
- f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
471
- )
472
-
473
- seg = encoded[seg_idx]
474
- global_encoded.append(seg)
475
-
476
- lengths = reversed([seg.size(1) for seg in global_encoded])
477
-
478
- # Pick last 2000 tokens
479
- count = 0
480
- for i, length in enumerate(lengths):
481
- count += length
482
- if count + length > max_length - 1024 - sum(
483
- t.shape[1] for t in encoded_prompts
484
- ):
485
- break
486
-
487
- if i != 0 and i % 2 == 0:
488
- i -= 1
489
-
490
- # Rotate the list, always make sure first segment is included to avoid drift
491
- if i < len(global_encoded) - 2:
492
- partial_encoded = global_encoded[:2] + global_encoded[-i:]
493
- else:
494
- partial_encoded = global_encoded
495
-
496
- if use_prompt:
497
- partial_encoded = encoded_prompts + partial_encoded
498
-
499
- cat_encoded = torch.cat(partial_encoded, dim=1)
500
- prompt_length = cat_encoded.size(1)
501
-
502
- t0 = time.perf_counter()
503
- y = generate(
504
- model=model,
505
- prompt=cat_encoded,
506
- max_new_tokens=max_new_tokens,
507
- im_end_id=im_end_id,
508
- decode_one_token=decode_one_token,
509
- temperature=temperature,
510
- top_p=top_p,
511
- repetition_penalty=repetition_penalty,
512
- )
513
-
514
- if sample_idx == 0 and seg_idx == 0 and compile:
515
- logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
516
-
517
- if torch.cuda.is_available():
518
- torch.cuda.synchronize()
519
-
520
- t = time.perf_counter() - t0
521
-
522
- tokens_generated = y.size(1) - prompt_length
523
- tokens_sec = tokens_generated / t
524
- logger.info(
525
- f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
526
- )
527
- logger.info(
528
- f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
529
- )
530
-
531
- if torch.cuda.is_available():
532
- logger.info(
533
- f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
534
- )
535
-
536
- # Put the generated tokens
537
- # since there is <im_end> and <eos> tokens, we remove last 2 tokens
538
- codes = y[1:, prompt_length:-1].clone()
539
- codes = codes - 1
540
- assert (codes >= 0).all(), f"Negative code found"
541
-
542
- decoded = y[:, prompt_length:-1].clone()
543
- # But for global encoding, we should keep the <im_end> token
544
-
545
- global_encoded.append(decoded)
546
- assert (codes >= 0).all(), f"Negative code found: {codes}"
547
- yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
548
- seg_idx += 1
549
-
550
- # This indicates the end of the current sample
551
- yield GenerateResponse(action="next")
552
-
553
-
554
- @dataclass
555
- class WrappedGenerateResponse:
556
- status: Literal["success", "error"]
557
- response: Optional[GenerateResponse | Exception] = None
558
-
559
-
560
- @dataclass
561
- class GenerateRequest:
562
- request: dict
563
- response_queue: queue.Queue
564
-
565
-
566
- def launch_thread_safe_queue(
567
- checkpoint_path,
568
- device,
569
- precision,
570
- compile: bool = False,
571
- ):
572
- input_queue = queue.Queue()
573
- init_event = threading.Event()
574
-
575
- def worker():
576
- model, decode_one_token = load_model(
577
- checkpoint_path, device, precision, compile=compile
578
- )
579
- init_event.set()
580
-
581
- while True:
582
- item: GenerateRequest | None = input_queue.get()
583
- if item is None:
584
- break
585
-
586
- kwargs = item.request
587
- response_queue = item.response_queue
588
-
589
- try:
590
- for chunk in generate_long(
591
- model=model, decode_one_token=decode_one_token, **kwargs
592
- ):
593
- response_queue.put(
594
- WrappedGenerateResponse(status="success", response=chunk)
595
- )
596
- except Exception as e:
597
- response_queue.put(WrappedGenerateResponse(status="error", response=e))
598
-
599
- threading.Thread(target=worker, daemon=True).start()
600
- init_event.wait()
601
-
602
- return input_queue
603
-
604
-
605
- @click.command()
606
- @click.option(
607
- "--text",
608
- type=str,
609
- default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
610
- )
611
- @click.option("--prompt-text", type=str, default=None, multiple=True)
612
- @click.option(
613
- "--prompt-tokens",
614
- type=click.Path(path_type=Path, exists=True),
615
- default=None,
616
- multiple=True,
617
- )
618
- @click.option("--num-samples", type=int, default=1)
619
- @click.option("--max-new-tokens", type=int, default=0)
620
- @click.option("--top-p", type=float, default=0.7)
621
- @click.option("--repetition-penalty", type=float, default=1.2)
622
- @click.option("--temperature", type=float, default=0.7)
623
- @click.option(
624
- "--checkpoint-path",
625
- type=click.Path(path_type=Path, exists=True),
626
- default="checkpoints/fish-speech-1.4",
627
- )
628
- @click.option("--device", type=str, default="cuda")
629
- @click.option("--compile/--no-compile", default=False)
630
- @click.option("--seed", type=int, default=42)
631
- @click.option("--half/--no-half", default=False)
632
- @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
633
- @click.option("--chunk-length", type=int, default=100)
634
- def main(
635
- text: str,
636
- prompt_text: Optional[list[str]],
637
- prompt_tokens: Optional[list[Path]],
638
- num_samples: int,
639
- max_new_tokens: int,
640
- top_p: int,
641
- repetition_penalty: float,
642
- temperature: float,
643
- checkpoint_path: Path,
644
- device: str,
645
- compile: bool,
646
- seed: int,
647
- half: bool,
648
- iterative_prompt: bool,
649
- chunk_length: int,
650
- ) -> None:
651
-
652
- precision = torch.half if half else torch.bfloat16
653
-
654
- if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
655
- raise ValueError(
656
- f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
657
- )
658
-
659
- logger.info("Loading model ...")
660
- t0 = time.time()
661
- model, decode_one_token = load_model(
662
- checkpoint_path, device, precision, compile=compile
663
- )
664
-
665
- if torch.cuda.is_available():
666
- torch.cuda.synchronize()
667
-
668
- logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
669
-
670
- if prompt_tokens is not None:
671
- prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
672
-
673
- torch.manual_seed(seed)
674
-
675
- if torch.cuda.is_available():
676
- torch.cuda.manual_seed(seed)
677
-
678
- generator = generate_long(
679
- model=model,
680
- device=device,
681
- decode_one_token=decode_one_token,
682
- text=text,
683
- num_samples=num_samples,
684
- max_new_tokens=max_new_tokens,
685
- top_p=top_p,
686
- repetition_penalty=repetition_penalty,
687
- temperature=temperature,
688
- compile=compile,
689
- iterative_prompt=iterative_prompt,
690
- chunk_length=chunk_length,
691
- prompt_text=prompt_text,
692
- prompt_tokens=prompt_tokens,
693
- )
694
-
695
- idx = 0
696
- codes = []
697
-
698
- for response in generator:
699
- if response.action == "sample":
700
- codes.append(response.codes)
701
- logger.info(f"Sampled text: {response.text}")
702
- elif response.action == "next":
703
- if codes:
704
- np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
705
- logger.info(f"Saved codes to codes_{idx}.npy")
706
- logger.info(f"Next sample")
707
- codes = []
708
- idx += 1
709
- else:
710
- logger.error(f"Error: {response}")
711
-
712
-
713
- if __name__ == "__main__":
714
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/llama/merge_lora.py DELETED
@@ -1,95 +0,0 @@
1
- import shutil
2
- from copy import deepcopy
3
- from pathlib import Path
4
-
5
- import click
6
- import hydra
7
- import torch
8
- from hydra import compose, initialize
9
- from hydra.utils import instantiate
10
- from loguru import logger
11
-
12
- from fish_speech.models.text2semantic.llama import BaseTransformer
13
- from fish_speech.models.text2semantic.lora import get_merged_state_dict
14
-
15
-
16
- @click.command()
17
- @click.option("--lora-config", type=str, default="r_8_alpha_16")
18
- @click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.4")
19
- @click.option("--lora-weight", type=str, required=True)
20
- @click.option("--output", type=str, required=True)
21
- def merge(lora_config, base_weight, lora_weight, output):
22
- output = Path(output)
23
- logger.info(
24
- f"Merging {base_weight} and {lora_weight} into {output} with {lora_config}"
25
- )
26
-
27
- with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"):
28
- cfg = compose(config_name=lora_config)
29
-
30
- lora_config = instantiate(cfg)
31
- logger.info(f"Loaded lora model with config {lora_config}")
32
-
33
- llama_model = BaseTransformer.from_pretrained(
34
- path=base_weight,
35
- load_weights=True,
36
- lora_config=lora_config,
37
- )
38
- logger.info(f"Loaded llama model")
39
-
40
- llama_state_dict = llama_model.state_dict()
41
- llama_state_dict = {k: v for k, v in llama_state_dict.items() if "lora" not in k}
42
- llama_state_dict_copy = deepcopy(llama_state_dict)
43
- lora_state_dict = torch.load(lora_weight, map_location="cpu")
44
-
45
- if "state_dict" in llama_state_dict:
46
- llama_state_dict = llama_state_dict["state_dict"]
47
-
48
- if "state_dict" in lora_state_dict:
49
- lora_state_dict = lora_state_dict["state_dict"]
50
-
51
- # remove prefix model.
52
- if any(k.startswith("model.") for k in llama_state_dict.keys()):
53
- llama_state_dict = {
54
- k.replace("model.", ""): v
55
- for k, v in llama_state_dict.items()
56
- if k.startswith("model.")
57
- }
58
- if any(k.startswith("model.") for k in lora_state_dict.keys()):
59
- lora_state_dict = {
60
- k.replace("model.", ""): v
61
- for k, v in lora_state_dict.items()
62
- if k.startswith("model.")
63
- }
64
-
65
- logger.info(f"Found {len(llama_state_dict)} keys in llama model")
66
- logger.info(f"Found {len(lora_state_dict)} keys in lora model")
67
-
68
- merged_state_dict = llama_state_dict | lora_state_dict
69
- llama_model.load_state_dict(merged_state_dict, strict=True)
70
- logger.info(f"Merged model loaded")
71
-
72
- # Trigger eval mode to merge lora
73
- llama_model.eval()
74
- llama_model.save_pretrained(output, drop_lora=True)
75
- logger.info(f"Saved merged model to {output}, validating")
76
-
77
- new_state_dict = torch.load(output / "model.pth", map_location="cpu")
78
- original_keys = set(llama_state_dict_copy.keys())
79
- merged_keys = set(new_state_dict.keys())
80
-
81
- assert original_keys == merged_keys, "Keys should be same"
82
-
83
- for key in original_keys:
84
- diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item()
85
- if diff_l1 != 0:
86
- break
87
- else:
88
- logger.error("Merged model is same as the original model")
89
- exit(1)
90
-
91
- logger.info("Merged model is different from the original model, check passed")
92
-
93
-
94
- if __name__ == "__main__":
95
- merge()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/llama/quantize.py DELETED
@@ -1,497 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- import datetime
4
- import shutil
5
-
6
- # This source code is licensed under the license found in the
7
- # LICENSE file in the root directory of this source tree.
8
- import time
9
- from pathlib import Path
10
-
11
- import click
12
- import torch
13
- import torch.nn as nn
14
- import torch.nn.functional as F
15
-
16
- from fish_speech.models.text2semantic.llama import find_multiple
17
- from tools.llama.generate import load_model
18
-
19
- ##### Quantization Primitives ######
20
-
21
-
22
- def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
23
- # assumes symmetric quantization
24
- # assumes axis == 0
25
- # assumes dense memory format
26
- # TODO(future): relax ^ as needed
27
-
28
- # default setup for affine quantization of activations
29
- eps = torch.finfo(torch.float32).eps
30
-
31
- # get min and max
32
- min_val, max_val = torch.aminmax(x, dim=1)
33
-
34
- # calculate scales and zero_points based on min and max
35
- # reference: https://fburl.com/code/srbiybme
36
- min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
37
- max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
38
- device = min_val_neg.device
39
-
40
- # reference: https://fburl.com/code/4wll53rk
41
- max_val_pos = torch.max(-min_val_neg, max_val_pos)
42
- scales = max_val_pos / (float(quant_max - quant_min) / 2)
43
- # ensure scales is the same dtype as the original tensor
44
- scales = torch.clamp(scales, min=eps).to(x.dtype)
45
- zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
46
-
47
- # quantize based on qmin/qmax/scales/zp
48
- # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
49
- x_div = x / scales.unsqueeze(-1)
50
- x_round = torch.round(x_div)
51
- x_zp = x_round + zero_points.unsqueeze(-1)
52
- quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
53
-
54
- return quant, scales, zero_points
55
-
56
-
57
- def get_group_qparams(w, n_bit=4, groupsize=128):
58
- # needed for GPTQ with padding
59
- if groupsize > w.shape[-1]:
60
- groupsize = w.shape[-1]
61
- assert groupsize > 1
62
- assert w.shape[-1] % groupsize == 0
63
- assert w.dim() == 2
64
-
65
- to_quant = w.reshape(-1, groupsize)
66
- assert torch.isnan(to_quant).sum() == 0
67
-
68
- max_val = to_quant.amax(dim=1, keepdim=True)
69
- min_val = to_quant.amin(dim=1, keepdim=True)
70
- max_int = 2**n_bit - 1
71
- scales = (max_val - min_val).clamp(min=1e-6) / max_int
72
- zeros = min_val + scales * (2 ** (n_bit - 1))
73
- return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
74
- torch.bfloat16
75
- ).reshape(w.shape[0], -1)
76
-
77
-
78
- def pack_scales_and_zeros(scales, zeros):
79
- assert scales.shape == zeros.shape
80
- assert scales.dtype == torch.bfloat16
81
- assert zeros.dtype == torch.bfloat16
82
- return (
83
- torch.cat(
84
- [
85
- scales.reshape(scales.size(0), scales.size(1), 1),
86
- zeros.reshape(zeros.size(0), zeros.size(1), 1),
87
- ],
88
- 2,
89
- )
90
- .transpose(0, 1)
91
- .contiguous()
92
- )
93
-
94
-
95
- def unpack_scales_and_zeros(scales_and_zeros):
96
- assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
97
- assert scales_and_zeros.dtype == torch.float
98
- return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
99
-
100
-
101
- def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
102
- assert groupsize > 1
103
- # needed for GPTQ single column quantize
104
- if groupsize > w.shape[-1] and scales.shape[-1] == 1:
105
- groupsize = w.shape[-1]
106
-
107
- assert w.shape[-1] % groupsize == 0
108
- assert w.dim() == 2
109
-
110
- to_quant = w.reshape(-1, groupsize)
111
- assert torch.isnan(to_quant).sum() == 0
112
-
113
- scales = scales.reshape(-1, 1)
114
- zeros = zeros.reshape(-1, 1)
115
- min_val = zeros - scales * (2 ** (n_bit - 1))
116
- max_int = 2**n_bit - 1
117
- min_int = 0
118
- w_int32 = (
119
- to_quant.sub(min_val)
120
- .div(scales)
121
- .round()
122
- .clamp_(min_int, max_int)
123
- .to(torch.int32)
124
- .reshape_as(w)
125
- )
126
-
127
- return w_int32
128
-
129
-
130
- def group_quantize_tensor(w, n_bit=4, groupsize=128):
131
- scales, zeros = get_group_qparams(w, n_bit, groupsize)
132
- w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
133
- scales_and_zeros = pack_scales_and_zeros(scales, zeros)
134
- return w_int32, scales_and_zeros
135
-
136
-
137
- def group_dequantize_tensor_from_qparams(
138
- w_int32, scales, zeros, n_bit=4, groupsize=128
139
- ):
140
- assert groupsize > 1
141
- # needed for GPTQ single column dequantize
142
- if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
143
- groupsize = w_int32.shape[-1]
144
- assert w_int32.shape[-1] % groupsize == 0
145
- assert w_int32.dim() == 2
146
-
147
- w_int32_grouped = w_int32.reshape(-1, groupsize)
148
- scales = scales.reshape(-1, 1)
149
- zeros = zeros.reshape(-1, 1)
150
-
151
- w_dq = (
152
- w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
153
- )
154
- return w_dq
155
-
156
-
157
- def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
158
- scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
159
- return group_dequantize_tensor_from_qparams(
160
- w_int32, scales, zeros, n_bit, groupsize
161
- )
162
-
163
-
164
- class QuantHandler:
165
- def __init__(self, mod):
166
- self.mod = mod
167
-
168
- def create_quantized_state_dict(self) -> "StateDict":
169
- pass
170
-
171
- def convert_for_runtime(self) -> "nn.Module":
172
- pass
173
-
174
-
175
- ##### Weight-only int8 per-channel quantized code ######
176
-
177
-
178
- def replace_linear_weight_only_int8_per_channel(module):
179
- for name, child in module.named_children():
180
- if isinstance(child, nn.Linear):
181
- setattr(
182
- module,
183
- name,
184
- WeightOnlyInt8Linear(child.in_features, child.out_features),
185
- )
186
- else:
187
- replace_linear_weight_only_int8_per_channel(child)
188
-
189
-
190
- class WeightOnlyInt8QuantHandler:
191
- def __init__(self, mod):
192
- self.mod = mod
193
-
194
- @torch.no_grad()
195
- def create_quantized_state_dict(self):
196
- cur_state_dict = self.mod.state_dict()
197
- for fqn, mod in self.mod.named_modules():
198
- if isinstance(mod, torch.nn.Linear):
199
- int8_weight, scales, _ = dynamically_quantize_per_channel(
200
- mod.weight.float(), -128, 127, torch.int8
201
- )
202
- cur_state_dict[f"{fqn}.weight"] = int8_weight
203
- cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
204
-
205
- return cur_state_dict
206
-
207
- def convert_for_runtime(self):
208
- replace_linear_weight_only_int8_per_channel(self.mod)
209
- return self.mod
210
-
211
-
212
- class WeightOnlyInt8Linear(torch.nn.Module):
213
- __constants__ = ["in_features", "out_features"]
214
- in_features: int
215
- out_features: int
216
- weight: torch.Tensor
217
-
218
- def __init__(
219
- self,
220
- in_features: int,
221
- out_features: int,
222
- bias: bool = True,
223
- device=None,
224
- dtype=None,
225
- ) -> None:
226
- factory_kwargs = {"device": device, "dtype": dtype}
227
- super().__init__()
228
- self.in_features = in_features
229
- self.out_features = out_features
230
- self.register_buffer(
231
- "weight", torch.empty((out_features, in_features), dtype=torch.int8)
232
- )
233
- self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
234
-
235
- def forward(self, input: torch.Tensor) -> torch.Tensor:
236
- return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
237
-
238
-
239
- ##### weight only int4 per channel groupwise quantized code ######
240
-
241
-
242
- def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
243
- weight_int32, scales_and_zeros = group_quantize_tensor(
244
- weight_bf16, n_bit=4, groupsize=groupsize
245
- )
246
- weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
247
- weight_int32, inner_k_tiles
248
- )
249
- return weight_int4pack, scales_and_zeros
250
-
251
-
252
- def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
253
- origin_x_size = x.size()
254
- x = x.reshape(-1, origin_x_size[-1])
255
- c = torch.ops.aten._weight_int4pack_mm(
256
- x, weight_int4pack, groupsize, scales_and_zeros
257
- )
258
- new_shape = origin_x_size[:-1] + (out_features,)
259
- c = c.reshape(new_shape)
260
- return c
261
-
262
-
263
- def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1):
264
- return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
265
-
266
-
267
- def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
268
- for name, child in module.named_children():
269
- if isinstance(child, nn.Linear):
270
- if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
271
- setattr(
272
- module,
273
- name,
274
- WeightOnlyInt4Linear(
275
- child.in_features,
276
- child.out_features,
277
- bias=False,
278
- groupsize=groupsize,
279
- inner_k_tiles=inner_k_tiles,
280
- padding=False,
281
- ),
282
- )
283
- elif padding:
284
- setattr(
285
- module,
286
- name,
287
- WeightOnlyInt4Linear(
288
- child.in_features,
289
- child.out_features,
290
- bias=False,
291
- groupsize=groupsize,
292
- inner_k_tiles=inner_k_tiles,
293
- padding=True,
294
- ),
295
- )
296
- else:
297
- replace_linear_int4(child, groupsize, inner_k_tiles, padding)
298
-
299
-
300
- class WeightOnlyInt4QuantHandler:
301
- def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
302
- self.mod = mod
303
- self.groupsize = groupsize
304
- self.inner_k_tiles = inner_k_tiles
305
- self.padding = padding
306
- assert groupsize in [32, 64, 128, 256]
307
- assert inner_k_tiles in [2, 4, 8]
308
-
309
- @torch.no_grad()
310
- def create_quantized_state_dict(self):
311
- cur_state_dict = self.mod.state_dict()
312
- for fqn, mod in self.mod.named_modules():
313
- if isinstance(mod, torch.nn.Linear):
314
- assert not mod.bias
315
- out_features = mod.out_features
316
- in_features = mod.in_features
317
- assert out_features % 8 == 0, "require out_features % 8 == 0"
318
- print(f"linear: {fqn}, in={in_features}, out={out_features}")
319
-
320
- weight = mod.weight.data
321
- if not _check_linear_int4_k(
322
- in_features, self.groupsize, self.inner_k_tiles
323
- ):
324
- if self.padding:
325
- import torch.nn.functional as F
326
-
327
- print(
328
- f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
329
- )
330
- padded_in_features = find_multiple(in_features, 1024)
331
- weight = F.pad(
332
- weight, pad=(0, padded_in_features - in_features)
333
- )
334
- else:
335
- print(
336
- f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
337
- + "and that groupsize and inner_k_tiles*16 evenly divide into it"
338
- )
339
- continue
340
- (
341
- weight_int4pack,
342
- scales_and_zeros,
343
- ) = prepare_int4_weight_and_scales_and_zeros(
344
- weight.to(torch.bfloat16).to("cuda"),
345
- self.groupsize,
346
- self.inner_k_tiles,
347
- )
348
- cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
349
- cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu")
350
-
351
- return cur_state_dict
352
-
353
- def convert_for_runtime(self):
354
- replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
355
- return self.mod
356
-
357
-
358
- class WeightOnlyInt4Linear(torch.nn.Module):
359
- __constants__ = ["in_features", "out_features"]
360
- in_features: int
361
- out_features: int
362
- weight: torch.Tensor
363
-
364
- def __init__(
365
- self,
366
- in_features: int,
367
- out_features: int,
368
- bias=True,
369
- device=None,
370
- dtype=None,
371
- groupsize: int = 128,
372
- inner_k_tiles: int = 8,
373
- padding: bool = True,
374
- ) -> None:
375
- super().__init__()
376
- self.padding = padding
377
- if padding:
378
- self.origin_in_features = in_features
379
- in_features = find_multiple(in_features, 1024)
380
-
381
- self.in_features = in_features
382
- self.out_features = out_features
383
- assert not bias, "require bias=False"
384
- self.groupsize = groupsize
385
- self.inner_k_tiles = inner_k_tiles
386
-
387
- assert out_features % 8 == 0, "require out_features % 8 == 0"
388
- assert (
389
- in_features % (inner_k_tiles * 16) == 0
390
- ), "require in_features % (innerKTiles * 16) == 0"
391
- self.register_buffer(
392
- "weight",
393
- torch.empty(
394
- (
395
- out_features // 8,
396
- in_features // (inner_k_tiles * 16),
397
- 32,
398
- inner_k_tiles // 2,
399
- ),
400
- dtype=torch.int32,
401
- ),
402
- )
403
- self.register_buffer(
404
- "scales_and_zeros",
405
- torch.empty(
406
- (in_features // groupsize, out_features, 2), dtype=torch.bfloat16
407
- ),
408
- )
409
-
410
- def forward(self, input: torch.Tensor) -> torch.Tensor:
411
- input = input.to(torch.bfloat16)
412
- if self.padding:
413
- import torch.nn.functional as F
414
-
415
- input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
416
- return linear_forward_int4(
417
- input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize
418
- )
419
-
420
-
421
- def generate_folder_name():
422
- now = datetime.datetime.now()
423
- folder_name = now.strftime("%Y%m%d_%H%M%S")
424
- return folder_name
425
-
426
-
427
- @click.command()
428
- @click.option(
429
- "--checkpoint-path",
430
- type=click.Path(path_type=Path, exists=True),
431
- default="checkpoints/fish-speech-1.4",
432
- )
433
- @click.option(
434
- "--mode", type=str, default="int8", help="type of quantization to perform"
435
- )
436
- @click.option(
437
- "--groupsize", type=int, default=128, help="Group size for int4 quantization."
438
- )
439
- @click.option("--timestamp", type=str, default="None", help="When to do quantization")
440
- def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -> None:
441
-
442
- device = "cpu"
443
- precision = torch.bfloat16
444
-
445
- print("Loading model ...")
446
- t0 = time.time()
447
-
448
- model, _ = load_model(
449
- checkpoint_path=checkpoint_path,
450
- device=device,
451
- precision=precision,
452
- compile=False,
453
- )
454
- vq_model = "firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
455
- now = timestamp if timestamp != "None" else generate_folder_name()
456
-
457
- if mode == "int8":
458
- print(
459
- "Quantizing model weights for int8 weight-only symmetric per-channel quantization"
460
- )
461
- quant_handler = WeightOnlyInt8QuantHandler(model)
462
- quantized_state_dict = quant_handler.create_quantized_state_dict()
463
-
464
- dir_name = checkpoint_path
465
- dst_name = Path(f"checkpoints/fs-1.2-int8-{now}")
466
- shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
467
- if (dst_name / vq_model).exists():
468
- (dst_name / vq_model).unlink()
469
- quantize_path = dst_name / "model.pth"
470
-
471
- elif mode == "int4":
472
- print(
473
- "Quantizing model weights for int4 weight-only affine per-channel groupwise quantization"
474
- )
475
- quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
476
- quantized_state_dict = quant_handler.create_quantized_state_dict()
477
-
478
- dir_name = checkpoint_path
479
- dst_name = Path(f"checkpoints/fs-1.2-int4-g{groupsize}-{now}")
480
- shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
481
- if (dst_name / vq_model).exists():
482
- (dst_name / vq_model).unlink()
483
- quantize_path = dst_name / "model.pth"
484
-
485
- else:
486
- raise ValueError(
487
- f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]"
488
- )
489
-
490
- print(f"Writing quantized weights to {quantize_path}")
491
- quantize_path.unlink(missing_ok=True) # remove existing file if one already there
492
- torch.save(quantized_state_dict, quantize_path)
493
- print(f"Quantization complete took {time.time() - t0:.02f} seconds")
494
-
495
-
496
- if __name__ == "__main__":
497
- quantize()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/llama/rebuild_tokenizer.py DELETED
@@ -1,57 +0,0 @@
1
- from tokenizers import Tokenizer, decoders, models, pre_tokenizers, processors, trainers
2
- from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
3
-
4
- # Initialize a tokenizer
5
- tokenizer = Tokenizer(models.BPE())
6
-
7
- # Customize pre-tokenization and decoding
8
- tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
9
- tokenizer.decoder = decoders.ByteLevel()
10
- tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
11
-
12
- # Don't train the tokenizer
13
- trainer = trainers.BpeTrainer(
14
- vocab_size=0,
15
- min_frequency=2,
16
- initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
17
- special_tokens=[
18
- "<|begin_of_sequence|>",
19
- "<|end_of_sequence|>",
20
- "<|im_start|>",
21
- "<|im_sep|>", # system, user, assistant, etc.
22
- "<|im_end|>",
23
- "<|semantic|>", # audio features
24
- "<|pad|>",
25
- ],
26
- )
27
-
28
- # <|im_start|>user<|im_sep|>...<|im_end|>
29
- # <|im_start|>assistant<|im_sep|><|semantic|><|semantic|><|semantic|><|semantic|><|semantic|><|im_end|>
30
- tokenizer.train_from_iterator([], trainer=trainer)
31
-
32
- print(len(tokenizer.get_vocab()))
33
- x = tokenizer.encode(
34
- "Hello, how are you? dfgnviadfjoiviouajeiodfjv 你好世界 🈶<|semantic|>"
35
- ).ids
36
- print(x, len(x))
37
- print(tokenizer.decode(x, skip_special_tokens=True))
38
-
39
-
40
- tokenizer = PreTrainedTokenizerFast(
41
- tokenizer_object=tokenizer,
42
- pad_token="<|pad|>",
43
- bos_token="<|begin_of_sequence|>",
44
- eos_token="<|end_of_sequence|>",
45
- )
46
-
47
- # Try tokenizing a new sequence
48
- sequence = "All around, too, lay vast quantities of the costliest merchandise, and treasures were heaped in every cranny of the rocks, but all these things only added to the desolation of the scene. 测试中文, 你好世界 🈶<|semantic|>"
49
- encoded = tokenizer(sequence).input_ids
50
-
51
- print("Test encoding....")
52
- print(f"\tSentence: {sequence}")
53
- print(f"\tEncoded: {encoded}")
54
- print(f"\tDecoded: {tokenizer.batch_decode(encoded)}")
55
- print(f"\tDecoded: {tokenizer.decode(encoded)}")
56
-
57
- tokenizer.push_to_hub("fishaudio/fish-speech-1", private=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/msgpack_api.py DELETED
@@ -1,34 +0,0 @@
1
- import httpx
2
- import ormsgpack
3
-
4
- from tools.commons import ServeReferenceAudio, ServeTTSRequest
5
-
6
- # priority: ref_id > references
7
- request = ServeTTSRequest(
8
- text="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
9
- # reference_id="114514",
10
- references=[
11
- ServeReferenceAudio(
12
- audio=open("lengyue.wav", "rb").read(),
13
- text=open("lengyue.lab", "r", encoding="utf-8").read(),
14
- )
15
- ],
16
- streaming=True,
17
- )
18
-
19
- with (
20
- httpx.Client() as client,
21
- open("hello.wav", "wb") as f,
22
- ):
23
- with client.stream(
24
- "POST",
25
- "http://127.0.0.1:8080/v1/tts",
26
- content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
27
- headers={
28
- "authorization": "Bearer YOUR_API_KEY",
29
- "content-type": "application/msgpack",
30
- },
31
- timeout=None,
32
- ) as response:
33
- for chunk in response.iter_bytes():
34
- f.write(chunk)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/post_api.py DELETED
@@ -1,205 +0,0 @@
1
- import argparse
2
- import base64
3
- import wave
4
-
5
- import ormsgpack
6
- import pyaudio
7
- import requests
8
- from pydub import AudioSegment
9
- from pydub.playback import play
10
-
11
- from tools.commons import ServeReferenceAudio, ServeTTSRequest
12
- from tools.file import audio_to_bytes, read_ref_text
13
-
14
-
15
- def parse_args():
16
-
17
- parser = argparse.ArgumentParser(
18
- description="Send a WAV file and text to a server and receive synthesized audio."
19
- )
20
-
21
- parser.add_argument(
22
- "--url",
23
- "-u",
24
- type=str,
25
- default="http://127.0.0.1:8080/v1/tts",
26
- help="URL of the server",
27
- )
28
- parser.add_argument(
29
- "--text", "-t", type=str, required=True, help="Text to be synthesized"
30
- )
31
- parser.add_argument(
32
- "--reference_id",
33
- "-id",
34
- type=str,
35
- default=None,
36
- help="ID of the reference model o be used for the speech",
37
- )
38
- parser.add_argument(
39
- "--reference_audio",
40
- "-ra",
41
- type=str,
42
- nargs="+",
43
- default=None,
44
- help="Path to the WAV file",
45
- )
46
- parser.add_argument(
47
- "--reference_text",
48
- "-rt",
49
- type=str,
50
- nargs="+",
51
- default=None,
52
- help="Reference text for voice synthesis",
53
- )
54
- parser.add_argument(
55
- "--output",
56
- "-o",
57
- type=str,
58
- default="generated_audio",
59
- help="Output audio file name",
60
- )
61
- parser.add_argument(
62
- "--play",
63
- type=bool,
64
- default=True,
65
- help="Whether to play audio after receiving data",
66
- )
67
- parser.add_argument("--normalize", type=bool, default=True)
68
- parser.add_argument(
69
- "--format", type=str, choices=["wav", "mp3", "flac"], default="wav"
70
- )
71
- parser.add_argument("--mp3_bitrate", type=int, default=64)
72
- parser.add_argument("--opus_bitrate", type=int, default=-1000)
73
- parser.add_argument("--latency", type=str, default="normal", help="延迟选项")
74
- parser.add_argument(
75
- "--max_new_tokens",
76
- type=int,
77
- default=1024,
78
- help="Maximum new tokens to generate",
79
- )
80
- parser.add_argument(
81
- "--chunk_length", type=int, default=100, help="Chunk length for synthesis"
82
- )
83
- parser.add_argument(
84
- "--top_p", type=float, default=0.7, help="Top-p sampling for synthesis"
85
- )
86
- parser.add_argument(
87
- "--repetition_penalty",
88
- type=float,
89
- default=1.2,
90
- help="Repetition penalty for synthesis",
91
- )
92
- parser.add_argument(
93
- "--temperature", type=float, default=0.7, help="Temperature for sampling"
94
- )
95
- parser.add_argument(
96
- "--speaker", type=str, default=None, help="Speaker ID for voice synthesis"
97
- )
98
- parser.add_argument("--emotion", type=str, default=None, help="Speaker's Emotion")
99
- parser.add_argument(
100
- "--streaming", type=bool, default=False, help="Enable streaming response"
101
- )
102
- parser.add_argument(
103
- "--channels", type=int, default=1, help="Number of audio channels"
104
- )
105
- parser.add_argument("--rate", type=int, default=44100, help="Sample rate for audio")
106
-
107
- return parser.parse_args()
108
-
109
-
110
- if __name__ == "__main__":
111
-
112
- args = parse_args()
113
-
114
- idstr: str | None = args.reference_id
115
- # priority: ref_id > [{text, audio},...]
116
- if idstr is None:
117
- ref_audios = args.reference_audio
118
- ref_texts = args.reference_text
119
- if ref_audios is None:
120
- byte_audios = []
121
- else:
122
- byte_audios = [audio_to_bytes(ref_audio) for ref_audio in ref_audios]
123
- if ref_texts is None:
124
- ref_texts = []
125
- else:
126
- ref_texts = [read_ref_text(ref_text) for ref_text in ref_texts]
127
- else:
128
- byte_audios = []
129
- ref_texts = []
130
- pass # in api.py
131
-
132
- data = {
133
- "text": args.text,
134
- "references": [
135
- ServeReferenceAudio(audio=ref_audio, text=ref_text)
136
- for ref_text, ref_audio in zip(ref_texts, byte_audios)
137
- ],
138
- "reference_id": idstr,
139
- "normalize": args.normalize,
140
- "format": args.format,
141
- "mp3_bitrate": args.mp3_bitrate,
142
- "opus_bitrate": args.opus_bitrate,
143
- "max_new_tokens": args.max_new_tokens,
144
- "chunk_length": args.chunk_length,
145
- "top_p": args.top_p,
146
- "repetition_penalty": args.repetition_penalty,
147
- "temperature": args.temperature,
148
- "speaker": args.speaker,
149
- "emotion": args.emotion,
150
- "streaming": args.streaming,
151
- }
152
-
153
- pydantic_data = ServeTTSRequest(**data)
154
-
155
- response = requests.post(
156
- args.url,
157
- data=ormsgpack.packb(pydantic_data, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
158
- stream=args.streaming,
159
- headers={
160
- "authorization": "Bearer YOUR_API_KEY",
161
- "content-type": "application/msgpack",
162
- },
163
- )
164
-
165
- if response.status_code == 200:
166
- if args.streaming:
167
- p = pyaudio.PyAudio()
168
- audio_format = pyaudio.paInt16 # Assuming 16-bit PCM format
169
- stream = p.open(
170
- format=audio_format, channels=args.channels, rate=args.rate, output=True
171
- )
172
-
173
- wf = wave.open(f"{args.output}.wav", "wb")
174
- wf.setnchannels(args.channels)
175
- wf.setsampwidth(p.get_sample_size(audio_format))
176
- wf.setframerate(args.rate)
177
-
178
- stream_stopped_flag = False
179
-
180
- try:
181
- for chunk in response.iter_content(chunk_size=1024):
182
- if chunk:
183
- stream.write(chunk)
184
- wf.writeframesraw(chunk)
185
- else:
186
- if not stream_stopped_flag:
187
- stream.stop_stream()
188
- stream_stopped_flag = True
189
- finally:
190
- stream.close()
191
- p.terminate()
192
- wf.close()
193
- else:
194
- audio_content = response.content
195
- audio_path = f"{args.output}.{args.format}"
196
- with open(audio_path, "wb") as audio_file:
197
- audio_file.write(audio_content)
198
-
199
- audio = AudioSegment.from_file(audio_path, format=args.format)
200
- if args.play:
201
- play(audio)
202
- print(f"Audio has been saved to '{audio_path}'.")
203
- else:
204
- print(f"Request failed with status code {response.status_code}")
205
- print(response.json())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/sensevoice/README.md DELETED
@@ -1,59 +0,0 @@
1
- # FunASR Command Line Interface
2
-
3
- This tool provides a command-line interface for separating vocals from instrumental tracks, converting videos to audio, and performing speech-to-text transcription on the resulting audio files.
4
-
5
- ## Requirements
6
-
7
- - Python >= 3.10
8
- - PyTorch <= 2.3.1
9
- - ffmpeg, pydub, audio-separator[gpu].
10
-
11
- ## Installation
12
-
13
- Install the required packages:
14
-
15
- ```bash
16
- pip install -e .[stable]
17
- ```
18
-
19
- Make sure you have `ffmpeg` installed and available in your `PATH`.
20
-
21
- ## Usage
22
-
23
- ### Basic Usage
24
-
25
- To run the tool with default settings:
26
-
27
- ```bash
28
- python tools/sensevoice/fun_asr.py --audio-dir <audio_directory> --save-dir <output_directory>
29
- ```
30
-
31
- ## Options
32
-
33
- | Option | Description |
34
- | :-----------------------: | :---------------------------------------------------------------------------: |
35
- | --audio-dir | Directory containing audio or video files. |
36
- | --save-dir | Directory to save processed audio files. |
37
- | --device | Device to use for processing. Options: cuda (default) or cpu. |
38
- | --language | Language of the transcription. Default is auto. |
39
- | --max_single_segment_time | Maximum duration of a single audio segment in milliseconds. Default is 20000. |
40
- | --punc | Enable punctuation prediction. |
41
- | --denoise | Enable noise reduction (vocal separation). |
42
-
43
- ## Example
44
-
45
- To process audio files in the directory `path/to/audio` and save the output to `path/to/output`, with punctuation and noise reduction enabled:
46
-
47
- ```bash
48
- python tools/sensevoice/fun_asr.py --audio-dir path/to/audio --save-dir path/to/output --punc --denoise
49
- ```
50
-
51
- ## Additional Notes
52
-
53
- - The tool supports `both audio and video files`. Videos will be converted to audio automatically.
54
- - If the `--denoise` option is used, the tool will perform vocal separation to isolate the vocals from the instrumental tracks.
55
- - The script will automatically create necessary directories in the `--save-dir`.
56
-
57
- ## Troubleshooting
58
-
59
- If you encounter any issues, make sure all dependencies are correctly installed and configured. For more detailed troubleshooting, refer to the documentation of each dependency.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/sensevoice/__init__.py DELETED
File without changes
tools/sensevoice/auto_model.py DELETED
@@ -1,573 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- encoding: utf-8 -*-
3
- # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
4
- # MIT License (https://opensource.org/licenses/MIT)
5
-
6
- import copy
7
- import json
8
- import logging
9
- import os.path
10
- import random
11
- import re
12
- import string
13
- import time
14
-
15
- import numpy as np
16
- import torch
17
- from funasr.download.download_model_from_hub import download_model
18
- from funasr.download.file import download_from_url
19
- from funasr.register import tables
20
- from funasr.train_utils.load_pretrained_model import load_pretrained_model
21
- from funasr.train_utils.set_all_random_seed import set_all_random_seed
22
- from funasr.utils import export_utils, misc
23
- from funasr.utils.load_utils import load_audio_text_image_video, load_bytes
24
- from funasr.utils.misc import deep_update
25
- from funasr.utils.timestamp_tools import timestamp_sentence, timestamp_sentence_en
26
- from tqdm import tqdm
27
-
28
- from .vad_utils import merge_vad, slice_padding_audio_samples
29
-
30
- try:
31
- from funasr.models.campplus.cluster_backend import ClusterBackend
32
- from funasr.models.campplus.utils import distribute_spk, postprocess, sv_chunk
33
- except:
34
- pass
35
-
36
-
37
- def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
38
- """ """
39
- data_list = []
40
- key_list = []
41
- filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
42
-
43
- chars = string.ascii_letters + string.digits
44
- if isinstance(data_in, str):
45
- if data_in.startswith("http://") or data_in.startswith("https://"): # url
46
- data_in = download_from_url(data_in)
47
-
48
- if isinstance(data_in, str) and os.path.exists(
49
- data_in
50
- ): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
51
- _, file_extension = os.path.splitext(data_in)
52
- file_extension = file_extension.lower()
53
- if file_extension in filelist: # filelist: wav.scp, file.jsonl;text.txt;
54
- with open(data_in, encoding="utf-8") as fin:
55
- for line in fin:
56
- key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
57
- if data_in.endswith(
58
- ".jsonl"
59
- ): # file.jsonl: json.dumps({"source": data})
60
- lines = json.loads(line.strip())
61
- data = lines["source"]
62
- key = data["key"] if "key" in data else key
63
- else: # filelist, wav.scp, text.txt: id \t data or data
64
- lines = line.strip().split(maxsplit=1)
65
- data = lines[1] if len(lines) > 1 else lines[0]
66
- key = lines[0] if len(lines) > 1 else key
67
-
68
- data_list.append(data)
69
- key_list.append(key)
70
- else:
71
- if key is None:
72
- # key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
73
- key = misc.extract_filename_without_extension(data_in)
74
- data_list = [data_in]
75
- key_list = [key]
76
- elif isinstance(data_in, (list, tuple)):
77
- if data_type is not None and isinstance(
78
- data_type, (list, tuple)
79
- ): # mutiple inputs
80
- data_list_tmp = []
81
- for data_in_i, data_type_i in zip(data_in, data_type):
82
- key_list, data_list_i = prepare_data_iterator(
83
- data_in=data_in_i, data_type=data_type_i
84
- )
85
- data_list_tmp.append(data_list_i)
86
- data_list = []
87
- for item in zip(*data_list_tmp):
88
- data_list.append(item)
89
- else:
90
- # [audio sample point, fbank, text]
91
- data_list = data_in
92
- key_list = []
93
- for data_i in data_in:
94
- if isinstance(data_i, str) and os.path.exists(data_i):
95
- key = misc.extract_filename_without_extension(data_i)
96
- else:
97
- if key is None:
98
- key = "rand_key_" + "".join(
99
- random.choice(chars) for _ in range(13)
100
- )
101
- key_list.append(key)
102
-
103
- else: # raw text; audio sample point, fbank; bytes
104
- if isinstance(data_in, bytes): # audio bytes
105
- data_in = load_bytes(data_in)
106
- if key is None:
107
- key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
108
- data_list = [data_in]
109
- key_list = [key]
110
-
111
- return key_list, data_list
112
-
113
-
114
- class AutoModel:
115
-
116
- def __init__(self, **kwargs):
117
-
118
- try:
119
- from funasr.utils.version_checker import check_for_update
120
-
121
- print(
122
- "Check update of funasr, and it would cost few times. You may disable it by set `disable_update=True` in AutoModel"
123
- )
124
- check_for_update(disable=kwargs.get("disable_update", False))
125
- except:
126
- pass
127
-
128
- log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
129
- logging.basicConfig(level=log_level)
130
-
131
- model, kwargs = self.build_model(**kwargs)
132
-
133
- # if vad_model is not None, build vad model else None
134
- vad_model = kwargs.get("vad_model", None)
135
- vad_kwargs = (
136
- {} if kwargs.get("vad_kwargs", {}) is None else kwargs.get("vad_kwargs", {})
137
- )
138
- if vad_model is not None:
139
- logging.info("Building VAD model.")
140
- vad_kwargs["model"] = vad_model
141
- vad_kwargs["model_revision"] = kwargs.get("vad_model_revision", "master")
142
- vad_kwargs["device"] = kwargs["device"]
143
- vad_model, vad_kwargs = self.build_model(**vad_kwargs)
144
-
145
- # if punc_model is not None, build punc model else None
146
- punc_model = kwargs.get("punc_model", None)
147
- punc_kwargs = (
148
- {}
149
- if kwargs.get("punc_kwargs", {}) is None
150
- else kwargs.get("punc_kwargs", {})
151
- )
152
- if punc_model is not None:
153
- logging.info("Building punc model.")
154
- punc_kwargs["model"] = punc_model
155
- punc_kwargs["model_revision"] = kwargs.get("punc_model_revision", "master")
156
- punc_kwargs["device"] = kwargs["device"]
157
- punc_model, punc_kwargs = self.build_model(**punc_kwargs)
158
-
159
- # if spk_model is not None, build spk model else None
160
- spk_model = kwargs.get("spk_model", None)
161
- spk_kwargs = (
162
- {} if kwargs.get("spk_kwargs", {}) is None else kwargs.get("spk_kwargs", {})
163
- )
164
- if spk_model is not None:
165
- logging.info("Building SPK model.")
166
- spk_kwargs["model"] = spk_model
167
- spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master")
168
- spk_kwargs["device"] = kwargs["device"]
169
- spk_model, spk_kwargs = self.build_model(**spk_kwargs)
170
- self.cb_model = ClusterBackend().to(kwargs["device"])
171
- spk_mode = kwargs.get("spk_mode", "punc_segment")
172
- if spk_mode not in ["default", "vad_segment", "punc_segment"]:
173
- logging.error(
174
- "spk_mode should be one of default, vad_segment and punc_segment."
175
- )
176
- self.spk_mode = spk_mode
177
-
178
- self.kwargs = kwargs
179
- self.model = model
180
- self.vad_model = vad_model
181
- self.vad_kwargs = vad_kwargs
182
- self.punc_model = punc_model
183
- self.punc_kwargs = punc_kwargs
184
- self.spk_model = spk_model
185
- self.spk_kwargs = spk_kwargs
186
- self.model_path = kwargs.get("model_path")
187
-
188
- @staticmethod
189
- def build_model(**kwargs):
190
- assert "model" in kwargs
191
- if "model_conf" not in kwargs:
192
- logging.info(
193
- "download models from model hub: {}".format(kwargs.get("hub", "ms"))
194
- )
195
- kwargs = download_model(**kwargs)
196
-
197
- set_all_random_seed(kwargs.get("seed", 0))
198
-
199
- device = kwargs.get("device", "cuda")
200
- if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
201
- device = "cpu"
202
- kwargs["batch_size"] = 1
203
- kwargs["device"] = device
204
-
205
- torch.set_num_threads(kwargs.get("ncpu", 4))
206
-
207
- # build tokenizer
208
- tokenizer = kwargs.get("tokenizer", None)
209
- if tokenizer is not None:
210
- tokenizer_class = tables.tokenizer_classes.get(tokenizer)
211
- tokenizer = tokenizer_class(**kwargs.get("tokenizer_conf", {}))
212
- kwargs["token_list"] = (
213
- tokenizer.token_list if hasattr(tokenizer, "token_list") else None
214
- )
215
- kwargs["token_list"] = (
216
- tokenizer.get_vocab()
217
- if hasattr(tokenizer, "get_vocab")
218
- else kwargs["token_list"]
219
- )
220
- vocab_size = (
221
- len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
222
- )
223
- if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
224
- vocab_size = tokenizer.get_vocab_size()
225
- else:
226
- vocab_size = -1
227
- kwargs["tokenizer"] = tokenizer
228
-
229
- # build frontend
230
- frontend = kwargs.get("frontend", None)
231
- kwargs["input_size"] = None
232
- if frontend is not None:
233
- frontend_class = tables.frontend_classes.get(frontend)
234
- frontend = frontend_class(**kwargs.get("frontend_conf", {}))
235
- kwargs["input_size"] = (
236
- frontend.output_size() if hasattr(frontend, "output_size") else None
237
- )
238
- kwargs["frontend"] = frontend
239
- # build model
240
- model_class = tables.model_classes.get(kwargs["model"])
241
- assert model_class is not None, f'{kwargs["model"]} is not registered'
242
- model_conf = {}
243
- deep_update(model_conf, kwargs.get("model_conf", {}))
244
- deep_update(model_conf, kwargs)
245
- model = model_class(**model_conf, vocab_size=vocab_size)
246
-
247
- # init_param
248
- init_param = kwargs.get("init_param", None)
249
- if init_param is not None:
250
- if os.path.exists(init_param):
251
- logging.info(f"Loading pretrained params from {init_param}")
252
- load_pretrained_model(
253
- model=model,
254
- path=init_param,
255
- ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
256
- oss_bucket=kwargs.get("oss_bucket", None),
257
- scope_map=kwargs.get("scope_map", []),
258
- excludes=kwargs.get("excludes", None),
259
- )
260
- else:
261
- print(f"error, init_param does not exist!: {init_param}")
262
-
263
- # fp16
264
- if kwargs.get("fp16", False):
265
- model.to(torch.float16)
266
- elif kwargs.get("bf16", False):
267
- model.to(torch.bfloat16)
268
- model.to(device)
269
-
270
- if not kwargs.get("disable_log", True):
271
- tables.print()
272
-
273
- return model, kwargs
274
-
275
- def __call__(self, *args, **cfg):
276
- kwargs = self.kwargs
277
- deep_update(kwargs, cfg)
278
- res = self.model(*args, kwargs)
279
- return res
280
-
281
- def generate(self, input, input_len=None, **cfg):
282
- if self.vad_model is None:
283
- return self.inference(input, input_len=input_len, **cfg)
284
-
285
- else:
286
- return self.inference_with_vad(input, input_len=input_len, **cfg)
287
-
288
- def inference(
289
- self, input, input_len=None, model=None, kwargs=None, key=None, **cfg
290
- ):
291
- kwargs = self.kwargs if kwargs is None else kwargs
292
- if "cache" in kwargs:
293
- kwargs.pop("cache")
294
- deep_update(kwargs, cfg)
295
- model = self.model if model is None else model
296
- model.eval()
297
-
298
- batch_size = kwargs.get("batch_size", 1)
299
- # if kwargs.get("device", "cpu") == "cpu":
300
- # batch_size = 1
301
-
302
- key_list, data_list = prepare_data_iterator(
303
- input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
304
- )
305
-
306
- speed_stats = {}
307
- asr_result_list = []
308
- num_samples = len(data_list)
309
- disable_pbar = self.kwargs.get("disable_pbar", False)
310
- pbar = (
311
- tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
312
- if not disable_pbar
313
- else None
314
- )
315
- time_speech_total = 0.0
316
- time_escape_total = 0.0
317
- for beg_idx in range(0, num_samples, batch_size):
318
- end_idx = min(num_samples, beg_idx + batch_size)
319
- data_batch = data_list[beg_idx:end_idx]
320
- key_batch = key_list[beg_idx:end_idx]
321
- batch = {"data_in": data_batch, "key": key_batch}
322
-
323
- if (end_idx - beg_idx) == 1 and kwargs.get(
324
- "data_type", None
325
- ) == "fbank": # fbank
326
- batch["data_in"] = data_batch[0]
327
- batch["data_lengths"] = input_len
328
-
329
- time1 = time.perf_counter()
330
- with torch.no_grad():
331
- res = model.inference(**batch, **kwargs)
332
- if isinstance(res, (list, tuple)):
333
- results = res[0] if len(res) > 0 else [{"text": ""}]
334
- meta_data = res[1] if len(res) > 1 else {}
335
- time2 = time.perf_counter()
336
-
337
- asr_result_list.extend(results)
338
-
339
- # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
340
- batch_data_time = meta_data.get("batch_data_time", -1)
341
- time_escape = time2 - time1
342
- speed_stats["load_data"] = meta_data.get("load_data", 0.0)
343
- speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
344
- speed_stats["forward"] = f"{time_escape:0.3f}"
345
- speed_stats["batch_size"] = f"{len(results)}"
346
- speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
347
- description = f"{speed_stats}, "
348
- if pbar:
349
- pbar.update(end_idx - beg_idx)
350
- pbar.set_description(description)
351
- time_speech_total += batch_data_time
352
- time_escape_total += time_escape
353
-
354
- if pbar:
355
- # pbar.update(1)
356
- pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
357
- torch.cuda.empty_cache()
358
- return asr_result_list
359
-
360
- def vad(self, input, input_len=None, **cfg):
361
- kwargs = self.kwargs
362
- # step.1: compute the vad model
363
- deep_update(self.vad_kwargs, cfg)
364
- beg_vad = time.time()
365
- res = self.inference(
366
- input,
367
- input_len=input_len,
368
- model=self.vad_model,
369
- kwargs=self.vad_kwargs,
370
- **cfg,
371
- )
372
- end_vad = time.time()
373
- # FIX(gcf): concat the vad clips for sense vocie model for better aed
374
- if cfg.get("merge_vad", False):
375
- for i in range(len(res)):
376
- res[i]["value"] = merge_vad(
377
- res[i]["value"], kwargs.get("merge_length_s", 15) * 1000
378
- )
379
- elapsed = end_vad - beg_vad
380
- return elapsed, res
381
-
382
- def inference_with_vadres(self, input, vad_res, input_len=None, **cfg):
383
-
384
- kwargs = self.kwargs
385
-
386
- # step.2 compute asr model
387
- model = self.model
388
- deep_update(kwargs, cfg)
389
- batch_size = max(int(kwargs.get("batch_size_s", 300)) * 1000, 1)
390
- batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60)) * 1000
391
- kwargs["batch_size"] = batch_size
392
-
393
- key_list, data_list = prepare_data_iterator(
394
- input, input_len=input_len, data_type=kwargs.get("data_type", None)
395
- )
396
- results_ret_list = []
397
- time_speech_total_all_samples = 1e-6
398
-
399
- beg_total = time.time()
400
- pbar_total = (
401
- tqdm(colour="red", total=len(vad_res), dynamic_ncols=True)
402
- if not kwargs.get("disable_pbar", False)
403
- else None
404
- )
405
-
406
- for i in range(len(vad_res)):
407
- key = vad_res[i]["key"]
408
- vadsegments = vad_res[i]["value"]
409
- input_i = data_list[i]
410
- fs = kwargs["frontend"].fs if hasattr(kwargs["frontend"], "fs") else 16000
411
- speech = load_audio_text_image_video(
412
- input_i, fs=fs, audio_fs=kwargs.get("fs", 16000)
413
- )
414
- speech_lengths = len(speech)
415
- n = len(vadsegments)
416
- data_with_index = [(vadsegments[i], i) for i in range(n)]
417
- sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
418
- results_sorted = []
419
-
420
- if not len(sorted_data):
421
- results_ret_list.append({"key": key, "text": "", "timestamp": []})
422
- logging.info("decoding, utt: {}, empty speech".format(key))
423
- continue
424
-
425
- if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
426
- batch_size = max(
427
- batch_size, sorted_data[0][0][1] - sorted_data[0][0][0]
428
- )
429
-
430
- if kwargs["device"] == "cpu":
431
- batch_size = 0
432
-
433
- beg_idx = 0
434
- beg_asr_total = time.time()
435
- time_speech_total_per_sample = speech_lengths / 16000
436
- time_speech_total_all_samples += time_speech_total_per_sample
437
-
438
- # pbar_sample = tqdm(colour="blue", total=n, dynamic_ncols=True)
439
-
440
- all_segments = []
441
- max_len_in_batch = 0
442
- end_idx = 1
443
-
444
- for j, _ in enumerate(range(0, n)):
445
- # pbar_sample.update(1)
446
- sample_length = sorted_data[j][0][1] - sorted_data[j][0][0]
447
- potential_batch_length = max(max_len_in_batch, sample_length) * (
448
- j + 1 - beg_idx
449
- )
450
- # batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0]
451
- if (
452
- j < n - 1
453
- and sample_length < batch_size_threshold_ms
454
- and potential_batch_length < batch_size
455
- ):
456
- max_len_in_batch = max(max_len_in_batch, sample_length)
457
- end_idx += 1
458
- continue
459
-
460
- speech_j, speech_lengths_j, intervals = slice_padding_audio_samples(
461
- speech, speech_lengths, sorted_data[beg_idx:end_idx]
462
- )
463
- results = self.inference(
464
- speech_j, input_len=None, model=model, kwargs=kwargs, **cfg
465
- )
466
-
467
- for _b in range(len(speech_j)):
468
- results[_b]["interval"] = intervals[_b]
469
-
470
- if self.spk_model is not None:
471
- # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
472
- for _b in range(len(speech_j)):
473
- vad_segments = [
474
- [
475
- sorted_data[beg_idx:end_idx][_b][0][0] / 1000.0,
476
- sorted_data[beg_idx:end_idx][_b][0][1] / 1000.0,
477
- np.array(speech_j[_b]),
478
- ]
479
- ]
480
- segments = sv_chunk(vad_segments)
481
- all_segments.extend(segments)
482
- speech_b = [i[2] for i in segments]
483
- spk_res = self.inference(
484
- speech_b,
485
- input_len=None,
486
- model=self.spk_model,
487
- kwargs=kwargs,
488
- **cfg,
489
- )
490
- results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"]
491
-
492
- beg_idx = end_idx
493
- end_idx += 1
494
- max_len_in_batch = sample_length
495
- if len(results) < 1:
496
- continue
497
- results_sorted.extend(results)
498
-
499
- # end_asr_total = time.time()
500
- # time_escape_total_per_sample = end_asr_total - beg_asr_total
501
- # pbar_sample.update(1)
502
- # pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
503
- # f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
504
- # f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
505
-
506
- restored_data = [0] * n
507
- for j in range(n):
508
- index = sorted_data[j][1]
509
- cur = results_sorted[j]
510
- pattern = r"<\|([^|]+)\|>"
511
- emotion_string = re.findall(pattern, cur["text"])
512
- cur["text"] = re.sub(pattern, "", cur["text"])
513
- cur["emo"] = "".join([f"<|{t}|>" for t in emotion_string])
514
- if self.punc_model is not None and len(cur["text"].strip()) > 0:
515
- deep_update(self.punc_kwargs, cfg)
516
- punc_res = self.inference(
517
- cur["text"],
518
- model=self.punc_model,
519
- kwargs=self.punc_kwargs,
520
- **cfg,
521
- )
522
- cur["text"] = punc_res[0]["text"]
523
-
524
- restored_data[index] = cur
525
-
526
- end_asr_total = time.time()
527
- time_escape_total_per_sample = end_asr_total - beg_asr_total
528
- if pbar_total:
529
- pbar_total.update(1)
530
- pbar_total.set_description(
531
- f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
532
- f"time_speech: {time_speech_total_per_sample: 0.3f}, "
533
- f"time_escape: {time_escape_total_per_sample:0.3f}"
534
- )
535
-
536
- # end_total = time.time()
537
- # time_escape_total_all_samples = end_total - beg_total
538
- # print(f"rtf_avg_all: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, "
539
- # f"time_speech_all: {time_speech_total_all_samples: 0.3f}, "
540
- # f"time_escape_all: {time_escape_total_all_samples:0.3f}")
541
- return restored_data
542
-
543
- def export(self, input=None, **cfg):
544
- """
545
-
546
- :param input:
547
- :param type:
548
- :param quantize:
549
- :param fallback_num:
550
- :param calib_num:
551
- :param opset_version:
552
- :param cfg:
553
- :return:
554
- """
555
-
556
- device = cfg.get("device", "cpu")
557
- model = self.model.to(device=device)
558
- kwargs = self.kwargs
559
- deep_update(kwargs, cfg)
560
- kwargs["device"] = device
561
- del kwargs["model"]
562
- model.eval()
563
-
564
- type = kwargs.get("type", "onnx")
565
-
566
- key_list, data_list = prepare_data_iterator(
567
- input, input_len=None, data_type=kwargs.get("data_type", None), key=None
568
- )
569
-
570
- with torch.no_grad():
571
- export_dir = export_utils.export(model=model, data_in=data_list, **kwargs)
572
-
573
- return export_dir
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/sensevoice/fun_asr.py DELETED
@@ -1,332 +0,0 @@
1
- import gc
2
- import os
3
- import re
4
-
5
- from audio_separator.separator import Separator
6
-
7
- os.environ["MODELSCOPE_CACHE"] = "./.cache/funasr"
8
- os.environ["UVR5_CACHE"] = "./.cache/uvr5-models"
9
- import json
10
- import subprocess
11
- from pathlib import Path
12
-
13
- import click
14
- import torch
15
- from loguru import logger
16
- from pydub import AudioSegment
17
- from silero_vad import get_speech_timestamps, load_silero_vad, read_audio
18
- from tqdm import tqdm
19
-
20
- from tools.file import AUDIO_EXTENSIONS, VIDEO_EXTENSIONS, list_files
21
- from tools.sensevoice.auto_model import AutoModel
22
-
23
-
24
- def uvr5_cli(
25
- audio_dir: Path,
26
- output_folder: Path,
27
- audio_files: list[Path] | None = None,
28
- output_format: str = "flac",
29
- model: str = "BS-Roformer-Viperx-1297.ckpt",
30
- ):
31
- # ["BS-Roformer-Viperx-1297.ckpt", "BS-Roformer-Viperx-1296.ckpt", "BS-Roformer-Viperx-1053.ckpt", "Mel-Roformer-Viperx-1143.ckpt"]
32
- sepr = Separator(
33
- model_file_dir=os.environ["UVR5_CACHE"],
34
- output_dir=output_folder,
35
- output_format=output_format,
36
- )
37
- dictmodel = {
38
- "BS-Roformer-Viperx-1297.ckpt": "model_bs_roformer_ep_317_sdr_12.9755.ckpt",
39
- "BS-Roformer-Viperx-1296.ckpt": "model_bs_roformer_ep_368_sdr_12.9628.ckpt",
40
- "BS-Roformer-Viperx-1053.ckpt": "model_bs_roformer_ep_937_sdr_10.5309.ckpt",
41
- "Mel-Roformer-Viperx-1143.ckpt": "model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt",
42
- }
43
- roformer_model = dictmodel[model]
44
- sepr.load_model(roformer_model)
45
- if audio_files is None:
46
- audio_files = list_files(
47
- path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
48
- )
49
- total_files = len(audio_files)
50
-
51
- print(f"{total_files} audio files found")
52
-
53
- res = []
54
- for audio in tqdm(audio_files, desc="Denoising: "):
55
- file_path = str(audio_dir / audio)
56
- sep_out = sepr.separate(file_path)
57
- if isinstance(sep_out, str):
58
- res.append(sep_out)
59
- elif isinstance(sep_out, list):
60
- res.extend(sep_out)
61
- del sepr
62
- gc.collect()
63
- if torch.cuda.is_available():
64
- torch.cuda.empty_cache()
65
-
66
- return res, roformer_model
67
-
68
-
69
- def get_sample_rate(media_path: Path):
70
- result = subprocess.run(
71
- [
72
- "ffprobe",
73
- "-v",
74
- "quiet",
75
- "-print_format",
76
- "json",
77
- "-show_streams",
78
- str(media_path),
79
- ],
80
- capture_output=True,
81
- text=True,
82
- check=True,
83
- )
84
- media_info = json.loads(result.stdout)
85
- for stream in media_info.get("streams", []):
86
- if stream.get("codec_type") == "audio":
87
- return stream.get("sample_rate")
88
- return "44100" # Default sample rate if not found
89
-
90
-
91
- def convert_to_mono(src_path: Path, out_path: Path, out_fmt: str = "wav"):
92
- sr = get_sample_rate(src_path)
93
- out_path.parent.mkdir(parents=True, exist_ok=True)
94
- if src_path.resolve() == out_path.resolve():
95
- output = str(out_path.with_stem(out_path.stem + f"_{sr}"))
96
- else:
97
- output = str(out_path)
98
- subprocess.run(
99
- [
100
- "ffmpeg",
101
- "-loglevel",
102
- "error",
103
- "-i",
104
- str(src_path),
105
- "-acodec",
106
- "pcm_s16le" if out_fmt == "wav" else "flac",
107
- "-ar",
108
- sr,
109
- "-ac",
110
- "1",
111
- "-y",
112
- output,
113
- ],
114
- check=True,
115
- )
116
- return out_path
117
-
118
-
119
- def convert_video_to_audio(video_path: Path, audio_dir: Path):
120
- cur_dir = audio_dir / video_path.relative_to(audio_dir).parent
121
- vocals = [
122
- p
123
- for p in cur_dir.glob(f"{video_path.stem}_(Vocals)*.*")
124
- if p.suffix in AUDIO_EXTENSIONS
125
- ]
126
- if len(vocals) > 0:
127
- return vocals[0]
128
- audio_path = cur_dir / f"{video_path.stem}.wav"
129
- convert_to_mono(video_path, audio_path)
130
- return audio_path
131
-
132
-
133
- @click.command()
134
- @click.option("--audio-dir", required=True, help="Directory containing audio files")
135
- @click.option(
136
- "--save-dir", required=True, help="Directory to save processed audio files"
137
- )
138
- @click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
139
- @click.option("--language", default="auto", help="Language of the transcription")
140
- @click.option(
141
- "--max_single_segment_time",
142
- default=20000,
143
- type=int,
144
- help="Maximum of Output single audio duration(ms)",
145
- )
146
- @click.option("--fsmn-vad/--silero-vad", default=False)
147
- @click.option("--punc/--no-punc", default=False)
148
- @click.option("--denoise/--no-denoise", default=False)
149
- @click.option("--save_emo/--no_save_emo", default=False)
150
- def main(
151
- audio_dir: str,
152
- save_dir: str,
153
- device: str,
154
- language: str,
155
- max_single_segment_time: int,
156
- fsmn_vad: bool,
157
- punc: bool,
158
- denoise: bool,
159
- save_emo: bool,
160
- ):
161
-
162
- audios_path = Path(audio_dir)
163
- save_path = Path(save_dir)
164
- save_path.mkdir(parents=True, exist_ok=True)
165
-
166
- video_files = list_files(
167
- path=audio_dir, extensions=VIDEO_EXTENSIONS, recursive=True
168
- )
169
- v2a_files = [convert_video_to_audio(p, audio_dir) for p in video_files]
170
-
171
- if denoise:
172
- VOCAL = "_(Vocals)"
173
- original_files = [
174
- p
175
- for p in audios_path.glob("**/*")
176
- if p.suffix in AUDIO_EXTENSIONS and VOCAL not in p.stem
177
- ]
178
-
179
- _, cur_model = uvr5_cli(
180
- audio_dir=audio_dir, output_folder=audio_dir, audio_files=original_files
181
- )
182
- need_remove = [p for p in audios_path.glob("**/*(Instrumental)*")]
183
- need_remove.extend(original_files)
184
- for _ in need_remove:
185
- _.unlink()
186
- vocal_files = [
187
- p
188
- for p in audios_path.glob("**/*")
189
- if p.suffix in AUDIO_EXTENSIONS and VOCAL in p.stem
190
- ]
191
- for f in vocal_files:
192
- fn, ext = f.stem, f.suffix
193
-
194
- v_pos = fn.find(VOCAL + "_" + cur_model.split(".")[0])
195
- if v_pos != -1:
196
- new_fn = fn[: v_pos + len(VOCAL)]
197
- new_f = f.with_name(new_fn + ext)
198
- f = f.rename(new_f)
199
- convert_to_mono(f, f, "flac")
200
- f.unlink()
201
-
202
- audio_files = list_files(
203
- path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
204
- )
205
-
206
- logger.info("Loading / Downloading Funasr model...")
207
-
208
- model_dir = "iic/SenseVoiceSmall"
209
-
210
- vad_model = "fsmn-vad" if fsmn_vad else None
211
- vad_kwargs = {"max_single_segment_time": max_single_segment_time}
212
- punc_model = "ct-punc" if punc else None
213
-
214
- manager = AutoModel(
215
- model=model_dir,
216
- trust_remote_code=False,
217
- vad_model=vad_model,
218
- vad_kwargs=vad_kwargs,
219
- punc_model=punc_model,
220
- device=device,
221
- )
222
-
223
- if not fsmn_vad and vad_model is None:
224
- vad_model = load_silero_vad()
225
-
226
- logger.info("Model loaded.")
227
-
228
- pattern = re.compile(r"_\d{3}\.")
229
-
230
- for file_path in tqdm(audio_files, desc="Processing audio file"):
231
-
232
- if pattern.search(file_path.name):
233
- # logger.info(f"Skipping {file_path} as it has already been processed.")
234
- continue
235
-
236
- file_stem = file_path.stem
237
- file_suffix = file_path.suffix
238
-
239
- rel_path = Path(file_path).relative_to(audio_dir)
240
- (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
241
-
242
- audio = AudioSegment.from_file(file_path)
243
-
244
- cfg = dict(
245
- cache={},
246
- language=language, # "zh", "en", "yue", "ja", "ko", "nospeech"
247
- use_itn=False,
248
- batch_size_s=60,
249
- )
250
-
251
- if fsmn_vad:
252
- elapsed, vad_res = manager.vad(input=str(file_path), **cfg)
253
- else:
254
- wav = read_audio(
255
- str(file_path)
256
- ) # backend (sox, soundfile, or ffmpeg) required!
257
- audio_key = file_path.stem
258
- audio_val = []
259
- speech_timestamps = get_speech_timestamps(
260
- wav,
261
- vad_model,
262
- max_speech_duration_s=max_single_segment_time // 1000,
263
- return_seconds=True,
264
- )
265
-
266
- audio_val = [
267
- [int(timestamp["start"] * 1000), int(timestamp["end"] * 1000)]
268
- for timestamp in speech_timestamps
269
- ]
270
- vad_res = []
271
- vad_res.append(dict(key=audio_key, value=audio_val))
272
-
273
- res = manager.inference_with_vadres(
274
- input=str(file_path), vad_res=vad_res, **cfg
275
- )
276
-
277
- for i, info in enumerate(res):
278
- [start_ms, end_ms] = info["interval"]
279
- text = info["text"]
280
- emo = info["emo"]
281
- sliced_audio = audio[start_ms:end_ms]
282
- audio_save_path = (
283
- save_path / rel_path.parent / f"{file_stem}_{i:03d}{file_suffix}"
284
- )
285
- sliced_audio.export(audio_save_path, format=file_suffix[1:])
286
- print(f"Exported {audio_save_path}: {text}")
287
-
288
- transcript_save_path = (
289
- save_path / rel_path.parent / f"{file_stem}_{i:03d}.lab"
290
- )
291
- with open(
292
- transcript_save_path,
293
- "w",
294
- encoding="utf-8",
295
- ) as f:
296
- f.write(text)
297
-
298
- if save_emo:
299
- emo_save_path = save_path / rel_path.parent / f"{file_stem}_{i:03d}.emo"
300
- with open(
301
- emo_save_path,
302
- "w",
303
- encoding="utf-8",
304
- ) as f:
305
- f.write(emo)
306
-
307
- if audios_path.resolve() == save_path.resolve():
308
- file_path.unlink()
309
-
310
-
311
- if __name__ == "__main__":
312
- main()
313
- exit(0)
314
- from funasr.utils.postprocess_utils import rich_transcription_postprocess
315
-
316
- # Load the audio file
317
- audio_path = Path(r"D:\PythonProject\ok\1_output_(Vocals).wav")
318
- model_dir = "iic/SenseVoiceSmall"
319
- m, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device="cuda:0")
320
- m.eval()
321
-
322
- res = m.inference(
323
- data_in=f"{kwargs['model_path']}/example/zh.mp3",
324
- language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
325
- use_itn=False,
326
- ban_emo_unk=False,
327
- **kwargs,
328
- )
329
-
330
- print(res)
331
- text = rich_transcription_postprocess(res[0][0]["text"])
332
- print(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/sensevoice/vad_utils.py DELETED
@@ -1,61 +0,0 @@
1
- import torch
2
- from torch.nn.utils.rnn import pad_sequence
3
-
4
-
5
- def slice_padding_fbank(speech, speech_lengths, vad_segments):
6
- speech_list = []
7
- speech_lengths_list = []
8
- for i, segment in enumerate(vad_segments):
9
-
10
- bed_idx = int(segment[0][0] * 16)
11
- end_idx = min(int(segment[0][1] * 16), speech_lengths[0])
12
- speech_i = speech[0, bed_idx:end_idx]
13
- speech_lengths_i = end_idx - bed_idx
14
- speech_list.append(speech_i)
15
- speech_lengths_list.append(speech_lengths_i)
16
- feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
17
- speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
18
- return feats_pad, speech_lengths_pad
19
-
20
-
21
- def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
22
- speech_list = []
23
- speech_lengths_list = []
24
- intervals = []
25
- for i, segment in enumerate(vad_segments):
26
- bed_idx = int(segment[0][0] * 16)
27
- end_idx = min(int(segment[0][1] * 16), speech_lengths)
28
- speech_i = speech[bed_idx:end_idx]
29
- speech_lengths_i = end_idx - bed_idx
30
- speech_list.append(speech_i)
31
- speech_lengths_list.append(speech_lengths_i)
32
- intervals.append([bed_idx // 16, end_idx // 16])
33
-
34
- return speech_list, speech_lengths_list, intervals
35
-
36
-
37
- def merge_vad(vad_result, max_length=15000, min_length=0):
38
- new_result = []
39
- if len(vad_result) <= 1:
40
- return vad_result
41
- time_step = [t[0] for t in vad_result] + [t[1] for t in vad_result]
42
- time_step = sorted(list(set(time_step)))
43
- if len(time_step) == 0:
44
- return []
45
- bg = 0
46
- for i in range(len(time_step) - 1):
47
- time = time_step[i]
48
- if time_step[i + 1] - bg < max_length:
49
- continue
50
- if time - bg > min_length:
51
- new_result.append([bg, time])
52
- # if time - bg < max_length * 1.5:
53
- # new_result.append([bg, time])
54
- # else:
55
- # split_num = int(time - bg) // max_length + 1
56
- # spl_l = int(time - bg) // split_num
57
- # for j in range(split_num):
58
- # new_result.append([bg + j * spl_l, bg + (j + 1) * spl_l])
59
- bg = time
60
- new_result.append([bg, time_step[-1]])
61
- return new_result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/smart_pad.py DELETED
@@ -1,60 +0,0 @@
1
- import random
2
- from multiprocessing import Pool
3
- from pathlib import Path
4
-
5
- import click
6
- import librosa
7
- import torch.nn.functional as F
8
- import torchaudio
9
- from tqdm import tqdm
10
-
11
- from tools.file import AUDIO_EXTENSIONS, list_files
12
-
13
- threshold = 10 ** (-50 / 20.0)
14
-
15
-
16
- def process(file):
17
- waveform, sample_rate = torchaudio.load(str(file), backend="sox")
18
- if waveform.size(0) > 1:
19
- waveform = waveform.mean(dim=0, keepdim=True)
20
-
21
- loudness = librosa.feature.rms(
22
- y=waveform.numpy().squeeze(), frame_length=2048, hop_length=512, center=True
23
- )[0]
24
-
25
- for i in range(len(loudness) - 1, 0, -1):
26
- if loudness[i] > threshold:
27
- break
28
-
29
- end_silent_time = (len(loudness) - i) * 512 / sample_rate
30
-
31
- if end_silent_time <= 0.3:
32
- random_time = random.uniform(0.3, 0.7) - end_silent_time
33
- waveform = F.pad(
34
- waveform, (0, int(random_time * sample_rate)), mode="constant", value=0
35
- )
36
-
37
- for i in range(len(loudness)):
38
- if loudness[i] > threshold:
39
- break
40
-
41
- start_silent_time = i * 512 / sample_rate
42
-
43
- if start_silent_time > 0.02:
44
- waveform = waveform[:, int((start_silent_time - 0.02) * sample_rate) :]
45
-
46
- torchaudio.save(uri=str(file), src=waveform, sample_rate=sample_rate)
47
-
48
-
49
- @click.command()
50
- @click.argument("source", type=Path)
51
- @click.option("--num-workers", type=int, default=12)
52
- def main(source, num_workers):
53
- files = list(list_files(source, AUDIO_EXTENSIONS, recursive=True))
54
-
55
- with Pool(num_workers) as p:
56
- list(tqdm(p.imap_unordered(process, files), total=len(files)))
57
-
58
-
59
- if __name__ == "__main__":
60
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/vqgan/__pycache__/inference.cpython-310.pyc DELETED
Binary file (3.5 kB)
 
tools/vqgan/create_train_split.py DELETED
@@ -1,83 +0,0 @@
1
- import math
2
- from pathlib import Path
3
- from random import Random
4
-
5
- import click
6
- from loguru import logger
7
- from pydub import AudioSegment
8
- from tqdm import tqdm
9
-
10
- from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
11
-
12
-
13
- @click.command()
14
- @click.argument("root", type=click.Path(exists=True, path_type=Path))
15
- @click.option("--val-ratio", type=float, default=None)
16
- @click.option("--val-count", type=int, default=None)
17
- @click.option("--filelist", default=None, type=Path)
18
- @click.option("--min-duration", default=None, type=float)
19
- @click.option("--max-duration", default=None, type=float)
20
- def main(root, val_ratio, val_count, filelist, min_duration, max_duration):
21
- if filelist:
22
- files = [i[0] for i in load_filelist(filelist)]
23
- else:
24
- files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
25
-
26
- if min_duration is None and max_duration is None:
27
- filtered_files = list(map(str, [file.relative_to(root) for file in files]))
28
- else:
29
- filtered_files = []
30
- for file in tqdm(files):
31
- try:
32
- audio = AudioSegment.from_file(str(file))
33
- duration = len(audio) / 1000.0
34
-
35
- if min_duration is not None and duration < min_duration:
36
- logger.info(
37
- f"Skipping {file} due to duration {duration:.2f} < {min_duration:.2f}"
38
- )
39
- continue
40
-
41
- if max_duration is not None and duration > max_duration:
42
- logger.info(
43
- f"Skipping {file} due to duration {duration:.2f} > {max_duration:.2f}"
44
- )
45
- continue
46
-
47
- filtered_files.append(str(file.relative_to(root)))
48
- except Exception as e:
49
- logger.info(f"Error processing {file}: {e}")
50
-
51
- logger.info(
52
- f"Found {len(files)} files, remaining {len(filtered_files)} files after filtering"
53
- )
54
-
55
- Random(42).shuffle(filtered_files)
56
-
57
- if val_count is None and val_ratio is None:
58
- logger.info("Validation ratio and count not specified, using min(20%, 100)")
59
- val_size = min(100, math.ceil(len(filtered_files) * 0.2))
60
- elif val_count is not None and val_ratio is not None:
61
- logger.error("Cannot specify both val_count and val_ratio")
62
- return
63
- elif val_count is not None:
64
- if val_count < 1 or val_count > len(filtered_files):
65
- logger.error("val_count must be between 1 and number of files")
66
- return
67
- val_size = val_count
68
- else:
69
- val_size = math.ceil(len(filtered_files) * val_ratio)
70
-
71
- logger.info(f"Using {val_size} files for validation")
72
-
73
- with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
74
- f.write("\n".join(filtered_files[val_size:]))
75
-
76
- with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
77
- f.write("\n".join(filtered_files[:val_size]))
78
-
79
- logger.info("Done")
80
-
81
-
82
- if __name__ == "__main__":
83
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/vqgan/extract_vq.py DELETED
@@ -1,227 +0,0 @@
1
- import os
2
- import subprocess as sp
3
- import sys
4
- import time
5
- from datetime import timedelta
6
- from functools import lru_cache
7
- from pathlib import Path
8
- from random import Random
9
-
10
- import click
11
- import numpy as np
12
- import torch
13
- import torchaudio
14
- from hydra import compose, initialize
15
- from hydra.utils import instantiate
16
- from lightning import LightningModule
17
- from loguru import logger
18
- from omegaconf import OmegaConf
19
-
20
- from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
21
-
22
- # register eval resolver
23
- OmegaConf.register_new_resolver("eval", eval)
24
- # This file is used to convert the audio files to text files using the Whisper model.
25
- # It's mainly used to generate the training data for the VQ model.
26
-
27
-
28
- RANK = int(os.environ.get("SLURM_PROCID", 0))
29
- WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
30
-
31
- logger_format = (
32
- "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
33
- "<level>{level: <8}</level> | "
34
- "<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
35
- "{extra[rank]} - <level>{message}</level>"
36
- )
37
- logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
38
- logger.remove()
39
- logger.add(sys.stderr, format=logger_format)
40
-
41
-
42
- @lru_cache(maxsize=1)
43
- def get_model(
44
- config_name: str = "firefly_gan_vq",
45
- checkpoint_path: str = "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
46
- device: str | torch.device = "cuda",
47
- ):
48
- with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
49
- cfg = compose(config_name=config_name)
50
-
51
- model = instantiate(cfg)
52
- state_dict = torch.load(
53
- checkpoint_path,
54
- map_location=device,
55
- )
56
- if "state_dict" in state_dict:
57
- state_dict = state_dict["state_dict"]
58
-
59
- if any("generator" in k for k in state_dict):
60
- state_dict = {
61
- k.replace("generator.", ""): v
62
- for k, v in state_dict.items()
63
- if "generator." in k
64
- }
65
-
66
- model.load_state_dict(state_dict, strict=False)
67
- model.eval()
68
- model.to(device)
69
-
70
- logger.info(f"Loaded model")
71
- return model
72
-
73
-
74
- @torch.inference_mode()
75
- def process_batch(files: list[Path], model) -> float:
76
- wavs = []
77
- audio_lengths = []
78
- new_files = []
79
- max_length = total_time = 0
80
-
81
- for file in files:
82
- try:
83
- wav, sr = torchaudio.load(
84
- str(file), backend="sox" if sys.platform == "linux" else "soundfile"
85
- ) # Need to install libsox-dev
86
- except Exception as e:
87
- logger.error(f"Error reading {file}: {e}")
88
- continue
89
-
90
- if wav.shape[0] > 1:
91
- wav = wav.mean(dim=0, keepdim=True)
92
-
93
- wav = torchaudio.functional.resample(
94
- wav.cuda(), sr, model.spec_transform.sample_rate
95
- )[0]
96
- total_time += len(wav) / model.spec_transform.sample_rate
97
- max_length = max(max_length, len(wav))
98
-
99
- wavs.append(wav)
100
- audio_lengths.append(len(wav))
101
- new_files.append(file)
102
-
103
- files = new_files
104
-
105
- # Pad to max length
106
- for i, wav in enumerate(wavs):
107
- wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
108
-
109
- audios = torch.stack(wavs, dim=0)[:, None]
110
- audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
111
-
112
- # Calculate lengths
113
- indices, feature_lengths = model.encode(audios, audio_lengths)
114
-
115
- # Save to disk
116
- outputs = indices.cpu().numpy()
117
-
118
- for file, length, feature, audio_length in zip(
119
- files, feature_lengths, outputs, audio_lengths
120
- ):
121
- feature = feature[:, :length]
122
-
123
- # (T,)
124
- with open(file.with_suffix(".npy"), "wb") as f:
125
- np.save(f, feature)
126
-
127
- return total_time
128
-
129
-
130
- @click.command()
131
- @click.argument("folder")
132
- @click.option("--num-workers", default=1)
133
- @click.option("--config-name", default="firefly_gan_vq")
134
- @click.option(
135
- "--checkpoint-path",
136
- default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
137
- )
138
- @click.option("--batch-size", default=64)
139
- @click.option("--filelist", default=None, type=Path)
140
- def main(
141
- folder: str,
142
- num_workers: int,
143
- config_name: str,
144
- checkpoint_path: str,
145
- batch_size: int,
146
- filelist: Path,
147
- ):
148
- if num_workers > 1 and WORLD_SIZE != num_workers:
149
- assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
150
-
151
- logger.info(f"Spawning {num_workers} workers")
152
-
153
- if torch.cuda.is_available():
154
- visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
155
- if visible_devices is None:
156
- visible_devices = list(range(torch.cuda.device_count()))
157
- else:
158
- visible_devices = visible_devices.split(",")
159
- else:
160
- # Set to empty string to avoid using GPU
161
- visible_devices = [""]
162
-
163
- processes = []
164
- for i in range(num_workers):
165
- env = os.environ.copy()
166
- env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
167
- env["SLURM_PROCID"] = str(i)
168
- env["SLURM_NTASKS"] = str(num_workers)
169
-
170
- processes.append(
171
- sp.Popen(
172
- [sys.executable] + sys.argv.copy(),
173
- env=env,
174
- )
175
- )
176
-
177
- for p in processes:
178
- p.wait()
179
-
180
- logger.info(f"All workers finished")
181
- return
182
-
183
- # This is a worker
184
- logger.info(f"Starting worker")
185
- if filelist:
186
- files = [i[0] for i in load_filelist(filelist)]
187
- else:
188
- files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False)
189
-
190
- print(f"Found {len(files)} files")
191
- files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()]
192
-
193
- total_files = len(files)
194
- files = files[RANK::WORLD_SIZE]
195
- logger.info(f"Processing {len(files)}/{total_files} files")
196
-
197
- # Batch processing
198
- total_time = 0
199
- begin_time = time.time()
200
- processed_files = 0
201
- model = get_model(config_name, checkpoint_path)
202
-
203
- for n_batch, idx in enumerate(range(0, len(files), batch_size)):
204
- batch = files[idx : idx + batch_size]
205
- batch_time = process_batch(batch, model)
206
-
207
- total_time += batch_time
208
- processed_files += len(batch)
209
-
210
- if (n_batch + 1) % 10 == 0:
211
- eta = (
212
- (time.time() - begin_time)
213
- / processed_files
214
- * (len(files) - processed_files)
215
- )
216
- logger.info(
217
- f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
218
- + f"ETA: {timedelta(seconds=round(eta))}s"
219
- )
220
-
221
- logger.info(
222
- f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
223
- )
224
-
225
-
226
- if __name__ == "__main__":
227
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/vqgan/inference.py DELETED
@@ -1,122 +0,0 @@
1
- from pathlib import Path
2
-
3
- import click
4
- import hydra
5
- import numpy as np
6
- import soundfile as sf
7
- import torch
8
- import torchaudio
9
- from hydra import compose, initialize
10
- from hydra.utils import instantiate
11
- from loguru import logger
12
- from omegaconf import OmegaConf
13
-
14
- from tools.file import AUDIO_EXTENSIONS
15
-
16
- # register eval resolver
17
- OmegaConf.register_new_resolver("eval", eval)
18
-
19
-
20
- def load_model(config_name, checkpoint_path, device="cuda"):
21
- hydra.core.global_hydra.GlobalHydra.instance().clear()
22
- with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
23
- cfg = compose(config_name=config_name)
24
-
25
- model = instantiate(cfg)
26
- state_dict = torch.load(
27
- checkpoint_path,
28
- map_location=device,
29
- )
30
- if "state_dict" in state_dict:
31
- state_dict = state_dict["state_dict"]
32
-
33
- if any("generator" in k for k in state_dict):
34
- state_dict = {
35
- k.replace("generator.", ""): v
36
- for k, v in state_dict.items()
37
- if "generator." in k
38
- }
39
-
40
- result = model.load_state_dict(state_dict, strict=False)
41
- model.eval()
42
- model.to(device)
43
-
44
- logger.info(f"Loaded model: {result}")
45
- return model
46
-
47
-
48
- @torch.no_grad()
49
- @click.command()
50
- @click.option(
51
- "--input-path",
52
- "-i",
53
- default="test.wav",
54
- type=click.Path(exists=True, path_type=Path),
55
- )
56
- @click.option(
57
- "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
58
- )
59
- @click.option("--config-name", default="firefly_gan_vq")
60
- @click.option(
61
- "--checkpoint-path",
62
- default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
63
- )
64
- @click.option(
65
- "--device",
66
- "-d",
67
- default="cuda",
68
- )
69
- def main(input_path, output_path, config_name, checkpoint_path, device):
70
- model = load_model(config_name, checkpoint_path, device=device)
71
-
72
- if input_path.suffix in AUDIO_EXTENSIONS:
73
- logger.info(f"Processing in-place reconstruction of {input_path}")
74
-
75
- # Load audio
76
- audio, sr = torchaudio.load(str(input_path))
77
- if audio.shape[0] > 1:
78
- audio = audio.mean(0, keepdim=True)
79
- audio = torchaudio.functional.resample(
80
- audio, sr, model.spec_transform.sample_rate
81
- )
82
-
83
- audios = audio[None].to(device)
84
- logger.info(
85
- f"Loaded audio with {audios.shape[2] / model.spec_transform.sample_rate:.2f} seconds"
86
- )
87
-
88
- # VQ Encoder
89
- audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long)
90
- indices = model.encode(audios, audio_lengths)[0][0]
91
-
92
- logger.info(f"Generated indices of shape {indices.shape}")
93
-
94
- # Save indices
95
- np.save(output_path.with_suffix(".npy"), indices.cpu().numpy())
96
- elif input_path.suffix == ".npy":
97
- logger.info(f"Processing precomputed indices from {input_path}")
98
- indices = np.load(input_path)
99
- indices = torch.from_numpy(indices).to(device).long()
100
- assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
101
- else:
102
- raise ValueError(f"Unknown input type: {input_path}")
103
-
104
- # Restore
105
- feature_lengths = torch.tensor([indices.shape[1]], device=device)
106
- fake_audios, _ = model.decode(
107
- indices=indices[None], feature_lengths=feature_lengths
108
- )
109
- audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate
110
-
111
- logger.info(
112
- f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
113
- )
114
-
115
- # Save audio
116
- fake_audio = fake_audios[0, 0].float().cpu().numpy()
117
- sf.write(output_path, fake_audio, model.spec_transform.sample_rate)
118
- logger.info(f"Saved audio to {output_path}")
119
-
120
-
121
- if __name__ == "__main__":
122
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/webui.py DELETED
@@ -1,485 +0,0 @@
1
- import gc
2
- import html
3
- import io
4
- import os
5
- import queue
6
- import wave
7
- from argparse import ArgumentParser
8
- from functools import partial
9
- from pathlib import Path
10
-
11
- import gradio as gr
12
- import librosa
13
- import numpy as np
14
- import pyrootutils
15
- import torch
16
- from loguru import logger
17
- from transformers import AutoTokenizer
18
-
19
- pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
20
-
21
-
22
- from fish_speech.i18n import i18n
23
- from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
24
- from fish_speech.utils import autocast_exclude_mps
25
- from tools.api import decode_vq_tokens, encode_reference
26
- from tools.llama.generate import (
27
- GenerateRequest,
28
- GenerateResponse,
29
- WrappedGenerateResponse,
30
- launch_thread_safe_queue,
31
- )
32
- from tools.vqgan.inference import load_model as load_decoder_model
33
-
34
- # Make einx happy
35
- os.environ["EINX_FILTER_TRACEBACK"] = "false"
36
-
37
-
38
- HEADER_MD = f"""# Fish Speech
39
-
40
- {i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}
41
-
42
- {i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.4).")}
43
-
44
- {i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")}
45
-
46
- {i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")}
47
- """
48
-
49
- TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
50
- SPACE_IMPORTED = False
51
-
52
-
53
- def build_html_error_message(error):
54
- return f"""
55
- <div style="color: red;
56
- font-weight: bold;">
57
- {html.escape(str(error))}
58
- </div>
59
- """
60
-
61
-
62
- @torch.inference_mode()
63
- def inference(
64
- text,
65
- enable_reference_audio,
66
- reference_audio,
67
- reference_text,
68
- max_new_tokens,
69
- chunk_length,
70
- top_p,
71
- repetition_penalty,
72
- temperature,
73
- streaming=False,
74
- ):
75
- if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
76
- return (
77
- None,
78
- None,
79
- i18n("Text is too long, please keep it under {} characters.").format(
80
- args.max_gradio_length
81
- ),
82
- )
83
-
84
- # Parse reference audio aka prompt
85
- prompt_tokens = encode_reference(
86
- decoder_model=decoder_model,
87
- reference_audio=reference_audio,
88
- enable_reference_audio=enable_reference_audio,
89
- )
90
-
91
- # LLAMA Inference
92
- request = dict(
93
- device=decoder_model.device,
94
- max_new_tokens=max_new_tokens,
95
- text=text,
96
- top_p=top_p,
97
- repetition_penalty=repetition_penalty,
98
- temperature=temperature,
99
- compile=args.compile,
100
- iterative_prompt=chunk_length > 0,
101
- chunk_length=chunk_length,
102
- max_length=2048,
103
- prompt_tokens=prompt_tokens if enable_reference_audio else None,
104
- prompt_text=reference_text if enable_reference_audio else None,
105
- )
106
-
107
- response_queue = queue.Queue()
108
- llama_queue.put(
109
- GenerateRequest(
110
- request=request,
111
- response_queue=response_queue,
112
- )
113
- )
114
-
115
- if streaming:
116
- yield wav_chunk_header(), None, None
117
-
118
- segments = []
119
-
120
- while True:
121
- result: WrappedGenerateResponse = response_queue.get()
122
- if result.status == "error":
123
- yield None, None, build_html_error_message(result.response)
124
- break
125
-
126
- result: GenerateResponse = result.response
127
- if result.action == "next":
128
- break
129
-
130
- with autocast_exclude_mps(
131
- device_type=decoder_model.device.type, dtype=args.precision
132
- ):
133
- fake_audios = decode_vq_tokens(
134
- decoder_model=decoder_model,
135
- codes=result.codes,
136
- )
137
-
138
- fake_audios = fake_audios.float().cpu().numpy()
139
- segments.append(fake_audios)
140
-
141
- if streaming:
142
- yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None
143
-
144
- if len(segments) == 0:
145
- return (
146
- None,
147
- None,
148
- build_html_error_message(
149
- i18n("No audio generated, please check the input text.")
150
- ),
151
- )
152
-
153
- # No matter streaming or not, we need to return the final audio
154
- audio = np.concatenate(segments, axis=0)
155
- yield None, (decoder_model.spec_transform.sample_rate, audio), None
156
-
157
- if torch.cuda.is_available():
158
- torch.cuda.empty_cache()
159
- gc.collect()
160
-
161
-
162
- inference_stream = partial(inference, streaming=True)
163
-
164
- n_audios = 4
165
-
166
- global_audio_list = []
167
- global_error_list = []
168
-
169
-
170
- def inference_wrapper(
171
- text,
172
- enable_reference_audio,
173
- reference_audio,
174
- reference_text,
175
- max_new_tokens,
176
- chunk_length,
177
- top_p,
178
- repetition_penalty,
179
- temperature,
180
- batch_infer_num,
181
- ):
182
- audios = []
183
- errors = []
184
-
185
- for _ in range(batch_infer_num):
186
- result = inference(
187
- text,
188
- enable_reference_audio,
189
- reference_audio,
190
- reference_text,
191
- max_new_tokens,
192
- chunk_length,
193
- top_p,
194
- repetition_penalty,
195
- temperature,
196
- )
197
-
198
- _, audio_data, error_message = next(result)
199
-
200
- audios.append(
201
- gr.Audio(value=audio_data if audio_data else None, visible=True),
202
- )
203
- errors.append(
204
- gr.HTML(value=error_message if error_message else None, visible=True),
205
- )
206
-
207
- for _ in range(batch_infer_num, n_audios):
208
- audios.append(
209
- gr.Audio(value=None, visible=False),
210
- )
211
- errors.append(
212
- gr.HTML(value=None, visible=False),
213
- )
214
-
215
- return None, *audios, *errors
216
-
217
-
218
- def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
219
- buffer = io.BytesIO()
220
-
221
- with wave.open(buffer, "wb") as wav_file:
222
- wav_file.setnchannels(channels)
223
- wav_file.setsampwidth(bit_depth // 8)
224
- wav_file.setframerate(sample_rate)
225
-
226
- wav_header_bytes = buffer.getvalue()
227
- buffer.close()
228
- return wav_header_bytes
229
-
230
-
231
- def normalize_text(user_input, use_normalization):
232
- if use_normalization:
233
- return ChnNormedText(raw_text=user_input).normalize()
234
- else:
235
- return user_input
236
-
237
-
238
- asr_model = None
239
-
240
-
241
- def build_app():
242
- with gr.Blocks(theme=gr.themes.Base()) as app:
243
- gr.Markdown(HEADER_MD)
244
-
245
- # Use light theme by default
246
- app.load(
247
- None,
248
- None,
249
- js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
250
- % args.theme,
251
- )
252
-
253
- # Inference
254
- with gr.Row():
255
- with gr.Column(scale=3):
256
- text = gr.Textbox(
257
- label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
258
- )
259
- refined_text = gr.Textbox(
260
- label=i18n("Realtime Transform Text"),
261
- placeholder=i18n(
262
- "Normalization Result Preview (Currently Only Chinese)"
263
- ),
264
- lines=5,
265
- interactive=False,
266
- )
267
-
268
- with gr.Row():
269
- if_refine_text = gr.Checkbox(
270
- label=i18n("Text Normalization"),
271
- value=False,
272
- scale=1,
273
- )
274
-
275
- with gr.Row():
276
- with gr.Tab(label=i18n("Advanced Config")):
277
- chunk_length = gr.Slider(
278
- label=i18n("Iterative Prompt Length, 0 means off"),
279
- minimum=50,
280
- maximum=300,
281
- value=200,
282
- step=8,
283
- )
284
-
285
- max_new_tokens = gr.Slider(
286
- label=i18n("Maximum tokens per batch, 0 means no limit"),
287
- minimum=0,
288
- maximum=2048,
289
- value=1024, # 0 means no limit
290
- step=8,
291
- )
292
-
293
- top_p = gr.Slider(
294
- label="Top-P",
295
- minimum=0.6,
296
- maximum=0.9,
297
- value=0.7,
298
- step=0.01,
299
- )
300
-
301
- repetition_penalty = gr.Slider(
302
- label=i18n("Repetition Penalty"),
303
- minimum=1,
304
- maximum=1.5,
305
- value=1.2,
306
- step=0.01,
307
- )
308
-
309
- temperature = gr.Slider(
310
- label="Temperature",
311
- minimum=0.6,
312
- maximum=0.9,
313
- value=0.7,
314
- step=0.01,
315
- )
316
-
317
- with gr.Tab(label=i18n("Reference Audio")):
318
- gr.Markdown(
319
- i18n(
320
- "5 to 10 seconds of reference audio, useful for specifying speaker."
321
- )
322
- )
323
-
324
- enable_reference_audio = gr.Checkbox(
325
- label=i18n("Enable Reference Audio"),
326
- )
327
- reference_audio = gr.Audio(
328
- label=i18n("Reference Audio"),
329
- type="filepath",
330
- )
331
- with gr.Row():
332
- reference_text = gr.Textbox(
333
- label=i18n("Reference Text"),
334
- lines=1,
335
- placeholder="在一无所知中,梦里的一天结束了,一���新的「轮回」便会开始。",
336
- value="",
337
- )
338
- with gr.Tab(label=i18n("Batch Inference")):
339
- batch_infer_num = gr.Slider(
340
- label="Batch infer nums",
341
- minimum=1,
342
- maximum=n_audios,
343
- step=1,
344
- value=1,
345
- )
346
-
347
- with gr.Column(scale=3):
348
- for _ in range(n_audios):
349
- with gr.Row():
350
- error = gr.HTML(
351
- label=i18n("Error Message"),
352
- visible=True if _ == 0 else False,
353
- )
354
- global_error_list.append(error)
355
- with gr.Row():
356
- audio = gr.Audio(
357
- label=i18n("Generated Audio"),
358
- type="numpy",
359
- interactive=False,
360
- visible=True if _ == 0 else False,
361
- )
362
- global_audio_list.append(audio)
363
-
364
- with gr.Row():
365
- stream_audio = gr.Audio(
366
- label=i18n("Streaming Audio"),
367
- streaming=True,
368
- autoplay=True,
369
- interactive=False,
370
- show_download_button=True,
371
- )
372
- with gr.Row():
373
- with gr.Column(scale=3):
374
- generate = gr.Button(
375
- value="\U0001F3A7 " + i18n("Generate"), variant="primary"
376
- )
377
- generate_stream = gr.Button(
378
- value="\U0001F3A7 " + i18n("Streaming Generate"),
379
- variant="primary",
380
- )
381
-
382
- text.input(
383
- fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
384
- )
385
-
386
- # # Submit
387
- generate.click(
388
- inference_wrapper,
389
- [
390
- refined_text,
391
- enable_reference_audio,
392
- reference_audio,
393
- reference_text,
394
- max_new_tokens,
395
- chunk_length,
396
- top_p,
397
- repetition_penalty,
398
- temperature,
399
- batch_infer_num,
400
- ],
401
- [stream_audio, *global_audio_list, *global_error_list],
402
- concurrency_limit=1,
403
- )
404
-
405
- generate_stream.click(
406
- inference_stream,
407
- [
408
- refined_text,
409
- enable_reference_audio,
410
- reference_audio,
411
- reference_text,
412
- max_new_tokens,
413
- chunk_length,
414
- top_p,
415
- repetition_penalty,
416
- temperature,
417
- ],
418
- [stream_audio, global_audio_list[0], global_error_list[0]],
419
- concurrency_limit=10,
420
- )
421
- return app
422
-
423
-
424
- def parse_args():
425
- parser = ArgumentParser()
426
- parser.add_argument(
427
- "--llama-checkpoint-path",
428
- type=Path,
429
- default="checkpoints/fish-speech-1.4",
430
- )
431
- parser.add_argument(
432
- "--decoder-checkpoint-path",
433
- type=Path,
434
- default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
435
- )
436
- parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
437
- parser.add_argument("--device", type=str, default="cuda")
438
- parser.add_argument("--half", action="store_true")
439
- parser.add_argument("--compile", action="store_true")
440
- parser.add_argument("--max-gradio-length", type=int, default=0)
441
- parser.add_argument("--theme", type=str, default="light")
442
-
443
- return parser.parse_args()
444
-
445
-
446
- if __name__ == "__main__":
447
- args = parse_args()
448
- args.precision = torch.half if args.half else torch.bfloat16
449
-
450
- logger.info("Loading Llama model...")
451
- llama_queue = launch_thread_safe_queue(
452
- checkpoint_path=args.llama_checkpoint_path,
453
- device=args.device,
454
- precision=args.precision,
455
- compile=args.compile,
456
- )
457
- logger.info("Llama model loaded, loading VQ-GAN model...")
458
-
459
- decoder_model = load_decoder_model(
460
- config_name=args.decoder_config_name,
461
- checkpoint_path=args.decoder_checkpoint_path,
462
- device=args.device,
463
- )
464
-
465
- logger.info("Decoder model loaded, warming up...")
466
-
467
- # Dry run to check if the model is loaded correctly and avoid the first-time latency
468
- list(
469
- inference(
470
- text="Hello, world!",
471
- enable_reference_audio=False,
472
- reference_audio=None,
473
- reference_text="",
474
- max_new_tokens=1024,
475
- chunk_length=200,
476
- top_p=0.7,
477
- repetition_penalty=1.2,
478
- temperature=0.7,
479
- )
480
- )
481
-
482
- logger.info("Warming up done, launching the web UI...")
483
-
484
- app = build_app()
485
- app.launch(show_api=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/whisper_asr.py DELETED
@@ -1,176 +0,0 @@
1
- """
2
- Used to transcribe all audio files in one folder into another folder.
3
- e.g.
4
- Directory structure:
5
- --pre_data_root
6
- ----SP_1
7
- ------01.wav
8
- ------02.wav
9
- ------......
10
- ----SP_2
11
- ------01.wav
12
- ------02.wav
13
- ------......
14
- Use
15
- python tools/whisper_asr.py --audio-dir pre_data_root/SP_1 --save-dir data/SP_1
16
- to transcribe the first speaker.
17
-
18
- Use
19
- python tools/whisper_asr.py --audio-dir pre_data_root/SP_2 --save-dir data/SP_2
20
- to transcribe the second speaker.
21
-
22
- Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
23
- """
24
-
25
- import re
26
- from pathlib import Path
27
-
28
- import click
29
- import soundfile as sf
30
- from faster_whisper import WhisperModel
31
- from loguru import logger
32
- from pydub import AudioSegment
33
- from tqdm import tqdm
34
-
35
- from tools.file import AUDIO_EXTENSIONS, list_files
36
-
37
-
38
- @click.command()
39
- @click.option("--model-size", default="large-v3", help="Size of the Whisper model")
40
- @click.option(
41
- "--compute-type",
42
- default="float16",
43
- help="Computation Precision of the Whisper model [float16 / int8_float16 / int8]",
44
- )
45
- @click.option("--audio-dir", required=True, help="Directory containing audio files")
46
- @click.option(
47
- "--save-dir", required=True, help="Directory to save processed audio files"
48
- )
49
- @click.option(
50
- "--sample-rate",
51
- default=44100,
52
- type=int,
53
- help="Output sample rate, default to input sample rate",
54
- )
55
- @click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
56
- @click.option("--language", default="auto", help="Language of the transcription")
57
- @click.option("--initial-prompt", default=None, help="Initial prompt for transcribing")
58
- def main(
59
- model_size,
60
- compute_type,
61
- audio_dir,
62
- save_dir,
63
- sample_rate,
64
- device,
65
- language,
66
- initial_prompt,
67
- ):
68
- logger.info("Loading / Downloading Faster Whisper model...")
69
-
70
- model = WhisperModel(
71
- model_size,
72
- device=device,
73
- compute_type=compute_type,
74
- download_root="faster_whisper",
75
- )
76
-
77
- logger.info("Model loaded.")
78
-
79
- save_path = Path(save_dir)
80
- save_path.mkdir(parents=True, exist_ok=True)
81
-
82
- audio_files = list_files(
83
- path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
84
- )
85
-
86
- for file_path in tqdm(audio_files, desc="Processing audio file"):
87
- file_stem = file_path.stem
88
- file_suffix = file_path.suffix
89
-
90
- rel_path = Path(file_path).relative_to(audio_dir)
91
- (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
92
-
93
- audio = AudioSegment.from_file(file_path)
94
-
95
- segments, info = model.transcribe(
96
- file_path,
97
- beam_size=5,
98
- language=None if language == "auto" else language,
99
- initial_prompt=initial_prompt,
100
- )
101
-
102
- print(
103
- "Detected language '%s' with probability %f"
104
- % (info.language, info.language_probability)
105
- )
106
- print("Total len(ms): ", len(audio))
107
-
108
- whole_text = None
109
- for segment in segments:
110
- id, start, end, text = (
111
- segment.id,
112
- segment.start,
113
- segment.end,
114
- segment.text,
115
- )
116
- print("Segment %03d [%.2fs -> %.2fs] %s" % (id, start, end, text))
117
- if not whole_text:
118
- whole_text = text
119
- else:
120
- whole_text += ", " + text
121
-
122
- whole_text += "."
123
-
124
- audio_save_path = save_path / rel_path.parent / f"{file_stem}{file_suffix}"
125
- audio.export(audio_save_path, format=file_suffix[1:])
126
- print(f"Exported {audio_save_path}")
127
-
128
- transcript_save_path = save_path / rel_path.parent / f"{file_stem}.lab"
129
- with open(
130
- transcript_save_path,
131
- "w",
132
- encoding="utf-8",
133
- ) as f:
134
- f.write(whole_text)
135
-
136
-
137
- if __name__ == "__main__":
138
- main()
139
- exit(0)
140
-
141
- audio = AudioSegment.from_wav(
142
- r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav"
143
- )
144
-
145
- model_size = "large-v3"
146
-
147
- model = WhisperModel(
148
- model_size,
149
- device="cuda",
150
- compute_type="float16",
151
- download_root="faster_whisper",
152
- )
153
-
154
- segments, info = model.transcribe(
155
- r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav",
156
- beam_size=5,
157
- )
158
-
159
- print(
160
- "Detected language '%s' with probability %f"
161
- % (info.language, info.language_probability)
162
- )
163
- print("Total len(ms): ", len(audio))
164
-
165
- for i, segment in enumerate(segments):
166
- print(
167
- "Segment %03d [%.2fs -> %.2fs] %s"
168
- % (i, segment.start, segment.end, segment.text)
169
- )
170
- start_ms = int(segment.start * 1000)
171
- end_ms = int(segment.end * 1000)
172
- segment_audio = audio[start_ms:end_ms]
173
- segment_audio.export(f"segment_{i:03d}.wav", format="wav")
174
- print(f"Exported segment_{i:03d}.wav")
175
-
176
- print("All segments have been exported.")