import argparse import os from pathlib import Path import sys import torchaudio import numpy as np from time import time import torch import typing as tp from omegaconf import OmegaConf from vocos import VocosDecoder from models.soundstream_hubert_new import SoundStream from tqdm import tqdm def build_soundstream_model(config): model = eval(config.generator.name)(**config.generator.config) return model def build_codec_model(config_path, vocal_decoder_path, inst_decoder_path): vocal_decoder = VocosDecoder.from_hparams(config_path=config_path) vocal_decoder.load_state_dict(torch.load(vocal_decoder_path)) inst_decoder = VocosDecoder.from_hparams(config_path=config_path) inst_decoder.load_state_dict(torch.load(inst_decoder_path)) return vocal_decoder, inst_decoder def save_audio(wav: torch.Tensor, path: tp.Union[Path, str], sample_rate: int, rescale: bool = False): limit = 0.99 mx = wav.abs().max() if rescale: wav = wav * min(limit / mx, 1) else: wav = wav.clamp(-limit, limit) path = str(Path(path).with_suffix('.mp3')) torchaudio.save(path, wav, sample_rate=sample_rate) def process_audio(input_file, output_file, rescale, args, decoder, soundstream): compressed = np.load(input_file, allow_pickle=True).astype(np.int16) print(f"Processing {input_file}") print(f"Compressed shape: {compressed.shape}") args.bw = float(4) compressed = torch.as_tensor(compressed, dtype=torch.long).unsqueeze(1) compressed = soundstream.get_embed(compressed.to(f"cuda:{args.cuda_idx}")) compressed = torch.tensor(compressed).to(f"cuda:{args.cuda_idx}") start_time = time() with torch.no_grad(): decoder.eval() decoder = decoder.to(f"cuda:{args.cuda_idx}") out = decoder(compressed) out = out.detach().cpu() duration = time() - start_time rtf = (out.shape[1] / 44100.0) / duration print(f"Decoded in {duration:.2f}s ({rtf:.2f}x RTF)") os.makedirs(os.path.dirname(output_file), exist_ok=True) save_audio(out, output_file, 44100, rescale=rescale) print(f"Saved: {output_file}") return out def find_matching_pairs(input_folder): if str(input_folder).endswith('.lst'): # Convert to string with open(input_folder, 'r') as file: files = [line.strip() for line in file if line.strip()] else: files = list(Path(input_folder).glob('*.npy')) print(f"found {len(files)} npy.") instrumental_files = {} vocal_files = {} for file in files: if not isinstance(file, Path): file = Path(file) name = file.stem if 'instrumental' in name.lower(): base_name = name.lower().replace('instrumental', '')#.strip('_') instrumental_files[base_name] = file elif 'vocal' in name.lower(): # base_name = name.lower().replace('vocal', '').strip('_') last_index = name.lower().rfind('vocal') if last_index != -1: # Create a new string with the last 'vocal' removed base_name = name.lower()[:last_index] + name.lower()[last_index + len('vocal'):] else: base_name = name.lower() vocal_files[base_name] = file # Find matching pairs pairs = [] for base_name in instrumental_files.keys(): if base_name in vocal_files: pairs.append(( instrumental_files[base_name], vocal_files[base_name], base_name )) return pairs def main(): parser = argparse.ArgumentParser(description='High fidelity neural audio codec using Vocos decoder.') parser.add_argument('--input_folder', type=Path, required=True, help='Input folder containing NPY files.') parser.add_argument('--output_base', type=Path, required=True, help='Base output folder.') parser.add_argument('--resume_path', type=str, default='./final_ckpt/ckpt_00360000.pth', help='Path to model checkpoint.') parser.add_argument('--config_path', type=str, default='./config.yaml', help='Path to Vocos config file.') parser.add_argument('--vocal_decoder_path', type=str, default='/aifs4su/mmcode/codeclm/xcodec_mini_infer_newdecoder/decoders/decoder_131000.pth', help='Path to Vocos decoder weights.') parser.add_argument('--inst_decoder_path', type=str, default='/aifs4su/mmcode/codeclm/xcodec_mini_infer_newdecoder/decoders/decoder_151000.pth', help='Path to Vocos decoder weights.') parser.add_argument('-r', '--rescale', action='store_true', help='Rescale output to avoid clipping.') args = parser.parse_args() # Validate inputs if not args.input_folder.exists(): sys.exit(f"Input folder {args.input_folder} does not exist.") if not os.path.isfile(args.config_path): sys.exit(f"{args.config_path} file does not exist.") # if not os.path.isfile(args.decoder_path): # sys.exit(f"{args.decoder_path} file does not exist.") # Create output directories mix_dir = args.output_base / 'mix' stems_dir = args.output_base / 'stems' os.makedirs(mix_dir, exist_ok=True) os.makedirs(stems_dir, exist_ok=True) # Initialize models config_ss = OmegaConf.load("./final_ckpt/config.yaml") soundstream = build_soundstream_model(config_ss) parameter_dict = torch.load(args.resume_path) soundstream.load_state_dict(parameter_dict['codec_model']) soundstream.eval() vocal_decoder, inst_decoder = build_codec_model(args.config_path, args.vocal_decoder_path, args.inst_decoder_path) # Find and process matching pairs pairs = find_matching_pairs(args.input_folder) print(f"Found {len(pairs)} matching pairs") pairs = [p for p in pairs if not os.path.exists(mix_dir / f'{p[2]}.mp3')] print(f"{len(pairs)} to reconstruct...") for instrumental_file, vocal_file, base_name in tqdm(pairs): print(f"\nProcessing pair: {base_name}") # Create stems directory for this song song_stems_dir = stems_dir / base_name os.makedirs(song_stems_dir, exist_ok=True) try: # Process instrumental instrumental_output = process_audio( instrumental_file, song_stems_dir / 'instrumental.mp3', args.rescale, args, inst_decoder, soundstream ) # Process vocal vocal_output = process_audio( vocal_file, song_stems_dir / 'vocal.mp3', args.rescale, args, vocal_decoder, soundstream ) except IndexError as e: print(e) continue # Create and save mix try: mix_output = instrumental_output + vocal_output save_audio(mix_output, mix_dir / f'{base_name}.mp3', 44100, args.rescale) print(f"Created mix: {mix_dir / f'{base_name}.mp3'}") except RuntimeError as e: print(e) print(f"mix {base_name} failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}") if __name__ == '__main__': main() # Example Usage # python reconstruct_separately.py --input_folder test_samples --output_base test