|
r""" |
|
The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it |
|
when needed. |
|
|
|
Parameters from hparam.py will be used |
|
""" |
|
import argparse |
|
import json |
|
import os |
|
import sys |
|
from pathlib import Path |
|
|
|
import rootutils |
|
import torch |
|
from hydra import compose, initialize |
|
from omegaconf import open_dict |
|
from tqdm.auto import tqdm |
|
|
|
from matcha.data.text_mel_datamodule import TextMelDataModule |
|
from matcha.utils.logging_utils import pylogger |
|
|
|
log = pylogger.get_pylogger(__name__) |
|
|
|
|
|
def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channels: int): |
|
"""Generate data mean and standard deviation helpful in data normalisation |
|
|
|
Args: |
|
data_loader (torch.utils.data.Dataloader): _description_ |
|
out_channels (int): mel spectrogram channels |
|
""" |
|
total_mel_sum = 0 |
|
total_mel_sq_sum = 0 |
|
total_mel_len = 0 |
|
|
|
for batch in tqdm(data_loader, leave=False): |
|
mels = batch["y"] |
|
mel_lengths = batch["y_lengths"] |
|
|
|
total_mel_len += torch.sum(mel_lengths) |
|
total_mel_sum += torch.sum(mels) |
|
total_mel_sq_sum += torch.sum(torch.pow(mels, 2)) |
|
|
|
data_mean = total_mel_sum / (total_mel_len * out_channels) |
|
data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2)) |
|
|
|
return {"mel_mean": data_mean.item(), "mel_std": data_std.item()} |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"-i", |
|
"--input-config", |
|
type=str, |
|
default="vctk.yaml", |
|
help="The name of the yaml config file under configs/data", |
|
) |
|
|
|
parser.add_argument( |
|
"-b", |
|
"--batch-size", |
|
type=int, |
|
default="256", |
|
help="Can have increased batch size for faster computation", |
|
) |
|
|
|
parser.add_argument( |
|
"-f", |
|
"--force", |
|
action="store_true", |
|
default=False, |
|
required=False, |
|
help="force overwrite the file", |
|
) |
|
args = parser.parse_args() |
|
output_file = Path(args.input_config).with_suffix(".json") |
|
|
|
if os.path.exists(output_file) and not args.force: |
|
print("File already exists. Use -f to force overwrite") |
|
sys.exit(1) |
|
|
|
with initialize(version_base="1.3", config_path="../../configs/data"): |
|
cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[]) |
|
|
|
root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") |
|
|
|
with open_dict(cfg): |
|
del cfg["hydra"] |
|
del cfg["_target_"] |
|
cfg["data_statistics"] = None |
|
cfg["seed"] = 1234 |
|
cfg["batch_size"] = args.batch_size |
|
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) |
|
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) |
|
|
|
text_mel_datamodule = TextMelDataModule(**cfg) |
|
text_mel_datamodule.setup() |
|
data_loader = text_mel_datamodule.train_dataloader() |
|
log.info("Dataloader loaded! Now computing stats...") |
|
params = compute_data_statistics(data_loader, cfg["n_feats"]) |
|
print(params) |
|
json.dump( |
|
params, |
|
open(output_file, "w"), |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|