kevinwang676's picture
Upload folder using huggingface_hub
cfdc687
raw
history blame
5.52 kB
import os
import torch
import numpy as np
import soundfile as sf
def fast_cosine_dist(source_feats: torch.Tensor, matching_pool: torch.Tensor, device='cpu'):
"""
Computes the cosine distance between source features and a matching pool of features.
Like torch.cdist, but fixed dim=-1 and for cosine distance.
Args:
source_feats (torch.Tensor): Tensor of source features with shape (n_source_feats, feat_dim).
matching_pool (torch.Tensor): Tensor of matching pool features with shape (n_matching_feats, feat_dim).
device (str, optional): Device to perform the computation on. Defaults to 'cpu'.
Returns:
torch.Tensor: Tensor of cosine distances between the source features and the matching pool features.
"""
source_feats = source_feats.to(device)
matching_pool = matching_pool.to(device)
source_norms = torch.norm(source_feats, p=2, dim=-1)
matching_norms = torch.norm(matching_pool, p=2, dim=-1)
dotprod = -torch.cdist(source_feats[None].to(device), matching_pool[None], p=2)[0]**2 + source_norms[:, None]**2 + matching_norms[None]**2
dotprod /= 2
dists = 1 - ( dotprod / (source_norms[:, None] * matching_norms[None]) )
return dists
def load_wav(wav_path, sr=None):
"""
Loads a waveform from a wav file.
Args:
wav_path (str): Path to the wav file.
sr (int, optional): Target sample rate.
If `sr` is specified and the loaded audio has a different sample rate, an AssertionError is raised.
Defaults to None.
Returns:
Tuple[np.ndarray, int]: Tuple containing the loaded waveform as a NumPy array and the sample rate.
"""
wav, fs = sf.read(wav_path)
if wav.ndim != 1:
print('The wav file %s has %d channels, select the first one to proceed.' %(wav_path, wav.ndim))
wav = wav[:,0]
assert sr is None or fs == sr, f'{sr} kHz audio is required. Got {fs}'
peak = np.abs(wav).max()
if peak > 1.0:
wav /= peak
return wav, fs
class ConfigWrapper(object):
"""
Wrapper dict class to avoid annoying key dict indexing like:
`config.sample_rate` instead of `config["sample_rate"]`.
"""
def __init__(self, **kwargs):
for k, v in kwargs.items():
if type(v) == dict:
v = ConfigWrapper(**v)
self[k] = v
def keys(self):
return self.__dict__.keys()
def items(self):
return self.__dict__.items()
def values(self):
return self.__dict__.values()
def to_dict_type(self):
return {
key: (value if not isinstance(value, ConfigWrapper) else value.to_dict_type())
for key, value in dict(**self).items()
}
def __len__(self):
return len(self.__dict__)
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, value):
return setattr(self, key, value)
def __contains__(self, key):
return key in self.__dict__
def __repr__(self):
return self.__dict__.__repr__()
def save_checkpoint(steps, epochs, model, optimizer, scheduler, checkpoint_path, dst_train=False):
"""Save checkpoint.
Args:
checkpoint_path (str): Checkpoint path to be saved.
"""
state_dict = {
"optimizer": {
"generator": optimizer["generator"].state_dict(),
"discriminator": optimizer["discriminator"].state_dict(),
},
"scheduler": {
"generator": scheduler["generator"].state_dict(),
"discriminator": scheduler["discriminator"].state_dict(),
},
"steps": steps,
"epochs": epochs,
}
if dst_train:
state_dict["model"] = {
"generator": model["generator"].module.state_dict(),
"discriminator": model["discriminator"].module.state_dict(),
}
else:
state_dict["model"] = {
"generator": model["generator"].state_dict(),
"discriminator": model["discriminator"].state_dict(),
}
if not os.path.exists(os.path.dirname(checkpoint_path)):
os.makedirs(os.path.dirname(checkpoint_path))
torch.save(state_dict, checkpoint_path)
def load_checkpoint(model, optimizer, scheduler, checkpoint_path, load_only_params=False, dst_train=False):
"""Load checkpoint.
Args:
checkpoint_path (str): Checkpoint path to be loaded.
load_only_params (bool): Whether to load only model parameters.
"""
state_dict = torch.load(checkpoint_path, map_location="cpu")
if dst_train:
model["generator"].module.load_state_dict(
state_dict["model"]["generator"]
)
model["discriminator"].module.load_state_dict(
state_dict["model"]["discriminator"]
)
else:
model["generator"].load_state_dict(state_dict["model"]["generator"])
model["discriminator"].load_state_dict(
state_dict["model"]["discriminator"]
)
optimizer["generator"].load_state_dict(
state_dict["optimizer"]["generator"]
)
optimizer["discriminator"].load_state_dict(
state_dict["optimizer"]["discriminator"]
)
scheduler["generator"].load_state_dict(
state_dict["scheduler"]["generator"]
)
scheduler["discriminator"].load_state_dict(
state_dict["scheduler"]["discriminator"]
)
steps = state_dict["steps"]
epochs = state_dict["epochs"]
return steps, epochs