--- base_model: - m-a-p/YuE-s1-7B-anneal-en-cot --- # Sample Inference Script ```py import random import re import sys from argparse import ArgumentParser from pathlib import Path from warnings import simplefilter sys.path.append("xcodec_mini_infer") simplefilter("ignore") import torch import torchaudio import yaml from exllamav2 import ( ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config, ExLlamaV2Tokenizer, Timer, ) from exllamav2.generator import ( ExLlamaV2DynamicGenerator, ExLlamaV2DynamicJob, ExLlamaV2Sampler, ) from rich import print from xcodec_mini_infer.models.soundstream_hubert_new import SoundStream parser = ArgumentParser() parser.add_argument("-m", "--model", required=True) parser.add_argument("-g", "--genre", required=True) parser.add_argument("-l", "--lyrics", required=True) parser.add_argument("-s", "--seed", type=int, default=None) parser.add_argument("-d", "--debug", action="store_true") parser.add_argument("--repetition_penalty", type=float, default=1.2) parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--top_p", type=float, default=0.93) args = parser.parse_args() with Timer() as timer: config = ExLlamaV2Config(args.model) model = ExLlamaV2(config, lazy_load=True) cache = ExLlamaV2Cache(model, lazy=True) model.load_autosplit(cache) tokenizer = ExLlamaV2Tokenizer(config, lazy_init=True) generator = ExLlamaV2DynamicGenerator(model, cache, tokenizer) generator.warmup() print(f"Loaded model in {timer.interval:.2f} seconds.") genre = Path(args.genre) genre = genre.read_text(encoding="utf-8") if genre.is_file() else args.genre genre = genre.strip() lyrics = Path(args.lyrics) lyrics = lyrics.read_text(encoding="utf-8") if lyrics.is_file() else args.lyrics lyrics = lyrics.strip() lyrics = re.findall(r"\[(\w+)\](.*?)\n(?=\[|\Z)", lyrics, re.DOTALL) lyrics = [f"[{l[0]}]\n{l[1].strip()}\n\n" for l in lyrics] lyrics_joined = "\n".join(lyrics) gen_settings = ExLlamaV2Sampler.Settings() gen_settings.allow_tokens(tokenizer, [32002] + list(range(45334, 46358))) gen_settings.temperature = args.temperature gen_settings.token_repetition_penalty = args.repetition_penalty gen_settings.top_p = args.top_p seed = args.seed if args.seed else random.randint(0, 2**64 - 1) stop_conditions = [""] output_joined = "" output = [] with Timer() as timer: for segment in lyrics: current = [] input = ( "Generate music from the given lyrics segment by segment.\n" f"[Genre] {genre}\n" f"{lyrics_joined}{output_joined}[start_of_segment]{segment}" ) input_ids = tokenizer.encode(input, encode_special_tokens=True) input_len = input_ids.shape[-1] max_new_tokens = config.max_seq_len - input_len print( f"Using {input_len} tokens of {config.max_seq_len} tokens " f"with {max_new_tokens} tokens left." ) job = ExLlamaV2DynamicJob( input_ids=input_ids, max_new_tokens=max_new_tokens, gen_settings=gen_settings, seed=seed, stop_conditions=stop_conditions, decode_special_tokens=True, ) generator.enqueue(job) with Timer() as inner: while generator.num_remaining_jobs(): for result in generator.iterate(): if result.get("stage") == "streaming": text = result.get("text") if text: current.append(text) output.append(text) if args.debug: print(text, end="", flush=True) if result.get("eos") and current: current_joined = "".join(current) output_joined += ( f"[start_of_segment]{segment}" f"{current_joined}[end_of_segment]" ) if args.debug: print() print(f"Generated {len(current)} tokens in {inner.interval:.2f} seconds.") print(f"Finished in {timer.interval:.2f} seconds with seed {seed}.") with Timer() as timer: codec_config = Path("xcodec_mini_infer/final_ckpt/config.yaml") codec_config = yaml.safe_load(codec_config.read_bytes()) codec = SoundStream(**codec_config["generator"]["config"]) state_dict = torch.load("xcodec_mini_infer/final_ckpt/ckpt_00360000.pth") codec.load_state_dict(state_dict["codec_model"]) codec = codec.eval().cuda() print(f"Loaded codec in {timer.interval:.2f} seconds.") with Timer() as timer, torch.inference_mode(): pattern = re.compile(r"") output_ids = [int(o[10:-1]) for o in output if re.match(pattern, o)] vocal = output_ids[::2] vocal = torch.tensor([[vocal]]).cuda() vocal = vocal.permute(1, 0, 2) vocal = codec.decode(vocal) vocal = vocal.squeeze(0).cpu() torchaudio.save("vocal.wav", vocal, 16000) inst = output_ids[1::2] inst = torch.tensor([[inst]]).cuda() inst = inst.permute(1, 0, 2) inst = codec.decode(inst) inst = inst.squeeze(0).cpu() torchaudio.save("inst.wav", inst, 16000) print(f"Decoded audio in {timer.interval:.2f} seconds.") ```