File size: 5,404 Bytes
8c28a4c 87043db 5ba94a4 87043db 5ba94a4 87043db 5ba94a4 87043db 5ba94a4 87043db 411e730 87043db 5ba94a4 87043db 5ba94a4 87043db 5ba94a4 87043db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
---
base_model:
- m-a-p/YuE-s1-7B-anneal-en-cot
---
# Sample Inference Script
```py
import random
import re
import sys
from argparse import ArgumentParser
from pathlib import Path
from warnings import simplefilter
sys.path.append("xcodec_mini_infer")
simplefilter("ignore")
import torch
import torchaudio
import yaml
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Cache,
ExLlamaV2Config,
ExLlamaV2Tokenizer,
Timer,
)
from exllamav2.generator import (
ExLlamaV2DynamicGenerator,
ExLlamaV2DynamicJob,
ExLlamaV2Sampler,
)
from rich import print
from xcodec_mini_infer.models.soundstream_hubert_new import SoundStream
parser = ArgumentParser()
parser.add_argument("-m", "--model", required=True)
parser.add_argument("-g", "--genre", required=True)
parser.add_argument("-l", "--lyrics", required=True)
parser.add_argument("-s", "--seed", type=int, default=None)
parser.add_argument("-d", "--debug", action="store_true")
parser.add_argument("--repetition_penalty", type=float, default=1.2)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_p", type=float, default=0.93)
args = parser.parse_args()
with Timer() as timer:
config = ExLlamaV2Config(args.model)
model = ExLlamaV2(config, lazy_load=True)
cache = ExLlamaV2Cache(model, lazy=True)
model.load_autosplit(cache)
tokenizer = ExLlamaV2Tokenizer(config, lazy_init=True)
generator = ExLlamaV2DynamicGenerator(model, cache, tokenizer)
generator.warmup()
print(f"Loaded model in {timer.interval:.2f} seconds.")
genre = Path(args.genre)
genre = genre.read_text(encoding="utf-8") if genre.is_file() else args.genre
genre = genre.strip()
lyrics = Path(args.lyrics)
lyrics = lyrics.read_text(encoding="utf-8") if lyrics.is_file() else args.lyrics
lyrics = lyrics.strip()
lyrics = re.findall(r"\[(\w+)\](.*?)\n(?=\[|\Z)", lyrics, re.DOTALL)
lyrics = [f"[{l[0]}]\n{l[1].strip()}\n\n" for l in lyrics]
lyrics_joined = "\n".join(lyrics)
gen_settings = ExLlamaV2Sampler.Settings()
gen_settings.allow_tokens(tokenizer, [32002] + list(range(45334, 46358)))
gen_settings.temperature = args.temperature
gen_settings.token_repetition_penalty = args.repetition_penalty
gen_settings.top_p = args.top_p
seed = args.seed if args.seed else random.randint(0, 2**64 - 1)
stop_conditions = ["<EOA>"]
output_joined = ""
output = []
with Timer() as timer:
for segment in lyrics:
current = []
input = (
"Generate music from the given lyrics segment by segment.\n"
f"[Genre] {genre}\n"
f"{lyrics_joined}{output_joined}[start_of_segment]{segment}<SOA><xcodec>"
)
input_ids = tokenizer.encode(input, encode_special_tokens=True)
input_len = input_ids.shape[-1]
max_new_tokens = config.max_seq_len - input_len
print(
f"Using {input_len} tokens of {config.max_seq_len} tokens "
f"with {max_new_tokens} tokens left."
)
job = ExLlamaV2DynamicJob(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
gen_settings=gen_settings,
seed=seed,
stop_conditions=stop_conditions,
decode_special_tokens=True,
)
generator.enqueue(job)
with Timer() as inner:
while generator.num_remaining_jobs():
for result in generator.iterate():
if result.get("stage") == "streaming":
text = result.get("text")
if text:
current.append(text)
output.append(text)
if args.debug:
print(text, end="", flush=True)
if result.get("eos") and current:
current_joined = "".join(current)
output_joined += (
f"[start_of_segment]{segment}<SOA><xcodec>"
f"{current_joined}<EOA>[end_of_segment]"
)
if args.debug:
print()
print(f"Generated {len(current)} tokens in {inner.interval:.2f} seconds.")
print(f"Finished in {timer.interval:.2f} seconds with seed {seed}.")
with Timer() as timer:
codec_config = Path("xcodec_mini_infer/final_ckpt/config.yaml")
codec_config = yaml.safe_load(codec_config.read_bytes())
codec = SoundStream(**codec_config["generator"]["config"])
state_dict = torch.load("xcodec_mini_infer/final_ckpt/ckpt_00360000.pth")
codec.load_state_dict(state_dict["codec_model"])
codec = codec.eval().cuda()
print(f"Loaded codec in {timer.interval:.2f} seconds.")
with Timer() as timer, torch.inference_mode():
pattern = re.compile(r"<xcodec/0/(\d+)>")
output_ids = [int(o[10:-1]) for o in output if re.match(pattern, o)]
vocal = output_ids[::2]
vocal = torch.tensor([[vocal]]).cuda()
vocal = vocal.permute(1, 0, 2)
vocal = codec.decode(vocal)
vocal = vocal.squeeze(0).cpu()
torchaudio.save("vocal.wav", vocal, 16000)
inst = output_ids[1::2]
inst = torch.tensor([[inst]]).cuda()
inst = inst.permute(1, 0, 2)
inst = codec.decode(inst)
inst = inst.squeeze(0).cpu()
torchaudio.save("inst.wav", inst, 16000)
print(f"Decoded audio in {timer.interval:.2f} seconds.")
``` |