Spaces:
Running
Running
import argparse | |
import os | |
from pathlib import Path | |
import numpy as np | |
import pandas as pd | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchaudio | |
from torch import Tensor | |
from tqdm import tqdm | |
import resampy | |
from modules.wavlm_encoder import WavLMEncoder | |
from utils.tools import fast_cosine_dist | |
DOWNSAMPLE_FACTOR = 320 | |
def make_opensinger_df(root_path: Path) -> pd.DataFrame: | |
all_files = [] | |
folders = ['ManRaw', 'WomanRaw'] | |
for f in folders: | |
all_files.extend(list((root_path/f).rglob('*.wav'))) | |
# f.parts[-3][:-3]: Man/Woman | |
speakers = [f.parts[-3][:-3] + '-' + f.stem.split('_')[0] for f in all_files] | |
df = pd.DataFrame({'path': all_files, 'speaker': speakers}) | |
return df | |
def main(args): | |
data_root = Path(args.data_root) | |
out_dir = Path(args.out_dir) if args.out_dir is not None else data_root/'wavlm_features' | |
device = torch.device(args.device) | |
seed = args.seed | |
SYNTH_WEIGHTINGS = F.one_hot(torch.tensor(args.synthesis_layer), num_classes=25).float().to(device)[:, None] | |
MATCH_WEIGHTINGS = F.one_hot(torch.tensor(args.matching_layer), num_classes=25).float().mean(axis=0).to(device)[:, None] | |
print(f"Matching weight: {MATCH_WEIGHTINGS.squeeze()}\nSynthesis weight: {SYNTH_WEIGHTINGS.squeeze()}") | |
ls_df = make_opensinger_df(data_root) | |
print(f"Loading wavlm.") | |
wavlm = WavLMEncoder('pretrained/WavLM-Large.pt', device=device) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
extract(ls_df, wavlm, device, data_root, out_dir, SYNTH_WEIGHTINGS, MATCH_WEIGHTINGS) | |
print("All done!", flush=True) | |
def get_full_features(path, wavlm, device): | |
x, sr = torchaudio.load(path) | |
if sr != 16000: | |
x = resampy.resample(x.numpy(), sr, 16000, axis=1) | |
x = torch.from_numpy(x).to(dtype=torch.float) | |
n_pad = DOWNSAMPLE_FACTOR - (x.shape[-1] % DOWNSAMPLE_FACTOR) | |
x = F.pad(x, (0, n_pad), value=0) | |
# extract the representation of each layer | |
wav_input_16khz = x.to(device) | |
features = wavlm.get_features(wav_input_16khz) | |
return features | |
def extract(df: pd.DataFrame, wavlm: nn.Module, device, data_root: Path, out_dir: Path, synth_weights: Tensor, match_weights: Tensor): | |
mb = tqdm(df.groupby('speaker'), desc=f'Total Progress') | |
for speaker, paths in mb: | |
if len(paths) == 1: | |
print(f"there is only one audio for speaker {speaker}, ignore him") | |
continue | |
targ_paths = {} | |
for i, row in paths.iterrows(): | |
rel_path = row.path.relative_to(data_root) | |
targ_paths[row.path] = (out_dir/rel_path).with_suffix('.pt') | |
if all([p.exists() for p in targ_paths.values()]): | |
continue | |
feature_cache = {} | |
synthesis_cache = {} | |
# 1. extract the wavlm features of all the audio of the speaker | |
pb = tqdm(paths.iterrows(), total=len(paths), desc=f'extracting {speaker}') | |
for i, row in pb: | |
feats = get_full_features(row.path, wavlm, device) | |
matching_feats = (feats*match_weights[:, None] ).sum(dim=0) # (seq_len, dim) | |
synth_feats = (feats*synth_weights[:, None] ).sum(dim=0) # (seq_len, dim) | |
feature_cache[row.path] = matching_feats | |
synthesis_cache[row.path] = synth_feats | |
# 2. replace the wavlm features of each singing audio with the wavlm features of other songs by the same singer. | |
pb = tqdm(paths.iterrows(), total=len(paths), desc=f'prematching {speaker}') | |
for i, row in pb: | |
targ_path = targ_paths[row.path] | |
if targ_path.is_file(): continue | |
os.makedirs(targ_path.parent, exist_ok=True) | |
source_feats = feature_cache[row.path] | |
# the audios of the same song are removed since the same song contains repeated phrases. | |
song_name = row.path.stem.split('_')[1] | |
filtered_matching_feats = {key: value for key, value in feature_cache.items() if song_name not in key.stem} | |
matching_pool = list(filtered_matching_feats.values()) | |
matching_pool = torch.concat(matching_pool, dim=0) | |
filtered_synth_feats = {key: value for key, value in synthesis_cache.items() if song_name not in key.stem} | |
synth_pool = list(filtered_synth_feats.values()) | |
synth_pool = torch.concat(synth_pool, dim=0) | |
# calculate the distance and replace each feature with its K neighbors | |
matching_pool = matching_pool.to(device) | |
synth_pool = synth_pool.to(device) | |
dists = fast_cosine_dist(source_feats, matching_pool, device=device) | |
best = dists.topk(k=args.topk, dim=-1, largest=False) # (src_len, 4) | |
out_feats = synth_pool[best.indices].mean(dim=1) # (N, dim) | |
# 3. save pre-matched sequence | |
torch.save(out_feats.cpu(), str(targ_path)) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description="Compute matched wavlm features for a OpenSinger dataset") | |
parser.add_argument('--data_root', required=True, type=str) | |
parser.add_argument('--seed', default=123, type=int) | |
parser.add_argument('--out_dir', type=str) | |
parser.add_argument('--device', default='cuda', type=str) | |
parser.add_argument('--topk', type=int, default=4) | |
parser.add_argument('--matching_layer', type=int, default=[20,21,22,23,24], nargs='+') | |
parser.add_argument('--synthesis_layer', type=int, default=6) | |
args = parser.parse_args() | |
main(args) | |