Spaces:
Build error
Build error
ddd
commited on
Commit
·
b93970c
1
Parent(s):
aee7e5a
Add application file
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LICENSE +21 -0
- README.md +83 -12
- checkpoints/.gitkeep +0 -0
- configs/config_base.yaml +42 -0
- configs/singing/base.yaml +42 -0
- configs/singing/fs2.yaml +3 -0
- configs/tts/base.yaml +95 -0
- configs/tts/base_zh.yaml +3 -0
- configs/tts/fs2.yaml +80 -0
- configs/tts/hifigan.yaml +21 -0
- configs/tts/lj/base_mel2wav.yaml +3 -0
- configs/tts/lj/base_text2mel.yaml +13 -0
- configs/tts/lj/fs2.yaml +3 -0
- configs/tts/lj/hifigan.yaml +3 -0
- configs/tts/lj/pwg.yaml +3 -0
- configs/tts/pwg.yaml +110 -0
- data/processed/ljspeech/dict.txt +77 -0
- data/processed/ljspeech/metadata_phone.csv +0 -0
- data/processed/ljspeech/mfa_dict.txt +0 -0
- data/processed/ljspeech/phone_set.json +1 -0
- data_gen/singing/binarize.py +398 -0
- data_gen/tts/base_binarizer.py +224 -0
- data_gen/tts/bin/binarize.py +20 -0
- data_gen/tts/binarizer_zh.py +59 -0
- data_gen/tts/data_gen_utils.py +347 -0
- data_gen/tts/txt_processors/base_text_processor.py +8 -0
- data_gen/tts/txt_processors/en.py +78 -0
- data_gen/tts/txt_processors/zh.py +41 -0
- data_gen/tts/txt_processors/zh_g2pM.py +72 -0
- docs/README-SVS-opencpop-cascade.md +111 -0
- docs/README-SVS-opencpop-e2e.md +106 -0
- docs/README-SVS-popcs.md +63 -0
- docs/README-SVS.md +44 -0
- docs/README-TTS.md +63 -0
- docs/README-zh.md +212 -0
- inference/svs/base_svs_infer.py +265 -0
- inference/svs/ds_cascade.py +54 -0
- inference/svs/ds_e2e.py +67 -0
- inference/svs/gradio/gradio_settings.yaml +19 -0
- inference/svs/gradio/infer.py +91 -0
- inference/svs/opencpop/cpop_pinyin2ph.txt +418 -0
- inference/svs/opencpop/map.py +8 -0
- modules/__init__.py +0 -0
- modules/commons/common_layers.py +668 -0
- modules/commons/espnet_positional_embedding.py +113 -0
- modules/commons/ssim.py +391 -0
- modules/diffsinger_midi/fs2.py +118 -0
- modules/fastspeech/fs2.py +255 -0
- modules/fastspeech/pe.py +149 -0
- modules/fastspeech/tts_modules.py +357 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2021 Jinglin Liu
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,83 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DiffSinger: Singing Voice Synthesis via Shallow Diffusion Mechanism
|
2 |
+
[![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2105.02446)
|
3 |
+
[![GitHub Stars](https://img.shields.io/github/stars/MoonInTheRiver/DiffSinger?style=social)](https://github.com/MoonInTheRiver/DiffSinger)
|
4 |
+
[![downloads](https://img.shields.io/github/downloads/MoonInTheRiver/DiffSinger/total.svg)](https://github.com/MoonInTheRiver/DiffSinger/releases)
|
5 |
+
| [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-blue)](https://huggingface.co/spaces/NATSpeech/DiffSpeech)
|
6 |
+
|
7 |
+
This repository is the official PyTorch implementation of our AAAI-2022 [paper](https://arxiv.org/abs/2105.02446), in which we propose DiffSinger (for Singing-Voice-Synthesis) and DiffSpeech (for Text-to-Speech).
|
8 |
+
|
9 |
+
<table style="width:100%">
|
10 |
+
<tr>
|
11 |
+
<th>DiffSinger/DiffSpeech at training</th>
|
12 |
+
<th>DiffSinger/DiffSpeech at inference</th>
|
13 |
+
</tr>
|
14 |
+
<tr>
|
15 |
+
<td><img src="resources/model_a.png" alt="Training" height="300"></td>
|
16 |
+
<td><img src="resources/model_b.png" alt="Inference" height="300"></td>
|
17 |
+
</tr>
|
18 |
+
</table>
|
19 |
+
|
20 |
+
:tada: :tada: :tada: **Updates**:
|
21 |
+
- Mar.2, 2022: [MIDI-new-version](docs/README-SVS-opencpop-e2e.md): A substantial improvement :sparkles:
|
22 |
+
- Mar.1, 2022: [NeuralSVB](https://github.com/MoonInTheRiver/NeuralSVB), for singing voice beautifying, has been released :sparkles: :sparkles: :sparkles: .
|
23 |
+
- Feb.13, 2022: [NATSpeech](https://github.com/NATSpeech/NATSpeech), the improved code framework, which contains the implementations of DiffSpeech and our NeurIPS-2021 work [PortaSpeech](https://openreview.net/forum?id=xmJsuh8xlq) has been released :sparkles: :sparkles: :sparkles:.
|
24 |
+
- Jan.29, 2022: support [MIDI-old-version](docs/README-SVS-opencpop-cascade.md) SVS. :construction: :pick: :hammer_and_wrench:
|
25 |
+
- Jan.13, 2022: support SVS, release PopCS dataset.
|
26 |
+
- Dec.19, 2021: support TTS. [HuggingFace🤗 Demo](https://huggingface.co/spaces/NATSpeech/DiffSpeech)
|
27 |
+
|
28 |
+
:rocket: **News**:
|
29 |
+
- Feb.24, 2022: Our new work, NeuralSVB was accepted by ACL-2022 [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2202.13277). [Demo Page](https://neuralsvb.github.io).
|
30 |
+
- Dec.01, 2021: DiffSinger was accepted by AAAI-2022.
|
31 |
+
- Sep.29, 2021: Our recent work `PortaSpeech: Portable and High-Quality Generative Text-to-Speech` was accepted by NeurIPS-2021 [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2109.15166) .
|
32 |
+
- May.06, 2021: We submitted DiffSinger to Arxiv [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2105.02446).
|
33 |
+
|
34 |
+
## Environments
|
35 |
+
```sh
|
36 |
+
conda create -n your_env_name python=3.8
|
37 |
+
source activate your_env_name
|
38 |
+
pip install -r requirements_2080.txt (GPU 2080Ti, CUDA 10.2)
|
39 |
+
or pip install -r requirements_3090.txt (GPU 3090, CUDA 11.4)
|
40 |
+
```
|
41 |
+
|
42 |
+
## Documents
|
43 |
+
- [Run DiffSpeech (TTS version)](docs/README-TTS.md).
|
44 |
+
- [Run DiffSinger (SVS version)](docs/README-SVS.md).
|
45 |
+
|
46 |
+
## Tensorboard
|
47 |
+
```sh
|
48 |
+
tensorboard --logdir_spec exp_name
|
49 |
+
```
|
50 |
+
<table style="width:100%">
|
51 |
+
<tr>
|
52 |
+
<td><img src="resources/tfb.png" alt="Tensorboard" height="250"></td>
|
53 |
+
</tr>
|
54 |
+
</table>
|
55 |
+
|
56 |
+
## Audio Demos
|
57 |
+
Old audio samples can be found in our [demo page](https://diffsinger.github.io/). Audio samples generated by this repository are listed here:
|
58 |
+
|
59 |
+
### TTS audio samples
|
60 |
+
Speech samples (test set of LJSpeech) can be found in [resources/demos_1213](https://github.com/MoonInTheRiver/DiffSinger/blob/master/resources/demos_1213).
|
61 |
+
|
62 |
+
### SVS audio samples
|
63 |
+
Singing samples (test set of PopCS) can be found in [resources/demos_0112](https://github.com/MoonInTheRiver/DiffSinger/blob/master/resources/demos_0112).
|
64 |
+
|
65 |
+
## Citation
|
66 |
+
@article{liu2021diffsinger,
|
67 |
+
title={Diffsinger: Singing voice synthesis via shallow diffusion mechanism},
|
68 |
+
author={Liu, Jinglin and Li, Chengxi and Ren, Yi and Chen, Feiyang and Liu, Peng and Zhao, Zhou},
|
69 |
+
journal={arXiv preprint arXiv:2105.02446},
|
70 |
+
volume={2},
|
71 |
+
year={2021}}
|
72 |
+
|
73 |
+
|
74 |
+
## Acknowledgements
|
75 |
+
Our codes are based on the following repos:
|
76 |
+
* [denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch)
|
77 |
+
* [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning)
|
78 |
+
* [ParallelWaveGAN](https://github.com/kan-bayashi/ParallelWaveGAN)
|
79 |
+
* [HifiGAN](https://github.com/jik876/hifi-gan)
|
80 |
+
* [espnet](https://github.com/espnet/espnet)
|
81 |
+
* [DiffWave](https://github.com/lmnt-com/diffwave)
|
82 |
+
|
83 |
+
Also thanks [Keon Lee](https://github.com/keonlee9420/DiffSinger) for fast implementation of our work.
|
checkpoints/.gitkeep
ADDED
File without changes
|
configs/config_base.yaml
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# task
|
2 |
+
binary_data_dir: ''
|
3 |
+
work_dir: '' # experiment directory.
|
4 |
+
infer: false # infer
|
5 |
+
seed: 1234
|
6 |
+
debug: false
|
7 |
+
save_codes:
|
8 |
+
- configs
|
9 |
+
- modules
|
10 |
+
- tasks
|
11 |
+
- utils
|
12 |
+
- usr
|
13 |
+
|
14 |
+
#############
|
15 |
+
# dataset
|
16 |
+
#############
|
17 |
+
ds_workers: 1
|
18 |
+
test_num: 100
|
19 |
+
valid_num: 100
|
20 |
+
endless_ds: false
|
21 |
+
sort_by_len: true
|
22 |
+
|
23 |
+
#########
|
24 |
+
# train and eval
|
25 |
+
#########
|
26 |
+
load_ckpt: ''
|
27 |
+
save_ckpt: true
|
28 |
+
save_best: false
|
29 |
+
num_ckpt_keep: 3
|
30 |
+
clip_grad_norm: 0
|
31 |
+
accumulate_grad_batches: 1
|
32 |
+
log_interval: 100
|
33 |
+
num_sanity_val_steps: 5 # steps of validation at the beginning
|
34 |
+
check_val_every_n_epoch: 10
|
35 |
+
val_check_interval: 2000
|
36 |
+
max_epochs: 1000
|
37 |
+
max_updates: 160000
|
38 |
+
max_tokens: 31250
|
39 |
+
max_sentences: 100000
|
40 |
+
max_eval_tokens: -1
|
41 |
+
max_eval_sentences: -1
|
42 |
+
test_input_dir: ''
|
configs/singing/base.yaml
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_config:
|
2 |
+
- configs/tts/base.yaml
|
3 |
+
- configs/tts/base_zh.yaml
|
4 |
+
|
5 |
+
|
6 |
+
datasets: []
|
7 |
+
test_prefixes: []
|
8 |
+
test_num: 0
|
9 |
+
valid_num: 0
|
10 |
+
|
11 |
+
pre_align_cls: data_gen.singing.pre_align.SingingPreAlign
|
12 |
+
binarizer_cls: data_gen.singing.binarize.SingingBinarizer
|
13 |
+
pre_align_args:
|
14 |
+
use_tone: false # for ZH
|
15 |
+
forced_align: mfa
|
16 |
+
use_sox: true
|
17 |
+
hop_size: 128 # Hop size.
|
18 |
+
fft_size: 512 # FFT size.
|
19 |
+
win_size: 512 # FFT size.
|
20 |
+
max_frames: 8000
|
21 |
+
fmin: 50 # Minimum freq in mel basis calculation.
|
22 |
+
fmax: 11025 # Maximum frequency in mel basis calculation.
|
23 |
+
pitch_type: frame
|
24 |
+
|
25 |
+
hidden_size: 256
|
26 |
+
mel_loss: "ssim:0.5|l1:0.5"
|
27 |
+
lambda_f0: 0.0
|
28 |
+
lambda_uv: 0.0
|
29 |
+
lambda_energy: 0.0
|
30 |
+
lambda_ph_dur: 0.0
|
31 |
+
lambda_sent_dur: 0.0
|
32 |
+
lambda_word_dur: 0.0
|
33 |
+
predictor_grad: 0.0
|
34 |
+
use_spk_embed: true
|
35 |
+
use_spk_id: false
|
36 |
+
|
37 |
+
max_tokens: 20000
|
38 |
+
max_updates: 400000
|
39 |
+
num_spk: 100
|
40 |
+
save_f0: true
|
41 |
+
use_gt_dur: true
|
42 |
+
use_gt_f0: true
|
configs/singing/fs2.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
base_config:
|
2 |
+
- configs/tts/fs2.yaml
|
3 |
+
- configs/singing/base.yaml
|
configs/tts/base.yaml
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# task
|
2 |
+
base_config: configs/config_base.yaml
|
3 |
+
task_cls: ''
|
4 |
+
#############
|
5 |
+
# dataset
|
6 |
+
#############
|
7 |
+
raw_data_dir: ''
|
8 |
+
processed_data_dir: ''
|
9 |
+
binary_data_dir: ''
|
10 |
+
dict_dir: ''
|
11 |
+
pre_align_cls: ''
|
12 |
+
binarizer_cls: data_gen.tts.base_binarizer.BaseBinarizer
|
13 |
+
pre_align_args:
|
14 |
+
use_tone: true # for ZH
|
15 |
+
forced_align: mfa
|
16 |
+
use_sox: false
|
17 |
+
txt_processor: en
|
18 |
+
allow_no_txt: false
|
19 |
+
denoise: false
|
20 |
+
binarization_args:
|
21 |
+
shuffle: false
|
22 |
+
with_txt: true
|
23 |
+
with_wav: false
|
24 |
+
with_align: true
|
25 |
+
with_spk_embed: true
|
26 |
+
with_f0: true
|
27 |
+
with_f0cwt: true
|
28 |
+
|
29 |
+
loud_norm: false
|
30 |
+
endless_ds: true
|
31 |
+
reset_phone_dict: true
|
32 |
+
|
33 |
+
test_num: 100
|
34 |
+
valid_num: 100
|
35 |
+
max_frames: 1550
|
36 |
+
max_input_tokens: 1550
|
37 |
+
audio_num_mel_bins: 80
|
38 |
+
audio_sample_rate: 22050
|
39 |
+
hop_size: 256 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate)
|
40 |
+
win_size: 1024 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate)
|
41 |
+
fmin: 80 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
|
42 |
+
fmax: 7600 # To be increased/reduced depending on data.
|
43 |
+
fft_size: 1024 # Extra window size is filled with 0 paddings to match this parameter
|
44 |
+
min_level_db: -100
|
45 |
+
num_spk: 1
|
46 |
+
mel_vmin: -6
|
47 |
+
mel_vmax: 1.5
|
48 |
+
ds_workers: 4
|
49 |
+
|
50 |
+
#########
|
51 |
+
# model
|
52 |
+
#########
|
53 |
+
dropout: 0.1
|
54 |
+
enc_layers: 4
|
55 |
+
dec_layers: 4
|
56 |
+
hidden_size: 384
|
57 |
+
num_heads: 2
|
58 |
+
prenet_dropout: 0.5
|
59 |
+
prenet_hidden_size: 256
|
60 |
+
stop_token_weight: 5.0
|
61 |
+
enc_ffn_kernel_size: 9
|
62 |
+
dec_ffn_kernel_size: 9
|
63 |
+
ffn_act: gelu
|
64 |
+
ffn_padding: 'SAME'
|
65 |
+
|
66 |
+
|
67 |
+
###########
|
68 |
+
# optimization
|
69 |
+
###########
|
70 |
+
lr: 2.0
|
71 |
+
warmup_updates: 8000
|
72 |
+
optimizer_adam_beta1: 0.9
|
73 |
+
optimizer_adam_beta2: 0.98
|
74 |
+
weight_decay: 0
|
75 |
+
clip_grad_norm: 1
|
76 |
+
|
77 |
+
|
78 |
+
###########
|
79 |
+
# train and eval
|
80 |
+
###########
|
81 |
+
max_tokens: 30000
|
82 |
+
max_sentences: 100000
|
83 |
+
max_eval_sentences: 1
|
84 |
+
max_eval_tokens: 60000
|
85 |
+
train_set_name: 'train'
|
86 |
+
valid_set_name: 'valid'
|
87 |
+
test_set_name: 'test'
|
88 |
+
vocoder: pwg
|
89 |
+
vocoder_ckpt: ''
|
90 |
+
profile_infer: false
|
91 |
+
out_wav_norm: false
|
92 |
+
save_gt: false
|
93 |
+
save_f0: false
|
94 |
+
gen_dir_name: ''
|
95 |
+
use_denoise: false
|
configs/tts/base_zh.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
pre_align_args:
|
2 |
+
txt_processor: zh_g2pM
|
3 |
+
binarizer_cls: data_gen.tts.binarizer_zh.ZhBinarizer
|
configs/tts/fs2.yaml
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_config: configs/tts/base.yaml
|
2 |
+
task_cls: tasks.tts.fs2.FastSpeech2Task
|
3 |
+
|
4 |
+
# model
|
5 |
+
hidden_size: 256
|
6 |
+
dropout: 0.1
|
7 |
+
encoder_type: fft # fft|tacotron|tacotron2|conformer
|
8 |
+
encoder_K: 8 # for tacotron encoder
|
9 |
+
decoder_type: fft # fft|rnn|conv|conformer
|
10 |
+
use_pos_embed: true
|
11 |
+
|
12 |
+
# duration
|
13 |
+
predictor_hidden: -1
|
14 |
+
predictor_kernel: 5
|
15 |
+
predictor_layers: 2
|
16 |
+
dur_predictor_kernel: 3
|
17 |
+
dur_predictor_layers: 2
|
18 |
+
predictor_dropout: 0.5
|
19 |
+
|
20 |
+
# pitch and energy
|
21 |
+
use_pitch_embed: true
|
22 |
+
pitch_type: ph # frame|ph|cwt
|
23 |
+
use_uv: true
|
24 |
+
cwt_hidden_size: 128
|
25 |
+
cwt_layers: 2
|
26 |
+
cwt_loss: l1
|
27 |
+
cwt_add_f0_loss: false
|
28 |
+
cwt_std_scale: 0.8
|
29 |
+
|
30 |
+
pitch_ar: false
|
31 |
+
#pitch_embed_type: 0q
|
32 |
+
pitch_loss: 'l1' # l1|l2|ssim
|
33 |
+
pitch_norm: log
|
34 |
+
use_energy_embed: false
|
35 |
+
|
36 |
+
# reference encoder and speaker embedding
|
37 |
+
use_spk_id: false
|
38 |
+
use_split_spk_id: false
|
39 |
+
use_spk_embed: false
|
40 |
+
use_var_enc: false
|
41 |
+
lambda_commit: 0.25
|
42 |
+
ref_norm_layer: bn
|
43 |
+
pitch_enc_hidden_stride_kernel:
|
44 |
+
- 0,2,5 # conv_hidden_size, conv_stride, conv_kernel_size. conv_hidden_size=0: use hidden_size
|
45 |
+
- 0,2,5
|
46 |
+
- 0,2,5
|
47 |
+
dur_enc_hidden_stride_kernel:
|
48 |
+
- 0,2,3 # conv_hidden_size, conv_stride, conv_kernel_size. conv_hidden_size=0: use hidden_size
|
49 |
+
- 0,2,3
|
50 |
+
- 0,1,3
|
51 |
+
|
52 |
+
|
53 |
+
# mel
|
54 |
+
mel_loss: l1:0.5|ssim:0.5 # l1|l2|gdl|ssim or l1:0.5|ssim:0.5
|
55 |
+
|
56 |
+
# loss lambda
|
57 |
+
lambda_f0: 1.0
|
58 |
+
lambda_uv: 1.0
|
59 |
+
lambda_energy: 0.1
|
60 |
+
lambda_ph_dur: 1.0
|
61 |
+
lambda_sent_dur: 1.0
|
62 |
+
lambda_word_dur: 1.0
|
63 |
+
predictor_grad: 0.1
|
64 |
+
|
65 |
+
# train and eval
|
66 |
+
pretrain_fs_ckpt: ''
|
67 |
+
warmup_updates: 2000
|
68 |
+
max_tokens: 32000
|
69 |
+
max_sentences: 100000
|
70 |
+
max_eval_sentences: 1
|
71 |
+
max_updates: 120000
|
72 |
+
num_valid_plots: 5
|
73 |
+
num_test_samples: 0
|
74 |
+
test_ids: []
|
75 |
+
use_gt_dur: false
|
76 |
+
use_gt_f0: false
|
77 |
+
|
78 |
+
# exp
|
79 |
+
dur_loss: mse # huber|mol
|
80 |
+
norm_type: gn
|
configs/tts/hifigan.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_config: configs/tts/pwg.yaml
|
2 |
+
task_cls: tasks.vocoder.hifigan.HifiGanTask
|
3 |
+
resblock: "1"
|
4 |
+
adam_b1: 0.8
|
5 |
+
adam_b2: 0.99
|
6 |
+
upsample_rates: [ 8,8,2,2 ]
|
7 |
+
upsample_kernel_sizes: [ 16,16,4,4 ]
|
8 |
+
upsample_initial_channel: 128
|
9 |
+
resblock_kernel_sizes: [ 3,7,11 ]
|
10 |
+
resblock_dilation_sizes: [ [ 1,3,5 ], [ 1,3,5 ], [ 1,3,5 ] ]
|
11 |
+
|
12 |
+
lambda_mel: 45.0
|
13 |
+
|
14 |
+
max_samples: 8192
|
15 |
+
max_sentences: 16
|
16 |
+
|
17 |
+
generator_params:
|
18 |
+
lr: 0.0002 # Generator's learning rate.
|
19 |
+
aux_context_window: 0 # Context window size for auxiliary feature.
|
20 |
+
discriminator_optimizer_params:
|
21 |
+
lr: 0.0002 # Discriminator's learning rate.
|
configs/tts/lj/base_mel2wav.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
raw_data_dir: 'data/raw/LJSpeech-1.1'
|
2 |
+
processed_data_dir: 'data/processed/ljspeech'
|
3 |
+
binary_data_dir: 'data/binary/ljspeech_wav'
|
configs/tts/lj/base_text2mel.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
raw_data_dir: 'data/raw/LJSpeech-1.1'
|
2 |
+
processed_data_dir: 'data/processed/ljspeech'
|
3 |
+
binary_data_dir: 'data/binary/ljspeech'
|
4 |
+
pre_align_cls: data_gen.tts.lj.pre_align.LJPreAlign
|
5 |
+
|
6 |
+
pitch_type: cwt
|
7 |
+
mel_loss: l1
|
8 |
+
num_test_samples: 20
|
9 |
+
test_ids: [ 68, 70, 74, 87, 110, 172, 190, 215, 231, 294,
|
10 |
+
316, 324, 402, 422, 485, 500, 505, 508, 509, 519 ]
|
11 |
+
use_energy_embed: false
|
12 |
+
test_num: 523
|
13 |
+
valid_num: 348
|
configs/tts/lj/fs2.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
base_config:
|
2 |
+
- configs/tts/fs2.yaml
|
3 |
+
- configs/tts/lj/base_text2mel.yaml
|
configs/tts/lj/hifigan.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
base_config:
|
2 |
+
- configs/tts/hifigan.yaml
|
3 |
+
- configs/tts/lj/base_mel2wav.yaml
|
configs/tts/lj/pwg.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
base_config:
|
2 |
+
- configs/tts/pwg.yaml
|
3 |
+
- configs/tts/lj/base_mel2wav.yaml
|
configs/tts/pwg.yaml
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_config: configs/tts/base.yaml
|
2 |
+
task_cls: tasks.vocoder.pwg.PwgTask
|
3 |
+
|
4 |
+
binarization_args:
|
5 |
+
with_wav: true
|
6 |
+
with_spk_embed: false
|
7 |
+
with_align: false
|
8 |
+
test_input_dir: ''
|
9 |
+
|
10 |
+
###########
|
11 |
+
# train and eval
|
12 |
+
###########
|
13 |
+
max_samples: 25600
|
14 |
+
max_sentences: 5
|
15 |
+
max_eval_sentences: 1
|
16 |
+
max_updates: 1000000
|
17 |
+
val_check_interval: 2000
|
18 |
+
|
19 |
+
|
20 |
+
###########################################################
|
21 |
+
# FEATURE EXTRACTION SETTING #
|
22 |
+
###########################################################
|
23 |
+
sampling_rate: 22050 # Sampling rate.
|
24 |
+
fft_size: 1024 # FFT size.
|
25 |
+
hop_size: 256 # Hop size.
|
26 |
+
win_length: null # Window length.
|
27 |
+
# If set to null, it will be the same as fft_size.
|
28 |
+
window: "hann" # Window function.
|
29 |
+
num_mels: 80 # Number of mel basis.
|
30 |
+
fmin: 80 # Minimum freq in mel basis calculation.
|
31 |
+
fmax: 7600 # Maximum frequency in mel basis calculation.
|
32 |
+
format: "hdf5" # Feature file format. "npy" or "hdf5" is supported.
|
33 |
+
|
34 |
+
###########################################################
|
35 |
+
# GENERATOR NETWORK ARCHITECTURE SETTING #
|
36 |
+
###########################################################
|
37 |
+
generator_params:
|
38 |
+
in_channels: 1 # Number of input channels.
|
39 |
+
out_channels: 1 # Number of output channels.
|
40 |
+
kernel_size: 3 # Kernel size of dilated convolution.
|
41 |
+
layers: 30 # Number of residual block layers.
|
42 |
+
stacks: 3 # Number of stacks i.e., dilation cycles.
|
43 |
+
residual_channels: 64 # Number of channels in residual conv.
|
44 |
+
gate_channels: 128 # Number of channels in gated conv.
|
45 |
+
skip_channels: 64 # Number of channels in skip conv.
|
46 |
+
aux_channels: 80 # Number of channels for auxiliary feature conv.
|
47 |
+
# Must be the same as num_mels.
|
48 |
+
aux_context_window: 2 # Context window size for auxiliary feature.
|
49 |
+
# If set to 2, previous 2 and future 2 frames will be considered.
|
50 |
+
dropout: 0.0 # Dropout rate. 0.0 means no dropout applied.
|
51 |
+
use_weight_norm: true # Whether to use weight norm.
|
52 |
+
# If set to true, it will be applied to all of the conv layers.
|
53 |
+
upsample_net: "ConvInUpsampleNetwork" # Upsampling network architecture.
|
54 |
+
upsample_params: # Upsampling network parameters.
|
55 |
+
upsample_scales: [4, 4, 4, 4] # Upsampling scales. Prodcut of these must be the same as hop size.
|
56 |
+
use_pitch_embed: false
|
57 |
+
|
58 |
+
###########################################################
|
59 |
+
# DISCRIMINATOR NETWORK ARCHITECTURE SETTING #
|
60 |
+
###########################################################
|
61 |
+
discriminator_params:
|
62 |
+
in_channels: 1 # Number of input channels.
|
63 |
+
out_channels: 1 # Number of output channels.
|
64 |
+
kernel_size: 3 # Number of output channels.
|
65 |
+
layers: 10 # Number of conv layers.
|
66 |
+
conv_channels: 64 # Number of chnn layers.
|
67 |
+
bias: true # Whether to use bias parameter in conv.
|
68 |
+
use_weight_norm: true # Whether to use weight norm.
|
69 |
+
# If set to true, it will be applied to all of the conv layers.
|
70 |
+
nonlinear_activation: "LeakyReLU" # Nonlinear function after each conv.
|
71 |
+
nonlinear_activation_params: # Nonlinear function parameters
|
72 |
+
negative_slope: 0.2 # Alpha in LeakyReLU.
|
73 |
+
|
74 |
+
###########################################################
|
75 |
+
# STFT LOSS SETTING #
|
76 |
+
###########################################################
|
77 |
+
stft_loss_params:
|
78 |
+
fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss.
|
79 |
+
hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss
|
80 |
+
win_lengths: [600, 1200, 240] # List of window length for STFT-based loss.
|
81 |
+
window: "hann_window" # Window function for STFT-based loss
|
82 |
+
use_mel_loss: false
|
83 |
+
|
84 |
+
###########################################################
|
85 |
+
# ADVERSARIAL LOSS SETTING #
|
86 |
+
###########################################################
|
87 |
+
lambda_adv: 4.0 # Loss balancing coefficient.
|
88 |
+
|
89 |
+
###########################################################
|
90 |
+
# OPTIMIZER & SCHEDULER SETTING #
|
91 |
+
###########################################################
|
92 |
+
generator_optimizer_params:
|
93 |
+
lr: 0.0001 # Generator's learning rate.
|
94 |
+
eps: 1.0e-6 # Generator's epsilon.
|
95 |
+
weight_decay: 0.0 # Generator's weight decay coefficient.
|
96 |
+
generator_scheduler_params:
|
97 |
+
step_size: 200000 # Generator's scheduler step size.
|
98 |
+
gamma: 0.5 # Generator's scheduler gamma.
|
99 |
+
# At each step size, lr will be multiplied by this parameter.
|
100 |
+
generator_grad_norm: 10 # Generator's gradient norm.
|
101 |
+
discriminator_optimizer_params:
|
102 |
+
lr: 0.00005 # Discriminator's learning rate.
|
103 |
+
eps: 1.0e-6 # Discriminator's epsilon.
|
104 |
+
weight_decay: 0.0 # Discriminator's weight decay coefficient.
|
105 |
+
discriminator_scheduler_params:
|
106 |
+
step_size: 200000 # Discriminator's scheduler step size.
|
107 |
+
gamma: 0.5 # Discriminator's scheduler gamma.
|
108 |
+
# At each step size, lr will be multiplied by this parameter.
|
109 |
+
discriminator_grad_norm: 1 # Discriminator's gradient norm.
|
110 |
+
disc_start_steps: 40000 # Number of steps to start to train discriminator.
|
data/processed/ljspeech/dict.txt
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
! !
|
2 |
+
, ,
|
3 |
+
. .
|
4 |
+
; ;
|
5 |
+
<BOS> <BOS>
|
6 |
+
<EOS> <EOS>
|
7 |
+
? ?
|
8 |
+
AA0 AA0
|
9 |
+
AA1 AA1
|
10 |
+
AA2 AA2
|
11 |
+
AE0 AE0
|
12 |
+
AE1 AE1
|
13 |
+
AE2 AE2
|
14 |
+
AH0 AH0
|
15 |
+
AH1 AH1
|
16 |
+
AH2 AH2
|
17 |
+
AO0 AO0
|
18 |
+
AO1 AO1
|
19 |
+
AO2 AO2
|
20 |
+
AW0 AW0
|
21 |
+
AW1 AW1
|
22 |
+
AW2 AW2
|
23 |
+
AY0 AY0
|
24 |
+
AY1 AY1
|
25 |
+
AY2 AY2
|
26 |
+
B B
|
27 |
+
CH CH
|
28 |
+
D D
|
29 |
+
DH DH
|
30 |
+
EH0 EH0
|
31 |
+
EH1 EH1
|
32 |
+
EH2 EH2
|
33 |
+
ER0 ER0
|
34 |
+
ER1 ER1
|
35 |
+
ER2 ER2
|
36 |
+
EY0 EY0
|
37 |
+
EY1 EY1
|
38 |
+
EY2 EY2
|
39 |
+
F F
|
40 |
+
G G
|
41 |
+
HH HH
|
42 |
+
IH0 IH0
|
43 |
+
IH1 IH1
|
44 |
+
IH2 IH2
|
45 |
+
IY0 IY0
|
46 |
+
IY1 IY1
|
47 |
+
IY2 IY2
|
48 |
+
JH JH
|
49 |
+
K K
|
50 |
+
L L
|
51 |
+
M M
|
52 |
+
N N
|
53 |
+
NG NG
|
54 |
+
OW0 OW0
|
55 |
+
OW1 OW1
|
56 |
+
OW2 OW2
|
57 |
+
OY0 OY0
|
58 |
+
OY1 OY1
|
59 |
+
OY2 OY2
|
60 |
+
P P
|
61 |
+
R R
|
62 |
+
S S
|
63 |
+
SH SH
|
64 |
+
T T
|
65 |
+
TH TH
|
66 |
+
UH0 UH0
|
67 |
+
UH1 UH1
|
68 |
+
UH2 UH2
|
69 |
+
UW0 UW0
|
70 |
+
UW1 UW1
|
71 |
+
UW2 UW2
|
72 |
+
V V
|
73 |
+
W W
|
74 |
+
Y Y
|
75 |
+
Z Z
|
76 |
+
ZH ZH
|
77 |
+
| |
|
data/processed/ljspeech/metadata_phone.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/processed/ljspeech/mfa_dict.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/processed/ljspeech/phone_set.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
["!", ",", ".", ";", "<BOS>", "<EOS>", "?", "AA0", "AA1", "AA2", "AE0", "AE1", "AE2", "AH0", "AH1", "AH2", "AO0", "AO1", "AO2", "AW0", "AW1", "AW2", "AY0", "AY1", "AY2", "B", "CH", "D", "DH", "EH0", "EH1", "EH2", "ER0", "ER1", "ER2", "EY0", "EY1", "EY2", "F", "G", "HH", "IH0", "IH1", "IH2", "IY0", "IY1", "IY2", "JH", "K", "L", "M", "N", "NG", "OW0", "OW1", "OW2", "OY0", "OY1", "OY2", "P", "R", "S", "SH", "T", "TH", "UH0", "UH1", "UH2", "UW0", "UW1", "UW2", "V", "W", "Y", "Z", "ZH", "|"]
|
data_gen/singing/binarize.py
ADDED
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
from copy import deepcopy
|
4 |
+
import pandas as pd
|
5 |
+
import logging
|
6 |
+
from tqdm import tqdm
|
7 |
+
import json
|
8 |
+
import glob
|
9 |
+
import re
|
10 |
+
from resemblyzer import VoiceEncoder
|
11 |
+
import traceback
|
12 |
+
import numpy as np
|
13 |
+
import pretty_midi
|
14 |
+
import librosa
|
15 |
+
from scipy.interpolate import interp1d
|
16 |
+
import torch
|
17 |
+
from textgrid import TextGrid
|
18 |
+
|
19 |
+
from utils.hparams import hparams
|
20 |
+
from data_gen.tts.data_gen_utils import build_phone_encoder, get_pitch
|
21 |
+
from utils.pitch_utils import f0_to_coarse
|
22 |
+
from data_gen.tts.base_binarizer import BaseBinarizer, BinarizationError
|
23 |
+
from data_gen.tts.binarizer_zh import ZhBinarizer
|
24 |
+
from data_gen.tts.txt_processors.zh_g2pM import ALL_YUNMU
|
25 |
+
from vocoders.base_vocoder import VOCODERS
|
26 |
+
|
27 |
+
|
28 |
+
class SingingBinarizer(BaseBinarizer):
|
29 |
+
def __init__(self, processed_data_dir=None):
|
30 |
+
if processed_data_dir is None:
|
31 |
+
processed_data_dir = hparams['processed_data_dir']
|
32 |
+
self.processed_data_dirs = processed_data_dir.split(",")
|
33 |
+
self.binarization_args = hparams['binarization_args']
|
34 |
+
self.pre_align_args = hparams['pre_align_args']
|
35 |
+
self.item2txt = {}
|
36 |
+
self.item2ph = {}
|
37 |
+
self.item2wavfn = {}
|
38 |
+
self.item2f0fn = {}
|
39 |
+
self.item2tgfn = {}
|
40 |
+
self.item2spk = {}
|
41 |
+
|
42 |
+
def split_train_test_set(self, item_names):
|
43 |
+
item_names = deepcopy(item_names)
|
44 |
+
test_item_names = [x for x in item_names if any([ts in x for ts in hparams['test_prefixes']])]
|
45 |
+
train_item_names = [x for x in item_names if x not in set(test_item_names)]
|
46 |
+
logging.info("train {}".format(len(train_item_names)))
|
47 |
+
logging.info("test {}".format(len(test_item_names)))
|
48 |
+
return train_item_names, test_item_names
|
49 |
+
|
50 |
+
def load_meta_data(self):
|
51 |
+
for ds_id, processed_data_dir in enumerate(self.processed_data_dirs):
|
52 |
+
wav_suffix = '_wf0.wav'
|
53 |
+
txt_suffix = '.txt'
|
54 |
+
ph_suffix = '_ph.txt'
|
55 |
+
tg_suffix = '.TextGrid'
|
56 |
+
all_wav_pieces = glob.glob(f'{processed_data_dir}/*/*{wav_suffix}')
|
57 |
+
|
58 |
+
for piece_path in all_wav_pieces:
|
59 |
+
item_name = raw_item_name = piece_path[len(processed_data_dir)+1:].replace('/', '-')[:-len(wav_suffix)]
|
60 |
+
if len(self.processed_data_dirs) > 1:
|
61 |
+
item_name = f'ds{ds_id}_{item_name}'
|
62 |
+
self.item2txt[item_name] = open(f'{piece_path.replace(wav_suffix, txt_suffix)}').readline()
|
63 |
+
self.item2ph[item_name] = open(f'{piece_path.replace(wav_suffix, ph_suffix)}').readline()
|
64 |
+
self.item2wavfn[item_name] = piece_path
|
65 |
+
|
66 |
+
self.item2spk[item_name] = re.split('-|#', piece_path.split('/')[-2])[0]
|
67 |
+
if len(self.processed_data_dirs) > 1:
|
68 |
+
self.item2spk[item_name] = f"ds{ds_id}_{self.item2spk[item_name]}"
|
69 |
+
self.item2tgfn[item_name] = piece_path.replace(wav_suffix, tg_suffix)
|
70 |
+
print('spkers: ', set(self.item2spk.values()))
|
71 |
+
self.item_names = sorted(list(self.item2txt.keys()))
|
72 |
+
if self.binarization_args['shuffle']:
|
73 |
+
random.seed(1234)
|
74 |
+
random.shuffle(self.item_names)
|
75 |
+
self._train_item_names, self._test_item_names = self.split_train_test_set(self.item_names)
|
76 |
+
|
77 |
+
@property
|
78 |
+
def train_item_names(self):
|
79 |
+
return self._train_item_names
|
80 |
+
|
81 |
+
@property
|
82 |
+
def valid_item_names(self):
|
83 |
+
return self._test_item_names
|
84 |
+
|
85 |
+
@property
|
86 |
+
def test_item_names(self):
|
87 |
+
return self._test_item_names
|
88 |
+
|
89 |
+
def process(self):
|
90 |
+
self.load_meta_data()
|
91 |
+
os.makedirs(hparams['binary_data_dir'], exist_ok=True)
|
92 |
+
self.spk_map = self.build_spk_map()
|
93 |
+
print("| spk_map: ", self.spk_map)
|
94 |
+
spk_map_fn = f"{hparams['binary_data_dir']}/spk_map.json"
|
95 |
+
json.dump(self.spk_map, open(spk_map_fn, 'w'))
|
96 |
+
|
97 |
+
self.phone_encoder = self._phone_encoder()
|
98 |
+
self.process_data('valid')
|
99 |
+
self.process_data('test')
|
100 |
+
self.process_data('train')
|
101 |
+
|
102 |
+
def _phone_encoder(self):
|
103 |
+
ph_set_fn = f"{hparams['binary_data_dir']}/phone_set.json"
|
104 |
+
ph_set = []
|
105 |
+
if hparams['reset_phone_dict'] or not os.path.exists(ph_set_fn):
|
106 |
+
for ph_sent in self.item2ph.values():
|
107 |
+
ph_set += ph_sent.split(' ')
|
108 |
+
ph_set = sorted(set(ph_set))
|
109 |
+
json.dump(ph_set, open(ph_set_fn, 'w'))
|
110 |
+
print("| Build phone set: ", ph_set)
|
111 |
+
else:
|
112 |
+
ph_set = json.load(open(ph_set_fn, 'r'))
|
113 |
+
print("| Load phone set: ", ph_set)
|
114 |
+
return build_phone_encoder(hparams['binary_data_dir'])
|
115 |
+
|
116 |
+
# @staticmethod
|
117 |
+
# def get_pitch(wav_fn, spec, res):
|
118 |
+
# wav_suffix = '_wf0.wav'
|
119 |
+
# f0_suffix = '_f0.npy'
|
120 |
+
# f0fn = wav_fn.replace(wav_suffix, f0_suffix)
|
121 |
+
# pitch_info = np.load(f0fn)
|
122 |
+
# f0 = [x[1] for x in pitch_info]
|
123 |
+
# spec_x_coor = np.arange(0, 1, 1 / len(spec))[:len(spec)]
|
124 |
+
# f0_x_coor = np.arange(0, 1, 1 / len(f0))[:len(f0)]
|
125 |
+
# f0 = interp1d(f0_x_coor, f0, 'nearest', fill_value='extrapolate')(spec_x_coor)[:len(spec)]
|
126 |
+
# # f0_x_coor = np.arange(0, 1, 1 / len(f0))
|
127 |
+
# # f0_x_coor[-1] = 1
|
128 |
+
# # f0 = interp1d(f0_x_coor, f0, 'nearest')(spec_x_coor)[:len(spec)]
|
129 |
+
# if sum(f0) == 0:
|
130 |
+
# raise BinarizationError("Empty f0")
|
131 |
+
# assert len(f0) == len(spec), (len(f0), len(spec))
|
132 |
+
# pitch_coarse = f0_to_coarse(f0)
|
133 |
+
#
|
134 |
+
# # vis f0
|
135 |
+
# # import matplotlib.pyplot as plt
|
136 |
+
# # from textgrid import TextGrid
|
137 |
+
# # tg_fn = wav_fn.replace(wav_suffix, '.TextGrid')
|
138 |
+
# # fig = plt.figure(figsize=(12, 6))
|
139 |
+
# # plt.pcolor(spec.T, vmin=-5, vmax=0)
|
140 |
+
# # ax = plt.gca()
|
141 |
+
# # ax2 = ax.twinx()
|
142 |
+
# # ax2.plot(f0, color='red')
|
143 |
+
# # ax2.set_ylim(0, 800)
|
144 |
+
# # itvs = TextGrid.fromFile(tg_fn)[0]
|
145 |
+
# # for itv in itvs:
|
146 |
+
# # x = itv.maxTime * hparams['audio_sample_rate'] / hparams['hop_size']
|
147 |
+
# # plt.vlines(x=x, ymin=0, ymax=80, color='black')
|
148 |
+
# # plt.text(x=x, y=20, s=itv.mark, color='black')
|
149 |
+
# # plt.savefig('tmp/20211229_singing_plots_test.png')
|
150 |
+
#
|
151 |
+
# res['f0'] = f0
|
152 |
+
# res['pitch'] = pitch_coarse
|
153 |
+
|
154 |
+
@classmethod
|
155 |
+
def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args):
|
156 |
+
if hparams['vocoder'] in VOCODERS:
|
157 |
+
wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn)
|
158 |
+
else:
|
159 |
+
wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn)
|
160 |
+
res = {
|
161 |
+
'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn,
|
162 |
+
'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id
|
163 |
+
}
|
164 |
+
try:
|
165 |
+
if binarization_args['with_f0']:
|
166 |
+
# cls.get_pitch(wav_fn, mel, res)
|
167 |
+
cls.get_pitch(wav, mel, res)
|
168 |
+
if binarization_args['with_txt']:
|
169 |
+
try:
|
170 |
+
# print(ph)
|
171 |
+
phone_encoded = res['phone'] = encoder.encode(ph)
|
172 |
+
except:
|
173 |
+
traceback.print_exc()
|
174 |
+
raise BinarizationError(f"Empty phoneme")
|
175 |
+
if binarization_args['with_align']:
|
176 |
+
cls.get_align(tg_fn, ph, mel, phone_encoded, res)
|
177 |
+
except BinarizationError as e:
|
178 |
+
print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}")
|
179 |
+
return None
|
180 |
+
return res
|
181 |
+
|
182 |
+
|
183 |
+
class MidiSingingBinarizer(SingingBinarizer):
|
184 |
+
item2midi = {}
|
185 |
+
item2midi_dur = {}
|
186 |
+
item2is_slur = {}
|
187 |
+
item2ph_durs = {}
|
188 |
+
item2wdb = {}
|
189 |
+
|
190 |
+
def load_meta_data(self):
|
191 |
+
for ds_id, processed_data_dir in enumerate(self.processed_data_dirs):
|
192 |
+
meta_midi = json.load(open(os.path.join(processed_data_dir, 'meta.json'))) # [list of dict]
|
193 |
+
|
194 |
+
for song_item in meta_midi:
|
195 |
+
item_name = raw_item_name = song_item['item_name']
|
196 |
+
if len(self.processed_data_dirs) > 1:
|
197 |
+
item_name = f'ds{ds_id}_{item_name}'
|
198 |
+
self.item2wavfn[item_name] = song_item['wav_fn']
|
199 |
+
self.item2txt[item_name] = song_item['txt']
|
200 |
+
|
201 |
+
self.item2ph[item_name] = ' '.join(song_item['phs'])
|
202 |
+
self.item2wdb[item_name] = [1 if x in ALL_YUNMU + ['AP', 'SP', '<SIL>'] else 0 for x in song_item['phs']]
|
203 |
+
self.item2ph_durs[item_name] = song_item['ph_dur']
|
204 |
+
|
205 |
+
self.item2midi[item_name] = song_item['notes']
|
206 |
+
self.item2midi_dur[item_name] = song_item['notes_dur']
|
207 |
+
self.item2is_slur[item_name] = song_item['is_slur']
|
208 |
+
self.item2spk[item_name] = 'pop-cs'
|
209 |
+
if len(self.processed_data_dirs) > 1:
|
210 |
+
self.item2spk[item_name] = f"ds{ds_id}_{self.item2spk[item_name]}"
|
211 |
+
|
212 |
+
print('spkers: ', set(self.item2spk.values()))
|
213 |
+
self.item_names = sorted(list(self.item2txt.keys()))
|
214 |
+
if self.binarization_args['shuffle']:
|
215 |
+
random.seed(1234)
|
216 |
+
random.shuffle(self.item_names)
|
217 |
+
self._train_item_names, self._test_item_names = self.split_train_test_set(self.item_names)
|
218 |
+
|
219 |
+
@staticmethod
|
220 |
+
def get_pitch(wav_fn, wav, spec, ph, res):
|
221 |
+
wav_suffix = '.wav'
|
222 |
+
# midi_suffix = '.mid'
|
223 |
+
wav_dir = 'wavs'
|
224 |
+
f0_dir = 'f0'
|
225 |
+
|
226 |
+
item_name = '/'.join(os.path.splitext(wav_fn)[0].split('/')[-2:]).replace('_wf0', '')
|
227 |
+
res['pitch_midi'] = np.asarray(MidiSingingBinarizer.item2midi[item_name])
|
228 |
+
res['midi_dur'] = np.asarray(MidiSingingBinarizer.item2midi_dur[item_name])
|
229 |
+
res['is_slur'] = np.asarray(MidiSingingBinarizer.item2is_slur[item_name])
|
230 |
+
res['word_boundary'] = np.asarray(MidiSingingBinarizer.item2wdb[item_name])
|
231 |
+
assert res['pitch_midi'].shape == res['midi_dur'].shape == res['is_slur'].shape, (
|
232 |
+
res['pitch_midi'].shape, res['midi_dur'].shape, res['is_slur'].shape)
|
233 |
+
|
234 |
+
# gt f0.
|
235 |
+
gt_f0, gt_pitch_coarse = get_pitch(wav, spec, hparams)
|
236 |
+
if sum(gt_f0) == 0:
|
237 |
+
raise BinarizationError("Empty **gt** f0")
|
238 |
+
res['f0'] = gt_f0
|
239 |
+
res['pitch'] = gt_pitch_coarse
|
240 |
+
|
241 |
+
@staticmethod
|
242 |
+
def get_align(ph_durs, mel, phone_encoded, res, hop_size=hparams['hop_size'], audio_sample_rate=hparams['audio_sample_rate']):
|
243 |
+
mel2ph = np.zeros([mel.shape[0]], int)
|
244 |
+
startTime = 0
|
245 |
+
|
246 |
+
for i_ph in range(len(ph_durs)):
|
247 |
+
start_frame = int(startTime * audio_sample_rate / hop_size + 0.5)
|
248 |
+
end_frame = int((startTime + ph_durs[i_ph]) * audio_sample_rate / hop_size + 0.5)
|
249 |
+
mel2ph[start_frame:end_frame] = i_ph + 1
|
250 |
+
startTime = startTime + ph_durs[i_ph]
|
251 |
+
|
252 |
+
# print('ph durs: ', ph_durs)
|
253 |
+
# print('mel2ph: ', mel2ph, len(mel2ph))
|
254 |
+
res['mel2ph'] = mel2ph
|
255 |
+
# res['dur'] = None
|
256 |
+
|
257 |
+
@classmethod
|
258 |
+
def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args):
|
259 |
+
if hparams['vocoder'] in VOCODERS:
|
260 |
+
wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn)
|
261 |
+
else:
|
262 |
+
wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn)
|
263 |
+
res = {
|
264 |
+
'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn,
|
265 |
+
'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id
|
266 |
+
}
|
267 |
+
try:
|
268 |
+
if binarization_args['with_f0']:
|
269 |
+
cls.get_pitch(wav_fn, wav, mel, ph, res)
|
270 |
+
if binarization_args['with_txt']:
|
271 |
+
try:
|
272 |
+
phone_encoded = res['phone'] = encoder.encode(ph)
|
273 |
+
except:
|
274 |
+
traceback.print_exc()
|
275 |
+
raise BinarizationError(f"Empty phoneme")
|
276 |
+
if binarization_args['with_align']:
|
277 |
+
cls.get_align(MidiSingingBinarizer.item2ph_durs[item_name], mel, phone_encoded, res)
|
278 |
+
except BinarizationError as e:
|
279 |
+
print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}")
|
280 |
+
return None
|
281 |
+
return res
|
282 |
+
|
283 |
+
|
284 |
+
class ZhSingingBinarizer(ZhBinarizer, SingingBinarizer):
|
285 |
+
pass
|
286 |
+
|
287 |
+
|
288 |
+
class OpencpopBinarizer(MidiSingingBinarizer):
|
289 |
+
item2midi = {}
|
290 |
+
item2midi_dur = {}
|
291 |
+
item2is_slur = {}
|
292 |
+
item2ph_durs = {}
|
293 |
+
item2wdb = {}
|
294 |
+
|
295 |
+
def split_train_test_set(self, item_names):
|
296 |
+
item_names = deepcopy(item_names)
|
297 |
+
test_item_names = [x for x in item_names if any([x.startswith(ts) for ts in hparams['test_prefixes']])]
|
298 |
+
train_item_names = [x for x in item_names if x not in set(test_item_names)]
|
299 |
+
logging.info("train {}".format(len(train_item_names)))
|
300 |
+
logging.info("test {}".format(len(test_item_names)))
|
301 |
+
return train_item_names, test_item_names
|
302 |
+
|
303 |
+
def load_meta_data(self):
|
304 |
+
raw_data_dir = hparams['raw_data_dir']
|
305 |
+
# meta_midi = json.load(open(os.path.join(raw_data_dir, 'meta.json'))) # [list of dict]
|
306 |
+
utterance_labels = open(os.path.join(raw_data_dir, 'transcriptions.txt')).readlines()
|
307 |
+
|
308 |
+
for utterance_label in utterance_labels:
|
309 |
+
song_info = utterance_label.split('|')
|
310 |
+
item_name = raw_item_name = song_info[0]
|
311 |
+
self.item2wavfn[item_name] = f'{raw_data_dir}/wavs/{item_name}.wav'
|
312 |
+
self.item2txt[item_name] = song_info[1]
|
313 |
+
|
314 |
+
self.item2ph[item_name] = song_info[2]
|
315 |
+
# self.item2wdb[item_name] = list(np.nonzero([1 if x in ALL_YUNMU + ['AP', 'SP'] else 0 for x in song_info[2].split()])[0])
|
316 |
+
self.item2wdb[item_name] = [1 if x in ALL_YUNMU + ['AP', 'SP'] else 0 for x in song_info[2].split()]
|
317 |
+
self.item2ph_durs[item_name] = [float(x) for x in song_info[5].split(" ")]
|
318 |
+
|
319 |
+
self.item2midi[item_name] = [librosa.note_to_midi(x.split("/")[0]) if x != 'rest' else 0
|
320 |
+
for x in song_info[3].split(" ")]
|
321 |
+
self.item2midi_dur[item_name] = [float(x) for x in song_info[4].split(" ")]
|
322 |
+
self.item2is_slur[item_name] = [int(x) for x in song_info[6].split(" ")]
|
323 |
+
self.item2spk[item_name] = 'opencpop'
|
324 |
+
|
325 |
+
print('spkers: ', set(self.item2spk.values()))
|
326 |
+
self.item_names = sorted(list(self.item2txt.keys()))
|
327 |
+
if self.binarization_args['shuffle']:
|
328 |
+
random.seed(1234)
|
329 |
+
random.shuffle(self.item_names)
|
330 |
+
self._train_item_names, self._test_item_names = self.split_train_test_set(self.item_names)
|
331 |
+
|
332 |
+
@staticmethod
|
333 |
+
def get_pitch(wav_fn, wav, spec, ph, res):
|
334 |
+
wav_suffix = '.wav'
|
335 |
+
# midi_suffix = '.mid'
|
336 |
+
wav_dir = 'wavs'
|
337 |
+
f0_dir = 'text_f0_align'
|
338 |
+
|
339 |
+
item_name = os.path.splitext(os.path.basename(wav_fn))[0]
|
340 |
+
res['pitch_midi'] = np.asarray(OpencpopBinarizer.item2midi[item_name])
|
341 |
+
res['midi_dur'] = np.asarray(OpencpopBinarizer.item2midi_dur[item_name])
|
342 |
+
res['is_slur'] = np.asarray(OpencpopBinarizer.item2is_slur[item_name])
|
343 |
+
res['word_boundary'] = np.asarray(OpencpopBinarizer.item2wdb[item_name])
|
344 |
+
assert res['pitch_midi'].shape == res['midi_dur'].shape == res['is_slur'].shape, (res['pitch_midi'].shape, res['midi_dur'].shape, res['is_slur'].shape)
|
345 |
+
|
346 |
+
# gt f0.
|
347 |
+
# f0 = None
|
348 |
+
# f0_suffix = '_f0.npy'
|
349 |
+
# f0fn = wav_fn.replace(wav_suffix, f0_suffix).replace(wav_dir, f0_dir)
|
350 |
+
# pitch_info = np.load(f0fn)
|
351 |
+
# f0 = [x[1] for x in pitch_info]
|
352 |
+
# spec_x_coor = np.arange(0, 1, 1 / len(spec))[:len(spec)]
|
353 |
+
#
|
354 |
+
# f0_x_coor = np.arange(0, 1, 1 / len(f0))[:len(f0)]
|
355 |
+
# f0 = interp1d(f0_x_coor, f0, 'nearest', fill_value='extrapolate')(spec_x_coor)[:len(spec)]
|
356 |
+
# if sum(f0) == 0:
|
357 |
+
# raise BinarizationError("Empty **gt** f0")
|
358 |
+
#
|
359 |
+
# pitch_coarse = f0_to_coarse(f0)
|
360 |
+
# res['f0'] = f0
|
361 |
+
# res['pitch'] = pitch_coarse
|
362 |
+
|
363 |
+
# gt f0.
|
364 |
+
gt_f0, gt_pitch_coarse = get_pitch(wav, spec, hparams)
|
365 |
+
if sum(gt_f0) == 0:
|
366 |
+
raise BinarizationError("Empty **gt** f0")
|
367 |
+
res['f0'] = gt_f0
|
368 |
+
res['pitch'] = gt_pitch_coarse
|
369 |
+
|
370 |
+
@classmethod
|
371 |
+
def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args):
|
372 |
+
if hparams['vocoder'] in VOCODERS:
|
373 |
+
wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn)
|
374 |
+
else:
|
375 |
+
wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn)
|
376 |
+
res = {
|
377 |
+
'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn,
|
378 |
+
'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id
|
379 |
+
}
|
380 |
+
try:
|
381 |
+
if binarization_args['with_f0']:
|
382 |
+
cls.get_pitch(wav_fn, wav, mel, ph, res)
|
383 |
+
if binarization_args['with_txt']:
|
384 |
+
try:
|
385 |
+
phone_encoded = res['phone'] = encoder.encode(ph)
|
386 |
+
except:
|
387 |
+
traceback.print_exc()
|
388 |
+
raise BinarizationError(f"Empty phoneme")
|
389 |
+
if binarization_args['with_align']:
|
390 |
+
cls.get_align(OpencpopBinarizer.item2ph_durs[item_name], mel, phone_encoded, res)
|
391 |
+
except BinarizationError as e:
|
392 |
+
print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}")
|
393 |
+
return None
|
394 |
+
return res
|
395 |
+
|
396 |
+
|
397 |
+
if __name__ == "__main__":
|
398 |
+
SingingBinarizer().process()
|
data_gen/tts/base_binarizer.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
3 |
+
|
4 |
+
from utils.multiprocess_utils import chunked_multiprocess_run
|
5 |
+
import random
|
6 |
+
import traceback
|
7 |
+
import json
|
8 |
+
from resemblyzer import VoiceEncoder
|
9 |
+
from tqdm import tqdm
|
10 |
+
from data_gen.tts.data_gen_utils import get_mel2ph, get_pitch, build_phone_encoder
|
11 |
+
from utils.hparams import set_hparams, hparams
|
12 |
+
import numpy as np
|
13 |
+
from utils.indexed_datasets import IndexedDatasetBuilder
|
14 |
+
from vocoders.base_vocoder import VOCODERS
|
15 |
+
import pandas as pd
|
16 |
+
|
17 |
+
|
18 |
+
class BinarizationError(Exception):
|
19 |
+
pass
|
20 |
+
|
21 |
+
|
22 |
+
class BaseBinarizer:
|
23 |
+
def __init__(self, processed_data_dir=None):
|
24 |
+
if processed_data_dir is None:
|
25 |
+
processed_data_dir = hparams['processed_data_dir']
|
26 |
+
self.processed_data_dirs = processed_data_dir.split(",")
|
27 |
+
self.binarization_args = hparams['binarization_args']
|
28 |
+
self.pre_align_args = hparams['pre_align_args']
|
29 |
+
self.forced_align = self.pre_align_args['forced_align']
|
30 |
+
tg_dir = None
|
31 |
+
if self.forced_align == 'mfa':
|
32 |
+
tg_dir = 'mfa_outputs'
|
33 |
+
if self.forced_align == 'kaldi':
|
34 |
+
tg_dir = 'kaldi_outputs'
|
35 |
+
self.item2txt = {}
|
36 |
+
self.item2ph = {}
|
37 |
+
self.item2wavfn = {}
|
38 |
+
self.item2tgfn = {}
|
39 |
+
self.item2spk = {}
|
40 |
+
for ds_id, processed_data_dir in enumerate(self.processed_data_dirs):
|
41 |
+
self.meta_df = pd.read_csv(f"{processed_data_dir}/metadata_phone.csv", dtype=str)
|
42 |
+
for r_idx, r in self.meta_df.iterrows():
|
43 |
+
item_name = raw_item_name = r['item_name']
|
44 |
+
if len(self.processed_data_dirs) > 1:
|
45 |
+
item_name = f'ds{ds_id}_{item_name}'
|
46 |
+
self.item2txt[item_name] = r['txt']
|
47 |
+
self.item2ph[item_name] = r['ph']
|
48 |
+
self.item2wavfn[item_name] = os.path.join(hparams['raw_data_dir'], 'wavs', os.path.basename(r['wav_fn']).split('_')[1])
|
49 |
+
self.item2spk[item_name] = r.get('spk', 'SPK1')
|
50 |
+
if len(self.processed_data_dirs) > 1:
|
51 |
+
self.item2spk[item_name] = f"ds{ds_id}_{self.item2spk[item_name]}"
|
52 |
+
if tg_dir is not None:
|
53 |
+
self.item2tgfn[item_name] = f"{processed_data_dir}/{tg_dir}/{raw_item_name}.TextGrid"
|
54 |
+
self.item_names = sorted(list(self.item2txt.keys()))
|
55 |
+
if self.binarization_args['shuffle']:
|
56 |
+
random.seed(1234)
|
57 |
+
random.shuffle(self.item_names)
|
58 |
+
|
59 |
+
@property
|
60 |
+
def train_item_names(self):
|
61 |
+
return self.item_names[hparams['test_num']+hparams['valid_num']:]
|
62 |
+
|
63 |
+
@property
|
64 |
+
def valid_item_names(self):
|
65 |
+
return self.item_names[0: hparams['test_num']+hparams['valid_num']] #
|
66 |
+
|
67 |
+
@property
|
68 |
+
def test_item_names(self):
|
69 |
+
return self.item_names[0: hparams['test_num']] # Audios for MOS testing are in 'test_ids'
|
70 |
+
|
71 |
+
def build_spk_map(self):
|
72 |
+
spk_map = set()
|
73 |
+
for item_name in self.item_names:
|
74 |
+
spk_name = self.item2spk[item_name]
|
75 |
+
spk_map.add(spk_name)
|
76 |
+
spk_map = {x: i for i, x in enumerate(sorted(list(spk_map)))}
|
77 |
+
assert len(spk_map) == 0 or len(spk_map) <= hparams['num_spk'], len(spk_map)
|
78 |
+
return spk_map
|
79 |
+
|
80 |
+
def item_name2spk_id(self, item_name):
|
81 |
+
return self.spk_map[self.item2spk[item_name]]
|
82 |
+
|
83 |
+
def _phone_encoder(self):
|
84 |
+
ph_set_fn = f"{hparams['binary_data_dir']}/phone_set.json"
|
85 |
+
ph_set = []
|
86 |
+
if hparams['reset_phone_dict'] or not os.path.exists(ph_set_fn):
|
87 |
+
for processed_data_dir in self.processed_data_dirs:
|
88 |
+
ph_set += [x.split(' ')[0] for x in open(f'{processed_data_dir}/dict.txt').readlines()]
|
89 |
+
ph_set = sorted(set(ph_set))
|
90 |
+
json.dump(ph_set, open(ph_set_fn, 'w'))
|
91 |
+
else:
|
92 |
+
ph_set = json.load(open(ph_set_fn, 'r'))
|
93 |
+
print("| phone set: ", ph_set)
|
94 |
+
return build_phone_encoder(hparams['binary_data_dir'])
|
95 |
+
|
96 |
+
def meta_data(self, prefix):
|
97 |
+
if prefix == 'valid':
|
98 |
+
item_names = self.valid_item_names
|
99 |
+
elif prefix == 'test':
|
100 |
+
item_names = self.test_item_names
|
101 |
+
else:
|
102 |
+
item_names = self.train_item_names
|
103 |
+
for item_name in item_names:
|
104 |
+
ph = self.item2ph[item_name]
|
105 |
+
txt = self.item2txt[item_name]
|
106 |
+
tg_fn = self.item2tgfn.get(item_name)
|
107 |
+
wav_fn = self.item2wavfn[item_name]
|
108 |
+
spk_id = self.item_name2spk_id(item_name)
|
109 |
+
yield item_name, ph, txt, tg_fn, wav_fn, spk_id
|
110 |
+
|
111 |
+
def process(self):
|
112 |
+
os.makedirs(hparams['binary_data_dir'], exist_ok=True)
|
113 |
+
self.spk_map = self.build_spk_map()
|
114 |
+
print("| spk_map: ", self.spk_map)
|
115 |
+
spk_map_fn = f"{hparams['binary_data_dir']}/spk_map.json"
|
116 |
+
json.dump(self.spk_map, open(spk_map_fn, 'w'))
|
117 |
+
|
118 |
+
self.phone_encoder = self._phone_encoder()
|
119 |
+
self.process_data('valid')
|
120 |
+
self.process_data('test')
|
121 |
+
self.process_data('train')
|
122 |
+
|
123 |
+
def process_data(self, prefix):
|
124 |
+
data_dir = hparams['binary_data_dir']
|
125 |
+
args = []
|
126 |
+
builder = IndexedDatasetBuilder(f'{data_dir}/{prefix}')
|
127 |
+
lengths = []
|
128 |
+
f0s = []
|
129 |
+
total_sec = 0
|
130 |
+
if self.binarization_args['with_spk_embed']:
|
131 |
+
voice_encoder = VoiceEncoder().cuda()
|
132 |
+
|
133 |
+
meta_data = list(self.meta_data(prefix))
|
134 |
+
for m in meta_data:
|
135 |
+
args.append(list(m) + [self.phone_encoder, self.binarization_args])
|
136 |
+
num_workers = int(os.getenv('N_PROC', os.cpu_count() // 3))
|
137 |
+
for f_id, (_, item) in enumerate(
|
138 |
+
zip(tqdm(meta_data), chunked_multiprocess_run(self.process_item, args, num_workers=num_workers))):
|
139 |
+
if item is None:
|
140 |
+
continue
|
141 |
+
item['spk_embed'] = voice_encoder.embed_utterance(item['wav']) \
|
142 |
+
if self.binarization_args['with_spk_embed'] else None
|
143 |
+
if not self.binarization_args['with_wav'] and 'wav' in item:
|
144 |
+
print("del wav")
|
145 |
+
del item['wav']
|
146 |
+
builder.add_item(item)
|
147 |
+
lengths.append(item['len'])
|
148 |
+
total_sec += item['sec']
|
149 |
+
if item.get('f0') is not None:
|
150 |
+
f0s.append(item['f0'])
|
151 |
+
builder.finalize()
|
152 |
+
np.save(f'{data_dir}/{prefix}_lengths.npy', lengths)
|
153 |
+
if len(f0s) > 0:
|
154 |
+
f0s = np.concatenate(f0s, 0)
|
155 |
+
f0s = f0s[f0s != 0]
|
156 |
+
np.save(f'{data_dir}/{prefix}_f0s_mean_std.npy', [np.mean(f0s).item(), np.std(f0s).item()])
|
157 |
+
print(f"| {prefix} total duration: {total_sec:.3f}s")
|
158 |
+
|
159 |
+
@classmethod
|
160 |
+
def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args):
|
161 |
+
if hparams['vocoder'] in VOCODERS:
|
162 |
+
wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn)
|
163 |
+
else:
|
164 |
+
wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn)
|
165 |
+
res = {
|
166 |
+
'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn,
|
167 |
+
'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id
|
168 |
+
}
|
169 |
+
try:
|
170 |
+
if binarization_args['with_f0']:
|
171 |
+
cls.get_pitch(wav, mel, res)
|
172 |
+
if binarization_args['with_f0cwt']:
|
173 |
+
cls.get_f0cwt(res['f0'], res)
|
174 |
+
if binarization_args['with_txt']:
|
175 |
+
try:
|
176 |
+
phone_encoded = res['phone'] = encoder.encode(ph)
|
177 |
+
except:
|
178 |
+
traceback.print_exc()
|
179 |
+
raise BinarizationError(f"Empty phoneme")
|
180 |
+
if binarization_args['with_align']:
|
181 |
+
cls.get_align(tg_fn, ph, mel, phone_encoded, res)
|
182 |
+
except BinarizationError as e:
|
183 |
+
print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}")
|
184 |
+
return None
|
185 |
+
return res
|
186 |
+
|
187 |
+
@staticmethod
|
188 |
+
def get_align(tg_fn, ph, mel, phone_encoded, res):
|
189 |
+
if tg_fn is not None and os.path.exists(tg_fn):
|
190 |
+
mel2ph, dur = get_mel2ph(tg_fn, ph, mel, hparams)
|
191 |
+
else:
|
192 |
+
raise BinarizationError(f"Align not found")
|
193 |
+
if mel2ph.max() - 1 >= len(phone_encoded):
|
194 |
+
raise BinarizationError(
|
195 |
+
f"Align does not match: mel2ph.max() - 1: {mel2ph.max() - 1}, len(phone_encoded): {len(phone_encoded)}")
|
196 |
+
res['mel2ph'] = mel2ph
|
197 |
+
res['dur'] = dur
|
198 |
+
|
199 |
+
@staticmethod
|
200 |
+
def get_pitch(wav, mel, res):
|
201 |
+
f0, pitch_coarse = get_pitch(wav, mel, hparams)
|
202 |
+
if sum(f0) == 0:
|
203 |
+
raise BinarizationError("Empty f0")
|
204 |
+
res['f0'] = f0
|
205 |
+
res['pitch'] = pitch_coarse
|
206 |
+
|
207 |
+
@staticmethod
|
208 |
+
def get_f0cwt(f0, res):
|
209 |
+
from utils.cwt import get_cont_lf0, get_lf0_cwt
|
210 |
+
uv, cont_lf0_lpf = get_cont_lf0(f0)
|
211 |
+
logf0s_mean_org, logf0s_std_org = np.mean(cont_lf0_lpf), np.std(cont_lf0_lpf)
|
212 |
+
cont_lf0_lpf_norm = (cont_lf0_lpf - logf0s_mean_org) / logf0s_std_org
|
213 |
+
Wavelet_lf0, scales = get_lf0_cwt(cont_lf0_lpf_norm)
|
214 |
+
if np.any(np.isnan(Wavelet_lf0)):
|
215 |
+
raise BinarizationError("NaN CWT")
|
216 |
+
res['cwt_spec'] = Wavelet_lf0
|
217 |
+
res['cwt_scales'] = scales
|
218 |
+
res['f0_mean'] = logf0s_mean_org
|
219 |
+
res['f0_std'] = logf0s_std_org
|
220 |
+
|
221 |
+
|
222 |
+
if __name__ == "__main__":
|
223 |
+
set_hparams()
|
224 |
+
BaseBinarizer().process()
|
data_gen/tts/bin/binarize.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
4 |
+
|
5 |
+
import importlib
|
6 |
+
from utils.hparams import set_hparams, hparams
|
7 |
+
|
8 |
+
|
9 |
+
def binarize():
|
10 |
+
binarizer_cls = hparams.get("binarizer_cls", 'data_gen.tts.base_binarizer.BaseBinarizer')
|
11 |
+
pkg = ".".join(binarizer_cls.split(".")[:-1])
|
12 |
+
cls_name = binarizer_cls.split(".")[-1]
|
13 |
+
binarizer_cls = getattr(importlib.import_module(pkg), cls_name)
|
14 |
+
print("| Binarizer: ", binarizer_cls)
|
15 |
+
binarizer_cls().process()
|
16 |
+
|
17 |
+
|
18 |
+
if __name__ == '__main__':
|
19 |
+
set_hparams()
|
20 |
+
binarize()
|
data_gen/tts/binarizer_zh.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
4 |
+
|
5 |
+
from data_gen.tts.txt_processors.zh_g2pM import ALL_SHENMU
|
6 |
+
from data_gen.tts.base_binarizer import BaseBinarizer, BinarizationError
|
7 |
+
from data_gen.tts.data_gen_utils import get_mel2ph
|
8 |
+
from utils.hparams import set_hparams, hparams
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
|
12 |
+
class ZhBinarizer(BaseBinarizer):
|
13 |
+
@staticmethod
|
14 |
+
def get_align(tg_fn, ph, mel, phone_encoded, res):
|
15 |
+
if tg_fn is not None and os.path.exists(tg_fn):
|
16 |
+
_, dur = get_mel2ph(tg_fn, ph, mel, hparams)
|
17 |
+
else:
|
18 |
+
raise BinarizationError(f"Align not found")
|
19 |
+
ph_list = ph.split(" ")
|
20 |
+
assert len(dur) == len(ph_list)
|
21 |
+
mel2ph = []
|
22 |
+
# 分隔符的时长分配给韵母
|
23 |
+
dur_cumsum = np.pad(np.cumsum(dur), [1, 0], mode='constant', constant_values=0)
|
24 |
+
for i in range(len(dur)):
|
25 |
+
p = ph_list[i]
|
26 |
+
if p[0] != '<' and not p[0].isalpha():
|
27 |
+
uv_ = res['f0'][dur_cumsum[i]:dur_cumsum[i + 1]] == 0
|
28 |
+
j = 0
|
29 |
+
while j < len(uv_) and not uv_[j]:
|
30 |
+
j += 1
|
31 |
+
dur[i - 1] += j
|
32 |
+
dur[i] -= j
|
33 |
+
if dur[i] < 100:
|
34 |
+
dur[i - 1] += dur[i]
|
35 |
+
dur[i] = 0
|
36 |
+
# 声母和韵母等长
|
37 |
+
for i in range(len(dur)):
|
38 |
+
p = ph_list[i]
|
39 |
+
if p in ALL_SHENMU:
|
40 |
+
p_next = ph_list[i + 1]
|
41 |
+
if not (dur[i] > 0 and p_next[0].isalpha() and p_next not in ALL_SHENMU):
|
42 |
+
print(f"assert dur[i] > 0 and p_next[0].isalpha() and p_next not in ALL_SHENMU, "
|
43 |
+
f"dur[i]: {dur[i]}, p: {p}, p_next: {p_next}.")
|
44 |
+
continue
|
45 |
+
total = dur[i + 1] + dur[i]
|
46 |
+
dur[i] = total // 2
|
47 |
+
dur[i + 1] = total - dur[i]
|
48 |
+
for i in range(len(dur)):
|
49 |
+
mel2ph += [i + 1] * dur[i]
|
50 |
+
mel2ph = np.array(mel2ph)
|
51 |
+
if mel2ph.max() - 1 >= len(phone_encoded):
|
52 |
+
raise BinarizationError(f"| Align does not match: {(mel2ph.max() - 1, len(phone_encoded))}")
|
53 |
+
res['mel2ph'] = mel2ph
|
54 |
+
res['dur'] = dur
|
55 |
+
|
56 |
+
|
57 |
+
if __name__ == "__main__":
|
58 |
+
set_hparams()
|
59 |
+
ZhBinarizer().process()
|
data_gen/tts/data_gen_utils.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
warnings.filterwarnings("ignore")
|
4 |
+
|
5 |
+
import parselmouth
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
from skimage.transform import resize
|
9 |
+
from utils.text_encoder import TokenTextEncoder
|
10 |
+
from utils.pitch_utils import f0_to_coarse
|
11 |
+
import struct
|
12 |
+
import webrtcvad
|
13 |
+
from scipy.ndimage.morphology import binary_dilation
|
14 |
+
import librosa
|
15 |
+
import numpy as np
|
16 |
+
from utils import audio
|
17 |
+
import pyloudnorm as pyln
|
18 |
+
import re
|
19 |
+
import json
|
20 |
+
from collections import OrderedDict
|
21 |
+
|
22 |
+
PUNCS = '!,.?;:'
|
23 |
+
|
24 |
+
int16_max = (2 ** 15) - 1
|
25 |
+
|
26 |
+
|
27 |
+
def trim_long_silences(path, sr=None, return_raw_wav=False, norm=True, vad_max_silence_length=12):
|
28 |
+
"""
|
29 |
+
Ensures that segments without voice in the waveform remain no longer than a
|
30 |
+
threshold determined by the VAD parameters in params.py.
|
31 |
+
:param wav: the raw waveform as a numpy array of floats
|
32 |
+
:param vad_max_silence_length: Maximum number of consecutive silent frames a segment can have.
|
33 |
+
:return: the same waveform with silences trimmed away (length <= original wav length)
|
34 |
+
"""
|
35 |
+
|
36 |
+
## Voice Activation Detection
|
37 |
+
# Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
|
38 |
+
# This sets the granularity of the VAD. Should not need to be changed.
|
39 |
+
sampling_rate = 16000
|
40 |
+
wav_raw, sr = librosa.core.load(path, sr=sr)
|
41 |
+
|
42 |
+
if norm:
|
43 |
+
meter = pyln.Meter(sr) # create BS.1770 meter
|
44 |
+
loudness = meter.integrated_loudness(wav_raw)
|
45 |
+
wav_raw = pyln.normalize.loudness(wav_raw, loudness, -20.0)
|
46 |
+
if np.abs(wav_raw).max() > 1.0:
|
47 |
+
wav_raw = wav_raw / np.abs(wav_raw).max()
|
48 |
+
|
49 |
+
wav = librosa.resample(wav_raw, sr, sampling_rate, res_type='kaiser_best')
|
50 |
+
|
51 |
+
vad_window_length = 30 # In milliseconds
|
52 |
+
# Number of frames to average together when performing the moving average smoothing.
|
53 |
+
# The larger this value, the larger the VAD variations must be to not get smoothed out.
|
54 |
+
vad_moving_average_width = 8
|
55 |
+
|
56 |
+
# Compute the voice detection window size
|
57 |
+
samples_per_window = (vad_window_length * sampling_rate) // 1000
|
58 |
+
|
59 |
+
# Trim the end of the audio to have a multiple of the window size
|
60 |
+
wav = wav[:len(wav) - (len(wav) % samples_per_window)]
|
61 |
+
|
62 |
+
# Convert the float waveform to 16-bit mono PCM
|
63 |
+
pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
|
64 |
+
|
65 |
+
# Perform voice activation detection
|
66 |
+
voice_flags = []
|
67 |
+
vad = webrtcvad.Vad(mode=3)
|
68 |
+
for window_start in range(0, len(wav), samples_per_window):
|
69 |
+
window_end = window_start + samples_per_window
|
70 |
+
voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
|
71 |
+
sample_rate=sampling_rate))
|
72 |
+
voice_flags = np.array(voice_flags)
|
73 |
+
|
74 |
+
# Smooth the voice detection with a moving average
|
75 |
+
def moving_average(array, width):
|
76 |
+
array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
|
77 |
+
ret = np.cumsum(array_padded, dtype=float)
|
78 |
+
ret[width:] = ret[width:] - ret[:-width]
|
79 |
+
return ret[width - 1:] / width
|
80 |
+
|
81 |
+
audio_mask = moving_average(voice_flags, vad_moving_average_width)
|
82 |
+
audio_mask = np.round(audio_mask).astype(np.bool)
|
83 |
+
|
84 |
+
# Dilate the voiced regions
|
85 |
+
audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
|
86 |
+
audio_mask = np.repeat(audio_mask, samples_per_window)
|
87 |
+
audio_mask = resize(audio_mask, (len(wav_raw),)) > 0
|
88 |
+
if return_raw_wav:
|
89 |
+
return wav_raw, audio_mask, sr
|
90 |
+
return wav_raw[audio_mask], audio_mask, sr
|
91 |
+
|
92 |
+
|
93 |
+
def process_utterance(wav_path,
|
94 |
+
fft_size=1024,
|
95 |
+
hop_size=256,
|
96 |
+
win_length=1024,
|
97 |
+
window="hann",
|
98 |
+
num_mels=80,
|
99 |
+
fmin=80,
|
100 |
+
fmax=7600,
|
101 |
+
eps=1e-6,
|
102 |
+
sample_rate=22050,
|
103 |
+
loud_norm=False,
|
104 |
+
min_level_db=-100,
|
105 |
+
return_linear=False,
|
106 |
+
trim_long_sil=False, vocoder='pwg'):
|
107 |
+
if isinstance(wav_path, str):
|
108 |
+
if trim_long_sil:
|
109 |
+
wav, _, _ = trim_long_silences(wav_path, sample_rate)
|
110 |
+
else:
|
111 |
+
wav, _ = librosa.core.load(wav_path, sr=sample_rate)
|
112 |
+
else:
|
113 |
+
wav = wav_path
|
114 |
+
|
115 |
+
if loud_norm:
|
116 |
+
meter = pyln.Meter(sample_rate) # create BS.1770 meter
|
117 |
+
loudness = meter.integrated_loudness(wav)
|
118 |
+
wav = pyln.normalize.loudness(wav, loudness, -22.0)
|
119 |
+
if np.abs(wav).max() > 1:
|
120 |
+
wav = wav / np.abs(wav).max()
|
121 |
+
|
122 |
+
# get amplitude spectrogram
|
123 |
+
x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size,
|
124 |
+
win_length=win_length, window=window, pad_mode="constant")
|
125 |
+
spc = np.abs(x_stft) # (n_bins, T)
|
126 |
+
|
127 |
+
# get mel basis
|
128 |
+
fmin = 0 if fmin == -1 else fmin
|
129 |
+
fmax = sample_rate / 2 if fmax == -1 else fmax
|
130 |
+
mel_basis = librosa.filters.mel(sample_rate, fft_size, num_mels, fmin, fmax)
|
131 |
+
mel = mel_basis @ spc
|
132 |
+
|
133 |
+
if vocoder == 'pwg':
|
134 |
+
mel = np.log10(np.maximum(eps, mel)) # (n_mel_bins, T)
|
135 |
+
else:
|
136 |
+
assert False, f'"{vocoder}" is not in ["pwg"].'
|
137 |
+
|
138 |
+
l_pad, r_pad = audio.librosa_pad_lr(wav, fft_size, hop_size, 1)
|
139 |
+
wav = np.pad(wav, (l_pad, r_pad), mode='constant', constant_values=0.0)
|
140 |
+
wav = wav[:mel.shape[1] * hop_size]
|
141 |
+
|
142 |
+
if not return_linear:
|
143 |
+
return wav, mel
|
144 |
+
else:
|
145 |
+
spc = audio.amp_to_db(spc)
|
146 |
+
spc = audio.normalize(spc, {'min_level_db': min_level_db})
|
147 |
+
return wav, mel, spc
|
148 |
+
|
149 |
+
|
150 |
+
def get_pitch(wav_data, mel, hparams):
|
151 |
+
"""
|
152 |
+
|
153 |
+
:param wav_data: [T]
|
154 |
+
:param mel: [T, 80]
|
155 |
+
:param hparams:
|
156 |
+
:return:
|
157 |
+
"""
|
158 |
+
time_step = hparams['hop_size'] / hparams['audio_sample_rate'] * 1000
|
159 |
+
f0_min = 80
|
160 |
+
f0_max = 750
|
161 |
+
|
162 |
+
if hparams['hop_size'] == 128:
|
163 |
+
pad_size = 4
|
164 |
+
elif hparams['hop_size'] == 256:
|
165 |
+
pad_size = 2
|
166 |
+
else:
|
167 |
+
assert False
|
168 |
+
|
169 |
+
f0 = parselmouth.Sound(wav_data, hparams['audio_sample_rate']).to_pitch_ac(
|
170 |
+
time_step=time_step / 1000, voicing_threshold=0.6,
|
171 |
+
pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
|
172 |
+
lpad = pad_size * 2
|
173 |
+
rpad = len(mel) - len(f0) - lpad
|
174 |
+
f0 = np.pad(f0, [[lpad, rpad]], mode='constant')
|
175 |
+
# mel and f0 are extracted by 2 different libraries. we should force them to have the same length.
|
176 |
+
# Attention: we find that new version of some libraries could cause ``rpad'' to be a negetive value...
|
177 |
+
# Just to be sure, we recommend users to set up the same environments as them in requirements_auto.txt (by Anaconda)
|
178 |
+
delta_l = len(mel) - len(f0)
|
179 |
+
assert np.abs(delta_l) <= 8
|
180 |
+
if delta_l > 0:
|
181 |
+
f0 = np.concatenate([f0, [f0[-1]] * delta_l], 0)
|
182 |
+
f0 = f0[:len(mel)]
|
183 |
+
pitch_coarse = f0_to_coarse(f0)
|
184 |
+
return f0, pitch_coarse
|
185 |
+
|
186 |
+
|
187 |
+
def remove_empty_lines(text):
|
188 |
+
"""remove empty lines"""
|
189 |
+
assert (len(text) > 0)
|
190 |
+
assert (isinstance(text, list))
|
191 |
+
text = [t.strip() for t in text]
|
192 |
+
if "" in text:
|
193 |
+
text.remove("")
|
194 |
+
return text
|
195 |
+
|
196 |
+
|
197 |
+
class TextGrid(object):
|
198 |
+
def __init__(self, text):
|
199 |
+
text = remove_empty_lines(text)
|
200 |
+
self.text = text
|
201 |
+
self.line_count = 0
|
202 |
+
self._get_type()
|
203 |
+
self._get_time_intval()
|
204 |
+
self._get_size()
|
205 |
+
self.tier_list = []
|
206 |
+
self._get_item_list()
|
207 |
+
|
208 |
+
def _extract_pattern(self, pattern, inc):
|
209 |
+
"""
|
210 |
+
Parameters
|
211 |
+
----------
|
212 |
+
pattern : regex to extract pattern
|
213 |
+
inc : increment of line count after extraction
|
214 |
+
Returns
|
215 |
+
-------
|
216 |
+
group : extracted info
|
217 |
+
"""
|
218 |
+
try:
|
219 |
+
group = re.match(pattern, self.text[self.line_count]).group(1)
|
220 |
+
self.line_count += inc
|
221 |
+
except AttributeError:
|
222 |
+
raise ValueError("File format error at line %d:%s" % (self.line_count, self.text[self.line_count]))
|
223 |
+
return group
|
224 |
+
|
225 |
+
def _get_type(self):
|
226 |
+
self.file_type = self._extract_pattern(r"File type = \"(.*)\"", 2)
|
227 |
+
|
228 |
+
def _get_time_intval(self):
|
229 |
+
self.xmin = self._extract_pattern(r"xmin = (.*)", 1)
|
230 |
+
self.xmax = self._extract_pattern(r"xmax = (.*)", 2)
|
231 |
+
|
232 |
+
def _get_size(self):
|
233 |
+
self.size = int(self._extract_pattern(r"size = (.*)", 2))
|
234 |
+
|
235 |
+
def _get_item_list(self):
|
236 |
+
"""Only supports IntervalTier currently"""
|
237 |
+
for itemIdx in range(1, self.size + 1):
|
238 |
+
tier = OrderedDict()
|
239 |
+
item_list = []
|
240 |
+
tier_idx = self._extract_pattern(r"item \[(.*)\]:", 1)
|
241 |
+
tier_class = self._extract_pattern(r"class = \"(.*)\"", 1)
|
242 |
+
if tier_class != "IntervalTier":
|
243 |
+
raise NotImplementedError("Only IntervalTier class is supported currently")
|
244 |
+
tier_name = self._extract_pattern(r"name = \"(.*)\"", 1)
|
245 |
+
tier_xmin = self._extract_pattern(r"xmin = (.*)", 1)
|
246 |
+
tier_xmax = self._extract_pattern(r"xmax = (.*)", 1)
|
247 |
+
tier_size = self._extract_pattern(r"intervals: size = (.*)", 1)
|
248 |
+
for i in range(int(tier_size)):
|
249 |
+
item = OrderedDict()
|
250 |
+
item["idx"] = self._extract_pattern(r"intervals \[(.*)\]", 1)
|
251 |
+
item["xmin"] = self._extract_pattern(r"xmin = (.*)", 1)
|
252 |
+
item["xmax"] = self._extract_pattern(r"xmax = (.*)", 1)
|
253 |
+
item["text"] = self._extract_pattern(r"text = \"(.*)\"", 1)
|
254 |
+
item_list.append(item)
|
255 |
+
tier["idx"] = tier_idx
|
256 |
+
tier["class"] = tier_class
|
257 |
+
tier["name"] = tier_name
|
258 |
+
tier["xmin"] = tier_xmin
|
259 |
+
tier["xmax"] = tier_xmax
|
260 |
+
tier["size"] = tier_size
|
261 |
+
tier["items"] = item_list
|
262 |
+
self.tier_list.append(tier)
|
263 |
+
|
264 |
+
def toJson(self):
|
265 |
+
_json = OrderedDict()
|
266 |
+
_json["file_type"] = self.file_type
|
267 |
+
_json["xmin"] = self.xmin
|
268 |
+
_json["xmax"] = self.xmax
|
269 |
+
_json["size"] = self.size
|
270 |
+
_json["tiers"] = self.tier_list
|
271 |
+
return json.dumps(_json, ensure_ascii=False, indent=2)
|
272 |
+
|
273 |
+
|
274 |
+
def get_mel2ph(tg_fn, ph, mel, hparams):
|
275 |
+
ph_list = ph.split(" ")
|
276 |
+
with open(tg_fn, "r") as f:
|
277 |
+
tg = f.readlines()
|
278 |
+
tg = remove_empty_lines(tg)
|
279 |
+
tg = TextGrid(tg)
|
280 |
+
tg = json.loads(tg.toJson())
|
281 |
+
split = np.ones(len(ph_list) + 1, np.float) * -1
|
282 |
+
tg_idx = 0
|
283 |
+
ph_idx = 0
|
284 |
+
tg_align = [x for x in tg['tiers'][-1]['items']]
|
285 |
+
tg_align_ = []
|
286 |
+
for x in tg_align:
|
287 |
+
x['xmin'] = float(x['xmin'])
|
288 |
+
x['xmax'] = float(x['xmax'])
|
289 |
+
if x['text'] in ['sil', 'sp', '', 'SIL', 'PUNC']:
|
290 |
+
x['text'] = ''
|
291 |
+
if len(tg_align_) > 0 and tg_align_[-1]['text'] == '':
|
292 |
+
tg_align_[-1]['xmax'] = x['xmax']
|
293 |
+
continue
|
294 |
+
tg_align_.append(x)
|
295 |
+
tg_align = tg_align_
|
296 |
+
tg_len = len([x for x in tg_align if x['text'] != ''])
|
297 |
+
ph_len = len([x for x in ph_list if not is_sil_phoneme(x)])
|
298 |
+
assert tg_len == ph_len, (tg_len, ph_len, tg_align, ph_list, tg_fn)
|
299 |
+
while tg_idx < len(tg_align) or ph_idx < len(ph_list):
|
300 |
+
if tg_idx == len(tg_align) and is_sil_phoneme(ph_list[ph_idx]):
|
301 |
+
split[ph_idx] = 1e8
|
302 |
+
ph_idx += 1
|
303 |
+
continue
|
304 |
+
x = tg_align[tg_idx]
|
305 |
+
if x['text'] == '' and ph_idx == len(ph_list):
|
306 |
+
tg_idx += 1
|
307 |
+
continue
|
308 |
+
assert ph_idx < len(ph_list), (tg_len, ph_len, tg_align, ph_list, tg_fn)
|
309 |
+
ph = ph_list[ph_idx]
|
310 |
+
if x['text'] == '' and not is_sil_phoneme(ph):
|
311 |
+
assert False, (ph_list, tg_align)
|
312 |
+
if x['text'] != '' and is_sil_phoneme(ph):
|
313 |
+
ph_idx += 1
|
314 |
+
else:
|
315 |
+
assert (x['text'] == '' and is_sil_phoneme(ph)) \
|
316 |
+
or x['text'].lower() == ph.lower() \
|
317 |
+
or x['text'].lower() == 'sil', (x['text'], ph)
|
318 |
+
split[ph_idx] = x['xmin']
|
319 |
+
if ph_idx > 0 and split[ph_idx - 1] == -1 and is_sil_phoneme(ph_list[ph_idx - 1]):
|
320 |
+
split[ph_idx - 1] = split[ph_idx]
|
321 |
+
ph_idx += 1
|
322 |
+
tg_idx += 1
|
323 |
+
assert tg_idx == len(tg_align), (tg_idx, [x['text'] for x in tg_align])
|
324 |
+
assert ph_idx >= len(ph_list) - 1, (ph_idx, ph_list, len(ph_list), [x['text'] for x in tg_align], tg_fn)
|
325 |
+
mel2ph = np.zeros([mel.shape[0]], np.int)
|
326 |
+
split[0] = 0
|
327 |
+
split[-1] = 1e8
|
328 |
+
for i in range(len(split) - 1):
|
329 |
+
assert split[i] != -1 and split[i] <= split[i + 1], (split[:-1],)
|
330 |
+
split = [int(s * hparams['audio_sample_rate'] / hparams['hop_size'] + 0.5) for s in split]
|
331 |
+
for ph_idx in range(len(ph_list)):
|
332 |
+
mel2ph[split[ph_idx]:split[ph_idx + 1]] = ph_idx + 1
|
333 |
+
mel2ph_torch = torch.from_numpy(mel2ph)
|
334 |
+
T_t = len(ph_list)
|
335 |
+
dur = mel2ph_torch.new_zeros([T_t + 1]).scatter_add(0, mel2ph_torch, torch.ones_like(mel2ph_torch))
|
336 |
+
dur = dur[1:].numpy()
|
337 |
+
return mel2ph, dur
|
338 |
+
|
339 |
+
|
340 |
+
def build_phone_encoder(data_dir):
|
341 |
+
phone_list_file = os.path.join(data_dir, 'phone_set.json')
|
342 |
+
phone_list = json.load(open(phone_list_file))
|
343 |
+
return TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',')
|
344 |
+
|
345 |
+
|
346 |
+
def is_sil_phoneme(p):
|
347 |
+
return not p[0].isalpha()
|
data_gen/tts/txt_processors/base_text_processor.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class BaseTxtProcessor:
|
2 |
+
@staticmethod
|
3 |
+
def sp_phonemes():
|
4 |
+
return ['|']
|
5 |
+
|
6 |
+
@classmethod
|
7 |
+
def process(cls, txt, pre_align_args):
|
8 |
+
raise NotImplementedError
|
data_gen/tts/txt_processors/en.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from data_gen.tts.data_gen_utils import PUNCS
|
3 |
+
from g2p_en import G2p
|
4 |
+
import unicodedata
|
5 |
+
from g2p_en.expand import normalize_numbers
|
6 |
+
from nltk import pos_tag
|
7 |
+
from nltk.tokenize import TweetTokenizer
|
8 |
+
|
9 |
+
from data_gen.tts.txt_processors.base_text_processor import BaseTxtProcessor
|
10 |
+
|
11 |
+
|
12 |
+
class EnG2p(G2p):
|
13 |
+
word_tokenize = TweetTokenizer().tokenize
|
14 |
+
|
15 |
+
def __call__(self, text):
|
16 |
+
# preprocessing
|
17 |
+
words = EnG2p.word_tokenize(text)
|
18 |
+
tokens = pos_tag(words) # tuples of (word, tag)
|
19 |
+
|
20 |
+
# steps
|
21 |
+
prons = []
|
22 |
+
for word, pos in tokens:
|
23 |
+
if re.search("[a-z]", word) is None:
|
24 |
+
pron = [word]
|
25 |
+
|
26 |
+
elif word in self.homograph2features: # Check homograph
|
27 |
+
pron1, pron2, pos1 = self.homograph2features[word]
|
28 |
+
if pos.startswith(pos1):
|
29 |
+
pron = pron1
|
30 |
+
else:
|
31 |
+
pron = pron2
|
32 |
+
elif word in self.cmu: # lookup CMU dict
|
33 |
+
pron = self.cmu[word][0]
|
34 |
+
else: # predict for oov
|
35 |
+
pron = self.predict(word)
|
36 |
+
|
37 |
+
prons.extend(pron)
|
38 |
+
prons.extend([" "])
|
39 |
+
|
40 |
+
return prons[:-1]
|
41 |
+
|
42 |
+
|
43 |
+
class TxtProcessor(BaseTxtProcessor):
|
44 |
+
g2p = EnG2p()
|
45 |
+
|
46 |
+
@staticmethod
|
47 |
+
def preprocess_text(text):
|
48 |
+
text = normalize_numbers(text)
|
49 |
+
text = ''.join(char for char in unicodedata.normalize('NFD', text)
|
50 |
+
if unicodedata.category(char) != 'Mn') # Strip accents
|
51 |
+
text = text.lower()
|
52 |
+
text = re.sub("[\'\"()]+", "", text)
|
53 |
+
text = re.sub("[-]+", " ", text)
|
54 |
+
text = re.sub(f"[^ a-z{PUNCS}]", "", text)
|
55 |
+
text = re.sub(f" ?([{PUNCS}]) ?", r"\1", text) # !! -> !
|
56 |
+
text = re.sub(f"([{PUNCS}])+", r"\1", text) # !! -> !
|
57 |
+
text = text.replace("i.e.", "that is")
|
58 |
+
text = text.replace("i.e.", "that is")
|
59 |
+
text = text.replace("etc.", "etc")
|
60 |
+
text = re.sub(f"([{PUNCS}])", r" \1 ", text)
|
61 |
+
text = re.sub(rf"\s+", r" ", text)
|
62 |
+
return text
|
63 |
+
|
64 |
+
@classmethod
|
65 |
+
def process(cls, txt, pre_align_args):
|
66 |
+
txt = cls.preprocess_text(txt).strip()
|
67 |
+
phs = cls.g2p(txt)
|
68 |
+
phs_ = []
|
69 |
+
n_word_sep = 0
|
70 |
+
for p in phs:
|
71 |
+
if p.strip() == '':
|
72 |
+
phs_ += ['|']
|
73 |
+
n_word_sep += 1
|
74 |
+
else:
|
75 |
+
phs_ += p.split(" ")
|
76 |
+
phs = phs_
|
77 |
+
assert n_word_sep + 1 == len(txt.split(" ")), (phs, f"\"{txt}\"")
|
78 |
+
return phs, txt
|
data_gen/tts/txt_processors/zh.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from pypinyin import pinyin, Style
|
3 |
+
from data_gen.tts.data_gen_utils import PUNCS
|
4 |
+
from data_gen.tts.txt_processors.base_text_processor import BaseTxtProcessor
|
5 |
+
from utils.text_norm import NSWNormalizer
|
6 |
+
|
7 |
+
|
8 |
+
class TxtProcessor(BaseTxtProcessor):
|
9 |
+
table = {ord(f): ord(t) for f, t in zip(
|
10 |
+
u':,。!?【】()%#@&1234567890',
|
11 |
+
u':,.!?[]()%#@&1234567890')}
|
12 |
+
|
13 |
+
@staticmethod
|
14 |
+
def preprocess_text(text):
|
15 |
+
text = text.translate(TxtProcessor.table)
|
16 |
+
text = NSWNormalizer(text).normalize(remove_punc=False)
|
17 |
+
text = re.sub("[\'\"()]+", "", text)
|
18 |
+
text = re.sub("[-]+", " ", text)
|
19 |
+
text = re.sub(f"[^ A-Za-z\u4e00-\u9fff{PUNCS}]", "", text)
|
20 |
+
text = re.sub(f"([{PUNCS}])+", r"\1", text) # !! -> !
|
21 |
+
text = re.sub(f"([{PUNCS}])", r" \1 ", text)
|
22 |
+
text = re.sub(rf"\s+", r"", text)
|
23 |
+
return text
|
24 |
+
|
25 |
+
@classmethod
|
26 |
+
def process(cls, txt, pre_align_args):
|
27 |
+
txt = cls.preprocess_text(txt)
|
28 |
+
shengmu = pinyin(txt, style=Style.INITIALS) # https://blog.csdn.net/zhoulei124/article/details/89055403
|
29 |
+
yunmu_finals = pinyin(txt, style=Style.FINALS)
|
30 |
+
yunmu_tone3 = pinyin(txt, style=Style.FINALS_TONE3)
|
31 |
+
yunmu = [[t[0] + '5'] if t[0] == f[0] else t for f, t in zip(yunmu_finals, yunmu_tone3)] \
|
32 |
+
if pre_align_args['use_tone'] else yunmu_finals
|
33 |
+
|
34 |
+
assert len(shengmu) == len(yunmu)
|
35 |
+
phs = ["|"]
|
36 |
+
for a, b, c in zip(shengmu, yunmu, yunmu_finals):
|
37 |
+
if a[0] == c[0]:
|
38 |
+
phs += [a[0], "|"]
|
39 |
+
else:
|
40 |
+
phs += [a[0], b[0], "|"]
|
41 |
+
return phs, txt
|
data_gen/tts/txt_processors/zh_g2pM.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import jieba
|
3 |
+
from pypinyin import pinyin, Style
|
4 |
+
from data_gen.tts.data_gen_utils import PUNCS
|
5 |
+
from data_gen.tts.txt_processors import zh
|
6 |
+
from g2pM import G2pM
|
7 |
+
|
8 |
+
ALL_SHENMU = ['zh', 'ch', 'sh', 'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h', 'j',
|
9 |
+
'q', 'x', 'r', 'z', 'c', 's', 'y', 'w']
|
10 |
+
ALL_YUNMU = ['a', 'ai', 'an', 'ang', 'ao', 'e', 'ei', 'en', 'eng', 'er', 'i', 'ia', 'ian',
|
11 |
+
'iang', 'iao', 'ie', 'in', 'ing', 'iong', 'iu', 'ng', 'o', 'ong', 'ou',
|
12 |
+
'u', 'ua', 'uai', 'uan', 'uang', 'ui', 'un', 'uo', 'v', 'van', 've', 'vn']
|
13 |
+
|
14 |
+
|
15 |
+
class TxtProcessor(zh.TxtProcessor):
|
16 |
+
model = G2pM()
|
17 |
+
|
18 |
+
@staticmethod
|
19 |
+
def sp_phonemes():
|
20 |
+
return ['|', '#']
|
21 |
+
|
22 |
+
@classmethod
|
23 |
+
def process(cls, txt, pre_align_args):
|
24 |
+
txt = cls.preprocess_text(txt)
|
25 |
+
ph_list = cls.model(txt, tone=pre_align_args['use_tone'], char_split=True)
|
26 |
+
seg_list = '#'.join(jieba.cut(txt))
|
27 |
+
assert len(ph_list) == len([s for s in seg_list if s != '#']), (ph_list, seg_list)
|
28 |
+
|
29 |
+
# 加入词边界'#'
|
30 |
+
ph_list_ = []
|
31 |
+
seg_idx = 0
|
32 |
+
for p in ph_list:
|
33 |
+
p = p.replace("u:", "v")
|
34 |
+
if seg_list[seg_idx] == '#':
|
35 |
+
ph_list_.append('#')
|
36 |
+
seg_idx += 1
|
37 |
+
else:
|
38 |
+
ph_list_.append("|")
|
39 |
+
seg_idx += 1
|
40 |
+
if re.findall('[\u4e00-\u9fff]', p):
|
41 |
+
if pre_align_args['use_tone']:
|
42 |
+
p = pinyin(p, style=Style.TONE3, strict=True)[0][0]
|
43 |
+
if p[-1] not in ['1', '2', '3', '4', '5']:
|
44 |
+
p = p + '5'
|
45 |
+
else:
|
46 |
+
p = pinyin(p, style=Style.NORMAL, strict=True)[0][0]
|
47 |
+
|
48 |
+
finished = False
|
49 |
+
if len([c.isalpha() for c in p]) > 1:
|
50 |
+
for shenmu in ALL_SHENMU:
|
51 |
+
if p.startswith(shenmu) and not p.lstrip(shenmu).isnumeric():
|
52 |
+
ph_list_ += [shenmu, p.lstrip(shenmu)]
|
53 |
+
finished = True
|
54 |
+
break
|
55 |
+
if not finished:
|
56 |
+
ph_list_.append(p)
|
57 |
+
|
58 |
+
ph_list = ph_list_
|
59 |
+
|
60 |
+
# 去除静音符号周围的词边界标记 [..., '#', ',', '#', ...]
|
61 |
+
sil_phonemes = list(PUNCS) + TxtProcessor.sp_phonemes()
|
62 |
+
ph_list_ = []
|
63 |
+
for i in range(0, len(ph_list), 1):
|
64 |
+
if ph_list[i] != '#' or (ph_list[i - 1] not in sil_phonemes and ph_list[i + 1] not in sil_phonemes):
|
65 |
+
ph_list_.append(ph_list[i])
|
66 |
+
ph_list = ph_list_
|
67 |
+
return ph_list, txt
|
68 |
+
|
69 |
+
|
70 |
+
if __name__ == '__main__':
|
71 |
+
phs, txt = TxtProcessor.process('他来到了,网易杭研大厦', {'use_tone': True})
|
72 |
+
print(phs)
|
docs/README-SVS-opencpop-cascade.md
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DiffSinger: Singing Voice Synthesis via Shallow Diffusion Mechanism
|
2 |
+
[![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2105.02446)
|
3 |
+
[![GitHub Stars](https://img.shields.io/github/stars/MoonInTheRiver/DiffSinger?style=social)](https://github.com/MoonInTheRiver/DiffSinger)
|
4 |
+
[![downloads](https://img.shields.io/github/downloads/MoonInTheRiver/DiffSinger/total.svg)](https://github.com/MoonInTheRiver/DiffSinger/releases)
|
5 |
+
|
6 |
+
## DiffSinger (MIDI version SVS)
|
7 |
+
### 0. Data Acquirement
|
8 |
+
For Opencpop dataset: Please strictly follow the instructions of [Opencpop](https://wenet.org.cn/opencpop/). We have no right to give you the access to Opencpop.
|
9 |
+
|
10 |
+
The pipeline below is designed for Opencpop dataset:
|
11 |
+
|
12 |
+
### 1. Preparation
|
13 |
+
|
14 |
+
#### Data Preparation
|
15 |
+
a) Download and extract Opencpop, then create a link to the dataset folder: `ln -s /xxx/opencpop data/raw/`
|
16 |
+
|
17 |
+
b) Run the following scripts to pack the dataset for training/inference.
|
18 |
+
|
19 |
+
```sh
|
20 |
+
export PYTHONPATH=.
|
21 |
+
CUDA_VISIBLE_DEVICES=0 python data_gen/tts/bin/binarize.py --config usr/configs/midi/cascade/opencs/aux_rel.yaml
|
22 |
+
|
23 |
+
# `data/binary/opencpop-midi-dp` will be generated.
|
24 |
+
```
|
25 |
+
|
26 |
+
#### Vocoder Preparation
|
27 |
+
We provide the pre-trained model of [HifiGAN-Singing](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/0109_hifigan_bigpopcs_hop128.zip) which is specially designed for SVS with NSF mechanism.
|
28 |
+
Please unzip this file into `checkpoints` before training your acoustic model.
|
29 |
+
|
30 |
+
(Update: You can also move [a ckpt with more training steps](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/model_ckpt_steps_1512000.ckpt) into this vocoder directory)
|
31 |
+
|
32 |
+
This singing vocoder is trained on ~70 hours singing data, which can be viewed as a universal vocoder.
|
33 |
+
|
34 |
+
#### Exp Name Preparation
|
35 |
+
```bash
|
36 |
+
export MY_FS_EXP_NAME=0302_opencpop_fs_midi
|
37 |
+
export MY_DS_EXP_NAME=0303_opencpop_ds58_midi
|
38 |
+
```
|
39 |
+
|
40 |
+
```
|
41 |
+
.
|
42 |
+
|--data
|
43 |
+
|--raw
|
44 |
+
|--opencpop
|
45 |
+
|--segments
|
46 |
+
|--transcriptions.txt
|
47 |
+
|--wavs
|
48 |
+
|--checkpoints
|
49 |
+
|--MY_FS_EXP_NAME (optional)
|
50 |
+
|--MY_DS_EXP_NAME (optional)
|
51 |
+
|--0109_hifigan_bigpopcs_hop128
|
52 |
+
|--model_ckpt_steps_1512000.ckpt
|
53 |
+
|--config.yaml
|
54 |
+
```
|
55 |
+
|
56 |
+
### 2. Training Example
|
57 |
+
First, you need a pre-trained FFT-Singer checkpoint. You can use the pre-trained model, or train FFT-Singer from scratch, run:
|
58 |
+
```sh
|
59 |
+
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/midi/cascade/opencs/aux_rel.yaml --exp_name $MY_FS_EXP_NAME --reset
|
60 |
+
```
|
61 |
+
|
62 |
+
Then, to train DiffSinger, run:
|
63 |
+
|
64 |
+
```sh
|
65 |
+
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name $MY_DS_EXP_NAME --reset
|
66 |
+
```
|
67 |
+
|
68 |
+
Remember to adjust the "fs2_ckpt" parameter in `usr/configs/midi/cascade/opencs/ds60_rel.yaml` to fit your path.
|
69 |
+
|
70 |
+
### 3. Inference Example
|
71 |
+
```sh
|
72 |
+
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name $MY_DS_EXP_NAME --reset --infer
|
73 |
+
```
|
74 |
+
|
75 |
+
We also provide:
|
76 |
+
- the pre-trained model of DiffSinger;
|
77 |
+
- the pre-trained model of FFT-Singer;
|
78 |
+
|
79 |
+
They can be found in [here](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/adjust-receptive-field.zip).
|
80 |
+
|
81 |
+
Remember to put the pre-trained models in `checkpoints` directory.
|
82 |
+
|
83 |
+
### 4. Inference from raw inputs
|
84 |
+
```sh
|
85 |
+
python inference/svs/ds_e2e.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name $MY_DS_EXP_NAME
|
86 |
+
```
|
87 |
+
Raw inputs:
|
88 |
+
```
|
89 |
+
inp = {
|
90 |
+
'text': '小酒窝长睫毛AP是你最美的记号',
|
91 |
+
'notes': 'C#4/Db4 | F#4/Gb4 | G#4/Ab4 | A#4/Bb4 F#4/Gb4 | F#4/Gb4 C#4/Db4 | C#4/Db4 | rest | C#4/Db4 | A#4/Bb4 | G#4/Ab4 | A#4/Bb4 | G#4/Ab4 | F4 | C#4/Db4',
|
92 |
+
'notes_duration': '0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340',
|
93 |
+
'input_type': 'word'
|
94 |
+
} # user input: Chinese characters
|
95 |
+
or,
|
96 |
+
inp = {
|
97 |
+
'text': '小酒窝长睫毛AP是你最美的记号',
|
98 |
+
'ph_seq': 'x iao j iu w o ch ang ang j ie ie m ao AP sh i n i z ui m ei d e j i h ao',
|
99 |
+
'note_seq': 'C#4/Db4 C#4/Db4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 C#4/Db4 C#4/Db4 C#4/Db4 rest C#4/Db4 C#4/Db4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 F4 F4 C#4/Db4 C#4/Db4',
|
100 |
+
'note_dur_seq': '0.407140 0.407140 0.376190 0.376190 0.242180 0.242180 0.509550 0.509550 0.183420 0.315400 0.315400 0.235020 0.361660 0.361660 0.223070 0.377270 0.377270 0.340550 0.340550 0.299620 0.299620 0.344510 0.344510 0.283770 0.283770 0.323390 0.323390 0.360340 0.360340',
|
101 |
+
'is_slur_seq': '0 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0',
|
102 |
+
'input_type': 'phoneme'
|
103 |
+
} # input like Opencpop dataset.
|
104 |
+
```
|
105 |
+
|
106 |
+
### 5. Some issues.
|
107 |
+
a) the HifiGAN-Singing is trained on our [vocoder dataset](https://dl.acm.org/doi/abs/10.1145/3474085.3475437) and the training set of [PopCS](https://arxiv.org/abs/2105.02446). Opencpop is the out-of-domain dataset (unseen speaker). This may cause the deterioration of audio quality, and we are considering fine-tuning this vocoder on the training set of Opencpop.
|
108 |
+
|
109 |
+
b) in this version of codes, we used the melody frontend ([lyric + MIDI]->[F0+ph_dur]) to predict F0 contour and phoneme duration.
|
110 |
+
|
111 |
+
c) generated audio demos can be found in [MY_DS_EXP_NAME](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/adjust-receptive-field.zip).
|
docs/README-SVS-opencpop-e2e.md
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DiffSinger: Singing Voice Synthesis via Shallow Diffusion Mechanism
|
2 |
+
[![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2105.02446)
|
3 |
+
[![GitHub Stars](https://img.shields.io/github/stars/MoonInTheRiver/DiffSinger?style=social)](https://github.com/MoonInTheRiver/DiffSinger)
|
4 |
+
[![downloads](https://img.shields.io/github/downloads/MoonInTheRiver/DiffSinger/total.svg)](https://github.com/MoonInTheRiver/DiffSinger/releases)
|
5 |
+
|
6 |
+
Substantial update: We 1) **abandon** the explicit prediction of the F0 curve; 2) increase the receptive field of the denoiser; 3) make the linguistic encoder more robust.
|
7 |
+
**By doing so, 1) the synthesized recordings are more natural in terms of pitch; 2) the pipeline is simpler.**
|
8 |
+
|
9 |
+
简而言之,把F0曲线的动态性交给生成式模型去捕捉,而不再是以前那样用MSE约束对数域F0。
|
10 |
+
|
11 |
+
## DiffSinger (MIDI version SVS)
|
12 |
+
### 0. Data Acquirement
|
13 |
+
For Opencpop dataset: Please strictly follow the instructions of [Opencpop](https://wenet.org.cn/opencpop/). We have no right to give you the access to Opencpop.
|
14 |
+
|
15 |
+
The pipeline below is designed for Opencpop dataset:
|
16 |
+
|
17 |
+
### 1. Preparation
|
18 |
+
|
19 |
+
#### Data Preparation
|
20 |
+
a) Download and extract Opencpop, then create a link to the dataset folder: `ln -s /xxx/opencpop data/raw/`
|
21 |
+
|
22 |
+
b) Run the following scripts to pack the dataset for training/inference.
|
23 |
+
|
24 |
+
```sh
|
25 |
+
export PYTHONPATH=.
|
26 |
+
CUDA_VISIBLE_DEVICES=0 python data_gen/tts/bin/binarize.py --config usr/configs/midi/cascade/opencs/aux_rel.yaml
|
27 |
+
|
28 |
+
# `data/binary/opencpop-midi-dp` will be generated.
|
29 |
+
```
|
30 |
+
|
31 |
+
#### Vocoder Preparation
|
32 |
+
We provide the pre-trained model of [HifiGAN-Singing](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/0109_hifigan_bigpopcs_hop128.zip) which is specially designed for SVS with NSF mechanism.
|
33 |
+
|
34 |
+
Also, please unzip pre-trained vocoder and [this pendant for vocoder](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/0102_xiaoma_pe.zip) into `checkpoints` before training your acoustic model.
|
35 |
+
|
36 |
+
(Update: You can also move [a ckpt with more training steps](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/model_ckpt_steps_1512000.ckpt) into this vocoder directory)
|
37 |
+
|
38 |
+
This singing vocoder is trained on ~70 hours singing data, which can be viewed as a universal vocoder.
|
39 |
+
|
40 |
+
#### Exp Name Preparation
|
41 |
+
```bash
|
42 |
+
export MY_DS_EXP_NAME=0228_opencpop_ds100_rel
|
43 |
+
```
|
44 |
+
|
45 |
+
```
|
46 |
+
.
|
47 |
+
|--data
|
48 |
+
|--raw
|
49 |
+
|--opencpop
|
50 |
+
|--segments
|
51 |
+
|--transcriptions.txt
|
52 |
+
|--wavs
|
53 |
+
|--checkpoints
|
54 |
+
|--MY_DS_EXP_NAME (optional)
|
55 |
+
|--0109_hifigan_bigpopcs_hop128 (vocoder)
|
56 |
+
|--model_ckpt_steps_1512000.ckpt
|
57 |
+
|--config.yaml
|
58 |
+
```
|
59 |
+
|
60 |
+
### 2. Training Example
|
61 |
+
```sh
|
62 |
+
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml --exp_name $MY_DS_EXP_NAME --reset
|
63 |
+
```
|
64 |
+
|
65 |
+
### 3. Inference from packed test set
|
66 |
+
```sh
|
67 |
+
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml --exp_name $MY_DS_EXP_NAME --reset --infer
|
68 |
+
```
|
69 |
+
|
70 |
+
We also provide:
|
71 |
+
- the pre-trained model of DiffSinger;
|
72 |
+
|
73 |
+
They can be found in [here](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/0228_opencpop_ds100_rel.zip).
|
74 |
+
|
75 |
+
Remember to put the pre-trained models in `checkpoints` directory.
|
76 |
+
|
77 |
+
### 4. Inference from raw inputs
|
78 |
+
```sh
|
79 |
+
python inference/svs/ds_e2e.py --config usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml --exp_name $MY_DS_EXP_NAME
|
80 |
+
```
|
81 |
+
Raw inputs:
|
82 |
+
```
|
83 |
+
inp = {
|
84 |
+
'text': '小酒窝长睫毛AP是你最美的记号',
|
85 |
+
'notes': 'C#4/Db4 | F#4/Gb4 | G#4/Ab4 | A#4/Bb4 F#4/Gb4 | F#4/Gb4 C#4/Db4 | C#4/Db4 | rest | C#4/Db4 | A#4/Bb4 | G#4/Ab4 | A#4/Bb4 | G#4/Ab4 | F4 | C#4/Db4',
|
86 |
+
'notes_duration': '0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340',
|
87 |
+
'input_type': 'word'
|
88 |
+
} # user input: Chinese characters
|
89 |
+
or,
|
90 |
+
inp = {
|
91 |
+
'text': '小酒窝长睫毛AP是你最美的记号',
|
92 |
+
'ph_seq': 'x iao j iu w o ch ang ang j ie ie m ao AP sh i n i z ui m ei d e j i h ao',
|
93 |
+
'note_seq': 'C#4/Db4 C#4/Db4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 C#4/Db4 C#4/Db4 C#4/Db4 rest C#4/Db4 C#4/Db4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 F4 F4 C#4/Db4 C#4/Db4',
|
94 |
+
'note_dur_seq': '0.407140 0.407140 0.376190 0.376190 0.242180 0.242180 0.509550 0.509550 0.183420 0.315400 0.315400 0.235020 0.361660 0.361660 0.223070 0.377270 0.377270 0.340550 0.340550 0.299620 0.299620 0.344510 0.344510 0.283770 0.283770 0.323390 0.323390 0.360340 0.360340',
|
95 |
+
'is_slur_seq': '0 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0',
|
96 |
+
'input_type': 'phoneme'
|
97 |
+
} # input like Opencpop dataset.
|
98 |
+
```
|
99 |
+
|
100 |
+
### 5. Some issues.
|
101 |
+
a) the HifiGAN-Singing is trained on our [vocoder dataset](https://dl.acm.org/doi/abs/10.1145/3474085.3475437) and the training set of [PopCS](https://arxiv.org/abs/2105.02446). Opencpop is the out-of-domain dataset (unseen speaker). This may cause the deterioration of audio quality, and we are considering fine-tuning this vocoder on the training set of Opencpop.
|
102 |
+
|
103 |
+
b) in this version of codes, we used the melody frontend ([lyric + MIDI]->[ph_dur]) to predict phoneme duration. F0 curve is implicitly predicted together with mel-spectrogram.
|
104 |
+
|
105 |
+
c) example [generated audio](https://github.com/MoonInTheRiver/DiffSinger/blob/master/resources/demos_0221/DS/).
|
106 |
+
More generated audio demos can be found in [DiffSinger](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/0228_opencpop_ds100_rel.zip).
|
docs/README-SVS-popcs.md
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## DiffSinger (SVS version)
|
2 |
+
|
3 |
+
### 0. Data Acquirement
|
4 |
+
- See in [apply_form](https://github.com/MoonInTheRiver/DiffSinger/blob/master/resources/apply_form.md).
|
5 |
+
- Dataset [preview](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/popcs_preview.zip).
|
6 |
+
|
7 |
+
### 1. Preparation
|
8 |
+
#### Data Preparation
|
9 |
+
a) Download and extract PopCS, then create a link to the dataset folder: `ln -s /xxx/popcs/ data/processed/popcs`
|
10 |
+
|
11 |
+
b) Run the following scripts to pack the dataset for training/inference.
|
12 |
+
```sh
|
13 |
+
export PYTHONPATH=.
|
14 |
+
CUDA_VISIBLE_DEVICES=0 python data_gen/tts/bin/binarize.py --config usr/configs/popcs_ds_beta6.yaml
|
15 |
+
# `data/binary/popcs-pmf0` will be generated.
|
16 |
+
```
|
17 |
+
|
18 |
+
#### Vocoder Preparation
|
19 |
+
We provide the pre-trained model of [HifiGAN-Singing](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/0109_hifigan_bigpopcs_hop128.zip) which is specially designed for SVS with NSF mechanism.
|
20 |
+
Please unzip this file into `checkpoints` before training your acoustic model.
|
21 |
+
|
22 |
+
(Update: You can also move [a ckpt with more training steps](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/model_ckpt_steps_1512000.ckpt) into this vocoder directory)
|
23 |
+
|
24 |
+
This singing vocoder is trained on ~70 hours singing data, which can be viewed as a universal vocoder.
|
25 |
+
|
26 |
+
### 2. Training Example
|
27 |
+
First, you need a pre-trained FFT-Singer checkpoint. You can use the [pre-trained model](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/popcs_fs2_pmf0_1230.zip), or train FFT-Singer from scratch, run:
|
28 |
+
|
29 |
+
```sh
|
30 |
+
# First, train fft-singer;
|
31 |
+
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/popcs_fs2.yaml --exp_name popcs_fs2_pmf0_1230 --reset
|
32 |
+
# Then, infer fft-singer;
|
33 |
+
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/popcs_fs2.yaml --exp_name popcs_fs2_pmf0_1230 --reset --infer
|
34 |
+
```
|
35 |
+
|
36 |
+
Then, to train DiffSinger, run:
|
37 |
+
```sh
|
38 |
+
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/popcs_ds_beta6_offline.yaml --exp_name popcs_ds_beta6_offline_pmf0_1230 --reset
|
39 |
+
```
|
40 |
+
|
41 |
+
Remember to adjust the "fs2_ckpt" parameter in `usr/configs/popcs_ds_beta6_offline.yaml` to fit your path.
|
42 |
+
|
43 |
+
### 3. Inference Example
|
44 |
+
```sh
|
45 |
+
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/popcs_ds_beta6_offline.yaml --exp_name popcs_ds_beta6_offline_pmf0_1230 --reset --infer
|
46 |
+
```
|
47 |
+
|
48 |
+
We also provide:
|
49 |
+
- the pre-trained model of [DiffSinger](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/popcs_ds_beta6_offline_pmf0_1230.zip);
|
50 |
+
- the pre-trained model of [FFT-Singer](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/popcs_fs2_pmf0_1230.zip) for the shallow diffusion mechanism in DiffSinger;
|
51 |
+
|
52 |
+
Remember to put the pre-trained models in `checkpoints` directory.
|
53 |
+
|
54 |
+
*Note that:*
|
55 |
+
|
56 |
+
- *the original PWG version vocoder in the paper we used has been put into commercial use, so we provide this HifiGAN version vocoder as a substitute.*
|
57 |
+
- *we assume the ground-truth F0 to be given as the pitch information following [1][2][3]. If you want to conduct experiments on MIDI data, you need an external F0 predictor (like [MIDI-old-version](README-SVS-opencpop-cascade.md)) or a joint prediction with spectrograms(like [MIDI-new-version](README-SVS-opencpop-e2e.md)).*
|
58 |
+
|
59 |
+
[1] Adversarially trained multi-singer sequence-to-sequence singing synthesizer. Interspeech 2020.
|
60 |
+
|
61 |
+
[2] SEQUENCE-TO-SEQUENCE SINGING SYNTHESIS USING THE FEED-FORWARD TRANSFORMER. ICASSP 2020.
|
62 |
+
|
63 |
+
[3] DeepSinger : Singing Voice Synthesis with Data Mined From the Web. KDD 2020.
|
docs/README-SVS.md
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## DiffSinger (SVS version)
|
2 |
+
|
3 |
+
### PART1. [Run DiffSinger on PopCS](README-SVS-popcs.md)
|
4 |
+
In this part, we only focus on spectrum modeling (acoustic model) and assume the ground-truth (GT) F0 to be given as the pitch information following these papers [1][2][3].
|
5 |
+
|
6 |
+
Thus, the pipeline of this part can be summarized as:
|
7 |
+
|
8 |
+
```
|
9 |
+
[lyrics] -> [linguistic representation] (Frontend)
|
10 |
+
[linguistic representation] + [GT F0] + [GT phoneme duration] -> [mel-spectrogram] (Acoustic model)
|
11 |
+
[mel-spectrogram] + [GT F0] -> [waveform] (Vocoder)
|
12 |
+
```
|
13 |
+
|
14 |
+
|
15 |
+
[1] Adversarially trained multi-singer sequence-to-sequence singing synthesizer. Interspeech 2020.
|
16 |
+
|
17 |
+
[2] SEQUENCE-TO-SEQUENCE SINGING SYNTHESIS USING THE FEED-FORWARD TRANSFORMER. ICASSP 2020.
|
18 |
+
|
19 |
+
[3] DeepSinger : Singing Voice Synthesis with Data Mined From the Web. KDD 2020.
|
20 |
+
|
21 |
+
### PART2. [Run DiffSinger on Opencpop](README-SVS-opencpop-cascade.md)
|
22 |
+
Thanks [Opencpop team](https://wenet.org.cn/opencpop/) for releasing their SVS dataset with MIDI label, **Jan.20, 2022**. (Also thanks to my co-author [Yi Ren](https://github.com/RayeRen), who applied for the dataset and did some preprocessing works for this part).
|
23 |
+
|
24 |
+
Since there are elaborately annotated MIDI labels, we are able to supplement the pipeline in PART 1 by adding a naive melody frontend.
|
25 |
+
|
26 |
+
#### 2.1
|
27 |
+
Thus, the pipeline of [this part](README-SVS-opencpop-cascade.md) can be summarized as:
|
28 |
+
|
29 |
+
```
|
30 |
+
[lyrics] + [MIDI] -> [linguistic representation (with MIDI information)] + [predicted F0] + [predicted phoneme duration] (Melody frontend)
|
31 |
+
[linguistic representation] + [predicted F0] + [predicted phoneme duration] -> [mel-spectrogram] (Acoustic model)
|
32 |
+
[mel-spectrogram] + [predicted F0] -> [waveform] (Vocoder)
|
33 |
+
```
|
34 |
+
|
35 |
+
#### 2.2
|
36 |
+
In 2.1, we find that if we predict F0 explicitly in the melody frontend, there will be many bad cases of uv/v prediction. Then, we abandon the explicit prediction of the F0 curve in the melody frontend but make a joint prediction with spectrograms.
|
37 |
+
|
38 |
+
Thus, the pipeline of [this part](README-SVS-opencpop-e2e.md) can be summarized as:
|
39 |
+
```
|
40 |
+
[lyrics] + [MIDI] -> [linguistic representation] + [predicted phoneme duration] (Melody frontend)
|
41 |
+
[linguistic representation (with MIDI information)] + [predicted phoneme duration] -> [mel-spectrogram] (Acoustic model)
|
42 |
+
[mel-spectrogram] -> [predicted F0] (Pitch extractor)
|
43 |
+
[mel-spectrogram] + [predicted F0] -> [waveform] (Vocoder)
|
44 |
+
```
|
docs/README-TTS.md
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## DiffSpeech (TTS version)
|
2 |
+
### 1. Preparation
|
3 |
+
|
4 |
+
#### Data Preparation
|
5 |
+
a) Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/), then create a link to the dataset folder: `ln -s /xxx/LJSpeech-1.1/ data/raw/`
|
6 |
+
|
7 |
+
b) Download and Unzip the [ground-truth duration](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/mfa_outputs.tar) extracted by [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner/releases/download/v1.0.1/montreal-forced-aligner_linux.tar.gz): `tar -xvf mfa_outputs.tar; mv mfa_outputs data/processed/ljspeech/`
|
8 |
+
|
9 |
+
c) Run the following scripts to pack the dataset for training/inference.
|
10 |
+
|
11 |
+
```sh
|
12 |
+
export PYTHONPATH=.
|
13 |
+
CUDA_VISIBLE_DEVICES=0 python data_gen/tts/bin/binarize.py --config configs/tts/lj/fs2.yaml
|
14 |
+
|
15 |
+
# `data/binary/ljspeech` will be generated.
|
16 |
+
```
|
17 |
+
|
18 |
+
#### Vocoder Preparation
|
19 |
+
We provide the pre-trained model of [HifiGAN](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/0414_hifi_lj_1.zip) vocoder.
|
20 |
+
Please unzip this file into `checkpoints` before training your acoustic model.
|
21 |
+
|
22 |
+
### 2. Training Example
|
23 |
+
|
24 |
+
First, you need a pre-trained FastSpeech2 checkpoint. You can use the [pre-trained model](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/fs2_lj_1.zip), or train FastSpeech2 from scratch, run:
|
25 |
+
```sh
|
26 |
+
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config configs/tts/lj/fs2.yaml --exp_name fs2_lj_1 --reset
|
27 |
+
```
|
28 |
+
Then, to train DiffSpeech, run:
|
29 |
+
```sh
|
30 |
+
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/lj_ds_beta6.yaml --exp_name lj_ds_beta6_1213 --reset
|
31 |
+
```
|
32 |
+
|
33 |
+
Remember to adjust the "fs2_ckpt" parameter in `usr/configs/lj_ds_beta6.yaml` to fit your path.
|
34 |
+
|
35 |
+
### 3. Inference Example
|
36 |
+
|
37 |
+
```sh
|
38 |
+
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/lj_ds_beta6.yaml --exp_name lj_ds_beta6_1213 --reset --infer
|
39 |
+
```
|
40 |
+
|
41 |
+
We also provide:
|
42 |
+
- the pre-trained model of [DiffSpeech](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/lj_ds_beta6_1213.zip);
|
43 |
+
- the individual pre-trained model of [FastSpeech 2](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/fs2_lj_1.zip) for the shallow diffusion mechanism in DiffSpeech;
|
44 |
+
|
45 |
+
Remember to put the pre-trained models in `checkpoints` directory.
|
46 |
+
|
47 |
+
## Mel Visualization
|
48 |
+
Along vertical axis, DiffSpeech: [0-80]; FastSpeech2: [80-160].
|
49 |
+
|
50 |
+
<table style="width:100%">
|
51 |
+
<tr>
|
52 |
+
<th>DiffSpeech vs. FastSpeech 2</th>
|
53 |
+
</tr>
|
54 |
+
<tr>
|
55 |
+
<td><img src="resources/diffspeech-fs2.png" alt="DiffSpeech-vs-FastSpeech2" height="250"></td>
|
56 |
+
</tr>
|
57 |
+
<tr>
|
58 |
+
<td><img src="resources/diffspeech-fs2-1.png" alt="DiffSpeech-vs-FastSpeech2" height="250"></td>
|
59 |
+
</tr>
|
60 |
+
<tr>
|
61 |
+
<td><img src="resources/diffspeech-fs2-2.png" alt="DiffSpeech-vs-FastSpeech2" height="250"></td>
|
62 |
+
</tr>
|
63 |
+
</table>
|
docs/README-zh.md
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DiffSinger: Singing Voice Synthesis via Shallow Diffusion Mechanism
|
2 |
+
[![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2105.02446)
|
3 |
+
[![GitHub Stars](https://img.shields.io/github/stars/MoonInTheRiver/DiffSinger?style=social)](https://github.com/MoonInTheRiver/DiffSinger)
|
4 |
+
[![downloads](https://img.shields.io/github/downloads/MoonInTheRiver/DiffSinger/total.svg)](https://github.com/MoonInTheRiver/DiffSinger/releases)
|
5 |
+
| [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-blue)](https://huggingface.co/spaces/NATSpeech/DiffSpeech)
|
6 |
+
| [English README](../README.md)
|
7 |
+
|
8 |
+
本仓库包含了我们的AAAI-2022 [论文](https://arxiv.org/abs/2105.02446)中提出的DiffSpeech (用于语音合成) 与 DiffSinger (用于歌声合成) 的官方Pytorch实现。
|
9 |
+
|
10 |
+
<table style="width:100%">
|
11 |
+
<tr>
|
12 |
+
<th>DiffSinger/DiffSpeech训练阶段</th>
|
13 |
+
<th>DiffSinger/DiffSpeech推理阶段</th>
|
14 |
+
</tr>
|
15 |
+
<tr>
|
16 |
+
<td><img src="resources/model_a.png" alt="Training" height="300"></td>
|
17 |
+
<td><img src="resources/model_b.png" alt="Inference" height="300"></td>
|
18 |
+
</tr>
|
19 |
+
</table>
|
20 |
+
|
21 |
+
:tada: :tada: :tada: **一些重要更新**:
|
22 |
+
- Mar.2, 2022: [MIDI-新版](README-SVS-opencpop-e2e.md): 重大更新 :sparkles:
|
23 |
+
- Mar.1, 2022: [NeuralSVB](https://github.com/MoonInTheRiver/NeuralSVB), 为了歌声美化任务的代码,开源了 :sparkles: :sparkles: :sparkles: .
|
24 |
+
- Feb.13, 2022: [NATSpeech](https://github.com/NATSpeech/NATSpeech), 一个升级后的代码框架, 包含了DiffSpeech和我们NeurIPS-2021的工作[PortaSpeech](https://openreview.net/forum?id=xmJsuh8xlq) 已经开源! :sparkles: :sparkles: :sparkles:.
|
25 |
+
- Jan.29, 2022: 支持了[MIDI-旧版](README-SVS-opencpop-cascade.md) 版本的歌声合成系统.
|
26 |
+
- Jan.13, 2022: 支持了歌声合成系统, 开源了PopCS数据集.
|
27 |
+
- Dec.19, 2021: 支持了语音合成系统. [HuggingFace🤗 Demo](https://huggingface.co/spaces/NATSpeech/DiffSpeech)
|
28 |
+
|
29 |
+
:rocket: **新闻**:
|
30 |
+
- Feb.24, 2022: 我们的新工作`NeuralSVB` 被 ACL-2022 接收 [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2202.13277). [音频演示](https://neuralsvb.github.io).
|
31 |
+
- Dec.01, 2021: DiffSinger被AAAI-2022接收.
|
32 |
+
- Sep.29, 2021: 我们的新工作`PortaSpeech: Portable and High-Quality Generative Text-to-Speech` 被NeurIPS-2021接收 [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2109.15166) .
|
33 |
+
- May.06, 2021: 我们把这篇DiffSinger提交到了公开论文网站: Arxiv [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2105.02446).
|
34 |
+
|
35 |
+
## 安装依赖
|
36 |
+
```sh
|
37 |
+
conda create -n your_env_name python=3.8
|
38 |
+
source activate your_env_name
|
39 |
+
pip install -r requirements_2080.txt (GPU 2080Ti, CUDA 10.2)
|
40 |
+
or pip install -r requirements_3090.txt (GPU 3090, CUDA 11.4)
|
41 |
+
```
|
42 |
+
|
43 |
+
## DiffSpeech (语音合成的版本)
|
44 |
+
### 1. 准备工作
|
45 |
+
|
46 |
+
#### 数据准备
|
47 |
+
a) 下载并解压 [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/), 创建软链接: `ln -s /xxx/LJSpeech-1.1/ data/raw/`
|
48 |
+
|
49 |
+
b) 下载并解压 [我们用MFA预处理好的对齐](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/mfa_outputs.tar): `tar -xvf mfa_outputs.tar; mv mfa_outputs data/processed/ljspeech/`
|
50 |
+
|
51 |
+
c) 按照如下脚本给数据集打包,打包后的二进制文件用于后续的训练和推理.
|
52 |
+
|
53 |
+
```sh
|
54 |
+
export PYTHONPATH=.
|
55 |
+
CUDA_VISIBLE_DEVICES=0 python data_gen/tts/bin/binarize.py --config configs/tts/lj/fs2.yaml
|
56 |
+
|
57 |
+
# `data/binary/ljspeech` will be generated.
|
58 |
+
```
|
59 |
+
|
60 |
+
#### 声码器准备
|
61 |
+
我们提供了[HifiGAN](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/0414_hifi_lj_1.zip)声码器的预训练模型.
|
62 |
+
请在训练声学模型前,先把声码器文件解压到`checkpoints`里。
|
63 |
+
|
64 |
+
### 2. 训练样例
|
65 |
+
|
66 |
+
首先你需要一个预训练好的FastSpeech2存档点. 你可以用[我们预训练好的模型](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/fs2_lj_1.zip), 或者跑下面这个指令从零开始训练FastSpeech2:
|
67 |
+
```sh
|
68 |
+
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config configs/tts/lj/fs2.yaml --exp_name fs2_lj_1 --reset
|
69 |
+
```
|
70 |
+
然后为了训练DiffSpeech, 运行:
|
71 |
+
```sh
|
72 |
+
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/lj_ds_beta6.yaml --exp_name lj_ds_beta6_1213 --reset
|
73 |
+
```
|
74 |
+
|
75 |
+
记得针对你的路径修改`usr/configs/lj_ds_beta6.yaml`里"fs2_ckpt"这个参数.
|
76 |
+
|
77 |
+
### 3. 推理样例
|
78 |
+
|
79 |
+
```sh
|
80 |
+
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/lj_ds_beta6.yaml --exp_name lj_ds_beta6_1213 --reset --infer
|
81 |
+
```
|
82 |
+
|
83 |
+
我们也提供了:
|
84 |
+
- [DiffSpeech](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/lj_ds_beta6_1213.zip)的预训练模型;
|
85 |
+
- [FastSpeech 2](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/fs2_lj_1.zip)的预训练模型, 这是为了DiffSpeech里的浅扩散机制;
|
86 |
+
|
87 |
+
记得把预训练模型放在 `checkpoints` 目录.
|
88 |
+
|
89 |
+
## DiffSinger (歌声合成的版本)
|
90 |
+
|
91 |
+
### 0. 数据获取
|
92 |
+
- 见 [申请表](https://github.com/MoonInTheRiver/DiffSinger/blob/master/resources/apply_form.md).
|
93 |
+
- 数据集 [预览](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/popcs_preview.zip).
|
94 |
+
|
95 |
+
### 1. Preparation
|
96 |
+
#### 数据准备
|
97 |
+
a) 下载并解压PopCS, 创建软链接: `ln -s /xxx/popcs/ data/processed/popcs`
|
98 |
+
|
99 |
+
b) 按照如下脚本给数据集打包,打包后的二进制文件用于后续的训练和推理.
|
100 |
+
```sh
|
101 |
+
export PYTHONPATH=.
|
102 |
+
CUDA_VISIBLE_DEVICES=0 python data_gen/tts/bin/binarize.py --config usr/configs/popcs_ds_beta6.yaml
|
103 |
+
# `data/binary/popcs-pmf0` 会生成出来.
|
104 |
+
```
|
105 |
+
|
106 |
+
#### 声码器准备
|
107 |
+
我们提供了[HifiGAN-Singing](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/0109_hifigan_bigpopcs_hop128.zip)的预训练模型, 它专门为了歌声合成系统设计, 采用了NSF的技术。
|
108 |
+
请在训练声学模型前,先把声码器文件解压到`checkpoints`里。
|
109 |
+
|
110 |
+
(更新: 你也可以将我们提供的[训练更多步数的存档点](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/model_ckpt_steps_1512000.ckpt)放到声码器的文件夹里)
|
111 |
+
|
112 |
+
这个声码器是在大约70小时的较大数据集上训练的, 可以被认为是一个通用声码器。
|
113 |
+
|
114 |
+
### 2. 训练样例
|
115 |
+
首先你需要一个预训练好的FFT-Singer. 你可以用[我们预训练好的模型](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/popcs_fs2_pmf0_1230.zip), 或者用如下脚本从零训练FFT-Singer:
|
116 |
+
|
117 |
+
```sh
|
118 |
+
# First, train fft-singer;
|
119 |
+
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/popcs_fs2.yaml --exp_name popcs_fs2_pmf0_1230 --reset
|
120 |
+
# Then, infer fft-singer;
|
121 |
+
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/popcs_fs2.yaml --exp_name popcs_fs2_pmf0_1230 --reset --infer
|
122 |
+
```
|
123 |
+
|
124 |
+
然后, 为了训练DiffSinger, 运行:
|
125 |
+
```sh
|
126 |
+
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/popcs_ds_beta6_offline.yaml --exp_name popcs_ds_beta6_offline_pmf0_1230 --reset
|
127 |
+
```
|
128 |
+
|
129 |
+
记得针对你的路径修改`usr/configs/popcs_ds_beta6_offline.yaml`里"fs2_ckpt"这个参数.
|
130 |
+
|
131 |
+
### 3. 推理样例
|
132 |
+
```sh
|
133 |
+
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/popcs_ds_beta6_offline.yaml --exp_name popcs_ds_beta6_offline_pmf0_1230 --reset --infer
|
134 |
+
```
|
135 |
+
|
136 |
+
我们也提供了:
|
137 |
+
- [DiffSinger](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/popcs_ds_beta6_offline_pmf0_1230.zip)的预训练模型;
|
138 |
+
- [FFT-Singer](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/popcs_fs2_pmf0_1230.zip)的预训练模型, 这是为了DiffSinger里的浅扩散机制;
|
139 |
+
|
140 |
+
记得把预训练模型放在 `checkpoints` 目录.
|
141 |
+
|
142 |
+
*请注意:*
|
143 |
+
|
144 |
+
-*我们原始论文中的PWG版本声码器已投入商业使用,因此我们提供此HifiGAN版本声码器作为替代品。*
|
145 |
+
|
146 |
+
-*我们这篇论文假设提供真实的F0来进行实验,如[1][2][3]等前作所做的那样,重点在频谱建模上,而非F0曲线的预测。如果你想对MIDI数据进行实验,从MIDI和歌词预测F0曲线(显式或隐式),请查看文档[MIDI-old-version](README-SVS-opencpop-cascade.md) 或 [MIDI-new-version](README-SVS-opencpop-e2e.md)。目前已经支持的MIDI数据集有: Opencpop*
|
147 |
+
|
148 |
+
[1] Adversarially trained multi-singer sequence-to-sequence singing synthesizer. Interspeech 2020.
|
149 |
+
|
150 |
+
[2] SEQUENCE-TO-SEQUENCE SINGING SYNTHESIS USING THE FEED-FORWARD TRANSFORMER. ICASSP 2020.
|
151 |
+
|
152 |
+
[3] DeepSinger : Singing Voice Synthesis with Data Mined From the Web. KDD 2020.
|
153 |
+
|
154 |
+
## Tensorboard
|
155 |
+
```sh
|
156 |
+
tensorboard --logdir_spec exp_name
|
157 |
+
```
|
158 |
+
<table style="width:100%">
|
159 |
+
<tr>
|
160 |
+
<td><img src="resources/tfb.png" alt="Tensorboard" height="250"></td>
|
161 |
+
</tr>
|
162 |
+
</table>
|
163 |
+
|
164 |
+
## Mel 可视化
|
165 |
+
沿着纵轴, DiffSpeech: [0-80]; FastSpeech2: [80-160].
|
166 |
+
|
167 |
+
<table style="width:100%">
|
168 |
+
<tr>
|
169 |
+
<th>DiffSpeech vs. FastSpeech 2</th>
|
170 |
+
</tr>
|
171 |
+
<tr>
|
172 |
+
<td><img src="resources/diffspeech-fs2.png" alt="DiffSpeech-vs-FastSpeech2" height="250"></td>
|
173 |
+
</tr>
|
174 |
+
<tr>
|
175 |
+
<td><img src="resources/diffspeech-fs2-1.png" alt="DiffSpeech-vs-FastSpeech2" height="250"></td>
|
176 |
+
</tr>
|
177 |
+
<tr>
|
178 |
+
<td><img src="resources/diffspeech-fs2-2.png" alt="DiffSpeech-vs-FastSpeech2" height="250"></td>
|
179 |
+
</tr>
|
180 |
+
</table>
|
181 |
+
|
182 |
+
## Audio Demos
|
183 |
+
音频样本可以看我们的[样例页](https://diffsinger.github.io/).
|
184 |
+
|
185 |
+
我们也放了部分由DiffSpeech+HifiGAN (标记为[P]) 和 GTmel+HifiGAN (标记为[G]) 生成的测试集音频样例在:[resources/demos_1213](../resources/demos_1213).
|
186 |
+
|
187 |
+
(对应这个预训练参数:[DiffSpeech](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/lj_ds_beta6_1213.zip))
|
188 |
+
|
189 |
+
---
|
190 |
+
:rocket: :rocket: :rocket: **更新:**
|
191 |
+
|
192 |
+
新生成的歌声样例在:[resources/demos_0112](../resources/demos_0112).
|
193 |
+
|
194 |
+
## Citation
|
195 |
+
如果本仓库对你的研究和工作有用,请引用以下论文:
|
196 |
+
|
197 |
+
@article{liu2021diffsinger,
|
198 |
+
title={Diffsinger: Singing voice synthesis via shallow diffusion mechanism},
|
199 |
+
author={Liu, Jinglin and Li, Chengxi and Ren, Yi and Chen, Feiyang and Liu, Peng and Zhao, Zhou},
|
200 |
+
journal={arXiv preprint arXiv:2105.02446},
|
201 |
+
volume={2},
|
202 |
+
year={2021}}
|
203 |
+
|
204 |
+
|
205 |
+
## 鸣谢
|
206 |
+
我们的代码基于如下仓库:
|
207 |
+
* [denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch)
|
208 |
+
* [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning)
|
209 |
+
* [ParallelWaveGAN](https://github.com/kan-bayashi/ParallelWaveGAN)
|
210 |
+
* [HifiGAN](https://github.com/jik876/hifi-gan)
|
211 |
+
* [espnet](https://github.com/espnet/espnet)
|
212 |
+
* [DiffWave](https://github.com/lmnt-com/diffwave)
|
inference/svs/base_svs_infer.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from modules.hifigan.hifigan import HifiGanGenerator
|
6 |
+
from vocoders.hifigan import HifiGAN
|
7 |
+
from inference.svs.opencpop.map import cpop_pinyin2ph_func
|
8 |
+
|
9 |
+
from utils import load_ckpt
|
10 |
+
from utils.hparams import set_hparams, hparams
|
11 |
+
from utils.text_encoder import TokenTextEncoder
|
12 |
+
from pypinyin import pinyin, lazy_pinyin, Style
|
13 |
+
import librosa
|
14 |
+
import glob
|
15 |
+
import re
|
16 |
+
|
17 |
+
|
18 |
+
class BaseSVSInfer:
|
19 |
+
def __init__(self, hparams, device=None):
|
20 |
+
if device is None:
|
21 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
22 |
+
self.hparams = hparams
|
23 |
+
self.device = device
|
24 |
+
|
25 |
+
phone_list = ["AP", "SP", "a", "ai", "an", "ang", "ao", "b", "c", "ch", "d", "e", "ei", "en", "eng", "er", "f", "g",
|
26 |
+
"h", "i", "ia", "ian", "iang", "iao", "ie", "in", "ing", "iong", "iu", "j", "k", "l", "m", "n", "o",
|
27 |
+
"ong", "ou", "p", "q", "r", "s", "sh", "t", "u", "ua", "uai", "uan", "uang", "ui", "un", "uo", "v",
|
28 |
+
"van", "ve", "vn", "w", "x", "y", "z", "zh"]
|
29 |
+
self.ph_encoder = TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',')
|
30 |
+
self.pinyin2phs = cpop_pinyin2ph_func()
|
31 |
+
self.spk_map = {'opencpop': 0}
|
32 |
+
|
33 |
+
self.model = self.build_model()
|
34 |
+
self.model.eval()
|
35 |
+
self.model.to(self.device)
|
36 |
+
self.vocoder = self.build_vocoder()
|
37 |
+
self.vocoder.eval()
|
38 |
+
self.vocoder.to(self.device)
|
39 |
+
|
40 |
+
def build_model(self):
|
41 |
+
raise NotImplementedError
|
42 |
+
|
43 |
+
def forward_model(self, inp):
|
44 |
+
raise NotImplementedError
|
45 |
+
|
46 |
+
def build_vocoder(self):
|
47 |
+
base_dir = hparams['vocoder_ckpt']
|
48 |
+
config_path = f'{base_dir}/config.yaml'
|
49 |
+
ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key=
|
50 |
+
lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1]
|
51 |
+
print('| load HifiGAN: ', ckpt)
|
52 |
+
ckpt_dict = torch.load(ckpt, map_location="cpu")
|
53 |
+
config = set_hparams(config_path, global_hparams=False)
|
54 |
+
state = ckpt_dict["state_dict"]["model_gen"]
|
55 |
+
vocoder = HifiGanGenerator(config)
|
56 |
+
vocoder.load_state_dict(state, strict=True)
|
57 |
+
vocoder.remove_weight_norm()
|
58 |
+
vocoder = vocoder.eval().to(self.device)
|
59 |
+
return vocoder
|
60 |
+
|
61 |
+
def run_vocoder(self, c, **kwargs):
|
62 |
+
c = c.transpose(2, 1) # [B, 80, T]
|
63 |
+
f0 = kwargs.get('f0') # [B, T]
|
64 |
+
if f0 is not None and hparams.get('use_nsf'):
|
65 |
+
# f0 = torch.FloatTensor(f0).to(self.device)
|
66 |
+
y = self.vocoder(c, f0).view(-1)
|
67 |
+
else:
|
68 |
+
y = self.vocoder(c).view(-1)
|
69 |
+
# [T]
|
70 |
+
return y[None]
|
71 |
+
|
72 |
+
def preprocess_word_level_input(self, inp):
|
73 |
+
# Pypinyin can't solve polyphonic words
|
74 |
+
text_raw = inp['text'].replace('最长', '最常').replace('长睫毛', '常睫毛') \
|
75 |
+
.replace('那么长', '那么常').replace('多长', '多常') \
|
76 |
+
.replace('很长', '很常') # We hope someone could provide a better g2p module for us by opening pull requests.
|
77 |
+
|
78 |
+
# lyric
|
79 |
+
pinyins = lazy_pinyin(text_raw, strict=False)
|
80 |
+
ph_per_word_lst = [self.pinyin2phs[pinyin.strip()] for pinyin in pinyins if pinyin.strip() in self.pinyin2phs]
|
81 |
+
|
82 |
+
# Note
|
83 |
+
note_per_word_lst = [x.strip() for x in inp['notes'].split('|') if x.strip() != '']
|
84 |
+
mididur_per_word_lst = [x.strip() for x in inp['notes_duration'].split('|') if x.strip() != '']
|
85 |
+
|
86 |
+
if len(note_per_word_lst) == len(ph_per_word_lst) == len(mididur_per_word_lst):
|
87 |
+
print('Pass word-notes check.')
|
88 |
+
else:
|
89 |
+
print('The number of words does\'t match the number of notes\' windows. ',
|
90 |
+
'You should split the note(s) for each word by | mark.')
|
91 |
+
print(ph_per_word_lst, note_per_word_lst, mididur_per_word_lst)
|
92 |
+
print(len(ph_per_word_lst), len(note_per_word_lst), len(mididur_per_word_lst))
|
93 |
+
return None
|
94 |
+
|
95 |
+
note_lst = []
|
96 |
+
ph_lst = []
|
97 |
+
midi_dur_lst = []
|
98 |
+
is_slur = []
|
99 |
+
for idx, ph_per_word in enumerate(ph_per_word_lst):
|
100 |
+
# for phs in one word:
|
101 |
+
# single ph like ['ai'] or multiple phs like ['n', 'i']
|
102 |
+
ph_in_this_word = ph_per_word.split()
|
103 |
+
|
104 |
+
# for notes in one word:
|
105 |
+
# single note like ['D4'] or multiple notes like ['D4', 'E4'] which means a 'slur' here.
|
106 |
+
note_in_this_word = note_per_word_lst[idx].split()
|
107 |
+
midi_dur_in_this_word = mididur_per_word_lst[idx].split()
|
108 |
+
# process for the model input
|
109 |
+
# Step 1.
|
110 |
+
# Deal with note of 'not slur' case or the first note of 'slur' case
|
111 |
+
# j ie
|
112 |
+
# F#4/Gb4 F#4/Gb4
|
113 |
+
# 0 0
|
114 |
+
for ph in ph_in_this_word:
|
115 |
+
ph_lst.append(ph)
|
116 |
+
note_lst.append(note_in_this_word[0])
|
117 |
+
midi_dur_lst.append(midi_dur_in_this_word[0])
|
118 |
+
is_slur.append(0)
|
119 |
+
# step 2.
|
120 |
+
# Deal with the 2nd, 3rd... notes of 'slur' case
|
121 |
+
# j ie ie
|
122 |
+
# F#4/Gb4 F#4/Gb4 C#4/Db4
|
123 |
+
# 0 0 1
|
124 |
+
if len(note_in_this_word) > 1: # is_slur = True, we should repeat the YUNMU to match the 2nd, 3rd... notes.
|
125 |
+
for idx in range(1, len(note_in_this_word)):
|
126 |
+
ph_lst.append(ph_in_this_word[1])
|
127 |
+
note_lst.append(note_in_this_word[idx])
|
128 |
+
midi_dur_lst.append(midi_dur_in_this_word[idx])
|
129 |
+
is_slur.append(1)
|
130 |
+
ph_seq = ' '.join(ph_lst)
|
131 |
+
|
132 |
+
if len(ph_lst) == len(note_lst) == len(midi_dur_lst):
|
133 |
+
print(len(ph_lst), len(note_lst), len(midi_dur_lst))
|
134 |
+
print('Pass word-notes check.')
|
135 |
+
else:
|
136 |
+
print('The number of words does\'t match the number of notes\' windows. ',
|
137 |
+
'You should split the note(s) for each word by | mark.')
|
138 |
+
return None
|
139 |
+
return ph_seq, note_lst, midi_dur_lst, is_slur
|
140 |
+
|
141 |
+
def preprocess_phoneme_level_input(self, inp):
|
142 |
+
ph_seq = inp['ph_seq']
|
143 |
+
note_lst = inp['note_seq'].split()
|
144 |
+
midi_dur_lst = inp['note_dur_seq'].split()
|
145 |
+
is_slur = inp['is_slur_seq'].split()
|
146 |
+
print(len(note_lst), len(ph_seq.split()), len(midi_dur_lst))
|
147 |
+
if len(note_lst) == len(ph_seq.split()) == len(midi_dur_lst):
|
148 |
+
print('Pass word-notes check.')
|
149 |
+
else:
|
150 |
+
print('The number of words does\'t match the number of notes\' windows. ',
|
151 |
+
'You should split the note(s) for each word by | mark.')
|
152 |
+
return None
|
153 |
+
return ph_seq, note_lst, midi_dur_lst, is_slur
|
154 |
+
|
155 |
+
def preprocess_input(self, inp, input_type='word'):
|
156 |
+
"""
|
157 |
+
|
158 |
+
:param inp: {'text': str, 'item_name': (str, optional), 'spk_name': (str, optional)}
|
159 |
+
:return:
|
160 |
+
"""
|
161 |
+
|
162 |
+
item_name = inp.get('item_name', '<ITEM_NAME>')
|
163 |
+
spk_name = inp.get('spk_name', 'opencpop')
|
164 |
+
|
165 |
+
# single spk
|
166 |
+
spk_id = self.spk_map[spk_name]
|
167 |
+
|
168 |
+
# get ph seq, note lst, midi dur lst, is slur lst.
|
169 |
+
if input_type == 'word':
|
170 |
+
ret = self.preprocess_word_level_input(inp)
|
171 |
+
elif input_type == 'phoneme': # like transcriptions.txt in Opencpop dataset.
|
172 |
+
ret = self.preprocess_phoneme_level_input(inp)
|
173 |
+
else:
|
174 |
+
print('Invalid input type.')
|
175 |
+
return None
|
176 |
+
|
177 |
+
if ret:
|
178 |
+
ph_seq, note_lst, midi_dur_lst, is_slur = ret
|
179 |
+
else:
|
180 |
+
print('==========> Preprocess_word_level or phone_level input wrong.')
|
181 |
+
return None
|
182 |
+
|
183 |
+
# convert note lst to midi id; convert note dur lst to midi duration
|
184 |
+
try:
|
185 |
+
midis = [librosa.note_to_midi(x.split("/")[0]) if x != 'rest' else 0
|
186 |
+
for x in note_lst]
|
187 |
+
midi_dur_lst = [float(x) for x in midi_dur_lst]
|
188 |
+
except Exception as e:
|
189 |
+
print(e)
|
190 |
+
print('Invalid Input Type.')
|
191 |
+
return None
|
192 |
+
|
193 |
+
ph_token = self.ph_encoder.encode(ph_seq)
|
194 |
+
item = {'item_name': item_name, 'text': inp['text'], 'ph': ph_seq, 'spk_id': spk_id,
|
195 |
+
'ph_token': ph_token, 'pitch_midi': np.asarray(midis), 'midi_dur': np.asarray(midi_dur_lst),
|
196 |
+
'is_slur': np.asarray(is_slur), }
|
197 |
+
item['ph_len'] = len(item['ph_token'])
|
198 |
+
return item
|
199 |
+
|
200 |
+
def input_to_batch(self, item):
|
201 |
+
item_names = [item['item_name']]
|
202 |
+
text = [item['text']]
|
203 |
+
ph = [item['ph']]
|
204 |
+
txt_tokens = torch.LongTensor(item['ph_token'])[None, :].to(self.device)
|
205 |
+
txt_lengths = torch.LongTensor([txt_tokens.shape[1]]).to(self.device)
|
206 |
+
spk_ids = torch.LongTensor(item['spk_id'])[None, :].to(self.device)
|
207 |
+
|
208 |
+
pitch_midi = torch.LongTensor(item['pitch_midi'])[None, :hparams['max_frames']].to(self.device)
|
209 |
+
midi_dur = torch.FloatTensor(item['midi_dur'])[None, :hparams['max_frames']].to(self.device)
|
210 |
+
is_slur = torch.LongTensor(item['is_slur'])[None, :hparams['max_frames']].to(self.device)
|
211 |
+
|
212 |
+
batch = {
|
213 |
+
'item_name': item_names,
|
214 |
+
'text': text,
|
215 |
+
'ph': ph,
|
216 |
+
'txt_tokens': txt_tokens,
|
217 |
+
'txt_lengths': txt_lengths,
|
218 |
+
'spk_ids': spk_ids,
|
219 |
+
'pitch_midi': pitch_midi,
|
220 |
+
'midi_dur': midi_dur,
|
221 |
+
'is_slur': is_slur
|
222 |
+
}
|
223 |
+
return batch
|
224 |
+
|
225 |
+
def postprocess_output(self, output):
|
226 |
+
return output
|
227 |
+
|
228 |
+
def infer_once(self, inp):
|
229 |
+
inp = self.preprocess_input(inp, input_type=inp['input_type'] if inp.get('input_type') else 'word')
|
230 |
+
output = self.forward_model(inp)
|
231 |
+
output = self.postprocess_output(output)
|
232 |
+
return output
|
233 |
+
|
234 |
+
@classmethod
|
235 |
+
def example_run(cls, inp):
|
236 |
+
from utils.audio import save_wav
|
237 |
+
set_hparams(print_hparams=False)
|
238 |
+
infer_ins = cls(hparams)
|
239 |
+
out = infer_ins.infer_once(inp)
|
240 |
+
os.makedirs('infer_out', exist_ok=True)
|
241 |
+
save_wav(out, f'infer_out/example_out.wav', hparams['audio_sample_rate'])
|
242 |
+
|
243 |
+
|
244 |
+
# if __name__ == '__main__':
|
245 |
+
# debug
|
246 |
+
# a = BaseSVSInfer(hparams)
|
247 |
+
# a.preprocess_input({'text': '你 说 你 不 SP 懂 为 何 在 这 时 牵 手 AP',
|
248 |
+
# 'notes': 'D#4/Eb4 | D#4/Eb4 | D#4/Eb4 | D#4/Eb4 | rest | D#4/Eb4 | D4 | D4 | D4 | D#4/Eb4 | F4 | D#4/Eb4 | D4 | rest',
|
249 |
+
# 'notes_duration': '0.113740 | 0.329060 | 0.287950 | 0.133480 | 0.150900 | 0.484730 | 0.242010 | 0.180820 | 0.343570 | 0.152050 | 0.266720 | 0.280310 | 0.633300 | 0.444590'
|
250 |
+
# })
|
251 |
+
|
252 |
+
# b = {
|
253 |
+
# 'text': '小酒窝长睫毛AP是你最美的记号',
|
254 |
+
# 'notes': 'C#4/Db4 | F#4/Gb4 | G#4/Ab4 | A#4/Bb4 F#4/Gb4 | F#4/Gb4 C#4/Db4 | C#4/Db4 | rest | C#4/Db4 | A#4/Bb4 | G#4/Ab4 | A#4/Bb4 | G#4/Ab4 | F4 | C#4/Db4',
|
255 |
+
# 'notes_duration': '0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340'
|
256 |
+
# }
|
257 |
+
# c = {
|
258 |
+
# 'text': '小酒窝长睫毛AP是你最美的记号',
|
259 |
+
# 'ph_seq': 'x iao j iu w o ch ang ang j ie ie m ao AP sh i n i z ui m ei d e j i h ao',
|
260 |
+
# 'note_seq': 'C#4/Db4 C#4/Db4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 C#4/Db4 C#4/Db4 C#4/Db4 rest C#4/Db4 C#4/Db4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 F4 F4 C#4/Db4 C#4/Db4',
|
261 |
+
# 'note_dur_seq': '0.407140 0.407140 0.376190 0.376190 0.242180 0.242180 0.509550 0.509550 0.183420 0.315400 0.315400 0.235020 0.361660 0.361660 0.223070 0.377270 0.377270 0.340550 0.340550 0.299620 0.299620 0.344510 0.344510 0.283770 0.283770 0.323390 0.323390 0.360340 0.360340',
|
262 |
+
# 'is_slur_seq': '0 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0'
|
263 |
+
# } # input like Opencpop dataset.
|
264 |
+
# a.preprocess_input(b)
|
265 |
+
# a.preprocess_input(c, input_type='phoneme')
|
inference/svs/ds_cascade.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
# from inference.tts.fs import FastSpeechInfer
|
3 |
+
# from modules.tts.fs2_orig import FastSpeech2Orig
|
4 |
+
from inference.svs.base_svs_infer import BaseSVSInfer
|
5 |
+
from utils import load_ckpt
|
6 |
+
from utils.hparams import hparams
|
7 |
+
from usr.diff.shallow_diffusion_tts import GaussianDiffusion
|
8 |
+
from usr.diffsinger_task import DIFF_DECODERS
|
9 |
+
|
10 |
+
class DiffSingerCascadeInfer(BaseSVSInfer):
|
11 |
+
def build_model(self):
|
12 |
+
model = GaussianDiffusion(
|
13 |
+
phone_encoder=self.ph_encoder,
|
14 |
+
out_dims=hparams['audio_num_mel_bins'], denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams),
|
15 |
+
timesteps=hparams['timesteps'],
|
16 |
+
K_step=hparams['K_step'],
|
17 |
+
loss_type=hparams['diff_loss_type'],
|
18 |
+
spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
|
19 |
+
)
|
20 |
+
model.eval()
|
21 |
+
load_ckpt(model, hparams['work_dir'], 'model')
|
22 |
+
return model
|
23 |
+
|
24 |
+
def forward_model(self, inp):
|
25 |
+
sample = self.input_to_batch(inp)
|
26 |
+
txt_tokens = sample['txt_tokens'] # [B, T_t]
|
27 |
+
spk_id = sample.get('spk_ids')
|
28 |
+
with torch.no_grad():
|
29 |
+
output = self.model(txt_tokens, spk_id=spk_id, ref_mels=None, infer=True,
|
30 |
+
pitch_midi=sample['pitch_midi'], midi_dur=sample['midi_dur'],
|
31 |
+
is_slur=sample['is_slur'])
|
32 |
+
mel_out = output['mel_out'] # [B, T,80]
|
33 |
+
f0_pred = output['f0_denorm']
|
34 |
+
wav_out = self.run_vocoder(mel_out, f0=f0_pred)
|
35 |
+
wav_out = wav_out.cpu().numpy()
|
36 |
+
return wav_out[0]
|
37 |
+
|
38 |
+
|
39 |
+
if __name__ == '__main__':
|
40 |
+
inp = {
|
41 |
+
'text': '小酒窝长睫毛AP是你最美的记号',
|
42 |
+
'notes': 'C#4/Db4 | F#4/Gb4 | G#4/Ab4 | A#4/Bb4 F#4/Gb4 | F#4/Gb4 C#4/Db4 | C#4/Db4 | rest | C#4/Db4 | A#4/Bb4 | G#4/Ab4 | A#4/Bb4 | G#4/Ab4 | F4 | C#4/Db4',
|
43 |
+
'notes_duration': '0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340',
|
44 |
+
'input_type': 'word'
|
45 |
+
} # user input: Chinese characters
|
46 |
+
c = {
|
47 |
+
'text': '小酒窝长睫毛AP是你最美的记号',
|
48 |
+
'ph_seq': 'x iao j iu w o ch ang ang j ie ie m ao AP sh i n i z ui m ei d e j i h ao',
|
49 |
+
'note_seq': 'C#4/Db4 C#4/Db4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 C#4/Db4 C#4/Db4 C#4/Db4 rest C#4/Db4 C#4/Db4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 F4 F4 C#4/Db4 C#4/Db4',
|
50 |
+
'note_dur_seq': '0.407140 0.407140 0.376190 0.376190 0.242180 0.242180 0.509550 0.509550 0.183420 0.315400 0.315400 0.235020 0.361660 0.361660 0.223070 0.377270 0.377270 0.340550 0.340550 0.299620 0.299620 0.344510 0.344510 0.283770 0.283770 0.323390 0.323390 0.360340 0.360340',
|
51 |
+
'is_slur_seq': '0 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0',
|
52 |
+
'input_type': 'phoneme'
|
53 |
+
} # input like Opencpop dataset.
|
54 |
+
DiffSingerCascadeInfer.example_run(inp)
|
inference/svs/ds_e2e.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
# from inference.tts.fs import FastSpeechInfer
|
3 |
+
# from modules.tts.fs2_orig import FastSpeech2Orig
|
4 |
+
from inference.svs.base_svs_infer import BaseSVSInfer
|
5 |
+
from utils import load_ckpt
|
6 |
+
from utils.hparams import hparams
|
7 |
+
from usr.diff.shallow_diffusion_tts import GaussianDiffusion
|
8 |
+
from usr.diffsinger_task import DIFF_DECODERS
|
9 |
+
from modules.fastspeech.pe import PitchExtractor
|
10 |
+
import utils
|
11 |
+
|
12 |
+
|
13 |
+
class DiffSingerE2EInfer(BaseSVSInfer):
|
14 |
+
def build_model(self):
|
15 |
+
model = GaussianDiffusion(
|
16 |
+
phone_encoder=self.ph_encoder,
|
17 |
+
out_dims=hparams['audio_num_mel_bins'], denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams),
|
18 |
+
timesteps=hparams['timesteps'],
|
19 |
+
K_step=hparams['K_step'],
|
20 |
+
loss_type=hparams['diff_loss_type'],
|
21 |
+
spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
|
22 |
+
)
|
23 |
+
model.eval()
|
24 |
+
load_ckpt(model, hparams['work_dir'], 'model')
|
25 |
+
|
26 |
+
if hparams.get('pe_enable') is not None and hparams['pe_enable']:
|
27 |
+
self.pe = PitchExtractor().cuda()
|
28 |
+
utils.load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True)
|
29 |
+
self.pe.eval()
|
30 |
+
return model
|
31 |
+
|
32 |
+
def forward_model(self, inp):
|
33 |
+
sample = self.input_to_batch(inp)
|
34 |
+
txt_tokens = sample['txt_tokens'] # [B, T_t]
|
35 |
+
spk_id = sample.get('spk_ids')
|
36 |
+
with torch.no_grad():
|
37 |
+
output = self.model(txt_tokens, spk_id=spk_id, ref_mels=None, infer=True,
|
38 |
+
pitch_midi=sample['pitch_midi'], midi_dur=sample['midi_dur'],
|
39 |
+
is_slur=sample['is_slur'])
|
40 |
+
mel_out = output['mel_out'] # [B, T,80]
|
41 |
+
if hparams.get('pe_enable') is not None and hparams['pe_enable']:
|
42 |
+
f0_pred = self.pe(mel_out)['f0_denorm_pred'] # pe predict from Pred mel
|
43 |
+
else:
|
44 |
+
f0_pred = output['f0_denorm']
|
45 |
+
wav_out = self.run_vocoder(mel_out, f0=f0_pred)
|
46 |
+
wav_out = wav_out.cpu().numpy()
|
47 |
+
return wav_out[0]
|
48 |
+
|
49 |
+
if __name__ == '__main__':
|
50 |
+
inp = {
|
51 |
+
'text': '小酒窝长睫毛AP是你最美的记号',
|
52 |
+
'notes': 'C#4/Db4 | F#4/Gb4 | G#4/Ab4 | A#4/Bb4 F#4/Gb4 | F#4/Gb4 C#4/Db4 | C#4/Db4 | rest | C#4/Db4 | A#4/Bb4 | G#4/Ab4 | A#4/Bb4 | G#4/Ab4 | F4 | C#4/Db4',
|
53 |
+
'notes_duration': '0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340',
|
54 |
+
'input_type': 'word'
|
55 |
+
} # user input: Chinese characters
|
56 |
+
c = {
|
57 |
+
'text': '小酒窝长睫毛AP是你最美的记号',
|
58 |
+
'ph_seq': 'x iao j iu w o ch ang ang j ie ie m ao AP sh i n i z ui m ei d e j i h ao',
|
59 |
+
'note_seq': 'C#4/Db4 C#4/Db4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 C#4/Db4 C#4/Db4 C#4/Db4 rest C#4/Db4 C#4/Db4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 F4 F4 C#4/Db4 C#4/Db4',
|
60 |
+
'note_dur_seq': '0.407140 0.407140 0.376190 0.376190 0.242180 0.242180 0.509550 0.509550 0.183420 0.315400 0.315400 0.235020 0.361660 0.361660 0.223070 0.377270 0.377270 0.340550 0.340550 0.299620 0.299620 0.344510 0.344510 0.283770 0.283770 0.323390 0.323390 0.360340 0.360340',
|
61 |
+
'is_slur_seq': '0 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0',
|
62 |
+
'input_type': 'phoneme'
|
63 |
+
} # input like Opencpop dataset.
|
64 |
+
DiffSingerE2EInfer.example_run(inp)
|
65 |
+
|
66 |
+
|
67 |
+
# python inference/svs/ds_e2e.py --config usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml --exp_name 0228_opencpop_ds100_rel
|
inference/svs/gradio/gradio_settings.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
title: 'DiffSinger'
|
2 |
+
description: |
|
3 |
+
Gradio demo for DiffSinger.
|
4 |
+
|
5 |
+
请给每个汉字分配音高和时值, 每个字对应的音高和时值需要用|分隔符隔开。需要保证分隔符分割出来的音符窗口与汉字个数(AP或SP也算一个汉字)一致。
|
6 |
+
|
7 |
+
article: |
|
8 |
+
Link to <a href='https://github.com/MoonInTheRiver/DiffSinger' style='color:blue;' target='_blank\'>Github REPO</a>
|
9 |
+
example_inputs:
|
10 |
+
- |-
|
11 |
+
你 说 你 不 SP 懂 为 何 在 这 时 牵 手 AP<sep>D#4/Eb4 | D#4/Eb4 | D#4/Eb4 | D#4/Eb4 | rest | D#4/Eb4 | D4 | D4 | D4 | D#4/Eb4 | F4 | D#4/Eb4 | D4 | rest<sep>0.113740 | 0.329060 | 0.287950 | 0.133480 | 0.150900 | 0.484730 | 0.242010 | 0.180820 | 0.343570 | 0.152050 | 0.266720 | 0.280310 | 0.633300 | 0.444590
|
12 |
+
- |-
|
13 |
+
小酒窝长睫毛AP是你最美的记号<sep>C#4/Db4 | F#4/Gb4 | G#4/Ab4 | A#4/Bb4 F#4/Gb4 | F#4/Gb4 C#4/Db4 | C#4/Db4 | rest | C#4/Db4 | A#4/Bb4 | G#4/Ab4 | A#4/Bb4 | G#4/Ab4 | F4 | C#4/Db4<sep>0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340
|
14 |
+
|
15 |
+
#inference_cls: inference.svs.ds_cascade.DiffSingerCascadeInfer
|
16 |
+
#exp_name: 0303_opencpop_ds58_midi
|
17 |
+
|
18 |
+
inference_cls: inference.svs.ds_e2e.DiffSingerE2EInfer
|
19 |
+
exp_name: 0228_opencpop_ds100_rel
|
inference/svs/gradio/infer.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import re
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import yaml
|
6 |
+
from gradio.inputs import Textbox
|
7 |
+
|
8 |
+
from inference.svs.base_svs_infer import BaseSVSInfer
|
9 |
+
from utils.hparams import set_hparams
|
10 |
+
from utils.hparams import hparams as hp
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
|
14 |
+
class GradioInfer:
|
15 |
+
def __init__(self, exp_name, inference_cls, title, description, article, example_inputs):
|
16 |
+
self.exp_name = exp_name
|
17 |
+
self.title = title
|
18 |
+
self.description = description
|
19 |
+
self.article = article
|
20 |
+
self.example_inputs = example_inputs
|
21 |
+
pkg = ".".join(inference_cls.split(".")[:-1])
|
22 |
+
cls_name = inference_cls.split(".")[-1]
|
23 |
+
self.inference_cls = getattr(importlib.import_module(pkg), cls_name)
|
24 |
+
|
25 |
+
def greet(self, text, notes, notes_duration):
|
26 |
+
PUNCS = '。?;:'
|
27 |
+
sents = re.split(rf'([{PUNCS}])', text.replace('\n', ','))
|
28 |
+
sents_notes = re.split(rf'([{PUNCS}])', notes.replace('\n', ','))
|
29 |
+
sents_notes_dur = re.split(rf'([{PUNCS}])', notes_duration.replace('\n', ','))
|
30 |
+
|
31 |
+
if sents[-1] not in list(PUNCS):
|
32 |
+
sents = sents + ['']
|
33 |
+
sents_notes = sents_notes + ['']
|
34 |
+
sents_notes_dur = sents_notes_dur + ['']
|
35 |
+
|
36 |
+
audio_outs = []
|
37 |
+
s, n, n_dur = "", "", ""
|
38 |
+
for i in range(0, len(sents), 2):
|
39 |
+
if len(sents[i]) > 0:
|
40 |
+
s += sents[i] + sents[i + 1]
|
41 |
+
n += sents_notes[i] + sents_notes[i+1]
|
42 |
+
n_dur += sents_notes_dur[i] + sents_notes_dur[i+1]
|
43 |
+
if len(s) >= 400 or (i >= len(sents) - 2 and len(s) > 0):
|
44 |
+
audio_out = self.infer_ins.infer_once({
|
45 |
+
'text': s,
|
46 |
+
'notes': n,
|
47 |
+
'notes_duration': n_dur,
|
48 |
+
})
|
49 |
+
audio_out = audio_out * 32767
|
50 |
+
audio_out = audio_out.astype(np.int16)
|
51 |
+
audio_outs.append(audio_out)
|
52 |
+
audio_outs.append(np.zeros(int(hp['audio_sample_rate'] * 0.3)).astype(np.int16))
|
53 |
+
s = ""
|
54 |
+
n = ""
|
55 |
+
audio_outs = np.concatenate(audio_outs)
|
56 |
+
return hp['audio_sample_rate'], audio_outs
|
57 |
+
|
58 |
+
def run(self):
|
59 |
+
set_hparams(exp_name=self.exp_name, print_hparams=False)
|
60 |
+
infer_cls = self.inference_cls
|
61 |
+
self.infer_ins: BaseSVSInfer = infer_cls(hp)
|
62 |
+
example_inputs = self.example_inputs
|
63 |
+
for i in range(len(example_inputs)):
|
64 |
+
text, notes, notes_dur = example_inputs[i].split('<sep>')
|
65 |
+
example_inputs[i] = [text, notes, notes_dur]
|
66 |
+
|
67 |
+
iface = gr.Interface(fn=self.greet,
|
68 |
+
inputs=[
|
69 |
+
Textbox(lines=2, placeholder=None, default=example_inputs[0][0], label="input text"),
|
70 |
+
Textbox(lines=2, placeholder=None, default=example_inputs[0][1], label="input note"),
|
71 |
+
Textbox(lines=2, placeholder=None, default=example_inputs[0][2], label="input duration")]
|
72 |
+
,
|
73 |
+
outputs="audio",
|
74 |
+
allow_flagging="never",
|
75 |
+
title=self.title,
|
76 |
+
description=self.description,
|
77 |
+
article=self.article,
|
78 |
+
examples=example_inputs,
|
79 |
+
enable_queue=True)
|
80 |
+
iface.launch(share=True,)# cache_examples=True)
|
81 |
+
|
82 |
+
|
83 |
+
if __name__ == '__main__':
|
84 |
+
gradio_config = yaml.safe_load(open('inference/svs/gradio/gradio_settings.yaml'))
|
85 |
+
g = GradioInfer(**gradio_config)
|
86 |
+
g.run()
|
87 |
+
|
88 |
+
|
89 |
+
# python inference/svs/gradio/infer.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name 0303_opencpop_ds58_midi
|
90 |
+
# python inference/svs/ds_cascade.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name 0303_opencpop_ds58_midi
|
91 |
+
# CUDA_VISIBLE_DEVICES=3 python inference/svs/gradio/infer.py --config usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml --exp_name 0228_opencpop_ds100_rel
|
inference/svs/opencpop/cpop_pinyin2ph.txt
ADDED
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
| a | a |
|
2 |
+
| ai | ai |
|
3 |
+
| an | an |
|
4 |
+
| ang | ang |
|
5 |
+
| ao | ao |
|
6 |
+
| ba | b a |
|
7 |
+
| bai | b ai |
|
8 |
+
| ban | b an |
|
9 |
+
| bang | b ang |
|
10 |
+
| bao | b ao |
|
11 |
+
| bei | b ei |
|
12 |
+
| ben | b en |
|
13 |
+
| beng | b eng |
|
14 |
+
| bi | b i |
|
15 |
+
| bian | b ian |
|
16 |
+
| biao | b iao |
|
17 |
+
| bie | b ie |
|
18 |
+
| bin | b in |
|
19 |
+
| bing | b ing |
|
20 |
+
| bo | b o |
|
21 |
+
| bu | b u |
|
22 |
+
| ca | c a |
|
23 |
+
| cai | c ai |
|
24 |
+
| can | c an |
|
25 |
+
| cang | c ang |
|
26 |
+
| cao | c ao |
|
27 |
+
| ce | c e |
|
28 |
+
| cei | c ei |
|
29 |
+
| cen | c en |
|
30 |
+
| ceng | c eng |
|
31 |
+
| cha | ch a |
|
32 |
+
| chai | ch ai |
|
33 |
+
| chan | ch an |
|
34 |
+
| chang | ch ang |
|
35 |
+
| chao | ch ao |
|
36 |
+
| che | ch e |
|
37 |
+
| chen | ch en |
|
38 |
+
| cheng | ch eng |
|
39 |
+
| chi | ch i |
|
40 |
+
| chong | ch ong |
|
41 |
+
| chou | ch ou |
|
42 |
+
| chu | ch u |
|
43 |
+
| chua | ch ua |
|
44 |
+
| chuai | ch uai |
|
45 |
+
| chuan | ch uan |
|
46 |
+
| chuang | ch uang |
|
47 |
+
| chui | ch ui |
|
48 |
+
| chun | ch un |
|
49 |
+
| chuo | ch uo |
|
50 |
+
| ci | c i |
|
51 |
+
| cong | c ong |
|
52 |
+
| cou | c ou |
|
53 |
+
| cu | c u |
|
54 |
+
| cuan | c uan |
|
55 |
+
| cui | c ui |
|
56 |
+
| cun | c un |
|
57 |
+
| cuo | c uo |
|
58 |
+
| da | d a |
|
59 |
+
| dai | d ai |
|
60 |
+
| dan | d an |
|
61 |
+
| dang | d ang |
|
62 |
+
| dao | d ao |
|
63 |
+
| de | d e |
|
64 |
+
| dei | d ei |
|
65 |
+
| den | d en |
|
66 |
+
| deng | d eng |
|
67 |
+
| di | d i |
|
68 |
+
| dia | d ia |
|
69 |
+
| dian | d ian |
|
70 |
+
| diao | d iao |
|
71 |
+
| die | d ie |
|
72 |
+
| ding | d ing |
|
73 |
+
| diu | d iu |
|
74 |
+
| dong | d ong |
|
75 |
+
| dou | d ou |
|
76 |
+
| du | d u |
|
77 |
+
| duan | d uan |
|
78 |
+
| dui | d ui |
|
79 |
+
| dun | d un |
|
80 |
+
| duo | d uo |
|
81 |
+
| e | e |
|
82 |
+
| ei | ei |
|
83 |
+
| en | en |
|
84 |
+
| eng | eng |
|
85 |
+
| er | er |
|
86 |
+
| fa | f a |
|
87 |
+
| fan | f an |
|
88 |
+
| fang | f ang |
|
89 |
+
| fei | f ei |
|
90 |
+
| fen | f en |
|
91 |
+
| feng | f eng |
|
92 |
+
| fo | f o |
|
93 |
+
| fou | f ou |
|
94 |
+
| fu | f u |
|
95 |
+
| ga | g a |
|
96 |
+
| gai | g ai |
|
97 |
+
| gan | g an |
|
98 |
+
| gang | g ang |
|
99 |
+
| gao | g ao |
|
100 |
+
| ge | g e |
|
101 |
+
| gei | g ei |
|
102 |
+
| gen | g en |
|
103 |
+
| geng | g eng |
|
104 |
+
| gong | g ong |
|
105 |
+
| gou | g ou |
|
106 |
+
| gu | g u |
|
107 |
+
| gua | g ua |
|
108 |
+
| guai | g uai |
|
109 |
+
| guan | g uan |
|
110 |
+
| guang | g uang |
|
111 |
+
| gui | g ui |
|
112 |
+
| gun | g un |
|
113 |
+
| guo | g uo |
|
114 |
+
| ha | h a |
|
115 |
+
| hai | h ai |
|
116 |
+
| han | h an |
|
117 |
+
| hang | h ang |
|
118 |
+
| hao | h ao |
|
119 |
+
| he | h e |
|
120 |
+
| hei | h ei |
|
121 |
+
| hen | h en |
|
122 |
+
| heng | h eng |
|
123 |
+
| hm | h m |
|
124 |
+
| hng | h ng |
|
125 |
+
| hong | h ong |
|
126 |
+
| hou | h ou |
|
127 |
+
| hu | h u |
|
128 |
+
| hua | h ua |
|
129 |
+
| huai | h uai |
|
130 |
+
| huan | h uan |
|
131 |
+
| huang | h uang |
|
132 |
+
| hui | h ui |
|
133 |
+
| hun | h un |
|
134 |
+
| huo | h uo |
|
135 |
+
| ji | j i |
|
136 |
+
| jia | j ia |
|
137 |
+
| jian | j ian |
|
138 |
+
| jiang | j iang |
|
139 |
+
| jiao | j iao |
|
140 |
+
| jie | j ie |
|
141 |
+
| jin | j in |
|
142 |
+
| jing | j ing |
|
143 |
+
| jiong | j iong |
|
144 |
+
| jiu | j iu |
|
145 |
+
| ju | j v |
|
146 |
+
| juan | j van |
|
147 |
+
| jue | j ve |
|
148 |
+
| jun | j vn |
|
149 |
+
| ka | k a |
|
150 |
+
| kai | k ai |
|
151 |
+
| kan | k an |
|
152 |
+
| kang | k ang |
|
153 |
+
| kao | k ao |
|
154 |
+
| ke | k e |
|
155 |
+
| kei | k ei |
|
156 |
+
| ken | k en |
|
157 |
+
| keng | k eng |
|
158 |
+
| kong | k ong |
|
159 |
+
| kou | k ou |
|
160 |
+
| ku | k u |
|
161 |
+
| kua | k ua |
|
162 |
+
| kuai | k uai |
|
163 |
+
| kuan | k uan |
|
164 |
+
| kuang | k uang |
|
165 |
+
| kui | k ui |
|
166 |
+
| kun | k un |
|
167 |
+
| kuo | k uo |
|
168 |
+
| la | l a |
|
169 |
+
| lai | l ai |
|
170 |
+
| lan | l an |
|
171 |
+
| lang | l ang |
|
172 |
+
| lao | l ao |
|
173 |
+
| le | l e |
|
174 |
+
| lei | l ei |
|
175 |
+
| leng | l eng |
|
176 |
+
| li | l i |
|
177 |
+
| lia | l ia |
|
178 |
+
| lian | l ian |
|
179 |
+
| liang | l iang |
|
180 |
+
| liao | l iao |
|
181 |
+
| lie | l ie |
|
182 |
+
| lin | l in |
|
183 |
+
| ling | l ing |
|
184 |
+
| liu | l iu |
|
185 |
+
| lo | l o |
|
186 |
+
| long | l ong |
|
187 |
+
| lou | l ou |
|
188 |
+
| lu | l u |
|
189 |
+
| luan | l uan |
|
190 |
+
| lun | l un |
|
191 |
+
| luo | l uo |
|
192 |
+
| lv | l v |
|
193 |
+
| lve | l ve |
|
194 |
+
| m | m |
|
195 |
+
| ma | m a |
|
196 |
+
| mai | m ai |
|
197 |
+
| man | m an |
|
198 |
+
| mang | m ang |
|
199 |
+
| mao | m ao |
|
200 |
+
| me | m e |
|
201 |
+
| mei | m ei |
|
202 |
+
| men | m en |
|
203 |
+
| meng | m eng |
|
204 |
+
| mi | m i |
|
205 |
+
| mian | m ian |
|
206 |
+
| miao | m iao |
|
207 |
+
| mie | m ie |
|
208 |
+
| min | m in |
|
209 |
+
| ming | m ing |
|
210 |
+
| miu | m iu |
|
211 |
+
| mo | m o |
|
212 |
+
| mou | m ou |
|
213 |
+
| mu | m u |
|
214 |
+
| n | n |
|
215 |
+
| na | n a |
|
216 |
+
| nai | n ai |
|
217 |
+
| nan | n an |
|
218 |
+
| nang | n ang |
|
219 |
+
| nao | n ao |
|
220 |
+
| ne | n e |
|
221 |
+
| nei | n ei |
|
222 |
+
| nen | n en |
|
223 |
+
| neng | n eng |
|
224 |
+
| ng | n g |
|
225 |
+
| ni | n i |
|
226 |
+
| nian | n ian |
|
227 |
+
| niang | n iang |
|
228 |
+
| niao | n iao |
|
229 |
+
| nie | n ie |
|
230 |
+
| nin | n in |
|
231 |
+
| ning | n ing |
|
232 |
+
| niu | n iu |
|
233 |
+
| nong | n ong |
|
234 |
+
| nou | n ou |
|
235 |
+
| nu | n u |
|
236 |
+
| nuan | n uan |
|
237 |
+
| nun | n un |
|
238 |
+
| nuo | n uo |
|
239 |
+
| nv | n v |
|
240 |
+
| nve | n ve |
|
241 |
+
| o | o |
|
242 |
+
| ou | ou |
|
243 |
+
| pa | p a |
|
244 |
+
| pai | p ai |
|
245 |
+
| pan | p an |
|
246 |
+
| pang | p ang |
|
247 |
+
| pao | p ao |
|
248 |
+
| pei | p ei |
|
249 |
+
| pen | p en |
|
250 |
+
| peng | p eng |
|
251 |
+
| pi | p i |
|
252 |
+
| pian | p ian |
|
253 |
+
| piao | p iao |
|
254 |
+
| pie | p ie |
|
255 |
+
| pin | p in |
|
256 |
+
| ping | p ing |
|
257 |
+
| po | p o |
|
258 |
+
| pou | p ou |
|
259 |
+
| pu | p u |
|
260 |
+
| qi | q i |
|
261 |
+
| qia | q ia |
|
262 |
+
| qian | q ian |
|
263 |
+
| qiang | q iang |
|
264 |
+
| qiao | q iao |
|
265 |
+
| qie | q ie |
|
266 |
+
| qin | q in |
|
267 |
+
| qing | q ing |
|
268 |
+
| qiong | q iong |
|
269 |
+
| qiu | q iu |
|
270 |
+
| qu | q v |
|
271 |
+
| quan | q van |
|
272 |
+
| que | q ve |
|
273 |
+
| qun | q vn |
|
274 |
+
| ran | r an |
|
275 |
+
| rang | r ang |
|
276 |
+
| rao | r ao |
|
277 |
+
| re | r e |
|
278 |
+
| ren | r en |
|
279 |
+
| reng | r eng |
|
280 |
+
| ri | r i |
|
281 |
+
| rong | r ong |
|
282 |
+
| rou | r ou |
|
283 |
+
| ru | r u |
|
284 |
+
| rua | r ua |
|
285 |
+
| ruan | r uan |
|
286 |
+
| rui | r ui |
|
287 |
+
| run | r un |
|
288 |
+
| ruo | r uo |
|
289 |
+
| sa | s a |
|
290 |
+
| sai | s ai |
|
291 |
+
| san | s an |
|
292 |
+
| sang | s ang |
|
293 |
+
| sao | s ao |
|
294 |
+
| se | s e |
|
295 |
+
| sen | s en |
|
296 |
+
| seng | s eng |
|
297 |
+
| sha | sh a |
|
298 |
+
| shai | sh ai |
|
299 |
+
| shan | sh an |
|
300 |
+
| shang | sh ang |
|
301 |
+
| shao | sh ao |
|
302 |
+
| she | sh e |
|
303 |
+
| shei | sh ei |
|
304 |
+
| shen | sh en |
|
305 |
+
| sheng | sh eng |
|
306 |
+
| shi | sh i |
|
307 |
+
| shou | sh ou |
|
308 |
+
| shu | sh u |
|
309 |
+
| shua | sh ua |
|
310 |
+
| shuai | sh uai |
|
311 |
+
| shuan | sh uan |
|
312 |
+
| shuang | sh uang |
|
313 |
+
| shui | sh ui |
|
314 |
+
| shun | sh un |
|
315 |
+
| shuo | sh uo |
|
316 |
+
| si | s i |
|
317 |
+
| song | s ong |
|
318 |
+
| sou | s ou |
|
319 |
+
| su | s u |
|
320 |
+
| suan | s uan |
|
321 |
+
| sui | s ui |
|
322 |
+
| sun | s un |
|
323 |
+
| suo | s uo |
|
324 |
+
| ta | t a |
|
325 |
+
| tai | t ai |
|
326 |
+
| tan | t an |
|
327 |
+
| tang | t ang |
|
328 |
+
| tao | t ao |
|
329 |
+
| te | t e |
|
330 |
+
| tei | t ei |
|
331 |
+
| teng | t eng |
|
332 |
+
| ti | t i |
|
333 |
+
| tian | t ian |
|
334 |
+
| tiao | t iao |
|
335 |
+
| tie | t ie |
|
336 |
+
| ting | t ing |
|
337 |
+
| tong | t ong |
|
338 |
+
| tou | t ou |
|
339 |
+
| tu | t u |
|
340 |
+
| tuan | t uan |
|
341 |
+
| tui | t ui |
|
342 |
+
| tun | t un |
|
343 |
+
| tuo | t uo |
|
344 |
+
| wa | w a |
|
345 |
+
| wai | w ai |
|
346 |
+
| wan | w an |
|
347 |
+
| wang | w ang |
|
348 |
+
| wei | w ei |
|
349 |
+
| wen | w en |
|
350 |
+
| weng | w eng |
|
351 |
+
| wo | w o |
|
352 |
+
| wu | w u |
|
353 |
+
| xi | x i |
|
354 |
+
| xia | x ia |
|
355 |
+
| xian | x ian |
|
356 |
+
| xiang | x iang |
|
357 |
+
| xiao | x iao |
|
358 |
+
| xie | x ie |
|
359 |
+
| xin | x in |
|
360 |
+
| xing | x ing |
|
361 |
+
| xiong | x iong |
|
362 |
+
| xiu | x iu |
|
363 |
+
| xu | x v |
|
364 |
+
| xuan | x van |
|
365 |
+
| xue | x ve |
|
366 |
+
| xun | x vn |
|
367 |
+
| ya | y a |
|
368 |
+
| yan | y an |
|
369 |
+
| yang | y ang |
|
370 |
+
| yao | y ao |
|
371 |
+
| ye | y e |
|
372 |
+
| yi | y i |
|
373 |
+
| yin | y in |
|
374 |
+
| ying | y ing |
|
375 |
+
| yo | y o |
|
376 |
+
| yong | y ong |
|
377 |
+
| you | y ou |
|
378 |
+
| yu | y v |
|
379 |
+
| yuan | y van |
|
380 |
+
| yue | y ve |
|
381 |
+
| yun | y vn |
|
382 |
+
| za | z a |
|
383 |
+
| zai | z ai |
|
384 |
+
| zan | z an |
|
385 |
+
| zang | z ang |
|
386 |
+
| zao | z ao |
|
387 |
+
| ze | z e |
|
388 |
+
| zei | z ei |
|
389 |
+
| zen | z en |
|
390 |
+
| zeng | z eng |
|
391 |
+
| zha | zh a |
|
392 |
+
| zhai | zh ai |
|
393 |
+
| zhan | zh an |
|
394 |
+
| zhang | zh ang |
|
395 |
+
| zhao | zh ao |
|
396 |
+
| zhe | zh e |
|
397 |
+
| zhei | zh ei |
|
398 |
+
| zhen | zh en |
|
399 |
+
| zheng | zh eng |
|
400 |
+
| zhi | zh i |
|
401 |
+
| zhong | zh ong |
|
402 |
+
| zhou | zh ou |
|
403 |
+
| zhu | zh u |
|
404 |
+
| zhua | zh ua |
|
405 |
+
| zhuai | zh uai |
|
406 |
+
| zhuan | zh uan |
|
407 |
+
| zhuang | zh uang |
|
408 |
+
| zhui | zh ui |
|
409 |
+
| zhun | zh un |
|
410 |
+
| zhuo | zh uo |
|
411 |
+
| zi | z i |
|
412 |
+
| zong | z ong |
|
413 |
+
| zou | z ou |
|
414 |
+
| zu | z u |
|
415 |
+
| zuan | z uan |
|
416 |
+
| zui | z ui |
|
417 |
+
| zun | z un |
|
418 |
+
| zuo | z uo |
|
inference/svs/opencpop/map.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def cpop_pinyin2ph_func():
|
2 |
+
# In the README file of opencpop dataset, they defined a "pinyin to phoneme mapping table"
|
3 |
+
pinyin2phs = {'AP': 'AP', 'SP': 'SP'}
|
4 |
+
with open('inference/svs/opencpop/cpop_pinyin2ph.txt') as rf:
|
5 |
+
for line in rf.readlines():
|
6 |
+
elements = [x.strip() for x in line.split('|') if x.strip() != '']
|
7 |
+
pinyin2phs[elements[0]] = elements[1]
|
8 |
+
return pinyin2phs
|
modules/__init__.py
ADDED
File without changes
|
modules/commons/common_layers.py
ADDED
@@ -0,0 +1,668 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import Parameter
|
5 |
+
import torch.onnx.operators
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import utils
|
8 |
+
|
9 |
+
|
10 |
+
class Reshape(nn.Module):
|
11 |
+
def __init__(self, *args):
|
12 |
+
super(Reshape, self).__init__()
|
13 |
+
self.shape = args
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
return x.view(self.shape)
|
17 |
+
|
18 |
+
|
19 |
+
class Permute(nn.Module):
|
20 |
+
def __init__(self, *args):
|
21 |
+
super(Permute, self).__init__()
|
22 |
+
self.args = args
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
return x.permute(self.args)
|
26 |
+
|
27 |
+
|
28 |
+
class LinearNorm(torch.nn.Module):
|
29 |
+
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
30 |
+
super(LinearNorm, self).__init__()
|
31 |
+
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
32 |
+
|
33 |
+
torch.nn.init.xavier_uniform_(
|
34 |
+
self.linear_layer.weight,
|
35 |
+
gain=torch.nn.init.calculate_gain(w_init_gain))
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
return self.linear_layer(x)
|
39 |
+
|
40 |
+
|
41 |
+
class ConvNorm(torch.nn.Module):
|
42 |
+
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
|
43 |
+
padding=None, dilation=1, bias=True, w_init_gain='linear'):
|
44 |
+
super(ConvNorm, self).__init__()
|
45 |
+
if padding is None:
|
46 |
+
assert (kernel_size % 2 == 1)
|
47 |
+
padding = int(dilation * (kernel_size - 1) / 2)
|
48 |
+
|
49 |
+
self.conv = torch.nn.Conv1d(in_channels, out_channels,
|
50 |
+
kernel_size=kernel_size, stride=stride,
|
51 |
+
padding=padding, dilation=dilation,
|
52 |
+
bias=bias)
|
53 |
+
|
54 |
+
torch.nn.init.xavier_uniform_(
|
55 |
+
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
|
56 |
+
|
57 |
+
def forward(self, signal):
|
58 |
+
conv_signal = self.conv(signal)
|
59 |
+
return conv_signal
|
60 |
+
|
61 |
+
|
62 |
+
def Embedding(num_embeddings, embedding_dim, padding_idx=None):
|
63 |
+
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
64 |
+
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
|
65 |
+
if padding_idx is not None:
|
66 |
+
nn.init.constant_(m.weight[padding_idx], 0)
|
67 |
+
return m
|
68 |
+
|
69 |
+
|
70 |
+
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
|
71 |
+
if not export and torch.cuda.is_available():
|
72 |
+
try:
|
73 |
+
from apex.normalization import FusedLayerNorm
|
74 |
+
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
|
75 |
+
except ImportError:
|
76 |
+
pass
|
77 |
+
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
|
78 |
+
|
79 |
+
|
80 |
+
def Linear(in_features, out_features, bias=True):
|
81 |
+
m = nn.Linear(in_features, out_features, bias)
|
82 |
+
nn.init.xavier_uniform_(m.weight)
|
83 |
+
if bias:
|
84 |
+
nn.init.constant_(m.bias, 0.)
|
85 |
+
return m
|
86 |
+
|
87 |
+
|
88 |
+
class SinusoidalPositionalEmbedding(nn.Module):
|
89 |
+
"""This module produces sinusoidal positional embeddings of any length.
|
90 |
+
|
91 |
+
Padding symbols are ignored.
|
92 |
+
"""
|
93 |
+
|
94 |
+
def __init__(self, embedding_dim, padding_idx, init_size=1024):
|
95 |
+
super().__init__()
|
96 |
+
self.embedding_dim = embedding_dim
|
97 |
+
self.padding_idx = padding_idx
|
98 |
+
self.weights = SinusoidalPositionalEmbedding.get_embedding(
|
99 |
+
init_size,
|
100 |
+
embedding_dim,
|
101 |
+
padding_idx,
|
102 |
+
)
|
103 |
+
self.register_buffer('_float_tensor', torch.FloatTensor(1))
|
104 |
+
|
105 |
+
@staticmethod
|
106 |
+
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
|
107 |
+
"""Build sinusoidal embeddings.
|
108 |
+
|
109 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
110 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
111 |
+
"""
|
112 |
+
half_dim = embedding_dim // 2
|
113 |
+
emb = math.log(10000) / (half_dim - 1)
|
114 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
|
115 |
+
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
|
116 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
|
117 |
+
if embedding_dim % 2 == 1:
|
118 |
+
# zero pad
|
119 |
+
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
|
120 |
+
if padding_idx is not None:
|
121 |
+
emb[padding_idx, :] = 0
|
122 |
+
return emb
|
123 |
+
|
124 |
+
def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
|
125 |
+
"""Input is expected to be of size [bsz x seqlen]."""
|
126 |
+
bsz, seq_len = input.shape[:2]
|
127 |
+
max_pos = self.padding_idx + 1 + seq_len
|
128 |
+
if self.weights is None or max_pos > self.weights.size(0):
|
129 |
+
# recompute/expand embeddings if needed
|
130 |
+
self.weights = SinusoidalPositionalEmbedding.get_embedding(
|
131 |
+
max_pos,
|
132 |
+
self.embedding_dim,
|
133 |
+
self.padding_idx,
|
134 |
+
)
|
135 |
+
self.weights = self.weights.to(self._float_tensor)
|
136 |
+
|
137 |
+
if incremental_state is not None:
|
138 |
+
# positions is the same for every token when decoding a single step
|
139 |
+
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
|
140 |
+
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
|
141 |
+
|
142 |
+
positions = utils.make_positions(input, self.padding_idx) if positions is None else positions
|
143 |
+
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
|
144 |
+
|
145 |
+
def max_positions(self):
|
146 |
+
"""Maximum number of supported positions."""
|
147 |
+
return int(1e5) # an arbitrary large number
|
148 |
+
|
149 |
+
|
150 |
+
class ConvTBC(nn.Module):
|
151 |
+
def __init__(self, in_channels, out_channels, kernel_size, padding=0):
|
152 |
+
super(ConvTBC, self).__init__()
|
153 |
+
self.in_channels = in_channels
|
154 |
+
self.out_channels = out_channels
|
155 |
+
self.kernel_size = kernel_size
|
156 |
+
self.padding = padding
|
157 |
+
|
158 |
+
self.weight = torch.nn.Parameter(torch.Tensor(
|
159 |
+
self.kernel_size, in_channels, out_channels))
|
160 |
+
self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
|
161 |
+
|
162 |
+
def forward(self, input):
|
163 |
+
return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding)
|
164 |
+
|
165 |
+
|
166 |
+
class MultiheadAttention(nn.Module):
|
167 |
+
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
|
168 |
+
add_bias_kv=False, add_zero_attn=False, self_attention=False,
|
169 |
+
encoder_decoder_attention=False):
|
170 |
+
super().__init__()
|
171 |
+
self.embed_dim = embed_dim
|
172 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
173 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
174 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
175 |
+
|
176 |
+
self.num_heads = num_heads
|
177 |
+
self.dropout = dropout
|
178 |
+
self.head_dim = embed_dim // num_heads
|
179 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
180 |
+
self.scaling = self.head_dim ** -0.5
|
181 |
+
|
182 |
+
self.self_attention = self_attention
|
183 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
184 |
+
|
185 |
+
assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
|
186 |
+
'value to be of the same size'
|
187 |
+
|
188 |
+
if self.qkv_same_dim:
|
189 |
+
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
|
190 |
+
else:
|
191 |
+
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
|
192 |
+
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
|
193 |
+
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
|
194 |
+
|
195 |
+
if bias:
|
196 |
+
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
|
197 |
+
else:
|
198 |
+
self.register_parameter('in_proj_bias', None)
|
199 |
+
|
200 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
201 |
+
|
202 |
+
if add_bias_kv:
|
203 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
204 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
205 |
+
else:
|
206 |
+
self.bias_k = self.bias_v = None
|
207 |
+
|
208 |
+
self.add_zero_attn = add_zero_attn
|
209 |
+
|
210 |
+
self.reset_parameters()
|
211 |
+
|
212 |
+
self.enable_torch_version = False
|
213 |
+
if hasattr(F, "multi_head_attention_forward"):
|
214 |
+
self.enable_torch_version = True
|
215 |
+
else:
|
216 |
+
self.enable_torch_version = False
|
217 |
+
self.last_attn_probs = None
|
218 |
+
|
219 |
+
def reset_parameters(self):
|
220 |
+
if self.qkv_same_dim:
|
221 |
+
nn.init.xavier_uniform_(self.in_proj_weight)
|
222 |
+
else:
|
223 |
+
nn.init.xavier_uniform_(self.k_proj_weight)
|
224 |
+
nn.init.xavier_uniform_(self.v_proj_weight)
|
225 |
+
nn.init.xavier_uniform_(self.q_proj_weight)
|
226 |
+
|
227 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
228 |
+
if self.in_proj_bias is not None:
|
229 |
+
nn.init.constant_(self.in_proj_bias, 0.)
|
230 |
+
nn.init.constant_(self.out_proj.bias, 0.)
|
231 |
+
if self.bias_k is not None:
|
232 |
+
nn.init.xavier_normal_(self.bias_k)
|
233 |
+
if self.bias_v is not None:
|
234 |
+
nn.init.xavier_normal_(self.bias_v)
|
235 |
+
|
236 |
+
def forward(
|
237 |
+
self,
|
238 |
+
query, key, value,
|
239 |
+
key_padding_mask=None,
|
240 |
+
incremental_state=None,
|
241 |
+
need_weights=True,
|
242 |
+
static_kv=False,
|
243 |
+
attn_mask=None,
|
244 |
+
before_softmax=False,
|
245 |
+
need_head_weights=False,
|
246 |
+
enc_dec_attn_constraint_mask=None,
|
247 |
+
reset_attn_weight=None
|
248 |
+
):
|
249 |
+
"""Input shape: Time x Batch x Channel
|
250 |
+
|
251 |
+
Args:
|
252 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
253 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
254 |
+
padding elements are indicated by 1s.
|
255 |
+
need_weights (bool, optional): return the attention weights,
|
256 |
+
averaged over heads (default: False).
|
257 |
+
attn_mask (ByteTensor, optional): typically used to
|
258 |
+
implement causal attention, where the mask prevents the
|
259 |
+
attention from looking forward in time (default: None).
|
260 |
+
before_softmax (bool, optional): return the raw attention
|
261 |
+
weights and values before the attention softmax.
|
262 |
+
need_head_weights (bool, optional): return the attention
|
263 |
+
weights for each head. Implies *need_weights*. Default:
|
264 |
+
return the average attention weights over all heads.
|
265 |
+
"""
|
266 |
+
if need_head_weights:
|
267 |
+
need_weights = True
|
268 |
+
|
269 |
+
tgt_len, bsz, embed_dim = query.size()
|
270 |
+
assert embed_dim == self.embed_dim
|
271 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
272 |
+
|
273 |
+
if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
|
274 |
+
if self.qkv_same_dim:
|
275 |
+
return F.multi_head_attention_forward(query, key, value,
|
276 |
+
self.embed_dim, self.num_heads,
|
277 |
+
self.in_proj_weight,
|
278 |
+
self.in_proj_bias, self.bias_k, self.bias_v,
|
279 |
+
self.add_zero_attn, self.dropout,
|
280 |
+
self.out_proj.weight, self.out_proj.bias,
|
281 |
+
self.training, key_padding_mask, need_weights,
|
282 |
+
attn_mask)
|
283 |
+
else:
|
284 |
+
return F.multi_head_attention_forward(query, key, value,
|
285 |
+
self.embed_dim, self.num_heads,
|
286 |
+
torch.empty([0]),
|
287 |
+
self.in_proj_bias, self.bias_k, self.bias_v,
|
288 |
+
self.add_zero_attn, self.dropout,
|
289 |
+
self.out_proj.weight, self.out_proj.bias,
|
290 |
+
self.training, key_padding_mask, need_weights,
|
291 |
+
attn_mask, use_separate_proj_weight=True,
|
292 |
+
q_proj_weight=self.q_proj_weight,
|
293 |
+
k_proj_weight=self.k_proj_weight,
|
294 |
+
v_proj_weight=self.v_proj_weight)
|
295 |
+
|
296 |
+
if incremental_state is not None:
|
297 |
+
print('Not implemented error.')
|
298 |
+
exit()
|
299 |
+
else:
|
300 |
+
saved_state = None
|
301 |
+
|
302 |
+
if self.self_attention:
|
303 |
+
# self-attention
|
304 |
+
q, k, v = self.in_proj_qkv(query)
|
305 |
+
elif self.encoder_decoder_attention:
|
306 |
+
# encoder-decoder attention
|
307 |
+
q = self.in_proj_q(query)
|
308 |
+
if key is None:
|
309 |
+
assert value is None
|
310 |
+
k = v = None
|
311 |
+
else:
|
312 |
+
k = self.in_proj_k(key)
|
313 |
+
v = self.in_proj_v(key)
|
314 |
+
|
315 |
+
else:
|
316 |
+
q = self.in_proj_q(query)
|
317 |
+
k = self.in_proj_k(key)
|
318 |
+
v = self.in_proj_v(value)
|
319 |
+
q *= self.scaling
|
320 |
+
|
321 |
+
if self.bias_k is not None:
|
322 |
+
assert self.bias_v is not None
|
323 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
324 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
325 |
+
if attn_mask is not None:
|
326 |
+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
327 |
+
if key_padding_mask is not None:
|
328 |
+
key_padding_mask = torch.cat(
|
329 |
+
[key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
|
330 |
+
|
331 |
+
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
332 |
+
if k is not None:
|
333 |
+
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
334 |
+
if v is not None:
|
335 |
+
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
336 |
+
|
337 |
+
if saved_state is not None:
|
338 |
+
print('Not implemented error.')
|
339 |
+
exit()
|
340 |
+
|
341 |
+
src_len = k.size(1)
|
342 |
+
|
343 |
+
# This is part of a workaround to get around fork/join parallelism
|
344 |
+
# not supporting Optional types.
|
345 |
+
if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
|
346 |
+
key_padding_mask = None
|
347 |
+
|
348 |
+
if key_padding_mask is not None:
|
349 |
+
assert key_padding_mask.size(0) == bsz
|
350 |
+
assert key_padding_mask.size(1) == src_len
|
351 |
+
|
352 |
+
if self.add_zero_attn:
|
353 |
+
src_len += 1
|
354 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
355 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
356 |
+
if attn_mask is not None:
|
357 |
+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
358 |
+
if key_padding_mask is not None:
|
359 |
+
key_padding_mask = torch.cat(
|
360 |
+
[key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
|
361 |
+
|
362 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
363 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
364 |
+
|
365 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
366 |
+
|
367 |
+
if attn_mask is not None:
|
368 |
+
if len(attn_mask.shape) == 2:
|
369 |
+
attn_mask = attn_mask.unsqueeze(0)
|
370 |
+
elif len(attn_mask.shape) == 3:
|
371 |
+
attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
|
372 |
+
bsz * self.num_heads, tgt_len, src_len)
|
373 |
+
attn_weights = attn_weights + attn_mask
|
374 |
+
|
375 |
+
if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
|
376 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
377 |
+
attn_weights = attn_weights.masked_fill(
|
378 |
+
enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
|
379 |
+
-1e9,
|
380 |
+
)
|
381 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
382 |
+
|
383 |
+
if key_padding_mask is not None:
|
384 |
+
# don't attend to padding symbols
|
385 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
386 |
+
attn_weights = attn_weights.masked_fill(
|
387 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
388 |
+
-1e9,
|
389 |
+
)
|
390 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
391 |
+
|
392 |
+
attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
393 |
+
|
394 |
+
if before_softmax:
|
395 |
+
return attn_weights, v
|
396 |
+
|
397 |
+
attn_weights_float = utils.softmax(attn_weights, dim=-1)
|
398 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
399 |
+
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
|
400 |
+
|
401 |
+
if reset_attn_weight is not None:
|
402 |
+
if reset_attn_weight:
|
403 |
+
self.last_attn_probs = attn_probs.detach()
|
404 |
+
else:
|
405 |
+
assert self.last_attn_probs is not None
|
406 |
+
attn_probs = self.last_attn_probs
|
407 |
+
attn = torch.bmm(attn_probs, v)
|
408 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
409 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
410 |
+
attn = self.out_proj(attn)
|
411 |
+
|
412 |
+
if need_weights:
|
413 |
+
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
|
414 |
+
if not need_head_weights:
|
415 |
+
# average attention weights over heads
|
416 |
+
attn_weights = attn_weights.mean(dim=0)
|
417 |
+
else:
|
418 |
+
attn_weights = None
|
419 |
+
|
420 |
+
return attn, (attn_weights, attn_logits)
|
421 |
+
|
422 |
+
def in_proj_qkv(self, query):
|
423 |
+
return self._in_proj(query).chunk(3, dim=-1)
|
424 |
+
|
425 |
+
def in_proj_q(self, query):
|
426 |
+
if self.qkv_same_dim:
|
427 |
+
return self._in_proj(query, end=self.embed_dim)
|
428 |
+
else:
|
429 |
+
bias = self.in_proj_bias
|
430 |
+
if bias is not None:
|
431 |
+
bias = bias[:self.embed_dim]
|
432 |
+
return F.linear(query, self.q_proj_weight, bias)
|
433 |
+
|
434 |
+
def in_proj_k(self, key):
|
435 |
+
if self.qkv_same_dim:
|
436 |
+
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
|
437 |
+
else:
|
438 |
+
weight = self.k_proj_weight
|
439 |
+
bias = self.in_proj_bias
|
440 |
+
if bias is not None:
|
441 |
+
bias = bias[self.embed_dim:2 * self.embed_dim]
|
442 |
+
return F.linear(key, weight, bias)
|
443 |
+
|
444 |
+
def in_proj_v(self, value):
|
445 |
+
if self.qkv_same_dim:
|
446 |
+
return self._in_proj(value, start=2 * self.embed_dim)
|
447 |
+
else:
|
448 |
+
weight = self.v_proj_weight
|
449 |
+
bias = self.in_proj_bias
|
450 |
+
if bias is not None:
|
451 |
+
bias = bias[2 * self.embed_dim:]
|
452 |
+
return F.linear(value, weight, bias)
|
453 |
+
|
454 |
+
def _in_proj(self, input, start=0, end=None):
|
455 |
+
weight = self.in_proj_weight
|
456 |
+
bias = self.in_proj_bias
|
457 |
+
weight = weight[start:end, :]
|
458 |
+
if bias is not None:
|
459 |
+
bias = bias[start:end]
|
460 |
+
return F.linear(input, weight, bias)
|
461 |
+
|
462 |
+
|
463 |
+
def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
|
464 |
+
return attn_weights
|
465 |
+
|
466 |
+
|
467 |
+
class Swish(torch.autograd.Function):
|
468 |
+
@staticmethod
|
469 |
+
def forward(ctx, i):
|
470 |
+
result = i * torch.sigmoid(i)
|
471 |
+
ctx.save_for_backward(i)
|
472 |
+
return result
|
473 |
+
|
474 |
+
@staticmethod
|
475 |
+
def backward(ctx, grad_output):
|
476 |
+
i = ctx.saved_variables[0]
|
477 |
+
sigmoid_i = torch.sigmoid(i)
|
478 |
+
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
|
479 |
+
|
480 |
+
|
481 |
+
class CustomSwish(nn.Module):
|
482 |
+
def forward(self, input_tensor):
|
483 |
+
return Swish.apply(input_tensor)
|
484 |
+
|
485 |
+
|
486 |
+
class TransformerFFNLayer(nn.Module):
|
487 |
+
def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu'):
|
488 |
+
super().__init__()
|
489 |
+
self.kernel_size = kernel_size
|
490 |
+
self.dropout = dropout
|
491 |
+
self.act = act
|
492 |
+
if padding == 'SAME':
|
493 |
+
self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size, padding=kernel_size // 2)
|
494 |
+
elif padding == 'LEFT':
|
495 |
+
self.ffn_1 = nn.Sequential(
|
496 |
+
nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
|
497 |
+
nn.Conv1d(hidden_size, filter_size, kernel_size)
|
498 |
+
)
|
499 |
+
self.ffn_2 = Linear(filter_size, hidden_size)
|
500 |
+
if self.act == 'swish':
|
501 |
+
self.swish_fn = CustomSwish()
|
502 |
+
|
503 |
+
def forward(self, x, incremental_state=None):
|
504 |
+
# x: T x B x C
|
505 |
+
if incremental_state is not None:
|
506 |
+
assert incremental_state is None, 'Nar-generation does not allow this.'
|
507 |
+
exit(1)
|
508 |
+
|
509 |
+
x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
|
510 |
+
x = x * self.kernel_size ** -0.5
|
511 |
+
|
512 |
+
if incremental_state is not None:
|
513 |
+
x = x[-1:]
|
514 |
+
if self.act == 'gelu':
|
515 |
+
x = F.gelu(x)
|
516 |
+
if self.act == 'relu':
|
517 |
+
x = F.relu(x)
|
518 |
+
if self.act == 'swish':
|
519 |
+
x = self.swish_fn(x)
|
520 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
521 |
+
x = self.ffn_2(x)
|
522 |
+
return x
|
523 |
+
|
524 |
+
|
525 |
+
class BatchNorm1dTBC(nn.Module):
|
526 |
+
def __init__(self, c):
|
527 |
+
super(BatchNorm1dTBC, self).__init__()
|
528 |
+
self.bn = nn.BatchNorm1d(c)
|
529 |
+
|
530 |
+
def forward(self, x):
|
531 |
+
"""
|
532 |
+
|
533 |
+
:param x: [T, B, C]
|
534 |
+
:return: [T, B, C]
|
535 |
+
"""
|
536 |
+
x = x.permute(1, 2, 0) # [B, C, T]
|
537 |
+
x = self.bn(x) # [B, C, T]
|
538 |
+
x = x.permute(2, 0, 1) # [T, B, C]
|
539 |
+
return x
|
540 |
+
|
541 |
+
|
542 |
+
class EncSALayer(nn.Module):
|
543 |
+
def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
|
544 |
+
relu_dropout=0.1, kernel_size=9, padding='SAME', norm='ln', act='gelu'):
|
545 |
+
super().__init__()
|
546 |
+
self.c = c
|
547 |
+
self.dropout = dropout
|
548 |
+
self.num_heads = num_heads
|
549 |
+
if num_heads > 0:
|
550 |
+
if norm == 'ln':
|
551 |
+
self.layer_norm1 = LayerNorm(c)
|
552 |
+
elif norm == 'bn':
|
553 |
+
self.layer_norm1 = BatchNorm1dTBC(c)
|
554 |
+
self.self_attn = MultiheadAttention(
|
555 |
+
self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False,
|
556 |
+
)
|
557 |
+
if norm == 'ln':
|
558 |
+
self.layer_norm2 = LayerNorm(c)
|
559 |
+
elif norm == 'bn':
|
560 |
+
self.layer_norm2 = BatchNorm1dTBC(c)
|
561 |
+
self.ffn = TransformerFFNLayer(
|
562 |
+
c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)
|
563 |
+
|
564 |
+
def forward(self, x, encoder_padding_mask=None, **kwargs):
|
565 |
+
layer_norm_training = kwargs.get('layer_norm_training', None)
|
566 |
+
if layer_norm_training is not None:
|
567 |
+
self.layer_norm1.training = layer_norm_training
|
568 |
+
self.layer_norm2.training = layer_norm_training
|
569 |
+
if self.num_heads > 0:
|
570 |
+
residual = x
|
571 |
+
x = self.layer_norm1(x)
|
572 |
+
x, _, = self.self_attn(
|
573 |
+
query=x,
|
574 |
+
key=x,
|
575 |
+
value=x,
|
576 |
+
key_padding_mask=encoder_padding_mask
|
577 |
+
)
|
578 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
579 |
+
x = residual + x
|
580 |
+
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
|
581 |
+
|
582 |
+
residual = x
|
583 |
+
x = self.layer_norm2(x)
|
584 |
+
x = self.ffn(x)
|
585 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
586 |
+
x = residual + x
|
587 |
+
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
|
588 |
+
return x
|
589 |
+
|
590 |
+
|
591 |
+
class DecSALayer(nn.Module):
|
592 |
+
def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1, kernel_size=9, act='gelu'):
|
593 |
+
super().__init__()
|
594 |
+
self.c = c
|
595 |
+
self.dropout = dropout
|
596 |
+
self.layer_norm1 = LayerNorm(c)
|
597 |
+
self.self_attn = MultiheadAttention(
|
598 |
+
c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
|
599 |
+
)
|
600 |
+
self.layer_norm2 = LayerNorm(c)
|
601 |
+
self.encoder_attn = MultiheadAttention(
|
602 |
+
c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
|
603 |
+
)
|
604 |
+
self.layer_norm3 = LayerNorm(c)
|
605 |
+
self.ffn = TransformerFFNLayer(
|
606 |
+
c, 4 * c, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
|
607 |
+
|
608 |
+
def forward(
|
609 |
+
self,
|
610 |
+
x,
|
611 |
+
encoder_out=None,
|
612 |
+
encoder_padding_mask=None,
|
613 |
+
incremental_state=None,
|
614 |
+
self_attn_mask=None,
|
615 |
+
self_attn_padding_mask=None,
|
616 |
+
attn_out=None,
|
617 |
+
reset_attn_weight=None,
|
618 |
+
**kwargs,
|
619 |
+
):
|
620 |
+
layer_norm_training = kwargs.get('layer_norm_training', None)
|
621 |
+
if layer_norm_training is not None:
|
622 |
+
self.layer_norm1.training = layer_norm_training
|
623 |
+
self.layer_norm2.training = layer_norm_training
|
624 |
+
self.layer_norm3.training = layer_norm_training
|
625 |
+
residual = x
|
626 |
+
x = self.layer_norm1(x)
|
627 |
+
x, _ = self.self_attn(
|
628 |
+
query=x,
|
629 |
+
key=x,
|
630 |
+
value=x,
|
631 |
+
key_padding_mask=self_attn_padding_mask,
|
632 |
+
incremental_state=incremental_state,
|
633 |
+
attn_mask=self_attn_mask
|
634 |
+
)
|
635 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
636 |
+
x = residual + x
|
637 |
+
|
638 |
+
residual = x
|
639 |
+
x = self.layer_norm2(x)
|
640 |
+
if encoder_out is not None:
|
641 |
+
x, attn = self.encoder_attn(
|
642 |
+
query=x,
|
643 |
+
key=encoder_out,
|
644 |
+
value=encoder_out,
|
645 |
+
key_padding_mask=encoder_padding_mask,
|
646 |
+
incremental_state=incremental_state,
|
647 |
+
static_kv=True,
|
648 |
+
enc_dec_attn_constraint_mask=None, #utils.get_incremental_state(self, incremental_state, 'enc_dec_attn_constraint_mask'),
|
649 |
+
reset_attn_weight=reset_attn_weight
|
650 |
+
)
|
651 |
+
attn_logits = attn[1]
|
652 |
+
else:
|
653 |
+
assert attn_out is not None
|
654 |
+
x = self.encoder_attn.in_proj_v(attn_out.transpose(0, 1))
|
655 |
+
attn_logits = None
|
656 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
657 |
+
x = residual + x
|
658 |
+
|
659 |
+
residual = x
|
660 |
+
x = self.layer_norm3(x)
|
661 |
+
x = self.ffn(x, incremental_state=incremental_state)
|
662 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
663 |
+
x = residual + x
|
664 |
+
# if len(attn_logits.size()) > 3:
|
665 |
+
# indices = attn_logits.softmax(-1).max(-1).values.sum(-1).argmax(-1)
|
666 |
+
# attn_logits = attn_logits.gather(1,
|
667 |
+
# indices[:, None, None, None].repeat(1, 1, attn_logits.size(-2), attn_logits.size(-1))).squeeze(1)
|
668 |
+
return x, attn_logits
|
modules/commons/espnet_positional_embedding.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class PositionalEncoding(torch.nn.Module):
|
6 |
+
"""Positional encoding.
|
7 |
+
Args:
|
8 |
+
d_model (int): Embedding dimension.
|
9 |
+
dropout_rate (float): Dropout rate.
|
10 |
+
max_len (int): Maximum input length.
|
11 |
+
reverse (bool): Whether to reverse the input position.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
|
15 |
+
"""Construct an PositionalEncoding object."""
|
16 |
+
super(PositionalEncoding, self).__init__()
|
17 |
+
self.d_model = d_model
|
18 |
+
self.reverse = reverse
|
19 |
+
self.xscale = math.sqrt(self.d_model)
|
20 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
21 |
+
self.pe = None
|
22 |
+
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
23 |
+
|
24 |
+
def extend_pe(self, x):
|
25 |
+
"""Reset the positional encodings."""
|
26 |
+
if self.pe is not None:
|
27 |
+
if self.pe.size(1) >= x.size(1):
|
28 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
29 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
30 |
+
return
|
31 |
+
pe = torch.zeros(x.size(1), self.d_model)
|
32 |
+
if self.reverse:
|
33 |
+
position = torch.arange(
|
34 |
+
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
35 |
+
).unsqueeze(1)
|
36 |
+
else:
|
37 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
38 |
+
div_term = torch.exp(
|
39 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
40 |
+
* -(math.log(10000.0) / self.d_model)
|
41 |
+
)
|
42 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
43 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
44 |
+
pe = pe.unsqueeze(0)
|
45 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
46 |
+
|
47 |
+
def forward(self, x: torch.Tensor):
|
48 |
+
"""Add positional encoding.
|
49 |
+
Args:
|
50 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
51 |
+
Returns:
|
52 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
53 |
+
"""
|
54 |
+
self.extend_pe(x)
|
55 |
+
x = x * self.xscale + self.pe[:, : x.size(1)]
|
56 |
+
return self.dropout(x)
|
57 |
+
|
58 |
+
|
59 |
+
class ScaledPositionalEncoding(PositionalEncoding):
|
60 |
+
"""Scaled positional encoding module.
|
61 |
+
See Sec. 3.2 https://arxiv.org/abs/1809.08895
|
62 |
+
Args:
|
63 |
+
d_model (int): Embedding dimension.
|
64 |
+
dropout_rate (float): Dropout rate.
|
65 |
+
max_len (int): Maximum input length.
|
66 |
+
"""
|
67 |
+
|
68 |
+
def __init__(self, d_model, dropout_rate, max_len=5000):
|
69 |
+
"""Initialize class."""
|
70 |
+
super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
|
71 |
+
self.alpha = torch.nn.Parameter(torch.tensor(1.0))
|
72 |
+
|
73 |
+
def reset_parameters(self):
|
74 |
+
"""Reset parameters."""
|
75 |
+
self.alpha.data = torch.tensor(1.0)
|
76 |
+
|
77 |
+
def forward(self, x):
|
78 |
+
"""Add positional encoding.
|
79 |
+
Args:
|
80 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
81 |
+
Returns:
|
82 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
83 |
+
"""
|
84 |
+
self.extend_pe(x)
|
85 |
+
x = x + self.alpha * self.pe[:, : x.size(1)]
|
86 |
+
return self.dropout(x)
|
87 |
+
|
88 |
+
|
89 |
+
class RelPositionalEncoding(PositionalEncoding):
|
90 |
+
"""Relative positional encoding module.
|
91 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
92 |
+
Args:
|
93 |
+
d_model (int): Embedding dimension.
|
94 |
+
dropout_rate (float): Dropout rate.
|
95 |
+
max_len (int): Maximum input length.
|
96 |
+
"""
|
97 |
+
|
98 |
+
def __init__(self, d_model, dropout_rate, max_len=5000):
|
99 |
+
"""Initialize class."""
|
100 |
+
super().__init__(d_model, dropout_rate, max_len, reverse=True)
|
101 |
+
|
102 |
+
def forward(self, x):
|
103 |
+
"""Compute positional encoding.
|
104 |
+
Args:
|
105 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
106 |
+
Returns:
|
107 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
108 |
+
torch.Tensor: Positional embedding tensor (1, time, `*`).
|
109 |
+
"""
|
110 |
+
self.extend_pe(x)
|
111 |
+
x = x * self.xscale
|
112 |
+
pos_emb = self.pe[:, : x.size(1)]
|
113 |
+
return self.dropout(x) + self.dropout(pos_emb)
|
modules/commons/ssim.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# '''
|
2 |
+
# https://github.com/One-sixth/ms_ssim_pytorch/blob/master/ssim.py
|
3 |
+
# '''
|
4 |
+
#
|
5 |
+
# import torch
|
6 |
+
# import torch.jit
|
7 |
+
# import torch.nn.functional as F
|
8 |
+
#
|
9 |
+
#
|
10 |
+
# @torch.jit.script
|
11 |
+
# def create_window(window_size: int, sigma: float, channel: int):
|
12 |
+
# '''
|
13 |
+
# Create 1-D gauss kernel
|
14 |
+
# :param window_size: the size of gauss kernel
|
15 |
+
# :param sigma: sigma of normal distribution
|
16 |
+
# :param channel: input channel
|
17 |
+
# :return: 1D kernel
|
18 |
+
# '''
|
19 |
+
# coords = torch.arange(window_size, dtype=torch.float)
|
20 |
+
# coords -= window_size // 2
|
21 |
+
#
|
22 |
+
# g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
|
23 |
+
# g /= g.sum()
|
24 |
+
#
|
25 |
+
# g = g.reshape(1, 1, 1, -1).repeat(channel, 1, 1, 1)
|
26 |
+
# return g
|
27 |
+
#
|
28 |
+
#
|
29 |
+
# @torch.jit.script
|
30 |
+
# def _gaussian_filter(x, window_1d, use_padding: bool):
|
31 |
+
# '''
|
32 |
+
# Blur input with 1-D kernel
|
33 |
+
# :param x: batch of tensors to be blured
|
34 |
+
# :param window_1d: 1-D gauss kernel
|
35 |
+
# :param use_padding: padding image before conv
|
36 |
+
# :return: blured tensors
|
37 |
+
# '''
|
38 |
+
# C = x.shape[1]
|
39 |
+
# padding = 0
|
40 |
+
# if use_padding:
|
41 |
+
# window_size = window_1d.shape[3]
|
42 |
+
# padding = window_size // 2
|
43 |
+
# out = F.conv2d(x, window_1d, stride=1, padding=(0, padding), groups=C)
|
44 |
+
# out = F.conv2d(out, window_1d.transpose(2, 3), stride=1, padding=(padding, 0), groups=C)
|
45 |
+
# return out
|
46 |
+
#
|
47 |
+
#
|
48 |
+
# @torch.jit.script
|
49 |
+
# def ssim(X, Y, window, data_range: float, use_padding: bool = False):
|
50 |
+
# '''
|
51 |
+
# Calculate ssim index for X and Y
|
52 |
+
# :param X: images [B, C, H, N_bins]
|
53 |
+
# :param Y: images [B, C, H, N_bins]
|
54 |
+
# :param window: 1-D gauss kernel
|
55 |
+
# :param data_range: value range of input images. (usually 1.0 or 255)
|
56 |
+
# :param use_padding: padding image before conv
|
57 |
+
# :return:
|
58 |
+
# '''
|
59 |
+
#
|
60 |
+
# K1 = 0.01
|
61 |
+
# K2 = 0.03
|
62 |
+
# compensation = 1.0
|
63 |
+
#
|
64 |
+
# C1 = (K1 * data_range) ** 2
|
65 |
+
# C2 = (K2 * data_range) ** 2
|
66 |
+
#
|
67 |
+
# mu1 = _gaussian_filter(X, window, use_padding)
|
68 |
+
# mu2 = _gaussian_filter(Y, window, use_padding)
|
69 |
+
# sigma1_sq = _gaussian_filter(X * X, window, use_padding)
|
70 |
+
# sigma2_sq = _gaussian_filter(Y * Y, window, use_padding)
|
71 |
+
# sigma12 = _gaussian_filter(X * Y, window, use_padding)
|
72 |
+
#
|
73 |
+
# mu1_sq = mu1.pow(2)
|
74 |
+
# mu2_sq = mu2.pow(2)
|
75 |
+
# mu1_mu2 = mu1 * mu2
|
76 |
+
#
|
77 |
+
# sigma1_sq = compensation * (sigma1_sq - mu1_sq)
|
78 |
+
# sigma2_sq = compensation * (sigma2_sq - mu2_sq)
|
79 |
+
# sigma12 = compensation * (sigma12 - mu1_mu2)
|
80 |
+
#
|
81 |
+
# cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
|
82 |
+
# # Fixed the issue that the negative value of cs_map caused ms_ssim to output Nan.
|
83 |
+
# cs_map = cs_map.clamp_min(0.)
|
84 |
+
# ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
|
85 |
+
#
|
86 |
+
# ssim_val = ssim_map.mean(dim=(1, 2, 3)) # reduce along CHW
|
87 |
+
# cs = cs_map.mean(dim=(1, 2, 3))
|
88 |
+
#
|
89 |
+
# return ssim_val, cs
|
90 |
+
#
|
91 |
+
#
|
92 |
+
# @torch.jit.script
|
93 |
+
# def ms_ssim(X, Y, window, data_range: float, weights, use_padding: bool = False, eps: float = 1e-8):
|
94 |
+
# '''
|
95 |
+
# interface of ms-ssim
|
96 |
+
# :param X: a batch of images, (N,C,H,W)
|
97 |
+
# :param Y: a batch of images, (N,C,H,W)
|
98 |
+
# :param window: 1-D gauss kernel
|
99 |
+
# :param data_range: value range of input images. (usually 1.0 or 255)
|
100 |
+
# :param weights: weights for different levels
|
101 |
+
# :param use_padding: padding image before conv
|
102 |
+
# :param eps: use for avoid grad nan.
|
103 |
+
# :return:
|
104 |
+
# '''
|
105 |
+
# levels = weights.shape[0]
|
106 |
+
# cs_vals = []
|
107 |
+
# ssim_vals = []
|
108 |
+
# for _ in range(levels):
|
109 |
+
# ssim_val, cs = ssim(X, Y, window=window, data_range=data_range, use_padding=use_padding)
|
110 |
+
# # Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf.
|
111 |
+
# ssim_val = ssim_val.clamp_min(eps)
|
112 |
+
# cs = cs.clamp_min(eps)
|
113 |
+
# cs_vals.append(cs)
|
114 |
+
#
|
115 |
+
# ssim_vals.append(ssim_val)
|
116 |
+
# padding = (X.shape[2] % 2, X.shape[3] % 2)
|
117 |
+
# X = F.avg_pool2d(X, kernel_size=2, stride=2, padding=padding)
|
118 |
+
# Y = F.avg_pool2d(Y, kernel_size=2, stride=2, padding=padding)
|
119 |
+
#
|
120 |
+
# cs_vals = torch.stack(cs_vals, dim=0)
|
121 |
+
# ms_ssim_val = torch.prod((cs_vals[:-1] ** weights[:-1].unsqueeze(1)) * (ssim_vals[-1] ** weights[-1]), dim=0)
|
122 |
+
# return ms_ssim_val
|
123 |
+
#
|
124 |
+
#
|
125 |
+
# class SSIM(torch.jit.ScriptModule):
|
126 |
+
# __constants__ = ['data_range', 'use_padding']
|
127 |
+
#
|
128 |
+
# def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False):
|
129 |
+
# '''
|
130 |
+
# :param window_size: the size of gauss kernel
|
131 |
+
# :param window_sigma: sigma of normal distribution
|
132 |
+
# :param data_range: value range of input images. (usually 1.0 or 255)
|
133 |
+
# :param channel: input channels (default: 3)
|
134 |
+
# :param use_padding: padding image before conv
|
135 |
+
# '''
|
136 |
+
# super().__init__()
|
137 |
+
# assert window_size % 2 == 1, 'Window size must be odd.'
|
138 |
+
# window = create_window(window_size, window_sigma, channel)
|
139 |
+
# self.register_buffer('window', window)
|
140 |
+
# self.data_range = data_range
|
141 |
+
# self.use_padding = use_padding
|
142 |
+
#
|
143 |
+
# @torch.jit.script_method
|
144 |
+
# def forward(self, X, Y):
|
145 |
+
# r = ssim(X, Y, window=self.window, data_range=self.data_range, use_padding=self.use_padding)
|
146 |
+
# return r[0]
|
147 |
+
#
|
148 |
+
#
|
149 |
+
# class MS_SSIM(torch.jit.ScriptModule):
|
150 |
+
# __constants__ = ['data_range', 'use_padding', 'eps']
|
151 |
+
#
|
152 |
+
# def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False, weights=None,
|
153 |
+
# levels=None, eps=1e-8):
|
154 |
+
# '''
|
155 |
+
# class for ms-ssim
|
156 |
+
# :param window_size: the size of gauss kernel
|
157 |
+
# :param window_sigma: sigma of normal distribution
|
158 |
+
# :param data_range: value range of input images. (usually 1.0 or 255)
|
159 |
+
# :param channel: input channels
|
160 |
+
# :param use_padding: padding image before conv
|
161 |
+
# :param weights: weights for different levels. (default [0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
|
162 |
+
# :param levels: number of downsampling
|
163 |
+
# :param eps: Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf.
|
164 |
+
# '''
|
165 |
+
# super().__init__()
|
166 |
+
# assert window_size % 2 == 1, 'Window size must be odd.'
|
167 |
+
# self.data_range = data_range
|
168 |
+
# self.use_padding = use_padding
|
169 |
+
# self.eps = eps
|
170 |
+
#
|
171 |
+
# window = create_window(window_size, window_sigma, channel)
|
172 |
+
# self.register_buffer('window', window)
|
173 |
+
#
|
174 |
+
# if weights is None:
|
175 |
+
# weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
|
176 |
+
# weights = torch.tensor(weights, dtype=torch.float)
|
177 |
+
#
|
178 |
+
# if levels is not None:
|
179 |
+
# weights = weights[:levels]
|
180 |
+
# weights = weights / weights.sum()
|
181 |
+
#
|
182 |
+
# self.register_buffer('weights', weights)
|
183 |
+
#
|
184 |
+
# @torch.jit.script_method
|
185 |
+
# def forward(self, X, Y):
|
186 |
+
# return ms_ssim(X, Y, window=self.window, data_range=self.data_range, weights=self.weights,
|
187 |
+
# use_padding=self.use_padding, eps=self.eps)
|
188 |
+
#
|
189 |
+
#
|
190 |
+
# if __name__ == '__main__':
|
191 |
+
# print('Simple Test')
|
192 |
+
# im = torch.randint(0, 255, (5, 3, 256, 256), dtype=torch.float, device='cuda')
|
193 |
+
# img1 = im / 255
|
194 |
+
# img2 = img1 * 0.5
|
195 |
+
#
|
196 |
+
# losser = SSIM(data_range=1.).cuda()
|
197 |
+
# loss = losser(img1, img2).mean()
|
198 |
+
#
|
199 |
+
# losser2 = MS_SSIM(data_range=1.).cuda()
|
200 |
+
# loss2 = losser2(img1, img2).mean()
|
201 |
+
#
|
202 |
+
# print(loss.item())
|
203 |
+
# print(loss2.item())
|
204 |
+
#
|
205 |
+
# if __name__ == '__main__':
|
206 |
+
# print('Training Test')
|
207 |
+
# import cv2
|
208 |
+
# import torch.optim
|
209 |
+
# import numpy as np
|
210 |
+
# import imageio
|
211 |
+
# import time
|
212 |
+
#
|
213 |
+
# out_test_video = False
|
214 |
+
# # 最好不要直接输出gif图,会非常大,最好先输出mkv文件后用ffmpeg转换到GIF
|
215 |
+
# video_use_gif = False
|
216 |
+
#
|
217 |
+
# im = cv2.imread('test_img1.jpg', 1)
|
218 |
+
# t_im = torch.from_numpy(im).cuda().permute(2, 0, 1).float()[None] / 255.
|
219 |
+
#
|
220 |
+
# if out_test_video:
|
221 |
+
# if video_use_gif:
|
222 |
+
# fps = 0.5
|
223 |
+
# out_wh = (im.shape[1] // 2, im.shape[0] // 2)
|
224 |
+
# suffix = '.gif'
|
225 |
+
# else:
|
226 |
+
# fps = 5
|
227 |
+
# out_wh = (im.shape[1], im.shape[0])
|
228 |
+
# suffix = '.mkv'
|
229 |
+
# video_last_time = time.perf_counter()
|
230 |
+
# video = imageio.get_writer('ssim_test' + suffix, fps=fps)
|
231 |
+
#
|
232 |
+
# # 测试ssim
|
233 |
+
# print('Training SSIM')
|
234 |
+
# rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255.
|
235 |
+
# rand_im.requires_grad = True
|
236 |
+
# optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8)
|
237 |
+
# losser = SSIM(data_range=1., channel=t_im.shape[1]).cuda()
|
238 |
+
# ssim_score = 0
|
239 |
+
# while ssim_score < 0.999:
|
240 |
+
# optim.zero_grad()
|
241 |
+
# loss = losser(rand_im, t_im)
|
242 |
+
# (-loss).sum().backward()
|
243 |
+
# ssim_score = loss.item()
|
244 |
+
# optim.step()
|
245 |
+
# r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0]
|
246 |
+
# r_im = cv2.putText(r_im, 'ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2)
|
247 |
+
#
|
248 |
+
# if out_test_video:
|
249 |
+
# if time.perf_counter() - video_last_time > 1. / fps:
|
250 |
+
# video_last_time = time.perf_counter()
|
251 |
+
# out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB)
|
252 |
+
# out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA)
|
253 |
+
# if isinstance(out_frame, cv2.UMat):
|
254 |
+
# out_frame = out_frame.get()
|
255 |
+
# video.append_data(out_frame)
|
256 |
+
#
|
257 |
+
# cv2.imshow('ssim', r_im)
|
258 |
+
# cv2.setWindowTitle('ssim', 'ssim %f' % ssim_score)
|
259 |
+
# cv2.waitKey(1)
|
260 |
+
#
|
261 |
+
# if out_test_video:
|
262 |
+
# video.close()
|
263 |
+
#
|
264 |
+
# # 测试ms_ssim
|
265 |
+
# if out_test_video:
|
266 |
+
# if video_use_gif:
|
267 |
+
# fps = 0.5
|
268 |
+
# out_wh = (im.shape[1] // 2, im.shape[0] // 2)
|
269 |
+
# suffix = '.gif'
|
270 |
+
# else:
|
271 |
+
# fps = 5
|
272 |
+
# out_wh = (im.shape[1], im.shape[0])
|
273 |
+
# suffix = '.mkv'
|
274 |
+
# video_last_time = time.perf_counter()
|
275 |
+
# video = imageio.get_writer('ms_ssim_test' + suffix, fps=fps)
|
276 |
+
#
|
277 |
+
# print('Training MS_SSIM')
|
278 |
+
# rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255.
|
279 |
+
# rand_im.requires_grad = True
|
280 |
+
# optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8)
|
281 |
+
# losser = MS_SSIM(data_range=1., channel=t_im.shape[1]).cuda()
|
282 |
+
# ssim_score = 0
|
283 |
+
# while ssim_score < 0.999:
|
284 |
+
# optim.zero_grad()
|
285 |
+
# loss = losser(rand_im, t_im)
|
286 |
+
# (-loss).sum().backward()
|
287 |
+
# ssim_score = loss.item()
|
288 |
+
# optim.step()
|
289 |
+
# r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0]
|
290 |
+
# r_im = cv2.putText(r_im, 'ms_ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2)
|
291 |
+
#
|
292 |
+
# if out_test_video:
|
293 |
+
# if time.perf_counter() - video_last_time > 1. / fps:
|
294 |
+
# video_last_time = time.perf_counter()
|
295 |
+
# out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB)
|
296 |
+
# out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA)
|
297 |
+
# if isinstance(out_frame, cv2.UMat):
|
298 |
+
# out_frame = out_frame.get()
|
299 |
+
# video.append_data(out_frame)
|
300 |
+
#
|
301 |
+
# cv2.imshow('ms_ssim', r_im)
|
302 |
+
# cv2.setWindowTitle('ms_ssim', 'ms_ssim %f' % ssim_score)
|
303 |
+
# cv2.waitKey(1)
|
304 |
+
#
|
305 |
+
# if out_test_video:
|
306 |
+
# video.close()
|
307 |
+
|
308 |
+
"""
|
309 |
+
Adapted from https://github.com/Po-Hsun-Su/pytorch-ssim
|
310 |
+
"""
|
311 |
+
|
312 |
+
import torch
|
313 |
+
import torch.nn.functional as F
|
314 |
+
from torch.autograd import Variable
|
315 |
+
import numpy as np
|
316 |
+
from math import exp
|
317 |
+
|
318 |
+
|
319 |
+
def gaussian(window_size, sigma):
|
320 |
+
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
|
321 |
+
return gauss / gauss.sum()
|
322 |
+
|
323 |
+
|
324 |
+
def create_window(window_size, channel):
|
325 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
326 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
327 |
+
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
328 |
+
return window
|
329 |
+
|
330 |
+
|
331 |
+
def _ssim(img1, img2, window, window_size, channel, size_average=True):
|
332 |
+
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
333 |
+
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
334 |
+
|
335 |
+
mu1_sq = mu1.pow(2)
|
336 |
+
mu2_sq = mu2.pow(2)
|
337 |
+
mu1_mu2 = mu1 * mu2
|
338 |
+
|
339 |
+
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
340 |
+
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
341 |
+
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
|
342 |
+
|
343 |
+
C1 = 0.01 ** 2
|
344 |
+
C2 = 0.03 ** 2
|
345 |
+
|
346 |
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
347 |
+
|
348 |
+
if size_average:
|
349 |
+
return ssim_map.mean()
|
350 |
+
else:
|
351 |
+
return ssim_map.mean(1)
|
352 |
+
|
353 |
+
|
354 |
+
class SSIM(torch.nn.Module):
|
355 |
+
def __init__(self, window_size=11, size_average=True):
|
356 |
+
super(SSIM, self).__init__()
|
357 |
+
self.window_size = window_size
|
358 |
+
self.size_average = size_average
|
359 |
+
self.channel = 1
|
360 |
+
self.window = create_window(window_size, self.channel)
|
361 |
+
|
362 |
+
def forward(self, img1, img2):
|
363 |
+
(_, channel, _, _) = img1.size()
|
364 |
+
|
365 |
+
if channel == self.channel and self.window.data.type() == img1.data.type():
|
366 |
+
window = self.window
|
367 |
+
else:
|
368 |
+
window = create_window(self.window_size, channel)
|
369 |
+
|
370 |
+
if img1.is_cuda:
|
371 |
+
window = window.cuda(img1.get_device())
|
372 |
+
window = window.type_as(img1)
|
373 |
+
|
374 |
+
self.window = window
|
375 |
+
self.channel = channel
|
376 |
+
|
377 |
+
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
|
378 |
+
|
379 |
+
|
380 |
+
window = None
|
381 |
+
|
382 |
+
|
383 |
+
def ssim(img1, img2, window_size=11, size_average=True):
|
384 |
+
(_, channel, _, _) = img1.size()
|
385 |
+
global window
|
386 |
+
if window is None:
|
387 |
+
window = create_window(window_size, channel)
|
388 |
+
if img1.is_cuda:
|
389 |
+
window = window.cuda(img1.get_device())
|
390 |
+
window = window.type_as(img1)
|
391 |
+
return _ssim(img1, img2, window, window_size, channel, size_average)
|
modules/diffsinger_midi/fs2.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.commons.common_layers import *
|
2 |
+
from modules.commons.common_layers import Embedding
|
3 |
+
from modules.fastspeech.tts_modules import FastspeechDecoder, DurationPredictor, LengthRegulator, PitchPredictor, \
|
4 |
+
EnergyPredictor, FastspeechEncoder
|
5 |
+
from utils.cwt import cwt2f0
|
6 |
+
from utils.hparams import hparams
|
7 |
+
from utils.pitch_utils import f0_to_coarse, denorm_f0, norm_f0
|
8 |
+
from modules.fastspeech.fs2 import FastSpeech2
|
9 |
+
|
10 |
+
|
11 |
+
class FastspeechMIDIEncoder(FastspeechEncoder):
|
12 |
+
def forward_embedding(self, txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding):
|
13 |
+
# embed tokens and positions
|
14 |
+
x = self.embed_scale * self.embed_tokens(txt_tokens)
|
15 |
+
x = x + midi_embedding + midi_dur_embedding + slur_embedding
|
16 |
+
if hparams['use_pos_embed']:
|
17 |
+
if hparams.get('rel_pos') is not None and hparams['rel_pos']:
|
18 |
+
x = self.embed_positions(x)
|
19 |
+
else:
|
20 |
+
positions = self.embed_positions(txt_tokens)
|
21 |
+
x = x + positions
|
22 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
23 |
+
return x
|
24 |
+
|
25 |
+
def forward(self, txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding):
|
26 |
+
"""
|
27 |
+
|
28 |
+
:param txt_tokens: [B, T]
|
29 |
+
:return: {
|
30 |
+
'encoder_out': [T x B x C]
|
31 |
+
}
|
32 |
+
"""
|
33 |
+
encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
|
34 |
+
x = self.forward_embedding(txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding) # [B, T, H]
|
35 |
+
x = super(FastspeechEncoder, self).forward(x, encoder_padding_mask)
|
36 |
+
return x
|
37 |
+
|
38 |
+
|
39 |
+
FS_ENCODERS = {
|
40 |
+
'fft': lambda hp, embed_tokens, d: FastspeechMIDIEncoder(
|
41 |
+
embed_tokens, hp['hidden_size'], hp['enc_layers'], hp['enc_ffn_kernel_size'],
|
42 |
+
num_heads=hp['num_heads']),
|
43 |
+
}
|
44 |
+
|
45 |
+
|
46 |
+
class FastSpeech2MIDI(FastSpeech2):
|
47 |
+
def __init__(self, dictionary, out_dims=None):
|
48 |
+
super().__init__(dictionary, out_dims)
|
49 |
+
del self.encoder
|
50 |
+
self.encoder = FS_ENCODERS[hparams['encoder_type']](hparams, self.encoder_embed_tokens, self.dictionary)
|
51 |
+
self.midi_embed = Embedding(300, self.hidden_size, self.padding_idx)
|
52 |
+
self.midi_dur_layer = Linear(1, self.hidden_size)
|
53 |
+
self.is_slur_embed = Embedding(2, self.hidden_size)
|
54 |
+
|
55 |
+
def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
|
56 |
+
ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=False,
|
57 |
+
spk_embed_dur_id=None, spk_embed_f0_id=None, infer=False, **kwargs):
|
58 |
+
ret = {}
|
59 |
+
|
60 |
+
midi_embedding = self.midi_embed(kwargs['pitch_midi'])
|
61 |
+
midi_dur_embedding, slur_embedding = 0, 0
|
62 |
+
if kwargs.get('midi_dur') is not None:
|
63 |
+
midi_dur_embedding = self.midi_dur_layer(kwargs['midi_dur'][:, :, None]) # [B, T, 1] -> [B, T, H]
|
64 |
+
if kwargs.get('is_slur') is not None:
|
65 |
+
slur_embedding = self.is_slur_embed(kwargs['is_slur'])
|
66 |
+
encoder_out = self.encoder(txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding) # [B, T, C]
|
67 |
+
src_nonpadding = (txt_tokens > 0).float()[:, :, None]
|
68 |
+
|
69 |
+
# add ref style embed
|
70 |
+
# Not implemented
|
71 |
+
# variance encoder
|
72 |
+
var_embed = 0
|
73 |
+
|
74 |
+
# encoder_out_dur denotes encoder outputs for duration predictor
|
75 |
+
# in speech adaptation, duration predictor use old speaker embedding
|
76 |
+
if hparams['use_spk_embed']:
|
77 |
+
spk_embed_dur = spk_embed_f0 = spk_embed = self.spk_embed_proj(spk_embed)[:, None, :]
|
78 |
+
elif hparams['use_spk_id']:
|
79 |
+
spk_embed_id = spk_embed
|
80 |
+
if spk_embed_dur_id is None:
|
81 |
+
spk_embed_dur_id = spk_embed_id
|
82 |
+
if spk_embed_f0_id is None:
|
83 |
+
spk_embed_f0_id = spk_embed_id
|
84 |
+
spk_embed = self.spk_embed_proj(spk_embed_id)[:, None, :]
|
85 |
+
spk_embed_dur = spk_embed_f0 = spk_embed
|
86 |
+
if hparams['use_split_spk_id']:
|
87 |
+
spk_embed_dur = self.spk_embed_dur(spk_embed_dur_id)[:, None, :]
|
88 |
+
spk_embed_f0 = self.spk_embed_f0(spk_embed_f0_id)[:, None, :]
|
89 |
+
else:
|
90 |
+
spk_embed_dur = spk_embed_f0 = spk_embed = 0
|
91 |
+
|
92 |
+
# add dur
|
93 |
+
dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding
|
94 |
+
|
95 |
+
mel2ph = self.add_dur(dur_inp, mel2ph, txt_tokens, ret)
|
96 |
+
|
97 |
+
decoder_inp = F.pad(encoder_out, [0, 0, 1, 0])
|
98 |
+
|
99 |
+
mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
|
100 |
+
decoder_inp_origin = decoder_inp = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H]
|
101 |
+
|
102 |
+
tgt_nonpadding = (mel2ph > 0).float()[:, :, None]
|
103 |
+
|
104 |
+
# add pitch and energy embed
|
105 |
+
pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0) * tgt_nonpadding
|
106 |
+
if hparams['use_pitch_embed']:
|
107 |
+
pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding
|
108 |
+
decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)
|
109 |
+
if hparams['use_energy_embed']:
|
110 |
+
decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret)
|
111 |
+
|
112 |
+
ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding
|
113 |
+
|
114 |
+
if skip_decoder:
|
115 |
+
return ret
|
116 |
+
ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
|
117 |
+
|
118 |
+
return ret
|
modules/fastspeech/fs2.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.commons.common_layers import *
|
2 |
+
from modules.commons.common_layers import Embedding
|
3 |
+
from modules.fastspeech.tts_modules import FastspeechDecoder, DurationPredictor, LengthRegulator, PitchPredictor, \
|
4 |
+
EnergyPredictor, FastspeechEncoder
|
5 |
+
from utils.cwt import cwt2f0
|
6 |
+
from utils.hparams import hparams
|
7 |
+
from utils.pitch_utils import f0_to_coarse, denorm_f0, norm_f0
|
8 |
+
|
9 |
+
FS_ENCODERS = {
|
10 |
+
'fft': lambda hp, embed_tokens, d: FastspeechEncoder(
|
11 |
+
embed_tokens, hp['hidden_size'], hp['enc_layers'], hp['enc_ffn_kernel_size'],
|
12 |
+
num_heads=hp['num_heads']),
|
13 |
+
}
|
14 |
+
|
15 |
+
FS_DECODERS = {
|
16 |
+
'fft': lambda hp: FastspeechDecoder(
|
17 |
+
hp['hidden_size'], hp['dec_layers'], hp['dec_ffn_kernel_size'], hp['num_heads']),
|
18 |
+
}
|
19 |
+
|
20 |
+
|
21 |
+
class FastSpeech2(nn.Module):
|
22 |
+
def __init__(self, dictionary, out_dims=None):
|
23 |
+
super().__init__()
|
24 |
+
self.dictionary = dictionary
|
25 |
+
self.padding_idx = dictionary.pad()
|
26 |
+
self.enc_layers = hparams['enc_layers']
|
27 |
+
self.dec_layers = hparams['dec_layers']
|
28 |
+
self.hidden_size = hparams['hidden_size']
|
29 |
+
self.encoder_embed_tokens = self.build_embedding(self.dictionary, self.hidden_size)
|
30 |
+
self.encoder = FS_ENCODERS[hparams['encoder_type']](hparams, self.encoder_embed_tokens, self.dictionary)
|
31 |
+
self.decoder = FS_DECODERS[hparams['decoder_type']](hparams)
|
32 |
+
self.out_dims = out_dims
|
33 |
+
if out_dims is None:
|
34 |
+
self.out_dims = hparams['audio_num_mel_bins']
|
35 |
+
self.mel_out = Linear(self.hidden_size, self.out_dims, bias=True)
|
36 |
+
|
37 |
+
if hparams['use_spk_id']:
|
38 |
+
self.spk_embed_proj = Embedding(hparams['num_spk'] + 1, self.hidden_size)
|
39 |
+
if hparams['use_split_spk_id']:
|
40 |
+
self.spk_embed_f0 = Embedding(hparams['num_spk'] + 1, self.hidden_size)
|
41 |
+
self.spk_embed_dur = Embedding(hparams['num_spk'] + 1, self.hidden_size)
|
42 |
+
elif hparams['use_spk_embed']:
|
43 |
+
self.spk_embed_proj = Linear(256, self.hidden_size, bias=True)
|
44 |
+
predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size
|
45 |
+
self.dur_predictor = DurationPredictor(
|
46 |
+
self.hidden_size,
|
47 |
+
n_chans=predictor_hidden,
|
48 |
+
n_layers=hparams['dur_predictor_layers'],
|
49 |
+
dropout_rate=hparams['predictor_dropout'], padding=hparams['ffn_padding'],
|
50 |
+
kernel_size=hparams['dur_predictor_kernel'])
|
51 |
+
self.length_regulator = LengthRegulator()
|
52 |
+
if hparams['use_pitch_embed']:
|
53 |
+
self.pitch_embed = Embedding(300, self.hidden_size, self.padding_idx)
|
54 |
+
if hparams['pitch_type'] == 'cwt':
|
55 |
+
h = hparams['cwt_hidden_size']
|
56 |
+
cwt_out_dims = 10
|
57 |
+
if hparams['use_uv']:
|
58 |
+
cwt_out_dims = cwt_out_dims + 1
|
59 |
+
self.cwt_predictor = nn.Sequential(
|
60 |
+
nn.Linear(self.hidden_size, h),
|
61 |
+
PitchPredictor(
|
62 |
+
h,
|
63 |
+
n_chans=predictor_hidden,
|
64 |
+
n_layers=hparams['predictor_layers'],
|
65 |
+
dropout_rate=hparams['predictor_dropout'], odim=cwt_out_dims,
|
66 |
+
padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel']))
|
67 |
+
self.cwt_stats_layers = nn.Sequential(
|
68 |
+
nn.Linear(self.hidden_size, h), nn.ReLU(),
|
69 |
+
nn.Linear(h, h), nn.ReLU(), nn.Linear(h, 2)
|
70 |
+
)
|
71 |
+
else:
|
72 |
+
self.pitch_predictor = PitchPredictor(
|
73 |
+
self.hidden_size,
|
74 |
+
n_chans=predictor_hidden,
|
75 |
+
n_layers=hparams['predictor_layers'],
|
76 |
+
dropout_rate=hparams['predictor_dropout'],
|
77 |
+
odim=2 if hparams['pitch_type'] == 'frame' else 1,
|
78 |
+
padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
|
79 |
+
if hparams['use_energy_embed']:
|
80 |
+
self.energy_embed = Embedding(256, self.hidden_size, self.padding_idx)
|
81 |
+
self.energy_predictor = EnergyPredictor(
|
82 |
+
self.hidden_size,
|
83 |
+
n_chans=predictor_hidden,
|
84 |
+
n_layers=hparams['predictor_layers'],
|
85 |
+
dropout_rate=hparams['predictor_dropout'], odim=1,
|
86 |
+
padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
|
87 |
+
|
88 |
+
def build_embedding(self, dictionary, embed_dim):
|
89 |
+
num_embeddings = len(dictionary)
|
90 |
+
emb = Embedding(num_embeddings, embed_dim, self.padding_idx)
|
91 |
+
return emb
|
92 |
+
|
93 |
+
def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
|
94 |
+
ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=False,
|
95 |
+
spk_embed_dur_id=None, spk_embed_f0_id=None, infer=False, **kwargs):
|
96 |
+
ret = {}
|
97 |
+
encoder_out = self.encoder(txt_tokens) # [B, T, C]
|
98 |
+
src_nonpadding = (txt_tokens > 0).float()[:, :, None]
|
99 |
+
|
100 |
+
# add ref style embed
|
101 |
+
# Not implemented
|
102 |
+
# variance encoder
|
103 |
+
var_embed = 0
|
104 |
+
|
105 |
+
# encoder_out_dur denotes encoder outputs for duration predictor
|
106 |
+
# in speech adaptation, duration predictor use old speaker embedding
|
107 |
+
if hparams['use_spk_embed']:
|
108 |
+
spk_embed_dur = spk_embed_f0 = spk_embed = self.spk_embed_proj(spk_embed)[:, None, :]
|
109 |
+
elif hparams['use_spk_id']:
|
110 |
+
spk_embed_id = spk_embed
|
111 |
+
if spk_embed_dur_id is None:
|
112 |
+
spk_embed_dur_id = spk_embed_id
|
113 |
+
if spk_embed_f0_id is None:
|
114 |
+
spk_embed_f0_id = spk_embed_id
|
115 |
+
spk_embed = self.spk_embed_proj(spk_embed_id)[:, None, :]
|
116 |
+
spk_embed_dur = spk_embed_f0 = spk_embed
|
117 |
+
if hparams['use_split_spk_id']:
|
118 |
+
spk_embed_dur = self.spk_embed_dur(spk_embed_dur_id)[:, None, :]
|
119 |
+
spk_embed_f0 = self.spk_embed_f0(spk_embed_f0_id)[:, None, :]
|
120 |
+
else:
|
121 |
+
spk_embed_dur = spk_embed_f0 = spk_embed = 0
|
122 |
+
|
123 |
+
# add dur
|
124 |
+
dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding
|
125 |
+
|
126 |
+
mel2ph = self.add_dur(dur_inp, mel2ph, txt_tokens, ret)
|
127 |
+
|
128 |
+
decoder_inp = F.pad(encoder_out, [0, 0, 1, 0])
|
129 |
+
|
130 |
+
mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
|
131 |
+
decoder_inp_origin = decoder_inp = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H]
|
132 |
+
|
133 |
+
tgt_nonpadding = (mel2ph > 0).float()[:, :, None]
|
134 |
+
|
135 |
+
# add pitch and energy embed
|
136 |
+
pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0) * tgt_nonpadding
|
137 |
+
if hparams['use_pitch_embed']:
|
138 |
+
pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding
|
139 |
+
decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)
|
140 |
+
if hparams['use_energy_embed']:
|
141 |
+
decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret)
|
142 |
+
|
143 |
+
ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding
|
144 |
+
|
145 |
+
if skip_decoder:
|
146 |
+
return ret
|
147 |
+
ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
|
148 |
+
|
149 |
+
return ret
|
150 |
+
|
151 |
+
def add_dur(self, dur_input, mel2ph, txt_tokens, ret):
|
152 |
+
"""
|
153 |
+
|
154 |
+
:param dur_input: [B, T_txt, H]
|
155 |
+
:param mel2ph: [B, T_mel]
|
156 |
+
:param txt_tokens: [B, T_txt]
|
157 |
+
:param ret:
|
158 |
+
:return:
|
159 |
+
"""
|
160 |
+
src_padding = txt_tokens == 0
|
161 |
+
dur_input = dur_input.detach() + hparams['predictor_grad'] * (dur_input - dur_input.detach())
|
162 |
+
if mel2ph is None:
|
163 |
+
dur, xs = self.dur_predictor.inference(dur_input, src_padding)
|
164 |
+
ret['dur'] = xs
|
165 |
+
ret['dur_choice'] = dur
|
166 |
+
mel2ph = self.length_regulator(dur, src_padding).detach()
|
167 |
+
# from modules.fastspeech.fake_modules import FakeLengthRegulator
|
168 |
+
# fake_lr = FakeLengthRegulator()
|
169 |
+
# fake_mel2ph = fake_lr(dur, (1 - src_padding.long()).sum(-1))[..., 0].detach()
|
170 |
+
# print(mel2ph == fake_mel2ph)
|
171 |
+
else:
|
172 |
+
ret['dur'] = self.dur_predictor(dur_input, src_padding)
|
173 |
+
ret['mel2ph'] = mel2ph
|
174 |
+
return mel2ph
|
175 |
+
|
176 |
+
def add_energy(self, decoder_inp, energy, ret):
|
177 |
+
decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
|
178 |
+
ret['energy_pred'] = energy_pred = self.energy_predictor(decoder_inp)[:, :, 0]
|
179 |
+
if energy is None:
|
180 |
+
energy = energy_pred
|
181 |
+
energy = torch.clamp(energy * 256 // 4, max=255).long()
|
182 |
+
energy_embed = self.energy_embed(energy)
|
183 |
+
return energy_embed
|
184 |
+
|
185 |
+
def add_pitch(self, decoder_inp, f0, uv, mel2ph, ret, encoder_out=None):
|
186 |
+
if hparams['pitch_type'] == 'ph':
|
187 |
+
pitch_pred_inp = encoder_out.detach() + hparams['predictor_grad'] * (encoder_out - encoder_out.detach())
|
188 |
+
pitch_padding = encoder_out.sum().abs() == 0
|
189 |
+
ret['pitch_pred'] = pitch_pred = self.pitch_predictor(pitch_pred_inp)
|
190 |
+
if f0 is None:
|
191 |
+
f0 = pitch_pred[:, :, 0]
|
192 |
+
ret['f0_denorm'] = f0_denorm = denorm_f0(f0, None, hparams, pitch_padding=pitch_padding)
|
193 |
+
pitch = f0_to_coarse(f0_denorm) # start from 0 [B, T_txt]
|
194 |
+
pitch = F.pad(pitch, [1, 0])
|
195 |
+
pitch = torch.gather(pitch, 1, mel2ph) # [B, T_mel]
|
196 |
+
pitch_embed = self.pitch_embed(pitch)
|
197 |
+
return pitch_embed
|
198 |
+
decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
|
199 |
+
|
200 |
+
pitch_padding = mel2ph == 0
|
201 |
+
|
202 |
+
if hparams['pitch_type'] == 'cwt':
|
203 |
+
pitch_padding = None
|
204 |
+
ret['cwt'] = cwt_out = self.cwt_predictor(decoder_inp)
|
205 |
+
stats_out = self.cwt_stats_layers(encoder_out[:, 0, :]) # [B, 2]
|
206 |
+
mean = ret['f0_mean'] = stats_out[:, 0]
|
207 |
+
std = ret['f0_std'] = stats_out[:, 1]
|
208 |
+
cwt_spec = cwt_out[:, :, :10]
|
209 |
+
if f0 is None:
|
210 |
+
std = std * hparams['cwt_std_scale']
|
211 |
+
f0 = self.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
|
212 |
+
if hparams['use_uv']:
|
213 |
+
assert cwt_out.shape[-1] == 11
|
214 |
+
uv = cwt_out[:, :, -1] > 0
|
215 |
+
elif hparams['pitch_ar']:
|
216 |
+
ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp, f0 if self.training else None)
|
217 |
+
if f0 is None:
|
218 |
+
f0 = pitch_pred[:, :, 0]
|
219 |
+
else:
|
220 |
+
ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp)
|
221 |
+
if f0 is None:
|
222 |
+
f0 = pitch_pred[:, :, 0]
|
223 |
+
if hparams['use_uv'] and uv is None:
|
224 |
+
uv = pitch_pred[:, :, 1] > 0
|
225 |
+
ret['f0_denorm'] = f0_denorm = denorm_f0(f0, uv, hparams, pitch_padding=pitch_padding)
|
226 |
+
if pitch_padding is not None:
|
227 |
+
f0[pitch_padding] = 0
|
228 |
+
|
229 |
+
pitch = f0_to_coarse(f0_denorm) # start from 0
|
230 |
+
pitch_embed = self.pitch_embed(pitch)
|
231 |
+
return pitch_embed
|
232 |
+
|
233 |
+
def run_decoder(self, decoder_inp, tgt_nonpadding, ret, infer, **kwargs):
|
234 |
+
x = decoder_inp # [B, T, H]
|
235 |
+
x = self.decoder(x)
|
236 |
+
x = self.mel_out(x)
|
237 |
+
return x * tgt_nonpadding
|
238 |
+
|
239 |
+
def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph):
|
240 |
+
f0 = cwt2f0(cwt_spec, mean, std, hparams['cwt_scales'])
|
241 |
+
f0 = torch.cat(
|
242 |
+
[f0] + [f0[:, -1:]] * (mel2ph.shape[1] - f0.shape[1]), 1)
|
243 |
+
f0_norm = norm_f0(f0, None, hparams)
|
244 |
+
return f0_norm
|
245 |
+
|
246 |
+
def out2mel(self, out):
|
247 |
+
return out
|
248 |
+
|
249 |
+
@staticmethod
|
250 |
+
def mel_norm(x):
|
251 |
+
return (x + 5.5) / (6.3 / 2) - 1
|
252 |
+
|
253 |
+
@staticmethod
|
254 |
+
def mel_denorm(x):
|
255 |
+
return (x + 1) * (6.3 / 2) - 5.5
|
modules/fastspeech/pe.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.commons.common_layers import *
|
2 |
+
from utils.hparams import hparams
|
3 |
+
from modules.fastspeech.tts_modules import PitchPredictor
|
4 |
+
from utils.pitch_utils import denorm_f0
|
5 |
+
|
6 |
+
|
7 |
+
class Prenet(nn.Module):
|
8 |
+
def __init__(self, in_dim=80, out_dim=256, kernel=5, n_layers=3, strides=None):
|
9 |
+
super(Prenet, self).__init__()
|
10 |
+
padding = kernel // 2
|
11 |
+
self.layers = []
|
12 |
+
self.strides = strides if strides is not None else [1] * n_layers
|
13 |
+
for l in range(n_layers):
|
14 |
+
self.layers.append(nn.Sequential(
|
15 |
+
nn.Conv1d(in_dim, out_dim, kernel_size=kernel, padding=padding, stride=self.strides[l]),
|
16 |
+
nn.ReLU(),
|
17 |
+
nn.BatchNorm1d(out_dim)
|
18 |
+
))
|
19 |
+
in_dim = out_dim
|
20 |
+
self.layers = nn.ModuleList(self.layers)
|
21 |
+
self.out_proj = nn.Linear(out_dim, out_dim)
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
"""
|
25 |
+
|
26 |
+
:param x: [B, T, 80]
|
27 |
+
:return: [L, B, T, H], [B, T, H]
|
28 |
+
"""
|
29 |
+
padding_mask = x.abs().sum(-1).eq(0).data # [B, T]
|
30 |
+
nonpadding_mask_TB = 1 - padding_mask.float()[:, None, :] # [B, 1, T]
|
31 |
+
x = x.transpose(1, 2)
|
32 |
+
hiddens = []
|
33 |
+
for i, l in enumerate(self.layers):
|
34 |
+
nonpadding_mask_TB = nonpadding_mask_TB[:, :, ::self.strides[i]]
|
35 |
+
x = l(x) * nonpadding_mask_TB
|
36 |
+
hiddens.append(x)
|
37 |
+
hiddens = torch.stack(hiddens, 0) # [L, B, H, T]
|
38 |
+
hiddens = hiddens.transpose(2, 3) # [L, B, T, H]
|
39 |
+
x = self.out_proj(x.transpose(1, 2)) # [B, T, H]
|
40 |
+
x = x * nonpadding_mask_TB.transpose(1, 2)
|
41 |
+
return hiddens, x
|
42 |
+
|
43 |
+
|
44 |
+
class ConvBlock(nn.Module):
|
45 |
+
def __init__(self, idim=80, n_chans=256, kernel_size=3, stride=1, norm='gn', dropout=0):
|
46 |
+
super().__init__()
|
47 |
+
self.conv = ConvNorm(idim, n_chans, kernel_size, stride=stride)
|
48 |
+
self.norm = norm
|
49 |
+
if self.norm == 'bn':
|
50 |
+
self.norm = nn.BatchNorm1d(n_chans)
|
51 |
+
elif self.norm == 'in':
|
52 |
+
self.norm = nn.InstanceNorm1d(n_chans, affine=True)
|
53 |
+
elif self.norm == 'gn':
|
54 |
+
self.norm = nn.GroupNorm(n_chans // 16, n_chans)
|
55 |
+
elif self.norm == 'ln':
|
56 |
+
self.norm = LayerNorm(n_chans // 16, n_chans)
|
57 |
+
elif self.norm == 'wn':
|
58 |
+
self.conv = torch.nn.utils.weight_norm(self.conv.conv)
|
59 |
+
self.dropout = nn.Dropout(dropout)
|
60 |
+
self.relu = nn.ReLU()
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
"""
|
64 |
+
|
65 |
+
:param x: [B, C, T]
|
66 |
+
:return: [B, C, T]
|
67 |
+
"""
|
68 |
+
x = self.conv(x)
|
69 |
+
if not isinstance(self.norm, str):
|
70 |
+
if self.norm == 'none':
|
71 |
+
pass
|
72 |
+
elif self.norm == 'ln':
|
73 |
+
x = self.norm(x.transpose(1, 2)).transpose(1, 2)
|
74 |
+
else:
|
75 |
+
x = self.norm(x)
|
76 |
+
x = self.relu(x)
|
77 |
+
x = self.dropout(x)
|
78 |
+
return x
|
79 |
+
|
80 |
+
|
81 |
+
class ConvStacks(nn.Module):
|
82 |
+
def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=5, norm='gn',
|
83 |
+
dropout=0, strides=None, res=True):
|
84 |
+
super().__init__()
|
85 |
+
self.conv = torch.nn.ModuleList()
|
86 |
+
self.kernel_size = kernel_size
|
87 |
+
self.res = res
|
88 |
+
self.in_proj = Linear(idim, n_chans)
|
89 |
+
if strides is None:
|
90 |
+
strides = [1] * n_layers
|
91 |
+
else:
|
92 |
+
assert len(strides) == n_layers
|
93 |
+
for idx in range(n_layers):
|
94 |
+
self.conv.append(ConvBlock(
|
95 |
+
n_chans, n_chans, kernel_size, stride=strides[idx], norm=norm, dropout=dropout))
|
96 |
+
self.out_proj = Linear(n_chans, odim)
|
97 |
+
|
98 |
+
def forward(self, x, return_hiddens=False):
|
99 |
+
"""
|
100 |
+
|
101 |
+
:param x: [B, T, H]
|
102 |
+
:return: [B, T, H]
|
103 |
+
"""
|
104 |
+
x = self.in_proj(x)
|
105 |
+
x = x.transpose(1, -1) # (B, idim, Tmax)
|
106 |
+
hiddens = []
|
107 |
+
for f in self.conv:
|
108 |
+
x_ = f(x)
|
109 |
+
x = x + x_ if self.res else x_ # (B, C, Tmax)
|
110 |
+
hiddens.append(x)
|
111 |
+
x = x.transpose(1, -1)
|
112 |
+
x = self.out_proj(x) # (B, Tmax, H)
|
113 |
+
if return_hiddens:
|
114 |
+
hiddens = torch.stack(hiddens, 1) # [B, L, C, T]
|
115 |
+
return x, hiddens
|
116 |
+
return x
|
117 |
+
|
118 |
+
|
119 |
+
class PitchExtractor(nn.Module):
|
120 |
+
def __init__(self, n_mel_bins=80, conv_layers=2):
|
121 |
+
super().__init__()
|
122 |
+
self.hidden_size = hparams['hidden_size']
|
123 |
+
self.predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size
|
124 |
+
self.conv_layers = conv_layers
|
125 |
+
|
126 |
+
self.mel_prenet = Prenet(n_mel_bins, self.hidden_size, strides=[1, 1, 1])
|
127 |
+
if self.conv_layers > 0:
|
128 |
+
self.mel_encoder = ConvStacks(
|
129 |
+
idim=self.hidden_size, n_chans=self.hidden_size, odim=self.hidden_size, n_layers=self.conv_layers)
|
130 |
+
self.pitch_predictor = PitchPredictor(
|
131 |
+
self.hidden_size, n_chans=self.predictor_hidden,
|
132 |
+
n_layers=5, dropout_rate=0.1, odim=2,
|
133 |
+
padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
|
134 |
+
|
135 |
+
def forward(self, mel_input=None):
|
136 |
+
ret = {}
|
137 |
+
mel_hidden = self.mel_prenet(mel_input)[1]
|
138 |
+
if self.conv_layers > 0:
|
139 |
+
mel_hidden = self.mel_encoder(mel_hidden)
|
140 |
+
|
141 |
+
ret['pitch_pred'] = pitch_pred = self.pitch_predictor(mel_hidden)
|
142 |
+
|
143 |
+
pitch_padding = mel_input.abs().sum(-1) == 0
|
144 |
+
use_uv = hparams['pitch_type'] == 'frame' and hparams['use_uv']
|
145 |
+
|
146 |
+
ret['f0_denorm_pred'] = denorm_f0(
|
147 |
+
pitch_pred[:, :, 0], (pitch_pred[:, :, 1] > 0) if use_uv else None,
|
148 |
+
hparams, pitch_padding=pitch_padding)
|
149 |
+
return ret
|
modules/fastspeech/tts_modules.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from modules.commons.espnet_positional_embedding import RelPositionalEncoding
|
9 |
+
from modules.commons.common_layers import SinusoidalPositionalEmbedding, Linear, EncSALayer, DecSALayer, BatchNorm1dTBC
|
10 |
+
from utils.hparams import hparams
|
11 |
+
|
12 |
+
DEFAULT_MAX_SOURCE_POSITIONS = 2000
|
13 |
+
DEFAULT_MAX_TARGET_POSITIONS = 2000
|
14 |
+
|
15 |
+
|
16 |
+
class TransformerEncoderLayer(nn.Module):
|
17 |
+
def __init__(self, hidden_size, dropout, kernel_size=None, num_heads=2, norm='ln'):
|
18 |
+
super().__init__()
|
19 |
+
self.hidden_size = hidden_size
|
20 |
+
self.dropout = dropout
|
21 |
+
self.num_heads = num_heads
|
22 |
+
self.op = EncSALayer(
|
23 |
+
hidden_size, num_heads, dropout=dropout,
|
24 |
+
attention_dropout=0.0, relu_dropout=dropout,
|
25 |
+
kernel_size=kernel_size
|
26 |
+
if kernel_size is not None else hparams['enc_ffn_kernel_size'],
|
27 |
+
padding=hparams['ffn_padding'],
|
28 |
+
norm=norm, act=hparams['ffn_act'])
|
29 |
+
|
30 |
+
def forward(self, x, **kwargs):
|
31 |
+
return self.op(x, **kwargs)
|
32 |
+
|
33 |
+
|
34 |
+
######################
|
35 |
+
# fastspeech modules
|
36 |
+
######################
|
37 |
+
class LayerNorm(torch.nn.LayerNorm):
|
38 |
+
"""Layer normalization module.
|
39 |
+
:param int nout: output dim size
|
40 |
+
:param int dim: dimension to be normalized
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(self, nout, dim=-1):
|
44 |
+
"""Construct an LayerNorm object."""
|
45 |
+
super(LayerNorm, self).__init__(nout, eps=1e-12)
|
46 |
+
self.dim = dim
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
"""Apply layer normalization.
|
50 |
+
:param torch.Tensor x: input tensor
|
51 |
+
:return: layer normalized tensor
|
52 |
+
:rtype torch.Tensor
|
53 |
+
"""
|
54 |
+
if self.dim == -1:
|
55 |
+
return super(LayerNorm, self).forward(x)
|
56 |
+
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
|
57 |
+
|
58 |
+
|
59 |
+
class DurationPredictor(torch.nn.Module):
|
60 |
+
"""Duration predictor module.
|
61 |
+
This is a module of duration predictor described in `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
|
62 |
+
The duration predictor predicts a duration of each frame in log domain from the hidden embeddings of encoder.
|
63 |
+
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
|
64 |
+
https://arxiv.org/pdf/1905.09263.pdf
|
65 |
+
Note:
|
66 |
+
The calculation domain of outputs is different between in `forward` and in `inference`. In `forward`,
|
67 |
+
the outputs are calculated in log domain but in `inference`, those are calculated in linear domain.
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0, padding='SAME'):
|
71 |
+
"""Initilize duration predictor module.
|
72 |
+
Args:
|
73 |
+
idim (int): Input dimension.
|
74 |
+
n_layers (int, optional): Number of convolutional layers.
|
75 |
+
n_chans (int, optional): Number of channels of convolutional layers.
|
76 |
+
kernel_size (int, optional): Kernel size of convolutional layers.
|
77 |
+
dropout_rate (float, optional): Dropout rate.
|
78 |
+
offset (float, optional): Offset value to avoid nan in log domain.
|
79 |
+
"""
|
80 |
+
super(DurationPredictor, self).__init__()
|
81 |
+
self.offset = offset
|
82 |
+
self.conv = torch.nn.ModuleList()
|
83 |
+
self.kernel_size = kernel_size
|
84 |
+
self.padding = padding
|
85 |
+
for idx in range(n_layers):
|
86 |
+
in_chans = idim if idx == 0 else n_chans
|
87 |
+
self.conv += [torch.nn.Sequential(
|
88 |
+
torch.nn.ConstantPad1d(((kernel_size - 1) // 2, (kernel_size - 1) // 2)
|
89 |
+
if padding == 'SAME'
|
90 |
+
else (kernel_size - 1, 0), 0),
|
91 |
+
torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=0),
|
92 |
+
torch.nn.ReLU(),
|
93 |
+
LayerNorm(n_chans, dim=1),
|
94 |
+
torch.nn.Dropout(dropout_rate)
|
95 |
+
)]
|
96 |
+
if hparams['dur_loss'] in ['mse', 'huber']:
|
97 |
+
odims = 1
|
98 |
+
elif hparams['dur_loss'] == 'mog':
|
99 |
+
odims = 15
|
100 |
+
elif hparams['dur_loss'] == 'crf':
|
101 |
+
odims = 32
|
102 |
+
from torchcrf import CRF
|
103 |
+
self.crf = CRF(odims, batch_first=True)
|
104 |
+
self.linear = torch.nn.Linear(n_chans, odims)
|
105 |
+
|
106 |
+
def _forward(self, xs, x_masks=None, is_inference=False):
|
107 |
+
xs = xs.transpose(1, -1) # (B, idim, Tmax)
|
108 |
+
for f in self.conv:
|
109 |
+
xs = f(xs) # (B, C, Tmax)
|
110 |
+
if x_masks is not None:
|
111 |
+
xs = xs * (1 - x_masks.float())[:, None, :]
|
112 |
+
|
113 |
+
xs = self.linear(xs.transpose(1, -1)) # [B, T, C]
|
114 |
+
xs = xs * (1 - x_masks.float())[:, :, None] # (B, T, C)
|
115 |
+
if is_inference:
|
116 |
+
return self.out2dur(xs), xs
|
117 |
+
else:
|
118 |
+
if hparams['dur_loss'] in ['mse']:
|
119 |
+
xs = xs.squeeze(-1) # (B, Tmax)
|
120 |
+
return xs
|
121 |
+
|
122 |
+
def out2dur(self, xs):
|
123 |
+
if hparams['dur_loss'] in ['mse']:
|
124 |
+
# NOTE: calculate in log domain
|
125 |
+
xs = xs.squeeze(-1) # (B, Tmax)
|
126 |
+
dur = torch.clamp(torch.round(xs.exp() - self.offset), min=0).long() # avoid negative value
|
127 |
+
elif hparams['dur_loss'] == 'mog':
|
128 |
+
return NotImplementedError
|
129 |
+
elif hparams['dur_loss'] == 'crf':
|
130 |
+
dur = torch.LongTensor(self.crf.decode(xs)).cuda()
|
131 |
+
return dur
|
132 |
+
|
133 |
+
def forward(self, xs, x_masks=None):
|
134 |
+
"""Calculate forward propagation.
|
135 |
+
Args:
|
136 |
+
xs (Tensor): Batch of input sequences (B, Tmax, idim).
|
137 |
+
x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
|
138 |
+
Returns:
|
139 |
+
Tensor: Batch of predicted durations in log domain (B, Tmax).
|
140 |
+
"""
|
141 |
+
return self._forward(xs, x_masks, False)
|
142 |
+
|
143 |
+
def inference(self, xs, x_masks=None):
|
144 |
+
"""Inference duration.
|
145 |
+
Args:
|
146 |
+
xs (Tensor): Batch of input sequences (B, Tmax, idim).
|
147 |
+
x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
|
148 |
+
Returns:
|
149 |
+
LongTensor: Batch of predicted durations in linear domain (B, Tmax).
|
150 |
+
"""
|
151 |
+
return self._forward(xs, x_masks, True)
|
152 |
+
|
153 |
+
|
154 |
+
class LengthRegulator(torch.nn.Module):
|
155 |
+
def __init__(self, pad_value=0.0):
|
156 |
+
super(LengthRegulator, self).__init__()
|
157 |
+
self.pad_value = pad_value
|
158 |
+
|
159 |
+
def forward(self, dur, dur_padding=None, alpha=1.0):
|
160 |
+
"""
|
161 |
+
Example (no batch dim version):
|
162 |
+
1. dur = [2,2,3]
|
163 |
+
2. token_idx = [[1],[2],[3]], dur_cumsum = [2,4,7], dur_cumsum_prev = [0,2,4]
|
164 |
+
3. token_mask = [[1,1,0,0,0,0,0],
|
165 |
+
[0,0,1,1,0,0,0],
|
166 |
+
[0,0,0,0,1,1,1]]
|
167 |
+
4. token_idx * token_mask = [[1,1,0,0,0,0,0],
|
168 |
+
[0,0,2,2,0,0,0],
|
169 |
+
[0,0,0,0,3,3,3]]
|
170 |
+
5. (token_idx * token_mask).sum(0) = [1,1,2,2,3,3,3]
|
171 |
+
|
172 |
+
:param dur: Batch of durations of each frame (B, T_txt)
|
173 |
+
:param dur_padding: Batch of padding of each frame (B, T_txt)
|
174 |
+
:param alpha: duration rescale coefficient
|
175 |
+
:return:
|
176 |
+
mel2ph (B, T_speech)
|
177 |
+
"""
|
178 |
+
assert alpha > 0
|
179 |
+
dur = torch.round(dur.float() * alpha).long()
|
180 |
+
if dur_padding is not None:
|
181 |
+
dur = dur * (1 - dur_padding.long())
|
182 |
+
token_idx = torch.arange(1, dur.shape[1] + 1)[None, :, None].to(dur.device)
|
183 |
+
dur_cumsum = torch.cumsum(dur, 1)
|
184 |
+
dur_cumsum_prev = F.pad(dur_cumsum, [1, -1], mode='constant', value=0)
|
185 |
+
|
186 |
+
pos_idx = torch.arange(dur.sum(-1).max())[None, None].to(dur.device)
|
187 |
+
token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None])
|
188 |
+
mel2ph = (token_idx * token_mask.long()).sum(1)
|
189 |
+
return mel2ph
|
190 |
+
|
191 |
+
|
192 |
+
class PitchPredictor(torch.nn.Module):
|
193 |
+
def __init__(self, idim, n_layers=5, n_chans=384, odim=2, kernel_size=5,
|
194 |
+
dropout_rate=0.1, padding='SAME'):
|
195 |
+
"""Initilize pitch predictor module.
|
196 |
+
Args:
|
197 |
+
idim (int): Input dimension.
|
198 |
+
n_layers (int, optional): Number of convolutional layers.
|
199 |
+
n_chans (int, optional): Number of channels of convolutional layers.
|
200 |
+
kernel_size (int, optional): Kernel size of convolutional layers.
|
201 |
+
dropout_rate (float, optional): Dropout rate.
|
202 |
+
"""
|
203 |
+
super(PitchPredictor, self).__init__()
|
204 |
+
self.conv = torch.nn.ModuleList()
|
205 |
+
self.kernel_size = kernel_size
|
206 |
+
self.padding = padding
|
207 |
+
for idx in range(n_layers):
|
208 |
+
in_chans = idim if idx == 0 else n_chans
|
209 |
+
self.conv += [torch.nn.Sequential(
|
210 |
+
torch.nn.ConstantPad1d(((kernel_size - 1) // 2, (kernel_size - 1) // 2)
|
211 |
+
if padding == 'SAME'
|
212 |
+
else (kernel_size - 1, 0), 0),
|
213 |
+
torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=0),
|
214 |
+
torch.nn.ReLU(),
|
215 |
+
LayerNorm(n_chans, dim=1),
|
216 |
+
torch.nn.Dropout(dropout_rate)
|
217 |
+
)]
|
218 |
+
self.linear = torch.nn.Linear(n_chans, odim)
|
219 |
+
self.embed_positions = SinusoidalPositionalEmbedding(idim, 0, init_size=4096)
|
220 |
+
self.pos_embed_alpha = nn.Parameter(torch.Tensor([1]))
|
221 |
+
|
222 |
+
def forward(self, xs):
|
223 |
+
"""
|
224 |
+
|
225 |
+
:param xs: [B, T, H]
|
226 |
+
:return: [B, T, H]
|
227 |
+
"""
|
228 |
+
positions = self.pos_embed_alpha * self.embed_positions(xs[..., 0])
|
229 |
+
xs = xs + positions
|
230 |
+
xs = xs.transpose(1, -1) # (B, idim, Tmax)
|
231 |
+
for f in self.conv:
|
232 |
+
xs = f(xs) # (B, C, Tmax)
|
233 |
+
# NOTE: calculate in log domain
|
234 |
+
xs = self.linear(xs.transpose(1, -1)) # (B, Tmax, H)
|
235 |
+
return xs
|
236 |
+
|
237 |
+
|
238 |
+
class EnergyPredictor(PitchPredictor):
|
239 |
+
pass
|
240 |
+
|
241 |
+
|
242 |
+
def mel2ph_to_dur(mel2ph, T_txt, max_dur=None):
|
243 |
+
B, _ = mel2ph.shape
|
244 |
+
dur = mel2ph.new_zeros(B, T_txt + 1).scatter_add(1, mel2ph, torch.ones_like(mel2ph))
|
245 |
+
dur = dur[:, 1:]
|
246 |
+
if max_dur is not None:
|
247 |
+
dur = dur.clamp(max=max_dur)
|
248 |
+
return dur
|
249 |
+
|
250 |
+
|
251 |
+
class FFTBlocks(nn.Module):
|
252 |
+
def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=None, num_heads=2,
|
253 |
+
use_pos_embed=True, use_last_norm=True, norm='ln', use_pos_embed_alpha=True):
|
254 |
+
super().__init__()
|
255 |
+
self.num_layers = num_layers
|
256 |
+
embed_dim = self.hidden_size = hidden_size
|
257 |
+
self.dropout = dropout if dropout is not None else hparams['dropout']
|
258 |
+
self.use_pos_embed = use_pos_embed
|
259 |
+
self.use_last_norm = use_last_norm
|
260 |
+
if use_pos_embed:
|
261 |
+
self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS
|
262 |
+
self.padding_idx = 0
|
263 |
+
self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1
|
264 |
+
self.embed_positions = SinusoidalPositionalEmbedding(
|
265 |
+
embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
|
266 |
+
)
|
267 |
+
|
268 |
+
self.layers = nn.ModuleList([])
|
269 |
+
self.layers.extend([
|
270 |
+
TransformerEncoderLayer(self.hidden_size, self.dropout,
|
271 |
+
kernel_size=ffn_kernel_size, num_heads=num_heads)
|
272 |
+
for _ in range(self.num_layers)
|
273 |
+
])
|
274 |
+
if self.use_last_norm:
|
275 |
+
if norm == 'ln':
|
276 |
+
self.layer_norm = nn.LayerNorm(embed_dim)
|
277 |
+
elif norm == 'bn':
|
278 |
+
self.layer_norm = BatchNorm1dTBC(embed_dim)
|
279 |
+
else:
|
280 |
+
self.layer_norm = None
|
281 |
+
|
282 |
+
def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False):
|
283 |
+
"""
|
284 |
+
:param x: [B, T, C]
|
285 |
+
:param padding_mask: [B, T]
|
286 |
+
:return: [B, T, C] or [L, B, T, C]
|
287 |
+
"""
|
288 |
+
padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
|
289 |
+
nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1]
|
290 |
+
if self.use_pos_embed:
|
291 |
+
positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
|
292 |
+
x = x + positions
|
293 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
294 |
+
# B x T x C -> T x B x C
|
295 |
+
x = x.transpose(0, 1) * nonpadding_mask_TB
|
296 |
+
hiddens = []
|
297 |
+
for layer in self.layers:
|
298 |
+
x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
|
299 |
+
hiddens.append(x)
|
300 |
+
if self.use_last_norm:
|
301 |
+
x = self.layer_norm(x) * nonpadding_mask_TB
|
302 |
+
if return_hiddens:
|
303 |
+
x = torch.stack(hiddens, 0) # [L, T, B, C]
|
304 |
+
x = x.transpose(1, 2) # [L, B, T, C]
|
305 |
+
else:
|
306 |
+
x = x.transpose(0, 1) # [B, T, C]
|
307 |
+
return x
|
308 |
+
|
309 |
+
|
310 |
+
class FastspeechEncoder(FFTBlocks):
|
311 |
+
def __init__(self, embed_tokens, hidden_size=None, num_layers=None, kernel_size=None, num_heads=2):
|
312 |
+
hidden_size = hparams['hidden_size'] if hidden_size is None else hidden_size
|
313 |
+
kernel_size = hparams['enc_ffn_kernel_size'] if kernel_size is None else kernel_size
|
314 |
+
num_layers = hparams['dec_layers'] if num_layers is None else num_layers
|
315 |
+
super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads,
|
316 |
+
use_pos_embed=False) # use_pos_embed_alpha for compatibility
|
317 |
+
self.embed_tokens = embed_tokens
|
318 |
+
self.embed_scale = math.sqrt(hidden_size)
|
319 |
+
self.padding_idx = 0
|
320 |
+
if hparams.get('rel_pos') is not None and hparams['rel_pos']:
|
321 |
+
self.embed_positions = RelPositionalEncoding(hidden_size, dropout_rate=0.0)
|
322 |
+
else:
|
323 |
+
self.embed_positions = SinusoidalPositionalEmbedding(
|
324 |
+
hidden_size, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
|
325 |
+
)
|
326 |
+
|
327 |
+
def forward(self, txt_tokens):
|
328 |
+
"""
|
329 |
+
|
330 |
+
:param txt_tokens: [B, T]
|
331 |
+
:return: {
|
332 |
+
'encoder_out': [T x B x C]
|
333 |
+
}
|
334 |
+
"""
|
335 |
+
encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
|
336 |
+
x = self.forward_embedding(txt_tokens) # [B, T, H]
|
337 |
+
x = super(FastspeechEncoder, self).forward(x, encoder_padding_mask)
|
338 |
+
return x
|
339 |
+
|
340 |
+
def forward_embedding(self, txt_tokens):
|
341 |
+
# embed tokens and positions
|
342 |
+
x = self.embed_scale * self.embed_tokens(txt_tokens)
|
343 |
+
if hparams['use_pos_embed']:
|
344 |
+
positions = self.embed_positions(txt_tokens)
|
345 |
+
x = x + positions
|
346 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
347 |
+
return x
|
348 |
+
|
349 |
+
|
350 |
+
class FastspeechDecoder(FFTBlocks):
|
351 |
+
def __init__(self, hidden_size=None, num_layers=None, kernel_size=None, num_heads=None):
|
352 |
+
num_heads = hparams['num_heads'] if num_heads is None else num_heads
|
353 |
+
hidden_size = hparams['hidden_size'] if hidden_size is None else hidden_size
|
354 |
+
kernel_size = hparams['dec_ffn_kernel_size'] if kernel_size is None else kernel_size
|
355 |
+
num_layers = hparams['dec_layers'] if num_layers is None else num_layers
|
356 |
+
super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads)
|
357 |
+
|