Spaces:
Sleeping
Sleeping
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
import torch | |
import time | |
import accelerate | |
import random | |
import numpy as np | |
from tqdm import tqdm | |
from accelerate.logging import get_logger | |
from torch.utils.data import DataLoader | |
from abc import abstractmethod | |
from pathlib import Path | |
from utils.io import save_audio | |
from utils.util import load_config | |
from models.vocoders.vocoder_inference import synthesis | |
class TTSInference(object): | |
def __init__(self, args=None, cfg=None): | |
super().__init__() | |
start = time.monotonic_ns() | |
self.args = args | |
self.cfg = cfg | |
self.infer_type = args.mode | |
# get exp_dir | |
if self.args.acoustics_dir is not None: | |
self.exp_dir = self.args.acoustics_dir | |
elif self.args.checkpoint_path is not None: | |
self.exp_dir = os.path.dirname(os.path.dirname(self.args.checkpoint_path)) | |
# Init accelerator | |
self.accelerator = accelerate.Accelerator() | |
self.accelerator.wait_for_everyone() | |
self.device = self.accelerator.device | |
# Get logger | |
with self.accelerator.main_process_first(): | |
self.logger = get_logger("inference", log_level=args.log_level) | |
# Log some info | |
self.logger.info("=" * 56) | |
self.logger.info("||\t\t" + "New inference process started." + "\t\t||") | |
self.logger.info("=" * 56) | |
self.logger.info("\n") | |
self.acoustic_model_dir = args.acoustics_dir | |
self.logger.debug(f"Acoustic model dir: {args.acoustics_dir}") | |
if args.vocoder_dir is not None: | |
self.vocoder_dir = args.vocoder_dir | |
self.logger.debug(f"Vocoder dir: {args.vocoder_dir}") | |
os.makedirs(args.output_dir, exist_ok=True) | |
# Set random seed | |
with self.accelerator.main_process_first(): | |
start = time.monotonic_ns() | |
self._set_random_seed(self.cfg.train.random_seed) | |
end = time.monotonic_ns() | |
self.logger.debug( | |
f"Setting random seed done in {(end - start) / 1e6:.2f}ms" | |
) | |
self.logger.debug(f"Random seed: {self.cfg.train.random_seed}") | |
# Setup data loader | |
if self.infer_type == "batch": | |
with self.accelerator.main_process_first(): | |
self.logger.info("Building dataset...") | |
start = time.monotonic_ns() | |
self.test_dataloader = self._build_test_dataloader() | |
end = time.monotonic_ns() | |
self.logger.info( | |
f"Building dataset done in {(end - start) / 1e6:.2f}ms" | |
) | |
# Build model | |
with self.accelerator.main_process_first(): | |
self.logger.info("Building model...") | |
start = time.monotonic_ns() | |
self.model = self._build_model() | |
end = time.monotonic_ns() | |
self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms") | |
# Init with accelerate | |
self.logger.info("Initializing accelerate...") | |
start = time.monotonic_ns() | |
self.accelerator = accelerate.Accelerator() | |
self.model = self.accelerator.prepare(self.model) | |
if self.infer_type == "batch": | |
self.test_dataloader = self.accelerator.prepare(self.test_dataloader) | |
end = time.monotonic_ns() | |
self.accelerator.wait_for_everyone() | |
self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms") | |
with self.accelerator.main_process_first(): | |
self.logger.info("Loading checkpoint...") | |
start = time.monotonic_ns() | |
if args.acoustics_dir is not None: | |
self._load_model( | |
checkpoint_dir=os.path.join(args.acoustics_dir, "checkpoint") | |
) | |
elif args.checkpoint_path is not None: | |
self._load_model(checkpoint_path=args.checkpoint_path) | |
else: | |
print("Either checkpoint dir or checkpoint path should be provided.") | |
end = time.monotonic_ns() | |
self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms") | |
self.model.eval() | |
self.accelerator.wait_for_everyone() | |
def _build_test_dataset(self): | |
pass | |
def _build_model(self): | |
pass | |
# TODO: LEGACY CODE | |
def _build_test_dataloader(self): | |
datasets, collate = self._build_test_dataset() | |
self.test_dataset = datasets(self.args, self.cfg) | |
self.test_collate = collate(self.cfg) | |
self.test_batch_size = min( | |
self.cfg.train.batch_size, len(self.test_dataset.metadata) | |
) | |
test_dataloader = DataLoader( | |
self.test_dataset, | |
collate_fn=self.test_collate, | |
num_workers=1, | |
batch_size=self.test_batch_size, | |
shuffle=False, | |
) | |
return test_dataloader | |
def _load_model( | |
self, | |
checkpoint_dir: str = None, | |
checkpoint_path: str = None, | |
old_mode: bool = False, | |
): | |
r"""Load model from checkpoint. If checkpoint_path is None, it will | |
load the latest checkpoint in checkpoint_dir. If checkpoint_path is not | |
None, it will load the checkpoint specified by checkpoint_path. **Only use this | |
method after** ``accelerator.prepare()``. | |
""" | |
if checkpoint_path is None: | |
assert checkpoint_dir is not None | |
# Load the latest accelerator state dicts | |
ls = [ | |
str(i) for i in Path(checkpoint_dir).glob("*") if not "audio" in str(i) | |
] | |
ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True) | |
checkpoint_path = ls[0] | |
self.accelerator.load_state(str(checkpoint_path)) | |
return str(checkpoint_path) | |
def inference(self): | |
if self.infer_type == "single": | |
out_dir = os.path.join(self.args.output_dir, "single") | |
os.makedirs(out_dir, exist_ok=True) | |
pred_audio = self.inference_for_single_utterance() | |
save_path = os.path.join(out_dir, "test_pred.wav") | |
save_audio(save_path, pred_audio, self.cfg.preprocess.sample_rate) | |
elif self.infer_type == "batch": | |
out_dir = os.path.join(self.args.output_dir, "batch") | |
os.makedirs(out_dir, exist_ok=True) | |
pred_audio_list = self.inference_for_batches() | |
for it, wav in zip(self.test_dataset.metadata, pred_audio_list): | |
uid = it["Uid"] | |
save_audio( | |
os.path.join(out_dir, f"{uid}.wav"), | |
wav.numpy(), | |
self.cfg.preprocess.sample_rate, | |
add_silence=True, | |
turn_up=True, | |
) | |
tmp_file = os.path.join(out_dir, f"{uid}.pt") | |
if os.path.exists(tmp_file): | |
os.remove(tmp_file) | |
print("Saved to: ", out_dir) | |
def inference_for_batches(self): | |
y_pred = [] | |
for i, batch in tqdm(enumerate(self.test_dataloader)): | |
y_pred, mel_lens, _ = self._inference_each_batch(batch) | |
y_ls = y_pred.chunk(self.test_batch_size) | |
tgt_ls = mel_lens.chunk(self.test_batch_size) | |
j = 0 | |
for it, l in zip(y_ls, tgt_ls): | |
l = l.item() | |
it = it.squeeze(0)[:l].detach().cpu() | |
uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"] | |
torch.save(it, os.path.join(self.args.output_dir, f"{uid}.pt")) | |
j += 1 | |
vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir) | |
res = synthesis( | |
cfg=vocoder_cfg, | |
vocoder_weight_file=vocoder_ckpt, | |
n_samples=None, | |
pred=[ | |
torch.load( | |
os.path.join(self.args.output_dir, "{}.pt".format(item["Uid"])) | |
).numpy() | |
for item in self.test_dataset.metadata | |
], | |
) | |
for it, wav in zip(self.test_dataset.metadata, res): | |
uid = it["Uid"] | |
save_audio( | |
os.path.join(self.args.output_dir, f"{uid}.wav"), | |
wav.numpy(), | |
22050, | |
add_silence=True, | |
turn_up=True, | |
) | |
def _inference_each_batch(self, batch_data): | |
pass | |
def inference_for_single_utterance(self, text): | |
pass | |
def synthesis_by_vocoder(self, pred): | |
audios_pred = synthesis( | |
self.vocoder_cfg, | |
self.checkpoint_dir_vocoder, | |
len(pred), | |
pred, | |
) | |
return audios_pred | |
def _parse_vocoder(vocoder_dir): | |
r"""Parse vocoder config""" | |
vocoder_dir = os.path.abspath(vocoder_dir) | |
ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")] | |
ckpt_list.sort(key=lambda x: int(x.stem), reverse=True) | |
ckpt_path = str(ckpt_list[0]) | |
vocoder_cfg = load_config( | |
os.path.join(vocoder_dir, "args.json"), lowercase=True | |
) | |
return vocoder_cfg, ckpt_path | |
def _set_random_seed(self, seed): | |
"""Set random seed for all possible random modules.""" | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.random.manual_seed(seed) | |