Spaces:
Running
Running
import os | |
import torch | |
import numpy as np | |
import soundfile as sf | |
def fast_cosine_dist(source_feats: torch.Tensor, matching_pool: torch.Tensor, device='cpu'): | |
""" | |
Computes the cosine distance between source features and a matching pool of features. | |
Like torch.cdist, but fixed dim=-1 and for cosine distance. | |
Args: | |
source_feats (torch.Tensor): Tensor of source features with shape (n_source_feats, feat_dim). | |
matching_pool (torch.Tensor): Tensor of matching pool features with shape (n_matching_feats, feat_dim). | |
device (str, optional): Device to perform the computation on. Defaults to 'cpu'. | |
Returns: | |
torch.Tensor: Tensor of cosine distances between the source features and the matching pool features. | |
""" | |
source_feats = source_feats.to(device) | |
matching_pool = matching_pool.to(device) | |
source_norms = torch.norm(source_feats, p=2, dim=-1) | |
matching_norms = torch.norm(matching_pool, p=2, dim=-1) | |
dotprod = -torch.cdist(source_feats[None].to(device), matching_pool[None], p=2)[0]**2 + source_norms[:, None]**2 + matching_norms[None]**2 | |
dotprod /= 2 | |
dists = 1 - ( dotprod / (source_norms[:, None] * matching_norms[None]) ) | |
return dists | |
def load_wav(wav_path, sr=None): | |
""" | |
Loads a waveform from a wav file. | |
Args: | |
wav_path (str): Path to the wav file. | |
sr (int, optional): Target sample rate. | |
If `sr` is specified and the loaded audio has a different sample rate, an AssertionError is raised. | |
Defaults to None. | |
Returns: | |
Tuple[np.ndarray, int]: Tuple containing the loaded waveform as a NumPy array and the sample rate. | |
""" | |
wav, fs = sf.read(wav_path) | |
if wav.ndim != 1: | |
print('The wav file %s has %d channels, select the first one to proceed.' %(wav_path, wav.ndim)) | |
wav = wav[:,0] | |
assert sr is None or fs == sr, f'{sr} kHz audio is required. Got {fs}' | |
peak = np.abs(wav).max() | |
if peak > 1.0: | |
wav /= peak | |
return wav, fs | |
class ConfigWrapper(object): | |
""" | |
Wrapper dict class to avoid annoying key dict indexing like: | |
`config.sample_rate` instead of `config["sample_rate"]`. | |
""" | |
def __init__(self, **kwargs): | |
for k, v in kwargs.items(): | |
if type(v) == dict: | |
v = ConfigWrapper(**v) | |
self[k] = v | |
def keys(self): | |
return self.__dict__.keys() | |
def items(self): | |
return self.__dict__.items() | |
def values(self): | |
return self.__dict__.values() | |
def to_dict_type(self): | |
return { | |
key: (value if not isinstance(value, ConfigWrapper) else value.to_dict_type()) | |
for key, value in dict(**self).items() | |
} | |
def __len__(self): | |
return len(self.__dict__) | |
def __getitem__(self, key): | |
return getattr(self, key) | |
def __setitem__(self, key, value): | |
return setattr(self, key, value) | |
def __contains__(self, key): | |
return key in self.__dict__ | |
def __repr__(self): | |
return self.__dict__.__repr__() | |
def save_checkpoint(steps, epochs, model, optimizer, scheduler, checkpoint_path, dst_train=False): | |
"""Save checkpoint. | |
Args: | |
checkpoint_path (str): Checkpoint path to be saved. | |
""" | |
state_dict = { | |
"optimizer": { | |
"generator": optimizer["generator"].state_dict(), | |
"discriminator": optimizer["discriminator"].state_dict(), | |
}, | |
"scheduler": { | |
"generator": scheduler["generator"].state_dict(), | |
"discriminator": scheduler["discriminator"].state_dict(), | |
}, | |
"steps": steps, | |
"epochs": epochs, | |
} | |
if dst_train: | |
state_dict["model"] = { | |
"generator": model["generator"].module.state_dict(), | |
"discriminator": model["discriminator"].module.state_dict(), | |
} | |
else: | |
state_dict["model"] = { | |
"generator": model["generator"].state_dict(), | |
"discriminator": model["discriminator"].state_dict(), | |
} | |
if not os.path.exists(os.path.dirname(checkpoint_path)): | |
os.makedirs(os.path.dirname(checkpoint_path)) | |
torch.save(state_dict, checkpoint_path) | |
def load_checkpoint(model, optimizer, scheduler, checkpoint_path, load_only_params=False, dst_train=False): | |
"""Load checkpoint. | |
Args: | |
checkpoint_path (str): Checkpoint path to be loaded. | |
load_only_params (bool): Whether to load only model parameters. | |
""" | |
state_dict = torch.load(checkpoint_path, map_location="cpu") | |
if dst_train: | |
model["generator"].module.load_state_dict( | |
state_dict["model"]["generator"] | |
) | |
model["discriminator"].module.load_state_dict( | |
state_dict["model"]["discriminator"] | |
) | |
else: | |
model["generator"].load_state_dict(state_dict["model"]["generator"]) | |
model["discriminator"].load_state_dict( | |
state_dict["model"]["discriminator"] | |
) | |
optimizer["generator"].load_state_dict( | |
state_dict["optimizer"]["generator"] | |
) | |
optimizer["discriminator"].load_state_dict( | |
state_dict["optimizer"]["discriminator"] | |
) | |
scheduler["generator"].load_state_dict( | |
state_dict["scheduler"]["generator"] | |
) | |
scheduler["discriminator"].load_state_dict( | |
state_dict["scheduler"]["discriminator"] | |
) | |
steps = state_dict["steps"] | |
epochs = state_dict["epochs"] | |
return steps, epochs | |