|
--- |
|
base_model: |
|
- m-a-p/YuE-s1-7B-anneal-en-cot |
|
--- |
|
|
|
# Sample Inference Script |
|
```py |
|
import random |
|
import re |
|
import sys |
|
from argparse import ArgumentParser |
|
from pathlib import Path |
|
|
|
sys.path.append("xcodec_mini_infer") |
|
|
|
import torch |
|
import torchaudio |
|
import yaml |
|
from exllamav2 import ( |
|
ExLlamaV2, |
|
ExLlamaV2Cache, |
|
ExLlamaV2Config, |
|
ExLlamaV2Tokenizer, |
|
Timer, |
|
) |
|
from exllamav2.generator import ( |
|
ExLlamaV2DynamicGenerator, |
|
ExLlamaV2DynamicJob, |
|
ExLlamaV2Sampler, |
|
) |
|
from rich import print |
|
|
|
from xcodec_mini_infer.models.soundstream_hubert_new import SoundStream |
|
|
|
parser = ArgumentParser() |
|
parser.add_argument("-s1", "--stage-1", required=True) |
|
parser.add_argument("-g", "--genre", default="genre.txt") |
|
parser.add_argument("-l", "--lyrics", default="lyrics.txt") |
|
parser.add_argument("-d", "--debug", action="store_true") |
|
parser.add_argument("-s", "--seed", type=int, default=None) |
|
parser.add_argument("--sample_rate", type=int, default=16000) |
|
parser.add_argument("--repetition_penalty", type=float, default=1.2) |
|
parser.add_argument("--temperature", type=float, default=1.0) |
|
parser.add_argument("--top_p", type=float, default=0.93) |
|
args = parser.parse_args() |
|
|
|
with Timer() as timer: |
|
config = ExLlamaV2Config(args.stage_1) |
|
model = ExLlamaV2(config, lazy_load=True) |
|
cache = ExLlamaV2Cache(model, lazy=True) |
|
model.load_autosplit(cache) |
|
|
|
tokenizer = ExLlamaV2Tokenizer(config, lazy_init=True) |
|
generator = ExLlamaV2DynamicGenerator(model, cache, tokenizer) |
|
generator.warmup() |
|
|
|
print(f"Loaded stage 1 model in {timer.interval:.2f} seconds.") |
|
|
|
genre = Path(args.genre) |
|
genre = genre.read_text(encoding="utf-8") if genre.is_file() else args.genre |
|
genre = genre.strip() |
|
|
|
lyrics = Path(args.lyrics) |
|
lyrics = lyrics.read_text(encoding="utf-8") if lyrics.is_file() else args.lyrics |
|
lyrics = lyrics.strip() |
|
|
|
lyrics = re.findall(r"\[(\w+)\](.*?)\n(?=\[|\Z)", lyrics, re.DOTALL) |
|
lyrics = [f"[{l[0]}]\n{l[1].strip()}\n\n" for l in lyrics] |
|
lyrics_joined = "\n".join(lyrics) |
|
|
|
gen_settings = ExLlamaV2Sampler.Settings() |
|
gen_settings.allow_tokens(tokenizer, [32002] + list(range(45334, 46358))) |
|
gen_settings.temperature = args.temperature |
|
gen_settings.token_repetition_penalty = args.repetition_penalty |
|
gen_settings.top_p = args.top_p |
|
|
|
seed = args.seed if args.seed else random.randint(0, 2**64 - 1) |
|
stop_conditions = ["<EOA>"] |
|
|
|
output_joined = "" |
|
output = [] |
|
|
|
with Timer() as timer: |
|
for segment in lyrics: |
|
current = [] |
|
|
|
input = ( |
|
"Generate music from the given lyrics segment by segment.\n" |
|
f"[Genre] {genre}\n" |
|
f"{lyrics_joined}{output_joined}[start_of_segment]{segment}<SOA><xcodec>" |
|
) |
|
|
|
input_ids = tokenizer.encode(input, encode_special_tokens=True) |
|
input_len = input_ids.shape[-1] |
|
max_new_tokens = config.max_seq_len - input_len |
|
|
|
print( |
|
f"Using {input_len} tokens of {config.max_seq_len} tokens " |
|
f"with {max_new_tokens} tokens left." |
|
) |
|
|
|
job = ExLlamaV2DynamicJob( |
|
input_ids=input_ids, |
|
max_new_tokens=max_new_tokens, |
|
gen_settings=gen_settings, |
|
seed=seed, |
|
stop_conditions=stop_conditions, |
|
decode_special_tokens=True, |
|
) |
|
|
|
generator.enqueue(job) |
|
|
|
with Timer() as inner: |
|
while generator.num_remaining_jobs(): |
|
for result in generator.iterate(): |
|
if result.get("stage") == "streaming": |
|
text = result.get("text") |
|
|
|
if text: |
|
current.append(text) |
|
output.append(text) |
|
|
|
if args.debug: |
|
print(text, end="", flush=True) |
|
|
|
if result.get("eos") and current: |
|
current_joined = "".join(current) |
|
output_joined += ( |
|
f"[start_of_segment]{segment}<SOA><xcodec>" |
|
f"{current_joined}<EOA>[end_of_segment]" |
|
) |
|
|
|
if args.debug: |
|
print() |
|
|
|
print(f"Generated {len(current)} tokens in {inner.interval:.2f} seconds.") |
|
|
|
print(f"Finished in {timer.interval:.2f} seconds with seed {seed}.") |
|
|
|
with Timer() as timer: |
|
codec_config = Path("xcodec_mini_infer/final_ckpt/config.yaml") |
|
codec_config = yaml.safe_load(codec_config.read_bytes()) |
|
codec = SoundStream(**codec_config["generator"]["config"]) |
|
state_dict = torch.load("xcodec_mini_infer/final_ckpt/ckpt_00360000.pth") |
|
codec.load_state_dict(state_dict["codec_model"]) |
|
codec = codec.eval().cuda() |
|
|
|
print(f"Loaded codec in {timer.interval:.2f} seconds.") |
|
|
|
with Timer() as timer, torch.inference_mode(): |
|
pattern = re.compile(r"<xcodec/0/(\d+)>") |
|
output_ids = [int(o[10:-1]) for o in output if re.match(pattern, o)] |
|
|
|
vocal = output_ids[::2] |
|
vocal = torch.tensor([[vocal]]).cuda() |
|
vocal = vocal.permute(1, 0, 2) |
|
vocal = codec.decode(vocal) |
|
vocal = vocal.squeeze(0).cpu() |
|
torchaudio.save("vocal.wav", vocal, args.sample_rate) |
|
|
|
inst = output_ids[1::2] |
|
inst = torch.tensor([[inst]]).cuda() |
|
inst = inst.permute(1, 0, 2) |
|
inst = codec.decode(inst) |
|
inst = inst.squeeze(0).cpu() |
|
torchaudio.save("inst.wav", inst, args.sample_rate) |
|
|
|
print(f"Decoded audio in {timer.interval:.2f} seconds.") |
|
``` |