Spaces:
Sleeping
Sleeping
Delete tools
Browse files- tools/__pycache__/api.cpython-310.pyc +0 -0
- tools/__pycache__/commons.cpython-310.pyc +0 -0
- tools/__pycache__/file.cpython-310.pyc +0 -0
- tools/__pycache__/webui.cpython-310.pyc +0 -0
- tools/api.py +0 -440
- tools/auto_rerank.py +0 -159
- tools/commons.py +0 -35
- tools/download_models.py +0 -55
- tools/extract_model.py +0 -21
- tools/file.py +0 -125
- tools/llama/__pycache__/generate.cpython-310.pyc +0 -0
- tools/llama/build_dataset.py +0 -169
- tools/llama/eval_in_context.py +0 -171
- tools/llama/generate.py +0 -714
- tools/llama/merge_lora.py +0 -95
- tools/llama/quantize.py +0 -497
- tools/llama/rebuild_tokenizer.py +0 -57
- tools/msgpack_api.py +0 -34
- tools/post_api.py +0 -205
- tools/sensevoice/README.md +0 -59
- tools/sensevoice/__init__.py +0 -0
- tools/sensevoice/auto_model.py +0 -573
- tools/sensevoice/fun_asr.py +0 -332
- tools/sensevoice/vad_utils.py +0 -61
- tools/smart_pad.py +0 -60
- tools/vqgan/__pycache__/inference.cpython-310.pyc +0 -0
- tools/vqgan/create_train_split.py +0 -83
- tools/vqgan/extract_vq.py +0 -227
- tools/vqgan/inference.py +0 -122
- tools/webui.py +0 -485
- tools/whisper_asr.py +0 -176
tools/__pycache__/api.cpython-310.pyc
DELETED
Binary file (10.1 kB)
|
|
tools/__pycache__/commons.cpython-310.pyc
DELETED
Binary file (1.49 kB)
|
|
tools/__pycache__/file.cpython-310.pyc
DELETED
Binary file (2.99 kB)
|
|
tools/__pycache__/webui.cpython-310.pyc
DELETED
Binary file (9.5 kB)
|
|
tools/api.py
DELETED
@@ -1,440 +0,0 @@
|
|
1 |
-
import base64
|
2 |
-
import io
|
3 |
-
import json
|
4 |
-
import queue
|
5 |
-
import random
|
6 |
-
import sys
|
7 |
-
import traceback
|
8 |
-
import wave
|
9 |
-
from argparse import ArgumentParser
|
10 |
-
from http import HTTPStatus
|
11 |
-
from pathlib import Path
|
12 |
-
from typing import Annotated, Any, Literal, Optional
|
13 |
-
|
14 |
-
import numpy as np
|
15 |
-
import ormsgpack
|
16 |
-
import pyrootutils
|
17 |
-
import soundfile as sf
|
18 |
-
import torch
|
19 |
-
import torchaudio
|
20 |
-
from baize.datastructures import ContentType
|
21 |
-
from kui.asgi import (
|
22 |
-
Body,
|
23 |
-
FactoryClass,
|
24 |
-
HTTPException,
|
25 |
-
HttpRequest,
|
26 |
-
HttpView,
|
27 |
-
JSONResponse,
|
28 |
-
Kui,
|
29 |
-
OpenAPI,
|
30 |
-
StreamResponse,
|
31 |
-
)
|
32 |
-
from kui.asgi.routing import MultimethodRoutes
|
33 |
-
from loguru import logger
|
34 |
-
from pydantic import BaseModel, Field, conint
|
35 |
-
|
36 |
-
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
37 |
-
|
38 |
-
# from fish_speech.models.vqgan.lit_module import VQGAN
|
39 |
-
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
|
40 |
-
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
41 |
-
from fish_speech.utils import autocast_exclude_mps
|
42 |
-
from tools.commons import ServeReferenceAudio, ServeTTSRequest
|
43 |
-
from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
|
44 |
-
from tools.llama.generate import (
|
45 |
-
GenerateRequest,
|
46 |
-
GenerateResponse,
|
47 |
-
WrappedGenerateResponse,
|
48 |
-
launch_thread_safe_queue,
|
49 |
-
)
|
50 |
-
from tools.vqgan.inference import load_model as load_decoder_model
|
51 |
-
|
52 |
-
|
53 |
-
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
54 |
-
buffer = io.BytesIO()
|
55 |
-
|
56 |
-
with wave.open(buffer, "wb") as wav_file:
|
57 |
-
wav_file.setnchannels(channels)
|
58 |
-
wav_file.setsampwidth(bit_depth // 8)
|
59 |
-
wav_file.setframerate(sample_rate)
|
60 |
-
|
61 |
-
wav_header_bytes = buffer.getvalue()
|
62 |
-
buffer.close()
|
63 |
-
return wav_header_bytes
|
64 |
-
|
65 |
-
|
66 |
-
# Define utils for web server
|
67 |
-
async def http_execption_handler(exc: HTTPException):
|
68 |
-
return JSONResponse(
|
69 |
-
dict(
|
70 |
-
statusCode=exc.status_code,
|
71 |
-
message=exc.content,
|
72 |
-
error=HTTPStatus(exc.status_code).phrase,
|
73 |
-
),
|
74 |
-
exc.status_code,
|
75 |
-
exc.headers,
|
76 |
-
)
|
77 |
-
|
78 |
-
|
79 |
-
async def other_exception_handler(exc: "Exception"):
|
80 |
-
traceback.print_exc()
|
81 |
-
|
82 |
-
status = HTTPStatus.INTERNAL_SERVER_ERROR
|
83 |
-
return JSONResponse(
|
84 |
-
dict(statusCode=status, message=str(exc), error=status.phrase),
|
85 |
-
status,
|
86 |
-
)
|
87 |
-
|
88 |
-
|
89 |
-
def load_audio(reference_audio, sr):
|
90 |
-
if len(reference_audio) > 255 or not Path(reference_audio).exists():
|
91 |
-
audio_data = reference_audio
|
92 |
-
reference_audio = io.BytesIO(audio_data)
|
93 |
-
|
94 |
-
waveform, original_sr = torchaudio.load(
|
95 |
-
reference_audio, backend="sox" if sys.platform == "linux" else "soundfile"
|
96 |
-
)
|
97 |
-
|
98 |
-
if waveform.shape[0] > 1:
|
99 |
-
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
100 |
-
|
101 |
-
if original_sr != sr:
|
102 |
-
resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr)
|
103 |
-
waveform = resampler(waveform)
|
104 |
-
|
105 |
-
audio = waveform.squeeze().numpy()
|
106 |
-
return audio
|
107 |
-
|
108 |
-
|
109 |
-
def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
|
110 |
-
if enable_reference_audio and reference_audio is not None:
|
111 |
-
# Load audios, and prepare basic info here
|
112 |
-
reference_audio_content = load_audio(
|
113 |
-
reference_audio, decoder_model.spec_transform.sample_rate
|
114 |
-
)
|
115 |
-
|
116 |
-
audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
|
117 |
-
None, None, :
|
118 |
-
]
|
119 |
-
audio_lengths = torch.tensor(
|
120 |
-
[audios.shape[2]], device=decoder_model.device, dtype=torch.long
|
121 |
-
)
|
122 |
-
logger.info(
|
123 |
-
f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds"
|
124 |
-
)
|
125 |
-
|
126 |
-
# VQ Encoder
|
127 |
-
if isinstance(decoder_model, FireflyArchitecture):
|
128 |
-
prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
|
129 |
-
|
130 |
-
logger.info(f"Encoded prompt: {prompt_tokens.shape}")
|
131 |
-
else:
|
132 |
-
prompt_tokens = None
|
133 |
-
logger.info("No reference audio provided")
|
134 |
-
|
135 |
-
return prompt_tokens
|
136 |
-
|
137 |
-
|
138 |
-
def decode_vq_tokens(
|
139 |
-
*,
|
140 |
-
decoder_model,
|
141 |
-
codes,
|
142 |
-
):
|
143 |
-
feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
|
144 |
-
logger.info(f"VQ features: {codes.shape}")
|
145 |
-
|
146 |
-
if isinstance(decoder_model, FireflyArchitecture):
|
147 |
-
# VQGAN Inference
|
148 |
-
return decoder_model.decode(
|
149 |
-
indices=codes[None],
|
150 |
-
feature_lengths=feature_lengths,
|
151 |
-
)[0].squeeze()
|
152 |
-
|
153 |
-
raise ValueError(f"Unknown model type: {type(decoder_model)}")
|
154 |
-
|
155 |
-
|
156 |
-
routes = MultimethodRoutes(base_class=HttpView)
|
157 |
-
|
158 |
-
|
159 |
-
def get_content_type(audio_format):
|
160 |
-
if audio_format == "wav":
|
161 |
-
return "audio/wav"
|
162 |
-
elif audio_format == "flac":
|
163 |
-
return "audio/flac"
|
164 |
-
elif audio_format == "mp3":
|
165 |
-
return "audio/mpeg"
|
166 |
-
else:
|
167 |
-
return "application/octet-stream"
|
168 |
-
|
169 |
-
|
170 |
-
@torch.inference_mode()
|
171 |
-
def inference(req: ServeTTSRequest):
|
172 |
-
|
173 |
-
idstr: str | None = req.reference_id
|
174 |
-
if idstr is not None:
|
175 |
-
ref_folder = Path("references") / idstr
|
176 |
-
ref_folder.mkdir(parents=True, exist_ok=True)
|
177 |
-
ref_audios = list_files(
|
178 |
-
ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
|
179 |
-
)
|
180 |
-
prompt_tokens = [
|
181 |
-
encode_reference(
|
182 |
-
decoder_model=decoder_model,
|
183 |
-
reference_audio=audio_to_bytes(str(ref_audio)),
|
184 |
-
enable_reference_audio=True,
|
185 |
-
)
|
186 |
-
for ref_audio in ref_audios
|
187 |
-
]
|
188 |
-
prompt_texts = [
|
189 |
-
read_ref_text(str(ref_audio.with_suffix(".lab")))
|
190 |
-
for ref_audio in ref_audios
|
191 |
-
]
|
192 |
-
|
193 |
-
else:
|
194 |
-
# Parse reference audio aka prompt
|
195 |
-
refs = req.references
|
196 |
-
if refs is None:
|
197 |
-
refs = []
|
198 |
-
prompt_tokens = [
|
199 |
-
encode_reference(
|
200 |
-
decoder_model=decoder_model,
|
201 |
-
reference_audio=ref.audio,
|
202 |
-
enable_reference_audio=True,
|
203 |
-
)
|
204 |
-
for ref in refs
|
205 |
-
]
|
206 |
-
prompt_texts = [ref.text for ref in refs]
|
207 |
-
|
208 |
-
# LLAMA Inference
|
209 |
-
request = dict(
|
210 |
-
device=decoder_model.device,
|
211 |
-
max_new_tokens=req.max_new_tokens,
|
212 |
-
text=(
|
213 |
-
req.text
|
214 |
-
if not req.normalize
|
215 |
-
else ChnNormedText(raw_text=req.text).normalize()
|
216 |
-
),
|
217 |
-
top_p=req.top_p,
|
218 |
-
repetition_penalty=req.repetition_penalty,
|
219 |
-
temperature=req.temperature,
|
220 |
-
compile=args.compile,
|
221 |
-
iterative_prompt=req.chunk_length > 0,
|
222 |
-
chunk_length=req.chunk_length,
|
223 |
-
max_length=2048,
|
224 |
-
prompt_tokens=prompt_tokens,
|
225 |
-
prompt_text=prompt_texts,
|
226 |
-
)
|
227 |
-
|
228 |
-
response_queue = queue.Queue()
|
229 |
-
llama_queue.put(
|
230 |
-
GenerateRequest(
|
231 |
-
request=request,
|
232 |
-
response_queue=response_queue,
|
233 |
-
)
|
234 |
-
)
|
235 |
-
|
236 |
-
if req.streaming:
|
237 |
-
yield wav_chunk_header()
|
238 |
-
|
239 |
-
segments = []
|
240 |
-
while True:
|
241 |
-
result: WrappedGenerateResponse = response_queue.get()
|
242 |
-
if result.status == "error":
|
243 |
-
raise result.response
|
244 |
-
break
|
245 |
-
|
246 |
-
result: GenerateResponse = result.response
|
247 |
-
if result.action == "next":
|
248 |
-
break
|
249 |
-
|
250 |
-
with autocast_exclude_mps(
|
251 |
-
device_type=decoder_model.device.type, dtype=args.precision
|
252 |
-
):
|
253 |
-
fake_audios = decode_vq_tokens(
|
254 |
-
decoder_model=decoder_model,
|
255 |
-
codes=result.codes,
|
256 |
-
)
|
257 |
-
|
258 |
-
fake_audios = fake_audios.float().cpu().numpy()
|
259 |
-
|
260 |
-
if req.streaming:
|
261 |
-
yield (fake_audios * 32768).astype(np.int16).tobytes()
|
262 |
-
else:
|
263 |
-
segments.append(fake_audios)
|
264 |
-
|
265 |
-
if req.streaming:
|
266 |
-
return
|
267 |
-
|
268 |
-
if len(segments) == 0:
|
269 |
-
raise HTTPException(
|
270 |
-
HTTPStatus.INTERNAL_SERVER_ERROR,
|
271 |
-
content="No audio generated, please check the input text.",
|
272 |
-
)
|
273 |
-
|
274 |
-
fake_audios = np.concatenate(segments, axis=0)
|
275 |
-
yield fake_audios
|
276 |
-
|
277 |
-
|
278 |
-
async def inference_async(req: ServeTTSRequest):
|
279 |
-
for chunk in inference(req):
|
280 |
-
yield chunk
|
281 |
-
|
282 |
-
|
283 |
-
async def buffer_to_async_generator(buffer):
|
284 |
-
yield buffer
|
285 |
-
|
286 |
-
|
287 |
-
@routes.http.post("/v1/tts")
|
288 |
-
async def api_invoke_model(
|
289 |
-
req: Annotated[ServeTTSRequest, Body(exclusive=True)],
|
290 |
-
):
|
291 |
-
"""
|
292 |
-
Invoke model and generate audio
|
293 |
-
"""
|
294 |
-
|
295 |
-
if args.max_text_length > 0 and len(req.text) > args.max_text_length:
|
296 |
-
raise HTTPException(
|
297 |
-
HTTPStatus.BAD_REQUEST,
|
298 |
-
content=f"Text is too long, max length is {args.max_text_length}",
|
299 |
-
)
|
300 |
-
|
301 |
-
if req.streaming and req.format != "wav":
|
302 |
-
raise HTTPException(
|
303 |
-
HTTPStatus.BAD_REQUEST,
|
304 |
-
content="Streaming only supports WAV format",
|
305 |
-
)
|
306 |
-
|
307 |
-
if req.streaming:
|
308 |
-
return StreamResponse(
|
309 |
-
iterable=inference_async(req),
|
310 |
-
headers={
|
311 |
-
"Content-Disposition": f"attachment; filename=audio.{req.format}",
|
312 |
-
},
|
313 |
-
content_type=get_content_type(req.format),
|
314 |
-
)
|
315 |
-
else:
|
316 |
-
fake_audios = next(inference(req))
|
317 |
-
buffer = io.BytesIO()
|
318 |
-
sf.write(
|
319 |
-
buffer,
|
320 |
-
fake_audios,
|
321 |
-
decoder_model.spec_transform.sample_rate,
|
322 |
-
format=req.format,
|
323 |
-
)
|
324 |
-
|
325 |
-
return StreamResponse(
|
326 |
-
iterable=buffer_to_async_generator(buffer.getvalue()),
|
327 |
-
headers={
|
328 |
-
"Content-Disposition": f"attachment; filename=audio.{req.format}",
|
329 |
-
},
|
330 |
-
content_type=get_content_type(req.format),
|
331 |
-
)
|
332 |
-
|
333 |
-
|
334 |
-
@routes.http.post("/v1/health")
|
335 |
-
async def api_health():
|
336 |
-
"""
|
337 |
-
Health check
|
338 |
-
"""
|
339 |
-
|
340 |
-
return JSONResponse({"status": "ok"})
|
341 |
-
|
342 |
-
|
343 |
-
def parse_args():
|
344 |
-
parser = ArgumentParser()
|
345 |
-
parser.add_argument(
|
346 |
-
"--llama-checkpoint-path",
|
347 |
-
type=str,
|
348 |
-
default="checkpoints/fish-speech-1.4",
|
349 |
-
)
|
350 |
-
parser.add_argument(
|
351 |
-
"--decoder-checkpoint-path",
|
352 |
-
type=str,
|
353 |
-
default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
354 |
-
)
|
355 |
-
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
356 |
-
parser.add_argument("--device", type=str, default="cuda")
|
357 |
-
parser.add_argument("--half", action="store_true")
|
358 |
-
parser.add_argument("--compile", action="store_true")
|
359 |
-
parser.add_argument("--max-text-length", type=int, default=0)
|
360 |
-
parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
|
361 |
-
parser.add_argument("--workers", type=int, default=1)
|
362 |
-
|
363 |
-
return parser.parse_args()
|
364 |
-
|
365 |
-
|
366 |
-
# Define Kui app
|
367 |
-
openapi = OpenAPI(
|
368 |
-
{
|
369 |
-
"title": "Fish Speech API",
|
370 |
-
},
|
371 |
-
).routes
|
372 |
-
|
373 |
-
|
374 |
-
class MsgPackRequest(HttpRequest):
|
375 |
-
async def data(self) -> Annotated[Any, ContentType("application/msgpack")]:
|
376 |
-
if self.content_type == "application/msgpack":
|
377 |
-
return ormsgpack.unpackb(await self.body)
|
378 |
-
|
379 |
-
raise HTTPException(
|
380 |
-
HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
|
381 |
-
headers={"Accept": "application/msgpack"},
|
382 |
-
)
|
383 |
-
|
384 |
-
|
385 |
-
app = Kui(
|
386 |
-
routes=routes + openapi[1:], # Remove the default route
|
387 |
-
exception_handlers={
|
388 |
-
HTTPException: http_execption_handler,
|
389 |
-
Exception: other_exception_handler,
|
390 |
-
},
|
391 |
-
factory_class=FactoryClass(http=MsgPackRequest),
|
392 |
-
cors_config={},
|
393 |
-
)
|
394 |
-
|
395 |
-
|
396 |
-
if __name__ == "__main__":
|
397 |
-
|
398 |
-
import uvicorn
|
399 |
-
|
400 |
-
args = parse_args()
|
401 |
-
args.precision = torch.half if args.half else torch.bfloat16
|
402 |
-
|
403 |
-
logger.info("Loading Llama model...")
|
404 |
-
llama_queue = launch_thread_safe_queue(
|
405 |
-
checkpoint_path=args.llama_checkpoint_path,
|
406 |
-
device=args.device,
|
407 |
-
precision=args.precision,
|
408 |
-
compile=args.compile,
|
409 |
-
)
|
410 |
-
logger.info("Llama model loaded, loading VQ-GAN model...")
|
411 |
-
|
412 |
-
decoder_model = load_decoder_model(
|
413 |
-
config_name=args.decoder_config_name,
|
414 |
-
checkpoint_path=args.decoder_checkpoint_path,
|
415 |
-
device=args.device,
|
416 |
-
)
|
417 |
-
|
418 |
-
logger.info("VQ-GAN model loaded, warming up...")
|
419 |
-
|
420 |
-
# Dry run to check if the model is loaded correctly and avoid the first-time latency
|
421 |
-
list(
|
422 |
-
inference(
|
423 |
-
ServeTTSRequest(
|
424 |
-
text="Hello world.",
|
425 |
-
references=[],
|
426 |
-
reference_id=None,
|
427 |
-
max_new_tokens=1024,
|
428 |
-
chunk_length=200,
|
429 |
-
top_p=0.7,
|
430 |
-
repetition_penalty=1.2,
|
431 |
-
temperature=0.7,
|
432 |
-
emotion=None,
|
433 |
-
format="wav",
|
434 |
-
)
|
435 |
-
)
|
436 |
-
)
|
437 |
-
|
438 |
-
logger.info(f"Warming up done, starting server at http://{args.listen}")
|
439 |
-
host, port = args.listen.split(":")
|
440 |
-
uvicorn.run(app, host=host, port=int(port), workers=args.workers, log_level="info")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/auto_rerank.py
DELETED
@@ -1,159 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
|
3 |
-
os.environ["MODELSCOPE_CACHE"] = ".cache/"
|
4 |
-
|
5 |
-
import string
|
6 |
-
import time
|
7 |
-
from threading import Lock
|
8 |
-
|
9 |
-
import librosa
|
10 |
-
import numpy as np
|
11 |
-
import opencc
|
12 |
-
import torch
|
13 |
-
from faster_whisper import WhisperModel
|
14 |
-
|
15 |
-
t2s_converter = opencc.OpenCC("t2s")
|
16 |
-
|
17 |
-
|
18 |
-
def load_model(*, device="cuda"):
|
19 |
-
model = WhisperModel(
|
20 |
-
"medium",
|
21 |
-
device=device,
|
22 |
-
compute_type="float16",
|
23 |
-
download_root="faster_whisper",
|
24 |
-
)
|
25 |
-
print("faster_whisper loaded!")
|
26 |
-
return model
|
27 |
-
|
28 |
-
|
29 |
-
@torch.no_grad()
|
30 |
-
def batch_asr_internal(model: WhisperModel, audios, sr):
|
31 |
-
resampled_audios = []
|
32 |
-
for audio in audios:
|
33 |
-
|
34 |
-
if isinstance(audio, np.ndarray):
|
35 |
-
audio = torch.from_numpy(audio).float()
|
36 |
-
|
37 |
-
if audio.dim() > 1:
|
38 |
-
audio = audio.squeeze()
|
39 |
-
|
40 |
-
assert audio.dim() == 1
|
41 |
-
audio_np = audio.numpy()
|
42 |
-
resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000)
|
43 |
-
resampled_audios.append(resampled_audio)
|
44 |
-
|
45 |
-
trans_results = []
|
46 |
-
|
47 |
-
for resampled_audio in resampled_audios:
|
48 |
-
segments, info = model.transcribe(
|
49 |
-
resampled_audio,
|
50 |
-
language=None,
|
51 |
-
beam_size=5,
|
52 |
-
initial_prompt="Punctuation is needed in any language.",
|
53 |
-
)
|
54 |
-
trans_results.append(list(segments))
|
55 |
-
|
56 |
-
results = []
|
57 |
-
for trans_res, audio in zip(trans_results, audios):
|
58 |
-
|
59 |
-
duration = len(audio) / sr * 1000
|
60 |
-
huge_gap = False
|
61 |
-
max_gap = 0.0
|
62 |
-
|
63 |
-
text = None
|
64 |
-
last_tr = None
|
65 |
-
|
66 |
-
for tr in trans_res:
|
67 |
-
delta = tr.text.strip()
|
68 |
-
if tr.id > 1:
|
69 |
-
max_gap = max(tr.start - last_tr.end, max_gap)
|
70 |
-
text += delta
|
71 |
-
else:
|
72 |
-
text = delta
|
73 |
-
|
74 |
-
last_tr = tr
|
75 |
-
if max_gap > 3.0:
|
76 |
-
huge_gap = True
|
77 |
-
break
|
78 |
-
|
79 |
-
sim_text = t2s_converter.convert(text)
|
80 |
-
results.append(
|
81 |
-
{
|
82 |
-
"text": sim_text,
|
83 |
-
"duration": duration,
|
84 |
-
"huge_gap": huge_gap,
|
85 |
-
}
|
86 |
-
)
|
87 |
-
|
88 |
-
return results
|
89 |
-
|
90 |
-
|
91 |
-
global_lock = Lock()
|
92 |
-
|
93 |
-
|
94 |
-
def batch_asr(model, audios, sr):
|
95 |
-
return batch_asr_internal(model, audios, sr)
|
96 |
-
|
97 |
-
|
98 |
-
def is_chinese(text):
|
99 |
-
return True
|
100 |
-
|
101 |
-
|
102 |
-
def calculate_wer(text1, text2, debug=False):
|
103 |
-
chars1 = remove_punctuation(text1)
|
104 |
-
chars2 = remove_punctuation(text2)
|
105 |
-
|
106 |
-
m, n = len(chars1), len(chars2)
|
107 |
-
|
108 |
-
if m > n:
|
109 |
-
chars1, chars2 = chars2, chars1
|
110 |
-
m, n = n, m
|
111 |
-
|
112 |
-
prev = list(range(m + 1)) # row 0 distance: [0, 1, 2, ...]
|
113 |
-
curr = [0] * (m + 1)
|
114 |
-
|
115 |
-
for j in range(1, n + 1):
|
116 |
-
curr[0] = j
|
117 |
-
for i in range(1, m + 1):
|
118 |
-
if chars1[i - 1] == chars2[j - 1]:
|
119 |
-
curr[i] = prev[i - 1]
|
120 |
-
else:
|
121 |
-
curr[i] = min(prev[i], curr[i - 1], prev[i - 1]) + 1
|
122 |
-
prev, curr = curr, prev
|
123 |
-
|
124 |
-
edits = prev[m]
|
125 |
-
tot = max(len(chars1), len(chars2))
|
126 |
-
wer = edits / tot
|
127 |
-
|
128 |
-
if debug:
|
129 |
-
print(" gt: ", chars1)
|
130 |
-
print(" pred: ", chars2)
|
131 |
-
print(" edits/tot = wer: ", edits, "/", tot, "=", wer)
|
132 |
-
|
133 |
-
return wer
|
134 |
-
|
135 |
-
|
136 |
-
def remove_punctuation(text):
|
137 |
-
chinese_punctuation = (
|
138 |
-
" \n\t”“!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—"
|
139 |
-
'‛""„‟…‧﹏'
|
140 |
-
)
|
141 |
-
all_punctuation = string.punctuation + chinese_punctuation
|
142 |
-
translator = str.maketrans("", "", all_punctuation)
|
143 |
-
text_without_punctuation = text.translate(translator)
|
144 |
-
return text_without_punctuation
|
145 |
-
|
146 |
-
|
147 |
-
if __name__ == "__main__":
|
148 |
-
model = load_model()
|
149 |
-
audios = [
|
150 |
-
librosa.load("44100.wav", sr=44100)[0],
|
151 |
-
librosa.load("lengyue.wav", sr=44100)[0],
|
152 |
-
]
|
153 |
-
print(np.array(audios[0]))
|
154 |
-
print(batch_asr(model, audios, 44100))
|
155 |
-
|
156 |
-
start_time = time.time()
|
157 |
-
for _ in range(10):
|
158 |
-
print(batch_asr(model, audios, 44100))
|
159 |
-
print("Time taken:", time.time() - start_time)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/commons.py
DELETED
@@ -1,35 +0,0 @@
|
|
1 |
-
from typing import Annotated, Literal, Optional
|
2 |
-
|
3 |
-
from pydantic import BaseModel, Field, conint
|
4 |
-
|
5 |
-
|
6 |
-
class ServeReferenceAudio(BaseModel):
|
7 |
-
audio: bytes
|
8 |
-
text: str
|
9 |
-
|
10 |
-
|
11 |
-
class ServeTTSRequest(BaseModel):
|
12 |
-
text: str
|
13 |
-
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
|
14 |
-
# Audio format
|
15 |
-
format: Literal["wav", "pcm", "mp3"] = "wav"
|
16 |
-
mp3_bitrate: Literal[64, 128, 192] = 128
|
17 |
-
# References audios for in-context learning
|
18 |
-
references: list[ServeReferenceAudio] = []
|
19 |
-
# Reference id
|
20 |
-
# For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
|
21 |
-
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
|
22 |
-
reference_id: str | None = None
|
23 |
-
# Normalize text for en & zh, this increase stability for numbers
|
24 |
-
normalize: bool = True
|
25 |
-
mp3_bitrate: Optional[int] = 64
|
26 |
-
opus_bitrate: Optional[int] = -1000
|
27 |
-
# Balance mode will reduce latency to 300ms, but may decrease stability
|
28 |
-
latency: Literal["normal", "balanced"] = "normal"
|
29 |
-
# not usually used below
|
30 |
-
streaming: bool = False
|
31 |
-
emotion: Optional[str] = None
|
32 |
-
max_new_tokens: int = 1024
|
33 |
-
top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
34 |
-
repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
|
35 |
-
temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/download_models.py
DELETED
@@ -1,55 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
|
3 |
-
from huggingface_hub import hf_hub_download
|
4 |
-
|
5 |
-
|
6 |
-
# Download
|
7 |
-
def check_and_download_files(repo_id, file_list, local_dir):
|
8 |
-
os.makedirs(local_dir, exist_ok=True)
|
9 |
-
for file in file_list:
|
10 |
-
file_path = os.path.join(local_dir, file)
|
11 |
-
if not os.path.exists(file_path):
|
12 |
-
print(f"{file} 不存在,从 Hugging Face 仓库下载...")
|
13 |
-
hf_hub_download(
|
14 |
-
repo_id=repo_id,
|
15 |
-
filename=file,
|
16 |
-
resume_download=True,
|
17 |
-
local_dir=local_dir,
|
18 |
-
local_dir_use_symlinks=False,
|
19 |
-
)
|
20 |
-
else:
|
21 |
-
print(f"{file} 已存在,跳过下载。")
|
22 |
-
|
23 |
-
|
24 |
-
# 1st
|
25 |
-
repo_id_1 = "fishaudio/fish-speech-1.4"
|
26 |
-
local_dir_1 = "./checkpoints/fish-speech-1.4"
|
27 |
-
files_1 = [
|
28 |
-
"model.pth",
|
29 |
-
"README.md",
|
30 |
-
"special_tokens_map.json",
|
31 |
-
"tokenizer_config.json",
|
32 |
-
"tokenizer.json",
|
33 |
-
"config.json",
|
34 |
-
"firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
35 |
-
]
|
36 |
-
|
37 |
-
# 3rd
|
38 |
-
repo_id_3 = "fishaudio/fish-speech-1"
|
39 |
-
local_dir_3 = "./"
|
40 |
-
files_3 = [
|
41 |
-
"ffmpeg.exe",
|
42 |
-
"ffprobe.exe",
|
43 |
-
]
|
44 |
-
|
45 |
-
# 4th
|
46 |
-
repo_id_4 = "SpicyqSama007/fish-speech-packed"
|
47 |
-
local_dir_4 = "./"
|
48 |
-
files_4 = [
|
49 |
-
"asr-label-win-x64.exe",
|
50 |
-
]
|
51 |
-
|
52 |
-
check_and_download_files(repo_id_1, files_1, local_dir_1)
|
53 |
-
|
54 |
-
check_and_download_files(repo_id_3, files_3, local_dir_3)
|
55 |
-
check_and_download_files(repo_id_4, files_4, local_dir_4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/extract_model.py
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
import click
|
2 |
-
import torch
|
3 |
-
from loguru import logger
|
4 |
-
|
5 |
-
|
6 |
-
@click.command()
|
7 |
-
@click.argument("model_path")
|
8 |
-
@click.argument("output_path")
|
9 |
-
def main(model_path, output_path):
|
10 |
-
if model_path == output_path:
|
11 |
-
logger.error("Model path and output path are the same")
|
12 |
-
return
|
13 |
-
|
14 |
-
logger.info(f"Loading model from {model_path}")
|
15 |
-
state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
|
16 |
-
torch.save(state_dict, output_path)
|
17 |
-
logger.info(f"Model saved to {output_path}")
|
18 |
-
|
19 |
-
|
20 |
-
if __name__ == "__main__":
|
21 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/file.py
DELETED
@@ -1,125 +0,0 @@
|
|
1 |
-
import base64
|
2 |
-
from pathlib import Path
|
3 |
-
from typing import Union
|
4 |
-
|
5 |
-
from loguru import logger
|
6 |
-
from natsort import natsorted
|
7 |
-
|
8 |
-
AUDIO_EXTENSIONS = {
|
9 |
-
".mp3",
|
10 |
-
".wav",
|
11 |
-
".flac",
|
12 |
-
".ogg",
|
13 |
-
".m4a",
|
14 |
-
".wma",
|
15 |
-
".aac",
|
16 |
-
".aiff",
|
17 |
-
".aif",
|
18 |
-
".aifc",
|
19 |
-
}
|
20 |
-
|
21 |
-
VIDEO_EXTENSIONS = {
|
22 |
-
".mp4",
|
23 |
-
".avi",
|
24 |
-
}
|
25 |
-
|
26 |
-
|
27 |
-
def audio_to_bytes(file_path):
|
28 |
-
if not file_path or not Path(file_path).exists():
|
29 |
-
return None
|
30 |
-
with open(file_path, "rb") as wav_file:
|
31 |
-
wav = wav_file.read()
|
32 |
-
return wav
|
33 |
-
|
34 |
-
|
35 |
-
def read_ref_text(ref_text):
|
36 |
-
path = Path(ref_text)
|
37 |
-
if path.exists() and path.is_file():
|
38 |
-
with path.open("r", encoding="utf-8") as file:
|
39 |
-
return file.read()
|
40 |
-
return ref_text
|
41 |
-
|
42 |
-
|
43 |
-
def list_files(
|
44 |
-
path: Union[Path, str],
|
45 |
-
extensions: set[str] = None,
|
46 |
-
recursive: bool = False,
|
47 |
-
sort: bool = True,
|
48 |
-
) -> list[Path]:
|
49 |
-
"""List files in a directory.
|
50 |
-
|
51 |
-
Args:
|
52 |
-
path (Path): Path to the directory.
|
53 |
-
extensions (set, optional): Extensions to filter. Defaults to None.
|
54 |
-
recursive (bool, optional): Whether to search recursively. Defaults to False.
|
55 |
-
sort (bool, optional): Whether to sort the files. Defaults to True.
|
56 |
-
|
57 |
-
Returns:
|
58 |
-
list: List of files.
|
59 |
-
"""
|
60 |
-
|
61 |
-
if isinstance(path, str):
|
62 |
-
path = Path(path)
|
63 |
-
|
64 |
-
if not path.exists():
|
65 |
-
raise FileNotFoundError(f"Directory {path} does not exist.")
|
66 |
-
|
67 |
-
files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
|
68 |
-
|
69 |
-
if sort:
|
70 |
-
files = natsorted(files)
|
71 |
-
|
72 |
-
return files
|
73 |
-
|
74 |
-
|
75 |
-
def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
|
76 |
-
"""
|
77 |
-
Load a Bert-VITS2 style filelist.
|
78 |
-
"""
|
79 |
-
|
80 |
-
files = set()
|
81 |
-
results = []
|
82 |
-
count_duplicated, count_not_found = 0, 0
|
83 |
-
|
84 |
-
LANGUAGE_TO_LANGUAGES = {
|
85 |
-
"zh": ["zh", "en"],
|
86 |
-
"jp": ["jp", "en"],
|
87 |
-
"en": ["en"],
|
88 |
-
}
|
89 |
-
|
90 |
-
with open(path, "r", encoding="utf-8") as f:
|
91 |
-
for line in f.readlines():
|
92 |
-
splits = line.strip().split("|", maxsplit=3)
|
93 |
-
if len(splits) != 4:
|
94 |
-
logger.warning(f"Invalid line: {line}")
|
95 |
-
continue
|
96 |
-
|
97 |
-
filename, speaker, language, text = splits
|
98 |
-
file = Path(filename)
|
99 |
-
language = language.strip().lower()
|
100 |
-
|
101 |
-
if language == "ja":
|
102 |
-
language = "jp"
|
103 |
-
|
104 |
-
assert language in ["zh", "jp", "en"], f"Invalid language {language}"
|
105 |
-
languages = LANGUAGE_TO_LANGUAGES[language]
|
106 |
-
|
107 |
-
if file in files:
|
108 |
-
logger.warning(f"Duplicated file: {file}")
|
109 |
-
count_duplicated += 1
|
110 |
-
continue
|
111 |
-
|
112 |
-
if not file.exists():
|
113 |
-
logger.warning(f"File not found: {file}")
|
114 |
-
count_not_found += 1
|
115 |
-
continue
|
116 |
-
|
117 |
-
results.append((file, speaker, languages, text))
|
118 |
-
|
119 |
-
if count_duplicated > 0:
|
120 |
-
logger.warning(f"Total duplicated files: {count_duplicated}")
|
121 |
-
|
122 |
-
if count_not_found > 0:
|
123 |
-
logger.warning(f"Total files not found: {count_not_found}")
|
124 |
-
|
125 |
-
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/llama/__pycache__/generate.cpython-310.pyc
DELETED
Binary file (15.1 kB)
|
|
tools/llama/build_dataset.py
DELETED
@@ -1,169 +0,0 @@
|
|
1 |
-
import itertools
|
2 |
-
import os
|
3 |
-
import re
|
4 |
-
from collections import defaultdict
|
5 |
-
from functools import partial
|
6 |
-
from multiprocessing import Pool
|
7 |
-
from pathlib import Path
|
8 |
-
|
9 |
-
import click
|
10 |
-
import numpy as np
|
11 |
-
from loguru import logger
|
12 |
-
from tqdm import tqdm
|
13 |
-
|
14 |
-
from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
|
15 |
-
from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
|
16 |
-
from tools.file import load_filelist
|
17 |
-
|
18 |
-
# To avoid CPU overload
|
19 |
-
os.environ["MKL_NUM_THREADS"] = "1"
|
20 |
-
os.environ["OMP_NUM_THREADS"] = "1"
|
21 |
-
|
22 |
-
|
23 |
-
def task_generator_folder(root: Path, text_extension: str):
|
24 |
-
files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
|
25 |
-
files = sorted(files)
|
26 |
-
|
27 |
-
grouped_files = defaultdict(list)
|
28 |
-
for file in tqdm(files, desc=f"Grouping {root}"):
|
29 |
-
p = str(file.parent)
|
30 |
-
speaker = file.parent.name
|
31 |
-
|
32 |
-
try:
|
33 |
-
if isinstance(text_extension, str):
|
34 |
-
texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")]
|
35 |
-
else:
|
36 |
-
texts = [
|
37 |
-
file.with_suffix(ext).read_text(encoding="utf-8")
|
38 |
-
for ext in text_extension
|
39 |
-
]
|
40 |
-
except Exception as e:
|
41 |
-
logger.error(f"Failed to read text {file}: {e}")
|
42 |
-
continue
|
43 |
-
|
44 |
-
grouped_files[p].append((speaker, file, texts))
|
45 |
-
|
46 |
-
logger.info(
|
47 |
-
f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
|
48 |
-
)
|
49 |
-
|
50 |
-
for i in grouped_files.values():
|
51 |
-
subset = [(f, t) for _, f, t in i]
|
52 |
-
yield i[0][0], subset, "folder"
|
53 |
-
|
54 |
-
|
55 |
-
def task_generator_filelist(filelist):
|
56 |
-
grouped_files = defaultdict(list)
|
57 |
-
for filename, speaker, _, text in load_filelist(filelist):
|
58 |
-
grouped_files[speaker].append((Path(filename), [text]))
|
59 |
-
|
60 |
-
logger.info(f"Found {len(grouped_files)} groups in {filelist}")
|
61 |
-
for speaker, values in grouped_files.items():
|
62 |
-
yield speaker, values, "filelist"
|
63 |
-
|
64 |
-
|
65 |
-
def run_task(task):
|
66 |
-
name, subset, source = task
|
67 |
-
|
68 |
-
# Parse the files
|
69 |
-
sentences = []
|
70 |
-
for file, texts in subset:
|
71 |
-
np_file = file.with_suffix(".npy")
|
72 |
-
if np_file.exists() is False:
|
73 |
-
logger.warning(f"Can't find {np_file}")
|
74 |
-
continue
|
75 |
-
|
76 |
-
new_texts = []
|
77 |
-
|
78 |
-
for text in texts:
|
79 |
-
# Simple cleaning: replace { xxx } and < xxx > with space
|
80 |
-
text = re.sub(r"\{.*?\}", " ", text)
|
81 |
-
text = re.sub(r"<.*?>", " ", text)
|
82 |
-
text = re.sub(r"\s+", " ", text)
|
83 |
-
new_texts.append(text)
|
84 |
-
|
85 |
-
try:
|
86 |
-
semantics = np.load(np_file)
|
87 |
-
except Exception as e:
|
88 |
-
logger.error(f"Failed to parse {file}: {e}")
|
89 |
-
continue
|
90 |
-
|
91 |
-
if isinstance(semantics, np.ndarray):
|
92 |
-
semantics = semantics.tolist()
|
93 |
-
|
94 |
-
sentences.append(
|
95 |
-
Sentence(
|
96 |
-
texts=new_texts,
|
97 |
-
semantics=[Semantics(values=s) for s in semantics],
|
98 |
-
)
|
99 |
-
)
|
100 |
-
|
101 |
-
# Pack the sentences
|
102 |
-
return pack_pb_stream(
|
103 |
-
TextData(
|
104 |
-
source=source,
|
105 |
-
name=name,
|
106 |
-
sentences=sentences,
|
107 |
-
)
|
108 |
-
)
|
109 |
-
|
110 |
-
|
111 |
-
@click.command()
|
112 |
-
@click.option(
|
113 |
-
"--input",
|
114 |
-
type=click.Path(path_type=Path),
|
115 |
-
required=True,
|
116 |
-
help="A folder containing the dataset or a filelist",
|
117 |
-
multiple=True,
|
118 |
-
)
|
119 |
-
@click.option(
|
120 |
-
"--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
|
121 |
-
)
|
122 |
-
@click.option("--num-workers", type=int, default=16)
|
123 |
-
@click.option("--text-extension", type=str, default=[".txt"], multiple=True)
|
124 |
-
@click.option(
|
125 |
-
"--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
|
126 |
-
)
|
127 |
-
def main(input, output, num_workers, text_extension, shard_size):
|
128 |
-
generator_fns = []
|
129 |
-
|
130 |
-
for f in input:
|
131 |
-
assert f.exists(), f"{f} not found"
|
132 |
-
|
133 |
-
if f.is_dir():
|
134 |
-
generator_fn = task_generator_folder(f, text_extension)
|
135 |
-
else:
|
136 |
-
generator_fn = task_generator_filelist(f)
|
137 |
-
|
138 |
-
generator_fns.append(generator_fn)
|
139 |
-
|
140 |
-
generator_fn = itertools.chain(*generator_fns)
|
141 |
-
output.mkdir(parents=True, exist_ok=True)
|
142 |
-
|
143 |
-
dataset_fp = None
|
144 |
-
tar_idx = 0
|
145 |
-
written_size = 0
|
146 |
-
|
147 |
-
with Pool(num_workers) as p:
|
148 |
-
for result in tqdm(p.imap_unordered(run_task, generator_fn)):
|
149 |
-
if dataset_fp is None:
|
150 |
-
dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
|
151 |
-
|
152 |
-
dataset_fp.write(result)
|
153 |
-
written_size += len(result)
|
154 |
-
|
155 |
-
if written_size > shard_size * 1024 * 1024:
|
156 |
-
logger.info(f"Finished writing {tar_idx} shards to {output}")
|
157 |
-
dataset_fp.close()
|
158 |
-
dataset_fp = None
|
159 |
-
written_size = 0
|
160 |
-
tar_idx += 1
|
161 |
-
|
162 |
-
if dataset_fp is not None:
|
163 |
-
dataset_fp.close()
|
164 |
-
|
165 |
-
logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
|
166 |
-
|
167 |
-
|
168 |
-
if __name__ == "__main__":
|
169 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/llama/eval_in_context.py
DELETED
@@ -1,171 +0,0 @@
|
|
1 |
-
import pyrootutils
|
2 |
-
import torch
|
3 |
-
import torch.nn.functional as F
|
4 |
-
from matplotlib import pyplot as plt
|
5 |
-
from transformers import AutoTokenizer
|
6 |
-
|
7 |
-
# register eval resolver and root
|
8 |
-
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
9 |
-
|
10 |
-
from torch.utils.data import DataLoader
|
11 |
-
|
12 |
-
from fish_speech.datasets.semantic import AutoAugTextDataset, TextDataCollator
|
13 |
-
from tools.llama.generate import load_model
|
14 |
-
|
15 |
-
|
16 |
-
def smooth(
|
17 |
-
scalars: list[float], weight: float
|
18 |
-
) -> list[float]: # Weight between 0 and 1
|
19 |
-
last = scalars[0] # First value in the plot (first timestep)
|
20 |
-
smoothed = list()
|
21 |
-
for point in scalars:
|
22 |
-
smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value
|
23 |
-
smoothed.append(smoothed_val) # Save it
|
24 |
-
last = smoothed_val # Anchor the last smoothed value
|
25 |
-
|
26 |
-
return smoothed
|
27 |
-
|
28 |
-
|
29 |
-
@torch.inference_mode()
|
30 |
-
def analyze_one_model(loader, config, weight, max_length):
|
31 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
32 |
-
model = load_model(
|
33 |
-
config,
|
34 |
-
weight,
|
35 |
-
device,
|
36 |
-
torch.bfloat16,
|
37 |
-
max_length,
|
38 |
-
compile=False,
|
39 |
-
)[0]
|
40 |
-
|
41 |
-
current_step = 0
|
42 |
-
model.eval()
|
43 |
-
|
44 |
-
semantic_loss_sum = torch.zeros(
|
45 |
-
max_length,
|
46 |
-
dtype=torch.float32,
|
47 |
-
device=device,
|
48 |
-
)
|
49 |
-
counter = torch.zeros(
|
50 |
-
max_length,
|
51 |
-
dtype=torch.long,
|
52 |
-
device=device,
|
53 |
-
)
|
54 |
-
|
55 |
-
for batch in loader:
|
56 |
-
batch = {k: v.to(device) for k, v in batch.items()}
|
57 |
-
|
58 |
-
labels = batch["labels"]
|
59 |
-
outputs = model(
|
60 |
-
inp=batch["inputs"],
|
61 |
-
key_padding_mask=batch["attention_masks"],
|
62 |
-
)
|
63 |
-
|
64 |
-
token_logits = outputs.token_logits
|
65 |
-
codebook_logits = outputs.codebook_logits
|
66 |
-
|
67 |
-
# Generate labels
|
68 |
-
base_loss = F.cross_entropy(
|
69 |
-
token_logits.reshape(-1, token_logits.size(-1)),
|
70 |
-
labels[:, 0].reshape(-1),
|
71 |
-
ignore_index=-100,
|
72 |
-
reduction="none",
|
73 |
-
)
|
74 |
-
|
75 |
-
codebook_labels = labels[:, 1 : 1 + model.config.num_codebooks].mT
|
76 |
-
semantic_loss = F.cross_entropy(
|
77 |
-
codebook_logits.reshape(-1, codebook_logits.size(-1)),
|
78 |
-
codebook_labels.reshape(-1),
|
79 |
-
ignore_index=-100,
|
80 |
-
reduction="none",
|
81 |
-
)
|
82 |
-
|
83 |
-
base_loss = base_loss.reshape(labels[:, 0].shape)
|
84 |
-
semantic_loss = semantic_loss.reshape(codebook_labels.shape)
|
85 |
-
|
86 |
-
semantic_loss_frame = semantic_loss.mean(-1)
|
87 |
-
pad_pos = codebook_labels.sum(-1) == -100 * model.config.num_codebooks
|
88 |
-
|
89 |
-
for loss_sample, pad in zip(semantic_loss_frame, pad_pos):
|
90 |
-
semantic_loss_sum[~pad] += loss_sample[~pad]
|
91 |
-
counter[~pad] += 1
|
92 |
-
|
93 |
-
current_step += 1
|
94 |
-
if current_step == 10:
|
95 |
-
break
|
96 |
-
|
97 |
-
semantic_loss = semantic_loss.cpu()
|
98 |
-
counter = counter.cpu()
|
99 |
-
xs, ys = [], []
|
100 |
-
|
101 |
-
for i, (loss, count) in enumerate(zip(semantic_loss_sum, counter)):
|
102 |
-
if count > 0:
|
103 |
-
xs.append(i)
|
104 |
-
ys.append((loss / count).item()) # for better loss visualization
|
105 |
-
|
106 |
-
smoothed_ys = smooth(ys, 0.95)
|
107 |
-
|
108 |
-
# Unload model
|
109 |
-
del model
|
110 |
-
torch.cuda.empty_cache()
|
111 |
-
|
112 |
-
return xs, ys, smoothed_ys
|
113 |
-
|
114 |
-
|
115 |
-
def main():
|
116 |
-
tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
|
117 |
-
max_length = 4096
|
118 |
-
|
119 |
-
ds = AutoAugTextDataset(
|
120 |
-
["data/protos/sft/云天河"],
|
121 |
-
tokenizer=tokenizer,
|
122 |
-
use_speaker=False,
|
123 |
-
interactive_prob=1.0,
|
124 |
-
max_length=max_length,
|
125 |
-
)
|
126 |
-
|
127 |
-
loader = DataLoader(
|
128 |
-
ds,
|
129 |
-
batch_size=8,
|
130 |
-
collate_fn=TextDataCollator(tokenizer, max_length=max_length),
|
131 |
-
num_workers=0,
|
132 |
-
shuffle=False,
|
133 |
-
)
|
134 |
-
|
135 |
-
plt.figure(figsize=(10, 5), dpi=200)
|
136 |
-
|
137 |
-
plt.xlabel("Frame")
|
138 |
-
plt.ylabel("Loss")
|
139 |
-
plt.yscale("log")
|
140 |
-
plt.title("Semantic Loss")
|
141 |
-
plt.grid(which="both", axis="both")
|
142 |
-
plt.xlim(0, max_length)
|
143 |
-
|
144 |
-
tests = [
|
145 |
-
(
|
146 |
-
"pertrain-medium",
|
147 |
-
"dual_ar_2_codebook_medium",
|
148 |
-
"checkpoints/text2semantic-pretrain-medium-2k-v1.pth",
|
149 |
-
),
|
150 |
-
(
|
151 |
-
"sft-medium",
|
152 |
-
"dual_ar_2_codebook_medium",
|
153 |
-
"checkpoints/text2semantic-sft-medium-v1.1-4k.pth",
|
154 |
-
),
|
155 |
-
(
|
156 |
-
"sft-large",
|
157 |
-
"dual_ar_2_codebook_large",
|
158 |
-
"checkpoints/text2semantic-sft-large-v1.1-4k.pth",
|
159 |
-
),
|
160 |
-
]
|
161 |
-
|
162 |
-
for name, config, weight in tests:
|
163 |
-
xs, _, smoothed_ys = analyze_one_model(loader, config, weight, max_length)
|
164 |
-
plt.plot(xs, smoothed_ys, label=name)
|
165 |
-
|
166 |
-
plt.legend()
|
167 |
-
plt.savefig("semantic_loss.png")
|
168 |
-
|
169 |
-
|
170 |
-
if __name__ == "__main__":
|
171 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/llama/generate.py
DELETED
@@ -1,714 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import queue
|
3 |
-
import threading
|
4 |
-
import time
|
5 |
-
from contextlib import nullcontext
|
6 |
-
from dataclasses import dataclass
|
7 |
-
from pathlib import Path
|
8 |
-
from typing import Literal, Optional, Tuple, Union
|
9 |
-
|
10 |
-
import click
|
11 |
-
import hydra
|
12 |
-
import numpy as np
|
13 |
-
import torch
|
14 |
-
import torch._dynamo.config
|
15 |
-
import torch._inductor.config
|
16 |
-
from loguru import logger
|
17 |
-
from tqdm import tqdm
|
18 |
-
|
19 |
-
from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
|
20 |
-
from fish_speech.text import clean_text, split_text
|
21 |
-
torch.cuda.is_available = lambda: False
|
22 |
-
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
23 |
-
torch._inductor.config.coordinate_descent_tuning = True
|
24 |
-
torch._inductor.config.triton.unique_kernel_names = True
|
25 |
-
|
26 |
-
|
27 |
-
if hasattr(torch._inductor.config, "fx_graph_cache"):
|
28 |
-
# Experimental feature to reduce compilation times, will be on by default in future
|
29 |
-
torch._inductor.config.fx_graph_cache = True
|
30 |
-
|
31 |
-
|
32 |
-
from fish_speech.models.text2semantic.llama import (
|
33 |
-
BaseTransformer,
|
34 |
-
DualARTransformer,
|
35 |
-
NaiveTransformer,
|
36 |
-
)
|
37 |
-
|
38 |
-
|
39 |
-
def multinomial_sample_one_no_sync(
|
40 |
-
probs_sort,
|
41 |
-
): # Does multinomial sampling without a cuda synchronization
|
42 |
-
q = torch.empty_like(probs_sort).exponential_(1)
|
43 |
-
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
44 |
-
|
45 |
-
|
46 |
-
def logits_to_probs(
|
47 |
-
logits,
|
48 |
-
previous_tokens: Optional[torch.Tensor] = None,
|
49 |
-
temperature: torch.Tensor = 1.0,
|
50 |
-
top_p: torch.Tensor = 1.0,
|
51 |
-
repetition_penalty: torch.Tensor = 1.0,
|
52 |
-
) -> torch.Tensor:
|
53 |
-
# Apply repetition penalty
|
54 |
-
if previous_tokens is not None:
|
55 |
-
previous_tokens = previous_tokens.long()
|
56 |
-
score = torch.gather(logits, dim=0, index=previous_tokens)
|
57 |
-
score = torch.where(
|
58 |
-
score < 0, score * repetition_penalty, score / repetition_penalty
|
59 |
-
)
|
60 |
-
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
61 |
-
|
62 |
-
# Apply top-p sampling
|
63 |
-
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
64 |
-
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
65 |
-
sorted_indices_to_remove = cum_probs > top_p
|
66 |
-
sorted_indices_to_remove[0] = False # keep at least one option
|
67 |
-
indices_to_remove = sorted_indices_to_remove.scatter(
|
68 |
-
dim=0, index=sorted_indices, src=sorted_indices_to_remove
|
69 |
-
)
|
70 |
-
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
71 |
-
|
72 |
-
logits = logits / max(temperature, 1e-5)
|
73 |
-
|
74 |
-
probs = torch.nn.functional.softmax(logits, dim=-1)
|
75 |
-
return probs
|
76 |
-
|
77 |
-
|
78 |
-
def sample(
|
79 |
-
logits,
|
80 |
-
previous_tokens: Optional[torch.Tensor] = None,
|
81 |
-
**sampling_kwargs,
|
82 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
83 |
-
probs = logits_to_probs(
|
84 |
-
logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
|
85 |
-
)
|
86 |
-
idx_next = multinomial_sample_one_no_sync(probs)
|
87 |
-
return idx_next, probs
|
88 |
-
|
89 |
-
|
90 |
-
def decode_one_token_ar(
|
91 |
-
model: DualARTransformer,
|
92 |
-
x: torch.Tensor,
|
93 |
-
input_pos: torch.Tensor,
|
94 |
-
previous_tokens: torch.Tensor = None,
|
95 |
-
**sampling_kwargs,
|
96 |
-
) -> torch.Tensor:
|
97 |
-
x = model.forward_generate(x, input_pos)
|
98 |
-
|
99 |
-
sampling_kwargs_main = sampling_kwargs.copy()
|
100 |
-
sampling_kwargs_main["temperature"] = 0.1
|
101 |
-
sampling_kwargs_main["top_p"] = 0.1
|
102 |
-
sampling_kwargs_main["repetition_penalty"] = 1.0
|
103 |
-
|
104 |
-
codebooks = [
|
105 |
-
sample(
|
106 |
-
x.logits,
|
107 |
-
previous_tokens=None, # Disable repetition penalty for the token codebook
|
108 |
-
**sampling_kwargs_main,
|
109 |
-
)[0]
|
110 |
-
]
|
111 |
-
|
112 |
-
x = x.hidden_states
|
113 |
-
|
114 |
-
# Cleanup the cache
|
115 |
-
for layer in model.fast_layers:
|
116 |
-
layer.attention.kv_cache.k_cache.fill_(0)
|
117 |
-
layer.attention.kv_cache.v_cache.fill_(0)
|
118 |
-
|
119 |
-
for codebook_idx in range(model.config.num_codebooks):
|
120 |
-
input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
|
121 |
-
logits = model.forward_generate_fast(x, input_pos)
|
122 |
-
a = sample(
|
123 |
-
logits,
|
124 |
-
previous_tokens=(
|
125 |
-
previous_tokens[codebook_idx + 1]
|
126 |
-
if previous_tokens is not None
|
127 |
-
else None
|
128 |
-
),
|
129 |
-
**sampling_kwargs,
|
130 |
-
)[0]
|
131 |
-
x = model.fast_embeddings(a)
|
132 |
-
codebooks.append(a)
|
133 |
-
|
134 |
-
return torch.stack(codebooks, dim=0)
|
135 |
-
|
136 |
-
|
137 |
-
def decode_one_token_naive(
|
138 |
-
model: NaiveTransformer,
|
139 |
-
x: torch.Tensor,
|
140 |
-
input_pos: torch.Tensor,
|
141 |
-
previous_tokens: torch.Tensor = None,
|
142 |
-
**sampling_kwargs,
|
143 |
-
) -> torch.Tensor:
|
144 |
-
x = model.forward_generate(x, input_pos)
|
145 |
-
|
146 |
-
sampling_kwargs_main = sampling_kwargs.copy()
|
147 |
-
sampling_kwargs_main["temperature"] = 0.1
|
148 |
-
sampling_kwargs_main["top_p"] = 0.1
|
149 |
-
sampling_kwargs_main["repetition_penalty"] = 1.0
|
150 |
-
|
151 |
-
codebooks = [
|
152 |
-
sample(
|
153 |
-
x.logits,
|
154 |
-
previous_tokens=None, # Disable repetition penalty for the token codebook
|
155 |
-
**sampling_kwargs_main,
|
156 |
-
)[0]
|
157 |
-
]
|
158 |
-
|
159 |
-
for i in range(model.config.num_codebooks):
|
160 |
-
codebooks.append(
|
161 |
-
sample(
|
162 |
-
x.codebook_logits[:, :, i],
|
163 |
-
previous_tokens=(
|
164 |
-
previous_tokens[i + 1] if previous_tokens is not None else None
|
165 |
-
),
|
166 |
-
**sampling_kwargs,
|
167 |
-
)[0]
|
168 |
-
)
|
169 |
-
|
170 |
-
return torch.stack(codebooks, dim=0)
|
171 |
-
|
172 |
-
|
173 |
-
def decode_n_tokens(
|
174 |
-
model: NaiveTransformer,
|
175 |
-
cur_token: torch.Tensor,
|
176 |
-
input_pos: torch.Tensor,
|
177 |
-
num_new_tokens: int,
|
178 |
-
im_end_id: int = 4,
|
179 |
-
decode_one_token=decode_one_token_naive,
|
180 |
-
**sampling_kwargs,
|
181 |
-
):
|
182 |
-
previous_tokens = torch.zeros(
|
183 |
-
(model.config.num_codebooks + 1, model.config.max_seq_len),
|
184 |
-
dtype=torch.int,
|
185 |
-
device=cur_token.device,
|
186 |
-
)
|
187 |
-
|
188 |
-
for i in tqdm(range(num_new_tokens)):
|
189 |
-
# We need to get windowed repeat penalty
|
190 |
-
win_size = 16
|
191 |
-
if i < win_size:
|
192 |
-
window = previous_tokens[:, :win_size]
|
193 |
-
else:
|
194 |
-
window = previous_tokens[:, i - win_size : i]
|
195 |
-
|
196 |
-
with (
|
197 |
-
torch.backends.cuda.sdp_kernel(
|
198 |
-
enable_flash=False, enable_mem_efficient=False, enable_math=True
|
199 |
-
)
|
200 |
-
if torch.cuda.is_available()
|
201 |
-
else nullcontext()
|
202 |
-
): # Actually better for Inductor to codegen attention here
|
203 |
-
next_token = decode_one_token(
|
204 |
-
model=model,
|
205 |
-
x=cur_token,
|
206 |
-
input_pos=input_pos,
|
207 |
-
previous_tokens=window,
|
208 |
-
**sampling_kwargs,
|
209 |
-
)
|
210 |
-
|
211 |
-
input_pos += 1
|
212 |
-
cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
|
213 |
-
previous_tokens[:, i : i + 1] = next_token.view(
|
214 |
-
model.config.num_codebooks + 1, -1
|
215 |
-
)
|
216 |
-
|
217 |
-
if cur_token[0, 0, -1] == im_end_id:
|
218 |
-
break
|
219 |
-
|
220 |
-
return previous_tokens[:, : i + 1]
|
221 |
-
|
222 |
-
|
223 |
-
@torch.no_grad()
|
224 |
-
@torch.inference_mode()
|
225 |
-
def generate(
|
226 |
-
*,
|
227 |
-
model: NaiveTransformer,
|
228 |
-
prompt: torch.Tensor,
|
229 |
-
max_new_tokens: int,
|
230 |
-
im_end_id: int = 4,
|
231 |
-
decode_one_token=decode_one_token_naive,
|
232 |
-
**sampling_kwargs,
|
233 |
-
) -> torch.Tensor:
|
234 |
-
"""
|
235 |
-
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
|
236 |
-
"""
|
237 |
-
|
238 |
-
# create an empty tensor of the expected final shape and fill in the current tokens
|
239 |
-
T = prompt.size(1)
|
240 |
-
|
241 |
-
if max_new_tokens:
|
242 |
-
if T + max_new_tokens > model.config.max_seq_len:
|
243 |
-
max_new_tokens = model.config.max_seq_len - T
|
244 |
-
logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
|
245 |
-
|
246 |
-
T_new = T + max_new_tokens
|
247 |
-
else:
|
248 |
-
T_new = model.config.max_seq_len
|
249 |
-
max_new_tokens = T_new - T
|
250 |
-
|
251 |
-
device, dtype = prompt.device, prompt.dtype
|
252 |
-
with torch.device(device):
|
253 |
-
model.setup_caches(
|
254 |
-
max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype
|
255 |
-
)
|
256 |
-
|
257 |
-
codebook_dim = 1 + model.config.num_codebooks
|
258 |
-
# create an empty tensor of the expected final shape and fill in the current tokens
|
259 |
-
empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
|
260 |
-
empty[:, :T] = prompt
|
261 |
-
seq = empty
|
262 |
-
input_pos = torch.arange(0, T, device=device)
|
263 |
-
|
264 |
-
# Use non-accelerated version for now, to avoid compilation overhead
|
265 |
-
prefill_decode = (
|
266 |
-
decode_one_token_naive
|
267 |
-
if isinstance(model, NaiveTransformer)
|
268 |
-
else decode_one_token_ar
|
269 |
-
)
|
270 |
-
|
271 |
-
next_token = prefill_decode(
|
272 |
-
model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
|
273 |
-
)
|
274 |
-
seq[:, T : T + 1] = next_token
|
275 |
-
|
276 |
-
input_pos = torch.tensor([T], device=device, dtype=torch.int)
|
277 |
-
x = decode_n_tokens(
|
278 |
-
model,
|
279 |
-
next_token.view(1, codebook_dim, -1),
|
280 |
-
input_pos,
|
281 |
-
max_new_tokens - 1,
|
282 |
-
im_end_id=im_end_id,
|
283 |
-
decode_one_token=decode_one_token,
|
284 |
-
**sampling_kwargs,
|
285 |
-
)
|
286 |
-
# x = torch.cat(generated_tokens, dim=1)
|
287 |
-
seq = seq[:, : T + 1 + x.size(1)]
|
288 |
-
seq[:, T + 1 :] = x
|
289 |
-
|
290 |
-
return seq
|
291 |
-
|
292 |
-
|
293 |
-
def encode_tokens(
|
294 |
-
tokenizer,
|
295 |
-
string,
|
296 |
-
device="cuda",
|
297 |
-
prompt_tokens=None,
|
298 |
-
num_codebooks=4,
|
299 |
-
):
|
300 |
-
string = clean_text(string)
|
301 |
-
string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
|
302 |
-
|
303 |
-
new_tokens = tokenizer.encode(
|
304 |
-
string,
|
305 |
-
add_special_tokens=False,
|
306 |
-
max_length=10**6,
|
307 |
-
truncation=False,
|
308 |
-
)
|
309 |
-
tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
|
310 |
-
|
311 |
-
# Codebooks
|
312 |
-
zeros = (
|
313 |
-
torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
|
314 |
-
* CODEBOOK_PAD_TOKEN_ID
|
315 |
-
)
|
316 |
-
prompt = torch.cat((tokens, zeros), dim=0)
|
317 |
-
|
318 |
-
if prompt_tokens is None:
|
319 |
-
return prompt
|
320 |
-
|
321 |
-
# Get prompt tokens
|
322 |
-
if prompt_tokens.ndim == 3:
|
323 |
-
assert (
|
324 |
-
prompt_tokens.shape[0] == 1
|
325 |
-
), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
|
326 |
-
prompt_tokens = prompt_tokens[0]
|
327 |
-
|
328 |
-
assert prompt_tokens.ndim == 2
|
329 |
-
data = prompt_tokens + 1
|
330 |
-
|
331 |
-
if prompt_tokens.shape[0] > num_codebooks:
|
332 |
-
logger.warning(
|
333 |
-
f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
|
334 |
-
)
|
335 |
-
data = data[:num_codebooks]
|
336 |
-
|
337 |
-
# Add pad token for each codebook
|
338 |
-
data = torch.cat(
|
339 |
-
(data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
|
340 |
-
dim=1,
|
341 |
-
)
|
342 |
-
|
343 |
-
# Since 1.0, we use <|semantic|>
|
344 |
-
s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
|
345 |
-
end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
346 |
-
main_token_ids = (
|
347 |
-
torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
|
348 |
-
)
|
349 |
-
main_token_ids[0, -1] = end_token_id
|
350 |
-
|
351 |
-
data = torch.cat((main_token_ids, data), dim=0)
|
352 |
-
prompt = torch.cat((prompt, data), dim=1)
|
353 |
-
|
354 |
-
return prompt
|
355 |
-
|
356 |
-
|
357 |
-
def load_model(checkpoint_path, device, precision, compile=False):
|
358 |
-
model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
|
359 |
-
checkpoint_path, load_weights=True
|
360 |
-
)
|
361 |
-
|
362 |
-
model = model.to(device=device, dtype=precision)
|
363 |
-
logger.info(f"Restored model from checkpoint")
|
364 |
-
|
365 |
-
if isinstance(model, DualARTransformer):
|
366 |
-
decode_one_token = decode_one_token_ar
|
367 |
-
logger.info("Using DualARTransformer")
|
368 |
-
else:
|
369 |
-
decode_one_token = decode_one_token_naive
|
370 |
-
logger.info("Using NaiveTransformer")
|
371 |
-
|
372 |
-
if compile:
|
373 |
-
logger.info("Compiling function...")
|
374 |
-
decode_one_token = torch.compile(
|
375 |
-
decode_one_token,
|
376 |
-
fullgraph=True,
|
377 |
-
backend="inductor" if torch.cuda.is_available() else "aot_eager",
|
378 |
-
mode="reduce-overhead" if torch.cuda.is_available() else None,
|
379 |
-
)
|
380 |
-
|
381 |
-
return model.eval(), decode_one_token
|
382 |
-
|
383 |
-
|
384 |
-
@dataclass
|
385 |
-
class GenerateResponse:
|
386 |
-
action: Literal["sample", "next"]
|
387 |
-
codes: Optional[torch.Tensor] = None
|
388 |
-
text: Optional[str] = None
|
389 |
-
|
390 |
-
|
391 |
-
def generate_long(
|
392 |
-
*,
|
393 |
-
model,
|
394 |
-
device: str | torch.device,
|
395 |
-
decode_one_token: callable,
|
396 |
-
text: str,
|
397 |
-
num_samples: int = 1,
|
398 |
-
max_new_tokens: int = 0,
|
399 |
-
top_p: int = 0.7,
|
400 |
-
repetition_penalty: float = 1.5,
|
401 |
-
temperature: float = 0.7,
|
402 |
-
compile: bool = False,
|
403 |
-
iterative_prompt: bool = True,
|
404 |
-
max_length: int = 2048,
|
405 |
-
chunk_length: int = 150,
|
406 |
-
prompt_text: Optional[str | list[str]] = None,
|
407 |
-
prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
|
408 |
-
):
|
409 |
-
assert 0 < top_p <= 1, "top_p must be in (0, 1]"
|
410 |
-
assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
|
411 |
-
assert 0 < temperature < 2, "temperature must be in (0, 2)"
|
412 |
-
|
413 |
-
use_prompt = prompt_text is not None and prompt_tokens is not None
|
414 |
-
if use_prompt and isinstance(prompt_text, str):
|
415 |
-
prompt_text = [prompt_text]
|
416 |
-
prompt_tokens = [prompt_tokens]
|
417 |
-
|
418 |
-
assert use_prompt is False or len(prompt_text) == len(
|
419 |
-
prompt_tokens
|
420 |
-
), "Prompt text and tokens must have the same length"
|
421 |
-
|
422 |
-
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
423 |
-
tokenizer = model.tokenizer
|
424 |
-
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
425 |
-
|
426 |
-
encoded = []
|
427 |
-
texts = split_text(text, chunk_length) if iterative_prompt else [text]
|
428 |
-
encoded_prompts = []
|
429 |
-
|
430 |
-
if use_prompt:
|
431 |
-
for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
|
432 |
-
encoded_prompts.append(
|
433 |
-
encode_tokens(
|
434 |
-
tokenizer,
|
435 |
-
string=t,
|
436 |
-
device=device,
|
437 |
-
prompt_tokens=c,
|
438 |
-
num_codebooks=model.config.num_codebooks,
|
439 |
-
)
|
440 |
-
)
|
441 |
-
|
442 |
-
for idx, text in enumerate(texts):
|
443 |
-
encoded.append(
|
444 |
-
encode_tokens(
|
445 |
-
tokenizer,
|
446 |
-
string=text,
|
447 |
-
device=device,
|
448 |
-
num_codebooks=model.config.num_codebooks,
|
449 |
-
)
|
450 |
-
)
|
451 |
-
logger.info(f"Encoded text: {text}")
|
452 |
-
|
453 |
-
# Move temperature, top_p, repetition_penalty to device
|
454 |
-
# This is important so that changing params doesn't trigger recompile
|
455 |
-
temperature = torch.tensor(temperature, device=device, dtype=torch.float)
|
456 |
-
top_p = torch.tensor(top_p, device=device, dtype=torch.float)
|
457 |
-
repetition_penalty = torch.tensor(
|
458 |
-
repetition_penalty, device=device, dtype=torch.float
|
459 |
-
)
|
460 |
-
|
461 |
-
for sample_idx in range(num_samples):
|
462 |
-
if torch.cuda.is_available():
|
463 |
-
torch.cuda.synchronize()
|
464 |
-
|
465 |
-
global_encoded = []
|
466 |
-
seg_idx = 0
|
467 |
-
|
468 |
-
while seg_idx < len(encoded):
|
469 |
-
logger.info(
|
470 |
-
f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
|
471 |
-
)
|
472 |
-
|
473 |
-
seg = encoded[seg_idx]
|
474 |
-
global_encoded.append(seg)
|
475 |
-
|
476 |
-
lengths = reversed([seg.size(1) for seg in global_encoded])
|
477 |
-
|
478 |
-
# Pick last 2000 tokens
|
479 |
-
count = 0
|
480 |
-
for i, length in enumerate(lengths):
|
481 |
-
count += length
|
482 |
-
if count + length > max_length - 1024 - sum(
|
483 |
-
t.shape[1] for t in encoded_prompts
|
484 |
-
):
|
485 |
-
break
|
486 |
-
|
487 |
-
if i != 0 and i % 2 == 0:
|
488 |
-
i -= 1
|
489 |
-
|
490 |
-
# Rotate the list, always make sure first segment is included to avoid drift
|
491 |
-
if i < len(global_encoded) - 2:
|
492 |
-
partial_encoded = global_encoded[:2] + global_encoded[-i:]
|
493 |
-
else:
|
494 |
-
partial_encoded = global_encoded
|
495 |
-
|
496 |
-
if use_prompt:
|
497 |
-
partial_encoded = encoded_prompts + partial_encoded
|
498 |
-
|
499 |
-
cat_encoded = torch.cat(partial_encoded, dim=1)
|
500 |
-
prompt_length = cat_encoded.size(1)
|
501 |
-
|
502 |
-
t0 = time.perf_counter()
|
503 |
-
y = generate(
|
504 |
-
model=model,
|
505 |
-
prompt=cat_encoded,
|
506 |
-
max_new_tokens=max_new_tokens,
|
507 |
-
im_end_id=im_end_id,
|
508 |
-
decode_one_token=decode_one_token,
|
509 |
-
temperature=temperature,
|
510 |
-
top_p=top_p,
|
511 |
-
repetition_penalty=repetition_penalty,
|
512 |
-
)
|
513 |
-
|
514 |
-
if sample_idx == 0 and seg_idx == 0 and compile:
|
515 |
-
logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
|
516 |
-
|
517 |
-
if torch.cuda.is_available():
|
518 |
-
torch.cuda.synchronize()
|
519 |
-
|
520 |
-
t = time.perf_counter() - t0
|
521 |
-
|
522 |
-
tokens_generated = y.size(1) - prompt_length
|
523 |
-
tokens_sec = tokens_generated / t
|
524 |
-
logger.info(
|
525 |
-
f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
|
526 |
-
)
|
527 |
-
logger.info(
|
528 |
-
f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
|
529 |
-
)
|
530 |
-
|
531 |
-
if torch.cuda.is_available():
|
532 |
-
logger.info(
|
533 |
-
f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
|
534 |
-
)
|
535 |
-
|
536 |
-
# Put the generated tokens
|
537 |
-
# since there is <im_end> and <eos> tokens, we remove last 2 tokens
|
538 |
-
codes = y[1:, prompt_length:-1].clone()
|
539 |
-
codes = codes - 1
|
540 |
-
assert (codes >= 0).all(), f"Negative code found"
|
541 |
-
|
542 |
-
decoded = y[:, prompt_length:-1].clone()
|
543 |
-
# But for global encoding, we should keep the <im_end> token
|
544 |
-
|
545 |
-
global_encoded.append(decoded)
|
546 |
-
assert (codes >= 0).all(), f"Negative code found: {codes}"
|
547 |
-
yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
|
548 |
-
seg_idx += 1
|
549 |
-
|
550 |
-
# This indicates the end of the current sample
|
551 |
-
yield GenerateResponse(action="next")
|
552 |
-
|
553 |
-
|
554 |
-
@dataclass
|
555 |
-
class WrappedGenerateResponse:
|
556 |
-
status: Literal["success", "error"]
|
557 |
-
response: Optional[GenerateResponse | Exception] = None
|
558 |
-
|
559 |
-
|
560 |
-
@dataclass
|
561 |
-
class GenerateRequest:
|
562 |
-
request: dict
|
563 |
-
response_queue: queue.Queue
|
564 |
-
|
565 |
-
|
566 |
-
def launch_thread_safe_queue(
|
567 |
-
checkpoint_path,
|
568 |
-
device,
|
569 |
-
precision,
|
570 |
-
compile: bool = False,
|
571 |
-
):
|
572 |
-
input_queue = queue.Queue()
|
573 |
-
init_event = threading.Event()
|
574 |
-
|
575 |
-
def worker():
|
576 |
-
model, decode_one_token = load_model(
|
577 |
-
checkpoint_path, device, precision, compile=compile
|
578 |
-
)
|
579 |
-
init_event.set()
|
580 |
-
|
581 |
-
while True:
|
582 |
-
item: GenerateRequest | None = input_queue.get()
|
583 |
-
if item is None:
|
584 |
-
break
|
585 |
-
|
586 |
-
kwargs = item.request
|
587 |
-
response_queue = item.response_queue
|
588 |
-
|
589 |
-
try:
|
590 |
-
for chunk in generate_long(
|
591 |
-
model=model, decode_one_token=decode_one_token, **kwargs
|
592 |
-
):
|
593 |
-
response_queue.put(
|
594 |
-
WrappedGenerateResponse(status="success", response=chunk)
|
595 |
-
)
|
596 |
-
except Exception as e:
|
597 |
-
response_queue.put(WrappedGenerateResponse(status="error", response=e))
|
598 |
-
|
599 |
-
threading.Thread(target=worker, daemon=True).start()
|
600 |
-
init_event.wait()
|
601 |
-
|
602 |
-
return input_queue
|
603 |
-
|
604 |
-
|
605 |
-
@click.command()
|
606 |
-
@click.option(
|
607 |
-
"--text",
|
608 |
-
type=str,
|
609 |
-
default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
|
610 |
-
)
|
611 |
-
@click.option("--prompt-text", type=str, default=None, multiple=True)
|
612 |
-
@click.option(
|
613 |
-
"--prompt-tokens",
|
614 |
-
type=click.Path(path_type=Path, exists=True),
|
615 |
-
default=None,
|
616 |
-
multiple=True,
|
617 |
-
)
|
618 |
-
@click.option("--num-samples", type=int, default=1)
|
619 |
-
@click.option("--max-new-tokens", type=int, default=0)
|
620 |
-
@click.option("--top-p", type=float, default=0.7)
|
621 |
-
@click.option("--repetition-penalty", type=float, default=1.2)
|
622 |
-
@click.option("--temperature", type=float, default=0.7)
|
623 |
-
@click.option(
|
624 |
-
"--checkpoint-path",
|
625 |
-
type=click.Path(path_type=Path, exists=True),
|
626 |
-
default="checkpoints/fish-speech-1.4",
|
627 |
-
)
|
628 |
-
@click.option("--device", type=str, default="cuda")
|
629 |
-
@click.option("--compile/--no-compile", default=False)
|
630 |
-
@click.option("--seed", type=int, default=42)
|
631 |
-
@click.option("--half/--no-half", default=False)
|
632 |
-
@click.option("--iterative-prompt/--no-iterative-prompt", default=True)
|
633 |
-
@click.option("--chunk-length", type=int, default=100)
|
634 |
-
def main(
|
635 |
-
text: str,
|
636 |
-
prompt_text: Optional[list[str]],
|
637 |
-
prompt_tokens: Optional[list[Path]],
|
638 |
-
num_samples: int,
|
639 |
-
max_new_tokens: int,
|
640 |
-
top_p: int,
|
641 |
-
repetition_penalty: float,
|
642 |
-
temperature: float,
|
643 |
-
checkpoint_path: Path,
|
644 |
-
device: str,
|
645 |
-
compile: bool,
|
646 |
-
seed: int,
|
647 |
-
half: bool,
|
648 |
-
iterative_prompt: bool,
|
649 |
-
chunk_length: int,
|
650 |
-
) -> None:
|
651 |
-
|
652 |
-
precision = torch.half if half else torch.bfloat16
|
653 |
-
|
654 |
-
if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
|
655 |
-
raise ValueError(
|
656 |
-
f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
|
657 |
-
)
|
658 |
-
|
659 |
-
logger.info("Loading model ...")
|
660 |
-
t0 = time.time()
|
661 |
-
model, decode_one_token = load_model(
|
662 |
-
checkpoint_path, device, precision, compile=compile
|
663 |
-
)
|
664 |
-
|
665 |
-
if torch.cuda.is_available():
|
666 |
-
torch.cuda.synchronize()
|
667 |
-
|
668 |
-
logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
|
669 |
-
|
670 |
-
if prompt_tokens is not None:
|
671 |
-
prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
|
672 |
-
|
673 |
-
torch.manual_seed(seed)
|
674 |
-
|
675 |
-
if torch.cuda.is_available():
|
676 |
-
torch.cuda.manual_seed(seed)
|
677 |
-
|
678 |
-
generator = generate_long(
|
679 |
-
model=model,
|
680 |
-
device=device,
|
681 |
-
decode_one_token=decode_one_token,
|
682 |
-
text=text,
|
683 |
-
num_samples=num_samples,
|
684 |
-
max_new_tokens=max_new_tokens,
|
685 |
-
top_p=top_p,
|
686 |
-
repetition_penalty=repetition_penalty,
|
687 |
-
temperature=temperature,
|
688 |
-
compile=compile,
|
689 |
-
iterative_prompt=iterative_prompt,
|
690 |
-
chunk_length=chunk_length,
|
691 |
-
prompt_text=prompt_text,
|
692 |
-
prompt_tokens=prompt_tokens,
|
693 |
-
)
|
694 |
-
|
695 |
-
idx = 0
|
696 |
-
codes = []
|
697 |
-
|
698 |
-
for response in generator:
|
699 |
-
if response.action == "sample":
|
700 |
-
codes.append(response.codes)
|
701 |
-
logger.info(f"Sampled text: {response.text}")
|
702 |
-
elif response.action == "next":
|
703 |
-
if codes:
|
704 |
-
np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
|
705 |
-
logger.info(f"Saved codes to codes_{idx}.npy")
|
706 |
-
logger.info(f"Next sample")
|
707 |
-
codes = []
|
708 |
-
idx += 1
|
709 |
-
else:
|
710 |
-
logger.error(f"Error: {response}")
|
711 |
-
|
712 |
-
|
713 |
-
if __name__ == "__main__":
|
714 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/llama/merge_lora.py
DELETED
@@ -1,95 +0,0 @@
|
|
1 |
-
import shutil
|
2 |
-
from copy import deepcopy
|
3 |
-
from pathlib import Path
|
4 |
-
|
5 |
-
import click
|
6 |
-
import hydra
|
7 |
-
import torch
|
8 |
-
from hydra import compose, initialize
|
9 |
-
from hydra.utils import instantiate
|
10 |
-
from loguru import logger
|
11 |
-
|
12 |
-
from fish_speech.models.text2semantic.llama import BaseTransformer
|
13 |
-
from fish_speech.models.text2semantic.lora import get_merged_state_dict
|
14 |
-
|
15 |
-
|
16 |
-
@click.command()
|
17 |
-
@click.option("--lora-config", type=str, default="r_8_alpha_16")
|
18 |
-
@click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.4")
|
19 |
-
@click.option("--lora-weight", type=str, required=True)
|
20 |
-
@click.option("--output", type=str, required=True)
|
21 |
-
def merge(lora_config, base_weight, lora_weight, output):
|
22 |
-
output = Path(output)
|
23 |
-
logger.info(
|
24 |
-
f"Merging {base_weight} and {lora_weight} into {output} with {lora_config}"
|
25 |
-
)
|
26 |
-
|
27 |
-
with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"):
|
28 |
-
cfg = compose(config_name=lora_config)
|
29 |
-
|
30 |
-
lora_config = instantiate(cfg)
|
31 |
-
logger.info(f"Loaded lora model with config {lora_config}")
|
32 |
-
|
33 |
-
llama_model = BaseTransformer.from_pretrained(
|
34 |
-
path=base_weight,
|
35 |
-
load_weights=True,
|
36 |
-
lora_config=lora_config,
|
37 |
-
)
|
38 |
-
logger.info(f"Loaded llama model")
|
39 |
-
|
40 |
-
llama_state_dict = llama_model.state_dict()
|
41 |
-
llama_state_dict = {k: v for k, v in llama_state_dict.items() if "lora" not in k}
|
42 |
-
llama_state_dict_copy = deepcopy(llama_state_dict)
|
43 |
-
lora_state_dict = torch.load(lora_weight, map_location="cpu")
|
44 |
-
|
45 |
-
if "state_dict" in llama_state_dict:
|
46 |
-
llama_state_dict = llama_state_dict["state_dict"]
|
47 |
-
|
48 |
-
if "state_dict" in lora_state_dict:
|
49 |
-
lora_state_dict = lora_state_dict["state_dict"]
|
50 |
-
|
51 |
-
# remove prefix model.
|
52 |
-
if any(k.startswith("model.") for k in llama_state_dict.keys()):
|
53 |
-
llama_state_dict = {
|
54 |
-
k.replace("model.", ""): v
|
55 |
-
for k, v in llama_state_dict.items()
|
56 |
-
if k.startswith("model.")
|
57 |
-
}
|
58 |
-
if any(k.startswith("model.") for k in lora_state_dict.keys()):
|
59 |
-
lora_state_dict = {
|
60 |
-
k.replace("model.", ""): v
|
61 |
-
for k, v in lora_state_dict.items()
|
62 |
-
if k.startswith("model.")
|
63 |
-
}
|
64 |
-
|
65 |
-
logger.info(f"Found {len(llama_state_dict)} keys in llama model")
|
66 |
-
logger.info(f"Found {len(lora_state_dict)} keys in lora model")
|
67 |
-
|
68 |
-
merged_state_dict = llama_state_dict | lora_state_dict
|
69 |
-
llama_model.load_state_dict(merged_state_dict, strict=True)
|
70 |
-
logger.info(f"Merged model loaded")
|
71 |
-
|
72 |
-
# Trigger eval mode to merge lora
|
73 |
-
llama_model.eval()
|
74 |
-
llama_model.save_pretrained(output, drop_lora=True)
|
75 |
-
logger.info(f"Saved merged model to {output}, validating")
|
76 |
-
|
77 |
-
new_state_dict = torch.load(output / "model.pth", map_location="cpu")
|
78 |
-
original_keys = set(llama_state_dict_copy.keys())
|
79 |
-
merged_keys = set(new_state_dict.keys())
|
80 |
-
|
81 |
-
assert original_keys == merged_keys, "Keys should be same"
|
82 |
-
|
83 |
-
for key in original_keys:
|
84 |
-
diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item()
|
85 |
-
if diff_l1 != 0:
|
86 |
-
break
|
87 |
-
else:
|
88 |
-
logger.error("Merged model is same as the original model")
|
89 |
-
exit(1)
|
90 |
-
|
91 |
-
logger.info("Merged model is different from the original model, check passed")
|
92 |
-
|
93 |
-
|
94 |
-
if __name__ == "__main__":
|
95 |
-
merge()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/llama/quantize.py
DELETED
@@ -1,497 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
import datetime
|
4 |
-
import shutil
|
5 |
-
|
6 |
-
# This source code is licensed under the license found in the
|
7 |
-
# LICENSE file in the root directory of this source tree.
|
8 |
-
import time
|
9 |
-
from pathlib import Path
|
10 |
-
|
11 |
-
import click
|
12 |
-
import torch
|
13 |
-
import torch.nn as nn
|
14 |
-
import torch.nn.functional as F
|
15 |
-
|
16 |
-
from fish_speech.models.text2semantic.llama import find_multiple
|
17 |
-
from tools.llama.generate import load_model
|
18 |
-
|
19 |
-
##### Quantization Primitives ######
|
20 |
-
|
21 |
-
|
22 |
-
def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
|
23 |
-
# assumes symmetric quantization
|
24 |
-
# assumes axis == 0
|
25 |
-
# assumes dense memory format
|
26 |
-
# TODO(future): relax ^ as needed
|
27 |
-
|
28 |
-
# default setup for affine quantization of activations
|
29 |
-
eps = torch.finfo(torch.float32).eps
|
30 |
-
|
31 |
-
# get min and max
|
32 |
-
min_val, max_val = torch.aminmax(x, dim=1)
|
33 |
-
|
34 |
-
# calculate scales and zero_points based on min and max
|
35 |
-
# reference: https://fburl.com/code/srbiybme
|
36 |
-
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
|
37 |
-
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
|
38 |
-
device = min_val_neg.device
|
39 |
-
|
40 |
-
# reference: https://fburl.com/code/4wll53rk
|
41 |
-
max_val_pos = torch.max(-min_val_neg, max_val_pos)
|
42 |
-
scales = max_val_pos / (float(quant_max - quant_min) / 2)
|
43 |
-
# ensure scales is the same dtype as the original tensor
|
44 |
-
scales = torch.clamp(scales, min=eps).to(x.dtype)
|
45 |
-
zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
|
46 |
-
|
47 |
-
# quantize based on qmin/qmax/scales/zp
|
48 |
-
# reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
|
49 |
-
x_div = x / scales.unsqueeze(-1)
|
50 |
-
x_round = torch.round(x_div)
|
51 |
-
x_zp = x_round + zero_points.unsqueeze(-1)
|
52 |
-
quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
|
53 |
-
|
54 |
-
return quant, scales, zero_points
|
55 |
-
|
56 |
-
|
57 |
-
def get_group_qparams(w, n_bit=4, groupsize=128):
|
58 |
-
# needed for GPTQ with padding
|
59 |
-
if groupsize > w.shape[-1]:
|
60 |
-
groupsize = w.shape[-1]
|
61 |
-
assert groupsize > 1
|
62 |
-
assert w.shape[-1] % groupsize == 0
|
63 |
-
assert w.dim() == 2
|
64 |
-
|
65 |
-
to_quant = w.reshape(-1, groupsize)
|
66 |
-
assert torch.isnan(to_quant).sum() == 0
|
67 |
-
|
68 |
-
max_val = to_quant.amax(dim=1, keepdim=True)
|
69 |
-
min_val = to_quant.amin(dim=1, keepdim=True)
|
70 |
-
max_int = 2**n_bit - 1
|
71 |
-
scales = (max_val - min_val).clamp(min=1e-6) / max_int
|
72 |
-
zeros = min_val + scales * (2 ** (n_bit - 1))
|
73 |
-
return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
|
74 |
-
torch.bfloat16
|
75 |
-
).reshape(w.shape[0], -1)
|
76 |
-
|
77 |
-
|
78 |
-
def pack_scales_and_zeros(scales, zeros):
|
79 |
-
assert scales.shape == zeros.shape
|
80 |
-
assert scales.dtype == torch.bfloat16
|
81 |
-
assert zeros.dtype == torch.bfloat16
|
82 |
-
return (
|
83 |
-
torch.cat(
|
84 |
-
[
|
85 |
-
scales.reshape(scales.size(0), scales.size(1), 1),
|
86 |
-
zeros.reshape(zeros.size(0), zeros.size(1), 1),
|
87 |
-
],
|
88 |
-
2,
|
89 |
-
)
|
90 |
-
.transpose(0, 1)
|
91 |
-
.contiguous()
|
92 |
-
)
|
93 |
-
|
94 |
-
|
95 |
-
def unpack_scales_and_zeros(scales_and_zeros):
|
96 |
-
assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
|
97 |
-
assert scales_and_zeros.dtype == torch.float
|
98 |
-
return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
|
99 |
-
|
100 |
-
|
101 |
-
def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
|
102 |
-
assert groupsize > 1
|
103 |
-
# needed for GPTQ single column quantize
|
104 |
-
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
|
105 |
-
groupsize = w.shape[-1]
|
106 |
-
|
107 |
-
assert w.shape[-1] % groupsize == 0
|
108 |
-
assert w.dim() == 2
|
109 |
-
|
110 |
-
to_quant = w.reshape(-1, groupsize)
|
111 |
-
assert torch.isnan(to_quant).sum() == 0
|
112 |
-
|
113 |
-
scales = scales.reshape(-1, 1)
|
114 |
-
zeros = zeros.reshape(-1, 1)
|
115 |
-
min_val = zeros - scales * (2 ** (n_bit - 1))
|
116 |
-
max_int = 2**n_bit - 1
|
117 |
-
min_int = 0
|
118 |
-
w_int32 = (
|
119 |
-
to_quant.sub(min_val)
|
120 |
-
.div(scales)
|
121 |
-
.round()
|
122 |
-
.clamp_(min_int, max_int)
|
123 |
-
.to(torch.int32)
|
124 |
-
.reshape_as(w)
|
125 |
-
)
|
126 |
-
|
127 |
-
return w_int32
|
128 |
-
|
129 |
-
|
130 |
-
def group_quantize_tensor(w, n_bit=4, groupsize=128):
|
131 |
-
scales, zeros = get_group_qparams(w, n_bit, groupsize)
|
132 |
-
w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
|
133 |
-
scales_and_zeros = pack_scales_and_zeros(scales, zeros)
|
134 |
-
return w_int32, scales_and_zeros
|
135 |
-
|
136 |
-
|
137 |
-
def group_dequantize_tensor_from_qparams(
|
138 |
-
w_int32, scales, zeros, n_bit=4, groupsize=128
|
139 |
-
):
|
140 |
-
assert groupsize > 1
|
141 |
-
# needed for GPTQ single column dequantize
|
142 |
-
if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
|
143 |
-
groupsize = w_int32.shape[-1]
|
144 |
-
assert w_int32.shape[-1] % groupsize == 0
|
145 |
-
assert w_int32.dim() == 2
|
146 |
-
|
147 |
-
w_int32_grouped = w_int32.reshape(-1, groupsize)
|
148 |
-
scales = scales.reshape(-1, 1)
|
149 |
-
zeros = zeros.reshape(-1, 1)
|
150 |
-
|
151 |
-
w_dq = (
|
152 |
-
w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
|
153 |
-
)
|
154 |
-
return w_dq
|
155 |
-
|
156 |
-
|
157 |
-
def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
|
158 |
-
scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
|
159 |
-
return group_dequantize_tensor_from_qparams(
|
160 |
-
w_int32, scales, zeros, n_bit, groupsize
|
161 |
-
)
|
162 |
-
|
163 |
-
|
164 |
-
class QuantHandler:
|
165 |
-
def __init__(self, mod):
|
166 |
-
self.mod = mod
|
167 |
-
|
168 |
-
def create_quantized_state_dict(self) -> "StateDict":
|
169 |
-
pass
|
170 |
-
|
171 |
-
def convert_for_runtime(self) -> "nn.Module":
|
172 |
-
pass
|
173 |
-
|
174 |
-
|
175 |
-
##### Weight-only int8 per-channel quantized code ######
|
176 |
-
|
177 |
-
|
178 |
-
def replace_linear_weight_only_int8_per_channel(module):
|
179 |
-
for name, child in module.named_children():
|
180 |
-
if isinstance(child, nn.Linear):
|
181 |
-
setattr(
|
182 |
-
module,
|
183 |
-
name,
|
184 |
-
WeightOnlyInt8Linear(child.in_features, child.out_features),
|
185 |
-
)
|
186 |
-
else:
|
187 |
-
replace_linear_weight_only_int8_per_channel(child)
|
188 |
-
|
189 |
-
|
190 |
-
class WeightOnlyInt8QuantHandler:
|
191 |
-
def __init__(self, mod):
|
192 |
-
self.mod = mod
|
193 |
-
|
194 |
-
@torch.no_grad()
|
195 |
-
def create_quantized_state_dict(self):
|
196 |
-
cur_state_dict = self.mod.state_dict()
|
197 |
-
for fqn, mod in self.mod.named_modules():
|
198 |
-
if isinstance(mod, torch.nn.Linear):
|
199 |
-
int8_weight, scales, _ = dynamically_quantize_per_channel(
|
200 |
-
mod.weight.float(), -128, 127, torch.int8
|
201 |
-
)
|
202 |
-
cur_state_dict[f"{fqn}.weight"] = int8_weight
|
203 |
-
cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
|
204 |
-
|
205 |
-
return cur_state_dict
|
206 |
-
|
207 |
-
def convert_for_runtime(self):
|
208 |
-
replace_linear_weight_only_int8_per_channel(self.mod)
|
209 |
-
return self.mod
|
210 |
-
|
211 |
-
|
212 |
-
class WeightOnlyInt8Linear(torch.nn.Module):
|
213 |
-
__constants__ = ["in_features", "out_features"]
|
214 |
-
in_features: int
|
215 |
-
out_features: int
|
216 |
-
weight: torch.Tensor
|
217 |
-
|
218 |
-
def __init__(
|
219 |
-
self,
|
220 |
-
in_features: int,
|
221 |
-
out_features: int,
|
222 |
-
bias: bool = True,
|
223 |
-
device=None,
|
224 |
-
dtype=None,
|
225 |
-
) -> None:
|
226 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
227 |
-
super().__init__()
|
228 |
-
self.in_features = in_features
|
229 |
-
self.out_features = out_features
|
230 |
-
self.register_buffer(
|
231 |
-
"weight", torch.empty((out_features, in_features), dtype=torch.int8)
|
232 |
-
)
|
233 |
-
self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
|
234 |
-
|
235 |
-
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
236 |
-
return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
|
237 |
-
|
238 |
-
|
239 |
-
##### weight only int4 per channel groupwise quantized code ######
|
240 |
-
|
241 |
-
|
242 |
-
def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
|
243 |
-
weight_int32, scales_and_zeros = group_quantize_tensor(
|
244 |
-
weight_bf16, n_bit=4, groupsize=groupsize
|
245 |
-
)
|
246 |
-
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
|
247 |
-
weight_int32, inner_k_tiles
|
248 |
-
)
|
249 |
-
return weight_int4pack, scales_and_zeros
|
250 |
-
|
251 |
-
|
252 |
-
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
|
253 |
-
origin_x_size = x.size()
|
254 |
-
x = x.reshape(-1, origin_x_size[-1])
|
255 |
-
c = torch.ops.aten._weight_int4pack_mm(
|
256 |
-
x, weight_int4pack, groupsize, scales_and_zeros
|
257 |
-
)
|
258 |
-
new_shape = origin_x_size[:-1] + (out_features,)
|
259 |
-
c = c.reshape(new_shape)
|
260 |
-
return c
|
261 |
-
|
262 |
-
|
263 |
-
def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1):
|
264 |
-
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
|
265 |
-
|
266 |
-
|
267 |
-
def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
|
268 |
-
for name, child in module.named_children():
|
269 |
-
if isinstance(child, nn.Linear):
|
270 |
-
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
|
271 |
-
setattr(
|
272 |
-
module,
|
273 |
-
name,
|
274 |
-
WeightOnlyInt4Linear(
|
275 |
-
child.in_features,
|
276 |
-
child.out_features,
|
277 |
-
bias=False,
|
278 |
-
groupsize=groupsize,
|
279 |
-
inner_k_tiles=inner_k_tiles,
|
280 |
-
padding=False,
|
281 |
-
),
|
282 |
-
)
|
283 |
-
elif padding:
|
284 |
-
setattr(
|
285 |
-
module,
|
286 |
-
name,
|
287 |
-
WeightOnlyInt4Linear(
|
288 |
-
child.in_features,
|
289 |
-
child.out_features,
|
290 |
-
bias=False,
|
291 |
-
groupsize=groupsize,
|
292 |
-
inner_k_tiles=inner_k_tiles,
|
293 |
-
padding=True,
|
294 |
-
),
|
295 |
-
)
|
296 |
-
else:
|
297 |
-
replace_linear_int4(child, groupsize, inner_k_tiles, padding)
|
298 |
-
|
299 |
-
|
300 |
-
class WeightOnlyInt4QuantHandler:
|
301 |
-
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
|
302 |
-
self.mod = mod
|
303 |
-
self.groupsize = groupsize
|
304 |
-
self.inner_k_tiles = inner_k_tiles
|
305 |
-
self.padding = padding
|
306 |
-
assert groupsize in [32, 64, 128, 256]
|
307 |
-
assert inner_k_tiles in [2, 4, 8]
|
308 |
-
|
309 |
-
@torch.no_grad()
|
310 |
-
def create_quantized_state_dict(self):
|
311 |
-
cur_state_dict = self.mod.state_dict()
|
312 |
-
for fqn, mod in self.mod.named_modules():
|
313 |
-
if isinstance(mod, torch.nn.Linear):
|
314 |
-
assert not mod.bias
|
315 |
-
out_features = mod.out_features
|
316 |
-
in_features = mod.in_features
|
317 |
-
assert out_features % 8 == 0, "require out_features % 8 == 0"
|
318 |
-
print(f"linear: {fqn}, in={in_features}, out={out_features}")
|
319 |
-
|
320 |
-
weight = mod.weight.data
|
321 |
-
if not _check_linear_int4_k(
|
322 |
-
in_features, self.groupsize, self.inner_k_tiles
|
323 |
-
):
|
324 |
-
if self.padding:
|
325 |
-
import torch.nn.functional as F
|
326 |
-
|
327 |
-
print(
|
328 |
-
f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
|
329 |
-
)
|
330 |
-
padded_in_features = find_multiple(in_features, 1024)
|
331 |
-
weight = F.pad(
|
332 |
-
weight, pad=(0, padded_in_features - in_features)
|
333 |
-
)
|
334 |
-
else:
|
335 |
-
print(
|
336 |
-
f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
|
337 |
-
+ "and that groupsize and inner_k_tiles*16 evenly divide into it"
|
338 |
-
)
|
339 |
-
continue
|
340 |
-
(
|
341 |
-
weight_int4pack,
|
342 |
-
scales_and_zeros,
|
343 |
-
) = prepare_int4_weight_and_scales_and_zeros(
|
344 |
-
weight.to(torch.bfloat16).to("cuda"),
|
345 |
-
self.groupsize,
|
346 |
-
self.inner_k_tiles,
|
347 |
-
)
|
348 |
-
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
|
349 |
-
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu")
|
350 |
-
|
351 |
-
return cur_state_dict
|
352 |
-
|
353 |
-
def convert_for_runtime(self):
|
354 |
-
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
|
355 |
-
return self.mod
|
356 |
-
|
357 |
-
|
358 |
-
class WeightOnlyInt4Linear(torch.nn.Module):
|
359 |
-
__constants__ = ["in_features", "out_features"]
|
360 |
-
in_features: int
|
361 |
-
out_features: int
|
362 |
-
weight: torch.Tensor
|
363 |
-
|
364 |
-
def __init__(
|
365 |
-
self,
|
366 |
-
in_features: int,
|
367 |
-
out_features: int,
|
368 |
-
bias=True,
|
369 |
-
device=None,
|
370 |
-
dtype=None,
|
371 |
-
groupsize: int = 128,
|
372 |
-
inner_k_tiles: int = 8,
|
373 |
-
padding: bool = True,
|
374 |
-
) -> None:
|
375 |
-
super().__init__()
|
376 |
-
self.padding = padding
|
377 |
-
if padding:
|
378 |
-
self.origin_in_features = in_features
|
379 |
-
in_features = find_multiple(in_features, 1024)
|
380 |
-
|
381 |
-
self.in_features = in_features
|
382 |
-
self.out_features = out_features
|
383 |
-
assert not bias, "require bias=False"
|
384 |
-
self.groupsize = groupsize
|
385 |
-
self.inner_k_tiles = inner_k_tiles
|
386 |
-
|
387 |
-
assert out_features % 8 == 0, "require out_features % 8 == 0"
|
388 |
-
assert (
|
389 |
-
in_features % (inner_k_tiles * 16) == 0
|
390 |
-
), "require in_features % (innerKTiles * 16) == 0"
|
391 |
-
self.register_buffer(
|
392 |
-
"weight",
|
393 |
-
torch.empty(
|
394 |
-
(
|
395 |
-
out_features // 8,
|
396 |
-
in_features // (inner_k_tiles * 16),
|
397 |
-
32,
|
398 |
-
inner_k_tiles // 2,
|
399 |
-
),
|
400 |
-
dtype=torch.int32,
|
401 |
-
),
|
402 |
-
)
|
403 |
-
self.register_buffer(
|
404 |
-
"scales_and_zeros",
|
405 |
-
torch.empty(
|
406 |
-
(in_features // groupsize, out_features, 2), dtype=torch.bfloat16
|
407 |
-
),
|
408 |
-
)
|
409 |
-
|
410 |
-
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
411 |
-
input = input.to(torch.bfloat16)
|
412 |
-
if self.padding:
|
413 |
-
import torch.nn.functional as F
|
414 |
-
|
415 |
-
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
|
416 |
-
return linear_forward_int4(
|
417 |
-
input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize
|
418 |
-
)
|
419 |
-
|
420 |
-
|
421 |
-
def generate_folder_name():
|
422 |
-
now = datetime.datetime.now()
|
423 |
-
folder_name = now.strftime("%Y%m%d_%H%M%S")
|
424 |
-
return folder_name
|
425 |
-
|
426 |
-
|
427 |
-
@click.command()
|
428 |
-
@click.option(
|
429 |
-
"--checkpoint-path",
|
430 |
-
type=click.Path(path_type=Path, exists=True),
|
431 |
-
default="checkpoints/fish-speech-1.4",
|
432 |
-
)
|
433 |
-
@click.option(
|
434 |
-
"--mode", type=str, default="int8", help="type of quantization to perform"
|
435 |
-
)
|
436 |
-
@click.option(
|
437 |
-
"--groupsize", type=int, default=128, help="Group size for int4 quantization."
|
438 |
-
)
|
439 |
-
@click.option("--timestamp", type=str, default="None", help="When to do quantization")
|
440 |
-
def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -> None:
|
441 |
-
|
442 |
-
device = "cpu"
|
443 |
-
precision = torch.bfloat16
|
444 |
-
|
445 |
-
print("Loading model ...")
|
446 |
-
t0 = time.time()
|
447 |
-
|
448 |
-
model, _ = load_model(
|
449 |
-
checkpoint_path=checkpoint_path,
|
450 |
-
device=device,
|
451 |
-
precision=precision,
|
452 |
-
compile=False,
|
453 |
-
)
|
454 |
-
vq_model = "firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
|
455 |
-
now = timestamp if timestamp != "None" else generate_folder_name()
|
456 |
-
|
457 |
-
if mode == "int8":
|
458 |
-
print(
|
459 |
-
"Quantizing model weights for int8 weight-only symmetric per-channel quantization"
|
460 |
-
)
|
461 |
-
quant_handler = WeightOnlyInt8QuantHandler(model)
|
462 |
-
quantized_state_dict = quant_handler.create_quantized_state_dict()
|
463 |
-
|
464 |
-
dir_name = checkpoint_path
|
465 |
-
dst_name = Path(f"checkpoints/fs-1.2-int8-{now}")
|
466 |
-
shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
|
467 |
-
if (dst_name / vq_model).exists():
|
468 |
-
(dst_name / vq_model).unlink()
|
469 |
-
quantize_path = dst_name / "model.pth"
|
470 |
-
|
471 |
-
elif mode == "int4":
|
472 |
-
print(
|
473 |
-
"Quantizing model weights for int4 weight-only affine per-channel groupwise quantization"
|
474 |
-
)
|
475 |
-
quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
|
476 |
-
quantized_state_dict = quant_handler.create_quantized_state_dict()
|
477 |
-
|
478 |
-
dir_name = checkpoint_path
|
479 |
-
dst_name = Path(f"checkpoints/fs-1.2-int4-g{groupsize}-{now}")
|
480 |
-
shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
|
481 |
-
if (dst_name / vq_model).exists():
|
482 |
-
(dst_name / vq_model).unlink()
|
483 |
-
quantize_path = dst_name / "model.pth"
|
484 |
-
|
485 |
-
else:
|
486 |
-
raise ValueError(
|
487 |
-
f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]"
|
488 |
-
)
|
489 |
-
|
490 |
-
print(f"Writing quantized weights to {quantize_path}")
|
491 |
-
quantize_path.unlink(missing_ok=True) # remove existing file if one already there
|
492 |
-
torch.save(quantized_state_dict, quantize_path)
|
493 |
-
print(f"Quantization complete took {time.time() - t0:.02f} seconds")
|
494 |
-
|
495 |
-
|
496 |
-
if __name__ == "__main__":
|
497 |
-
quantize()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/llama/rebuild_tokenizer.py
DELETED
@@ -1,57 +0,0 @@
|
|
1 |
-
from tokenizers import Tokenizer, decoders, models, pre_tokenizers, processors, trainers
|
2 |
-
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
3 |
-
|
4 |
-
# Initialize a tokenizer
|
5 |
-
tokenizer = Tokenizer(models.BPE())
|
6 |
-
|
7 |
-
# Customize pre-tokenization and decoding
|
8 |
-
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
9 |
-
tokenizer.decoder = decoders.ByteLevel()
|
10 |
-
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
|
11 |
-
|
12 |
-
# Don't train the tokenizer
|
13 |
-
trainer = trainers.BpeTrainer(
|
14 |
-
vocab_size=0,
|
15 |
-
min_frequency=2,
|
16 |
-
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
|
17 |
-
special_tokens=[
|
18 |
-
"<|begin_of_sequence|>",
|
19 |
-
"<|end_of_sequence|>",
|
20 |
-
"<|im_start|>",
|
21 |
-
"<|im_sep|>", # system, user, assistant, etc.
|
22 |
-
"<|im_end|>",
|
23 |
-
"<|semantic|>", # audio features
|
24 |
-
"<|pad|>",
|
25 |
-
],
|
26 |
-
)
|
27 |
-
|
28 |
-
# <|im_start|>user<|im_sep|>...<|im_end|>
|
29 |
-
# <|im_start|>assistant<|im_sep|><|semantic|><|semantic|><|semantic|><|semantic|><|semantic|><|im_end|>
|
30 |
-
tokenizer.train_from_iterator([], trainer=trainer)
|
31 |
-
|
32 |
-
print(len(tokenizer.get_vocab()))
|
33 |
-
x = tokenizer.encode(
|
34 |
-
"Hello, how are you? dfgnviadfjoiviouajeiodfjv 你好世界 🈶<|semantic|>"
|
35 |
-
).ids
|
36 |
-
print(x, len(x))
|
37 |
-
print(tokenizer.decode(x, skip_special_tokens=True))
|
38 |
-
|
39 |
-
|
40 |
-
tokenizer = PreTrainedTokenizerFast(
|
41 |
-
tokenizer_object=tokenizer,
|
42 |
-
pad_token="<|pad|>",
|
43 |
-
bos_token="<|begin_of_sequence|>",
|
44 |
-
eos_token="<|end_of_sequence|>",
|
45 |
-
)
|
46 |
-
|
47 |
-
# Try tokenizing a new sequence
|
48 |
-
sequence = "All around, too, lay vast quantities of the costliest merchandise, and treasures were heaped in every cranny of the rocks, but all these things only added to the desolation of the scene. 测试中文, 你好世界 🈶<|semantic|>"
|
49 |
-
encoded = tokenizer(sequence).input_ids
|
50 |
-
|
51 |
-
print("Test encoding....")
|
52 |
-
print(f"\tSentence: {sequence}")
|
53 |
-
print(f"\tEncoded: {encoded}")
|
54 |
-
print(f"\tDecoded: {tokenizer.batch_decode(encoded)}")
|
55 |
-
print(f"\tDecoded: {tokenizer.decode(encoded)}")
|
56 |
-
|
57 |
-
tokenizer.push_to_hub("fishaudio/fish-speech-1", private=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/msgpack_api.py
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
import httpx
|
2 |
-
import ormsgpack
|
3 |
-
|
4 |
-
from tools.commons import ServeReferenceAudio, ServeTTSRequest
|
5 |
-
|
6 |
-
# priority: ref_id > references
|
7 |
-
request = ServeTTSRequest(
|
8 |
-
text="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
|
9 |
-
# reference_id="114514",
|
10 |
-
references=[
|
11 |
-
ServeReferenceAudio(
|
12 |
-
audio=open("lengyue.wav", "rb").read(),
|
13 |
-
text=open("lengyue.lab", "r", encoding="utf-8").read(),
|
14 |
-
)
|
15 |
-
],
|
16 |
-
streaming=True,
|
17 |
-
)
|
18 |
-
|
19 |
-
with (
|
20 |
-
httpx.Client() as client,
|
21 |
-
open("hello.wav", "wb") as f,
|
22 |
-
):
|
23 |
-
with client.stream(
|
24 |
-
"POST",
|
25 |
-
"http://127.0.0.1:8080/v1/tts",
|
26 |
-
content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
|
27 |
-
headers={
|
28 |
-
"authorization": "Bearer YOUR_API_KEY",
|
29 |
-
"content-type": "application/msgpack",
|
30 |
-
},
|
31 |
-
timeout=None,
|
32 |
-
) as response:
|
33 |
-
for chunk in response.iter_bytes():
|
34 |
-
f.write(chunk)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/post_api.py
DELETED
@@ -1,205 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import base64
|
3 |
-
import wave
|
4 |
-
|
5 |
-
import ormsgpack
|
6 |
-
import pyaudio
|
7 |
-
import requests
|
8 |
-
from pydub import AudioSegment
|
9 |
-
from pydub.playback import play
|
10 |
-
|
11 |
-
from tools.commons import ServeReferenceAudio, ServeTTSRequest
|
12 |
-
from tools.file import audio_to_bytes, read_ref_text
|
13 |
-
|
14 |
-
|
15 |
-
def parse_args():
|
16 |
-
|
17 |
-
parser = argparse.ArgumentParser(
|
18 |
-
description="Send a WAV file and text to a server and receive synthesized audio."
|
19 |
-
)
|
20 |
-
|
21 |
-
parser.add_argument(
|
22 |
-
"--url",
|
23 |
-
"-u",
|
24 |
-
type=str,
|
25 |
-
default="http://127.0.0.1:8080/v1/tts",
|
26 |
-
help="URL of the server",
|
27 |
-
)
|
28 |
-
parser.add_argument(
|
29 |
-
"--text", "-t", type=str, required=True, help="Text to be synthesized"
|
30 |
-
)
|
31 |
-
parser.add_argument(
|
32 |
-
"--reference_id",
|
33 |
-
"-id",
|
34 |
-
type=str,
|
35 |
-
default=None,
|
36 |
-
help="ID of the reference model o be used for the speech",
|
37 |
-
)
|
38 |
-
parser.add_argument(
|
39 |
-
"--reference_audio",
|
40 |
-
"-ra",
|
41 |
-
type=str,
|
42 |
-
nargs="+",
|
43 |
-
default=None,
|
44 |
-
help="Path to the WAV file",
|
45 |
-
)
|
46 |
-
parser.add_argument(
|
47 |
-
"--reference_text",
|
48 |
-
"-rt",
|
49 |
-
type=str,
|
50 |
-
nargs="+",
|
51 |
-
default=None,
|
52 |
-
help="Reference text for voice synthesis",
|
53 |
-
)
|
54 |
-
parser.add_argument(
|
55 |
-
"--output",
|
56 |
-
"-o",
|
57 |
-
type=str,
|
58 |
-
default="generated_audio",
|
59 |
-
help="Output audio file name",
|
60 |
-
)
|
61 |
-
parser.add_argument(
|
62 |
-
"--play",
|
63 |
-
type=bool,
|
64 |
-
default=True,
|
65 |
-
help="Whether to play audio after receiving data",
|
66 |
-
)
|
67 |
-
parser.add_argument("--normalize", type=bool, default=True)
|
68 |
-
parser.add_argument(
|
69 |
-
"--format", type=str, choices=["wav", "mp3", "flac"], default="wav"
|
70 |
-
)
|
71 |
-
parser.add_argument("--mp3_bitrate", type=int, default=64)
|
72 |
-
parser.add_argument("--opus_bitrate", type=int, default=-1000)
|
73 |
-
parser.add_argument("--latency", type=str, default="normal", help="延迟选项")
|
74 |
-
parser.add_argument(
|
75 |
-
"--max_new_tokens",
|
76 |
-
type=int,
|
77 |
-
default=1024,
|
78 |
-
help="Maximum new tokens to generate",
|
79 |
-
)
|
80 |
-
parser.add_argument(
|
81 |
-
"--chunk_length", type=int, default=100, help="Chunk length for synthesis"
|
82 |
-
)
|
83 |
-
parser.add_argument(
|
84 |
-
"--top_p", type=float, default=0.7, help="Top-p sampling for synthesis"
|
85 |
-
)
|
86 |
-
parser.add_argument(
|
87 |
-
"--repetition_penalty",
|
88 |
-
type=float,
|
89 |
-
default=1.2,
|
90 |
-
help="Repetition penalty for synthesis",
|
91 |
-
)
|
92 |
-
parser.add_argument(
|
93 |
-
"--temperature", type=float, default=0.7, help="Temperature for sampling"
|
94 |
-
)
|
95 |
-
parser.add_argument(
|
96 |
-
"--speaker", type=str, default=None, help="Speaker ID for voice synthesis"
|
97 |
-
)
|
98 |
-
parser.add_argument("--emotion", type=str, default=None, help="Speaker's Emotion")
|
99 |
-
parser.add_argument(
|
100 |
-
"--streaming", type=bool, default=False, help="Enable streaming response"
|
101 |
-
)
|
102 |
-
parser.add_argument(
|
103 |
-
"--channels", type=int, default=1, help="Number of audio channels"
|
104 |
-
)
|
105 |
-
parser.add_argument("--rate", type=int, default=44100, help="Sample rate for audio")
|
106 |
-
|
107 |
-
return parser.parse_args()
|
108 |
-
|
109 |
-
|
110 |
-
if __name__ == "__main__":
|
111 |
-
|
112 |
-
args = parse_args()
|
113 |
-
|
114 |
-
idstr: str | None = args.reference_id
|
115 |
-
# priority: ref_id > [{text, audio},...]
|
116 |
-
if idstr is None:
|
117 |
-
ref_audios = args.reference_audio
|
118 |
-
ref_texts = args.reference_text
|
119 |
-
if ref_audios is None:
|
120 |
-
byte_audios = []
|
121 |
-
else:
|
122 |
-
byte_audios = [audio_to_bytes(ref_audio) for ref_audio in ref_audios]
|
123 |
-
if ref_texts is None:
|
124 |
-
ref_texts = []
|
125 |
-
else:
|
126 |
-
ref_texts = [read_ref_text(ref_text) for ref_text in ref_texts]
|
127 |
-
else:
|
128 |
-
byte_audios = []
|
129 |
-
ref_texts = []
|
130 |
-
pass # in api.py
|
131 |
-
|
132 |
-
data = {
|
133 |
-
"text": args.text,
|
134 |
-
"references": [
|
135 |
-
ServeReferenceAudio(audio=ref_audio, text=ref_text)
|
136 |
-
for ref_text, ref_audio in zip(ref_texts, byte_audios)
|
137 |
-
],
|
138 |
-
"reference_id": idstr,
|
139 |
-
"normalize": args.normalize,
|
140 |
-
"format": args.format,
|
141 |
-
"mp3_bitrate": args.mp3_bitrate,
|
142 |
-
"opus_bitrate": args.opus_bitrate,
|
143 |
-
"max_new_tokens": args.max_new_tokens,
|
144 |
-
"chunk_length": args.chunk_length,
|
145 |
-
"top_p": args.top_p,
|
146 |
-
"repetition_penalty": args.repetition_penalty,
|
147 |
-
"temperature": args.temperature,
|
148 |
-
"speaker": args.speaker,
|
149 |
-
"emotion": args.emotion,
|
150 |
-
"streaming": args.streaming,
|
151 |
-
}
|
152 |
-
|
153 |
-
pydantic_data = ServeTTSRequest(**data)
|
154 |
-
|
155 |
-
response = requests.post(
|
156 |
-
args.url,
|
157 |
-
data=ormsgpack.packb(pydantic_data, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
|
158 |
-
stream=args.streaming,
|
159 |
-
headers={
|
160 |
-
"authorization": "Bearer YOUR_API_KEY",
|
161 |
-
"content-type": "application/msgpack",
|
162 |
-
},
|
163 |
-
)
|
164 |
-
|
165 |
-
if response.status_code == 200:
|
166 |
-
if args.streaming:
|
167 |
-
p = pyaudio.PyAudio()
|
168 |
-
audio_format = pyaudio.paInt16 # Assuming 16-bit PCM format
|
169 |
-
stream = p.open(
|
170 |
-
format=audio_format, channels=args.channels, rate=args.rate, output=True
|
171 |
-
)
|
172 |
-
|
173 |
-
wf = wave.open(f"{args.output}.wav", "wb")
|
174 |
-
wf.setnchannels(args.channels)
|
175 |
-
wf.setsampwidth(p.get_sample_size(audio_format))
|
176 |
-
wf.setframerate(args.rate)
|
177 |
-
|
178 |
-
stream_stopped_flag = False
|
179 |
-
|
180 |
-
try:
|
181 |
-
for chunk in response.iter_content(chunk_size=1024):
|
182 |
-
if chunk:
|
183 |
-
stream.write(chunk)
|
184 |
-
wf.writeframesraw(chunk)
|
185 |
-
else:
|
186 |
-
if not stream_stopped_flag:
|
187 |
-
stream.stop_stream()
|
188 |
-
stream_stopped_flag = True
|
189 |
-
finally:
|
190 |
-
stream.close()
|
191 |
-
p.terminate()
|
192 |
-
wf.close()
|
193 |
-
else:
|
194 |
-
audio_content = response.content
|
195 |
-
audio_path = f"{args.output}.{args.format}"
|
196 |
-
with open(audio_path, "wb") as audio_file:
|
197 |
-
audio_file.write(audio_content)
|
198 |
-
|
199 |
-
audio = AudioSegment.from_file(audio_path, format=args.format)
|
200 |
-
if args.play:
|
201 |
-
play(audio)
|
202 |
-
print(f"Audio has been saved to '{audio_path}'.")
|
203 |
-
else:
|
204 |
-
print(f"Request failed with status code {response.status_code}")
|
205 |
-
print(response.json())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/sensevoice/README.md
DELETED
@@ -1,59 +0,0 @@
|
|
1 |
-
# FunASR Command Line Interface
|
2 |
-
|
3 |
-
This tool provides a command-line interface for separating vocals from instrumental tracks, converting videos to audio, and performing speech-to-text transcription on the resulting audio files.
|
4 |
-
|
5 |
-
## Requirements
|
6 |
-
|
7 |
-
- Python >= 3.10
|
8 |
-
- PyTorch <= 2.3.1
|
9 |
-
- ffmpeg, pydub, audio-separator[gpu].
|
10 |
-
|
11 |
-
## Installation
|
12 |
-
|
13 |
-
Install the required packages:
|
14 |
-
|
15 |
-
```bash
|
16 |
-
pip install -e .[stable]
|
17 |
-
```
|
18 |
-
|
19 |
-
Make sure you have `ffmpeg` installed and available in your `PATH`.
|
20 |
-
|
21 |
-
## Usage
|
22 |
-
|
23 |
-
### Basic Usage
|
24 |
-
|
25 |
-
To run the tool with default settings:
|
26 |
-
|
27 |
-
```bash
|
28 |
-
python tools/sensevoice/fun_asr.py --audio-dir <audio_directory> --save-dir <output_directory>
|
29 |
-
```
|
30 |
-
|
31 |
-
## Options
|
32 |
-
|
33 |
-
| Option | Description |
|
34 |
-
| :-----------------------: | :---------------------------------------------------------------------------: |
|
35 |
-
| --audio-dir | Directory containing audio or video files. |
|
36 |
-
| --save-dir | Directory to save processed audio files. |
|
37 |
-
| --device | Device to use for processing. Options: cuda (default) or cpu. |
|
38 |
-
| --language | Language of the transcription. Default is auto. |
|
39 |
-
| --max_single_segment_time | Maximum duration of a single audio segment in milliseconds. Default is 20000. |
|
40 |
-
| --punc | Enable punctuation prediction. |
|
41 |
-
| --denoise | Enable noise reduction (vocal separation). |
|
42 |
-
|
43 |
-
## Example
|
44 |
-
|
45 |
-
To process audio files in the directory `path/to/audio` and save the output to `path/to/output`, with punctuation and noise reduction enabled:
|
46 |
-
|
47 |
-
```bash
|
48 |
-
python tools/sensevoice/fun_asr.py --audio-dir path/to/audio --save-dir path/to/output --punc --denoise
|
49 |
-
```
|
50 |
-
|
51 |
-
## Additional Notes
|
52 |
-
|
53 |
-
- The tool supports `both audio and video files`. Videos will be converted to audio automatically.
|
54 |
-
- If the `--denoise` option is used, the tool will perform vocal separation to isolate the vocals from the instrumental tracks.
|
55 |
-
- The script will automatically create necessary directories in the `--save-dir`.
|
56 |
-
|
57 |
-
## Troubleshooting
|
58 |
-
|
59 |
-
If you encounter any issues, make sure all dependencies are correctly installed and configured. For more detailed troubleshooting, refer to the documentation of each dependency.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/sensevoice/__init__.py
DELETED
File without changes
|
tools/sensevoice/auto_model.py
DELETED
@@ -1,573 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
# -*- encoding: utf-8 -*-
|
3 |
-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
4 |
-
# MIT License (https://opensource.org/licenses/MIT)
|
5 |
-
|
6 |
-
import copy
|
7 |
-
import json
|
8 |
-
import logging
|
9 |
-
import os.path
|
10 |
-
import random
|
11 |
-
import re
|
12 |
-
import string
|
13 |
-
import time
|
14 |
-
|
15 |
-
import numpy as np
|
16 |
-
import torch
|
17 |
-
from funasr.download.download_model_from_hub import download_model
|
18 |
-
from funasr.download.file import download_from_url
|
19 |
-
from funasr.register import tables
|
20 |
-
from funasr.train_utils.load_pretrained_model import load_pretrained_model
|
21 |
-
from funasr.train_utils.set_all_random_seed import set_all_random_seed
|
22 |
-
from funasr.utils import export_utils, misc
|
23 |
-
from funasr.utils.load_utils import load_audio_text_image_video, load_bytes
|
24 |
-
from funasr.utils.misc import deep_update
|
25 |
-
from funasr.utils.timestamp_tools import timestamp_sentence, timestamp_sentence_en
|
26 |
-
from tqdm import tqdm
|
27 |
-
|
28 |
-
from .vad_utils import merge_vad, slice_padding_audio_samples
|
29 |
-
|
30 |
-
try:
|
31 |
-
from funasr.models.campplus.cluster_backend import ClusterBackend
|
32 |
-
from funasr.models.campplus.utils import distribute_spk, postprocess, sv_chunk
|
33 |
-
except:
|
34 |
-
pass
|
35 |
-
|
36 |
-
|
37 |
-
def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
|
38 |
-
""" """
|
39 |
-
data_list = []
|
40 |
-
key_list = []
|
41 |
-
filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
|
42 |
-
|
43 |
-
chars = string.ascii_letters + string.digits
|
44 |
-
if isinstance(data_in, str):
|
45 |
-
if data_in.startswith("http://") or data_in.startswith("https://"): # url
|
46 |
-
data_in = download_from_url(data_in)
|
47 |
-
|
48 |
-
if isinstance(data_in, str) and os.path.exists(
|
49 |
-
data_in
|
50 |
-
): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
|
51 |
-
_, file_extension = os.path.splitext(data_in)
|
52 |
-
file_extension = file_extension.lower()
|
53 |
-
if file_extension in filelist: # filelist: wav.scp, file.jsonl;text.txt;
|
54 |
-
with open(data_in, encoding="utf-8") as fin:
|
55 |
-
for line in fin:
|
56 |
-
key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
|
57 |
-
if data_in.endswith(
|
58 |
-
".jsonl"
|
59 |
-
): # file.jsonl: json.dumps({"source": data})
|
60 |
-
lines = json.loads(line.strip())
|
61 |
-
data = lines["source"]
|
62 |
-
key = data["key"] if "key" in data else key
|
63 |
-
else: # filelist, wav.scp, text.txt: id \t data or data
|
64 |
-
lines = line.strip().split(maxsplit=1)
|
65 |
-
data = lines[1] if len(lines) > 1 else lines[0]
|
66 |
-
key = lines[0] if len(lines) > 1 else key
|
67 |
-
|
68 |
-
data_list.append(data)
|
69 |
-
key_list.append(key)
|
70 |
-
else:
|
71 |
-
if key is None:
|
72 |
-
# key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
|
73 |
-
key = misc.extract_filename_without_extension(data_in)
|
74 |
-
data_list = [data_in]
|
75 |
-
key_list = [key]
|
76 |
-
elif isinstance(data_in, (list, tuple)):
|
77 |
-
if data_type is not None and isinstance(
|
78 |
-
data_type, (list, tuple)
|
79 |
-
): # mutiple inputs
|
80 |
-
data_list_tmp = []
|
81 |
-
for data_in_i, data_type_i in zip(data_in, data_type):
|
82 |
-
key_list, data_list_i = prepare_data_iterator(
|
83 |
-
data_in=data_in_i, data_type=data_type_i
|
84 |
-
)
|
85 |
-
data_list_tmp.append(data_list_i)
|
86 |
-
data_list = []
|
87 |
-
for item in zip(*data_list_tmp):
|
88 |
-
data_list.append(item)
|
89 |
-
else:
|
90 |
-
# [audio sample point, fbank, text]
|
91 |
-
data_list = data_in
|
92 |
-
key_list = []
|
93 |
-
for data_i in data_in:
|
94 |
-
if isinstance(data_i, str) and os.path.exists(data_i):
|
95 |
-
key = misc.extract_filename_without_extension(data_i)
|
96 |
-
else:
|
97 |
-
if key is None:
|
98 |
-
key = "rand_key_" + "".join(
|
99 |
-
random.choice(chars) for _ in range(13)
|
100 |
-
)
|
101 |
-
key_list.append(key)
|
102 |
-
|
103 |
-
else: # raw text; audio sample point, fbank; bytes
|
104 |
-
if isinstance(data_in, bytes): # audio bytes
|
105 |
-
data_in = load_bytes(data_in)
|
106 |
-
if key is None:
|
107 |
-
key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
|
108 |
-
data_list = [data_in]
|
109 |
-
key_list = [key]
|
110 |
-
|
111 |
-
return key_list, data_list
|
112 |
-
|
113 |
-
|
114 |
-
class AutoModel:
|
115 |
-
|
116 |
-
def __init__(self, **kwargs):
|
117 |
-
|
118 |
-
try:
|
119 |
-
from funasr.utils.version_checker import check_for_update
|
120 |
-
|
121 |
-
print(
|
122 |
-
"Check update of funasr, and it would cost few times. You may disable it by set `disable_update=True` in AutoModel"
|
123 |
-
)
|
124 |
-
check_for_update(disable=kwargs.get("disable_update", False))
|
125 |
-
except:
|
126 |
-
pass
|
127 |
-
|
128 |
-
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
|
129 |
-
logging.basicConfig(level=log_level)
|
130 |
-
|
131 |
-
model, kwargs = self.build_model(**kwargs)
|
132 |
-
|
133 |
-
# if vad_model is not None, build vad model else None
|
134 |
-
vad_model = kwargs.get("vad_model", None)
|
135 |
-
vad_kwargs = (
|
136 |
-
{} if kwargs.get("vad_kwargs", {}) is None else kwargs.get("vad_kwargs", {})
|
137 |
-
)
|
138 |
-
if vad_model is not None:
|
139 |
-
logging.info("Building VAD model.")
|
140 |
-
vad_kwargs["model"] = vad_model
|
141 |
-
vad_kwargs["model_revision"] = kwargs.get("vad_model_revision", "master")
|
142 |
-
vad_kwargs["device"] = kwargs["device"]
|
143 |
-
vad_model, vad_kwargs = self.build_model(**vad_kwargs)
|
144 |
-
|
145 |
-
# if punc_model is not None, build punc model else None
|
146 |
-
punc_model = kwargs.get("punc_model", None)
|
147 |
-
punc_kwargs = (
|
148 |
-
{}
|
149 |
-
if kwargs.get("punc_kwargs", {}) is None
|
150 |
-
else kwargs.get("punc_kwargs", {})
|
151 |
-
)
|
152 |
-
if punc_model is not None:
|
153 |
-
logging.info("Building punc model.")
|
154 |
-
punc_kwargs["model"] = punc_model
|
155 |
-
punc_kwargs["model_revision"] = kwargs.get("punc_model_revision", "master")
|
156 |
-
punc_kwargs["device"] = kwargs["device"]
|
157 |
-
punc_model, punc_kwargs = self.build_model(**punc_kwargs)
|
158 |
-
|
159 |
-
# if spk_model is not None, build spk model else None
|
160 |
-
spk_model = kwargs.get("spk_model", None)
|
161 |
-
spk_kwargs = (
|
162 |
-
{} if kwargs.get("spk_kwargs", {}) is None else kwargs.get("spk_kwargs", {})
|
163 |
-
)
|
164 |
-
if spk_model is not None:
|
165 |
-
logging.info("Building SPK model.")
|
166 |
-
spk_kwargs["model"] = spk_model
|
167 |
-
spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master")
|
168 |
-
spk_kwargs["device"] = kwargs["device"]
|
169 |
-
spk_model, spk_kwargs = self.build_model(**spk_kwargs)
|
170 |
-
self.cb_model = ClusterBackend().to(kwargs["device"])
|
171 |
-
spk_mode = kwargs.get("spk_mode", "punc_segment")
|
172 |
-
if spk_mode not in ["default", "vad_segment", "punc_segment"]:
|
173 |
-
logging.error(
|
174 |
-
"spk_mode should be one of default, vad_segment and punc_segment."
|
175 |
-
)
|
176 |
-
self.spk_mode = spk_mode
|
177 |
-
|
178 |
-
self.kwargs = kwargs
|
179 |
-
self.model = model
|
180 |
-
self.vad_model = vad_model
|
181 |
-
self.vad_kwargs = vad_kwargs
|
182 |
-
self.punc_model = punc_model
|
183 |
-
self.punc_kwargs = punc_kwargs
|
184 |
-
self.spk_model = spk_model
|
185 |
-
self.spk_kwargs = spk_kwargs
|
186 |
-
self.model_path = kwargs.get("model_path")
|
187 |
-
|
188 |
-
@staticmethod
|
189 |
-
def build_model(**kwargs):
|
190 |
-
assert "model" in kwargs
|
191 |
-
if "model_conf" not in kwargs:
|
192 |
-
logging.info(
|
193 |
-
"download models from model hub: {}".format(kwargs.get("hub", "ms"))
|
194 |
-
)
|
195 |
-
kwargs = download_model(**kwargs)
|
196 |
-
|
197 |
-
set_all_random_seed(kwargs.get("seed", 0))
|
198 |
-
|
199 |
-
device = kwargs.get("device", "cuda")
|
200 |
-
if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
|
201 |
-
device = "cpu"
|
202 |
-
kwargs["batch_size"] = 1
|
203 |
-
kwargs["device"] = device
|
204 |
-
|
205 |
-
torch.set_num_threads(kwargs.get("ncpu", 4))
|
206 |
-
|
207 |
-
# build tokenizer
|
208 |
-
tokenizer = kwargs.get("tokenizer", None)
|
209 |
-
if tokenizer is not None:
|
210 |
-
tokenizer_class = tables.tokenizer_classes.get(tokenizer)
|
211 |
-
tokenizer = tokenizer_class(**kwargs.get("tokenizer_conf", {}))
|
212 |
-
kwargs["token_list"] = (
|
213 |
-
tokenizer.token_list if hasattr(tokenizer, "token_list") else None
|
214 |
-
)
|
215 |
-
kwargs["token_list"] = (
|
216 |
-
tokenizer.get_vocab()
|
217 |
-
if hasattr(tokenizer, "get_vocab")
|
218 |
-
else kwargs["token_list"]
|
219 |
-
)
|
220 |
-
vocab_size = (
|
221 |
-
len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
|
222 |
-
)
|
223 |
-
if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
|
224 |
-
vocab_size = tokenizer.get_vocab_size()
|
225 |
-
else:
|
226 |
-
vocab_size = -1
|
227 |
-
kwargs["tokenizer"] = tokenizer
|
228 |
-
|
229 |
-
# build frontend
|
230 |
-
frontend = kwargs.get("frontend", None)
|
231 |
-
kwargs["input_size"] = None
|
232 |
-
if frontend is not None:
|
233 |
-
frontend_class = tables.frontend_classes.get(frontend)
|
234 |
-
frontend = frontend_class(**kwargs.get("frontend_conf", {}))
|
235 |
-
kwargs["input_size"] = (
|
236 |
-
frontend.output_size() if hasattr(frontend, "output_size") else None
|
237 |
-
)
|
238 |
-
kwargs["frontend"] = frontend
|
239 |
-
# build model
|
240 |
-
model_class = tables.model_classes.get(kwargs["model"])
|
241 |
-
assert model_class is not None, f'{kwargs["model"]} is not registered'
|
242 |
-
model_conf = {}
|
243 |
-
deep_update(model_conf, kwargs.get("model_conf", {}))
|
244 |
-
deep_update(model_conf, kwargs)
|
245 |
-
model = model_class(**model_conf, vocab_size=vocab_size)
|
246 |
-
|
247 |
-
# init_param
|
248 |
-
init_param = kwargs.get("init_param", None)
|
249 |
-
if init_param is not None:
|
250 |
-
if os.path.exists(init_param):
|
251 |
-
logging.info(f"Loading pretrained params from {init_param}")
|
252 |
-
load_pretrained_model(
|
253 |
-
model=model,
|
254 |
-
path=init_param,
|
255 |
-
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
|
256 |
-
oss_bucket=kwargs.get("oss_bucket", None),
|
257 |
-
scope_map=kwargs.get("scope_map", []),
|
258 |
-
excludes=kwargs.get("excludes", None),
|
259 |
-
)
|
260 |
-
else:
|
261 |
-
print(f"error, init_param does not exist!: {init_param}")
|
262 |
-
|
263 |
-
# fp16
|
264 |
-
if kwargs.get("fp16", False):
|
265 |
-
model.to(torch.float16)
|
266 |
-
elif kwargs.get("bf16", False):
|
267 |
-
model.to(torch.bfloat16)
|
268 |
-
model.to(device)
|
269 |
-
|
270 |
-
if not kwargs.get("disable_log", True):
|
271 |
-
tables.print()
|
272 |
-
|
273 |
-
return model, kwargs
|
274 |
-
|
275 |
-
def __call__(self, *args, **cfg):
|
276 |
-
kwargs = self.kwargs
|
277 |
-
deep_update(kwargs, cfg)
|
278 |
-
res = self.model(*args, kwargs)
|
279 |
-
return res
|
280 |
-
|
281 |
-
def generate(self, input, input_len=None, **cfg):
|
282 |
-
if self.vad_model is None:
|
283 |
-
return self.inference(input, input_len=input_len, **cfg)
|
284 |
-
|
285 |
-
else:
|
286 |
-
return self.inference_with_vad(input, input_len=input_len, **cfg)
|
287 |
-
|
288 |
-
def inference(
|
289 |
-
self, input, input_len=None, model=None, kwargs=None, key=None, **cfg
|
290 |
-
):
|
291 |
-
kwargs = self.kwargs if kwargs is None else kwargs
|
292 |
-
if "cache" in kwargs:
|
293 |
-
kwargs.pop("cache")
|
294 |
-
deep_update(kwargs, cfg)
|
295 |
-
model = self.model if model is None else model
|
296 |
-
model.eval()
|
297 |
-
|
298 |
-
batch_size = kwargs.get("batch_size", 1)
|
299 |
-
# if kwargs.get("device", "cpu") == "cpu":
|
300 |
-
# batch_size = 1
|
301 |
-
|
302 |
-
key_list, data_list = prepare_data_iterator(
|
303 |
-
input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
|
304 |
-
)
|
305 |
-
|
306 |
-
speed_stats = {}
|
307 |
-
asr_result_list = []
|
308 |
-
num_samples = len(data_list)
|
309 |
-
disable_pbar = self.kwargs.get("disable_pbar", False)
|
310 |
-
pbar = (
|
311 |
-
tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
|
312 |
-
if not disable_pbar
|
313 |
-
else None
|
314 |
-
)
|
315 |
-
time_speech_total = 0.0
|
316 |
-
time_escape_total = 0.0
|
317 |
-
for beg_idx in range(0, num_samples, batch_size):
|
318 |
-
end_idx = min(num_samples, beg_idx + batch_size)
|
319 |
-
data_batch = data_list[beg_idx:end_idx]
|
320 |
-
key_batch = key_list[beg_idx:end_idx]
|
321 |
-
batch = {"data_in": data_batch, "key": key_batch}
|
322 |
-
|
323 |
-
if (end_idx - beg_idx) == 1 and kwargs.get(
|
324 |
-
"data_type", None
|
325 |
-
) == "fbank": # fbank
|
326 |
-
batch["data_in"] = data_batch[0]
|
327 |
-
batch["data_lengths"] = input_len
|
328 |
-
|
329 |
-
time1 = time.perf_counter()
|
330 |
-
with torch.no_grad():
|
331 |
-
res = model.inference(**batch, **kwargs)
|
332 |
-
if isinstance(res, (list, tuple)):
|
333 |
-
results = res[0] if len(res) > 0 else [{"text": ""}]
|
334 |
-
meta_data = res[1] if len(res) > 1 else {}
|
335 |
-
time2 = time.perf_counter()
|
336 |
-
|
337 |
-
asr_result_list.extend(results)
|
338 |
-
|
339 |
-
# batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
|
340 |
-
batch_data_time = meta_data.get("batch_data_time", -1)
|
341 |
-
time_escape = time2 - time1
|
342 |
-
speed_stats["load_data"] = meta_data.get("load_data", 0.0)
|
343 |
-
speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
|
344 |
-
speed_stats["forward"] = f"{time_escape:0.3f}"
|
345 |
-
speed_stats["batch_size"] = f"{len(results)}"
|
346 |
-
speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
|
347 |
-
description = f"{speed_stats}, "
|
348 |
-
if pbar:
|
349 |
-
pbar.update(end_idx - beg_idx)
|
350 |
-
pbar.set_description(description)
|
351 |
-
time_speech_total += batch_data_time
|
352 |
-
time_escape_total += time_escape
|
353 |
-
|
354 |
-
if pbar:
|
355 |
-
# pbar.update(1)
|
356 |
-
pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
|
357 |
-
torch.cuda.empty_cache()
|
358 |
-
return asr_result_list
|
359 |
-
|
360 |
-
def vad(self, input, input_len=None, **cfg):
|
361 |
-
kwargs = self.kwargs
|
362 |
-
# step.1: compute the vad model
|
363 |
-
deep_update(self.vad_kwargs, cfg)
|
364 |
-
beg_vad = time.time()
|
365 |
-
res = self.inference(
|
366 |
-
input,
|
367 |
-
input_len=input_len,
|
368 |
-
model=self.vad_model,
|
369 |
-
kwargs=self.vad_kwargs,
|
370 |
-
**cfg,
|
371 |
-
)
|
372 |
-
end_vad = time.time()
|
373 |
-
# FIX(gcf): concat the vad clips for sense vocie model for better aed
|
374 |
-
if cfg.get("merge_vad", False):
|
375 |
-
for i in range(len(res)):
|
376 |
-
res[i]["value"] = merge_vad(
|
377 |
-
res[i]["value"], kwargs.get("merge_length_s", 15) * 1000
|
378 |
-
)
|
379 |
-
elapsed = end_vad - beg_vad
|
380 |
-
return elapsed, res
|
381 |
-
|
382 |
-
def inference_with_vadres(self, input, vad_res, input_len=None, **cfg):
|
383 |
-
|
384 |
-
kwargs = self.kwargs
|
385 |
-
|
386 |
-
# step.2 compute asr model
|
387 |
-
model = self.model
|
388 |
-
deep_update(kwargs, cfg)
|
389 |
-
batch_size = max(int(kwargs.get("batch_size_s", 300)) * 1000, 1)
|
390 |
-
batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60)) * 1000
|
391 |
-
kwargs["batch_size"] = batch_size
|
392 |
-
|
393 |
-
key_list, data_list = prepare_data_iterator(
|
394 |
-
input, input_len=input_len, data_type=kwargs.get("data_type", None)
|
395 |
-
)
|
396 |
-
results_ret_list = []
|
397 |
-
time_speech_total_all_samples = 1e-6
|
398 |
-
|
399 |
-
beg_total = time.time()
|
400 |
-
pbar_total = (
|
401 |
-
tqdm(colour="red", total=len(vad_res), dynamic_ncols=True)
|
402 |
-
if not kwargs.get("disable_pbar", False)
|
403 |
-
else None
|
404 |
-
)
|
405 |
-
|
406 |
-
for i in range(len(vad_res)):
|
407 |
-
key = vad_res[i]["key"]
|
408 |
-
vadsegments = vad_res[i]["value"]
|
409 |
-
input_i = data_list[i]
|
410 |
-
fs = kwargs["frontend"].fs if hasattr(kwargs["frontend"], "fs") else 16000
|
411 |
-
speech = load_audio_text_image_video(
|
412 |
-
input_i, fs=fs, audio_fs=kwargs.get("fs", 16000)
|
413 |
-
)
|
414 |
-
speech_lengths = len(speech)
|
415 |
-
n = len(vadsegments)
|
416 |
-
data_with_index = [(vadsegments[i], i) for i in range(n)]
|
417 |
-
sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
|
418 |
-
results_sorted = []
|
419 |
-
|
420 |
-
if not len(sorted_data):
|
421 |
-
results_ret_list.append({"key": key, "text": "", "timestamp": []})
|
422 |
-
logging.info("decoding, utt: {}, empty speech".format(key))
|
423 |
-
continue
|
424 |
-
|
425 |
-
if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
|
426 |
-
batch_size = max(
|
427 |
-
batch_size, sorted_data[0][0][1] - sorted_data[0][0][0]
|
428 |
-
)
|
429 |
-
|
430 |
-
if kwargs["device"] == "cpu":
|
431 |
-
batch_size = 0
|
432 |
-
|
433 |
-
beg_idx = 0
|
434 |
-
beg_asr_total = time.time()
|
435 |
-
time_speech_total_per_sample = speech_lengths / 16000
|
436 |
-
time_speech_total_all_samples += time_speech_total_per_sample
|
437 |
-
|
438 |
-
# pbar_sample = tqdm(colour="blue", total=n, dynamic_ncols=True)
|
439 |
-
|
440 |
-
all_segments = []
|
441 |
-
max_len_in_batch = 0
|
442 |
-
end_idx = 1
|
443 |
-
|
444 |
-
for j, _ in enumerate(range(0, n)):
|
445 |
-
# pbar_sample.update(1)
|
446 |
-
sample_length = sorted_data[j][0][1] - sorted_data[j][0][0]
|
447 |
-
potential_batch_length = max(max_len_in_batch, sample_length) * (
|
448 |
-
j + 1 - beg_idx
|
449 |
-
)
|
450 |
-
# batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0]
|
451 |
-
if (
|
452 |
-
j < n - 1
|
453 |
-
and sample_length < batch_size_threshold_ms
|
454 |
-
and potential_batch_length < batch_size
|
455 |
-
):
|
456 |
-
max_len_in_batch = max(max_len_in_batch, sample_length)
|
457 |
-
end_idx += 1
|
458 |
-
continue
|
459 |
-
|
460 |
-
speech_j, speech_lengths_j, intervals = slice_padding_audio_samples(
|
461 |
-
speech, speech_lengths, sorted_data[beg_idx:end_idx]
|
462 |
-
)
|
463 |
-
results = self.inference(
|
464 |
-
speech_j, input_len=None, model=model, kwargs=kwargs, **cfg
|
465 |
-
)
|
466 |
-
|
467 |
-
for _b in range(len(speech_j)):
|
468 |
-
results[_b]["interval"] = intervals[_b]
|
469 |
-
|
470 |
-
if self.spk_model is not None:
|
471 |
-
# compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
|
472 |
-
for _b in range(len(speech_j)):
|
473 |
-
vad_segments = [
|
474 |
-
[
|
475 |
-
sorted_data[beg_idx:end_idx][_b][0][0] / 1000.0,
|
476 |
-
sorted_data[beg_idx:end_idx][_b][0][1] / 1000.0,
|
477 |
-
np.array(speech_j[_b]),
|
478 |
-
]
|
479 |
-
]
|
480 |
-
segments = sv_chunk(vad_segments)
|
481 |
-
all_segments.extend(segments)
|
482 |
-
speech_b = [i[2] for i in segments]
|
483 |
-
spk_res = self.inference(
|
484 |
-
speech_b,
|
485 |
-
input_len=None,
|
486 |
-
model=self.spk_model,
|
487 |
-
kwargs=kwargs,
|
488 |
-
**cfg,
|
489 |
-
)
|
490 |
-
results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"]
|
491 |
-
|
492 |
-
beg_idx = end_idx
|
493 |
-
end_idx += 1
|
494 |
-
max_len_in_batch = sample_length
|
495 |
-
if len(results) < 1:
|
496 |
-
continue
|
497 |
-
results_sorted.extend(results)
|
498 |
-
|
499 |
-
# end_asr_total = time.time()
|
500 |
-
# time_escape_total_per_sample = end_asr_total - beg_asr_total
|
501 |
-
# pbar_sample.update(1)
|
502 |
-
# pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
|
503 |
-
# f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
|
504 |
-
# f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
|
505 |
-
|
506 |
-
restored_data = [0] * n
|
507 |
-
for j in range(n):
|
508 |
-
index = sorted_data[j][1]
|
509 |
-
cur = results_sorted[j]
|
510 |
-
pattern = r"<\|([^|]+)\|>"
|
511 |
-
emotion_string = re.findall(pattern, cur["text"])
|
512 |
-
cur["text"] = re.sub(pattern, "", cur["text"])
|
513 |
-
cur["emo"] = "".join([f"<|{t}|>" for t in emotion_string])
|
514 |
-
if self.punc_model is not None and len(cur["text"].strip()) > 0:
|
515 |
-
deep_update(self.punc_kwargs, cfg)
|
516 |
-
punc_res = self.inference(
|
517 |
-
cur["text"],
|
518 |
-
model=self.punc_model,
|
519 |
-
kwargs=self.punc_kwargs,
|
520 |
-
**cfg,
|
521 |
-
)
|
522 |
-
cur["text"] = punc_res[0]["text"]
|
523 |
-
|
524 |
-
restored_data[index] = cur
|
525 |
-
|
526 |
-
end_asr_total = time.time()
|
527 |
-
time_escape_total_per_sample = end_asr_total - beg_asr_total
|
528 |
-
if pbar_total:
|
529 |
-
pbar_total.update(1)
|
530 |
-
pbar_total.set_description(
|
531 |
-
f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
|
532 |
-
f"time_speech: {time_speech_total_per_sample: 0.3f}, "
|
533 |
-
f"time_escape: {time_escape_total_per_sample:0.3f}"
|
534 |
-
)
|
535 |
-
|
536 |
-
# end_total = time.time()
|
537 |
-
# time_escape_total_all_samples = end_total - beg_total
|
538 |
-
# print(f"rtf_avg_all: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, "
|
539 |
-
# f"time_speech_all: {time_speech_total_all_samples: 0.3f}, "
|
540 |
-
# f"time_escape_all: {time_escape_total_all_samples:0.3f}")
|
541 |
-
return restored_data
|
542 |
-
|
543 |
-
def export(self, input=None, **cfg):
|
544 |
-
"""
|
545 |
-
|
546 |
-
:param input:
|
547 |
-
:param type:
|
548 |
-
:param quantize:
|
549 |
-
:param fallback_num:
|
550 |
-
:param calib_num:
|
551 |
-
:param opset_version:
|
552 |
-
:param cfg:
|
553 |
-
:return:
|
554 |
-
"""
|
555 |
-
|
556 |
-
device = cfg.get("device", "cpu")
|
557 |
-
model = self.model.to(device=device)
|
558 |
-
kwargs = self.kwargs
|
559 |
-
deep_update(kwargs, cfg)
|
560 |
-
kwargs["device"] = device
|
561 |
-
del kwargs["model"]
|
562 |
-
model.eval()
|
563 |
-
|
564 |
-
type = kwargs.get("type", "onnx")
|
565 |
-
|
566 |
-
key_list, data_list = prepare_data_iterator(
|
567 |
-
input, input_len=None, data_type=kwargs.get("data_type", None), key=None
|
568 |
-
)
|
569 |
-
|
570 |
-
with torch.no_grad():
|
571 |
-
export_dir = export_utils.export(model=model, data_in=data_list, **kwargs)
|
572 |
-
|
573 |
-
return export_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/sensevoice/fun_asr.py
DELETED
@@ -1,332 +0,0 @@
|
|
1 |
-
import gc
|
2 |
-
import os
|
3 |
-
import re
|
4 |
-
|
5 |
-
from audio_separator.separator import Separator
|
6 |
-
|
7 |
-
os.environ["MODELSCOPE_CACHE"] = "./.cache/funasr"
|
8 |
-
os.environ["UVR5_CACHE"] = "./.cache/uvr5-models"
|
9 |
-
import json
|
10 |
-
import subprocess
|
11 |
-
from pathlib import Path
|
12 |
-
|
13 |
-
import click
|
14 |
-
import torch
|
15 |
-
from loguru import logger
|
16 |
-
from pydub import AudioSegment
|
17 |
-
from silero_vad import get_speech_timestamps, load_silero_vad, read_audio
|
18 |
-
from tqdm import tqdm
|
19 |
-
|
20 |
-
from tools.file import AUDIO_EXTENSIONS, VIDEO_EXTENSIONS, list_files
|
21 |
-
from tools.sensevoice.auto_model import AutoModel
|
22 |
-
|
23 |
-
|
24 |
-
def uvr5_cli(
|
25 |
-
audio_dir: Path,
|
26 |
-
output_folder: Path,
|
27 |
-
audio_files: list[Path] | None = None,
|
28 |
-
output_format: str = "flac",
|
29 |
-
model: str = "BS-Roformer-Viperx-1297.ckpt",
|
30 |
-
):
|
31 |
-
# ["BS-Roformer-Viperx-1297.ckpt", "BS-Roformer-Viperx-1296.ckpt", "BS-Roformer-Viperx-1053.ckpt", "Mel-Roformer-Viperx-1143.ckpt"]
|
32 |
-
sepr = Separator(
|
33 |
-
model_file_dir=os.environ["UVR5_CACHE"],
|
34 |
-
output_dir=output_folder,
|
35 |
-
output_format=output_format,
|
36 |
-
)
|
37 |
-
dictmodel = {
|
38 |
-
"BS-Roformer-Viperx-1297.ckpt": "model_bs_roformer_ep_317_sdr_12.9755.ckpt",
|
39 |
-
"BS-Roformer-Viperx-1296.ckpt": "model_bs_roformer_ep_368_sdr_12.9628.ckpt",
|
40 |
-
"BS-Roformer-Viperx-1053.ckpt": "model_bs_roformer_ep_937_sdr_10.5309.ckpt",
|
41 |
-
"Mel-Roformer-Viperx-1143.ckpt": "model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt",
|
42 |
-
}
|
43 |
-
roformer_model = dictmodel[model]
|
44 |
-
sepr.load_model(roformer_model)
|
45 |
-
if audio_files is None:
|
46 |
-
audio_files = list_files(
|
47 |
-
path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
|
48 |
-
)
|
49 |
-
total_files = len(audio_files)
|
50 |
-
|
51 |
-
print(f"{total_files} audio files found")
|
52 |
-
|
53 |
-
res = []
|
54 |
-
for audio in tqdm(audio_files, desc="Denoising: "):
|
55 |
-
file_path = str(audio_dir / audio)
|
56 |
-
sep_out = sepr.separate(file_path)
|
57 |
-
if isinstance(sep_out, str):
|
58 |
-
res.append(sep_out)
|
59 |
-
elif isinstance(sep_out, list):
|
60 |
-
res.extend(sep_out)
|
61 |
-
del sepr
|
62 |
-
gc.collect()
|
63 |
-
if torch.cuda.is_available():
|
64 |
-
torch.cuda.empty_cache()
|
65 |
-
|
66 |
-
return res, roformer_model
|
67 |
-
|
68 |
-
|
69 |
-
def get_sample_rate(media_path: Path):
|
70 |
-
result = subprocess.run(
|
71 |
-
[
|
72 |
-
"ffprobe",
|
73 |
-
"-v",
|
74 |
-
"quiet",
|
75 |
-
"-print_format",
|
76 |
-
"json",
|
77 |
-
"-show_streams",
|
78 |
-
str(media_path),
|
79 |
-
],
|
80 |
-
capture_output=True,
|
81 |
-
text=True,
|
82 |
-
check=True,
|
83 |
-
)
|
84 |
-
media_info = json.loads(result.stdout)
|
85 |
-
for stream in media_info.get("streams", []):
|
86 |
-
if stream.get("codec_type") == "audio":
|
87 |
-
return stream.get("sample_rate")
|
88 |
-
return "44100" # Default sample rate if not found
|
89 |
-
|
90 |
-
|
91 |
-
def convert_to_mono(src_path: Path, out_path: Path, out_fmt: str = "wav"):
|
92 |
-
sr = get_sample_rate(src_path)
|
93 |
-
out_path.parent.mkdir(parents=True, exist_ok=True)
|
94 |
-
if src_path.resolve() == out_path.resolve():
|
95 |
-
output = str(out_path.with_stem(out_path.stem + f"_{sr}"))
|
96 |
-
else:
|
97 |
-
output = str(out_path)
|
98 |
-
subprocess.run(
|
99 |
-
[
|
100 |
-
"ffmpeg",
|
101 |
-
"-loglevel",
|
102 |
-
"error",
|
103 |
-
"-i",
|
104 |
-
str(src_path),
|
105 |
-
"-acodec",
|
106 |
-
"pcm_s16le" if out_fmt == "wav" else "flac",
|
107 |
-
"-ar",
|
108 |
-
sr,
|
109 |
-
"-ac",
|
110 |
-
"1",
|
111 |
-
"-y",
|
112 |
-
output,
|
113 |
-
],
|
114 |
-
check=True,
|
115 |
-
)
|
116 |
-
return out_path
|
117 |
-
|
118 |
-
|
119 |
-
def convert_video_to_audio(video_path: Path, audio_dir: Path):
|
120 |
-
cur_dir = audio_dir / video_path.relative_to(audio_dir).parent
|
121 |
-
vocals = [
|
122 |
-
p
|
123 |
-
for p in cur_dir.glob(f"{video_path.stem}_(Vocals)*.*")
|
124 |
-
if p.suffix in AUDIO_EXTENSIONS
|
125 |
-
]
|
126 |
-
if len(vocals) > 0:
|
127 |
-
return vocals[0]
|
128 |
-
audio_path = cur_dir / f"{video_path.stem}.wav"
|
129 |
-
convert_to_mono(video_path, audio_path)
|
130 |
-
return audio_path
|
131 |
-
|
132 |
-
|
133 |
-
@click.command()
|
134 |
-
@click.option("--audio-dir", required=True, help="Directory containing audio files")
|
135 |
-
@click.option(
|
136 |
-
"--save-dir", required=True, help="Directory to save processed audio files"
|
137 |
-
)
|
138 |
-
@click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
|
139 |
-
@click.option("--language", default="auto", help="Language of the transcription")
|
140 |
-
@click.option(
|
141 |
-
"--max_single_segment_time",
|
142 |
-
default=20000,
|
143 |
-
type=int,
|
144 |
-
help="Maximum of Output single audio duration(ms)",
|
145 |
-
)
|
146 |
-
@click.option("--fsmn-vad/--silero-vad", default=False)
|
147 |
-
@click.option("--punc/--no-punc", default=False)
|
148 |
-
@click.option("--denoise/--no-denoise", default=False)
|
149 |
-
@click.option("--save_emo/--no_save_emo", default=False)
|
150 |
-
def main(
|
151 |
-
audio_dir: str,
|
152 |
-
save_dir: str,
|
153 |
-
device: str,
|
154 |
-
language: str,
|
155 |
-
max_single_segment_time: int,
|
156 |
-
fsmn_vad: bool,
|
157 |
-
punc: bool,
|
158 |
-
denoise: bool,
|
159 |
-
save_emo: bool,
|
160 |
-
):
|
161 |
-
|
162 |
-
audios_path = Path(audio_dir)
|
163 |
-
save_path = Path(save_dir)
|
164 |
-
save_path.mkdir(parents=True, exist_ok=True)
|
165 |
-
|
166 |
-
video_files = list_files(
|
167 |
-
path=audio_dir, extensions=VIDEO_EXTENSIONS, recursive=True
|
168 |
-
)
|
169 |
-
v2a_files = [convert_video_to_audio(p, audio_dir) for p in video_files]
|
170 |
-
|
171 |
-
if denoise:
|
172 |
-
VOCAL = "_(Vocals)"
|
173 |
-
original_files = [
|
174 |
-
p
|
175 |
-
for p in audios_path.glob("**/*")
|
176 |
-
if p.suffix in AUDIO_EXTENSIONS and VOCAL not in p.stem
|
177 |
-
]
|
178 |
-
|
179 |
-
_, cur_model = uvr5_cli(
|
180 |
-
audio_dir=audio_dir, output_folder=audio_dir, audio_files=original_files
|
181 |
-
)
|
182 |
-
need_remove = [p for p in audios_path.glob("**/*(Instrumental)*")]
|
183 |
-
need_remove.extend(original_files)
|
184 |
-
for _ in need_remove:
|
185 |
-
_.unlink()
|
186 |
-
vocal_files = [
|
187 |
-
p
|
188 |
-
for p in audios_path.glob("**/*")
|
189 |
-
if p.suffix in AUDIO_EXTENSIONS and VOCAL in p.stem
|
190 |
-
]
|
191 |
-
for f in vocal_files:
|
192 |
-
fn, ext = f.stem, f.suffix
|
193 |
-
|
194 |
-
v_pos = fn.find(VOCAL + "_" + cur_model.split(".")[0])
|
195 |
-
if v_pos != -1:
|
196 |
-
new_fn = fn[: v_pos + len(VOCAL)]
|
197 |
-
new_f = f.with_name(new_fn + ext)
|
198 |
-
f = f.rename(new_f)
|
199 |
-
convert_to_mono(f, f, "flac")
|
200 |
-
f.unlink()
|
201 |
-
|
202 |
-
audio_files = list_files(
|
203 |
-
path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
|
204 |
-
)
|
205 |
-
|
206 |
-
logger.info("Loading / Downloading Funasr model...")
|
207 |
-
|
208 |
-
model_dir = "iic/SenseVoiceSmall"
|
209 |
-
|
210 |
-
vad_model = "fsmn-vad" if fsmn_vad else None
|
211 |
-
vad_kwargs = {"max_single_segment_time": max_single_segment_time}
|
212 |
-
punc_model = "ct-punc" if punc else None
|
213 |
-
|
214 |
-
manager = AutoModel(
|
215 |
-
model=model_dir,
|
216 |
-
trust_remote_code=False,
|
217 |
-
vad_model=vad_model,
|
218 |
-
vad_kwargs=vad_kwargs,
|
219 |
-
punc_model=punc_model,
|
220 |
-
device=device,
|
221 |
-
)
|
222 |
-
|
223 |
-
if not fsmn_vad and vad_model is None:
|
224 |
-
vad_model = load_silero_vad()
|
225 |
-
|
226 |
-
logger.info("Model loaded.")
|
227 |
-
|
228 |
-
pattern = re.compile(r"_\d{3}\.")
|
229 |
-
|
230 |
-
for file_path in tqdm(audio_files, desc="Processing audio file"):
|
231 |
-
|
232 |
-
if pattern.search(file_path.name):
|
233 |
-
# logger.info(f"Skipping {file_path} as it has already been processed.")
|
234 |
-
continue
|
235 |
-
|
236 |
-
file_stem = file_path.stem
|
237 |
-
file_suffix = file_path.suffix
|
238 |
-
|
239 |
-
rel_path = Path(file_path).relative_to(audio_dir)
|
240 |
-
(save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
|
241 |
-
|
242 |
-
audio = AudioSegment.from_file(file_path)
|
243 |
-
|
244 |
-
cfg = dict(
|
245 |
-
cache={},
|
246 |
-
language=language, # "zh", "en", "yue", "ja", "ko", "nospeech"
|
247 |
-
use_itn=False,
|
248 |
-
batch_size_s=60,
|
249 |
-
)
|
250 |
-
|
251 |
-
if fsmn_vad:
|
252 |
-
elapsed, vad_res = manager.vad(input=str(file_path), **cfg)
|
253 |
-
else:
|
254 |
-
wav = read_audio(
|
255 |
-
str(file_path)
|
256 |
-
) # backend (sox, soundfile, or ffmpeg) required!
|
257 |
-
audio_key = file_path.stem
|
258 |
-
audio_val = []
|
259 |
-
speech_timestamps = get_speech_timestamps(
|
260 |
-
wav,
|
261 |
-
vad_model,
|
262 |
-
max_speech_duration_s=max_single_segment_time // 1000,
|
263 |
-
return_seconds=True,
|
264 |
-
)
|
265 |
-
|
266 |
-
audio_val = [
|
267 |
-
[int(timestamp["start"] * 1000), int(timestamp["end"] * 1000)]
|
268 |
-
for timestamp in speech_timestamps
|
269 |
-
]
|
270 |
-
vad_res = []
|
271 |
-
vad_res.append(dict(key=audio_key, value=audio_val))
|
272 |
-
|
273 |
-
res = manager.inference_with_vadres(
|
274 |
-
input=str(file_path), vad_res=vad_res, **cfg
|
275 |
-
)
|
276 |
-
|
277 |
-
for i, info in enumerate(res):
|
278 |
-
[start_ms, end_ms] = info["interval"]
|
279 |
-
text = info["text"]
|
280 |
-
emo = info["emo"]
|
281 |
-
sliced_audio = audio[start_ms:end_ms]
|
282 |
-
audio_save_path = (
|
283 |
-
save_path / rel_path.parent / f"{file_stem}_{i:03d}{file_suffix}"
|
284 |
-
)
|
285 |
-
sliced_audio.export(audio_save_path, format=file_suffix[1:])
|
286 |
-
print(f"Exported {audio_save_path}: {text}")
|
287 |
-
|
288 |
-
transcript_save_path = (
|
289 |
-
save_path / rel_path.parent / f"{file_stem}_{i:03d}.lab"
|
290 |
-
)
|
291 |
-
with open(
|
292 |
-
transcript_save_path,
|
293 |
-
"w",
|
294 |
-
encoding="utf-8",
|
295 |
-
) as f:
|
296 |
-
f.write(text)
|
297 |
-
|
298 |
-
if save_emo:
|
299 |
-
emo_save_path = save_path / rel_path.parent / f"{file_stem}_{i:03d}.emo"
|
300 |
-
with open(
|
301 |
-
emo_save_path,
|
302 |
-
"w",
|
303 |
-
encoding="utf-8",
|
304 |
-
) as f:
|
305 |
-
f.write(emo)
|
306 |
-
|
307 |
-
if audios_path.resolve() == save_path.resolve():
|
308 |
-
file_path.unlink()
|
309 |
-
|
310 |
-
|
311 |
-
if __name__ == "__main__":
|
312 |
-
main()
|
313 |
-
exit(0)
|
314 |
-
from funasr.utils.postprocess_utils import rich_transcription_postprocess
|
315 |
-
|
316 |
-
# Load the audio file
|
317 |
-
audio_path = Path(r"D:\PythonProject\ok\1_output_(Vocals).wav")
|
318 |
-
model_dir = "iic/SenseVoiceSmall"
|
319 |
-
m, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device="cuda:0")
|
320 |
-
m.eval()
|
321 |
-
|
322 |
-
res = m.inference(
|
323 |
-
data_in=f"{kwargs['model_path']}/example/zh.mp3",
|
324 |
-
language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
|
325 |
-
use_itn=False,
|
326 |
-
ban_emo_unk=False,
|
327 |
-
**kwargs,
|
328 |
-
)
|
329 |
-
|
330 |
-
print(res)
|
331 |
-
text = rich_transcription_postprocess(res[0][0]["text"])
|
332 |
-
print(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/sensevoice/vad_utils.py
DELETED
@@ -1,61 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch.nn.utils.rnn import pad_sequence
|
3 |
-
|
4 |
-
|
5 |
-
def slice_padding_fbank(speech, speech_lengths, vad_segments):
|
6 |
-
speech_list = []
|
7 |
-
speech_lengths_list = []
|
8 |
-
for i, segment in enumerate(vad_segments):
|
9 |
-
|
10 |
-
bed_idx = int(segment[0][0] * 16)
|
11 |
-
end_idx = min(int(segment[0][1] * 16), speech_lengths[0])
|
12 |
-
speech_i = speech[0, bed_idx:end_idx]
|
13 |
-
speech_lengths_i = end_idx - bed_idx
|
14 |
-
speech_list.append(speech_i)
|
15 |
-
speech_lengths_list.append(speech_lengths_i)
|
16 |
-
feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
|
17 |
-
speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
|
18 |
-
return feats_pad, speech_lengths_pad
|
19 |
-
|
20 |
-
|
21 |
-
def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
|
22 |
-
speech_list = []
|
23 |
-
speech_lengths_list = []
|
24 |
-
intervals = []
|
25 |
-
for i, segment in enumerate(vad_segments):
|
26 |
-
bed_idx = int(segment[0][0] * 16)
|
27 |
-
end_idx = min(int(segment[0][1] * 16), speech_lengths)
|
28 |
-
speech_i = speech[bed_idx:end_idx]
|
29 |
-
speech_lengths_i = end_idx - bed_idx
|
30 |
-
speech_list.append(speech_i)
|
31 |
-
speech_lengths_list.append(speech_lengths_i)
|
32 |
-
intervals.append([bed_idx // 16, end_idx // 16])
|
33 |
-
|
34 |
-
return speech_list, speech_lengths_list, intervals
|
35 |
-
|
36 |
-
|
37 |
-
def merge_vad(vad_result, max_length=15000, min_length=0):
|
38 |
-
new_result = []
|
39 |
-
if len(vad_result) <= 1:
|
40 |
-
return vad_result
|
41 |
-
time_step = [t[0] for t in vad_result] + [t[1] for t in vad_result]
|
42 |
-
time_step = sorted(list(set(time_step)))
|
43 |
-
if len(time_step) == 0:
|
44 |
-
return []
|
45 |
-
bg = 0
|
46 |
-
for i in range(len(time_step) - 1):
|
47 |
-
time = time_step[i]
|
48 |
-
if time_step[i + 1] - bg < max_length:
|
49 |
-
continue
|
50 |
-
if time - bg > min_length:
|
51 |
-
new_result.append([bg, time])
|
52 |
-
# if time - bg < max_length * 1.5:
|
53 |
-
# new_result.append([bg, time])
|
54 |
-
# else:
|
55 |
-
# split_num = int(time - bg) // max_length + 1
|
56 |
-
# spl_l = int(time - bg) // split_num
|
57 |
-
# for j in range(split_num):
|
58 |
-
# new_result.append([bg + j * spl_l, bg + (j + 1) * spl_l])
|
59 |
-
bg = time
|
60 |
-
new_result.append([bg, time_step[-1]])
|
61 |
-
return new_result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/smart_pad.py
DELETED
@@ -1,60 +0,0 @@
|
|
1 |
-
import random
|
2 |
-
from multiprocessing import Pool
|
3 |
-
from pathlib import Path
|
4 |
-
|
5 |
-
import click
|
6 |
-
import librosa
|
7 |
-
import torch.nn.functional as F
|
8 |
-
import torchaudio
|
9 |
-
from tqdm import tqdm
|
10 |
-
|
11 |
-
from tools.file import AUDIO_EXTENSIONS, list_files
|
12 |
-
|
13 |
-
threshold = 10 ** (-50 / 20.0)
|
14 |
-
|
15 |
-
|
16 |
-
def process(file):
|
17 |
-
waveform, sample_rate = torchaudio.load(str(file), backend="sox")
|
18 |
-
if waveform.size(0) > 1:
|
19 |
-
waveform = waveform.mean(dim=0, keepdim=True)
|
20 |
-
|
21 |
-
loudness = librosa.feature.rms(
|
22 |
-
y=waveform.numpy().squeeze(), frame_length=2048, hop_length=512, center=True
|
23 |
-
)[0]
|
24 |
-
|
25 |
-
for i in range(len(loudness) - 1, 0, -1):
|
26 |
-
if loudness[i] > threshold:
|
27 |
-
break
|
28 |
-
|
29 |
-
end_silent_time = (len(loudness) - i) * 512 / sample_rate
|
30 |
-
|
31 |
-
if end_silent_time <= 0.3:
|
32 |
-
random_time = random.uniform(0.3, 0.7) - end_silent_time
|
33 |
-
waveform = F.pad(
|
34 |
-
waveform, (0, int(random_time * sample_rate)), mode="constant", value=0
|
35 |
-
)
|
36 |
-
|
37 |
-
for i in range(len(loudness)):
|
38 |
-
if loudness[i] > threshold:
|
39 |
-
break
|
40 |
-
|
41 |
-
start_silent_time = i * 512 / sample_rate
|
42 |
-
|
43 |
-
if start_silent_time > 0.02:
|
44 |
-
waveform = waveform[:, int((start_silent_time - 0.02) * sample_rate) :]
|
45 |
-
|
46 |
-
torchaudio.save(uri=str(file), src=waveform, sample_rate=sample_rate)
|
47 |
-
|
48 |
-
|
49 |
-
@click.command()
|
50 |
-
@click.argument("source", type=Path)
|
51 |
-
@click.option("--num-workers", type=int, default=12)
|
52 |
-
def main(source, num_workers):
|
53 |
-
files = list(list_files(source, AUDIO_EXTENSIONS, recursive=True))
|
54 |
-
|
55 |
-
with Pool(num_workers) as p:
|
56 |
-
list(tqdm(p.imap_unordered(process, files), total=len(files)))
|
57 |
-
|
58 |
-
|
59 |
-
if __name__ == "__main__":
|
60 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/vqgan/__pycache__/inference.cpython-310.pyc
DELETED
Binary file (3.5 kB)
|
|
tools/vqgan/create_train_split.py
DELETED
@@ -1,83 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
from pathlib import Path
|
3 |
-
from random import Random
|
4 |
-
|
5 |
-
import click
|
6 |
-
from loguru import logger
|
7 |
-
from pydub import AudioSegment
|
8 |
-
from tqdm import tqdm
|
9 |
-
|
10 |
-
from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
|
11 |
-
|
12 |
-
|
13 |
-
@click.command()
|
14 |
-
@click.argument("root", type=click.Path(exists=True, path_type=Path))
|
15 |
-
@click.option("--val-ratio", type=float, default=None)
|
16 |
-
@click.option("--val-count", type=int, default=None)
|
17 |
-
@click.option("--filelist", default=None, type=Path)
|
18 |
-
@click.option("--min-duration", default=None, type=float)
|
19 |
-
@click.option("--max-duration", default=None, type=float)
|
20 |
-
def main(root, val_ratio, val_count, filelist, min_duration, max_duration):
|
21 |
-
if filelist:
|
22 |
-
files = [i[0] for i in load_filelist(filelist)]
|
23 |
-
else:
|
24 |
-
files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
|
25 |
-
|
26 |
-
if min_duration is None and max_duration is None:
|
27 |
-
filtered_files = list(map(str, [file.relative_to(root) for file in files]))
|
28 |
-
else:
|
29 |
-
filtered_files = []
|
30 |
-
for file in tqdm(files):
|
31 |
-
try:
|
32 |
-
audio = AudioSegment.from_file(str(file))
|
33 |
-
duration = len(audio) / 1000.0
|
34 |
-
|
35 |
-
if min_duration is not None and duration < min_duration:
|
36 |
-
logger.info(
|
37 |
-
f"Skipping {file} due to duration {duration:.2f} < {min_duration:.2f}"
|
38 |
-
)
|
39 |
-
continue
|
40 |
-
|
41 |
-
if max_duration is not None and duration > max_duration:
|
42 |
-
logger.info(
|
43 |
-
f"Skipping {file} due to duration {duration:.2f} > {max_duration:.2f}"
|
44 |
-
)
|
45 |
-
continue
|
46 |
-
|
47 |
-
filtered_files.append(str(file.relative_to(root)))
|
48 |
-
except Exception as e:
|
49 |
-
logger.info(f"Error processing {file}: {e}")
|
50 |
-
|
51 |
-
logger.info(
|
52 |
-
f"Found {len(files)} files, remaining {len(filtered_files)} files after filtering"
|
53 |
-
)
|
54 |
-
|
55 |
-
Random(42).shuffle(filtered_files)
|
56 |
-
|
57 |
-
if val_count is None and val_ratio is None:
|
58 |
-
logger.info("Validation ratio and count not specified, using min(20%, 100)")
|
59 |
-
val_size = min(100, math.ceil(len(filtered_files) * 0.2))
|
60 |
-
elif val_count is not None and val_ratio is not None:
|
61 |
-
logger.error("Cannot specify both val_count and val_ratio")
|
62 |
-
return
|
63 |
-
elif val_count is not None:
|
64 |
-
if val_count < 1 or val_count > len(filtered_files):
|
65 |
-
logger.error("val_count must be between 1 and number of files")
|
66 |
-
return
|
67 |
-
val_size = val_count
|
68 |
-
else:
|
69 |
-
val_size = math.ceil(len(filtered_files) * val_ratio)
|
70 |
-
|
71 |
-
logger.info(f"Using {val_size} files for validation")
|
72 |
-
|
73 |
-
with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
|
74 |
-
f.write("\n".join(filtered_files[val_size:]))
|
75 |
-
|
76 |
-
with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
|
77 |
-
f.write("\n".join(filtered_files[:val_size]))
|
78 |
-
|
79 |
-
logger.info("Done")
|
80 |
-
|
81 |
-
|
82 |
-
if __name__ == "__main__":
|
83 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/vqgan/extract_vq.py
DELETED
@@ -1,227 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import subprocess as sp
|
3 |
-
import sys
|
4 |
-
import time
|
5 |
-
from datetime import timedelta
|
6 |
-
from functools import lru_cache
|
7 |
-
from pathlib import Path
|
8 |
-
from random import Random
|
9 |
-
|
10 |
-
import click
|
11 |
-
import numpy as np
|
12 |
-
import torch
|
13 |
-
import torchaudio
|
14 |
-
from hydra import compose, initialize
|
15 |
-
from hydra.utils import instantiate
|
16 |
-
from lightning import LightningModule
|
17 |
-
from loguru import logger
|
18 |
-
from omegaconf import OmegaConf
|
19 |
-
|
20 |
-
from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
|
21 |
-
|
22 |
-
# register eval resolver
|
23 |
-
OmegaConf.register_new_resolver("eval", eval)
|
24 |
-
# This file is used to convert the audio files to text files using the Whisper model.
|
25 |
-
# It's mainly used to generate the training data for the VQ model.
|
26 |
-
|
27 |
-
|
28 |
-
RANK = int(os.environ.get("SLURM_PROCID", 0))
|
29 |
-
WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
|
30 |
-
|
31 |
-
logger_format = (
|
32 |
-
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
|
33 |
-
"<level>{level: <8}</level> | "
|
34 |
-
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
|
35 |
-
"{extra[rank]} - <level>{message}</level>"
|
36 |
-
)
|
37 |
-
logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
|
38 |
-
logger.remove()
|
39 |
-
logger.add(sys.stderr, format=logger_format)
|
40 |
-
|
41 |
-
|
42 |
-
@lru_cache(maxsize=1)
|
43 |
-
def get_model(
|
44 |
-
config_name: str = "firefly_gan_vq",
|
45 |
-
checkpoint_path: str = "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
46 |
-
device: str | torch.device = "cuda",
|
47 |
-
):
|
48 |
-
with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
|
49 |
-
cfg = compose(config_name=config_name)
|
50 |
-
|
51 |
-
model = instantiate(cfg)
|
52 |
-
state_dict = torch.load(
|
53 |
-
checkpoint_path,
|
54 |
-
map_location=device,
|
55 |
-
)
|
56 |
-
if "state_dict" in state_dict:
|
57 |
-
state_dict = state_dict["state_dict"]
|
58 |
-
|
59 |
-
if any("generator" in k for k in state_dict):
|
60 |
-
state_dict = {
|
61 |
-
k.replace("generator.", ""): v
|
62 |
-
for k, v in state_dict.items()
|
63 |
-
if "generator." in k
|
64 |
-
}
|
65 |
-
|
66 |
-
model.load_state_dict(state_dict, strict=False)
|
67 |
-
model.eval()
|
68 |
-
model.to(device)
|
69 |
-
|
70 |
-
logger.info(f"Loaded model")
|
71 |
-
return model
|
72 |
-
|
73 |
-
|
74 |
-
@torch.inference_mode()
|
75 |
-
def process_batch(files: list[Path], model) -> float:
|
76 |
-
wavs = []
|
77 |
-
audio_lengths = []
|
78 |
-
new_files = []
|
79 |
-
max_length = total_time = 0
|
80 |
-
|
81 |
-
for file in files:
|
82 |
-
try:
|
83 |
-
wav, sr = torchaudio.load(
|
84 |
-
str(file), backend="sox" if sys.platform == "linux" else "soundfile"
|
85 |
-
) # Need to install libsox-dev
|
86 |
-
except Exception as e:
|
87 |
-
logger.error(f"Error reading {file}: {e}")
|
88 |
-
continue
|
89 |
-
|
90 |
-
if wav.shape[0] > 1:
|
91 |
-
wav = wav.mean(dim=0, keepdim=True)
|
92 |
-
|
93 |
-
wav = torchaudio.functional.resample(
|
94 |
-
wav.cuda(), sr, model.spec_transform.sample_rate
|
95 |
-
)[0]
|
96 |
-
total_time += len(wav) / model.spec_transform.sample_rate
|
97 |
-
max_length = max(max_length, len(wav))
|
98 |
-
|
99 |
-
wavs.append(wav)
|
100 |
-
audio_lengths.append(len(wav))
|
101 |
-
new_files.append(file)
|
102 |
-
|
103 |
-
files = new_files
|
104 |
-
|
105 |
-
# Pad to max length
|
106 |
-
for i, wav in enumerate(wavs):
|
107 |
-
wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
|
108 |
-
|
109 |
-
audios = torch.stack(wavs, dim=0)[:, None]
|
110 |
-
audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
|
111 |
-
|
112 |
-
# Calculate lengths
|
113 |
-
indices, feature_lengths = model.encode(audios, audio_lengths)
|
114 |
-
|
115 |
-
# Save to disk
|
116 |
-
outputs = indices.cpu().numpy()
|
117 |
-
|
118 |
-
for file, length, feature, audio_length in zip(
|
119 |
-
files, feature_lengths, outputs, audio_lengths
|
120 |
-
):
|
121 |
-
feature = feature[:, :length]
|
122 |
-
|
123 |
-
# (T,)
|
124 |
-
with open(file.with_suffix(".npy"), "wb") as f:
|
125 |
-
np.save(f, feature)
|
126 |
-
|
127 |
-
return total_time
|
128 |
-
|
129 |
-
|
130 |
-
@click.command()
|
131 |
-
@click.argument("folder")
|
132 |
-
@click.option("--num-workers", default=1)
|
133 |
-
@click.option("--config-name", default="firefly_gan_vq")
|
134 |
-
@click.option(
|
135 |
-
"--checkpoint-path",
|
136 |
-
default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
137 |
-
)
|
138 |
-
@click.option("--batch-size", default=64)
|
139 |
-
@click.option("--filelist", default=None, type=Path)
|
140 |
-
def main(
|
141 |
-
folder: str,
|
142 |
-
num_workers: int,
|
143 |
-
config_name: str,
|
144 |
-
checkpoint_path: str,
|
145 |
-
batch_size: int,
|
146 |
-
filelist: Path,
|
147 |
-
):
|
148 |
-
if num_workers > 1 and WORLD_SIZE != num_workers:
|
149 |
-
assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
|
150 |
-
|
151 |
-
logger.info(f"Spawning {num_workers} workers")
|
152 |
-
|
153 |
-
if torch.cuda.is_available():
|
154 |
-
visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
155 |
-
if visible_devices is None:
|
156 |
-
visible_devices = list(range(torch.cuda.device_count()))
|
157 |
-
else:
|
158 |
-
visible_devices = visible_devices.split(",")
|
159 |
-
else:
|
160 |
-
# Set to empty string to avoid using GPU
|
161 |
-
visible_devices = [""]
|
162 |
-
|
163 |
-
processes = []
|
164 |
-
for i in range(num_workers):
|
165 |
-
env = os.environ.copy()
|
166 |
-
env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
|
167 |
-
env["SLURM_PROCID"] = str(i)
|
168 |
-
env["SLURM_NTASKS"] = str(num_workers)
|
169 |
-
|
170 |
-
processes.append(
|
171 |
-
sp.Popen(
|
172 |
-
[sys.executable] + sys.argv.copy(),
|
173 |
-
env=env,
|
174 |
-
)
|
175 |
-
)
|
176 |
-
|
177 |
-
for p in processes:
|
178 |
-
p.wait()
|
179 |
-
|
180 |
-
logger.info(f"All workers finished")
|
181 |
-
return
|
182 |
-
|
183 |
-
# This is a worker
|
184 |
-
logger.info(f"Starting worker")
|
185 |
-
if filelist:
|
186 |
-
files = [i[0] for i in load_filelist(filelist)]
|
187 |
-
else:
|
188 |
-
files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False)
|
189 |
-
|
190 |
-
print(f"Found {len(files)} files")
|
191 |
-
files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()]
|
192 |
-
|
193 |
-
total_files = len(files)
|
194 |
-
files = files[RANK::WORLD_SIZE]
|
195 |
-
logger.info(f"Processing {len(files)}/{total_files} files")
|
196 |
-
|
197 |
-
# Batch processing
|
198 |
-
total_time = 0
|
199 |
-
begin_time = time.time()
|
200 |
-
processed_files = 0
|
201 |
-
model = get_model(config_name, checkpoint_path)
|
202 |
-
|
203 |
-
for n_batch, idx in enumerate(range(0, len(files), batch_size)):
|
204 |
-
batch = files[idx : idx + batch_size]
|
205 |
-
batch_time = process_batch(batch, model)
|
206 |
-
|
207 |
-
total_time += batch_time
|
208 |
-
processed_files += len(batch)
|
209 |
-
|
210 |
-
if (n_batch + 1) % 10 == 0:
|
211 |
-
eta = (
|
212 |
-
(time.time() - begin_time)
|
213 |
-
/ processed_files
|
214 |
-
* (len(files) - processed_files)
|
215 |
-
)
|
216 |
-
logger.info(
|
217 |
-
f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
|
218 |
-
+ f"ETA: {timedelta(seconds=round(eta))}s"
|
219 |
-
)
|
220 |
-
|
221 |
-
logger.info(
|
222 |
-
f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
|
223 |
-
)
|
224 |
-
|
225 |
-
|
226 |
-
if __name__ == "__main__":
|
227 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/vqgan/inference.py
DELETED
@@ -1,122 +0,0 @@
|
|
1 |
-
from pathlib import Path
|
2 |
-
|
3 |
-
import click
|
4 |
-
import hydra
|
5 |
-
import numpy as np
|
6 |
-
import soundfile as sf
|
7 |
-
import torch
|
8 |
-
import torchaudio
|
9 |
-
from hydra import compose, initialize
|
10 |
-
from hydra.utils import instantiate
|
11 |
-
from loguru import logger
|
12 |
-
from omegaconf import OmegaConf
|
13 |
-
|
14 |
-
from tools.file import AUDIO_EXTENSIONS
|
15 |
-
|
16 |
-
# register eval resolver
|
17 |
-
OmegaConf.register_new_resolver("eval", eval)
|
18 |
-
|
19 |
-
|
20 |
-
def load_model(config_name, checkpoint_path, device="cuda"):
|
21 |
-
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
22 |
-
with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
|
23 |
-
cfg = compose(config_name=config_name)
|
24 |
-
|
25 |
-
model = instantiate(cfg)
|
26 |
-
state_dict = torch.load(
|
27 |
-
checkpoint_path,
|
28 |
-
map_location=device,
|
29 |
-
)
|
30 |
-
if "state_dict" in state_dict:
|
31 |
-
state_dict = state_dict["state_dict"]
|
32 |
-
|
33 |
-
if any("generator" in k for k in state_dict):
|
34 |
-
state_dict = {
|
35 |
-
k.replace("generator.", ""): v
|
36 |
-
for k, v in state_dict.items()
|
37 |
-
if "generator." in k
|
38 |
-
}
|
39 |
-
|
40 |
-
result = model.load_state_dict(state_dict, strict=False)
|
41 |
-
model.eval()
|
42 |
-
model.to(device)
|
43 |
-
|
44 |
-
logger.info(f"Loaded model: {result}")
|
45 |
-
return model
|
46 |
-
|
47 |
-
|
48 |
-
@torch.no_grad()
|
49 |
-
@click.command()
|
50 |
-
@click.option(
|
51 |
-
"--input-path",
|
52 |
-
"-i",
|
53 |
-
default="test.wav",
|
54 |
-
type=click.Path(exists=True, path_type=Path),
|
55 |
-
)
|
56 |
-
@click.option(
|
57 |
-
"--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
|
58 |
-
)
|
59 |
-
@click.option("--config-name", default="firefly_gan_vq")
|
60 |
-
@click.option(
|
61 |
-
"--checkpoint-path",
|
62 |
-
default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
63 |
-
)
|
64 |
-
@click.option(
|
65 |
-
"--device",
|
66 |
-
"-d",
|
67 |
-
default="cuda",
|
68 |
-
)
|
69 |
-
def main(input_path, output_path, config_name, checkpoint_path, device):
|
70 |
-
model = load_model(config_name, checkpoint_path, device=device)
|
71 |
-
|
72 |
-
if input_path.suffix in AUDIO_EXTENSIONS:
|
73 |
-
logger.info(f"Processing in-place reconstruction of {input_path}")
|
74 |
-
|
75 |
-
# Load audio
|
76 |
-
audio, sr = torchaudio.load(str(input_path))
|
77 |
-
if audio.shape[0] > 1:
|
78 |
-
audio = audio.mean(0, keepdim=True)
|
79 |
-
audio = torchaudio.functional.resample(
|
80 |
-
audio, sr, model.spec_transform.sample_rate
|
81 |
-
)
|
82 |
-
|
83 |
-
audios = audio[None].to(device)
|
84 |
-
logger.info(
|
85 |
-
f"Loaded audio with {audios.shape[2] / model.spec_transform.sample_rate:.2f} seconds"
|
86 |
-
)
|
87 |
-
|
88 |
-
# VQ Encoder
|
89 |
-
audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long)
|
90 |
-
indices = model.encode(audios, audio_lengths)[0][0]
|
91 |
-
|
92 |
-
logger.info(f"Generated indices of shape {indices.shape}")
|
93 |
-
|
94 |
-
# Save indices
|
95 |
-
np.save(output_path.with_suffix(".npy"), indices.cpu().numpy())
|
96 |
-
elif input_path.suffix == ".npy":
|
97 |
-
logger.info(f"Processing precomputed indices from {input_path}")
|
98 |
-
indices = np.load(input_path)
|
99 |
-
indices = torch.from_numpy(indices).to(device).long()
|
100 |
-
assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
|
101 |
-
else:
|
102 |
-
raise ValueError(f"Unknown input type: {input_path}")
|
103 |
-
|
104 |
-
# Restore
|
105 |
-
feature_lengths = torch.tensor([indices.shape[1]], device=device)
|
106 |
-
fake_audios, _ = model.decode(
|
107 |
-
indices=indices[None], feature_lengths=feature_lengths
|
108 |
-
)
|
109 |
-
audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate
|
110 |
-
|
111 |
-
logger.info(
|
112 |
-
f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
|
113 |
-
)
|
114 |
-
|
115 |
-
# Save audio
|
116 |
-
fake_audio = fake_audios[0, 0].float().cpu().numpy()
|
117 |
-
sf.write(output_path, fake_audio, model.spec_transform.sample_rate)
|
118 |
-
logger.info(f"Saved audio to {output_path}")
|
119 |
-
|
120 |
-
|
121 |
-
if __name__ == "__main__":
|
122 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/webui.py
DELETED
@@ -1,485 +0,0 @@
|
|
1 |
-
import gc
|
2 |
-
import html
|
3 |
-
import io
|
4 |
-
import os
|
5 |
-
import queue
|
6 |
-
import wave
|
7 |
-
from argparse import ArgumentParser
|
8 |
-
from functools import partial
|
9 |
-
from pathlib import Path
|
10 |
-
|
11 |
-
import gradio as gr
|
12 |
-
import librosa
|
13 |
-
import numpy as np
|
14 |
-
import pyrootutils
|
15 |
-
import torch
|
16 |
-
from loguru import logger
|
17 |
-
from transformers import AutoTokenizer
|
18 |
-
|
19 |
-
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
20 |
-
|
21 |
-
|
22 |
-
from fish_speech.i18n import i18n
|
23 |
-
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
24 |
-
from fish_speech.utils import autocast_exclude_mps
|
25 |
-
from tools.api import decode_vq_tokens, encode_reference
|
26 |
-
from tools.llama.generate import (
|
27 |
-
GenerateRequest,
|
28 |
-
GenerateResponse,
|
29 |
-
WrappedGenerateResponse,
|
30 |
-
launch_thread_safe_queue,
|
31 |
-
)
|
32 |
-
from tools.vqgan.inference import load_model as load_decoder_model
|
33 |
-
|
34 |
-
# Make einx happy
|
35 |
-
os.environ["EINX_FILTER_TRACEBACK"] = "false"
|
36 |
-
|
37 |
-
|
38 |
-
HEADER_MD = f"""# Fish Speech
|
39 |
-
|
40 |
-
{i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}
|
41 |
-
|
42 |
-
{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.4).")}
|
43 |
-
|
44 |
-
{i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")}
|
45 |
-
|
46 |
-
{i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")}
|
47 |
-
"""
|
48 |
-
|
49 |
-
TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
|
50 |
-
SPACE_IMPORTED = False
|
51 |
-
|
52 |
-
|
53 |
-
def build_html_error_message(error):
|
54 |
-
return f"""
|
55 |
-
<div style="color: red;
|
56 |
-
font-weight: bold;">
|
57 |
-
{html.escape(str(error))}
|
58 |
-
</div>
|
59 |
-
"""
|
60 |
-
|
61 |
-
|
62 |
-
@torch.inference_mode()
|
63 |
-
def inference(
|
64 |
-
text,
|
65 |
-
enable_reference_audio,
|
66 |
-
reference_audio,
|
67 |
-
reference_text,
|
68 |
-
max_new_tokens,
|
69 |
-
chunk_length,
|
70 |
-
top_p,
|
71 |
-
repetition_penalty,
|
72 |
-
temperature,
|
73 |
-
streaming=False,
|
74 |
-
):
|
75 |
-
if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
|
76 |
-
return (
|
77 |
-
None,
|
78 |
-
None,
|
79 |
-
i18n("Text is too long, please keep it under {} characters.").format(
|
80 |
-
args.max_gradio_length
|
81 |
-
),
|
82 |
-
)
|
83 |
-
|
84 |
-
# Parse reference audio aka prompt
|
85 |
-
prompt_tokens = encode_reference(
|
86 |
-
decoder_model=decoder_model,
|
87 |
-
reference_audio=reference_audio,
|
88 |
-
enable_reference_audio=enable_reference_audio,
|
89 |
-
)
|
90 |
-
|
91 |
-
# LLAMA Inference
|
92 |
-
request = dict(
|
93 |
-
device=decoder_model.device,
|
94 |
-
max_new_tokens=max_new_tokens,
|
95 |
-
text=text,
|
96 |
-
top_p=top_p,
|
97 |
-
repetition_penalty=repetition_penalty,
|
98 |
-
temperature=temperature,
|
99 |
-
compile=args.compile,
|
100 |
-
iterative_prompt=chunk_length > 0,
|
101 |
-
chunk_length=chunk_length,
|
102 |
-
max_length=2048,
|
103 |
-
prompt_tokens=prompt_tokens if enable_reference_audio else None,
|
104 |
-
prompt_text=reference_text if enable_reference_audio else None,
|
105 |
-
)
|
106 |
-
|
107 |
-
response_queue = queue.Queue()
|
108 |
-
llama_queue.put(
|
109 |
-
GenerateRequest(
|
110 |
-
request=request,
|
111 |
-
response_queue=response_queue,
|
112 |
-
)
|
113 |
-
)
|
114 |
-
|
115 |
-
if streaming:
|
116 |
-
yield wav_chunk_header(), None, None
|
117 |
-
|
118 |
-
segments = []
|
119 |
-
|
120 |
-
while True:
|
121 |
-
result: WrappedGenerateResponse = response_queue.get()
|
122 |
-
if result.status == "error":
|
123 |
-
yield None, None, build_html_error_message(result.response)
|
124 |
-
break
|
125 |
-
|
126 |
-
result: GenerateResponse = result.response
|
127 |
-
if result.action == "next":
|
128 |
-
break
|
129 |
-
|
130 |
-
with autocast_exclude_mps(
|
131 |
-
device_type=decoder_model.device.type, dtype=args.precision
|
132 |
-
):
|
133 |
-
fake_audios = decode_vq_tokens(
|
134 |
-
decoder_model=decoder_model,
|
135 |
-
codes=result.codes,
|
136 |
-
)
|
137 |
-
|
138 |
-
fake_audios = fake_audios.float().cpu().numpy()
|
139 |
-
segments.append(fake_audios)
|
140 |
-
|
141 |
-
if streaming:
|
142 |
-
yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None
|
143 |
-
|
144 |
-
if len(segments) == 0:
|
145 |
-
return (
|
146 |
-
None,
|
147 |
-
None,
|
148 |
-
build_html_error_message(
|
149 |
-
i18n("No audio generated, please check the input text.")
|
150 |
-
),
|
151 |
-
)
|
152 |
-
|
153 |
-
# No matter streaming or not, we need to return the final audio
|
154 |
-
audio = np.concatenate(segments, axis=0)
|
155 |
-
yield None, (decoder_model.spec_transform.sample_rate, audio), None
|
156 |
-
|
157 |
-
if torch.cuda.is_available():
|
158 |
-
torch.cuda.empty_cache()
|
159 |
-
gc.collect()
|
160 |
-
|
161 |
-
|
162 |
-
inference_stream = partial(inference, streaming=True)
|
163 |
-
|
164 |
-
n_audios = 4
|
165 |
-
|
166 |
-
global_audio_list = []
|
167 |
-
global_error_list = []
|
168 |
-
|
169 |
-
|
170 |
-
def inference_wrapper(
|
171 |
-
text,
|
172 |
-
enable_reference_audio,
|
173 |
-
reference_audio,
|
174 |
-
reference_text,
|
175 |
-
max_new_tokens,
|
176 |
-
chunk_length,
|
177 |
-
top_p,
|
178 |
-
repetition_penalty,
|
179 |
-
temperature,
|
180 |
-
batch_infer_num,
|
181 |
-
):
|
182 |
-
audios = []
|
183 |
-
errors = []
|
184 |
-
|
185 |
-
for _ in range(batch_infer_num):
|
186 |
-
result = inference(
|
187 |
-
text,
|
188 |
-
enable_reference_audio,
|
189 |
-
reference_audio,
|
190 |
-
reference_text,
|
191 |
-
max_new_tokens,
|
192 |
-
chunk_length,
|
193 |
-
top_p,
|
194 |
-
repetition_penalty,
|
195 |
-
temperature,
|
196 |
-
)
|
197 |
-
|
198 |
-
_, audio_data, error_message = next(result)
|
199 |
-
|
200 |
-
audios.append(
|
201 |
-
gr.Audio(value=audio_data if audio_data else None, visible=True),
|
202 |
-
)
|
203 |
-
errors.append(
|
204 |
-
gr.HTML(value=error_message if error_message else None, visible=True),
|
205 |
-
)
|
206 |
-
|
207 |
-
for _ in range(batch_infer_num, n_audios):
|
208 |
-
audios.append(
|
209 |
-
gr.Audio(value=None, visible=False),
|
210 |
-
)
|
211 |
-
errors.append(
|
212 |
-
gr.HTML(value=None, visible=False),
|
213 |
-
)
|
214 |
-
|
215 |
-
return None, *audios, *errors
|
216 |
-
|
217 |
-
|
218 |
-
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
219 |
-
buffer = io.BytesIO()
|
220 |
-
|
221 |
-
with wave.open(buffer, "wb") as wav_file:
|
222 |
-
wav_file.setnchannels(channels)
|
223 |
-
wav_file.setsampwidth(bit_depth // 8)
|
224 |
-
wav_file.setframerate(sample_rate)
|
225 |
-
|
226 |
-
wav_header_bytes = buffer.getvalue()
|
227 |
-
buffer.close()
|
228 |
-
return wav_header_bytes
|
229 |
-
|
230 |
-
|
231 |
-
def normalize_text(user_input, use_normalization):
|
232 |
-
if use_normalization:
|
233 |
-
return ChnNormedText(raw_text=user_input).normalize()
|
234 |
-
else:
|
235 |
-
return user_input
|
236 |
-
|
237 |
-
|
238 |
-
asr_model = None
|
239 |
-
|
240 |
-
|
241 |
-
def build_app():
|
242 |
-
with gr.Blocks(theme=gr.themes.Base()) as app:
|
243 |
-
gr.Markdown(HEADER_MD)
|
244 |
-
|
245 |
-
# Use light theme by default
|
246 |
-
app.load(
|
247 |
-
None,
|
248 |
-
None,
|
249 |
-
js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
|
250 |
-
% args.theme,
|
251 |
-
)
|
252 |
-
|
253 |
-
# Inference
|
254 |
-
with gr.Row():
|
255 |
-
with gr.Column(scale=3):
|
256 |
-
text = gr.Textbox(
|
257 |
-
label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
|
258 |
-
)
|
259 |
-
refined_text = gr.Textbox(
|
260 |
-
label=i18n("Realtime Transform Text"),
|
261 |
-
placeholder=i18n(
|
262 |
-
"Normalization Result Preview (Currently Only Chinese)"
|
263 |
-
),
|
264 |
-
lines=5,
|
265 |
-
interactive=False,
|
266 |
-
)
|
267 |
-
|
268 |
-
with gr.Row():
|
269 |
-
if_refine_text = gr.Checkbox(
|
270 |
-
label=i18n("Text Normalization"),
|
271 |
-
value=False,
|
272 |
-
scale=1,
|
273 |
-
)
|
274 |
-
|
275 |
-
with gr.Row():
|
276 |
-
with gr.Tab(label=i18n("Advanced Config")):
|
277 |
-
chunk_length = gr.Slider(
|
278 |
-
label=i18n("Iterative Prompt Length, 0 means off"),
|
279 |
-
minimum=50,
|
280 |
-
maximum=300,
|
281 |
-
value=200,
|
282 |
-
step=8,
|
283 |
-
)
|
284 |
-
|
285 |
-
max_new_tokens = gr.Slider(
|
286 |
-
label=i18n("Maximum tokens per batch, 0 means no limit"),
|
287 |
-
minimum=0,
|
288 |
-
maximum=2048,
|
289 |
-
value=1024, # 0 means no limit
|
290 |
-
step=8,
|
291 |
-
)
|
292 |
-
|
293 |
-
top_p = gr.Slider(
|
294 |
-
label="Top-P",
|
295 |
-
minimum=0.6,
|
296 |
-
maximum=0.9,
|
297 |
-
value=0.7,
|
298 |
-
step=0.01,
|
299 |
-
)
|
300 |
-
|
301 |
-
repetition_penalty = gr.Slider(
|
302 |
-
label=i18n("Repetition Penalty"),
|
303 |
-
minimum=1,
|
304 |
-
maximum=1.5,
|
305 |
-
value=1.2,
|
306 |
-
step=0.01,
|
307 |
-
)
|
308 |
-
|
309 |
-
temperature = gr.Slider(
|
310 |
-
label="Temperature",
|
311 |
-
minimum=0.6,
|
312 |
-
maximum=0.9,
|
313 |
-
value=0.7,
|
314 |
-
step=0.01,
|
315 |
-
)
|
316 |
-
|
317 |
-
with gr.Tab(label=i18n("Reference Audio")):
|
318 |
-
gr.Markdown(
|
319 |
-
i18n(
|
320 |
-
"5 to 10 seconds of reference audio, useful for specifying speaker."
|
321 |
-
)
|
322 |
-
)
|
323 |
-
|
324 |
-
enable_reference_audio = gr.Checkbox(
|
325 |
-
label=i18n("Enable Reference Audio"),
|
326 |
-
)
|
327 |
-
reference_audio = gr.Audio(
|
328 |
-
label=i18n("Reference Audio"),
|
329 |
-
type="filepath",
|
330 |
-
)
|
331 |
-
with gr.Row():
|
332 |
-
reference_text = gr.Textbox(
|
333 |
-
label=i18n("Reference Text"),
|
334 |
-
lines=1,
|
335 |
-
placeholder="在一无所知中,梦里的一天结束了,一���新的「轮回」便会开始。",
|
336 |
-
value="",
|
337 |
-
)
|
338 |
-
with gr.Tab(label=i18n("Batch Inference")):
|
339 |
-
batch_infer_num = gr.Slider(
|
340 |
-
label="Batch infer nums",
|
341 |
-
minimum=1,
|
342 |
-
maximum=n_audios,
|
343 |
-
step=1,
|
344 |
-
value=1,
|
345 |
-
)
|
346 |
-
|
347 |
-
with gr.Column(scale=3):
|
348 |
-
for _ in range(n_audios):
|
349 |
-
with gr.Row():
|
350 |
-
error = gr.HTML(
|
351 |
-
label=i18n("Error Message"),
|
352 |
-
visible=True if _ == 0 else False,
|
353 |
-
)
|
354 |
-
global_error_list.append(error)
|
355 |
-
with gr.Row():
|
356 |
-
audio = gr.Audio(
|
357 |
-
label=i18n("Generated Audio"),
|
358 |
-
type="numpy",
|
359 |
-
interactive=False,
|
360 |
-
visible=True if _ == 0 else False,
|
361 |
-
)
|
362 |
-
global_audio_list.append(audio)
|
363 |
-
|
364 |
-
with gr.Row():
|
365 |
-
stream_audio = gr.Audio(
|
366 |
-
label=i18n("Streaming Audio"),
|
367 |
-
streaming=True,
|
368 |
-
autoplay=True,
|
369 |
-
interactive=False,
|
370 |
-
show_download_button=True,
|
371 |
-
)
|
372 |
-
with gr.Row():
|
373 |
-
with gr.Column(scale=3):
|
374 |
-
generate = gr.Button(
|
375 |
-
value="\U0001F3A7 " + i18n("Generate"), variant="primary"
|
376 |
-
)
|
377 |
-
generate_stream = gr.Button(
|
378 |
-
value="\U0001F3A7 " + i18n("Streaming Generate"),
|
379 |
-
variant="primary",
|
380 |
-
)
|
381 |
-
|
382 |
-
text.input(
|
383 |
-
fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
|
384 |
-
)
|
385 |
-
|
386 |
-
# # Submit
|
387 |
-
generate.click(
|
388 |
-
inference_wrapper,
|
389 |
-
[
|
390 |
-
refined_text,
|
391 |
-
enable_reference_audio,
|
392 |
-
reference_audio,
|
393 |
-
reference_text,
|
394 |
-
max_new_tokens,
|
395 |
-
chunk_length,
|
396 |
-
top_p,
|
397 |
-
repetition_penalty,
|
398 |
-
temperature,
|
399 |
-
batch_infer_num,
|
400 |
-
],
|
401 |
-
[stream_audio, *global_audio_list, *global_error_list],
|
402 |
-
concurrency_limit=1,
|
403 |
-
)
|
404 |
-
|
405 |
-
generate_stream.click(
|
406 |
-
inference_stream,
|
407 |
-
[
|
408 |
-
refined_text,
|
409 |
-
enable_reference_audio,
|
410 |
-
reference_audio,
|
411 |
-
reference_text,
|
412 |
-
max_new_tokens,
|
413 |
-
chunk_length,
|
414 |
-
top_p,
|
415 |
-
repetition_penalty,
|
416 |
-
temperature,
|
417 |
-
],
|
418 |
-
[stream_audio, global_audio_list[0], global_error_list[0]],
|
419 |
-
concurrency_limit=10,
|
420 |
-
)
|
421 |
-
return app
|
422 |
-
|
423 |
-
|
424 |
-
def parse_args():
|
425 |
-
parser = ArgumentParser()
|
426 |
-
parser.add_argument(
|
427 |
-
"--llama-checkpoint-path",
|
428 |
-
type=Path,
|
429 |
-
default="checkpoints/fish-speech-1.4",
|
430 |
-
)
|
431 |
-
parser.add_argument(
|
432 |
-
"--decoder-checkpoint-path",
|
433 |
-
type=Path,
|
434 |
-
default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
435 |
-
)
|
436 |
-
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
437 |
-
parser.add_argument("--device", type=str, default="cuda")
|
438 |
-
parser.add_argument("--half", action="store_true")
|
439 |
-
parser.add_argument("--compile", action="store_true")
|
440 |
-
parser.add_argument("--max-gradio-length", type=int, default=0)
|
441 |
-
parser.add_argument("--theme", type=str, default="light")
|
442 |
-
|
443 |
-
return parser.parse_args()
|
444 |
-
|
445 |
-
|
446 |
-
if __name__ == "__main__":
|
447 |
-
args = parse_args()
|
448 |
-
args.precision = torch.half if args.half else torch.bfloat16
|
449 |
-
|
450 |
-
logger.info("Loading Llama model...")
|
451 |
-
llama_queue = launch_thread_safe_queue(
|
452 |
-
checkpoint_path=args.llama_checkpoint_path,
|
453 |
-
device=args.device,
|
454 |
-
precision=args.precision,
|
455 |
-
compile=args.compile,
|
456 |
-
)
|
457 |
-
logger.info("Llama model loaded, loading VQ-GAN model...")
|
458 |
-
|
459 |
-
decoder_model = load_decoder_model(
|
460 |
-
config_name=args.decoder_config_name,
|
461 |
-
checkpoint_path=args.decoder_checkpoint_path,
|
462 |
-
device=args.device,
|
463 |
-
)
|
464 |
-
|
465 |
-
logger.info("Decoder model loaded, warming up...")
|
466 |
-
|
467 |
-
# Dry run to check if the model is loaded correctly and avoid the first-time latency
|
468 |
-
list(
|
469 |
-
inference(
|
470 |
-
text="Hello, world!",
|
471 |
-
enable_reference_audio=False,
|
472 |
-
reference_audio=None,
|
473 |
-
reference_text="",
|
474 |
-
max_new_tokens=1024,
|
475 |
-
chunk_length=200,
|
476 |
-
top_p=0.7,
|
477 |
-
repetition_penalty=1.2,
|
478 |
-
temperature=0.7,
|
479 |
-
)
|
480 |
-
)
|
481 |
-
|
482 |
-
logger.info("Warming up done, launching the web UI...")
|
483 |
-
|
484 |
-
app = build_app()
|
485 |
-
app.launch(show_api=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/whisper_asr.py
DELETED
@@ -1,176 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Used to transcribe all audio files in one folder into another folder.
|
3 |
-
e.g.
|
4 |
-
Directory structure:
|
5 |
-
--pre_data_root
|
6 |
-
----SP_1
|
7 |
-
------01.wav
|
8 |
-
------02.wav
|
9 |
-
------......
|
10 |
-
----SP_2
|
11 |
-
------01.wav
|
12 |
-
------02.wav
|
13 |
-
------......
|
14 |
-
Use
|
15 |
-
python tools/whisper_asr.py --audio-dir pre_data_root/SP_1 --save-dir data/SP_1
|
16 |
-
to transcribe the first speaker.
|
17 |
-
|
18 |
-
Use
|
19 |
-
python tools/whisper_asr.py --audio-dir pre_data_root/SP_2 --save-dir data/SP_2
|
20 |
-
to transcribe the second speaker.
|
21 |
-
|
22 |
-
Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
|
23 |
-
"""
|
24 |
-
|
25 |
-
import re
|
26 |
-
from pathlib import Path
|
27 |
-
|
28 |
-
import click
|
29 |
-
import soundfile as sf
|
30 |
-
from faster_whisper import WhisperModel
|
31 |
-
from loguru import logger
|
32 |
-
from pydub import AudioSegment
|
33 |
-
from tqdm import tqdm
|
34 |
-
|
35 |
-
from tools.file import AUDIO_EXTENSIONS, list_files
|
36 |
-
|
37 |
-
|
38 |
-
@click.command()
|
39 |
-
@click.option("--model-size", default="large-v3", help="Size of the Whisper model")
|
40 |
-
@click.option(
|
41 |
-
"--compute-type",
|
42 |
-
default="float16",
|
43 |
-
help="Computation Precision of the Whisper model [float16 / int8_float16 / int8]",
|
44 |
-
)
|
45 |
-
@click.option("--audio-dir", required=True, help="Directory containing audio files")
|
46 |
-
@click.option(
|
47 |
-
"--save-dir", required=True, help="Directory to save processed audio files"
|
48 |
-
)
|
49 |
-
@click.option(
|
50 |
-
"--sample-rate",
|
51 |
-
default=44100,
|
52 |
-
type=int,
|
53 |
-
help="Output sample rate, default to input sample rate",
|
54 |
-
)
|
55 |
-
@click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
|
56 |
-
@click.option("--language", default="auto", help="Language of the transcription")
|
57 |
-
@click.option("--initial-prompt", default=None, help="Initial prompt for transcribing")
|
58 |
-
def main(
|
59 |
-
model_size,
|
60 |
-
compute_type,
|
61 |
-
audio_dir,
|
62 |
-
save_dir,
|
63 |
-
sample_rate,
|
64 |
-
device,
|
65 |
-
language,
|
66 |
-
initial_prompt,
|
67 |
-
):
|
68 |
-
logger.info("Loading / Downloading Faster Whisper model...")
|
69 |
-
|
70 |
-
model = WhisperModel(
|
71 |
-
model_size,
|
72 |
-
device=device,
|
73 |
-
compute_type=compute_type,
|
74 |
-
download_root="faster_whisper",
|
75 |
-
)
|
76 |
-
|
77 |
-
logger.info("Model loaded.")
|
78 |
-
|
79 |
-
save_path = Path(save_dir)
|
80 |
-
save_path.mkdir(parents=True, exist_ok=True)
|
81 |
-
|
82 |
-
audio_files = list_files(
|
83 |
-
path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
|
84 |
-
)
|
85 |
-
|
86 |
-
for file_path in tqdm(audio_files, desc="Processing audio file"):
|
87 |
-
file_stem = file_path.stem
|
88 |
-
file_suffix = file_path.suffix
|
89 |
-
|
90 |
-
rel_path = Path(file_path).relative_to(audio_dir)
|
91 |
-
(save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
|
92 |
-
|
93 |
-
audio = AudioSegment.from_file(file_path)
|
94 |
-
|
95 |
-
segments, info = model.transcribe(
|
96 |
-
file_path,
|
97 |
-
beam_size=5,
|
98 |
-
language=None if language == "auto" else language,
|
99 |
-
initial_prompt=initial_prompt,
|
100 |
-
)
|
101 |
-
|
102 |
-
print(
|
103 |
-
"Detected language '%s' with probability %f"
|
104 |
-
% (info.language, info.language_probability)
|
105 |
-
)
|
106 |
-
print("Total len(ms): ", len(audio))
|
107 |
-
|
108 |
-
whole_text = None
|
109 |
-
for segment in segments:
|
110 |
-
id, start, end, text = (
|
111 |
-
segment.id,
|
112 |
-
segment.start,
|
113 |
-
segment.end,
|
114 |
-
segment.text,
|
115 |
-
)
|
116 |
-
print("Segment %03d [%.2fs -> %.2fs] %s" % (id, start, end, text))
|
117 |
-
if not whole_text:
|
118 |
-
whole_text = text
|
119 |
-
else:
|
120 |
-
whole_text += ", " + text
|
121 |
-
|
122 |
-
whole_text += "."
|
123 |
-
|
124 |
-
audio_save_path = save_path / rel_path.parent / f"{file_stem}{file_suffix}"
|
125 |
-
audio.export(audio_save_path, format=file_suffix[1:])
|
126 |
-
print(f"Exported {audio_save_path}")
|
127 |
-
|
128 |
-
transcript_save_path = save_path / rel_path.parent / f"{file_stem}.lab"
|
129 |
-
with open(
|
130 |
-
transcript_save_path,
|
131 |
-
"w",
|
132 |
-
encoding="utf-8",
|
133 |
-
) as f:
|
134 |
-
f.write(whole_text)
|
135 |
-
|
136 |
-
|
137 |
-
if __name__ == "__main__":
|
138 |
-
main()
|
139 |
-
exit(0)
|
140 |
-
|
141 |
-
audio = AudioSegment.from_wav(
|
142 |
-
r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav"
|
143 |
-
)
|
144 |
-
|
145 |
-
model_size = "large-v3"
|
146 |
-
|
147 |
-
model = WhisperModel(
|
148 |
-
model_size,
|
149 |
-
device="cuda",
|
150 |
-
compute_type="float16",
|
151 |
-
download_root="faster_whisper",
|
152 |
-
)
|
153 |
-
|
154 |
-
segments, info = model.transcribe(
|
155 |
-
r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav",
|
156 |
-
beam_size=5,
|
157 |
-
)
|
158 |
-
|
159 |
-
print(
|
160 |
-
"Detected language '%s' with probability %f"
|
161 |
-
% (info.language, info.language_probability)
|
162 |
-
)
|
163 |
-
print("Total len(ms): ", len(audio))
|
164 |
-
|
165 |
-
for i, segment in enumerate(segments):
|
166 |
-
print(
|
167 |
-
"Segment %03d [%.2fs -> %.2fs] %s"
|
168 |
-
% (i, segment.start, segment.end, segment.text)
|
169 |
-
)
|
170 |
-
start_ms = int(segment.start * 1000)
|
171 |
-
end_ms = int(segment.end * 1000)
|
172 |
-
segment_audio = audio[start_ms:end_ms]
|
173 |
-
segment_audio.export(f"segment_{i:03d}.wav", format="wav")
|
174 |
-
print(f"Exported segment_{i:03d}.wav")
|
175 |
-
|
176 |
-
print("All segments have been exported.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|