Spaces:
Running
on
Zero
Running
on
Zero
# Author: Haohe Liu | |
# Email: [email protected] | |
# Date: 11 Feb 2023 | |
import os | |
import json | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
import matplotlib | |
from scipy.io import wavfile | |
from matplotlib import pyplot as plt | |
matplotlib.use("Agg") | |
import hashlib | |
import os | |
import requests | |
from tqdm import tqdm | |
URL_MAP = { | |
"vggishish_lpaps": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/vggishish16.pt", | |
"vggishish_mean_std_melspec_10s_22050hz": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/train_means_stds_melspec_10s_22050hz.txt", | |
"melception": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/melception-21-05-10T09-28-40.pt", | |
} | |
CKPT_MAP = { | |
"vggishish_lpaps": "vggishish16.pt", | |
"vggishish_mean_std_melspec_10s_22050hz": "train_means_stds_melspec_10s_22050hz.txt", | |
"melception": "melception-21-05-10T09-28-40.pt", | |
} | |
MD5_MAP = { | |
"vggishish_lpaps": "197040c524a07ccacf7715d7080a80bd", | |
"vggishish_mean_std_melspec_10s_22050hz": "f449c6fd0e248936c16f6d22492bb625", | |
"melception": "a71a41041e945b457c7d3d814bbcf72d", | |
} | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def load_json(fname): | |
with open(fname, "r") as f: | |
data = json.load(f) | |
return data | |
def read_json(dataset_json_file): | |
with open(dataset_json_file, "r") as fp: | |
data_json = json.load(fp) | |
return data_json["data"] | |
def copy_test_subset_data(metadata, testset_copy_target_path): | |
# metadata = read_json(testset_metadata) | |
os.makedirs(testset_copy_target_path, exist_ok=True) | |
if len(os.listdir(testset_copy_target_path)) == len(metadata): | |
return | |
else: | |
# delete files in folder testset_copy_target_path | |
for file in os.listdir(testset_copy_target_path): | |
try: | |
os.remove(os.path.join(testset_copy_target_path, file)) | |
except Exception as e: | |
print(e) | |
print("Copying test subset data to {}".format(testset_copy_target_path)) | |
for each in tqdm(metadata): | |
cmd = "cp {} {}".format(each["wav"], os.path.join(testset_copy_target_path)) | |
os.system(cmd) | |
def listdir_nohidden(path): | |
for f in os.listdir(path): | |
if not f.startswith("."): | |
yield f | |
def get_restore_step(path): | |
checkpoints = os.listdir(path) | |
if os.path.exists(os.path.join(path, "final.ckpt")): | |
return "final.ckpt", 0 | |
elif not os.path.exists(os.path.join(path, "last.ckpt")): | |
steps = [int(x.split(".ckpt")[0].split("step=")[1]) for x in checkpoints] | |
return checkpoints[np.argmax(steps)], np.max(steps) | |
else: | |
steps = [] | |
for x in checkpoints: | |
if "last" in x: | |
if "-v" not in x: | |
fname = "last.ckpt" | |
else: | |
this_version = int(x.split(".ckpt")[0].split("-v")[1]) | |
steps.append(this_version) | |
if len(steps) == 0 or this_version > np.max(steps): | |
fname = "last-v%s.ckpt" % this_version | |
return fname, 0 | |
def download(url, local_path, chunk_size=1024): | |
os.makedirs(os.path.split(local_path)[0], exist_ok=True) | |
with requests.get(url, stream=True) as r: | |
total_size = int(r.headers.get("content-length", 0)) | |
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: | |
with open(local_path, "wb") as f: | |
for data in r.iter_content(chunk_size=chunk_size): | |
if data: | |
f.write(data) | |
pbar.update(chunk_size) | |
def md5_hash(path): | |
with open(path, "rb") as f: | |
content = f.read() | |
return hashlib.md5(content).hexdigest() | |
def get_ckpt_path(name, root, check=False): | |
assert name in URL_MAP | |
path = os.path.join(root, CKPT_MAP[name]) | |
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): | |
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) | |
download(URL_MAP[name], path) | |
md5 = md5_hash(path) | |
assert md5 == MD5_MAP[name], md5 | |
return path | |
class KeyNotFoundError(Exception): | |
def __init__(self, cause, keys=None, visited=None): | |
self.cause = cause | |
self.keys = keys | |
self.visited = visited | |
messages = list() | |
if keys is not None: | |
messages.append("Key not found: {}".format(keys)) | |
if visited is not None: | |
messages.append("Visited: {}".format(visited)) | |
messages.append("Cause:\n{}".format(cause)) | |
message = "\n".join(messages) | |
super().__init__(message) | |
def retrieve( | |
list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False | |
): | |
"""Given a nested list or dict return the desired value at key expanding | |
callable nodes if necessary and :attr:`expand` is ``True``. The expansion | |
is done in-place. | |
Parameters | |
---------- | |
list_or_dict : list or dict | |
Possibly nested list or dictionary. | |
key : str | |
key/to/value, path like string describing all keys necessary to | |
consider to get to the desired value. List indices can also be | |
passed here. | |
splitval : str | |
String that defines the delimiter between keys of the | |
different depth levels in `key`. | |
default : obj | |
Value returned if :attr:`key` is not found. | |
expand : bool | |
Whether to expand callable nodes on the path or not. | |
Returns | |
------- | |
The desired value or if :attr:`default` is not ``None`` and the | |
:attr:`key` is not found returns ``default``. | |
Raises | |
------ | |
Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is | |
``None``. | |
""" | |
keys = key.split(splitval) | |
success = True | |
try: | |
visited = [] | |
parent = None | |
last_key = None | |
for key in keys: | |
if callable(list_or_dict): | |
if not expand: | |
raise KeyNotFoundError( | |
ValueError( | |
"Trying to get past callable node with expand=False." | |
), | |
keys=keys, | |
visited=visited, | |
) | |
list_or_dict = list_or_dict() | |
parent[last_key] = list_or_dict | |
last_key = key | |
parent = list_or_dict | |
try: | |
if isinstance(list_or_dict, dict): | |
list_or_dict = list_or_dict[key] | |
else: | |
list_or_dict = list_or_dict[int(key)] | |
except (KeyError, IndexError, ValueError) as e: | |
raise KeyNotFoundError(e, keys=keys, visited=visited) | |
visited += [key] | |
# final expansion of retrieved value | |
if expand and callable(list_or_dict): | |
list_or_dict = list_or_dict() | |
parent[last_key] = list_or_dict | |
except KeyNotFoundError as e: | |
if default is None: | |
raise e | |
else: | |
list_or_dict = default | |
success = False | |
if not pass_success: | |
return list_or_dict | |
else: | |
return list_or_dict, success | |
def to_device(data, device): | |
if len(data) == 12: | |
( | |
ids, | |
raw_texts, | |
speakers, | |
texts, | |
src_lens, | |
max_src_len, | |
mels, | |
mel_lens, | |
max_mel_len, | |
pitches, | |
energies, | |
durations, | |
) = data | |
speakers = torch.from_numpy(speakers).long().to(device) | |
texts = torch.from_numpy(texts).long().to(device) | |
src_lens = torch.from_numpy(src_lens).to(device) | |
mels = torch.from_numpy(mels).float().to(device) | |
mel_lens = torch.from_numpy(mel_lens).to(device) | |
pitches = torch.from_numpy(pitches).float().to(device) | |
energies = torch.from_numpy(energies).to(device) | |
durations = torch.from_numpy(durations).long().to(device) | |
return ( | |
ids, | |
raw_texts, | |
speakers, | |
texts, | |
src_lens, | |
max_src_len, | |
mels, | |
mel_lens, | |
max_mel_len, | |
pitches, | |
energies, | |
durations, | |
) | |
if len(data) == 6: | |
(ids, raw_texts, speakers, texts, src_lens, max_src_len) = data | |
speakers = torch.from_numpy(speakers).long().to(device) | |
texts = torch.from_numpy(texts).long().to(device) | |
src_lens = torch.from_numpy(src_lens).to(device) | |
return (ids, raw_texts, speakers, texts, src_lens, max_src_len) | |
def log(logger, step=None, fig=None, audio=None, sampling_rate=22050, tag=""): | |
# if losses is not None: | |
# logger.add_scalar("Loss/total_loss", losses[0], step) | |
# logger.add_scalar("Loss/mel_loss", losses[1], step) | |
# logger.add_scalar("Loss/mel_postnet_loss", losses[2], step) | |
# logger.add_scalar("Loss/pitch_loss", losses[3], step) | |
# logger.add_scalar("Loss/energy_loss", losses[4], step) | |
# logger.add_scalar("Loss/duration_loss", losses[5], step) | |
# if(len(losses) > 6): | |
# logger.add_scalar("Loss/disc_loss", losses[6], step) | |
# logger.add_scalar("Loss/fmap_loss", losses[7], step) | |
# logger.add_scalar("Loss/r_loss", losses[8], step) | |
# logger.add_scalar("Loss/g_loss", losses[9], step) | |
# logger.add_scalar("Loss/gen_loss", losses[10], step) | |
# logger.add_scalar("Loss/diff_loss", losses[11], step) | |
if fig is not None: | |
logger.add_figure(tag, fig) | |
if audio is not None: | |
audio = audio / (max(abs(audio)) * 1.1) | |
logger.add_audio( | |
tag, | |
audio, | |
sample_rate=sampling_rate, | |
) | |
def get_mask_from_lengths(lengths, max_len=None): | |
batch_size = lengths.shape[0] | |
if max_len is None: | |
max_len = torch.max(lengths).item() | |
ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device) | |
mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) | |
return mask | |
def expand(values, durations): | |
out = list() | |
for value, d in zip(values, durations): | |
out += [value] * max(0, int(d)) | |
return np.array(out) | |
def synth_one_sample_val( | |
targets, predictions, vocoder, model_config, preprocess_config | |
): | |
index = np.random.choice(list(np.arange(targets[6].size(0)))) | |
basename = targets[0][index] | |
src_len = predictions[8][index].item() | |
mel_len = predictions[9][index].item() | |
mel_target = targets[6][index, :mel_len].detach().transpose(0, 1) | |
mel_prediction = predictions[0][index, :mel_len].detach().transpose(0, 1) | |
postnet_mel_prediction = predictions[1][index, :mel_len].detach().transpose(0, 1) | |
duration = targets[11][index, :src_len].detach().cpu().numpy() | |
if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level": | |
pitch = predictions[2][index, :src_len].detach().cpu().numpy() | |
pitch = expand(pitch, duration) | |
else: | |
pitch = predictions[2][index, :mel_len].detach().cpu().numpy() | |
if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level": | |
energy = predictions[3][index, :src_len].detach().cpu().numpy() | |
energy = expand(energy, duration) | |
else: | |
energy = predictions[3][index, :mel_len].detach().cpu().numpy() | |
with open( | |
os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json") | |
) as f: | |
stats = json.load(f) | |
stats = stats["pitch"] + stats["energy"][:2] | |
# from datetime import datetime | |
# now = datetime.now() | |
# current_time = now.strftime("%D:%H:%M:%S") | |
# np.save(("mel_pred_%s.npy" % current_time).replace("/","-"), mel_prediction.cpu().numpy()) | |
# np.save(("postnet_mel_prediction_%s.npy" % current_time).replace("/","-"), postnet_mel_prediction.cpu().numpy()) | |
# np.save(("mel_target_%s.npy" % current_time).replace("/","-"), mel_target.cpu().numpy()) | |
fig = plot_mel( | |
[ | |
(mel_prediction.cpu().numpy(), pitch, energy), | |
(postnet_mel_prediction.cpu().numpy(), pitch, energy), | |
(mel_target.cpu().numpy(), pitch, energy), | |
], | |
stats, | |
[ | |
"Raw mel spectrogram prediction", | |
"Postnet mel prediction", | |
"Ground-Truth Spectrogram", | |
], | |
) | |
if vocoder is not None: | |
from .model import vocoder_infer | |
wav_reconstruction = vocoder_infer( | |
mel_target.unsqueeze(0), | |
vocoder, | |
model_config, | |
preprocess_config, | |
)[0] | |
wav_prediction = vocoder_infer( | |
postnet_mel_prediction.unsqueeze(0), | |
vocoder, | |
model_config, | |
preprocess_config, | |
)[0] | |
else: | |
wav_reconstruction = wav_prediction = None | |
return fig, wav_reconstruction, wav_prediction, basename | |
def synth_one_sample(mel_input, mel_prediction, labels, vocoder): | |
if vocoder is not None: | |
from .model import vocoder_infer | |
wav_reconstruction = vocoder_infer( | |
mel_input.permute(0, 2, 1), | |
vocoder, | |
) | |
wav_prediction = vocoder_infer( | |
mel_prediction.permute(0, 2, 1), | |
vocoder, | |
) | |
else: | |
wav_reconstruction = wav_prediction = None | |
return wav_reconstruction, wav_prediction | |
def synth_samples(targets, predictions, vocoder, model_config, preprocess_config, path): | |
# (diff_output, diff_loss, latent_loss) = diffusion | |
basenames = targets[0] | |
for i in range(len(predictions[1])): | |
basename = basenames[i] | |
src_len = predictions[8][i].item() | |
mel_len = predictions[9][i].item() | |
mel_prediction = predictions[1][i, :mel_len].detach().transpose(0, 1) | |
# diff_output = diff_output[i, :mel_len].detach().transpose(0, 1) | |
# duration = predictions[5][i, :src_len].detach().cpu().numpy() | |
if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level": | |
pitch = predictions[2][i, :src_len].detach().cpu().numpy() | |
# pitch = expand(pitch, duration) | |
else: | |
pitch = predictions[2][i, :mel_len].detach().cpu().numpy() | |
if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level": | |
energy = predictions[3][i, :src_len].detach().cpu().numpy() | |
# energy = expand(energy, duration) | |
else: | |
energy = predictions[3][i, :mel_len].detach().cpu().numpy() | |
# import ipdb; ipdb.set_trace() | |
with open( | |
os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json") | |
) as f: | |
stats = json.load(f) | |
stats = stats["pitch"] + stats["energy"][:2] | |
fig = plot_mel( | |
[ | |
(mel_prediction.cpu().numpy(), pitch, energy), | |
], | |
stats, | |
["Synthetized Spectrogram by PostNet"], | |
) | |
# np.save("{}_postnet.npy".format(basename), mel_prediction.cpu().numpy()) | |
plt.savefig(os.path.join(path, "{}_postnet_2.png".format(basename))) | |
plt.close() | |
from .model import vocoder_infer | |
mel_predictions = predictions[1].transpose(1, 2) | |
lengths = predictions[9] * preprocess_config["preprocessing"]["stft"]["hop_length"] | |
wav_predictions = vocoder_infer( | |
mel_predictions, vocoder, model_config, preprocess_config, lengths=lengths | |
) | |
sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"] | |
for wav, basename in zip(wav_predictions, basenames): | |
wavfile.write(os.path.join(path, "{}.wav".format(basename)), sampling_rate, wav) | |
def plot_mel(data, titles=None): | |
fig, axes = plt.subplots(len(data), 1, squeeze=False) | |
if titles is None: | |
titles = [None for i in range(len(data))] | |
for i in range(len(data)): | |
mel = data[i] | |
axes[i][0].imshow(mel, origin="lower", aspect="auto") | |
axes[i][0].set_aspect(2.5, adjustable="box") | |
axes[i][0].set_ylim(0, mel.shape[0]) | |
axes[i][0].set_title(titles[i], fontsize="medium") | |
axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False) | |
axes[i][0].set_anchor("W") | |
return fig | |
def pad_1D(inputs, PAD=0): | |
def pad_data(x, length, PAD): | |
x_padded = np.pad( | |
x, (0, length - x.shape[0]), mode="constant", constant_values=PAD | |
) | |
return x_padded | |
max_len = max((len(x) for x in inputs)) | |
padded = np.stack([pad_data(x, max_len, PAD) for x in inputs]) | |
return padded | |
def pad_2D(inputs, maxlen=None): | |
def pad(x, max_len): | |
PAD = 0 | |
if np.shape(x)[0] > max_len: | |
raise ValueError("not max_len") | |
s = np.shape(x)[1] | |
x_padded = np.pad( | |
x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD | |
) | |
return x_padded[:, :s] | |
if maxlen: | |
output = np.stack([pad(x, maxlen) for x in inputs]) | |
else: | |
max_len = max(np.shape(x)[0] for x in inputs) | |
output = np.stack([pad(x, max_len) for x in inputs]) | |
return output | |
def pad(input_ele, mel_max_length=None): | |
if mel_max_length: | |
max_len = mel_max_length | |
else: | |
max_len = max([input_ele[i].size(0) for i in range(len(input_ele))]) | |
out_list = list() | |
for i, batch in enumerate(input_ele): | |
if len(batch.shape) == 1: | |
one_batch_padded = F.pad( | |
batch, (0, max_len - batch.size(0)), "constant", 0.0 | |
) | |
elif len(batch.shape) == 2: | |
one_batch_padded = F.pad( | |
batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0 | |
) | |
out_list.append(one_batch_padded) | |
out_padded = torch.stack(out_list) | |
return out_padded | |