KingNish's picture
Update app.py
98d025d verified
raw
history blame
13.9 kB
import os
import subprocess
# Install flash attention
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
import spaces
import os
import torch
import numpy as np
from omegaconf import OmegaConf
import torchaudio
from torchaudio.transforms import Resample
import soundfile as sf
import uuid
from tqdm import tqdm
from einops import rearrange
import gradio as gr
import re
from collections import Counter
from codecmanipulator import CodecManipulator
from mmtokenizer import _MMSentencePieceTokenizer
from transformers import AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
from models.soundstream_hubert_new import SoundStream
from vocoder import build_codec_model, process_audio
from post_process_audio import replace_low_freq_with_energy_matched
# Initialize global variables and models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
codectool = CodecManipulator("xcodec", 0, 1)
codectool_stage2 = CodecManipulator("xcodec", 0, 8)
# Load models once at startup
def load_models():
# Stage 1 Model
stage1_model = AutoModelForCausalLM.from_pretrained(
"m-a-p/YuE-s1-7B-anneal-en-cot",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2"
).to(device)
stage1_model.eval()
# Stage 2 Model
stage2_model = AutoModelForCausalLM.from_pretrained(
"m-a-p/YuE-s2-1B-general",
torch_dtype=torch.float16,
attn_implementation="flash_attention_2"
).to(device)
stage2_model.eval()
# Codec Model
model_config = OmegaConf.load('./xcodec_mini_infer/final_ckpt/config.yaml')
codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
parameter_dict = torch.load('./xcodec_mini_infer/final_ckpt/ckpt_00360000.pth', map_location='cpu')
codec_model.load_state_dict(parameter_dict['codec_model'])
codec_model.eval()
return stage1_model, stage2_model, codec_model
stage1_model, stage2_model, codec_model = load_models()
# Helper functions
def split_lyrics(lyrics):
pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
segments = re.findall(pattern, lyrics, re.DOTALL)
return [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
def load_audio_mono(filepath, sampling_rate=16000):
audio, sr = torchaudio.load(filepath)
audio = torch.mean(audio, dim=0, keepdim=True) # Convert to mono
if sr != sampling_rate:
resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
audio = resampler(audio)
return audio
def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
folder_path = os.path.dirname(path)
if not os.path.exists(folder_path):
os.makedirs(folder_path)
limit = 0.99
max_val = wav.abs().max()
wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
# Stage 1 Generation
def stage1_generate(genres, lyrics_text, use_audio_prompt, audio_prompt_path, prompt_start_time, prompt_end_time):
structured_lyrics = split_lyrics(lyrics_text)
full_lyrics = "\n".join(structured_lyrics)
prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"] + structured_lyrics
random_id = str(uuid.uuid4())
output_dir = os.path.join("./output", random_id)
os.makedirs(output_dir, exist_ok=True)
stage1_output_set = []
for i, p in enumerate(tqdm(prompt_texts)):
section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
guidance_scale = 1.5 if i <= 1 else 1.2
if i == 0:
continue
if i == 1 and use_audio_prompt:
audio_prompt = load_audio_mono(audio_prompt_path)
audio_prompt.unsqueeze_(0)
with torch.no_grad():
raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
raw_codes = raw_codes.transpose(0, 1).cpu().numpy().astype(np.int16)
audio_prompt_codec = codectool.npy2ids(raw_codes[0])[int(prompt_start_time * 50): int(prompt_end_time * 50)]
audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
else:
head_id = mmtokenizer.tokenize(prompt_texts[0])
prompt_ids = head_id + mmtokenizer.tokenize("[start_of_segment]") + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
with torch.no_grad():
output_seq = stage1_model.generate(
input_ids=prompt_ids,
max_new_tokens=3000,
min_new_tokens=100,
do_sample=True,
top_p=0.93,
temperature=1.0,
repetition_penalty=1.2,
eos_token_id=mmtokenizer.eoa,
pad_token_id=mmtokenizer.eoa,
)
if i > 1:
raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, prompt_ids.shape[-1]:]], dim=1)
else:
raw_output = output_seq
# Save Stage 1 outputs
ids = raw_output[0].cpu().numpy()
soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
vocals = []
instrumentals = []
for i in range(len(soa_idx)):
codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
if codec_ids[0] == 32016:
codec_ids = codec_ids[1:]
codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
vocals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0])
vocals.append(vocals_ids)
instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
instrumentals.append(instrumentals_ids)
vocals = np.concatenate(vocals, axis=1)
instrumentals = np.concatenate(instrumentals, axis=1)
vocal_save_path = os.path.join(output_dir, f"vocal_{random_id}.npy")
inst_save_path = os.path.join(output_dir, f"instrumental_{random_id}.npy")
np.save(vocal_save_path, vocals)
np.save(inst_save_path, instrumentals)
stage1_output_set.append(vocal_save_path)
stage1_output_set.append(inst_save_path)
return stage1_output_set, output_dir
# Stage 2 Generation
def stage2_generate(model, prompt, batch_size=16):
codec_ids = codectool.unflatten(prompt, n_quantizer=1)
codec_ids = codectool.offset_tok_ids(
codec_ids,
global_offset=codectool.global_offset,
codebook_size=codectool.codebook_size,
num_codebooks=codectool.num_codebooks,
).astype(np.int32)
if batch_size > 1:
codec_list = []
for i in range(batch_size):
idx_begin = i * 300
idx_end = (i + 1) * 300
codec_list.append(codec_ids[:, idx_begin:idx_end])
codec_ids = np.concatenate(codec_list, axis=0)
prompt_ids = np.concatenate(
[
np.tile([mmtokenizer.soa, mmtokenizer.stage_1], (batch_size, 1)),
codec_ids,
np.tile([mmtokenizer.stage_2], (batch_size, 1)),
],
axis=1
)
else:
prompt_ids = np.concatenate([
np.array([mmtokenizer.soa, mmtokenizer.stage_1]),
codec_ids.flatten(),
np.array([mmtokenizer.stage_2])
]).astype(np.int32)
prompt_ids = prompt_ids[np.newaxis, ...]
codec_ids = torch.as_tensor(codec_ids).to(device)
prompt_ids = torch.as_tensor(prompt_ids).to(device)
len_prompt = prompt_ids.shape[-1]
block_list = LogitsProcessorList([BlockTokenRangeProcessor(0, 46358), BlockTokenRangeProcessor(53526, mmtokenizer.vocab_size)])
for frames_idx in range(codec_ids.shape[1]):
cb0 = codec_ids[:, frames_idx:frames_idx + 1]
prompt_ids = torch.cat([prompt_ids, cb0], dim=1)
input_ids = prompt_ids
with torch.no_grad():
stage2_output = model.generate(
input_ids=input_ids,
min_new_tokens=7,
max_new_tokens=7,
eos_token_id=mmtokenizer.eoa,
pad_token_id=mmtokenizer.eoa,
logits_processor=block_list,
)
assert stage2_output.shape[1] - prompt_ids.shape[1] == 7, f"output new tokens={stage2_output.shape[1] - prompt_ids.shape[1]}"
prompt_ids = stage2_output
if batch_size > 1:
output = prompt_ids.cpu().numpy()[:, len_prompt:]
output_list = [output[i] for i in range(batch_size)]
output = np.concatenate(output_list, axis=0)
else:
output = prompt_ids[0].cpu().numpy()[len_prompt:]
return output
def stage2_inference(model, stage1_output_set, output_dir, batch_size=4):
stage2_result = []
for i in tqdm(range(len(stage1_output_set))):
output_filename = os.path.join(output_dir, os.path.basename(stage1_output_set[i]))
if os.path.exists(output_filename):
continue
prompt = np.load(stage1_output_set[i]).astype(np.int32)
output_duration = prompt.shape[-1] // 50 // 6 * 6
num_batch = output_duration // 6
if num_batch <= batch_size:
output = stage2_generate(model, prompt[:, :output_duration * 50], batch_size=num_batch)
else:
segments = []
num_segments = (num_batch // batch_size) + (1 if num_batch % batch_size != 0 else 0)
for seg in range(num_segments):
start_idx = seg * batch_size * 300
end_idx = min((seg + 1) * batch_size * 300, output_duration * 50)
current_batch_size = batch_size if seg != num_segments - 1 or num_batch % batch_size == 0 else num_batch % batch_size
segment = stage2_generate(model, prompt[:, start_idx:end_idx], batch_size=current_batch_size)
segments.append(segment)
output = np.concatenate(segments, axis=0)
if output_duration * 50 != prompt.shape[-1]:
ending = stage2_generate(model, prompt[:, output_duration * 50:], batch_size=1)
output = np.concatenate([output, ending], axis=0)
output = codectool_stage2.ids2npy(output)
fixed_output = copy.deepcopy(output)
for i, line in enumerate(output):
for j, element in enumerate(line):
if element < 0 or element > 1023:
counter = Counter(line)
most_frequant = sorted(counter.items(), key=lambda x: x[1], reverse=True)[0][0]
fixed_output[i, j] = most_frequant
np.save(output_filename, fixed_output)
stage2_result.append(output_filename)
return stage2_result
# Main Gradio function
@spaces.GPU()
def generate_music(genres, lyrics_text, use_audio_prompt, audio_prompt, start_time, end_time, progress=gr.Progress()):
progress(0.1, "Running Stage 1 Generation...")
stage1_output_set, output_dir = stage1_generate(genres, lyrics_text, use_audio_prompt, audio_prompt, start_time, end_time)
progress(0.6, "Running Stage 2 Refinement...")
stage2_result = stage2_inference(stage2_model, stage1_output_set, output_dir)
progress(0.8, "Processing Audio...")
vocal_decoder, inst_decoder = build_codec_model('./xcodec_mini_infer/decoders/config.yaml', './xcodec_mini_infer/decoders/decoder_131000.pth', './xcodec_mini_infer/decoders/decoder_151000.pth')
vocoder_output_dir = os.path.join(output_dir, "vocoder")
os.makedirs(vocoder_output_dir, exist_ok=True)
for npy in stage2_result:
if 'instrumental' in npy:
process_audio(npy, os.path.join(vocoder_output_dir, 'instrumental.mp3'), False, None, inst_decoder, codec_model)
else:
process_audio(npy, os.path.join(vocoder_output_dir, 'vocal.mp3'), False, None, vocal_decoder, codec_model)
return [
os.path.join(vocoder_output_dir, 'instrumental.mp3'),
os.path.join(vocoder_output_dir, 'vocal.mp3')
]
# Gradio UI
with gr.Blocks(title="AI Music Generation") as demo:
gr.Markdown("# 🎡 AI Music Generation Pipeline")
with gr.Row():
with gr.Column():
genre_input = gr.Textbox(label="Genre Tags", placeholder="e.g., Pop, Happy, Female Vocal")
lyrics_input = gr.Textbox(label="Lyrics", lines=10, placeholder="Enter lyrics with segments...")
use_audio_prompt = gr.Checkbox(label="Use Audio Prompt")
audio_input = gr.Audio(label="Reference Audio", type="filepath", visible=False)
start_time = gr.Number(label="Start Time (sec)", value=0.0, visible=False)
end_time = gr.Number(label="End Time (sec)", value=30.0, visible=False)
generate_btn = gr.Button("Generate Music", variant="primary")
with gr.Column():
vocal_output = gr.Audio(label="Vocal Track", interactive=False)
inst_output = gr.Audio(label="Instrumental Track", interactive=False)
use_audio_prompt.change(
lambda x: [gr.update(visible=x), gr.update(visible=x), gr.update(visible=x)],
inputs=use_audio_prompt,
outputs=[audio_input, start_time, end_time]
)
generate_btn.click(
generate_music,
inputs=[genre_input, lyrics_input, use_audio_prompt, audio_input, start_time, end_time],
outputs=[vocal_output, inst_output]
)
if __name__ == "__main__":
demo.launch()