Spaces:
Running
Running
File size: 5,609 Bytes
cfdc687 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import argparse
import os
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch import Tensor
from tqdm import tqdm
import resampy
from modules.wavlm_encoder import WavLMEncoder
from utils.tools import fast_cosine_dist
DOWNSAMPLE_FACTOR = 320
def make_opensinger_df(root_path: Path) -> pd.DataFrame:
all_files = []
folders = ['ManRaw', 'WomanRaw']
for f in folders:
all_files.extend(list((root_path/f).rglob('*.wav')))
# f.parts[-3][:-3]: Man/Woman
speakers = [f.parts[-3][:-3] + '-' + f.stem.split('_')[0] for f in all_files]
df = pd.DataFrame({'path': all_files, 'speaker': speakers})
return df
def main(args):
data_root = Path(args.data_root)
out_dir = Path(args.out_dir) if args.out_dir is not None else data_root/'wavlm_features'
device = torch.device(args.device)
seed = args.seed
SYNTH_WEIGHTINGS = F.one_hot(torch.tensor(args.synthesis_layer), num_classes=25).float().to(device)[:, None]
MATCH_WEIGHTINGS = F.one_hot(torch.tensor(args.matching_layer), num_classes=25).float().mean(axis=0).to(device)[:, None]
print(f"Matching weight: {MATCH_WEIGHTINGS.squeeze()}\nSynthesis weight: {SYNTH_WEIGHTINGS.squeeze()}")
ls_df = make_opensinger_df(data_root)
print(f"Loading wavlm.")
wavlm = WavLMEncoder('pretrained/WavLM-Large.pt', device=device)
np.random.seed(seed)
torch.manual_seed(seed)
extract(ls_df, wavlm, device, data_root, out_dir, SYNTH_WEIGHTINGS, MATCH_WEIGHTINGS)
print("All done!", flush=True)
@torch.inference_mode()
def get_full_features(path, wavlm, device):
x, sr = torchaudio.load(path)
if sr != 16000:
x = resampy.resample(x.numpy(), sr, 16000, axis=1)
x = torch.from_numpy(x).to(dtype=torch.float)
n_pad = DOWNSAMPLE_FACTOR - (x.shape[-1] % DOWNSAMPLE_FACTOR)
x = F.pad(x, (0, n_pad), value=0)
# extract the representation of each layer
wav_input_16khz = x.to(device)
features = wavlm.get_features(wav_input_16khz)
return features
@torch.inference_mode()
def extract(df: pd.DataFrame, wavlm: nn.Module, device, data_root: Path, out_dir: Path, synth_weights: Tensor, match_weights: Tensor):
mb = tqdm(df.groupby('speaker'), desc=f'Total Progress')
for speaker, paths in mb:
if len(paths) == 1:
print(f"there is only one audio for speaker {speaker}, ignore him")
continue
targ_paths = {}
for i, row in paths.iterrows():
rel_path = row.path.relative_to(data_root)
targ_paths[row.path] = (out_dir/rel_path).with_suffix('.pt')
if all([p.exists() for p in targ_paths.values()]):
continue
feature_cache = {}
synthesis_cache = {}
# 1. extract the wavlm features of all the audio of the speaker
pb = tqdm(paths.iterrows(), total=len(paths), desc=f'extracting {speaker}')
for i, row in pb:
feats = get_full_features(row.path, wavlm, device)
matching_feats = (feats*match_weights[:, None] ).sum(dim=0) # (seq_len, dim)
synth_feats = (feats*synth_weights[:, None] ).sum(dim=0) # (seq_len, dim)
feature_cache[row.path] = matching_feats
synthesis_cache[row.path] = synth_feats
# 2. replace the wavlm features of each singing audio with the wavlm features of other songs by the same singer.
pb = tqdm(paths.iterrows(), total=len(paths), desc=f'prematching {speaker}')
for i, row in pb:
targ_path = targ_paths[row.path]
if targ_path.is_file(): continue
os.makedirs(targ_path.parent, exist_ok=True)
source_feats = feature_cache[row.path]
# the audios of the same song are removed since the same song contains repeated phrases.
song_name = row.path.stem.split('_')[1]
filtered_matching_feats = {key: value for key, value in feature_cache.items() if song_name not in key.stem}
matching_pool = list(filtered_matching_feats.values())
matching_pool = torch.concat(matching_pool, dim=0)
filtered_synth_feats = {key: value for key, value in synthesis_cache.items() if song_name not in key.stem}
synth_pool = list(filtered_synth_feats.values())
synth_pool = torch.concat(synth_pool, dim=0)
# calculate the distance and replace each feature with its K neighbors
matching_pool = matching_pool.to(device)
synth_pool = synth_pool.to(device)
dists = fast_cosine_dist(source_feats, matching_pool, device=device)
best = dists.topk(k=args.topk, dim=-1, largest=False) # (src_len, 4)
out_feats = synth_pool[best.indices].mean(dim=1) # (N, dim)
# 3. save pre-matched sequence
torch.save(out_feats.cpu(), str(targ_path))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Compute matched wavlm features for a OpenSinger dataset")
parser.add_argument('--data_root', required=True, type=str)
parser.add_argument('--seed', default=123, type=int)
parser.add_argument('--out_dir', type=str)
parser.add_argument('--device', default='cuda', type=str)
parser.add_argument('--topk', type=int, default=4)
parser.add_argument('--matching_layer', type=int, default=[20,21,22,23,24], nargs='+')
parser.add_argument('--synthesis_layer', type=int, default=6)
args = parser.parse_args()
main(args)
|