upload
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .DS_Store +0 -0
- LICENSE +20 -0
- README.md +116 -12
- Untitled.ipynb +24 -0
- app.py +7 -0
- assets/.DS_Store +0 -0
- assets/hifigan/config.json +38 -0
- assets/infore/.DS_Store +0 -0
- assets/infore/lexicon.txt +0 -0
- assets/transcript.txt +26 -0
- notebooks/align_text_audio_infore_mfa.ipynb +193 -0
- notebooks/denoise_infore_dataset.ipynb +138 -0
- scripts/download_aligned_infore_dataset.py +45 -0
- scripts/quick_start.sh +12 -0
- setup.cfg +14 -0
- setup.py +43 -0
- tests/test_nat_acoustic.py +18 -0
- tests/test_nat_duration.py +15 -0
- vietTTS.egg-info/PKG-INFO +11 -0
- vietTTS.egg-info/SOURCES.txt +10 -0
- vietTTS.egg-info/dependency_links.txt +1 -0
- vietTTS.egg-info/requires.txt +12 -0
- vietTTS.egg-info/top_level.txt +1 -0
- vietTTS/__init__.py +0 -0
- vietTTS/__pycache__/__init__.cpython-39.pyc +0 -0
- vietTTS/__pycache__/synthesizer.cpython-39.pyc +0 -0
- vietTTS/hifigan/__pycache__/config.cpython-39.pyc +0 -0
- vietTTS/hifigan/__pycache__/mel2wave.cpython-39.pyc +0 -0
- vietTTS/hifigan/__pycache__/model.cpython-39.pyc +0 -0
- vietTTS/hifigan/config.py +6 -0
- vietTTS/hifigan/convert_torch_model_to_haiku.py +83 -0
- vietTTS/hifigan/create_mel.py +241 -0
- vietTTS/hifigan/data_loader.py +0 -0
- vietTTS/hifigan/mel2wave.py +41 -0
- vietTTS/hifigan/model.py +125 -0
- vietTTS/hifigan/torch_model.py +414 -0
- vietTTS/hifigan/trainer.py +0 -0
- vietTTS/nat/__init__.py +0 -0
- vietTTS/nat/__pycache__/__init__.cpython-39.pyc +0 -0
- vietTTS/nat/__pycache__/config.cpython-39.pyc +0 -0
- vietTTS/nat/__pycache__/data_loader.cpython-39.pyc +0 -0
- vietTTS/nat/__pycache__/model.cpython-39.pyc +0 -0
- vietTTS/nat/__pycache__/text2mel.cpython-39.pyc +0 -0
- vietTTS/nat/acoustic_tpu_trainer.py +189 -0
- vietTTS/nat/acoustic_trainer.py +181 -0
- vietTTS/nat/config.py +74 -0
- vietTTS/nat/data_loader.py +156 -0
- vietTTS/nat/dsp.py +128 -0
- vietTTS/nat/duration_trainer.py +142 -0
- vietTTS/nat/gta.py +82 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
LICENSE
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2021 ntt123
|
2 |
+
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining
|
4 |
+
a copy of this software and associated documentation files (the
|
5 |
+
"Software"), to deal in the Software without restriction, including
|
6 |
+
without limitation the rights to use, copy, modify, merge, publish,
|
7 |
+
distribute, sublicense, and/or sell copies of the Software, and to
|
8 |
+
permit persons to whom the Software is furnished to do so, subject to
|
9 |
+
the following conditions:
|
10 |
+
|
11 |
+
The above copyright notice and this permission notice shall be
|
12 |
+
included in all copies or substantial portions of the Software.
|
13 |
+
|
14 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
15 |
+
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
16 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
17 |
+
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
18 |
+
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
19 |
+
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
20 |
+
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,116 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
A Vietnamese TTS
|
2 |
+
================
|
3 |
+
|
4 |
+
Duration model + Acoustic model + HiFiGAN vocoder for vietnamese text-to-speech application.
|
5 |
+
|
6 |
+
Online demo at https://huggingface.co/spaces/ntt123/vietTTS.
|
7 |
+
|
8 |
+
A synthesized audio clip: [clip.wav](assets/infore/clip.wav). A colab notebook: [notebook](https://colab.research.google.com/drive/1oczrWOQOr1Y_qLdgis1twSlNZlfPVXoY?usp=sharing).
|
9 |
+
|
10 |
+
|
11 |
+
🔔Checkout the experimental `multi-speaker` branch (`git checkout multi-speaker`) for multi-speaker support.🔔
|
12 |
+
|
13 |
+
Install
|
14 |
+
-------
|
15 |
+
|
16 |
+
|
17 |
+
```sh
|
18 |
+
git clone https://github.com/NTT123/vietTTS.git
|
19 |
+
cd vietTTS
|
20 |
+
pip3 install -e .
|
21 |
+
```
|
22 |
+
|
23 |
+
|
24 |
+
Quick start using pretrained models
|
25 |
+
----------------------------------
|
26 |
+
```sh
|
27 |
+
bash ./scripts/quick_start.sh
|
28 |
+
```
|
29 |
+
|
30 |
+
|
31 |
+
Download InfoRe dataset
|
32 |
+
-----------------------
|
33 |
+
|
34 |
+
```sh
|
35 |
+
python ./scripts/download_aligned_infore_dataset.py
|
36 |
+
```
|
37 |
+
|
38 |
+
**Note**: this is a denoised and aligned version of the original dataset which is donated by the InfoRe Technology company (see [here](https://www.facebook.com/groups/j2team.community/permalink/1010834009248719/)). You can download the original dataset (**InfoRe Technology 1**) at [here](https://github.com/TensorSpeech/TensorFlowASR/blob/main/README.md#vietnamese).
|
39 |
+
|
40 |
+
See `notebooks/denoise_infore_dataset.ipynb` for instructions on how to denoise the dataset. We use the Montreal Forced Aligner (MFA) to align transcript and speech (textgrid files).
|
41 |
+
See `notebooks/align_text_audio_infore_mfa.ipynb` for instructions on how to create textgrid files.
|
42 |
+
|
43 |
+
Train duration model
|
44 |
+
--------------------
|
45 |
+
|
46 |
+
```sh
|
47 |
+
python -m vietTTS.nat.duration_trainer
|
48 |
+
```
|
49 |
+
|
50 |
+
|
51 |
+
Train acoustic model
|
52 |
+
--------------------
|
53 |
+
```sh
|
54 |
+
python -m vietTTS.nat.acoustic_trainer
|
55 |
+
```
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
Train HiFiGAN vocoder
|
60 |
+
-------------
|
61 |
+
|
62 |
+
We use the original implementation from HiFiGAN authors at https://github.com/jik876/hifi-gan. Use the config file at `assets/hifigan/config.json` to train your model.
|
63 |
+
|
64 |
+
```sh
|
65 |
+
git clone https://github.com/jik876/hifi-gan.git
|
66 |
+
|
67 |
+
# create dataset in hifi-gan format
|
68 |
+
ln -sf `pwd`/train_data hifi-gan/data
|
69 |
+
cd hifi-gan/data
|
70 |
+
ls -1 *.TextGrid | sed -e 's/\.TextGrid$//' > files.txt
|
71 |
+
cd ..
|
72 |
+
head -n 100 data/files.txt > val_files.txt
|
73 |
+
tail -n +101 data/files.txt > train_files.txt
|
74 |
+
rm data/files.txt
|
75 |
+
|
76 |
+
# training
|
77 |
+
python train.py \
|
78 |
+
--config ../assets/hifigan/config.json \
|
79 |
+
--input_wavs_dir=data \
|
80 |
+
--input_training_file=train_files.txt \
|
81 |
+
--input_validation_file=val_files.txt
|
82 |
+
```
|
83 |
+
|
84 |
+
Finetune on Ground-Truth Aligned melspectrograms:
|
85 |
+
```sh
|
86 |
+
cd /path/to/vietTTS # go to vietTTS directory
|
87 |
+
python -m vietTTS.nat.zero_silence_segments -o train_data # zero all [sil, sp, spn] segments
|
88 |
+
python -m vietTTS.nat.gta -o /path/to/hifi-gan/ft_dataset # create gta melspectrograms at hifi-gan/ft_dataset directory
|
89 |
+
|
90 |
+
# turn on finetune
|
91 |
+
cd /path/to/hifi-gan
|
92 |
+
python train.py \
|
93 |
+
--fine_tuning True \
|
94 |
+
--config ../assets/hifigan/config.json \
|
95 |
+
--input_wavs_dir=data \
|
96 |
+
--input_training_file=train_files.txt \
|
97 |
+
--input_validation_file=val_files.txt
|
98 |
+
```
|
99 |
+
|
100 |
+
Then, use the following command to convert pytorch model to haiku format:
|
101 |
+
```sh
|
102 |
+
cd ..
|
103 |
+
python -m vietTTS.hifigan.convert_torch_model_to_haiku \
|
104 |
+
--config-file=assets/hifigan/config.json \
|
105 |
+
--checkpoint-file=hifi-gan/cp_hifigan/g_[latest_checkpoint]
|
106 |
+
```
|
107 |
+
|
108 |
+
Synthesize speech
|
109 |
+
-----------------
|
110 |
+
|
111 |
+
```sh
|
112 |
+
python -m vietTTS.synthesizer \
|
113 |
+
--lexicon-file=train_data/lexicon.txt \
|
114 |
+
--text="hôm qua em tới trường" \
|
115 |
+
--output=clip.wav
|
116 |
+
```
|
Untitled.ipynb
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "b036f793-1443-4341-932c-d112386937ea",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": []
|
10 |
+
}
|
11 |
+
],
|
12 |
+
"metadata": {
|
13 |
+
"kernelspec": {
|
14 |
+
"display_name": "Python 3 (ipykernel)",
|
15 |
+
"language": "python",
|
16 |
+
"name": "python3"
|
17 |
+
},
|
18 |
+
"language_info": {
|
19 |
+
"name": ""
|
20 |
+
}
|
21 |
+
},
|
22 |
+
"nbformat": 4,
|
23 |
+
"nbformat_minor": 5
|
24 |
+
}
|
app.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
def greet(name):
|
4 |
+
return "Hello " + name + "!!"
|
5 |
+
|
6 |
+
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
|
7 |
+
iface.launch()
|
assets/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
assets/hifigan/config.json
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"resblock": "1",
|
3 |
+
"num_gpus": 0,
|
4 |
+
"batch_size": 16,
|
5 |
+
"learning_rate": 0.0002,
|
6 |
+
"adam_b1": 0.8,
|
7 |
+
"adam_b2": 0.99,
|
8 |
+
"lr_decay": 0.999,
|
9 |
+
"seed": 1234,
|
10 |
+
|
11 |
+
"upsample_rates": [8,8,2,2],
|
12 |
+
"upsample_kernel_sizes": [16,16,4,4],
|
13 |
+
"upsample_initial_channel": 512,
|
14 |
+
"resblock_kernel_sizes": [3,7,11],
|
15 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
16 |
+
"resblock_initial_channel": 256,
|
17 |
+
|
18 |
+
"segment_size": 8192,
|
19 |
+
"num_mels": 80,
|
20 |
+
"num_freq": 1025,
|
21 |
+
"n_fft": 1024,
|
22 |
+
"hop_size": 256,
|
23 |
+
"win_size": 1024,
|
24 |
+
|
25 |
+
"sampling_rate": 16000,
|
26 |
+
|
27 |
+
"fmin": 0,
|
28 |
+
"fmax": 8000,
|
29 |
+
"fmax_for_loss": null,
|
30 |
+
|
31 |
+
"num_workers": 4,
|
32 |
+
|
33 |
+
"dist_config": {
|
34 |
+
"dist_backend": "nccl",
|
35 |
+
"dist_url": "tcp://localhost:54321",
|
36 |
+
"world_size": 1
|
37 |
+
}
|
38 |
+
}
|
assets/infore/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
assets/infore/lexicon.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
assets/transcript.txt
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Trăm năm trong cõi người ta,
|
2 |
+
Chữ tài chữ mệnh khéo là ghét nhau.
|
3 |
+
Trải qua một cuộc bể dâu,
|
4 |
+
Những điều trông thấy mà đau đớn lòng.
|
5 |
+
Lạ gì bỉ sắc tư phong,
|
6 |
+
Trời xanh quen thói má hồng đánh ghen.
|
7 |
+
Cảo thơm lần giở trước đèn,
|
8 |
+
Phong tình cổ lục còn truyền sử xanh.
|
9 |
+
Rằng: Năm Gia tĩnh triều Minh,
|
10 |
+
Bốn phương phẳng lặng hai kinh chữ vàng.
|
11 |
+
Có nhà viên ngoại họ Vương,
|
12 |
+
Gia tư nghỉ cũng thường thường bậc trung.
|
13 |
+
Một trai con thứ rốt lòng,
|
14 |
+
Vương Quan là chữ nối dòng nho gia.
|
15 |
+
Đầu lòng hai ả tố nga,
|
16 |
+
Thúy Kiều là chị em là Thúy Vân.
|
17 |
+
Mai cốt cách tuyết tinh thần,
|
18 |
+
Mỗi người một vẻ mười phân vẹn mười.
|
19 |
+
Vân xem trang trọng khác vời,
|
20 |
+
Khuôn trăng đầy đặn nét ngài nở nang.
|
21 |
+
Hoa cười ngọc thốt đoan trang,
|
22 |
+
Mây thua nước tóc tuyết nhường màu da.
|
23 |
+
Kiều càng sắc sảo mặn mà,
|
24 |
+
So bề tài sắc lại là phần hơn.
|
25 |
+
Làn thu thủy nét xuân sơn,
|
26 |
+
Hoa ghen thua thắm liễu hờn kém xanh.
|
notebooks/align_text_audio_infore_mfa.ipynb
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"### Align text and audio using Montreal Forced Aligner (MFA)"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "code",
|
12 |
+
"execution_count": null,
|
13 |
+
"metadata": {
|
14 |
+
"id": "IPkicKwU8IWj"
|
15 |
+
},
|
16 |
+
"outputs": [],
|
17 |
+
"source": [
|
18 |
+
"%%capture\n",
|
19 |
+
"!apt update -y\n",
|
20 |
+
"!pip install -U pip"
|
21 |
+
]
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"cell_type": "code",
|
25 |
+
"execution_count": null,
|
26 |
+
"metadata": {
|
27 |
+
"id": "G6Z-aDd08hfk"
|
28 |
+
},
|
29 |
+
"outputs": [],
|
30 |
+
"source": [
|
31 |
+
"%%capture\n",
|
32 |
+
"%%bash\n",
|
33 |
+
"data_root=\"./infore_16k_denoised\"\n",
|
34 |
+
"mkdir -p $data_root\n",
|
35 |
+
"cd $data_root\n",
|
36 |
+
"wget https://huggingface.co/datasets/ntt123/infore/resolve/main/infore_16k_denoised.zip -O infore.zip\n",
|
37 |
+
"unzip infore.zip "
|
38 |
+
]
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"cell_type": "code",
|
42 |
+
"execution_count": null,
|
43 |
+
"metadata": {
|
44 |
+
"id": "VWwgAePDXy4m"
|
45 |
+
},
|
46 |
+
"outputs": [],
|
47 |
+
"source": [
|
48 |
+
"from pathlib import Path\n",
|
49 |
+
"\n",
|
50 |
+
"txt_files = sorted(Path(\"./infore_16k_denoised\").glob(\"*.txt\"))\n",
|
51 |
+
"f = open(\"/content/words.txt\", \"w\", encoding=\"utf-8\")\n",
|
52 |
+
"for txt_file in txt_files:\n",
|
53 |
+
" wav_file = txt_file.with_suffix(\".wav\")\n",
|
54 |
+
" if not wav_file.exists():\n",
|
55 |
+
" continue\n",
|
56 |
+
" line = open(txt_file, \"r\", encoding=\"utf-8\").read()\n",
|
57 |
+
" for word in line.strip().lower().split():\n",
|
58 |
+
" f.write(word)\n",
|
59 |
+
" f.write(\"\\n\")\n",
|
60 |
+
"f.close()"
|
61 |
+
]
|
62 |
+
},
|
63 |
+
{
|
64 |
+
"cell_type": "code",
|
65 |
+
"execution_count": null,
|
66 |
+
"metadata": {
|
67 |
+
"id": "FktjNXbDkBLh"
|
68 |
+
},
|
69 |
+
"outputs": [],
|
70 |
+
"source": [
|
71 |
+
"black_list = (\n",
|
72 |
+
" []\n",
|
73 |
+
" + [\"q\", \"adn\", \"h\", \"stress\", \"b\", \"k\", \"mark\", \"gas\", \"cs\", \"test\", \"l\", \"hiv\"]\n",
|
74 |
+
" + [\"v\", \"d\", \"c\", \"p\", \"martin\", \"visa\", \"euro\", \"laser\", \"x\", \"real\", \"shop\"]\n",
|
75 |
+
" + [\"studio\", \"kelvin\", \"đt\", \"pop\", \"rock\", \"gara\", \"karaoke\", \"đicr\", \"đigiúp\"]\n",
|
76 |
+
" + [\"khmer\", \"ii\", \"s\", \"tr\", \"xhcn\", \"casino\", \"guitar\", \"sex\", \"oxi\", \"radio\"]\n",
|
77 |
+
" + [\"qúy\", \"asean\", \"hlv\" \"ts\", \"video\", \"virus\", \"usd\", \"robot\", \"ph\", \"album\"]\n",
|
78 |
+
" + [\"s\", \"kg\", \"km\", \"g\", \"tr\", \"đ\", \"ak\", \"d\", \"m\", \"n\"]\n",
|
79 |
+
")"
|
80 |
+
]
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"cell_type": "code",
|
84 |
+
"execution_count": null,
|
85 |
+
"metadata": {
|
86 |
+
"id": "b3nMwfzK_g0B"
|
87 |
+
},
|
88 |
+
"outputs": [],
|
89 |
+
"source": [
|
90 |
+
"ws = open(\"/content/words.txt\").readlines()\n",
|
91 |
+
"f = open(\"/content/lexicon.txt\", \"w\")\n",
|
92 |
+
"for w in sorted(set(ws)):\n",
|
93 |
+
" w = w.strip()\n",
|
94 |
+
"\n",
|
95 |
+
" # this is a hack to match phoneme set in the vietTTS repo\n",
|
96 |
+
" p = list(w)\n",
|
97 |
+
" p = \" \".join(p)\n",
|
98 |
+
" if w in black_list:\n",
|
99 |
+
" continue\n",
|
100 |
+
" else:\n",
|
101 |
+
" f.write(f\"{w}\\t{p}\\n\")\n",
|
102 |
+
"f.close()"
|
103 |
+
]
|
104 |
+
},
|
105 |
+
{
|
106 |
+
"cell_type": "code",
|
107 |
+
"execution_count": null,
|
108 |
+
"metadata": {
|
109 |
+
"id": "WuWZKTNRt1eM"
|
110 |
+
},
|
111 |
+
"outputs": [],
|
112 |
+
"source": [
|
113 |
+
"%%writefile install_mfa.sh\n",
|
114 |
+
"#!/bin/bash\n",
|
115 |
+
"\n",
|
116 |
+
"## a script to install Montreal Forced Aligner (MFA)\n",
|
117 |
+
"\n",
|
118 |
+
"root_dir=${1:-/tmp/mfa}\n",
|
119 |
+
"mkdir -p $root_dir\n",
|
120 |
+
"cd $root_dir\n",
|
121 |
+
"\n",
|
122 |
+
"# download miniconda3\n",
|
123 |
+
"wget -q --show-progress https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh\n",
|
124 |
+
"bash Miniconda3-latest-Linux-x86_64.sh -b -p $root_dir/miniconda3 -f\n",
|
125 |
+
"\n",
|
126 |
+
"#install MFA\n",
|
127 |
+
"$root_dir/miniconda3/bin/conda create -n aligner -c conda-forge montreal-forced-aligner=2.0.0rc7 -y\n",
|
128 |
+
"\n",
|
129 |
+
"echo -e \"\\n======== DONE ==========\"\n",
|
130 |
+
"echo -e \"\\nTo activate MFA, run: source $root_dir/miniconda3/bin/activate aligner\""
|
131 |
+
]
|
132 |
+
},
|
133 |
+
{
|
134 |
+
"cell_type": "code",
|
135 |
+
"execution_count": null,
|
136 |
+
"metadata": {
|
137 |
+
"id": "osR7KJCNXJYq"
|
138 |
+
},
|
139 |
+
"outputs": [],
|
140 |
+
"source": [
|
141 |
+
"# download and install mfa\n",
|
142 |
+
"INSTALL_DIR = \"/tmp/mfa\" # path to install directory\n",
|
143 |
+
"!bash ./install_mfa.sh {INSTALL_DIR}"
|
144 |
+
]
|
145 |
+
},
|
146 |
+
{
|
147 |
+
"cell_type": "code",
|
148 |
+
"execution_count": null,
|
149 |
+
"metadata": {
|
150 |
+
"colab": {
|
151 |
+
"base_uri": "https://localhost:8080/"
|
152 |
+
},
|
153 |
+
"id": "hxbXwJZlXLPz",
|
154 |
+
"outputId": "d3e40ec5-68a7-40ec-d070-137736d7a956"
|
155 |
+
},
|
156 |
+
"outputs": [],
|
157 |
+
"source": [
|
158 |
+
"!source {INSTALL_DIR}/miniconda3/bin/activate aligner; \\\n",
|
159 |
+
"mfa train --clean -t ./temp -o ./infore_mfa.zip ./infore_16k_denoised lexicon.txt ./infore_textgrid"
|
160 |
+
]
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"cell_type": "code",
|
164 |
+
"execution_count": null,
|
165 |
+
"metadata": {
|
166 |
+
"id": "8Z65_BtXagn1"
|
167 |
+
},
|
168 |
+
"outputs": [],
|
169 |
+
"source": [
|
170 |
+
"# copy to train directory\n",
|
171 |
+
"!mkdir -p train_data\n",
|
172 |
+
"!cp ./infore_16k_denoised/*.wav ./train_data\n",
|
173 |
+
"!cp ./infore_textgrid/*.TextGrid ./train_data"
|
174 |
+
]
|
175 |
+
}
|
176 |
+
],
|
177 |
+
"metadata": {
|
178 |
+
"colab": {
|
179 |
+
"collapsed_sections": [],
|
180 |
+
"name": "align-text-audio | InfoRe using MFA v2rc7.ipynb",
|
181 |
+
"provenance": []
|
182 |
+
},
|
183 |
+
"kernelspec": {
|
184 |
+
"display_name": "Python 3",
|
185 |
+
"name": "python3"
|
186 |
+
},
|
187 |
+
"language_info": {
|
188 |
+
"name": "python"
|
189 |
+
}
|
190 |
+
},
|
191 |
+
"nbformat": 4,
|
192 |
+
"nbformat_minor": 0
|
193 |
+
}
|
notebooks/denoise_infore_dataset.ipynb
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "qjubHCEzYtG8"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"### Step 1. Download InfoRE dataset"
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": null,
|
15 |
+
"metadata": {
|
16 |
+
"id": "zCBJzJi6BE_o"
|
17 |
+
},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"%%capture\n",
|
21 |
+
"%%bash\n",
|
22 |
+
"mkdir -p /content/data\n",
|
23 |
+
"cd /content/data\n",
|
24 |
+
"wget https://huggingface.co/datasets/ntt123/infore/resolve/main/infore_16k.zip\n",
|
25 |
+
"# unzip -P BroughtToYouByInfoRe 25hours.zip\n",
|
26 |
+
"unzip infore_16k.zip"
|
27 |
+
]
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"cell_type": "markdown",
|
31 |
+
"metadata": {
|
32 |
+
"id": "6C47hb9nYzmB"
|
33 |
+
},
|
34 |
+
"source": [
|
35 |
+
"### Step 2. Normalize audio clip"
|
36 |
+
]
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"cell_type": "code",
|
40 |
+
"execution_count": null,
|
41 |
+
"metadata": {
|
42 |
+
"id": "Hp9TK8PbBcQM"
|
43 |
+
},
|
44 |
+
"outputs": [],
|
45 |
+
"source": [
|
46 |
+
"%%capture\n",
|
47 |
+
"!sudo apt install -y sox\n",
|
48 |
+
"!pip install soundfile librosa\n",
|
49 |
+
"!pip install onnxruntime==1.11.1"
|
50 |
+
]
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"cell_type": "code",
|
54 |
+
"execution_count": null,
|
55 |
+
"metadata": {
|
56 |
+
"colab": {
|
57 |
+
"base_uri": "https://localhost:8080/"
|
58 |
+
},
|
59 |
+
"id": "FW45D8xM9Mcc",
|
60 |
+
"outputId": "8d7ea7a9-ea5a-48ca-88fe-37dd4ed55d9b"
|
61 |
+
},
|
62 |
+
"outputs": [],
|
63 |
+
"source": [
|
64 |
+
"!mkdir -p /content/infore_16k\n",
|
65 |
+
"from pathlib import Path\n",
|
66 |
+
"import os\n",
|
67 |
+
"from tqdm.cli import tqdm\n",
|
68 |
+
"\n",
|
69 |
+
"wavs = sorted(Path(\"/content/data/InfoRe\").glob(\"*.wav\"))\n",
|
70 |
+
"for path in tqdm(wavs):\n",
|
71 |
+
" out = Path(\"/content/infore_16k\") / path.name\n",
|
72 |
+
" cmd = f\"sox {path} -c 1 -e signed-integer -b 16 -r 16k --norm=-3 {out}\"\n",
|
73 |
+
" os.system(cmd)"
|
74 |
+
]
|
75 |
+
},
|
76 |
+
{
|
77 |
+
"cell_type": "markdown",
|
78 |
+
"metadata": {
|
79 |
+
"id": "kooiBrsQY5sQ"
|
80 |
+
},
|
81 |
+
"source": [
|
82 |
+
"### Step 3. Denoise using DNS-Challenge's baseline"
|
83 |
+
]
|
84 |
+
},
|
85 |
+
{
|
86 |
+
"cell_type": "code",
|
87 |
+
"execution_count": null,
|
88 |
+
"metadata": {
|
89 |
+
"id": "hsXeNkZ3Xacj"
|
90 |
+
},
|
91 |
+
"outputs": [],
|
92 |
+
"source": [
|
93 |
+
"!git clone https://github.com/microsoft/DNS-Challenge\n",
|
94 |
+
"%cd DNS-Challenge/NSNet2-baseline/\n",
|
95 |
+
"!git checkout -f 8b87a33b2892f147b5c7ad39ea978453730db269\n",
|
96 |
+
"!python run_nsnet2.py -i /content/infore_16k/ -o /content/infore_16k_denoised -m ./nsnet2-20ms-baseline.onnx"
|
97 |
+
]
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"cell_type": "markdown",
|
101 |
+
"metadata": {
|
102 |
+
"id": "T5JtZwKgZI4r"
|
103 |
+
},
|
104 |
+
"source": [
|
105 |
+
"### Step 4. Zip the denoised dataset"
|
106 |
+
]
|
107 |
+
},
|
108 |
+
{
|
109 |
+
"cell_type": "code",
|
110 |
+
"execution_count": null,
|
111 |
+
"metadata": {
|
112 |
+
"id": "eeFggV0uYop_"
|
113 |
+
},
|
114 |
+
"outputs": [],
|
115 |
+
"source": [
|
116 |
+
"%cd /content\n",
|
117 |
+
"!cp /content/data/InfoRe/*.txt ./infore_16k_denoised\n",
|
118 |
+
"!cd ./infore_16k_denoised; zip -r ../infore_16k_denoised.zip ."
|
119 |
+
]
|
120 |
+
}
|
121 |
+
],
|
122 |
+
"metadata": {
|
123 |
+
"colab": {
|
124 |
+
"collapsed_sections": [],
|
125 |
+
"name": "prepare_infore_dataset.ipynb",
|
126 |
+
"provenance": []
|
127 |
+
},
|
128 |
+
"kernelspec": {
|
129 |
+
"display_name": "Python 3",
|
130 |
+
"name": "python3"
|
131 |
+
},
|
132 |
+
"language_info": {
|
133 |
+
"name": "python"
|
134 |
+
}
|
135 |
+
},
|
136 |
+
"nbformat": 4,
|
137 |
+
"nbformat_minor": 0
|
138 |
+
}
|
scripts/download_aligned_infore_dataset.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A script to download the InfoRE dataset and textgrid files.
|
3 |
+
"""
|
4 |
+
import shutil
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import pooch
|
8 |
+
from pooch import Unzip
|
9 |
+
from tqdm.cli import tqdm
|
10 |
+
|
11 |
+
|
12 |
+
def download_infore_data():
|
13 |
+
"""download infore wav files"""
|
14 |
+
files = pooch.retrieve(
|
15 |
+
url="https://huggingface.co/datasets/ntt123/infore/resolve/main/infore_16k_denoised.zip",
|
16 |
+
known_hash="2445527b345fb0b1816ce3c8f09bae419d6bbe251f16d6c74d8dd95ef9fb0737",
|
17 |
+
processor=Unzip(),
|
18 |
+
progressbar=True,
|
19 |
+
)
|
20 |
+
data_dir = Path(sorted(files)[0]).parent
|
21 |
+
return data_dir
|
22 |
+
|
23 |
+
|
24 |
+
def download_textgrid():
|
25 |
+
"""download textgrid files"""
|
26 |
+
files = pooch.retrieve(
|
27 |
+
url="https://huggingface.co/datasets/ntt123/infore/resolve/main/infore_tg.zip",
|
28 |
+
known_hash="26e4f53025220097ea95dc266657de8d65104b0a17a6ffba778fc016c8dd36d7",
|
29 |
+
processor=Unzip(),
|
30 |
+
progressbar=True,
|
31 |
+
)
|
32 |
+
data_dir = Path(sorted(files)[0]).parent
|
33 |
+
return data_dir
|
34 |
+
|
35 |
+
|
36 |
+
DATA_ROOT = Path("./train_data")
|
37 |
+
DATA_ROOT.mkdir(parents=True, exist_ok=True)
|
38 |
+
wav_dir = download_infore_data()
|
39 |
+
tg_dir = download_textgrid()
|
40 |
+
|
41 |
+
for path in tqdm(tg_dir.glob("*.TextGrid")):
|
42 |
+
wav_name = path.with_suffix(".wav").name
|
43 |
+
wav_src = wav_dir / wav_name
|
44 |
+
shutil.copy(path, DATA_ROOT)
|
45 |
+
shutil.copy(wav_src, DATA_ROOT)
|
scripts/quick_start.sh
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
if [ ! -f assets/infore/hifigan/g_01140000 ]; then
|
2 |
+
echo "Downloading models..."
|
3 |
+
mkdir -p assets/infore/{nat,hifigan}
|
4 |
+
wget https://huggingface.co/ntt123/viettts_infore_16k/resolve/main/duration_latest_ckpt.pickle -O assets/infore/nat/duration_latest_ckpt.pickle
|
5 |
+
wget https://huggingface.co/ntt123/viettts_infore_16k/resolve/main/acoustic_latest_ckpt.pickle -O assets/infore/nat/acoustic_latest_ckpt.pickle
|
6 |
+
wget https://huggingface.co/ntt123/viettts_infore_16k/resolve/main/g_01140000 -O assets/infore/hifigan/g_01140000
|
7 |
+
python3 -m vietTTS.hifigan.convert_torch_model_to_haiku --config-file=assets/hifigan/config.json --checkpoint-file=assets/infore/hifigan/g_01140000
|
8 |
+
fi
|
9 |
+
|
10 |
+
echo "Generate audio clip"
|
11 |
+
text=`cat assets/transcript.txt`
|
12 |
+
python3 -m vietTTS.synthesizer --text "$text" --output assets/infore/clip.wav --lexicon-file assets/infore/lexicon.txt --silence-duration 0.2
|
setup.cfg
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[pep8]
|
2 |
+
max-line-length = 120
|
3 |
+
indent-size = 2
|
4 |
+
|
5 |
+
[pycodestyle]
|
6 |
+
max-line-length = 120
|
7 |
+
|
8 |
+
[yapf]
|
9 |
+
based_on_style = pep8
|
10 |
+
column_limit = 120
|
11 |
+
|
12 |
+
[tool:pytest]
|
13 |
+
testpaths=
|
14 |
+
tests
|
setup.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup
|
2 |
+
|
3 |
+
__version__ = "0.4.1"
|
4 |
+
url = "https://github.com/ntt123/vietTTS"
|
5 |
+
|
6 |
+
install_requires = [
|
7 |
+
"dm-haiku",
|
8 |
+
"einops",
|
9 |
+
"fire",
|
10 |
+
"gdown",
|
11 |
+
"jax",
|
12 |
+
"jaxlib",
|
13 |
+
"librosa",
|
14 |
+
"optax",
|
15 |
+
"tabulate",
|
16 |
+
"textgrid @ git+https://github.com/kylebgorman/textgrid.git",
|
17 |
+
"tqdm",
|
18 |
+
"matplotlib",
|
19 |
+
]
|
20 |
+
setup_requires = []
|
21 |
+
tests_require = []
|
22 |
+
|
23 |
+
setup(
|
24 |
+
name="vietTTS",
|
25 |
+
version=__version__,
|
26 |
+
description="A vietnamese text-to-speech library.",
|
27 |
+
author="ntt123",
|
28 |
+
url=url,
|
29 |
+
keywords=[
|
30 |
+
"text-to-speech",
|
31 |
+
"tts",
|
32 |
+
"deep-learning",
|
33 |
+
"dm-haiku",
|
34 |
+
"jax",
|
35 |
+
"vietnamese",
|
36 |
+
"speech-synthesis",
|
37 |
+
],
|
38 |
+
install_requires=install_requires,
|
39 |
+
setup_requires=setup_requires,
|
40 |
+
tests_require=tests_require,
|
41 |
+
packages=["vietTTS"],
|
42 |
+
python_requires=">=3.7",
|
43 |
+
)
|
tests/test_nat_acoustic.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import haiku
|
2 |
+
import haiku as hk
|
3 |
+
import jax.numpy as jnp
|
4 |
+
import jax.random
|
5 |
+
from vietTTS.nat.config import FLAGS
|
6 |
+
from vietTTS.nat.model import AcousticModel
|
7 |
+
|
8 |
+
|
9 |
+
@hk.testing.transform_and_run
|
10 |
+
def test_duration():
|
11 |
+
net = AcousticModel()
|
12 |
+
token = jnp.zeros((2, 10), dtype=jnp.int32)
|
13 |
+
lengths = jnp.zeros((2,), dtype=jnp.int32)
|
14 |
+
durations = jnp.zeros((2, 10), dtype=jnp.float32)
|
15 |
+
mel = jnp.zeros((2, 20, 160), dtype=jnp.float32)
|
16 |
+
o1, o2 = net(token, mel, lengths, durations)
|
17 |
+
assert o1.shape == (2, 20, 160)
|
18 |
+
assert o2.shape == (2, 20, 160)
|
tests/test_nat_duration.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import haiku
|
2 |
+
import haiku as hk
|
3 |
+
import jax.numpy as jnp
|
4 |
+
import jax.random
|
5 |
+
from vietTTS.nat.config import FLAGS
|
6 |
+
from vietTTS.nat.model import DurationModel
|
7 |
+
|
8 |
+
|
9 |
+
@hk.testing.transform_and_run
|
10 |
+
def test_duration():
|
11 |
+
net = DurationModel()
|
12 |
+
p = jnp.zeros((2, 10), dtype=jnp.int32)
|
13 |
+
l = jnp.zeros((2,), dtype=jnp.int32)
|
14 |
+
o = net(p, l)
|
15 |
+
assert o.shape == (2, 10, 1)
|
vietTTS.egg-info/PKG-INFO
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 1.2
|
2 |
+
Name: vietTTS
|
3 |
+
Version: 0.4.1
|
4 |
+
Summary: A vietnamese text-to-speech library.
|
5 |
+
Home-page: https://github.com/ntt123/vietTTS
|
6 |
+
Author: ntt123
|
7 |
+
License: UNKNOWN
|
8 |
+
Description: UNKNOWN
|
9 |
+
Keywords: text-to-speech,tts,deep-learning,dm-haiku,jax,vietnamese,speech-synthesis
|
10 |
+
Platform: UNKNOWN
|
11 |
+
Requires-Python: >=3.7
|
vietTTS.egg-info/SOURCES.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
README.md
|
2 |
+
setup.cfg
|
3 |
+
setup.py
|
4 |
+
vietTTS/__init__.py
|
5 |
+
vietTTS/synthesizer.py
|
6 |
+
vietTTS.egg-info/PKG-INFO
|
7 |
+
vietTTS.egg-info/SOURCES.txt
|
8 |
+
vietTTS.egg-info/dependency_links.txt
|
9 |
+
vietTTS.egg-info/requires.txt
|
10 |
+
vietTTS.egg-info/top_level.txt
|
vietTTS.egg-info/dependency_links.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
vietTTS.egg-info/requires.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dm-haiku
|
2 |
+
einops
|
3 |
+
fire
|
4 |
+
gdown
|
5 |
+
jax
|
6 |
+
jaxlib
|
7 |
+
librosa
|
8 |
+
optax
|
9 |
+
tabulate
|
10 |
+
textgrid@ git+https://github.com/kylebgorman/textgrid.git
|
11 |
+
tqdm
|
12 |
+
matplotlib
|
vietTTS.egg-info/top_level.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
vietTTS
|
vietTTS/__init__.py
ADDED
File without changes
|
vietTTS/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (154 Bytes). View file
|
|
vietTTS/__pycache__/synthesizer.cpython-39.pyc
ADDED
Binary file (1.39 kB). View file
|
|
vietTTS/hifigan/__pycache__/config.cpython-39.pyc
ADDED
Binary file (437 Bytes). View file
|
|
vietTTS/hifigan/__pycache__/mel2wave.cpython-39.pyc
ADDED
Binary file (1.51 kB). View file
|
|
vietTTS/hifigan/__pycache__/model.cpython-39.pyc
ADDED
Binary file (3.8 kB). View file
|
|
vietTTS/hifigan/config.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import NamedTuple
|
3 |
+
|
4 |
+
|
5 |
+
class FLAGS:
|
6 |
+
ckpt_dir = Path("./assets/infore/hifigan")
|
vietTTS/hifigan/convert_torch_model_to_haiku.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import pickle
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from .config import FLAGS
|
10 |
+
from .torch_model import Generator
|
11 |
+
|
12 |
+
|
13 |
+
class AttrDict(dict):
|
14 |
+
def __init__(self, *args, **kwargs):
|
15 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
16 |
+
self.__dict__ = self
|
17 |
+
|
18 |
+
|
19 |
+
def load_checkpoint(filepath, device):
|
20 |
+
assert os.path.isfile(filepath)
|
21 |
+
print("Loading '{}'".format(filepath))
|
22 |
+
checkpoint_dict = torch.load(filepath, map_location=device)
|
23 |
+
print("Complete.")
|
24 |
+
return checkpoint_dict
|
25 |
+
|
26 |
+
|
27 |
+
def convert_to_haiku(a, h, device):
|
28 |
+
generator = Generator(h).to(device)
|
29 |
+
state_dict_g = load_checkpoint(a.checkpoint_file, device)
|
30 |
+
generator.load_state_dict(state_dict_g["generator"])
|
31 |
+
generator.eval()
|
32 |
+
generator.remove_weight_norm()
|
33 |
+
hk_map = {}
|
34 |
+
for a, b in generator.state_dict().items():
|
35 |
+
print(a, b.shape)
|
36 |
+
if a.startswith("conv_pre"):
|
37 |
+
a = "generator/~/conv1_d"
|
38 |
+
elif a.startswith("conv_post"):
|
39 |
+
a = "generator/~/conv1_d_1"
|
40 |
+
elif a.startswith("ups."):
|
41 |
+
ii = a.split(".")[1]
|
42 |
+
a = f"generator/~/ups_{ii}"
|
43 |
+
elif a.startswith("resblocks."):
|
44 |
+
_, x, y, z, _ = a.split(".")
|
45 |
+
ver = h.resblock
|
46 |
+
a = f"generator/~/res_block{ver}_{x}/~/{y}_{z}"
|
47 |
+
print(a, b.shape)
|
48 |
+
if a not in hk_map:
|
49 |
+
hk_map[a] = {}
|
50 |
+
if len(b.shape) == 1:
|
51 |
+
hk_map[a]["b"] = b.numpy()
|
52 |
+
else:
|
53 |
+
if "ups" in a:
|
54 |
+
hk_map[a]["w"] = np.rot90(b.numpy(), k=1, axes=(0, 2))
|
55 |
+
elif "conv" in a:
|
56 |
+
hk_map[a]["w"] = np.swapaxes(b.numpy(), 0, 2)
|
57 |
+
else:
|
58 |
+
hk_map[a]["w"] = b.numpy()
|
59 |
+
|
60 |
+
FLAGS.ckpt_dir.mkdir(parents=True, exist_ok=True)
|
61 |
+
with open(FLAGS.ckpt_dir / "hk_hifi.pickle", "wb") as f:
|
62 |
+
pickle.dump(hk_map, f)
|
63 |
+
|
64 |
+
|
65 |
+
def main():
|
66 |
+
parser = argparse.ArgumentParser()
|
67 |
+
parser.add_argument("--checkpoint-file", required=True)
|
68 |
+
parser.add_argument("--config-file", required=True)
|
69 |
+
a = parser.parse_args()
|
70 |
+
|
71 |
+
config_file = a.config_file
|
72 |
+
with open(config_file) as f:
|
73 |
+
data = f.read()
|
74 |
+
|
75 |
+
json_config = json.loads(data)
|
76 |
+
h = AttrDict(json_config)
|
77 |
+
|
78 |
+
device = torch.device("cpu")
|
79 |
+
convert_to_haiku(a, h, device)
|
80 |
+
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
main()
|
vietTTS/hifigan/create_mel.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
import torch.utils.data
|
6 |
+
import numpy as np
|
7 |
+
from librosa.util import normalize
|
8 |
+
from scipy.io.wavfile import read
|
9 |
+
from librosa.filters import mel as librosa_mel_fn
|
10 |
+
|
11 |
+
MAX_WAV_VALUE = 32768.0
|
12 |
+
|
13 |
+
|
14 |
+
def load_wav(full_path):
|
15 |
+
sampling_rate, data = read(full_path)
|
16 |
+
return data, sampling_rate
|
17 |
+
|
18 |
+
|
19 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
20 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
21 |
+
|
22 |
+
|
23 |
+
def dynamic_range_decompression(x, C=1):
|
24 |
+
return np.exp(x) / C
|
25 |
+
|
26 |
+
|
27 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
28 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
29 |
+
|
30 |
+
|
31 |
+
def dynamic_range_decompression_torch(x, C=1):
|
32 |
+
return torch.exp(x) / C
|
33 |
+
|
34 |
+
|
35 |
+
def spectral_normalize_torch(magnitudes):
|
36 |
+
output = dynamic_range_compression_torch(magnitudes)
|
37 |
+
return output
|
38 |
+
|
39 |
+
|
40 |
+
def spectral_de_normalize_torch(magnitudes):
|
41 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
42 |
+
return output
|
43 |
+
|
44 |
+
|
45 |
+
mel_basis = {}
|
46 |
+
hann_window = {}
|
47 |
+
|
48 |
+
|
49 |
+
def mel_spectrogram(
|
50 |
+
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
|
51 |
+
):
|
52 |
+
if torch.min(y) < -1.0:
|
53 |
+
print("min value is ", torch.min(y))
|
54 |
+
if torch.max(y) > 1.0:
|
55 |
+
print("max value is ", torch.max(y))
|
56 |
+
|
57 |
+
global mel_basis, hann_window
|
58 |
+
if fmax not in mel_basis:
|
59 |
+
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
60 |
+
mel_basis[str(fmax) + "_" + str(y.device)] = (
|
61 |
+
torch.from_numpy(mel).float().to(y.device)
|
62 |
+
)
|
63 |
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
64 |
+
|
65 |
+
y = torch.nn.functional.pad(
|
66 |
+
y.unsqueeze(1),
|
67 |
+
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
68 |
+
mode="reflect",
|
69 |
+
)
|
70 |
+
y = y.squeeze(1)
|
71 |
+
|
72 |
+
spec = torch.stft(
|
73 |
+
y,
|
74 |
+
n_fft,
|
75 |
+
hop_length=hop_size,
|
76 |
+
win_length=win_size,
|
77 |
+
window=hann_window[str(y.device)],
|
78 |
+
center=center,
|
79 |
+
pad_mode="reflect",
|
80 |
+
normalized=False,
|
81 |
+
onesided=True,
|
82 |
+
)
|
83 |
+
|
84 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
85 |
+
|
86 |
+
spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
|
87 |
+
spec = spectral_normalize_torch(spec)
|
88 |
+
|
89 |
+
return spec
|
90 |
+
|
91 |
+
|
92 |
+
def get_dataset_filelist(a):
|
93 |
+
with open(a.input_training_file, "r", encoding="utf-8") as fi:
|
94 |
+
training_files = [
|
95 |
+
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
|
96 |
+
for x in fi.read().split("\n")
|
97 |
+
if len(x) > 0
|
98 |
+
]
|
99 |
+
|
100 |
+
with open(a.input_validation_file, "r", encoding="utf-8") as fi:
|
101 |
+
validation_files = [
|
102 |
+
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
|
103 |
+
for x in fi.read().split("\n")
|
104 |
+
if len(x) > 0
|
105 |
+
]
|
106 |
+
return training_files, validation_files
|
107 |
+
|
108 |
+
|
109 |
+
class MelDataset(torch.utils.data.Dataset):
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
training_files,
|
113 |
+
segment_size,
|
114 |
+
n_fft,
|
115 |
+
num_mels,
|
116 |
+
hop_size,
|
117 |
+
win_size,
|
118 |
+
sampling_rate,
|
119 |
+
fmin,
|
120 |
+
fmax,
|
121 |
+
split=True,
|
122 |
+
shuffle=True,
|
123 |
+
n_cache_reuse=1,
|
124 |
+
device=None,
|
125 |
+
fmax_loss=None,
|
126 |
+
fine_tuning=False,
|
127 |
+
base_mels_path=None,
|
128 |
+
):
|
129 |
+
self.audio_files = training_files
|
130 |
+
random.seed(1234)
|
131 |
+
if shuffle:
|
132 |
+
random.shuffle(self.audio_files)
|
133 |
+
self.segment_size = segment_size
|
134 |
+
self.sampling_rate = sampling_rate
|
135 |
+
self.split = split
|
136 |
+
self.n_fft = n_fft
|
137 |
+
self.num_mels = num_mels
|
138 |
+
self.hop_size = hop_size
|
139 |
+
self.win_size = win_size
|
140 |
+
self.fmin = fmin
|
141 |
+
self.fmax = fmax
|
142 |
+
self.fmax_loss = fmax_loss
|
143 |
+
self.cached_wav = None
|
144 |
+
self.n_cache_reuse = n_cache_reuse
|
145 |
+
self._cache_ref_count = 0
|
146 |
+
self.device = device
|
147 |
+
self.fine_tuning = fine_tuning
|
148 |
+
self.base_mels_path = base_mels_path
|
149 |
+
|
150 |
+
def __getitem__(self, index):
|
151 |
+
filename = self.audio_files[index]
|
152 |
+
if self._cache_ref_count == 0:
|
153 |
+
audio, sampling_rate = load_wav(filename)
|
154 |
+
audio = audio / MAX_WAV_VALUE
|
155 |
+
if not self.fine_tuning:
|
156 |
+
audio = normalize(audio) * 0.95
|
157 |
+
self.cached_wav = audio
|
158 |
+
if sampling_rate != self.sampling_rate:
|
159 |
+
raise ValueError(
|
160 |
+
"{} SR doesn't match target {} SR".format(
|
161 |
+
sampling_rate, self.sampling_rate
|
162 |
+
)
|
163 |
+
)
|
164 |
+
self._cache_ref_count = self.n_cache_reuse
|
165 |
+
else:
|
166 |
+
audio = self.cached_wav
|
167 |
+
self._cache_ref_count -= 1
|
168 |
+
|
169 |
+
audio = torch.FloatTensor(audio)
|
170 |
+
audio = audio.unsqueeze(0)
|
171 |
+
|
172 |
+
if not self.fine_tuning:
|
173 |
+
if self.split:
|
174 |
+
if audio.size(1) >= self.segment_size:
|
175 |
+
max_audio_start = audio.size(1) - self.segment_size
|
176 |
+
audio_start = random.randint(0, max_audio_start)
|
177 |
+
audio = audio[:, audio_start : audio_start + self.segment_size]
|
178 |
+
else:
|
179 |
+
audio = torch.nn.functional.pad(
|
180 |
+
audio, (0, self.segment_size - audio.size(1)), "constant"
|
181 |
+
)
|
182 |
+
|
183 |
+
mel = mel_spectrogram(
|
184 |
+
audio,
|
185 |
+
self.n_fft,
|
186 |
+
self.num_mels,
|
187 |
+
self.sampling_rate,
|
188 |
+
self.hop_size,
|
189 |
+
self.win_size,
|
190 |
+
self.fmin,
|
191 |
+
self.fmax,
|
192 |
+
center=False,
|
193 |
+
)
|
194 |
+
else:
|
195 |
+
mel = np.load(
|
196 |
+
os.path.join(
|
197 |
+
self.base_mels_path,
|
198 |
+
os.path.splitext(os.path.split(filename)[-1])[0] + ".npy",
|
199 |
+
)
|
200 |
+
)
|
201 |
+
mel = torch.from_numpy(mel)
|
202 |
+
|
203 |
+
if len(mel.shape) < 3:
|
204 |
+
mel = mel.unsqueeze(0)
|
205 |
+
|
206 |
+
if self.split:
|
207 |
+
frames_per_seg = math.ceil(self.segment_size / self.hop_size)
|
208 |
+
|
209 |
+
if audio.size(1) >= self.segment_size:
|
210 |
+
mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
|
211 |
+
mel = mel[:, :, mel_start : mel_start + frames_per_seg]
|
212 |
+
audio = audio[
|
213 |
+
:,
|
214 |
+
mel_start
|
215 |
+
* self.hop_size : (mel_start + frames_per_seg)
|
216 |
+
* self.hop_size,
|
217 |
+
]
|
218 |
+
else:
|
219 |
+
mel = torch.nn.functional.pad(
|
220 |
+
mel, (0, frames_per_seg - mel.size(2)), "constant"
|
221 |
+
)
|
222 |
+
audio = torch.nn.functional.pad(
|
223 |
+
audio, (0, self.segment_size - audio.size(1)), "constant"
|
224 |
+
)
|
225 |
+
|
226 |
+
mel_loss = mel_spectrogram(
|
227 |
+
audio,
|
228 |
+
self.n_fft,
|
229 |
+
self.num_mels,
|
230 |
+
self.sampling_rate,
|
231 |
+
self.hop_size,
|
232 |
+
self.win_size,
|
233 |
+
self.fmin,
|
234 |
+
self.fmax_loss,
|
235 |
+
center=False,
|
236 |
+
)
|
237 |
+
|
238 |
+
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
|
239 |
+
|
240 |
+
def __len__(self):
|
241 |
+
return len(self.audio_files)
|
vietTTS/hifigan/data_loader.py
ADDED
File without changes
|
vietTTS/hifigan/mel2wave.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import pickle
|
4 |
+
|
5 |
+
import haiku as hk
|
6 |
+
import jax
|
7 |
+
import jax.numpy as jnp
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from .config import FLAGS
|
11 |
+
from .model import Generator
|
12 |
+
|
13 |
+
|
14 |
+
class AttrDict(dict):
|
15 |
+
def __init__(self, *args, **kwargs):
|
16 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
17 |
+
self.__dict__ = self
|
18 |
+
|
19 |
+
|
20 |
+
def mel2wave(mel):
|
21 |
+
config_file = "assets/hifigan/config.json"
|
22 |
+
MAX_WAV_VALUE = 32768.0
|
23 |
+
with open(config_file) as f:
|
24 |
+
data = f.read()
|
25 |
+
json_config = json.loads(data)
|
26 |
+
h = AttrDict(json_config)
|
27 |
+
|
28 |
+
@hk.transform_with_state
|
29 |
+
def forward(x):
|
30 |
+
net = Generator(h)
|
31 |
+
return net(x)
|
32 |
+
|
33 |
+
rng = next(hk.PRNGSequence(42))
|
34 |
+
|
35 |
+
with open(FLAGS.ckpt_dir / "hk_hifi.pickle", "rb") as f:
|
36 |
+
params = pickle.load(f)
|
37 |
+
aux = {}
|
38 |
+
wav, aux = forward.apply(params, aux, rng, mel)
|
39 |
+
wav = jnp.squeeze(wav)
|
40 |
+
audio = jax.device_get(wav)
|
41 |
+
return audio
|
vietTTS/hifigan/model.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import haiku as hk
|
2 |
+
import jax
|
3 |
+
import jax.numpy as jnp
|
4 |
+
|
5 |
+
LRELU_SLOPE = 0.1
|
6 |
+
|
7 |
+
|
8 |
+
def get_padding(kernel_size, dilation=1):
|
9 |
+
p = int((kernel_size * dilation - dilation) / 2)
|
10 |
+
return ((p, p),)
|
11 |
+
|
12 |
+
|
13 |
+
class ResBlock1(hk.Module):
|
14 |
+
def __init__(
|
15 |
+
self, h, channels, kernel_size=3, dilation=(1, 3, 5), name="resblock1"
|
16 |
+
):
|
17 |
+
super().__init__(name=name)
|
18 |
+
|
19 |
+
self.h = h
|
20 |
+
self.convs1 = [
|
21 |
+
hk.Conv1D(
|
22 |
+
channels,
|
23 |
+
kernel_size,
|
24 |
+
1,
|
25 |
+
rate=dilation[i],
|
26 |
+
padding=get_padding(kernel_size, dilation[i]),
|
27 |
+
name=f"convs1_{i}",
|
28 |
+
)
|
29 |
+
for i in range(3)
|
30 |
+
]
|
31 |
+
|
32 |
+
self.convs2 = [
|
33 |
+
hk.Conv1D(
|
34 |
+
channels,
|
35 |
+
kernel_size,
|
36 |
+
1,
|
37 |
+
rate=1,
|
38 |
+
padding=get_padding(kernel_size, 1),
|
39 |
+
name=f"convs2_{i}",
|
40 |
+
)
|
41 |
+
for i in range(3)
|
42 |
+
]
|
43 |
+
|
44 |
+
def __call__(self, x):
|
45 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
46 |
+
xt = jax.nn.leaky_relu(x, LRELU_SLOPE)
|
47 |
+
xt = c1(xt)
|
48 |
+
xt = jax.nn.leaky_relu(xt, LRELU_SLOPE)
|
49 |
+
xt = c2(xt)
|
50 |
+
x = xt + x
|
51 |
+
return x
|
52 |
+
|
53 |
+
|
54 |
+
class ResBlock2(hk.Module):
|
55 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), name="ResBlock2"):
|
56 |
+
super().__init__(name=name)
|
57 |
+
self.h = h
|
58 |
+
self.convs = [
|
59 |
+
hk.Conv1D(
|
60 |
+
channels,
|
61 |
+
kernel_size,
|
62 |
+
1,
|
63 |
+
rate=dilation[i],
|
64 |
+
padding=get_padding(kernel_size, dilation[i]),
|
65 |
+
)
|
66 |
+
for i in range(2)
|
67 |
+
]
|
68 |
+
|
69 |
+
def __call__(self, x):
|
70 |
+
for c in self.convs:
|
71 |
+
xt = jax.nn.leaky_relu(x, LRELU_SLOPE)
|
72 |
+
xt = c(xt)
|
73 |
+
x = xt + x
|
74 |
+
return x
|
75 |
+
|
76 |
+
|
77 |
+
class Generator(hk.Module):
|
78 |
+
def __init__(self, h):
|
79 |
+
super().__init__()
|
80 |
+
self.h = h
|
81 |
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
82 |
+
self.num_upsamples = len(h.upsample_rates)
|
83 |
+
self.conv_pre = hk.Conv1D(h.upsample_initial_channel, 7, 1, padding=((3, 3),))
|
84 |
+
resblock = ResBlock1 if h.resblock == "1" else ResBlock2
|
85 |
+
self.ups = []
|
86 |
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
87 |
+
self.ups.append(
|
88 |
+
hk.Conv1DTranspose(
|
89 |
+
h.upsample_initial_channel // (2 ** (i + 1)),
|
90 |
+
kernel_shape=k,
|
91 |
+
stride=u,
|
92 |
+
padding="SAME",
|
93 |
+
name=f"ups_{i}",
|
94 |
+
)
|
95 |
+
)
|
96 |
+
|
97 |
+
self.resblocks = []
|
98 |
+
|
99 |
+
for i in range(len(self.ups)):
|
100 |
+
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
101 |
+
for j, (k, d) in enumerate(
|
102 |
+
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
|
103 |
+
):
|
104 |
+
self.resblocks.append(
|
105 |
+
resblock(h, ch, k, d, name=f"res_block1_{len(self.resblocks)}")
|
106 |
+
)
|
107 |
+
self.conv_post = hk.Conv1D(1, 7, 1, padding=((3, 3),))
|
108 |
+
|
109 |
+
def __call__(self, x):
|
110 |
+
x = self.conv_pre(x)
|
111 |
+
for i in range(self.num_upsamples):
|
112 |
+
x = jax.nn.leaky_relu(x, LRELU_SLOPE)
|
113 |
+
|
114 |
+
x = self.ups[i](x)
|
115 |
+
xs = None
|
116 |
+
for j in range(self.num_kernels):
|
117 |
+
if xs is None:
|
118 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
119 |
+
else:
|
120 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
121 |
+
x = xs / self.num_kernels
|
122 |
+
x = jax.nn.leaky_relu(x) # default pytorch value
|
123 |
+
x = self.conv_post(x)
|
124 |
+
x = jnp.tanh(x)
|
125 |
+
return x
|
vietTTS/hifigan/torch_model.py
ADDED
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
|
5 |
+
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
6 |
+
|
7 |
+
# from utils import init_weights, get_padding
|
8 |
+
|
9 |
+
LRELU_SLOPE = 0.1
|
10 |
+
|
11 |
+
|
12 |
+
def get_padding(kernel_size, dilation=1):
|
13 |
+
return int((kernel_size * dilation - dilation) / 2)
|
14 |
+
|
15 |
+
|
16 |
+
def init_weights(m, mean=0.0, std=0.01):
|
17 |
+
classname = m.__class__.__name__
|
18 |
+
if classname.find("Conv") != -1:
|
19 |
+
m.weight.data.normal_(mean, std)
|
20 |
+
|
21 |
+
|
22 |
+
class ResBlock1(torch.nn.Module):
|
23 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
24 |
+
super(ResBlock1, self).__init__()
|
25 |
+
self.h = h
|
26 |
+
self.convs1 = nn.ModuleList(
|
27 |
+
[
|
28 |
+
weight_norm(
|
29 |
+
Conv1d(
|
30 |
+
channels,
|
31 |
+
channels,
|
32 |
+
kernel_size,
|
33 |
+
1,
|
34 |
+
dilation=dilation[0],
|
35 |
+
padding=get_padding(kernel_size, dilation[0]),
|
36 |
+
)
|
37 |
+
),
|
38 |
+
weight_norm(
|
39 |
+
Conv1d(
|
40 |
+
channels,
|
41 |
+
channels,
|
42 |
+
kernel_size,
|
43 |
+
1,
|
44 |
+
dilation=dilation[1],
|
45 |
+
padding=get_padding(kernel_size, dilation[1]),
|
46 |
+
)
|
47 |
+
),
|
48 |
+
weight_norm(
|
49 |
+
Conv1d(
|
50 |
+
channels,
|
51 |
+
channels,
|
52 |
+
kernel_size,
|
53 |
+
1,
|
54 |
+
dilation=dilation[2],
|
55 |
+
padding=get_padding(kernel_size, dilation[2]),
|
56 |
+
)
|
57 |
+
),
|
58 |
+
]
|
59 |
+
)
|
60 |
+
self.convs1.apply(init_weights)
|
61 |
+
|
62 |
+
self.convs2 = nn.ModuleList(
|
63 |
+
[
|
64 |
+
weight_norm(
|
65 |
+
Conv1d(
|
66 |
+
channels,
|
67 |
+
channels,
|
68 |
+
kernel_size,
|
69 |
+
1,
|
70 |
+
dilation=1,
|
71 |
+
padding=get_padding(kernel_size, 1),
|
72 |
+
)
|
73 |
+
),
|
74 |
+
weight_norm(
|
75 |
+
Conv1d(
|
76 |
+
channels,
|
77 |
+
channels,
|
78 |
+
kernel_size,
|
79 |
+
1,
|
80 |
+
dilation=1,
|
81 |
+
padding=get_padding(kernel_size, 1),
|
82 |
+
)
|
83 |
+
),
|
84 |
+
weight_norm(
|
85 |
+
Conv1d(
|
86 |
+
channels,
|
87 |
+
channels,
|
88 |
+
kernel_size,
|
89 |
+
1,
|
90 |
+
dilation=1,
|
91 |
+
padding=get_padding(kernel_size, 1),
|
92 |
+
)
|
93 |
+
),
|
94 |
+
]
|
95 |
+
)
|
96 |
+
self.convs2.apply(init_weights)
|
97 |
+
|
98 |
+
def forward(self, x):
|
99 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
100 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
101 |
+
xt = c1(xt)
|
102 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
103 |
+
xt = c2(xt)
|
104 |
+
x = xt + x
|
105 |
+
return x
|
106 |
+
|
107 |
+
def remove_weight_norm(self):
|
108 |
+
for l in self.convs1:
|
109 |
+
remove_weight_norm(l)
|
110 |
+
for l in self.convs2:
|
111 |
+
remove_weight_norm(l)
|
112 |
+
|
113 |
+
|
114 |
+
class ResBlock2(torch.nn.Module):
|
115 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
116 |
+
super(ResBlock2, self).__init__()
|
117 |
+
self.h = h
|
118 |
+
self.convs = nn.ModuleList(
|
119 |
+
[
|
120 |
+
weight_norm(
|
121 |
+
Conv1d(
|
122 |
+
channels,
|
123 |
+
channels,
|
124 |
+
kernel_size,
|
125 |
+
1,
|
126 |
+
dilation=dilation[0],
|
127 |
+
padding=get_padding(kernel_size, dilation[0]),
|
128 |
+
)
|
129 |
+
),
|
130 |
+
weight_norm(
|
131 |
+
Conv1d(
|
132 |
+
channels,
|
133 |
+
channels,
|
134 |
+
kernel_size,
|
135 |
+
1,
|
136 |
+
dilation=dilation[1],
|
137 |
+
padding=get_padding(kernel_size, dilation[1]),
|
138 |
+
)
|
139 |
+
),
|
140 |
+
]
|
141 |
+
)
|
142 |
+
self.convs.apply(init_weights)
|
143 |
+
|
144 |
+
def forward(self, x):
|
145 |
+
for c in self.convs:
|
146 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
147 |
+
xt = c(xt)
|
148 |
+
x = xt + x
|
149 |
+
return x
|
150 |
+
|
151 |
+
def remove_weight_norm(self):
|
152 |
+
for l in self.convs:
|
153 |
+
remove_weight_norm(l)
|
154 |
+
|
155 |
+
|
156 |
+
class Generator(torch.nn.Module):
|
157 |
+
def __init__(self, h):
|
158 |
+
super(Generator, self).__init__()
|
159 |
+
self.h = h
|
160 |
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
161 |
+
self.num_upsamples = len(h.upsample_rates)
|
162 |
+
self.conv_pre = weight_norm(
|
163 |
+
Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)
|
164 |
+
)
|
165 |
+
resblock = ResBlock1 if h.resblock == "1" else ResBlock2
|
166 |
+
|
167 |
+
self.ups = nn.ModuleList()
|
168 |
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
169 |
+
self.ups.append(
|
170 |
+
weight_norm(
|
171 |
+
ConvTranspose1d(
|
172 |
+
h.upsample_initial_channel // (2**i),
|
173 |
+
h.upsample_initial_channel // (2 ** (i + 1)),
|
174 |
+
k,
|
175 |
+
u,
|
176 |
+
padding=(k - u) // 2,
|
177 |
+
)
|
178 |
+
)
|
179 |
+
)
|
180 |
+
|
181 |
+
self.resblocks = nn.ModuleList()
|
182 |
+
for i in range(len(self.ups)):
|
183 |
+
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
184 |
+
for j, (k, d) in enumerate(
|
185 |
+
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
|
186 |
+
):
|
187 |
+
self.resblocks.append(resblock(h, ch, k, d))
|
188 |
+
|
189 |
+
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
190 |
+
self.ups.apply(init_weights)
|
191 |
+
self.conv_post.apply(init_weights)
|
192 |
+
|
193 |
+
def forward(self, x):
|
194 |
+
x = self.conv_pre(x)
|
195 |
+
for i in range(self.num_upsamples):
|
196 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
197 |
+
x = self.ups[i](x)
|
198 |
+
xs = None
|
199 |
+
for j in range(self.num_kernels):
|
200 |
+
if xs is None:
|
201 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
202 |
+
else:
|
203 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
204 |
+
x = xs / self.num_kernels
|
205 |
+
x = F.leaky_relu(x)
|
206 |
+
x = self.conv_post(x)
|
207 |
+
x = torch.tanh(x)
|
208 |
+
|
209 |
+
return x
|
210 |
+
|
211 |
+
def remove_weight_norm(self):
|
212 |
+
print("Removing weight norm...")
|
213 |
+
for l in self.ups:
|
214 |
+
remove_weight_norm(l)
|
215 |
+
for l in self.resblocks:
|
216 |
+
l.remove_weight_norm()
|
217 |
+
remove_weight_norm(self.conv_pre)
|
218 |
+
remove_weight_norm(self.conv_post)
|
219 |
+
|
220 |
+
|
221 |
+
class DiscriminatorP(torch.nn.Module):
|
222 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
223 |
+
super(DiscriminatorP, self).__init__()
|
224 |
+
self.period = period
|
225 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
226 |
+
self.convs = nn.ModuleList(
|
227 |
+
[
|
228 |
+
norm_f(
|
229 |
+
Conv2d(
|
230 |
+
1,
|
231 |
+
32,
|
232 |
+
(kernel_size, 1),
|
233 |
+
(stride, 1),
|
234 |
+
padding=(get_padding(5, 1), 0),
|
235 |
+
)
|
236 |
+
),
|
237 |
+
norm_f(
|
238 |
+
Conv2d(
|
239 |
+
32,
|
240 |
+
128,
|
241 |
+
(kernel_size, 1),
|
242 |
+
(stride, 1),
|
243 |
+
padding=(get_padding(5, 1), 0),
|
244 |
+
)
|
245 |
+
),
|
246 |
+
norm_f(
|
247 |
+
Conv2d(
|
248 |
+
128,
|
249 |
+
512,
|
250 |
+
(kernel_size, 1),
|
251 |
+
(stride, 1),
|
252 |
+
padding=(get_padding(5, 1), 0),
|
253 |
+
)
|
254 |
+
),
|
255 |
+
norm_f(
|
256 |
+
Conv2d(
|
257 |
+
512,
|
258 |
+
1024,
|
259 |
+
(kernel_size, 1),
|
260 |
+
(stride, 1),
|
261 |
+
padding=(get_padding(5, 1), 0),
|
262 |
+
)
|
263 |
+
),
|
264 |
+
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
265 |
+
]
|
266 |
+
)
|
267 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
268 |
+
|
269 |
+
def forward(self, x):
|
270 |
+
fmap = []
|
271 |
+
|
272 |
+
# 1d to 2d
|
273 |
+
b, c, t = x.shape
|
274 |
+
if t % self.period != 0: # pad first
|
275 |
+
n_pad = self.period - (t % self.period)
|
276 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
277 |
+
t = t + n_pad
|
278 |
+
x = x.view(b, c, t // self.period, self.period)
|
279 |
+
|
280 |
+
for l in self.convs:
|
281 |
+
x = l(x)
|
282 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
283 |
+
fmap.append(x)
|
284 |
+
x = self.conv_post(x)
|
285 |
+
fmap.append(x)
|
286 |
+
x = torch.flatten(x, 1, -1)
|
287 |
+
|
288 |
+
return x, fmap
|
289 |
+
|
290 |
+
|
291 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
292 |
+
def __init__(self):
|
293 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
294 |
+
self.discriminators = nn.ModuleList(
|
295 |
+
[
|
296 |
+
DiscriminatorP(2),
|
297 |
+
DiscriminatorP(3),
|
298 |
+
DiscriminatorP(5),
|
299 |
+
DiscriminatorP(7),
|
300 |
+
DiscriminatorP(11),
|
301 |
+
]
|
302 |
+
)
|
303 |
+
|
304 |
+
def forward(self, y, y_hat):
|
305 |
+
y_d_rs = []
|
306 |
+
y_d_gs = []
|
307 |
+
fmap_rs = []
|
308 |
+
fmap_gs = []
|
309 |
+
for i, d in enumerate(self.discriminators):
|
310 |
+
y_d_r, fmap_r = d(y)
|
311 |
+
y_d_g, fmap_g = d(y_hat)
|
312 |
+
y_d_rs.append(y_d_r)
|
313 |
+
fmap_rs.append(fmap_r)
|
314 |
+
y_d_gs.append(y_d_g)
|
315 |
+
fmap_gs.append(fmap_g)
|
316 |
+
|
317 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
318 |
+
|
319 |
+
|
320 |
+
class DiscriminatorS(torch.nn.Module):
|
321 |
+
def __init__(self, use_spectral_norm=False):
|
322 |
+
super(DiscriminatorS, self).__init__()
|
323 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
324 |
+
self.convs = nn.ModuleList(
|
325 |
+
[
|
326 |
+
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
|
327 |
+
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
328 |
+
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
329 |
+
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
330 |
+
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
331 |
+
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
332 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
333 |
+
]
|
334 |
+
)
|
335 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
336 |
+
|
337 |
+
def forward(self, x):
|
338 |
+
fmap = []
|
339 |
+
for l in self.convs:
|
340 |
+
x = l(x)
|
341 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
342 |
+
fmap.append(x)
|
343 |
+
x = self.conv_post(x)
|
344 |
+
fmap.append(x)
|
345 |
+
x = torch.flatten(x, 1, -1)
|
346 |
+
|
347 |
+
return x, fmap
|
348 |
+
|
349 |
+
|
350 |
+
class MultiScaleDiscriminator(torch.nn.Module):
|
351 |
+
def __init__(self):
|
352 |
+
super(MultiScaleDiscriminator, self).__init__()
|
353 |
+
self.discriminators = nn.ModuleList(
|
354 |
+
[
|
355 |
+
DiscriminatorS(use_spectral_norm=True),
|
356 |
+
DiscriminatorS(),
|
357 |
+
DiscriminatorS(),
|
358 |
+
]
|
359 |
+
)
|
360 |
+
self.meanpools = nn.ModuleList(
|
361 |
+
[AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]
|
362 |
+
)
|
363 |
+
|
364 |
+
def forward(self, y, y_hat):
|
365 |
+
y_d_rs = []
|
366 |
+
y_d_gs = []
|
367 |
+
fmap_rs = []
|
368 |
+
fmap_gs = []
|
369 |
+
for i, d in enumerate(self.discriminators):
|
370 |
+
if i != 0:
|
371 |
+
y = self.meanpools[i - 1](y)
|
372 |
+
y_hat = self.meanpools[i - 1](y_hat)
|
373 |
+
y_d_r, fmap_r = d(y)
|
374 |
+
y_d_g, fmap_g = d(y_hat)
|
375 |
+
y_d_rs.append(y_d_r)
|
376 |
+
fmap_rs.append(fmap_r)
|
377 |
+
y_d_gs.append(y_d_g)
|
378 |
+
fmap_gs.append(fmap_g)
|
379 |
+
|
380 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
381 |
+
|
382 |
+
|
383 |
+
def feature_loss(fmap_r, fmap_g):
|
384 |
+
loss = 0
|
385 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
386 |
+
for rl, gl in zip(dr, dg):
|
387 |
+
loss += torch.mean(torch.abs(rl - gl))
|
388 |
+
|
389 |
+
return loss * 2
|
390 |
+
|
391 |
+
|
392 |
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
393 |
+
loss = 0
|
394 |
+
r_losses = []
|
395 |
+
g_losses = []
|
396 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
397 |
+
r_loss = torch.mean((1 - dr) ** 2)
|
398 |
+
g_loss = torch.mean(dg**2)
|
399 |
+
loss += r_loss + g_loss
|
400 |
+
r_losses.append(r_loss.item())
|
401 |
+
g_losses.append(g_loss.item())
|
402 |
+
|
403 |
+
return loss, r_losses, g_losses
|
404 |
+
|
405 |
+
|
406 |
+
def generator_loss(disc_outputs):
|
407 |
+
loss = 0
|
408 |
+
gen_losses = []
|
409 |
+
for dg in disc_outputs:
|
410 |
+
l = torch.mean((1 - dg) ** 2)
|
411 |
+
gen_losses.append(l)
|
412 |
+
loss += l
|
413 |
+
|
414 |
+
return loss, gen_losses
|
vietTTS/hifigan/trainer.py
ADDED
File without changes
|
vietTTS/nat/__init__.py
ADDED
File without changes
|
vietTTS/nat/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (158 Bytes). View file
|
|
vietTTS/nat/__pycache__/config.cpython-39.pyc
ADDED
Binary file (2.43 kB). View file
|
|
vietTTS/nat/__pycache__/data_loader.cpython-39.pyc
ADDED
Binary file (4.25 kB). View file
|
|
vietTTS/nat/__pycache__/model.cpython-39.pyc
ADDED
Binary file (6.93 kB). View file
|
|
vietTTS/nat/__pycache__/text2mel.cpython-39.pyc
ADDED
Binary file (4.01 kB). View file
|
|
vietTTS/nat/acoustic_tpu_trainer.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
from functools import partial
|
4 |
+
from typing import Deque
|
5 |
+
|
6 |
+
import fire
|
7 |
+
import jax
|
8 |
+
import jax.numpy as jnp
|
9 |
+
import jax.tools.colab_tpu
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import optax
|
12 |
+
from tqdm.auto import tqdm
|
13 |
+
|
14 |
+
from .acoustic_trainer import initial_state, loss_vag, val_loss_fn
|
15 |
+
from .config import FLAGS
|
16 |
+
from .data_loader import load_textgrid_wav
|
17 |
+
from .dsp import MelFilter
|
18 |
+
from .utils import print_flags
|
19 |
+
|
20 |
+
|
21 |
+
def setup_colab_tpu():
|
22 |
+
jax.tools.colab_tpu.setup_tpu()
|
23 |
+
|
24 |
+
|
25 |
+
def train(
|
26 |
+
batch_size: int = 32,
|
27 |
+
steps_per_update: int = 10,
|
28 |
+
learning_rate: float = 1024e-6,
|
29 |
+
):
|
30 |
+
"""Train acoustic model on multiple cores (TPU)."""
|
31 |
+
lr_schedule = optax.exponential_decay(learning_rate, 50_000, 0.5, staircase=True)
|
32 |
+
|
33 |
+
optimizer = optax.chain(
|
34 |
+
optax.clip_by_global_norm(1.0),
|
35 |
+
optax.adamw(lr_schedule, weight_decay=FLAGS.weight_decay),
|
36 |
+
)
|
37 |
+
|
38 |
+
def update_step(prev_state, inputs):
|
39 |
+
params, aux, rng, optim_state = prev_state
|
40 |
+
rng, new_rng = jax.random.split(rng)
|
41 |
+
(loss, new_aux), grads = loss_vag(params, aux, rng, inputs)
|
42 |
+
grads = jax.lax.pmean(grads, axis_name="i")
|
43 |
+
updates, new_optim_state = optimizer.update(grads, optim_state, params)
|
44 |
+
new_params = optax.apply_updates(params, updates)
|
45 |
+
next_state = (new_params, new_aux, new_rng, new_optim_state)
|
46 |
+
return next_state, loss
|
47 |
+
|
48 |
+
@partial(jax.pmap, axis_name="i")
|
49 |
+
def update(params, aux, rng, optim_state, inputs):
|
50 |
+
states, losses = jax.lax.scan(
|
51 |
+
update_step, (params, aux, rng, optim_state), inputs
|
52 |
+
)
|
53 |
+
return states, jnp.mean(losses)
|
54 |
+
|
55 |
+
print(jax.devices())
|
56 |
+
num_devices = jax.device_count()
|
57 |
+
train_data_iter = load_textgrid_wav(
|
58 |
+
FLAGS.data_dir,
|
59 |
+
FLAGS.max_phoneme_seq_len,
|
60 |
+
batch_size * num_devices * steps_per_update,
|
61 |
+
FLAGS.max_wave_len,
|
62 |
+
"train",
|
63 |
+
)
|
64 |
+
val_data_iter = load_textgrid_wav(
|
65 |
+
FLAGS.data_dir,
|
66 |
+
FLAGS.max_phoneme_seq_len,
|
67 |
+
batch_size,
|
68 |
+
FLAGS.max_wave_len,
|
69 |
+
"val",
|
70 |
+
)
|
71 |
+
melfilter = MelFilter(
|
72 |
+
FLAGS.sample_rate,
|
73 |
+
FLAGS.n_fft,
|
74 |
+
FLAGS.mel_dim,
|
75 |
+
FLAGS.fmin,
|
76 |
+
FLAGS.fmax,
|
77 |
+
)
|
78 |
+
batch = next(train_data_iter)
|
79 |
+
batch = jax.tree_map(lambda x: x[:1], batch)
|
80 |
+
batch = batch._replace(mels=melfilter(batch.wavs.astype(jnp.float32) / (2**15)))
|
81 |
+
params, aux, rng, optim_state = initial_state(optimizer, batch)
|
82 |
+
losses = Deque(maxlen=1000)
|
83 |
+
val_losses = Deque(maxlen=100)
|
84 |
+
|
85 |
+
last_step = -steps_per_update
|
86 |
+
|
87 |
+
# loading latest checkpoint
|
88 |
+
ckpt_fn = FLAGS.ckpt_dir / "acoustic_latest_ckpt.pickle"
|
89 |
+
if ckpt_fn.exists():
|
90 |
+
print("Resuming from latest checkpoint at", ckpt_fn)
|
91 |
+
with open(ckpt_fn, "rb") as f:
|
92 |
+
dic = pickle.load(f)
|
93 |
+
last_step, params, aux, rng, optim_state = (
|
94 |
+
dic["step"],
|
95 |
+
dic["params"],
|
96 |
+
dic["aux"],
|
97 |
+
dic["rng"],
|
98 |
+
dic["optim_state"],
|
99 |
+
)
|
100 |
+
|
101 |
+
tr = tqdm(
|
102 |
+
range(
|
103 |
+
last_step + steps_per_update, FLAGS.num_training_steps + 1, steps_per_update
|
104 |
+
),
|
105 |
+
desc="training",
|
106 |
+
total=FLAGS.num_training_steps // steps_per_update + 1,
|
107 |
+
initial=last_step // steps_per_update + 1,
|
108 |
+
)
|
109 |
+
|
110 |
+
params, aux, rng, optim_state = jax.device_put_replicated(
|
111 |
+
(params, aux, rng, optim_state), jax.devices()
|
112 |
+
)
|
113 |
+
|
114 |
+
def batch_reshape(batch):
|
115 |
+
return jax.tree_map(
|
116 |
+
lambda x: jnp.reshape(x, (num_devices, steps_per_update, -1) + x.shape[1:]),
|
117 |
+
batch,
|
118 |
+
)
|
119 |
+
|
120 |
+
for step in tr:
|
121 |
+
batch = next(train_data_iter)
|
122 |
+
batch = batch_reshape(batch)
|
123 |
+
(params, aux, rng, optim_state), loss = update(
|
124 |
+
params, aux, rng, optim_state, batch
|
125 |
+
)
|
126 |
+
losses.append(loss)
|
127 |
+
|
128 |
+
if step % 10 == 0:
|
129 |
+
val_batch = next(val_data_iter)
|
130 |
+
val_loss, val_aux, predicted_mel, gt_mel = val_loss_fn(
|
131 |
+
*jax.tree_map(lambda x: x[0], (params, aux, rng)), val_batch
|
132 |
+
)
|
133 |
+
val_losses.append(val_loss)
|
134 |
+
attn = jax.device_get(val_aux["acoustic_model"]["attn"])
|
135 |
+
predicted_mel = jax.device_get(predicted_mel[0])
|
136 |
+
gt_mel = jax.device_get(gt_mel[0])
|
137 |
+
|
138 |
+
if step % 1000 == 0:
|
139 |
+
loss = jnp.mean(sum(losses)).item() / len(losses)
|
140 |
+
val_loss = sum(val_losses).item() / len(val_losses)
|
141 |
+
tr.write(f"step {step} train loss {loss:.3f} val loss {val_loss:.3f}")
|
142 |
+
|
143 |
+
# saving predicted mels
|
144 |
+
plt.figure(figsize=(10, 10))
|
145 |
+
plt.subplot(3, 1, 1)
|
146 |
+
plt.imshow(predicted_mel.T, origin="lower", aspect="auto")
|
147 |
+
plt.subplot(3, 1, 2)
|
148 |
+
plt.imshow(gt_mel.T, origin="lower", aspect="auto")
|
149 |
+
plt.subplot(3, 1, 3)
|
150 |
+
plt.imshow(attn.T, origin="lower", aspect="auto")
|
151 |
+
plt.tight_layout()
|
152 |
+
plt.savefig(FLAGS.ckpt_dir / f"mel_{step:06d}.png")
|
153 |
+
plt.close()
|
154 |
+
|
155 |
+
# saving checkpoint
|
156 |
+
with open(ckpt_fn, "wb") as f:
|
157 |
+
params_, aux_, rng_, optim_state_ = jax.tree_map(
|
158 |
+
lambda x: x[0], (params, aux, rng, optim_state)
|
159 |
+
)
|
160 |
+
pickle.dump(
|
161 |
+
{
|
162 |
+
"step": step,
|
163 |
+
"params": params_,
|
164 |
+
"aux": aux_,
|
165 |
+
"rng": rng_,
|
166 |
+
"optim_state": optim_state_,
|
167 |
+
},
|
168 |
+
f,
|
169 |
+
)
|
170 |
+
|
171 |
+
|
172 |
+
if __name__ == "__main__":
|
173 |
+
# we don't use these flags.
|
174 |
+
del FLAGS.batch_size
|
175 |
+
del FLAGS.learning_rate
|
176 |
+
del FLAGS.duration_learning_rate
|
177 |
+
del FLAGS.duration_lstm_dim
|
178 |
+
del FLAGS.duration_embed_dropout_rate
|
179 |
+
|
180 |
+
print_flags(FLAGS.__dict__)
|
181 |
+
|
182 |
+
if "COLAB_TPU_ADDR" in os.environ:
|
183 |
+
setup_colab_tpu()
|
184 |
+
|
185 |
+
if not FLAGS.ckpt_dir.exists():
|
186 |
+
print("Create checkpoint dir at", FLAGS.ckpt_dir)
|
187 |
+
FLAGS.ckpt_dir.mkdir(parents=True, exist_ok=True)
|
188 |
+
|
189 |
+
fire.Fire(train)
|
vietTTS/nat/acoustic_trainer.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
from functools import partial
|
3 |
+
from typing import Deque
|
4 |
+
|
5 |
+
import haiku as hk
|
6 |
+
import jax
|
7 |
+
import jax.numpy as jnp
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import optax
|
10 |
+
from tqdm.auto import tqdm
|
11 |
+
from vietTTS.nat.config import AcousticInput
|
12 |
+
|
13 |
+
from .config import FLAGS, AcousticInput
|
14 |
+
from .data_loader import load_textgrid_wav
|
15 |
+
from .dsp import MelFilter
|
16 |
+
from .model import AcousticModel
|
17 |
+
from .utils import print_flags
|
18 |
+
|
19 |
+
|
20 |
+
@hk.transform_with_state
|
21 |
+
def net(x):
|
22 |
+
return AcousticModel(is_training=True)(x)
|
23 |
+
|
24 |
+
|
25 |
+
@hk.transform_with_state
|
26 |
+
def val_net(x):
|
27 |
+
return AcousticModel(is_training=False)(x)
|
28 |
+
|
29 |
+
|
30 |
+
def loss_fn(params, aux, rng, inputs: AcousticInput, is_training=True):
|
31 |
+
"""Compute loss"""
|
32 |
+
melfilter = MelFilter(
|
33 |
+
FLAGS.sample_rate, FLAGS.n_fft, FLAGS.mel_dim, FLAGS.fmin, FLAGS.fmax
|
34 |
+
)
|
35 |
+
wavs = inputs.wavs.astype(jnp.float32) / (2**15)
|
36 |
+
mels = melfilter(wavs)
|
37 |
+
B, L, D = mels.shape
|
38 |
+
go_frame = jnp.zeros((B, 1, D), dtype=jnp.float32)
|
39 |
+
inp_mels = jnp.concatenate((go_frame, mels[:, :-1, :]), axis=1)
|
40 |
+
n_frames = inputs.durations * FLAGS.sample_rate / (FLAGS.n_fft // 4)
|
41 |
+
inputs = inputs._replace(mels=inp_mels, durations=n_frames)
|
42 |
+
model = net if is_training else val_net
|
43 |
+
(mel1_hat, mel2_hat), new_aux = model.apply(params, aux, rng, inputs)
|
44 |
+
loss1 = (jnp.square(mel1_hat - mels) + jnp.square(mel2_hat - mels)) / 2
|
45 |
+
loss2 = (jnp.abs(mel1_hat - mels) + jnp.abs(mel2_hat - mels)) / 2
|
46 |
+
loss = jnp.mean((loss1 + loss2) / 2, axis=-1)
|
47 |
+
num_frames = (inputs.wav_lengths // (FLAGS.n_fft // 4))[:, None]
|
48 |
+
mask = jnp.arange(0, L)[None, :] < num_frames
|
49 |
+
loss = jnp.sum(loss * mask) / jnp.sum(mask)
|
50 |
+
return (loss, new_aux) if is_training else (loss, new_aux, mel2_hat, mels)
|
51 |
+
|
52 |
+
|
53 |
+
train_loss_fn = partial(loss_fn, is_training=True)
|
54 |
+
val_loss_fn = jax.jit(partial(loss_fn, is_training=False))
|
55 |
+
|
56 |
+
loss_vag = jax.value_and_grad(train_loss_fn, has_aux=True)
|
57 |
+
|
58 |
+
|
59 |
+
def initial_state(optimizer, batch):
|
60 |
+
rng = jax.random.PRNGKey(42)
|
61 |
+
params, aux = hk.transform_with_state(lambda x: AcousticModel(True)(x)).init(
|
62 |
+
rng, batch
|
63 |
+
)
|
64 |
+
optim_state = optimizer.init(params)
|
65 |
+
return params, aux, rng, optim_state
|
66 |
+
|
67 |
+
|
68 |
+
def train():
|
69 |
+
|
70 |
+
optimizer = optax.chain(
|
71 |
+
optax.clip_by_global_norm(1.0),
|
72 |
+
optax.adamw(FLAGS.learning_rate, weight_decay=FLAGS.weight_decay),
|
73 |
+
)
|
74 |
+
|
75 |
+
@jax.jit
|
76 |
+
def update(params, aux, rng, optim_state, inputs):
|
77 |
+
rng, new_rng = jax.random.split(rng)
|
78 |
+
(loss, new_aux), grads = loss_vag(params, aux, rng, inputs)
|
79 |
+
updates, new_optim_state = optimizer.update(grads, optim_state, params)
|
80 |
+
new_params = optax.apply_updates(updates, params)
|
81 |
+
return loss, (new_params, new_aux, new_rng, new_optim_state)
|
82 |
+
|
83 |
+
train_data_iter = load_textgrid_wav(
|
84 |
+
FLAGS.data_dir,
|
85 |
+
FLAGS.max_phoneme_seq_len,
|
86 |
+
FLAGS.batch_size,
|
87 |
+
FLAGS.max_wave_len,
|
88 |
+
"train",
|
89 |
+
)
|
90 |
+
val_data_iter = load_textgrid_wav(
|
91 |
+
FLAGS.data_dir,
|
92 |
+
FLAGS.max_phoneme_seq_len,
|
93 |
+
FLAGS.batch_size,
|
94 |
+
FLAGS.max_wave_len,
|
95 |
+
"val",
|
96 |
+
)
|
97 |
+
melfilter = MelFilter(
|
98 |
+
FLAGS.sample_rate, FLAGS.n_fft, FLAGS.mel_dim, FLAGS.fmin, FLAGS.fmax
|
99 |
+
)
|
100 |
+
batch = next(train_data_iter)
|
101 |
+
batch = batch._replace(mels=melfilter(batch.wavs.astype(jnp.float32) / (2**15)))
|
102 |
+
params, aux, rng, optim_state = initial_state(optimizer, batch)
|
103 |
+
losses = Deque(maxlen=1000)
|
104 |
+
val_losses = Deque(maxlen=100)
|
105 |
+
|
106 |
+
last_step = -1
|
107 |
+
|
108 |
+
# loading latest checkpoint
|
109 |
+
ckpt_fn = FLAGS.ckpt_dir / "acoustic_latest_ckpt.pickle"
|
110 |
+
if ckpt_fn.exists():
|
111 |
+
print("Resuming from latest checkpoint at", ckpt_fn)
|
112 |
+
with open(ckpt_fn, "rb") as f:
|
113 |
+
dic = pickle.load(f)
|
114 |
+
last_step, params, aux, rng, optim_state = (
|
115 |
+
dic["step"],
|
116 |
+
dic["params"],
|
117 |
+
dic["aux"],
|
118 |
+
dic["rng"],
|
119 |
+
dic["optim_state"],
|
120 |
+
)
|
121 |
+
|
122 |
+
tr = tqdm(
|
123 |
+
range(last_step + 1, FLAGS.num_training_steps + 1),
|
124 |
+
desc="training",
|
125 |
+
total=FLAGS.num_training_steps + 1,
|
126 |
+
initial=last_step + 1,
|
127 |
+
)
|
128 |
+
for step in tr:
|
129 |
+
batch = next(train_data_iter)
|
130 |
+
loss, (params, aux, rng, optim_state) = update(
|
131 |
+
params, aux, rng, optim_state, batch
|
132 |
+
)
|
133 |
+
losses.append(loss)
|
134 |
+
|
135 |
+
if step % 10 == 0:
|
136 |
+
val_batch = next(val_data_iter)
|
137 |
+
val_loss, val_aux, predicted_mel, gt_mel = val_loss_fn(
|
138 |
+
params, aux, rng, val_batch
|
139 |
+
)
|
140 |
+
val_losses.append(val_loss)
|
141 |
+
attn = jax.device_get(val_aux["acoustic_model"]["attn"])
|
142 |
+
predicted_mel = jax.device_get(predicted_mel[0])
|
143 |
+
gt_mel = jax.device_get(gt_mel[0])
|
144 |
+
|
145 |
+
if step % 1000 == 0:
|
146 |
+
loss = sum(losses).item() / len(losses)
|
147 |
+
val_loss = sum(val_losses).item() / len(val_losses)
|
148 |
+
tr.write(f"step {step} train loss {loss:.3f} val loss {val_loss:.3f}")
|
149 |
+
|
150 |
+
# saving predicted mels
|
151 |
+
plt.figure(figsize=(10, 10))
|
152 |
+
plt.subplot(3, 1, 1)
|
153 |
+
plt.imshow(predicted_mel.T, origin="lower", aspect="auto")
|
154 |
+
plt.subplot(3, 1, 2)
|
155 |
+
plt.imshow(gt_mel.T, origin="lower", aspect="auto")
|
156 |
+
plt.subplot(3, 1, 3)
|
157 |
+
plt.imshow(attn.T, origin="lower", aspect="auto")
|
158 |
+
plt.tight_layout()
|
159 |
+
plt.savefig(FLAGS.ckpt_dir / f"mel_{step:06d}.png")
|
160 |
+
plt.close()
|
161 |
+
|
162 |
+
# saving checkpoint
|
163 |
+
with open(ckpt_fn, "wb") as f:
|
164 |
+
pickle.dump(
|
165 |
+
{
|
166 |
+
"step": step,
|
167 |
+
"params": params,
|
168 |
+
"aux": aux,
|
169 |
+
"rng": rng,
|
170 |
+
"optim_state": optim_state,
|
171 |
+
},
|
172 |
+
f,
|
173 |
+
)
|
174 |
+
|
175 |
+
|
176 |
+
if __name__ == "__main__":
|
177 |
+
print_flags(FLAGS.__dict__)
|
178 |
+
if not FLAGS.ckpt_dir.exists():
|
179 |
+
print("Create checkpoint dir at", FLAGS.ckpt_dir)
|
180 |
+
FLAGS.ckpt_dir.mkdir(parents=True, exist_ok=True)
|
181 |
+
train()
|
vietTTS/nat/config.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import Namespace
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import NamedTuple
|
4 |
+
|
5 |
+
from jax.numpy import ndarray
|
6 |
+
|
7 |
+
|
8 |
+
class FLAGS(Namespace):
|
9 |
+
"""Configurations"""
|
10 |
+
|
11 |
+
duration_lstm_dim = 256
|
12 |
+
vocab_size = 256
|
13 |
+
duration_embed_dropout_rate = 0.5
|
14 |
+
num_training_steps = 200_000
|
15 |
+
postnet_dim = 512
|
16 |
+
acoustic_decoder_dim = 512
|
17 |
+
acoustic_encoder_dim = 256
|
18 |
+
|
19 |
+
# dataset
|
20 |
+
max_phoneme_seq_len = 256 * 1
|
21 |
+
assert max_phoneme_seq_len % 256 == 0 # prevent compilation error on Colab T4 GPU
|
22 |
+
max_wave_len = 1024 * 64 * 3
|
23 |
+
|
24 |
+
# Montreal Forced Aligner
|
25 |
+
special_phonemes = ["sil", "sp", "spn", " "] # [sil], [sp] [spn] [word end]
|
26 |
+
sil_index = special_phonemes.index("sil")
|
27 |
+
sp_index = sil_index # no use of "sp"
|
28 |
+
word_end_index = special_phonemes.index(" ")
|
29 |
+
_normal_phonemes = (
|
30 |
+
[]
|
31 |
+
+ ["a", "b", "c", "d", "e", "g", "h", "i", "k", "l"]
|
32 |
+
+ ["m", "n", "o", "p", "q", "r", "s", "t", "u", "v"]
|
33 |
+
+ ["x", "y", "à", "á", "â", "ã", "è", "é", "ê", "ì"]
|
34 |
+
+ ["í", "ò", "ó", "ô", "õ", "ù", "ú", "ý", "ă", "đ"]
|
35 |
+
+ ["ĩ", "ũ", "ơ", "ư", "ạ", "ả", "ấ", "ầ", "ẩ", "ẫ"]
|
36 |
+
+ ["ậ", "ắ", "ằ", "ẳ", "ẵ", "ặ", "ẹ", "ẻ", "ẽ", "ế"]
|
37 |
+
+ ["ề", "ể", "ễ", "ệ", "ỉ", "ị", "ọ", "ỏ", "ố", "ồ"]
|
38 |
+
+ ["ổ", "ỗ", "ộ", "ớ", "ờ", "ở", "ỡ", "ợ", "ụ", "ủ"]
|
39 |
+
+ ["ứ", "ừ", "ử", "ữ", "ự", "ỳ", "ỵ", "ỷ", "ỹ"]
|
40 |
+
)
|
41 |
+
|
42 |
+
# dsp
|
43 |
+
mel_dim = 80
|
44 |
+
n_fft = 1024
|
45 |
+
sample_rate = 16000
|
46 |
+
fmin = 0.0
|
47 |
+
fmax = 8000
|
48 |
+
|
49 |
+
# training
|
50 |
+
batch_size = 64
|
51 |
+
learning_rate = 1e-4
|
52 |
+
duration_learning_rate = 1e-4
|
53 |
+
max_grad_norm = 1.0
|
54 |
+
weight_decay = 1e-4
|
55 |
+
token_mask_prob = 0.1
|
56 |
+
|
57 |
+
# ckpt
|
58 |
+
ckpt_dir = Path("assets/infore/nat")
|
59 |
+
data_dir = Path("train_data")
|
60 |
+
|
61 |
+
|
62 |
+
class DurationInput(NamedTuple):
|
63 |
+
phonemes: ndarray
|
64 |
+
lengths: ndarray
|
65 |
+
durations: ndarray
|
66 |
+
|
67 |
+
|
68 |
+
class AcousticInput(NamedTuple):
|
69 |
+
phonemes: ndarray
|
70 |
+
lengths: ndarray
|
71 |
+
durations: ndarray
|
72 |
+
wavs: ndarray
|
73 |
+
wav_lengths: ndarray
|
74 |
+
mels: ndarray
|
vietTTS/nat/data_loader.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import textgrid
|
6 |
+
from scipy.io import wavfile
|
7 |
+
|
8 |
+
from .config import FLAGS, AcousticInput, DurationInput
|
9 |
+
|
10 |
+
|
11 |
+
def load_phonemes_set():
|
12 |
+
S = FLAGS.special_phonemes + FLAGS._normal_phonemes
|
13 |
+
return S
|
14 |
+
|
15 |
+
|
16 |
+
def pad_seq(s, maxlen, value=0):
|
17 |
+
assert maxlen >= len(s)
|
18 |
+
return tuple(s) + (value,) * (maxlen - len(s))
|
19 |
+
|
20 |
+
|
21 |
+
def is_in_word(phone, word):
|
22 |
+
def time_in_word(time, word):
|
23 |
+
return (word.minTime - 1e-3) < time and (word.maxTime + 1e-3) > time
|
24 |
+
|
25 |
+
return time_in_word(phone.minTime, word) and time_in_word(phone.maxTime, word)
|
26 |
+
|
27 |
+
|
28 |
+
def load_textgrid(fn: Path):
|
29 |
+
"""load textgrid file"""
|
30 |
+
tg = textgrid.TextGrid.fromFile(str(fn.resolve()))
|
31 |
+
data = []
|
32 |
+
words = list(tg[0])
|
33 |
+
widx = 0
|
34 |
+
assert tg[1][0].minTime == 0, "The first phoneme has to start at time 0"
|
35 |
+
for p in tg[1]:
|
36 |
+
if not p in words[widx]:
|
37 |
+
widx = widx + 1
|
38 |
+
if len(words[widx - 1].mark) > 0:
|
39 |
+
data.append((FLAGS.special_phonemes[FLAGS.word_end_index], 0.0))
|
40 |
+
if widx >= len(words):
|
41 |
+
break
|
42 |
+
assert p in words[widx], "mismatched word vs phoneme"
|
43 |
+
mark = p.mark.strip().lower()
|
44 |
+
if len(mark) == 0:
|
45 |
+
mark = "sil"
|
46 |
+
data.append((mark, p.duration()))
|
47 |
+
return data
|
48 |
+
|
49 |
+
|
50 |
+
def textgrid_data_loader(data_dir: Path, seq_len: int, batch_size: int, mode: str):
|
51 |
+
"""load all textgrid files in the directory"""
|
52 |
+
tg_files = sorted(data_dir.glob("*.TextGrid"))
|
53 |
+
random.Random(42).shuffle(tg_files)
|
54 |
+
L = len(tg_files) * 95 // 100
|
55 |
+
assert mode in ["train", "val"]
|
56 |
+
phonemes = load_phonemes_set()
|
57 |
+
if mode == "train":
|
58 |
+
tg_files = tg_files[:L]
|
59 |
+
if mode == "val":
|
60 |
+
tg_files = tg_files[L:]
|
61 |
+
|
62 |
+
data = []
|
63 |
+
for fn in tg_files:
|
64 |
+
ps, ds = zip(*load_textgrid(fn))
|
65 |
+
ps = [phonemes.index(p) for p in ps]
|
66 |
+
l = len(ps)
|
67 |
+
ps = pad_seq(ps, seq_len, 0)
|
68 |
+
ds = pad_seq(ds, seq_len, 0)
|
69 |
+
data.append((ps, ds, l))
|
70 |
+
|
71 |
+
batch = []
|
72 |
+
while True:
|
73 |
+
random.shuffle(data)
|
74 |
+
for e in data:
|
75 |
+
batch.append(e)
|
76 |
+
if len(batch) == batch_size:
|
77 |
+
ps, ds, lengths = zip(*batch)
|
78 |
+
ps = np.array(ps, dtype=np.int32)
|
79 |
+
ds = np.array(ds, dtype=np.float32)
|
80 |
+
lengths = np.array(lengths, dtype=np.int32)
|
81 |
+
yield DurationInput(ps, lengths, ds)
|
82 |
+
batch = []
|
83 |
+
|
84 |
+
|
85 |
+
def load_textgrid_wav(
|
86 |
+
data_dir: Path, token_seq_len: int, batch_size, pad_wav_len, mode: str
|
87 |
+
):
|
88 |
+
"""load wav and textgrid files to memory."""
|
89 |
+
tg_files = sorted(data_dir.glob("*.TextGrid"))
|
90 |
+
random.Random(42).shuffle(tg_files)
|
91 |
+
L = len(tg_files) * 95 // 100
|
92 |
+
assert mode in ["train", "val", "gta"]
|
93 |
+
phonemes = load_phonemes_set()
|
94 |
+
if mode == "gta":
|
95 |
+
tg_files = tg_files # all files
|
96 |
+
elif mode == "train":
|
97 |
+
tg_files = tg_files[:L]
|
98 |
+
elif mode == "val":
|
99 |
+
tg_files = tg_files[L:]
|
100 |
+
|
101 |
+
data = []
|
102 |
+
for fn in tg_files:
|
103 |
+
ps, ds = zip(*load_textgrid(fn))
|
104 |
+
ps = [phonemes.index(p) for p in ps]
|
105 |
+
l = len(ps)
|
106 |
+
ps = pad_seq(ps, token_seq_len, 0)
|
107 |
+
ds = pad_seq(ds, token_seq_len, 0)
|
108 |
+
|
109 |
+
wav_file = data_dir / f"{fn.stem}.wav"
|
110 |
+
sr, y = wavfile.read(wav_file)
|
111 |
+
y = np.copy(y)
|
112 |
+
start_time = 0
|
113 |
+
for i, (phone_idx, duration) in enumerate(zip(ps, ds)):
|
114 |
+
l = int(start_time * sr)
|
115 |
+
end_time = start_time + duration
|
116 |
+
r = int(end_time * sr)
|
117 |
+
if i == len(ps) - 1:
|
118 |
+
r = len(y)
|
119 |
+
if phone_idx < len(FLAGS.special_phonemes):
|
120 |
+
y[l:r] = 0
|
121 |
+
start_time = end_time
|
122 |
+
|
123 |
+
if len(y) > pad_wav_len:
|
124 |
+
y = y[:pad_wav_len]
|
125 |
+
|
126 |
+
# # normalize to match hifigan preprocessing
|
127 |
+
# y = y.astype(np.float32)
|
128 |
+
# y = y / np.max(np.abs(y))
|
129 |
+
# y = y * 0.95
|
130 |
+
# y = y * (2 ** 15)
|
131 |
+
# y = y.astype(np.int16)
|
132 |
+
|
133 |
+
wav_length = len(y)
|
134 |
+
y = np.pad(y, (0, pad_wav_len - len(y)))
|
135 |
+
data.append((fn.stem, ps, ds, l, y, wav_length))
|
136 |
+
|
137 |
+
batch = []
|
138 |
+
while True:
|
139 |
+
random.shuffle(data)
|
140 |
+
for idx, e in enumerate(data):
|
141 |
+
batch.append(e)
|
142 |
+
if len(batch) == batch_size or (mode == "gta" and idx == len(data) - 1):
|
143 |
+
names, ps, ds, lengths, wavs, wav_lengths = zip(*batch)
|
144 |
+
ps = np.array(ps, dtype=np.int32)
|
145 |
+
ds = np.array(ds, dtype=np.float32)
|
146 |
+
lengths = np.array(lengths, dtype=np.int32)
|
147 |
+
wavs = np.array(wavs, dtype=np.int16)
|
148 |
+
wav_lengths = np.array(wav_lengths, dtype=np.int32)
|
149 |
+
if mode == "gta":
|
150 |
+
yield names, AcousticInput(ps, lengths, ds, wavs, wav_lengths, None)
|
151 |
+
else:
|
152 |
+
yield AcousticInput(ps, lengths, ds, wavs, wav_lengths, None)
|
153 |
+
batch = []
|
154 |
+
if mode == "gta":
|
155 |
+
assert len(batch) == 0
|
156 |
+
break
|
vietTTS/nat/dsp.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import jax
|
5 |
+
import jax.numpy as jnp
|
6 |
+
import librosa
|
7 |
+
from einops import rearrange
|
8 |
+
from jax.numpy import ndarray
|
9 |
+
|
10 |
+
|
11 |
+
def rolling_window(a: ndarray, window: int, hop_length: int):
|
12 |
+
"""return a stack of overlap subsequence of an array.
|
13 |
+
``return jnp.stack( [a[0:10], a[5:15], a[10:20],...], axis=0)``
|
14 |
+
Source: https://github.com/google/jax/issues/3171
|
15 |
+
Args:
|
16 |
+
a (ndarray): input array of shape `[L, ...]`
|
17 |
+
window (int): length of each subarray (window).
|
18 |
+
hop_length (int): distance between neighbouring windows.
|
19 |
+
"""
|
20 |
+
|
21 |
+
idx = (
|
22 |
+
jnp.arange(window)[:, None]
|
23 |
+
+ jnp.arange((len(a) - window) // hop_length + 1)[None, :] * hop_length
|
24 |
+
)
|
25 |
+
return a[idx]
|
26 |
+
|
27 |
+
|
28 |
+
@partial(jax.jit, static_argnums=[1, 2, 3, 4, 5, 6])
|
29 |
+
def stft(
|
30 |
+
y: ndarray,
|
31 |
+
n_fft: int = 2048,
|
32 |
+
hop_length: Optional[int] = None,
|
33 |
+
win_length: Optional[int] = None,
|
34 |
+
window: str = "hann",
|
35 |
+
center: bool = True,
|
36 |
+
pad_mode: str = "reflect",
|
37 |
+
):
|
38 |
+
"""A jax reimplementation of ``librosa.stft`` function."""
|
39 |
+
|
40 |
+
if win_length is None:
|
41 |
+
win_length = n_fft
|
42 |
+
|
43 |
+
if hop_length is None:
|
44 |
+
hop_length = win_length // 4
|
45 |
+
|
46 |
+
if window == "hann":
|
47 |
+
fft_window = jnp.hanning(win_length + 1)[:-1]
|
48 |
+
else:
|
49 |
+
raise RuntimeError(f"{window} window function is not supported!")
|
50 |
+
|
51 |
+
pad_len = (n_fft - win_length) // 2
|
52 |
+
fft_window = jnp.pad(fft_window, (pad_len, pad_len), mode="constant")
|
53 |
+
fft_window = fft_window[:, None]
|
54 |
+
if center:
|
55 |
+
y = jnp.pad(y, int(n_fft // 2), mode=pad_mode)
|
56 |
+
|
57 |
+
# jax does not support ``np.lib.stride_tricks.as_strided`` function
|
58 |
+
# see https://github.com/google/jax/issues/3171 for comments.
|
59 |
+
y_frames = rolling_window(y, n_fft, hop_length) * fft_window
|
60 |
+
stft_matrix = jnp.fft.fft(y_frames, axis=0)
|
61 |
+
d = int(1 + n_fft // 2)
|
62 |
+
return stft_matrix[:d]
|
63 |
+
|
64 |
+
|
65 |
+
@partial(jax.jit, static_argnums=[1, 2, 3, 4, 5, 6])
|
66 |
+
def batched_stft(
|
67 |
+
y: ndarray,
|
68 |
+
n_fft: int,
|
69 |
+
hop_length: int,
|
70 |
+
win_length: int,
|
71 |
+
window: str,
|
72 |
+
center: bool = True,
|
73 |
+
pad_mode: str = "reflect",
|
74 |
+
):
|
75 |
+
"""Batched version of ``stft`` function.
|
76 |
+
TN => FTN
|
77 |
+
"""
|
78 |
+
|
79 |
+
assert len(y.shape) >= 2
|
80 |
+
if window == "hann":
|
81 |
+
fft_window = jnp.hanning(win_length + 1)[:-1]
|
82 |
+
else:
|
83 |
+
raise RuntimeError(f"{window} window function is not supported!")
|
84 |
+
pad_len = (n_fft - win_length) // 2
|
85 |
+
if pad_len > 0:
|
86 |
+
fft_window = jnp.pad(fft_window, (pad_len, pad_len), mode="constant")
|
87 |
+
win_length = n_fft
|
88 |
+
else:
|
89 |
+
fft_window = fft_window
|
90 |
+
if center:
|
91 |
+
pad_width = ((n_fft // 2, n_fft // 2),) + ((0, 0),) * (len(y.shape) - 1)
|
92 |
+
y = jnp.pad(y, pad_width, mode=pad_mode)
|
93 |
+
|
94 |
+
# jax does not support ``np.lib.stride_tricks.as_strided`` function
|
95 |
+
# see https://github.com/google/jax/issues/3171 for comments.
|
96 |
+
y_frames = rolling_window(y, n_fft, hop_length)
|
97 |
+
fft_window = jnp.reshape(fft_window, (-1,) + (1,) * (len(y.shape)))
|
98 |
+
y_frames = y_frames * fft_window
|
99 |
+
stft_matrix = jnp.fft.fft(y_frames, axis=0)
|
100 |
+
d = int(1 + n_fft // 2)
|
101 |
+
return stft_matrix[:d]
|
102 |
+
|
103 |
+
|
104 |
+
class MelFilter:
|
105 |
+
"""Convert waveform to mel spectrogram."""
|
106 |
+
|
107 |
+
def __init__(self, sample_rate: int, n_fft: int, n_mels: int, fmin=0.0, fmax=8000):
|
108 |
+
self.melfb = jax.device_put(
|
109 |
+
librosa.filters.mel(
|
110 |
+
sr=sample_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax
|
111 |
+
)
|
112 |
+
)
|
113 |
+
self.n_fft = n_fft
|
114 |
+
|
115 |
+
def __call__(self, y: ndarray) -> ndarray:
|
116 |
+
hop_length = self.n_fft // 4
|
117 |
+
window_length = self.n_fft
|
118 |
+
assert len(y.shape) == 2
|
119 |
+
y = rearrange(y, "n s -> s n")
|
120 |
+
p = (self.n_fft - hop_length) // 2
|
121 |
+
y = jnp.pad(y, ((p, p), (0, 0)), mode="reflect")
|
122 |
+
spec = batched_stft(
|
123 |
+
y, self.n_fft, hop_length, window_length, "hann", False, "reflect"
|
124 |
+
)
|
125 |
+
mag = jnp.sqrt(jnp.square(spec.real) + jnp.square(spec.imag) + 1e-9)
|
126 |
+
mel = jnp.einsum("ms,sfn->nfm", self.melfb, mag)
|
127 |
+
cond = jnp.log(jnp.clip(mel, a_min=1e-5, a_max=None))
|
128 |
+
return cond
|
vietTTS/nat/duration_trainer.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from typing import Deque
|
3 |
+
|
4 |
+
import haiku as hk
|
5 |
+
import jax
|
6 |
+
import jax.numpy as jnp
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import numpy as np
|
9 |
+
import optax
|
10 |
+
from tqdm.auto import tqdm
|
11 |
+
from vietTTS.nat.config import DurationInput
|
12 |
+
|
13 |
+
from .config import FLAGS
|
14 |
+
from .data_loader import textgrid_data_loader
|
15 |
+
from .model import DurationModel
|
16 |
+
from .utils import load_latest_ckpt, print_flags, save_ckpt
|
17 |
+
|
18 |
+
|
19 |
+
def loss_fn(params, aux, rng, x: DurationInput, is_training=True):
|
20 |
+
"""return the l1 loss"""
|
21 |
+
|
22 |
+
@hk.transform_with_state
|
23 |
+
def net(x):
|
24 |
+
return DurationModel(is_training=is_training)(x)
|
25 |
+
|
26 |
+
if is_training:
|
27 |
+
# randomly mask tokens with [WORD END] token
|
28 |
+
# during training to avoid overfitting
|
29 |
+
m_rng, rng = jax.random.split(rng, 2)
|
30 |
+
m = jax.random.bernoulli(m_rng, FLAGS.token_mask_prob, x.phonemes.shape)
|
31 |
+
x = x._replace(phonemes=jnp.where(m, FLAGS.word_end_index, x.phonemes))
|
32 |
+
durations, aux = net.apply(params, aux, rng, x)
|
33 |
+
mask = jnp.arange(0, x.phonemes.shape[1])[None, :] < x.lengths[:, None]
|
34 |
+
# NOT predict [WORD END] token
|
35 |
+
mask = jnp.where(x.phonemes == FLAGS.word_end_index, False, mask)
|
36 |
+
masked_loss = jnp.abs(durations - x.durations) * mask
|
37 |
+
loss = jnp.sum(masked_loss) / jnp.sum(mask)
|
38 |
+
return loss, aux
|
39 |
+
|
40 |
+
|
41 |
+
forward_fn = jax.jit(
|
42 |
+
hk.transform_with_state(lambda x: DurationModel(is_training=False)(x)).apply
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
def predict_duration(params, aux, rng, x: DurationInput):
|
47 |
+
d, _ = forward_fn(params, aux, rng, x)
|
48 |
+
return d, x.durations
|
49 |
+
|
50 |
+
|
51 |
+
val_loss_fn = jax.jit(partial(loss_fn, is_training=False))
|
52 |
+
|
53 |
+
loss_vag = jax.value_and_grad(loss_fn, has_aux=True)
|
54 |
+
|
55 |
+
optimizer = optax.chain(
|
56 |
+
optax.clip_by_global_norm(FLAGS.max_grad_norm),
|
57 |
+
optax.adamw(FLAGS.duration_learning_rate, weight_decay=FLAGS.weight_decay),
|
58 |
+
)
|
59 |
+
|
60 |
+
|
61 |
+
@jax.jit
|
62 |
+
def update(params, aux, rng, optim_state, inputs: DurationInput):
|
63 |
+
rng, new_rng = jax.random.split(rng)
|
64 |
+
(loss, new_aux), grads = loss_vag(params, aux, rng, inputs)
|
65 |
+
updates, new_optim_state = optimizer.update(grads, optim_state, params)
|
66 |
+
new_params = optax.apply_updates(params, updates)
|
67 |
+
return loss, (new_params, new_aux, new_rng, new_optim_state)
|
68 |
+
|
69 |
+
|
70 |
+
def initial_state(batch):
|
71 |
+
rng = jax.random.PRNGKey(42)
|
72 |
+
params, aux = hk.transform_with_state(lambda x: DurationModel(True)(x)).init(
|
73 |
+
rng, batch
|
74 |
+
)
|
75 |
+
optim_state = optimizer.init(params)
|
76 |
+
return params, aux, rng, optim_state
|
77 |
+
|
78 |
+
|
79 |
+
def plot_val_duration(step: int, batch, params, aux, rng):
|
80 |
+
fn = FLAGS.ckpt_dir / f"duration_{step:06d}.png"
|
81 |
+
predicted_dur, gt_dur = predict_duration(params, aux, rng, batch)
|
82 |
+
L = batch.lengths[0]
|
83 |
+
x = np.arange(0, L) * 3
|
84 |
+
plt.plot(predicted_dur[0, :L])
|
85 |
+
plt.plot(gt_dur[0, :L])
|
86 |
+
plt.legend(["predicted", "gt"])
|
87 |
+
plt.title("Phoneme durations")
|
88 |
+
plt.savefig(fn)
|
89 |
+
plt.close()
|
90 |
+
|
91 |
+
|
92 |
+
def train():
|
93 |
+
train_data_iter = textgrid_data_loader(
|
94 |
+
FLAGS.data_dir, FLAGS.max_phoneme_seq_len, FLAGS.batch_size, mode="train"
|
95 |
+
)
|
96 |
+
val_data_iter = textgrid_data_loader(
|
97 |
+
FLAGS.data_dir, FLAGS.max_phoneme_seq_len, FLAGS.batch_size, mode="val"
|
98 |
+
)
|
99 |
+
losses = Deque(maxlen=1000)
|
100 |
+
val_losses = Deque(maxlen=100)
|
101 |
+
latest_ckpt = load_latest_ckpt(FLAGS.ckpt_dir)
|
102 |
+
if latest_ckpt is not None:
|
103 |
+
last_step, params, aux, rng, optim_state = latest_ckpt
|
104 |
+
else:
|
105 |
+
last_step = -1
|
106 |
+
print("Generate random initial states...")
|
107 |
+
params, aux, rng, optim_state = initial_state(next(train_data_iter))
|
108 |
+
|
109 |
+
tr = tqdm(
|
110 |
+
range(last_step + 1, 1 + FLAGS.num_training_steps),
|
111 |
+
total=1 + FLAGS.num_training_steps,
|
112 |
+
initial=last_step + 1,
|
113 |
+
ncols=80,
|
114 |
+
desc="training",
|
115 |
+
)
|
116 |
+
for step in tr:
|
117 |
+
batch = next(train_data_iter)
|
118 |
+
loss, (params, aux, rng, optim_state) = update(
|
119 |
+
params, aux, rng, optim_state, batch
|
120 |
+
)
|
121 |
+
losses.append(loss)
|
122 |
+
|
123 |
+
if step % 10 == 0:
|
124 |
+
val_loss, _ = val_loss_fn(params, aux, rng, next(val_data_iter))
|
125 |
+
val_losses.append(val_loss)
|
126 |
+
|
127 |
+
if step % 1000 == 0:
|
128 |
+
loss = sum(losses).item() / len(losses)
|
129 |
+
val_loss = sum(val_losses).item() / len(val_losses)
|
130 |
+
plot_val_duration(step, next(val_data_iter), params, aux, rng)
|
131 |
+
tr.write(
|
132 |
+
f" {step:>6d}/{FLAGS.num_training_steps:>6d} | train loss {loss:.5f} | val loss {val_loss:.5f}"
|
133 |
+
)
|
134 |
+
save_ckpt(step, params, aux, rng, optim_state, ckpt_dir=FLAGS.ckpt_dir)
|
135 |
+
|
136 |
+
|
137 |
+
if __name__ == "__main__":
|
138 |
+
print_flags(FLAGS.__dict__)
|
139 |
+
if not FLAGS.ckpt_dir.exists():
|
140 |
+
print("Create checkpoint dir at", FLAGS.ckpt_dir)
|
141 |
+
FLAGS.ckpt_dir.mkdir(parents=True, exist_ok=True)
|
142 |
+
train()
|
vietTTS/nat/gta.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
from argparse import ArgumentParser
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import haiku as hk
|
6 |
+
import jax
|
7 |
+
import jax.numpy as jnp
|
8 |
+
import numpy as np
|
9 |
+
from tqdm.auto import tqdm
|
10 |
+
from vietTTS.nat.config import AcousticInput
|
11 |
+
|
12 |
+
from .config import FLAGS, AcousticInput
|
13 |
+
from .data_loader import load_textgrid_wav
|
14 |
+
from .dsp import MelFilter
|
15 |
+
from .model import AcousticModel
|
16 |
+
|
17 |
+
|
18 |
+
@hk.transform_with_state
|
19 |
+
def net(x):
|
20 |
+
return AcousticModel(is_training=True)(x)
|
21 |
+
|
22 |
+
|
23 |
+
@hk.transform_with_state
|
24 |
+
def val_net(x):
|
25 |
+
return AcousticModel(is_training=False)(x)
|
26 |
+
|
27 |
+
|
28 |
+
def forward_fn_(params, aux, rng, inputs: AcousticInput):
|
29 |
+
melfilter = MelFilter(
|
30 |
+
FLAGS.sample_rate, FLAGS.n_fft, FLAGS.mel_dim, FLAGS.fmin, FLAGS.fmax
|
31 |
+
)
|
32 |
+
mels = melfilter(inputs.wavs.astype(jnp.float32) / (2**15))
|
33 |
+
B, L, D = mels.shape
|
34 |
+
inp_mels = jnp.concatenate(
|
35 |
+
(jnp.zeros((B, 1, D), dtype=jnp.float32), mels[:, :-1, :]), axis=1
|
36 |
+
)
|
37 |
+
n_frames = inputs.durations * FLAGS.sample_rate / (FLAGS.n_fft // 4)
|
38 |
+
inputs = inputs._replace(mels=inp_mels, durations=n_frames)
|
39 |
+
(mel1_hat, mel2_hat), new_aux = val_net.apply(params, aux, rng, inputs)
|
40 |
+
return mel2_hat
|
41 |
+
|
42 |
+
|
43 |
+
forward_fn = jax.jit(forward_fn_)
|
44 |
+
|
45 |
+
|
46 |
+
def generate_gta(out_dir: Path):
|
47 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
48 |
+
data_iter = load_textgrid_wav(
|
49 |
+
FLAGS.data_dir,
|
50 |
+
FLAGS.max_phoneme_seq_len,
|
51 |
+
FLAGS.batch_size,
|
52 |
+
FLAGS.max_wave_len,
|
53 |
+
"gta",
|
54 |
+
)
|
55 |
+
ckpt_fn = FLAGS.ckpt_dir / "acoustic_latest_ckpt.pickle"
|
56 |
+
print("Resuming from latest checkpoint at", ckpt_fn)
|
57 |
+
with open(ckpt_fn, "rb") as f:
|
58 |
+
dic = pickle.load(f)
|
59 |
+
_, params, aux, rng, _ = (
|
60 |
+
dic["step"],
|
61 |
+
dic["params"],
|
62 |
+
dic["aux"],
|
63 |
+
dic["rng"],
|
64 |
+
dic["optim_state"],
|
65 |
+
)
|
66 |
+
|
67 |
+
tr = tqdm(data_iter)
|
68 |
+
for names, batch in tr:
|
69 |
+
lengths = batch.wav_lengths
|
70 |
+
predicted_mel = forward_fn(params, aux, rng, batch)
|
71 |
+
mel = jax.device_get(predicted_mel)
|
72 |
+
for idx, fn in enumerate(names):
|
73 |
+
file = out_dir / f"{fn}.npy"
|
74 |
+
tr.write(f"saving to file {file}")
|
75 |
+
l = lengths[idx] // (FLAGS.n_fft // 4)
|
76 |
+
np.save(file, mel[idx, :l].T)
|
77 |
+
|
78 |
+
|
79 |
+
if __name__ == "__main__":
|
80 |
+
parser = ArgumentParser()
|
81 |
+
parser.add_argument("-o", "--output-dir", type=Path, default="gta")
|
82 |
+
generate_gta(parser.parse_args().output_dir)
|