File size: 5,404 Bytes
8c28a4c
 
 
87043db
 
 
 
 
 
 
 
 
5ba94a4
87043db
 
5ba94a4
87043db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ba94a4
 
 
87043db
5ba94a4
87043db
 
 
 
 
 
411e730
87043db
 
 
 
 
 
 
 
5ba94a4
87043db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ba94a4
87043db
 
 
 
 
 
5ba94a4
87043db
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
---
base_model:
- m-a-p/YuE-s1-7B-anneal-en-cot
---

# Sample Inference Script
```py
import random
import re
import sys
from argparse import ArgumentParser
from pathlib import Path
from warnings import simplefilter

sys.path.append("xcodec_mini_infer")
simplefilter("ignore")

import torch
import torchaudio
import yaml
from exllamav2 import (
    ExLlamaV2,
    ExLlamaV2Cache,
    ExLlamaV2Config,
    ExLlamaV2Tokenizer,
    Timer,
)
from exllamav2.generator import (
    ExLlamaV2DynamicGenerator,
    ExLlamaV2DynamicJob,
    ExLlamaV2Sampler,
)
from rich import print

from xcodec_mini_infer.models.soundstream_hubert_new import SoundStream

parser = ArgumentParser()
parser.add_argument("-m", "--model", required=True)
parser.add_argument("-g", "--genre", required=True)
parser.add_argument("-l", "--lyrics", required=True)
parser.add_argument("-s", "--seed", type=int, default=None)
parser.add_argument("-d", "--debug", action="store_true")
parser.add_argument("--repetition_penalty", type=float, default=1.2)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_p", type=float, default=0.93)
args = parser.parse_args()

with Timer() as timer:
    config = ExLlamaV2Config(args.model)
    model = ExLlamaV2(config, lazy_load=True)
    cache = ExLlamaV2Cache(model, lazy=True)
    model.load_autosplit(cache)

    tokenizer = ExLlamaV2Tokenizer(config, lazy_init=True)
    generator = ExLlamaV2DynamicGenerator(model, cache, tokenizer)
    generator.warmup()

print(f"Loaded model in {timer.interval:.2f} seconds.")

genre = Path(args.genre)
genre = genre.read_text(encoding="utf-8") if genre.is_file() else args.genre
genre = genre.strip()

lyrics = Path(args.lyrics)
lyrics = lyrics.read_text(encoding="utf-8") if lyrics.is_file() else args.lyrics
lyrics = lyrics.strip()

lyrics = re.findall(r"\[(\w+)\](.*?)\n(?=\[|\Z)", lyrics, re.DOTALL)
lyrics = [f"[{l[0]}]\n{l[1].strip()}\n\n" for l in lyrics]
lyrics_joined = "\n".join(lyrics)

gen_settings = ExLlamaV2Sampler.Settings()
gen_settings.allow_tokens(tokenizer, [32002] + list(range(45334, 46358)))
gen_settings.temperature = args.temperature
gen_settings.token_repetition_penalty = args.repetition_penalty
gen_settings.top_p = args.top_p

seed = args.seed if args.seed else random.randint(0, 2**64 - 1)
stop_conditions = ["<EOA>"]

output_joined = ""
output = []

with Timer() as timer:
    for segment in lyrics:
        current = []

        input = (
            "Generate music from the given lyrics segment by segment.\n"
            f"[Genre] {genre}\n"
            f"{lyrics_joined}{output_joined}[start_of_segment]{segment}<SOA><xcodec>"
        )

        input_ids = tokenizer.encode(input, encode_special_tokens=True)
        input_len = input_ids.shape[-1]
        max_new_tokens = config.max_seq_len - input_len

        print(
            f"Using {input_len} tokens of {config.max_seq_len} tokens "
            f"with {max_new_tokens} tokens left."
        )

        job = ExLlamaV2DynamicJob(
            input_ids=input_ids,
            max_new_tokens=max_new_tokens,
            gen_settings=gen_settings,
            seed=seed,
            stop_conditions=stop_conditions,
            decode_special_tokens=True,
        )

        generator.enqueue(job)

        with Timer() as inner:
            while generator.num_remaining_jobs():
                for result in generator.iterate():
                    if result.get("stage") == "streaming":
                        text = result.get("text")

                        if text:
                            current.append(text)
                            output.append(text)

                            if args.debug:
                                print(text, end="", flush=True)

                    if result.get("eos") and current:
                        current_joined = "".join(current)
                        output_joined += (
                            f"[start_of_segment]{segment}<SOA><xcodec>"
                            f"{current_joined}<EOA>[end_of_segment]"
                        )

                        if args.debug:
                            print()

        print(f"Generated {len(current)} tokens in {inner.interval:.2f} seconds.")

print(f"Finished in {timer.interval:.2f} seconds with seed {seed}.")

with Timer() as timer:
    codec_config = Path("xcodec_mini_infer/final_ckpt/config.yaml")
    codec_config = yaml.safe_load(codec_config.read_bytes())
    codec = SoundStream(**codec_config["generator"]["config"])
    state_dict = torch.load("xcodec_mini_infer/final_ckpt/ckpt_00360000.pth")
    codec.load_state_dict(state_dict["codec_model"])
    codec = codec.eval().cuda()

print(f"Loaded codec in {timer.interval:.2f} seconds.")

with Timer() as timer, torch.inference_mode():
    pattern = re.compile(r"<xcodec/0/(\d+)>")
    output_ids = [int(o[10:-1]) for o in output if re.match(pattern, o)]

    vocal = output_ids[::2]
    vocal = torch.tensor([[vocal]]).cuda()
    vocal = vocal.permute(1, 0, 2)
    vocal = codec.decode(vocal)
    vocal = vocal.squeeze(0).cpu()
    torchaudio.save("vocal.wav", vocal, 16000)

    inst = output_ids[1::2]
    inst = torch.tensor([[inst]]).cuda()
    inst = inst.permute(1, 0, 2)
    inst = codec.decode(inst)
    inst = inst.squeeze(0).cpu()
    torchaudio.save("inst.wav", inst, 16000)

print(f"Decoded audio in {timer.interval:.2f} seconds.")
```