Spaces:
Running
Running
import argparse | |
import logging | |
import os | |
import sys | |
import json | |
import random | |
import numpy as np | |
import torch | |
from torch.utils.data import DataLoader | |
from utils.tools import ConfigWrapper | |
from dataset.dataset import SVCDataset | |
from modules.FastSVC import SVCNN | |
from modules.discriminator import MelGANMultiScaleDiscriminator | |
from optimizers.scheduler import StepLRScheduler | |
from loss.adversarial_loss import GeneratorAdversarialLoss | |
from loss.adversarial_loss import DiscriminatorAdversarialLoss | |
from loss.stft_loss import MultiResolutionSTFTLoss | |
from trainer import Trainer | |
def main(): | |
"""Run training process.""" | |
parser = argparse.ArgumentParser( | |
description="Train the FastSVC model." | |
) | |
parser.add_argument( | |
"--data_root", | |
type=str, | |
required=True, | |
help="dataset root path.", | |
) | |
parser.add_argument( | |
"--config", | |
type=str, | |
required=True, | |
help="configuration file path.", | |
) | |
parser.add_argument( | |
"--cp_path", | |
required=True, | |
type=str, | |
nargs="?", | |
help='checkpoint file path.', | |
) | |
parser.add_argument( | |
"--pretrain", | |
default="", | |
type=str, | |
nargs="?", | |
help='checkpoint file path to load pretrained params. (default="")', | |
) | |
parser.add_argument( | |
"--resume", | |
default=False, | |
type=bool, | |
nargs="?", | |
help='whether to resume training from a certain checkpoint.', | |
) | |
parser.add_argument( | |
"--seed", | |
default=0, | |
type=int, | |
help="random seed.", | |
) | |
parser.add_argument( | |
"--verbose", | |
type=int, | |
default=1, | |
help="logging level. higher is more logging. (default=1)", | |
) | |
args = parser.parse_args() | |
local_rank = 0 | |
args.distributed = False | |
if not torch.cuda.is_available(): | |
device = torch.device("cpu") | |
else: | |
device = torch.device("cuda") | |
# effective when using fixed size inputs | |
# see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936 | |
torch.backends.cudnn.benchmark = True | |
# setup for distributed training | |
# see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed | |
args.world_size = torch.cuda.device_count() | |
args.distributed = args.world_size > 1 | |
if args.distributed: | |
local_rank = int(os.environ["LOCAL_RANK"]) | |
torch.cuda.set_device(local_rank) | |
print('Using multi-GPUs for training. n_GPU=%d.' %(args.world_size)) | |
torch.distributed.init_process_group(backend="nccl") | |
# random seed | |
torch.manual_seed(args.seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(args.seed) | |
np.random.seed(args.seed) | |
random.seed(args.seed) | |
# suppress logging for distributed training | |
if local_rank != 0: | |
sys.stdout = open(os.devnull, "w") | |
# set logger | |
if args.verbose > 1: | |
logging.basicConfig( | |
level=logging.DEBUG, | |
stream=sys.stdout, | |
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", | |
) | |
elif args.verbose > 0: | |
logging.basicConfig( | |
level=logging.INFO, | |
stream=sys.stdout, | |
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", | |
) | |
else: | |
logging.basicConfig( | |
level=logging.WARN, | |
stream=sys.stdout, | |
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", | |
) | |
logging.warning("Skip DEBUG/INFO messages") | |
# load and save config | |
with open(args.config) as f: | |
config = ConfigWrapper(**json.load(f)) | |
config.training_config.rank = local_rank | |
config.training_config.distributed = args.distributed | |
config.interval_config.out_dir = args.cp_path | |
# get dataset | |
train_set = SVCDataset(args.data_root, config.data_config.n_samples, config.data_config.sampling_rate, config.data_config.hop_size, 'train') | |
valid_set = SVCDataset(args.data_root, config.data_config.n_samples, config.data_config.sampling_rate, config.data_config.hop_size, 'valid') | |
dataset = { | |
"train": train_set, | |
"dev": valid_set, | |
} | |
# get data loader | |
sampler = {"train": None, "dev": None} | |
if args.distributed: | |
# setup sampler for distributed training | |
from torch.utils.data.distributed import DistributedSampler | |
sampler["train"] = DistributedSampler( | |
dataset=dataset["train"], | |
num_replicas=args.world_size, | |
rank=local_rank, | |
shuffle=True, | |
) | |
data_loader = { | |
"train": DataLoader( | |
dataset=dataset["train"], | |
shuffle=False if args.distributed else True, | |
batch_size=config.data_config.batch_size, | |
num_workers=config.data_config.num_workers, | |
sampler=sampler["train"], | |
pin_memory=config.data_config.pin_memory, | |
drop_last=True, | |
), | |
"dev": DataLoader( | |
dataset=dataset["dev"], | |
shuffle=False, | |
batch_size=config.data_config.batch_size, | |
num_workers=config.data_config.num_workers, | |
sampler=sampler["dev"], | |
pin_memory=config.data_config.pin_memory, | |
), | |
} | |
# define models | |
svc_mdl = SVCNN(config).to(device) | |
discriminator = MelGANMultiScaleDiscriminator().to(device) | |
model = { | |
"generator": svc_mdl, | |
"discriminator": discriminator, | |
} | |
# define criterions | |
criterion = { | |
"gen_adv": GeneratorAdversarialLoss( | |
# keep compatibility | |
**config.loss_config.generator_adv_loss_params | |
).to(device), | |
"dis_adv": DiscriminatorAdversarialLoss( | |
# keep compatibility | |
**config.loss_config.discriminator_adv_loss_params | |
).to(device), | |
} | |
criterion["stft"] = MultiResolutionSTFTLoss( | |
**config.loss_config.stft, | |
).to(device) | |
# define optimizers and schedulers | |
optimizer = { | |
"generator": torch.optim.Adam(model["generator"].parameters(), lr=config.optimizer_config.lr), | |
"discriminator": torch.optim.Adam(model["discriminator"].parameters(), lr=config.optimizer_config.lr), | |
} | |
scheduler = { | |
"generator": StepLRScheduler(optimizer["generator"], step_size=config.optimizer_config.scheduler_step_size, gamma=config.optimizer_config.scheduler_gamma), | |
"discriminator": StepLRScheduler(optimizer["discriminator"], step_size=config.optimizer_config.scheduler_step_size, gamma=config.optimizer_config.scheduler_gamma), | |
} | |
if args.distributed: | |
from torch.nn.parallel import DistributedDataParallel | |
model["generator"] = DistributedDataParallel(model["generator"]) | |
model["discriminator"] = DistributedDataParallel(model["discriminator"]) | |
# define trainer | |
trainer = Trainer( | |
steps=0, | |
epochs=0, | |
data_loader=data_loader, | |
sampler=sampler, | |
model=model, | |
criterion=criterion, | |
optimizer=optimizer, | |
scheduler=scheduler, | |
config=config, | |
device=device, | |
) | |
# load pretrained parameters from checkpoint | |
if args.resume: | |
if args.pretrain != "": | |
trainer.load_checkpoint(args.pretrain, load_only_params=False, dst_train=args.distributed) | |
logging.info(f"Successfully load parameters from {args.pretrain}.") | |
else: | |
if os.path.isdir(args.cp_path): | |
dir_files = os.listdir(args.cp_path) | |
cp_files = [fname for fname in dir_files if fname[:11] == 'checkpoint-'] | |
if len(cp_files) == 0: | |
logging.info(f'No pretrained checkpoints. Training from scratch...') | |
else: | |
cp_files.sort(key=lambda fname: os.path.getmtime(f'{args.cp_path}/{fname}')) | |
latest_cp = f'{args.cp_path}/{cp_files[-1]}' | |
trainer.load_checkpoint(latest_cp, load_only_params=False, dst_train=args.distributed) | |
logging.info(f'No pretrain dir specified, use the latest one instead. Successfully load parameters from {latest_cp}.') | |
else: | |
logging.info(f'No pretrain dir specified. Training from scratch...') | |
# run training loop | |
try: | |
trainer.run() | |
finally: | |
trainer.save_checkpoint( | |
os.path.join(config.interval_config.out_dir, f"checkpoint-{trainer.steps}steps.pkl"), args.distributed | |
) | |
logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.") | |
if __name__ == "__main__": | |
main() | |