tobiccino commited on
Commit
12da6cc
1 Parent(s): e851206
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. LICENSE +20 -0
  3. README.md +116 -12
  4. Untitled.ipynb +24 -0
  5. app.py +7 -0
  6. assets/.DS_Store +0 -0
  7. assets/hifigan/config.json +38 -0
  8. assets/infore/.DS_Store +0 -0
  9. assets/infore/lexicon.txt +0 -0
  10. assets/transcript.txt +26 -0
  11. notebooks/align_text_audio_infore_mfa.ipynb +193 -0
  12. notebooks/denoise_infore_dataset.ipynb +138 -0
  13. scripts/download_aligned_infore_dataset.py +45 -0
  14. scripts/quick_start.sh +12 -0
  15. setup.cfg +14 -0
  16. setup.py +43 -0
  17. tests/test_nat_acoustic.py +18 -0
  18. tests/test_nat_duration.py +15 -0
  19. vietTTS.egg-info/PKG-INFO +11 -0
  20. vietTTS.egg-info/SOURCES.txt +10 -0
  21. vietTTS.egg-info/dependency_links.txt +1 -0
  22. vietTTS.egg-info/requires.txt +12 -0
  23. vietTTS.egg-info/top_level.txt +1 -0
  24. vietTTS/__init__.py +0 -0
  25. vietTTS/__pycache__/__init__.cpython-39.pyc +0 -0
  26. vietTTS/__pycache__/synthesizer.cpython-39.pyc +0 -0
  27. vietTTS/hifigan/__pycache__/config.cpython-39.pyc +0 -0
  28. vietTTS/hifigan/__pycache__/mel2wave.cpython-39.pyc +0 -0
  29. vietTTS/hifigan/__pycache__/model.cpython-39.pyc +0 -0
  30. vietTTS/hifigan/config.py +6 -0
  31. vietTTS/hifigan/convert_torch_model_to_haiku.py +83 -0
  32. vietTTS/hifigan/create_mel.py +241 -0
  33. vietTTS/hifigan/data_loader.py +0 -0
  34. vietTTS/hifigan/mel2wave.py +41 -0
  35. vietTTS/hifigan/model.py +125 -0
  36. vietTTS/hifigan/torch_model.py +414 -0
  37. vietTTS/hifigan/trainer.py +0 -0
  38. vietTTS/nat/__init__.py +0 -0
  39. vietTTS/nat/__pycache__/__init__.cpython-39.pyc +0 -0
  40. vietTTS/nat/__pycache__/config.cpython-39.pyc +0 -0
  41. vietTTS/nat/__pycache__/data_loader.cpython-39.pyc +0 -0
  42. vietTTS/nat/__pycache__/model.cpython-39.pyc +0 -0
  43. vietTTS/nat/__pycache__/text2mel.cpython-39.pyc +0 -0
  44. vietTTS/nat/acoustic_tpu_trainer.py +189 -0
  45. vietTTS/nat/acoustic_trainer.py +181 -0
  46. vietTTS/nat/config.py +74 -0
  47. vietTTS/nat/data_loader.py +156 -0
  48. vietTTS/nat/dsp.py +128 -0
  49. vietTTS/nat/duration_trainer.py +142 -0
  50. vietTTS/nat/gta.py +82 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
LICENSE ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2021 ntt123
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining
4
+ a copy of this software and associated documentation files (the
5
+ "Software"), to deal in the Software without restriction, including
6
+ without limitation the rights to use, copy, modify, merge, publish,
7
+ distribute, sublicense, and/or sell copies of the Software, and to
8
+ permit persons to whom the Software is furnished to do so, subject to
9
+ the following conditions:
10
+
11
+ The above copyright notice and this permission notice shall be
12
+ included in all copies or substantial portions of the Software.
13
+
14
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17
+ NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
18
+ LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
19
+ OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
20
+ WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
README.md CHANGED
@@ -1,12 +1,116 @@
1
- ---
2
- title: Tts
3
- emoji: 🔥
4
- colorFrom: green
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 3.18.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ A Vietnamese TTS
2
+ ================
3
+
4
+ Duration model + Acoustic model + HiFiGAN vocoder for vietnamese text-to-speech application.
5
+
6
+ Online demo at https://huggingface.co/spaces/ntt123/vietTTS.
7
+
8
+ A synthesized audio clip: [clip.wav](assets/infore/clip.wav). A colab notebook: [notebook](https://colab.research.google.com/drive/1oczrWOQOr1Y_qLdgis1twSlNZlfPVXoY?usp=sharing).
9
+
10
+
11
+ 🔔Checkout the experimental `multi-speaker` branch (`git checkout multi-speaker`) for multi-speaker support.🔔
12
+
13
+ Install
14
+ -------
15
+
16
+
17
+ ```sh
18
+ git clone https://github.com/NTT123/vietTTS.git
19
+ cd vietTTS
20
+ pip3 install -e .
21
+ ```
22
+
23
+
24
+ Quick start using pretrained models
25
+ ----------------------------------
26
+ ```sh
27
+ bash ./scripts/quick_start.sh
28
+ ```
29
+
30
+
31
+ Download InfoRe dataset
32
+ -----------------------
33
+
34
+ ```sh
35
+ python ./scripts/download_aligned_infore_dataset.py
36
+ ```
37
+
38
+ **Note**: this is a denoised and aligned version of the original dataset which is donated by the InfoRe Technology company (see [here](https://www.facebook.com/groups/j2team.community/permalink/1010834009248719/)). You can download the original dataset (**InfoRe Technology 1**) at [here](https://github.com/TensorSpeech/TensorFlowASR/blob/main/README.md#vietnamese).
39
+
40
+ See `notebooks/denoise_infore_dataset.ipynb` for instructions on how to denoise the dataset. We use the Montreal Forced Aligner (MFA) to align transcript and speech (textgrid files).
41
+ See `notebooks/align_text_audio_infore_mfa.ipynb` for instructions on how to create textgrid files.
42
+
43
+ Train duration model
44
+ --------------------
45
+
46
+ ```sh
47
+ python -m vietTTS.nat.duration_trainer
48
+ ```
49
+
50
+
51
+ Train acoustic model
52
+ --------------------
53
+ ```sh
54
+ python -m vietTTS.nat.acoustic_trainer
55
+ ```
56
+
57
+
58
+
59
+ Train HiFiGAN vocoder
60
+ -------------
61
+
62
+ We use the original implementation from HiFiGAN authors at https://github.com/jik876/hifi-gan. Use the config file at `assets/hifigan/config.json` to train your model.
63
+
64
+ ```sh
65
+ git clone https://github.com/jik876/hifi-gan.git
66
+
67
+ # create dataset in hifi-gan format
68
+ ln -sf `pwd`/train_data hifi-gan/data
69
+ cd hifi-gan/data
70
+ ls -1 *.TextGrid | sed -e 's/\.TextGrid$//' > files.txt
71
+ cd ..
72
+ head -n 100 data/files.txt > val_files.txt
73
+ tail -n +101 data/files.txt > train_files.txt
74
+ rm data/files.txt
75
+
76
+ # training
77
+ python train.py \
78
+ --config ../assets/hifigan/config.json \
79
+ --input_wavs_dir=data \
80
+ --input_training_file=train_files.txt \
81
+ --input_validation_file=val_files.txt
82
+ ```
83
+
84
+ Finetune on Ground-Truth Aligned melspectrograms:
85
+ ```sh
86
+ cd /path/to/vietTTS # go to vietTTS directory
87
+ python -m vietTTS.nat.zero_silence_segments -o train_data # zero all [sil, sp, spn] segments
88
+ python -m vietTTS.nat.gta -o /path/to/hifi-gan/ft_dataset # create gta melspectrograms at hifi-gan/ft_dataset directory
89
+
90
+ # turn on finetune
91
+ cd /path/to/hifi-gan
92
+ python train.py \
93
+ --fine_tuning True \
94
+ --config ../assets/hifigan/config.json \
95
+ --input_wavs_dir=data \
96
+ --input_training_file=train_files.txt \
97
+ --input_validation_file=val_files.txt
98
+ ```
99
+
100
+ Then, use the following command to convert pytorch model to haiku format:
101
+ ```sh
102
+ cd ..
103
+ python -m vietTTS.hifigan.convert_torch_model_to_haiku \
104
+ --config-file=assets/hifigan/config.json \
105
+ --checkpoint-file=hifi-gan/cp_hifigan/g_[latest_checkpoint]
106
+ ```
107
+
108
+ Synthesize speech
109
+ -----------------
110
+
111
+ ```sh
112
+ python -m vietTTS.synthesizer \
113
+ --lexicon-file=train_data/lexicon.txt \
114
+ --text="hôm qua em tới trường" \
115
+ --output=clip.wav
116
+ ```
Untitled.ipynb ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "b036f793-1443-4341-932c-d112386937ea",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": []
10
+ }
11
+ ],
12
+ "metadata": {
13
+ "kernelspec": {
14
+ "display_name": "Python 3 (ipykernel)",
15
+ "language": "python",
16
+ "name": "python3"
17
+ },
18
+ "language_info": {
19
+ "name": ""
20
+ }
21
+ },
22
+ "nbformat": 4,
23
+ "nbformat_minor": 5
24
+ }
app.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def greet(name):
4
+ return "Hello " + name + "!!"
5
+
6
+ iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
+ iface.launch()
assets/.DS_Store ADDED
Binary file (6.15 kB). View file
 
assets/hifigan/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 16,
5
+ "learning_rate": 0.0002,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.999,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [8,8,2,2],
12
+ "upsample_kernel_sizes": [16,16,4,4],
13
+ "upsample_initial_channel": 512,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+ "resblock_initial_channel": 256,
17
+
18
+ "segment_size": 8192,
19
+ "num_mels": 80,
20
+ "num_freq": 1025,
21
+ "n_fft": 1024,
22
+ "hop_size": 256,
23
+ "win_size": 1024,
24
+
25
+ "sampling_rate": 16000,
26
+
27
+ "fmin": 0,
28
+ "fmax": 8000,
29
+ "fmax_for_loss": null,
30
+
31
+ "num_workers": 4,
32
+
33
+ "dist_config": {
34
+ "dist_backend": "nccl",
35
+ "dist_url": "tcp://localhost:54321",
36
+ "world_size": 1
37
+ }
38
+ }
assets/infore/.DS_Store ADDED
Binary file (6.15 kB). View file
 
assets/infore/lexicon.txt ADDED
The diff for this file is too large to render. See raw diff
 
assets/transcript.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Trăm năm trong cõi người ta,
2
+ Chữ tài chữ mệnh khéo là ghét nhau.
3
+ Trải qua một cuộc bể dâu,
4
+ Những điều trông thấy mà đau đớn lòng.
5
+ Lạ gì bỉ sắc tư phong,
6
+ Trời xanh quen thói má hồng đánh ghen.
7
+ Cảo thơm lần giở trước đèn,
8
+ Phong tình cổ lục còn truyền sử xanh.
9
+ Rằng: Năm Gia tĩnh triều Minh,
10
+ Bốn phương phẳng lặng hai kinh chữ vàng.
11
+ Có nhà viên ngoại họ Vương,
12
+ Gia tư nghỉ cũng thường thường bậc trung.
13
+ Một trai con thứ rốt lòng,
14
+ Vương Quan là chữ nối dòng nho gia.
15
+ Đầu lòng hai ả tố nga,
16
+ Thúy Kiều là chị em là Thúy Vân.
17
+ Mai cốt cách tuyết tinh thần,
18
+ Mỗi người một vẻ mười phân vẹn mười.
19
+ Vân xem trang trọng khác vời,
20
+ Khuôn trăng đầy đặn nét ngài nở nang.
21
+ Hoa cười ngọc thốt đoan trang,
22
+ Mây thua nước tóc tuyết nhường màu da.
23
+ Kiều càng sắc sảo mặn mà,
24
+ So bề tài sắc lại là phần hơn.
25
+ Làn thu thủy nét xuân sơn,
26
+ Hoa ghen thua thắm liễu hờn kém xanh.
notebooks/align_text_audio_infore_mfa.ipynb ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "### Align text and audio using Montreal Forced Aligner (MFA)"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {
14
+ "id": "IPkicKwU8IWj"
15
+ },
16
+ "outputs": [],
17
+ "source": [
18
+ "%%capture\n",
19
+ "!apt update -y\n",
20
+ "!pip install -U pip"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": null,
26
+ "metadata": {
27
+ "id": "G6Z-aDd08hfk"
28
+ },
29
+ "outputs": [],
30
+ "source": [
31
+ "%%capture\n",
32
+ "%%bash\n",
33
+ "data_root=\"./infore_16k_denoised\"\n",
34
+ "mkdir -p $data_root\n",
35
+ "cd $data_root\n",
36
+ "wget https://huggingface.co/datasets/ntt123/infore/resolve/main/infore_16k_denoised.zip -O infore.zip\n",
37
+ "unzip infore.zip "
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "metadata": {
44
+ "id": "VWwgAePDXy4m"
45
+ },
46
+ "outputs": [],
47
+ "source": [
48
+ "from pathlib import Path\n",
49
+ "\n",
50
+ "txt_files = sorted(Path(\"./infore_16k_denoised\").glob(\"*.txt\"))\n",
51
+ "f = open(\"/content/words.txt\", \"w\", encoding=\"utf-8\")\n",
52
+ "for txt_file in txt_files:\n",
53
+ " wav_file = txt_file.with_suffix(\".wav\")\n",
54
+ " if not wav_file.exists():\n",
55
+ " continue\n",
56
+ " line = open(txt_file, \"r\", encoding=\"utf-8\").read()\n",
57
+ " for word in line.strip().lower().split():\n",
58
+ " f.write(word)\n",
59
+ " f.write(\"\\n\")\n",
60
+ "f.close()"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": null,
66
+ "metadata": {
67
+ "id": "FktjNXbDkBLh"
68
+ },
69
+ "outputs": [],
70
+ "source": [
71
+ "black_list = (\n",
72
+ " []\n",
73
+ " + [\"q\", \"adn\", \"h\", \"stress\", \"b\", \"k\", \"mark\", \"gas\", \"cs\", \"test\", \"l\", \"hiv\"]\n",
74
+ " + [\"v\", \"d\", \"c\", \"p\", \"martin\", \"visa\", \"euro\", \"laser\", \"x\", \"real\", \"shop\"]\n",
75
+ " + [\"studio\", \"kelvin\", \"đt\", \"pop\", \"rock\", \"gara\", \"karaoke\", \"đicr\", \"đigiúp\"]\n",
76
+ " + [\"khmer\", \"ii\", \"s\", \"tr\", \"xhcn\", \"casino\", \"guitar\", \"sex\", \"oxi\", \"radio\"]\n",
77
+ " + [\"qúy\", \"asean\", \"hlv\" \"ts\", \"video\", \"virus\", \"usd\", \"robot\", \"ph\", \"album\"]\n",
78
+ " + [\"s\", \"kg\", \"km\", \"g\", \"tr\", \"đ\", \"ak\", \"d\", \"m\", \"n\"]\n",
79
+ ")"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": null,
85
+ "metadata": {
86
+ "id": "b3nMwfzK_g0B"
87
+ },
88
+ "outputs": [],
89
+ "source": [
90
+ "ws = open(\"/content/words.txt\").readlines()\n",
91
+ "f = open(\"/content/lexicon.txt\", \"w\")\n",
92
+ "for w in sorted(set(ws)):\n",
93
+ " w = w.strip()\n",
94
+ "\n",
95
+ " # this is a hack to match phoneme set in the vietTTS repo\n",
96
+ " p = list(w)\n",
97
+ " p = \" \".join(p)\n",
98
+ " if w in black_list:\n",
99
+ " continue\n",
100
+ " else:\n",
101
+ " f.write(f\"{w}\\t{p}\\n\")\n",
102
+ "f.close()"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": null,
108
+ "metadata": {
109
+ "id": "WuWZKTNRt1eM"
110
+ },
111
+ "outputs": [],
112
+ "source": [
113
+ "%%writefile install_mfa.sh\n",
114
+ "#!/bin/bash\n",
115
+ "\n",
116
+ "## a script to install Montreal Forced Aligner (MFA)\n",
117
+ "\n",
118
+ "root_dir=${1:-/tmp/mfa}\n",
119
+ "mkdir -p $root_dir\n",
120
+ "cd $root_dir\n",
121
+ "\n",
122
+ "# download miniconda3\n",
123
+ "wget -q --show-progress https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh\n",
124
+ "bash Miniconda3-latest-Linux-x86_64.sh -b -p $root_dir/miniconda3 -f\n",
125
+ "\n",
126
+ "#install MFA\n",
127
+ "$root_dir/miniconda3/bin/conda create -n aligner -c conda-forge montreal-forced-aligner=2.0.0rc7 -y\n",
128
+ "\n",
129
+ "echo -e \"\\n======== DONE ==========\"\n",
130
+ "echo -e \"\\nTo activate MFA, run: source $root_dir/miniconda3/bin/activate aligner\""
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": null,
136
+ "metadata": {
137
+ "id": "osR7KJCNXJYq"
138
+ },
139
+ "outputs": [],
140
+ "source": [
141
+ "# download and install mfa\n",
142
+ "INSTALL_DIR = \"/tmp/mfa\" # path to install directory\n",
143
+ "!bash ./install_mfa.sh {INSTALL_DIR}"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": null,
149
+ "metadata": {
150
+ "colab": {
151
+ "base_uri": "https://localhost:8080/"
152
+ },
153
+ "id": "hxbXwJZlXLPz",
154
+ "outputId": "d3e40ec5-68a7-40ec-d070-137736d7a956"
155
+ },
156
+ "outputs": [],
157
+ "source": [
158
+ "!source {INSTALL_DIR}/miniconda3/bin/activate aligner; \\\n",
159
+ "mfa train --clean -t ./temp -o ./infore_mfa.zip ./infore_16k_denoised lexicon.txt ./infore_textgrid"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": null,
165
+ "metadata": {
166
+ "id": "8Z65_BtXagn1"
167
+ },
168
+ "outputs": [],
169
+ "source": [
170
+ "# copy to train directory\n",
171
+ "!mkdir -p train_data\n",
172
+ "!cp ./infore_16k_denoised/*.wav ./train_data\n",
173
+ "!cp ./infore_textgrid/*.TextGrid ./train_data"
174
+ ]
175
+ }
176
+ ],
177
+ "metadata": {
178
+ "colab": {
179
+ "collapsed_sections": [],
180
+ "name": "align-text-audio | InfoRe using MFA v2rc7.ipynb",
181
+ "provenance": []
182
+ },
183
+ "kernelspec": {
184
+ "display_name": "Python 3",
185
+ "name": "python3"
186
+ },
187
+ "language_info": {
188
+ "name": "python"
189
+ }
190
+ },
191
+ "nbformat": 4,
192
+ "nbformat_minor": 0
193
+ }
notebooks/denoise_infore_dataset.ipynb ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "qjubHCEzYtG8"
7
+ },
8
+ "source": [
9
+ "### Step 1. Download InfoRE dataset"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {
16
+ "id": "zCBJzJi6BE_o"
17
+ },
18
+ "outputs": [],
19
+ "source": [
20
+ "%%capture\n",
21
+ "%%bash\n",
22
+ "mkdir -p /content/data\n",
23
+ "cd /content/data\n",
24
+ "wget https://huggingface.co/datasets/ntt123/infore/resolve/main/infore_16k.zip\n",
25
+ "# unzip -P BroughtToYouByInfoRe 25hours.zip\n",
26
+ "unzip infore_16k.zip"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "markdown",
31
+ "metadata": {
32
+ "id": "6C47hb9nYzmB"
33
+ },
34
+ "source": [
35
+ "### Step 2. Normalize audio clip"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "metadata": {
42
+ "id": "Hp9TK8PbBcQM"
43
+ },
44
+ "outputs": [],
45
+ "source": [
46
+ "%%capture\n",
47
+ "!sudo apt install -y sox\n",
48
+ "!pip install soundfile librosa\n",
49
+ "!pip install onnxruntime==1.11.1"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": null,
55
+ "metadata": {
56
+ "colab": {
57
+ "base_uri": "https://localhost:8080/"
58
+ },
59
+ "id": "FW45D8xM9Mcc",
60
+ "outputId": "8d7ea7a9-ea5a-48ca-88fe-37dd4ed55d9b"
61
+ },
62
+ "outputs": [],
63
+ "source": [
64
+ "!mkdir -p /content/infore_16k\n",
65
+ "from pathlib import Path\n",
66
+ "import os\n",
67
+ "from tqdm.cli import tqdm\n",
68
+ "\n",
69
+ "wavs = sorted(Path(\"/content/data/InfoRe\").glob(\"*.wav\"))\n",
70
+ "for path in tqdm(wavs):\n",
71
+ " out = Path(\"/content/infore_16k\") / path.name\n",
72
+ " cmd = f\"sox {path} -c 1 -e signed-integer -b 16 -r 16k --norm=-3 {out}\"\n",
73
+ " os.system(cmd)"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "markdown",
78
+ "metadata": {
79
+ "id": "kooiBrsQY5sQ"
80
+ },
81
+ "source": [
82
+ "### Step 3. Denoise using DNS-Challenge's baseline"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "execution_count": null,
88
+ "metadata": {
89
+ "id": "hsXeNkZ3Xacj"
90
+ },
91
+ "outputs": [],
92
+ "source": [
93
+ "!git clone https://github.com/microsoft/DNS-Challenge\n",
94
+ "%cd DNS-Challenge/NSNet2-baseline/\n",
95
+ "!git checkout -f 8b87a33b2892f147b5c7ad39ea978453730db269\n",
96
+ "!python run_nsnet2.py -i /content/infore_16k/ -o /content/infore_16k_denoised -m ./nsnet2-20ms-baseline.onnx"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "markdown",
101
+ "metadata": {
102
+ "id": "T5JtZwKgZI4r"
103
+ },
104
+ "source": [
105
+ "### Step 4. Zip the denoised dataset"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": null,
111
+ "metadata": {
112
+ "id": "eeFggV0uYop_"
113
+ },
114
+ "outputs": [],
115
+ "source": [
116
+ "%cd /content\n",
117
+ "!cp /content/data/InfoRe/*.txt ./infore_16k_denoised\n",
118
+ "!cd ./infore_16k_denoised; zip -r ../infore_16k_denoised.zip ."
119
+ ]
120
+ }
121
+ ],
122
+ "metadata": {
123
+ "colab": {
124
+ "collapsed_sections": [],
125
+ "name": "prepare_infore_dataset.ipynb",
126
+ "provenance": []
127
+ },
128
+ "kernelspec": {
129
+ "display_name": "Python 3",
130
+ "name": "python3"
131
+ },
132
+ "language_info": {
133
+ "name": "python"
134
+ }
135
+ },
136
+ "nbformat": 4,
137
+ "nbformat_minor": 0
138
+ }
scripts/download_aligned_infore_dataset.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A script to download the InfoRE dataset and textgrid files.
3
+ """
4
+ import shutil
5
+ from pathlib import Path
6
+
7
+ import pooch
8
+ from pooch import Unzip
9
+ from tqdm.cli import tqdm
10
+
11
+
12
+ def download_infore_data():
13
+ """download infore wav files"""
14
+ files = pooch.retrieve(
15
+ url="https://huggingface.co/datasets/ntt123/infore/resolve/main/infore_16k_denoised.zip",
16
+ known_hash="2445527b345fb0b1816ce3c8f09bae419d6bbe251f16d6c74d8dd95ef9fb0737",
17
+ processor=Unzip(),
18
+ progressbar=True,
19
+ )
20
+ data_dir = Path(sorted(files)[0]).parent
21
+ return data_dir
22
+
23
+
24
+ def download_textgrid():
25
+ """download textgrid files"""
26
+ files = pooch.retrieve(
27
+ url="https://huggingface.co/datasets/ntt123/infore/resolve/main/infore_tg.zip",
28
+ known_hash="26e4f53025220097ea95dc266657de8d65104b0a17a6ffba778fc016c8dd36d7",
29
+ processor=Unzip(),
30
+ progressbar=True,
31
+ )
32
+ data_dir = Path(sorted(files)[0]).parent
33
+ return data_dir
34
+
35
+
36
+ DATA_ROOT = Path("./train_data")
37
+ DATA_ROOT.mkdir(parents=True, exist_ok=True)
38
+ wav_dir = download_infore_data()
39
+ tg_dir = download_textgrid()
40
+
41
+ for path in tqdm(tg_dir.glob("*.TextGrid")):
42
+ wav_name = path.with_suffix(".wav").name
43
+ wav_src = wav_dir / wav_name
44
+ shutil.copy(path, DATA_ROOT)
45
+ shutil.copy(wav_src, DATA_ROOT)
scripts/quick_start.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ if [ ! -f assets/infore/hifigan/g_01140000 ]; then
2
+ echo "Downloading models..."
3
+ mkdir -p assets/infore/{nat,hifigan}
4
+ wget https://huggingface.co/ntt123/viettts_infore_16k/resolve/main/duration_latest_ckpt.pickle -O assets/infore/nat/duration_latest_ckpt.pickle
5
+ wget https://huggingface.co/ntt123/viettts_infore_16k/resolve/main/acoustic_latest_ckpt.pickle -O assets/infore/nat/acoustic_latest_ckpt.pickle
6
+ wget https://huggingface.co/ntt123/viettts_infore_16k/resolve/main/g_01140000 -O assets/infore/hifigan/g_01140000
7
+ python3 -m vietTTS.hifigan.convert_torch_model_to_haiku --config-file=assets/hifigan/config.json --checkpoint-file=assets/infore/hifigan/g_01140000
8
+ fi
9
+
10
+ echo "Generate audio clip"
11
+ text=`cat assets/transcript.txt`
12
+ python3 -m vietTTS.synthesizer --text "$text" --output assets/infore/clip.wav --lexicon-file assets/infore/lexicon.txt --silence-duration 0.2
setup.cfg ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [pep8]
2
+ max-line-length = 120
3
+ indent-size = 2
4
+
5
+ [pycodestyle]
6
+ max-line-length = 120
7
+
8
+ [yapf]
9
+ based_on_style = pep8
10
+ column_limit = 120
11
+
12
+ [tool:pytest]
13
+ testpaths=
14
+ tests
setup.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup
2
+
3
+ __version__ = "0.4.1"
4
+ url = "https://github.com/ntt123/vietTTS"
5
+
6
+ install_requires = [
7
+ "dm-haiku",
8
+ "einops",
9
+ "fire",
10
+ "gdown",
11
+ "jax",
12
+ "jaxlib",
13
+ "librosa",
14
+ "optax",
15
+ "tabulate",
16
+ "textgrid @ git+https://github.com/kylebgorman/textgrid.git",
17
+ "tqdm",
18
+ "matplotlib",
19
+ ]
20
+ setup_requires = []
21
+ tests_require = []
22
+
23
+ setup(
24
+ name="vietTTS",
25
+ version=__version__,
26
+ description="A vietnamese text-to-speech library.",
27
+ author="ntt123",
28
+ url=url,
29
+ keywords=[
30
+ "text-to-speech",
31
+ "tts",
32
+ "deep-learning",
33
+ "dm-haiku",
34
+ "jax",
35
+ "vietnamese",
36
+ "speech-synthesis",
37
+ ],
38
+ install_requires=install_requires,
39
+ setup_requires=setup_requires,
40
+ tests_require=tests_require,
41
+ packages=["vietTTS"],
42
+ python_requires=">=3.7",
43
+ )
tests/test_nat_acoustic.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import haiku
2
+ import haiku as hk
3
+ import jax.numpy as jnp
4
+ import jax.random
5
+ from vietTTS.nat.config import FLAGS
6
+ from vietTTS.nat.model import AcousticModel
7
+
8
+
9
+ @hk.testing.transform_and_run
10
+ def test_duration():
11
+ net = AcousticModel()
12
+ token = jnp.zeros((2, 10), dtype=jnp.int32)
13
+ lengths = jnp.zeros((2,), dtype=jnp.int32)
14
+ durations = jnp.zeros((2, 10), dtype=jnp.float32)
15
+ mel = jnp.zeros((2, 20, 160), dtype=jnp.float32)
16
+ o1, o2 = net(token, mel, lengths, durations)
17
+ assert o1.shape == (2, 20, 160)
18
+ assert o2.shape == (2, 20, 160)
tests/test_nat_duration.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import haiku
2
+ import haiku as hk
3
+ import jax.numpy as jnp
4
+ import jax.random
5
+ from vietTTS.nat.config import FLAGS
6
+ from vietTTS.nat.model import DurationModel
7
+
8
+
9
+ @hk.testing.transform_and_run
10
+ def test_duration():
11
+ net = DurationModel()
12
+ p = jnp.zeros((2, 10), dtype=jnp.int32)
13
+ l = jnp.zeros((2,), dtype=jnp.int32)
14
+ o = net(p, l)
15
+ assert o.shape == (2, 10, 1)
vietTTS.egg-info/PKG-INFO ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 1.2
2
+ Name: vietTTS
3
+ Version: 0.4.1
4
+ Summary: A vietnamese text-to-speech library.
5
+ Home-page: https://github.com/ntt123/vietTTS
6
+ Author: ntt123
7
+ License: UNKNOWN
8
+ Description: UNKNOWN
9
+ Keywords: text-to-speech,tts,deep-learning,dm-haiku,jax,vietnamese,speech-synthesis
10
+ Platform: UNKNOWN
11
+ Requires-Python: >=3.7
vietTTS.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ setup.cfg
3
+ setup.py
4
+ vietTTS/__init__.py
5
+ vietTTS/synthesizer.py
6
+ vietTTS.egg-info/PKG-INFO
7
+ vietTTS.egg-info/SOURCES.txt
8
+ vietTTS.egg-info/dependency_links.txt
9
+ vietTTS.egg-info/requires.txt
10
+ vietTTS.egg-info/top_level.txt
vietTTS.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
vietTTS.egg-info/requires.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dm-haiku
2
+ einops
3
+ fire
4
+ gdown
5
+ jax
6
+ jaxlib
7
+ librosa
8
+ optax
9
+ tabulate
10
+ textgrid@ git+https://github.com/kylebgorman/textgrid.git
11
+ tqdm
12
+ matplotlib
vietTTS.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ vietTTS
vietTTS/__init__.py ADDED
File without changes
vietTTS/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (154 Bytes). View file
 
vietTTS/__pycache__/synthesizer.cpython-39.pyc ADDED
Binary file (1.39 kB). View file
 
vietTTS/hifigan/__pycache__/config.cpython-39.pyc ADDED
Binary file (437 Bytes). View file
 
vietTTS/hifigan/__pycache__/mel2wave.cpython-39.pyc ADDED
Binary file (1.51 kB). View file
 
vietTTS/hifigan/__pycache__/model.cpython-39.pyc ADDED
Binary file (3.8 kB). View file
 
vietTTS/hifigan/config.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import NamedTuple
3
+
4
+
5
+ class FLAGS:
6
+ ckpt_dir = Path("./assets/infore/hifigan")
vietTTS/hifigan/convert_torch_model_to_haiku.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import pickle
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from .config import FLAGS
10
+ from .torch_model import Generator
11
+
12
+
13
+ class AttrDict(dict):
14
+ def __init__(self, *args, **kwargs):
15
+ super(AttrDict, self).__init__(*args, **kwargs)
16
+ self.__dict__ = self
17
+
18
+
19
+ def load_checkpoint(filepath, device):
20
+ assert os.path.isfile(filepath)
21
+ print("Loading '{}'".format(filepath))
22
+ checkpoint_dict = torch.load(filepath, map_location=device)
23
+ print("Complete.")
24
+ return checkpoint_dict
25
+
26
+
27
+ def convert_to_haiku(a, h, device):
28
+ generator = Generator(h).to(device)
29
+ state_dict_g = load_checkpoint(a.checkpoint_file, device)
30
+ generator.load_state_dict(state_dict_g["generator"])
31
+ generator.eval()
32
+ generator.remove_weight_norm()
33
+ hk_map = {}
34
+ for a, b in generator.state_dict().items():
35
+ print(a, b.shape)
36
+ if a.startswith("conv_pre"):
37
+ a = "generator/~/conv1_d"
38
+ elif a.startswith("conv_post"):
39
+ a = "generator/~/conv1_d_1"
40
+ elif a.startswith("ups."):
41
+ ii = a.split(".")[1]
42
+ a = f"generator/~/ups_{ii}"
43
+ elif a.startswith("resblocks."):
44
+ _, x, y, z, _ = a.split(".")
45
+ ver = h.resblock
46
+ a = f"generator/~/res_block{ver}_{x}/~/{y}_{z}"
47
+ print(a, b.shape)
48
+ if a not in hk_map:
49
+ hk_map[a] = {}
50
+ if len(b.shape) == 1:
51
+ hk_map[a]["b"] = b.numpy()
52
+ else:
53
+ if "ups" in a:
54
+ hk_map[a]["w"] = np.rot90(b.numpy(), k=1, axes=(0, 2))
55
+ elif "conv" in a:
56
+ hk_map[a]["w"] = np.swapaxes(b.numpy(), 0, 2)
57
+ else:
58
+ hk_map[a]["w"] = b.numpy()
59
+
60
+ FLAGS.ckpt_dir.mkdir(parents=True, exist_ok=True)
61
+ with open(FLAGS.ckpt_dir / "hk_hifi.pickle", "wb") as f:
62
+ pickle.dump(hk_map, f)
63
+
64
+
65
+ def main():
66
+ parser = argparse.ArgumentParser()
67
+ parser.add_argument("--checkpoint-file", required=True)
68
+ parser.add_argument("--config-file", required=True)
69
+ a = parser.parse_args()
70
+
71
+ config_file = a.config_file
72
+ with open(config_file) as f:
73
+ data = f.read()
74
+
75
+ json_config = json.loads(data)
76
+ h = AttrDict(json_config)
77
+
78
+ device = torch.device("cpu")
79
+ convert_to_haiku(a, h, device)
80
+
81
+
82
+ if __name__ == "__main__":
83
+ main()
vietTTS/hifigan/create_mel.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ import torch
5
+ import torch.utils.data
6
+ import numpy as np
7
+ from librosa.util import normalize
8
+ from scipy.io.wavfile import read
9
+ from librosa.filters import mel as librosa_mel_fn
10
+
11
+ MAX_WAV_VALUE = 32768.0
12
+
13
+
14
+ def load_wav(full_path):
15
+ sampling_rate, data = read(full_path)
16
+ return data, sampling_rate
17
+
18
+
19
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
20
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
21
+
22
+
23
+ def dynamic_range_decompression(x, C=1):
24
+ return np.exp(x) / C
25
+
26
+
27
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
28
+ return torch.log(torch.clamp(x, min=clip_val) * C)
29
+
30
+
31
+ def dynamic_range_decompression_torch(x, C=1):
32
+ return torch.exp(x) / C
33
+
34
+
35
+ def spectral_normalize_torch(magnitudes):
36
+ output = dynamic_range_compression_torch(magnitudes)
37
+ return output
38
+
39
+
40
+ def spectral_de_normalize_torch(magnitudes):
41
+ output = dynamic_range_decompression_torch(magnitudes)
42
+ return output
43
+
44
+
45
+ mel_basis = {}
46
+ hann_window = {}
47
+
48
+
49
+ def mel_spectrogram(
50
+ y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
51
+ ):
52
+ if torch.min(y) < -1.0:
53
+ print("min value is ", torch.min(y))
54
+ if torch.max(y) > 1.0:
55
+ print("max value is ", torch.max(y))
56
+
57
+ global mel_basis, hann_window
58
+ if fmax not in mel_basis:
59
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
60
+ mel_basis[str(fmax) + "_" + str(y.device)] = (
61
+ torch.from_numpy(mel).float().to(y.device)
62
+ )
63
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
64
+
65
+ y = torch.nn.functional.pad(
66
+ y.unsqueeze(1),
67
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
68
+ mode="reflect",
69
+ )
70
+ y = y.squeeze(1)
71
+
72
+ spec = torch.stft(
73
+ y,
74
+ n_fft,
75
+ hop_length=hop_size,
76
+ win_length=win_size,
77
+ window=hann_window[str(y.device)],
78
+ center=center,
79
+ pad_mode="reflect",
80
+ normalized=False,
81
+ onesided=True,
82
+ )
83
+
84
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
85
+
86
+ spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
87
+ spec = spectral_normalize_torch(spec)
88
+
89
+ return spec
90
+
91
+
92
+ def get_dataset_filelist(a):
93
+ with open(a.input_training_file, "r", encoding="utf-8") as fi:
94
+ training_files = [
95
+ os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
96
+ for x in fi.read().split("\n")
97
+ if len(x) > 0
98
+ ]
99
+
100
+ with open(a.input_validation_file, "r", encoding="utf-8") as fi:
101
+ validation_files = [
102
+ os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
103
+ for x in fi.read().split("\n")
104
+ if len(x) > 0
105
+ ]
106
+ return training_files, validation_files
107
+
108
+
109
+ class MelDataset(torch.utils.data.Dataset):
110
+ def __init__(
111
+ self,
112
+ training_files,
113
+ segment_size,
114
+ n_fft,
115
+ num_mels,
116
+ hop_size,
117
+ win_size,
118
+ sampling_rate,
119
+ fmin,
120
+ fmax,
121
+ split=True,
122
+ shuffle=True,
123
+ n_cache_reuse=1,
124
+ device=None,
125
+ fmax_loss=None,
126
+ fine_tuning=False,
127
+ base_mels_path=None,
128
+ ):
129
+ self.audio_files = training_files
130
+ random.seed(1234)
131
+ if shuffle:
132
+ random.shuffle(self.audio_files)
133
+ self.segment_size = segment_size
134
+ self.sampling_rate = sampling_rate
135
+ self.split = split
136
+ self.n_fft = n_fft
137
+ self.num_mels = num_mels
138
+ self.hop_size = hop_size
139
+ self.win_size = win_size
140
+ self.fmin = fmin
141
+ self.fmax = fmax
142
+ self.fmax_loss = fmax_loss
143
+ self.cached_wav = None
144
+ self.n_cache_reuse = n_cache_reuse
145
+ self._cache_ref_count = 0
146
+ self.device = device
147
+ self.fine_tuning = fine_tuning
148
+ self.base_mels_path = base_mels_path
149
+
150
+ def __getitem__(self, index):
151
+ filename = self.audio_files[index]
152
+ if self._cache_ref_count == 0:
153
+ audio, sampling_rate = load_wav(filename)
154
+ audio = audio / MAX_WAV_VALUE
155
+ if not self.fine_tuning:
156
+ audio = normalize(audio) * 0.95
157
+ self.cached_wav = audio
158
+ if sampling_rate != self.sampling_rate:
159
+ raise ValueError(
160
+ "{} SR doesn't match target {} SR".format(
161
+ sampling_rate, self.sampling_rate
162
+ )
163
+ )
164
+ self._cache_ref_count = self.n_cache_reuse
165
+ else:
166
+ audio = self.cached_wav
167
+ self._cache_ref_count -= 1
168
+
169
+ audio = torch.FloatTensor(audio)
170
+ audio = audio.unsqueeze(0)
171
+
172
+ if not self.fine_tuning:
173
+ if self.split:
174
+ if audio.size(1) >= self.segment_size:
175
+ max_audio_start = audio.size(1) - self.segment_size
176
+ audio_start = random.randint(0, max_audio_start)
177
+ audio = audio[:, audio_start : audio_start + self.segment_size]
178
+ else:
179
+ audio = torch.nn.functional.pad(
180
+ audio, (0, self.segment_size - audio.size(1)), "constant"
181
+ )
182
+
183
+ mel = mel_spectrogram(
184
+ audio,
185
+ self.n_fft,
186
+ self.num_mels,
187
+ self.sampling_rate,
188
+ self.hop_size,
189
+ self.win_size,
190
+ self.fmin,
191
+ self.fmax,
192
+ center=False,
193
+ )
194
+ else:
195
+ mel = np.load(
196
+ os.path.join(
197
+ self.base_mels_path,
198
+ os.path.splitext(os.path.split(filename)[-1])[0] + ".npy",
199
+ )
200
+ )
201
+ mel = torch.from_numpy(mel)
202
+
203
+ if len(mel.shape) < 3:
204
+ mel = mel.unsqueeze(0)
205
+
206
+ if self.split:
207
+ frames_per_seg = math.ceil(self.segment_size / self.hop_size)
208
+
209
+ if audio.size(1) >= self.segment_size:
210
+ mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
211
+ mel = mel[:, :, mel_start : mel_start + frames_per_seg]
212
+ audio = audio[
213
+ :,
214
+ mel_start
215
+ * self.hop_size : (mel_start + frames_per_seg)
216
+ * self.hop_size,
217
+ ]
218
+ else:
219
+ mel = torch.nn.functional.pad(
220
+ mel, (0, frames_per_seg - mel.size(2)), "constant"
221
+ )
222
+ audio = torch.nn.functional.pad(
223
+ audio, (0, self.segment_size - audio.size(1)), "constant"
224
+ )
225
+
226
+ mel_loss = mel_spectrogram(
227
+ audio,
228
+ self.n_fft,
229
+ self.num_mels,
230
+ self.sampling_rate,
231
+ self.hop_size,
232
+ self.win_size,
233
+ self.fmin,
234
+ self.fmax_loss,
235
+ center=False,
236
+ )
237
+
238
+ return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
239
+
240
+ def __len__(self):
241
+ return len(self.audio_files)
vietTTS/hifigan/data_loader.py ADDED
File without changes
vietTTS/hifigan/mel2wave.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import pickle
4
+
5
+ import haiku as hk
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+
10
+ from .config import FLAGS
11
+ from .model import Generator
12
+
13
+
14
+ class AttrDict(dict):
15
+ def __init__(self, *args, **kwargs):
16
+ super(AttrDict, self).__init__(*args, **kwargs)
17
+ self.__dict__ = self
18
+
19
+
20
+ def mel2wave(mel):
21
+ config_file = "assets/hifigan/config.json"
22
+ MAX_WAV_VALUE = 32768.0
23
+ with open(config_file) as f:
24
+ data = f.read()
25
+ json_config = json.loads(data)
26
+ h = AttrDict(json_config)
27
+
28
+ @hk.transform_with_state
29
+ def forward(x):
30
+ net = Generator(h)
31
+ return net(x)
32
+
33
+ rng = next(hk.PRNGSequence(42))
34
+
35
+ with open(FLAGS.ckpt_dir / "hk_hifi.pickle", "rb") as f:
36
+ params = pickle.load(f)
37
+ aux = {}
38
+ wav, aux = forward.apply(params, aux, rng, mel)
39
+ wav = jnp.squeeze(wav)
40
+ audio = jax.device_get(wav)
41
+ return audio
vietTTS/hifigan/model.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import haiku as hk
2
+ import jax
3
+ import jax.numpy as jnp
4
+
5
+ LRELU_SLOPE = 0.1
6
+
7
+
8
+ def get_padding(kernel_size, dilation=1):
9
+ p = int((kernel_size * dilation - dilation) / 2)
10
+ return ((p, p),)
11
+
12
+
13
+ class ResBlock1(hk.Module):
14
+ def __init__(
15
+ self, h, channels, kernel_size=3, dilation=(1, 3, 5), name="resblock1"
16
+ ):
17
+ super().__init__(name=name)
18
+
19
+ self.h = h
20
+ self.convs1 = [
21
+ hk.Conv1D(
22
+ channels,
23
+ kernel_size,
24
+ 1,
25
+ rate=dilation[i],
26
+ padding=get_padding(kernel_size, dilation[i]),
27
+ name=f"convs1_{i}",
28
+ )
29
+ for i in range(3)
30
+ ]
31
+
32
+ self.convs2 = [
33
+ hk.Conv1D(
34
+ channels,
35
+ kernel_size,
36
+ 1,
37
+ rate=1,
38
+ padding=get_padding(kernel_size, 1),
39
+ name=f"convs2_{i}",
40
+ )
41
+ for i in range(3)
42
+ ]
43
+
44
+ def __call__(self, x):
45
+ for c1, c2 in zip(self.convs1, self.convs2):
46
+ xt = jax.nn.leaky_relu(x, LRELU_SLOPE)
47
+ xt = c1(xt)
48
+ xt = jax.nn.leaky_relu(xt, LRELU_SLOPE)
49
+ xt = c2(xt)
50
+ x = xt + x
51
+ return x
52
+
53
+
54
+ class ResBlock2(hk.Module):
55
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), name="ResBlock2"):
56
+ super().__init__(name=name)
57
+ self.h = h
58
+ self.convs = [
59
+ hk.Conv1D(
60
+ channels,
61
+ kernel_size,
62
+ 1,
63
+ rate=dilation[i],
64
+ padding=get_padding(kernel_size, dilation[i]),
65
+ )
66
+ for i in range(2)
67
+ ]
68
+
69
+ def __call__(self, x):
70
+ for c in self.convs:
71
+ xt = jax.nn.leaky_relu(x, LRELU_SLOPE)
72
+ xt = c(xt)
73
+ x = xt + x
74
+ return x
75
+
76
+
77
+ class Generator(hk.Module):
78
+ def __init__(self, h):
79
+ super().__init__()
80
+ self.h = h
81
+ self.num_kernels = len(h.resblock_kernel_sizes)
82
+ self.num_upsamples = len(h.upsample_rates)
83
+ self.conv_pre = hk.Conv1D(h.upsample_initial_channel, 7, 1, padding=((3, 3),))
84
+ resblock = ResBlock1 if h.resblock == "1" else ResBlock2
85
+ self.ups = []
86
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
87
+ self.ups.append(
88
+ hk.Conv1DTranspose(
89
+ h.upsample_initial_channel // (2 ** (i + 1)),
90
+ kernel_shape=k,
91
+ stride=u,
92
+ padding="SAME",
93
+ name=f"ups_{i}",
94
+ )
95
+ )
96
+
97
+ self.resblocks = []
98
+
99
+ for i in range(len(self.ups)):
100
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
101
+ for j, (k, d) in enumerate(
102
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
103
+ ):
104
+ self.resblocks.append(
105
+ resblock(h, ch, k, d, name=f"res_block1_{len(self.resblocks)}")
106
+ )
107
+ self.conv_post = hk.Conv1D(1, 7, 1, padding=((3, 3),))
108
+
109
+ def __call__(self, x):
110
+ x = self.conv_pre(x)
111
+ for i in range(self.num_upsamples):
112
+ x = jax.nn.leaky_relu(x, LRELU_SLOPE)
113
+
114
+ x = self.ups[i](x)
115
+ xs = None
116
+ for j in range(self.num_kernels):
117
+ if xs is None:
118
+ xs = self.resblocks[i * self.num_kernels + j](x)
119
+ else:
120
+ xs += self.resblocks[i * self.num_kernels + j](x)
121
+ x = xs / self.num_kernels
122
+ x = jax.nn.leaky_relu(x) # default pytorch value
123
+ x = self.conv_post(x)
124
+ x = jnp.tanh(x)
125
+ return x
vietTTS/hifigan/torch_model.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
5
+ from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
6
+
7
+ # from utils import init_weights, get_padding
8
+
9
+ LRELU_SLOPE = 0.1
10
+
11
+
12
+ def get_padding(kernel_size, dilation=1):
13
+ return int((kernel_size * dilation - dilation) / 2)
14
+
15
+
16
+ def init_weights(m, mean=0.0, std=0.01):
17
+ classname = m.__class__.__name__
18
+ if classname.find("Conv") != -1:
19
+ m.weight.data.normal_(mean, std)
20
+
21
+
22
+ class ResBlock1(torch.nn.Module):
23
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
24
+ super(ResBlock1, self).__init__()
25
+ self.h = h
26
+ self.convs1 = nn.ModuleList(
27
+ [
28
+ weight_norm(
29
+ Conv1d(
30
+ channels,
31
+ channels,
32
+ kernel_size,
33
+ 1,
34
+ dilation=dilation[0],
35
+ padding=get_padding(kernel_size, dilation[0]),
36
+ )
37
+ ),
38
+ weight_norm(
39
+ Conv1d(
40
+ channels,
41
+ channels,
42
+ kernel_size,
43
+ 1,
44
+ dilation=dilation[1],
45
+ padding=get_padding(kernel_size, dilation[1]),
46
+ )
47
+ ),
48
+ weight_norm(
49
+ Conv1d(
50
+ channels,
51
+ channels,
52
+ kernel_size,
53
+ 1,
54
+ dilation=dilation[2],
55
+ padding=get_padding(kernel_size, dilation[2]),
56
+ )
57
+ ),
58
+ ]
59
+ )
60
+ self.convs1.apply(init_weights)
61
+
62
+ self.convs2 = nn.ModuleList(
63
+ [
64
+ weight_norm(
65
+ Conv1d(
66
+ channels,
67
+ channels,
68
+ kernel_size,
69
+ 1,
70
+ dilation=1,
71
+ padding=get_padding(kernel_size, 1),
72
+ )
73
+ ),
74
+ weight_norm(
75
+ Conv1d(
76
+ channels,
77
+ channels,
78
+ kernel_size,
79
+ 1,
80
+ dilation=1,
81
+ padding=get_padding(kernel_size, 1),
82
+ )
83
+ ),
84
+ weight_norm(
85
+ Conv1d(
86
+ channels,
87
+ channels,
88
+ kernel_size,
89
+ 1,
90
+ dilation=1,
91
+ padding=get_padding(kernel_size, 1),
92
+ )
93
+ ),
94
+ ]
95
+ )
96
+ self.convs2.apply(init_weights)
97
+
98
+ def forward(self, x):
99
+ for c1, c2 in zip(self.convs1, self.convs2):
100
+ xt = F.leaky_relu(x, LRELU_SLOPE)
101
+ xt = c1(xt)
102
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
103
+ xt = c2(xt)
104
+ x = xt + x
105
+ return x
106
+
107
+ def remove_weight_norm(self):
108
+ for l in self.convs1:
109
+ remove_weight_norm(l)
110
+ for l in self.convs2:
111
+ remove_weight_norm(l)
112
+
113
+
114
+ class ResBlock2(torch.nn.Module):
115
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
116
+ super(ResBlock2, self).__init__()
117
+ self.h = h
118
+ self.convs = nn.ModuleList(
119
+ [
120
+ weight_norm(
121
+ Conv1d(
122
+ channels,
123
+ channels,
124
+ kernel_size,
125
+ 1,
126
+ dilation=dilation[0],
127
+ padding=get_padding(kernel_size, dilation[0]),
128
+ )
129
+ ),
130
+ weight_norm(
131
+ Conv1d(
132
+ channels,
133
+ channels,
134
+ kernel_size,
135
+ 1,
136
+ dilation=dilation[1],
137
+ padding=get_padding(kernel_size, dilation[1]),
138
+ )
139
+ ),
140
+ ]
141
+ )
142
+ self.convs.apply(init_weights)
143
+
144
+ def forward(self, x):
145
+ for c in self.convs:
146
+ xt = F.leaky_relu(x, LRELU_SLOPE)
147
+ xt = c(xt)
148
+ x = xt + x
149
+ return x
150
+
151
+ def remove_weight_norm(self):
152
+ for l in self.convs:
153
+ remove_weight_norm(l)
154
+
155
+
156
+ class Generator(torch.nn.Module):
157
+ def __init__(self, h):
158
+ super(Generator, self).__init__()
159
+ self.h = h
160
+ self.num_kernels = len(h.resblock_kernel_sizes)
161
+ self.num_upsamples = len(h.upsample_rates)
162
+ self.conv_pre = weight_norm(
163
+ Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)
164
+ )
165
+ resblock = ResBlock1 if h.resblock == "1" else ResBlock2
166
+
167
+ self.ups = nn.ModuleList()
168
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
169
+ self.ups.append(
170
+ weight_norm(
171
+ ConvTranspose1d(
172
+ h.upsample_initial_channel // (2**i),
173
+ h.upsample_initial_channel // (2 ** (i + 1)),
174
+ k,
175
+ u,
176
+ padding=(k - u) // 2,
177
+ )
178
+ )
179
+ )
180
+
181
+ self.resblocks = nn.ModuleList()
182
+ for i in range(len(self.ups)):
183
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
184
+ for j, (k, d) in enumerate(
185
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
186
+ ):
187
+ self.resblocks.append(resblock(h, ch, k, d))
188
+
189
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
190
+ self.ups.apply(init_weights)
191
+ self.conv_post.apply(init_weights)
192
+
193
+ def forward(self, x):
194
+ x = self.conv_pre(x)
195
+ for i in range(self.num_upsamples):
196
+ x = F.leaky_relu(x, LRELU_SLOPE)
197
+ x = self.ups[i](x)
198
+ xs = None
199
+ for j in range(self.num_kernels):
200
+ if xs is None:
201
+ xs = self.resblocks[i * self.num_kernels + j](x)
202
+ else:
203
+ xs += self.resblocks[i * self.num_kernels + j](x)
204
+ x = xs / self.num_kernels
205
+ x = F.leaky_relu(x)
206
+ x = self.conv_post(x)
207
+ x = torch.tanh(x)
208
+
209
+ return x
210
+
211
+ def remove_weight_norm(self):
212
+ print("Removing weight norm...")
213
+ for l in self.ups:
214
+ remove_weight_norm(l)
215
+ for l in self.resblocks:
216
+ l.remove_weight_norm()
217
+ remove_weight_norm(self.conv_pre)
218
+ remove_weight_norm(self.conv_post)
219
+
220
+
221
+ class DiscriminatorP(torch.nn.Module):
222
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
223
+ super(DiscriminatorP, self).__init__()
224
+ self.period = period
225
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
226
+ self.convs = nn.ModuleList(
227
+ [
228
+ norm_f(
229
+ Conv2d(
230
+ 1,
231
+ 32,
232
+ (kernel_size, 1),
233
+ (stride, 1),
234
+ padding=(get_padding(5, 1), 0),
235
+ )
236
+ ),
237
+ norm_f(
238
+ Conv2d(
239
+ 32,
240
+ 128,
241
+ (kernel_size, 1),
242
+ (stride, 1),
243
+ padding=(get_padding(5, 1), 0),
244
+ )
245
+ ),
246
+ norm_f(
247
+ Conv2d(
248
+ 128,
249
+ 512,
250
+ (kernel_size, 1),
251
+ (stride, 1),
252
+ padding=(get_padding(5, 1), 0),
253
+ )
254
+ ),
255
+ norm_f(
256
+ Conv2d(
257
+ 512,
258
+ 1024,
259
+ (kernel_size, 1),
260
+ (stride, 1),
261
+ padding=(get_padding(5, 1), 0),
262
+ )
263
+ ),
264
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
265
+ ]
266
+ )
267
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
268
+
269
+ def forward(self, x):
270
+ fmap = []
271
+
272
+ # 1d to 2d
273
+ b, c, t = x.shape
274
+ if t % self.period != 0: # pad first
275
+ n_pad = self.period - (t % self.period)
276
+ x = F.pad(x, (0, n_pad), "reflect")
277
+ t = t + n_pad
278
+ x = x.view(b, c, t // self.period, self.period)
279
+
280
+ for l in self.convs:
281
+ x = l(x)
282
+ x = F.leaky_relu(x, LRELU_SLOPE)
283
+ fmap.append(x)
284
+ x = self.conv_post(x)
285
+ fmap.append(x)
286
+ x = torch.flatten(x, 1, -1)
287
+
288
+ return x, fmap
289
+
290
+
291
+ class MultiPeriodDiscriminator(torch.nn.Module):
292
+ def __init__(self):
293
+ super(MultiPeriodDiscriminator, self).__init__()
294
+ self.discriminators = nn.ModuleList(
295
+ [
296
+ DiscriminatorP(2),
297
+ DiscriminatorP(3),
298
+ DiscriminatorP(5),
299
+ DiscriminatorP(7),
300
+ DiscriminatorP(11),
301
+ ]
302
+ )
303
+
304
+ def forward(self, y, y_hat):
305
+ y_d_rs = []
306
+ y_d_gs = []
307
+ fmap_rs = []
308
+ fmap_gs = []
309
+ for i, d in enumerate(self.discriminators):
310
+ y_d_r, fmap_r = d(y)
311
+ y_d_g, fmap_g = d(y_hat)
312
+ y_d_rs.append(y_d_r)
313
+ fmap_rs.append(fmap_r)
314
+ y_d_gs.append(y_d_g)
315
+ fmap_gs.append(fmap_g)
316
+
317
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
318
+
319
+
320
+ class DiscriminatorS(torch.nn.Module):
321
+ def __init__(self, use_spectral_norm=False):
322
+ super(DiscriminatorS, self).__init__()
323
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
324
+ self.convs = nn.ModuleList(
325
+ [
326
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
327
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
328
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
329
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
330
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
331
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
332
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
333
+ ]
334
+ )
335
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
336
+
337
+ def forward(self, x):
338
+ fmap = []
339
+ for l in self.convs:
340
+ x = l(x)
341
+ x = F.leaky_relu(x, LRELU_SLOPE)
342
+ fmap.append(x)
343
+ x = self.conv_post(x)
344
+ fmap.append(x)
345
+ x = torch.flatten(x, 1, -1)
346
+
347
+ return x, fmap
348
+
349
+
350
+ class MultiScaleDiscriminator(torch.nn.Module):
351
+ def __init__(self):
352
+ super(MultiScaleDiscriminator, self).__init__()
353
+ self.discriminators = nn.ModuleList(
354
+ [
355
+ DiscriminatorS(use_spectral_norm=True),
356
+ DiscriminatorS(),
357
+ DiscriminatorS(),
358
+ ]
359
+ )
360
+ self.meanpools = nn.ModuleList(
361
+ [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]
362
+ )
363
+
364
+ def forward(self, y, y_hat):
365
+ y_d_rs = []
366
+ y_d_gs = []
367
+ fmap_rs = []
368
+ fmap_gs = []
369
+ for i, d in enumerate(self.discriminators):
370
+ if i != 0:
371
+ y = self.meanpools[i - 1](y)
372
+ y_hat = self.meanpools[i - 1](y_hat)
373
+ y_d_r, fmap_r = d(y)
374
+ y_d_g, fmap_g = d(y_hat)
375
+ y_d_rs.append(y_d_r)
376
+ fmap_rs.append(fmap_r)
377
+ y_d_gs.append(y_d_g)
378
+ fmap_gs.append(fmap_g)
379
+
380
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
381
+
382
+
383
+ def feature_loss(fmap_r, fmap_g):
384
+ loss = 0
385
+ for dr, dg in zip(fmap_r, fmap_g):
386
+ for rl, gl in zip(dr, dg):
387
+ loss += torch.mean(torch.abs(rl - gl))
388
+
389
+ return loss * 2
390
+
391
+
392
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
393
+ loss = 0
394
+ r_losses = []
395
+ g_losses = []
396
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
397
+ r_loss = torch.mean((1 - dr) ** 2)
398
+ g_loss = torch.mean(dg**2)
399
+ loss += r_loss + g_loss
400
+ r_losses.append(r_loss.item())
401
+ g_losses.append(g_loss.item())
402
+
403
+ return loss, r_losses, g_losses
404
+
405
+
406
+ def generator_loss(disc_outputs):
407
+ loss = 0
408
+ gen_losses = []
409
+ for dg in disc_outputs:
410
+ l = torch.mean((1 - dg) ** 2)
411
+ gen_losses.append(l)
412
+ loss += l
413
+
414
+ return loss, gen_losses
vietTTS/hifigan/trainer.py ADDED
File without changes
vietTTS/nat/__init__.py ADDED
File without changes
vietTTS/nat/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (158 Bytes). View file
 
vietTTS/nat/__pycache__/config.cpython-39.pyc ADDED
Binary file (2.43 kB). View file
 
vietTTS/nat/__pycache__/data_loader.cpython-39.pyc ADDED
Binary file (4.25 kB). View file
 
vietTTS/nat/__pycache__/model.cpython-39.pyc ADDED
Binary file (6.93 kB). View file
 
vietTTS/nat/__pycache__/text2mel.cpython-39.pyc ADDED
Binary file (4.01 kB). View file
 
vietTTS/nat/acoustic_tpu_trainer.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ from functools import partial
4
+ from typing import Deque
5
+
6
+ import fire
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import jax.tools.colab_tpu
10
+ import matplotlib.pyplot as plt
11
+ import optax
12
+ from tqdm.auto import tqdm
13
+
14
+ from .acoustic_trainer import initial_state, loss_vag, val_loss_fn
15
+ from .config import FLAGS
16
+ from .data_loader import load_textgrid_wav
17
+ from .dsp import MelFilter
18
+ from .utils import print_flags
19
+
20
+
21
+ def setup_colab_tpu():
22
+ jax.tools.colab_tpu.setup_tpu()
23
+
24
+
25
+ def train(
26
+ batch_size: int = 32,
27
+ steps_per_update: int = 10,
28
+ learning_rate: float = 1024e-6,
29
+ ):
30
+ """Train acoustic model on multiple cores (TPU)."""
31
+ lr_schedule = optax.exponential_decay(learning_rate, 50_000, 0.5, staircase=True)
32
+
33
+ optimizer = optax.chain(
34
+ optax.clip_by_global_norm(1.0),
35
+ optax.adamw(lr_schedule, weight_decay=FLAGS.weight_decay),
36
+ )
37
+
38
+ def update_step(prev_state, inputs):
39
+ params, aux, rng, optim_state = prev_state
40
+ rng, new_rng = jax.random.split(rng)
41
+ (loss, new_aux), grads = loss_vag(params, aux, rng, inputs)
42
+ grads = jax.lax.pmean(grads, axis_name="i")
43
+ updates, new_optim_state = optimizer.update(grads, optim_state, params)
44
+ new_params = optax.apply_updates(params, updates)
45
+ next_state = (new_params, new_aux, new_rng, new_optim_state)
46
+ return next_state, loss
47
+
48
+ @partial(jax.pmap, axis_name="i")
49
+ def update(params, aux, rng, optim_state, inputs):
50
+ states, losses = jax.lax.scan(
51
+ update_step, (params, aux, rng, optim_state), inputs
52
+ )
53
+ return states, jnp.mean(losses)
54
+
55
+ print(jax.devices())
56
+ num_devices = jax.device_count()
57
+ train_data_iter = load_textgrid_wav(
58
+ FLAGS.data_dir,
59
+ FLAGS.max_phoneme_seq_len,
60
+ batch_size * num_devices * steps_per_update,
61
+ FLAGS.max_wave_len,
62
+ "train",
63
+ )
64
+ val_data_iter = load_textgrid_wav(
65
+ FLAGS.data_dir,
66
+ FLAGS.max_phoneme_seq_len,
67
+ batch_size,
68
+ FLAGS.max_wave_len,
69
+ "val",
70
+ )
71
+ melfilter = MelFilter(
72
+ FLAGS.sample_rate,
73
+ FLAGS.n_fft,
74
+ FLAGS.mel_dim,
75
+ FLAGS.fmin,
76
+ FLAGS.fmax,
77
+ )
78
+ batch = next(train_data_iter)
79
+ batch = jax.tree_map(lambda x: x[:1], batch)
80
+ batch = batch._replace(mels=melfilter(batch.wavs.astype(jnp.float32) / (2**15)))
81
+ params, aux, rng, optim_state = initial_state(optimizer, batch)
82
+ losses = Deque(maxlen=1000)
83
+ val_losses = Deque(maxlen=100)
84
+
85
+ last_step = -steps_per_update
86
+
87
+ # loading latest checkpoint
88
+ ckpt_fn = FLAGS.ckpt_dir / "acoustic_latest_ckpt.pickle"
89
+ if ckpt_fn.exists():
90
+ print("Resuming from latest checkpoint at", ckpt_fn)
91
+ with open(ckpt_fn, "rb") as f:
92
+ dic = pickle.load(f)
93
+ last_step, params, aux, rng, optim_state = (
94
+ dic["step"],
95
+ dic["params"],
96
+ dic["aux"],
97
+ dic["rng"],
98
+ dic["optim_state"],
99
+ )
100
+
101
+ tr = tqdm(
102
+ range(
103
+ last_step + steps_per_update, FLAGS.num_training_steps + 1, steps_per_update
104
+ ),
105
+ desc="training",
106
+ total=FLAGS.num_training_steps // steps_per_update + 1,
107
+ initial=last_step // steps_per_update + 1,
108
+ )
109
+
110
+ params, aux, rng, optim_state = jax.device_put_replicated(
111
+ (params, aux, rng, optim_state), jax.devices()
112
+ )
113
+
114
+ def batch_reshape(batch):
115
+ return jax.tree_map(
116
+ lambda x: jnp.reshape(x, (num_devices, steps_per_update, -1) + x.shape[1:]),
117
+ batch,
118
+ )
119
+
120
+ for step in tr:
121
+ batch = next(train_data_iter)
122
+ batch = batch_reshape(batch)
123
+ (params, aux, rng, optim_state), loss = update(
124
+ params, aux, rng, optim_state, batch
125
+ )
126
+ losses.append(loss)
127
+
128
+ if step % 10 == 0:
129
+ val_batch = next(val_data_iter)
130
+ val_loss, val_aux, predicted_mel, gt_mel = val_loss_fn(
131
+ *jax.tree_map(lambda x: x[0], (params, aux, rng)), val_batch
132
+ )
133
+ val_losses.append(val_loss)
134
+ attn = jax.device_get(val_aux["acoustic_model"]["attn"])
135
+ predicted_mel = jax.device_get(predicted_mel[0])
136
+ gt_mel = jax.device_get(gt_mel[0])
137
+
138
+ if step % 1000 == 0:
139
+ loss = jnp.mean(sum(losses)).item() / len(losses)
140
+ val_loss = sum(val_losses).item() / len(val_losses)
141
+ tr.write(f"step {step} train loss {loss:.3f} val loss {val_loss:.3f}")
142
+
143
+ # saving predicted mels
144
+ plt.figure(figsize=(10, 10))
145
+ plt.subplot(3, 1, 1)
146
+ plt.imshow(predicted_mel.T, origin="lower", aspect="auto")
147
+ plt.subplot(3, 1, 2)
148
+ plt.imshow(gt_mel.T, origin="lower", aspect="auto")
149
+ plt.subplot(3, 1, 3)
150
+ plt.imshow(attn.T, origin="lower", aspect="auto")
151
+ plt.tight_layout()
152
+ plt.savefig(FLAGS.ckpt_dir / f"mel_{step:06d}.png")
153
+ plt.close()
154
+
155
+ # saving checkpoint
156
+ with open(ckpt_fn, "wb") as f:
157
+ params_, aux_, rng_, optim_state_ = jax.tree_map(
158
+ lambda x: x[0], (params, aux, rng, optim_state)
159
+ )
160
+ pickle.dump(
161
+ {
162
+ "step": step,
163
+ "params": params_,
164
+ "aux": aux_,
165
+ "rng": rng_,
166
+ "optim_state": optim_state_,
167
+ },
168
+ f,
169
+ )
170
+
171
+
172
+ if __name__ == "__main__":
173
+ # we don't use these flags.
174
+ del FLAGS.batch_size
175
+ del FLAGS.learning_rate
176
+ del FLAGS.duration_learning_rate
177
+ del FLAGS.duration_lstm_dim
178
+ del FLAGS.duration_embed_dropout_rate
179
+
180
+ print_flags(FLAGS.__dict__)
181
+
182
+ if "COLAB_TPU_ADDR" in os.environ:
183
+ setup_colab_tpu()
184
+
185
+ if not FLAGS.ckpt_dir.exists():
186
+ print("Create checkpoint dir at", FLAGS.ckpt_dir)
187
+ FLAGS.ckpt_dir.mkdir(parents=True, exist_ok=True)
188
+
189
+ fire.Fire(train)
vietTTS/nat/acoustic_trainer.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from functools import partial
3
+ from typing import Deque
4
+
5
+ import haiku as hk
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import matplotlib.pyplot as plt
9
+ import optax
10
+ from tqdm.auto import tqdm
11
+ from vietTTS.nat.config import AcousticInput
12
+
13
+ from .config import FLAGS, AcousticInput
14
+ from .data_loader import load_textgrid_wav
15
+ from .dsp import MelFilter
16
+ from .model import AcousticModel
17
+ from .utils import print_flags
18
+
19
+
20
+ @hk.transform_with_state
21
+ def net(x):
22
+ return AcousticModel(is_training=True)(x)
23
+
24
+
25
+ @hk.transform_with_state
26
+ def val_net(x):
27
+ return AcousticModel(is_training=False)(x)
28
+
29
+
30
+ def loss_fn(params, aux, rng, inputs: AcousticInput, is_training=True):
31
+ """Compute loss"""
32
+ melfilter = MelFilter(
33
+ FLAGS.sample_rate, FLAGS.n_fft, FLAGS.mel_dim, FLAGS.fmin, FLAGS.fmax
34
+ )
35
+ wavs = inputs.wavs.astype(jnp.float32) / (2**15)
36
+ mels = melfilter(wavs)
37
+ B, L, D = mels.shape
38
+ go_frame = jnp.zeros((B, 1, D), dtype=jnp.float32)
39
+ inp_mels = jnp.concatenate((go_frame, mels[:, :-1, :]), axis=1)
40
+ n_frames = inputs.durations * FLAGS.sample_rate / (FLAGS.n_fft // 4)
41
+ inputs = inputs._replace(mels=inp_mels, durations=n_frames)
42
+ model = net if is_training else val_net
43
+ (mel1_hat, mel2_hat), new_aux = model.apply(params, aux, rng, inputs)
44
+ loss1 = (jnp.square(mel1_hat - mels) + jnp.square(mel2_hat - mels)) / 2
45
+ loss2 = (jnp.abs(mel1_hat - mels) + jnp.abs(mel2_hat - mels)) / 2
46
+ loss = jnp.mean((loss1 + loss2) / 2, axis=-1)
47
+ num_frames = (inputs.wav_lengths // (FLAGS.n_fft // 4))[:, None]
48
+ mask = jnp.arange(0, L)[None, :] < num_frames
49
+ loss = jnp.sum(loss * mask) / jnp.sum(mask)
50
+ return (loss, new_aux) if is_training else (loss, new_aux, mel2_hat, mels)
51
+
52
+
53
+ train_loss_fn = partial(loss_fn, is_training=True)
54
+ val_loss_fn = jax.jit(partial(loss_fn, is_training=False))
55
+
56
+ loss_vag = jax.value_and_grad(train_loss_fn, has_aux=True)
57
+
58
+
59
+ def initial_state(optimizer, batch):
60
+ rng = jax.random.PRNGKey(42)
61
+ params, aux = hk.transform_with_state(lambda x: AcousticModel(True)(x)).init(
62
+ rng, batch
63
+ )
64
+ optim_state = optimizer.init(params)
65
+ return params, aux, rng, optim_state
66
+
67
+
68
+ def train():
69
+
70
+ optimizer = optax.chain(
71
+ optax.clip_by_global_norm(1.0),
72
+ optax.adamw(FLAGS.learning_rate, weight_decay=FLAGS.weight_decay),
73
+ )
74
+
75
+ @jax.jit
76
+ def update(params, aux, rng, optim_state, inputs):
77
+ rng, new_rng = jax.random.split(rng)
78
+ (loss, new_aux), grads = loss_vag(params, aux, rng, inputs)
79
+ updates, new_optim_state = optimizer.update(grads, optim_state, params)
80
+ new_params = optax.apply_updates(updates, params)
81
+ return loss, (new_params, new_aux, new_rng, new_optim_state)
82
+
83
+ train_data_iter = load_textgrid_wav(
84
+ FLAGS.data_dir,
85
+ FLAGS.max_phoneme_seq_len,
86
+ FLAGS.batch_size,
87
+ FLAGS.max_wave_len,
88
+ "train",
89
+ )
90
+ val_data_iter = load_textgrid_wav(
91
+ FLAGS.data_dir,
92
+ FLAGS.max_phoneme_seq_len,
93
+ FLAGS.batch_size,
94
+ FLAGS.max_wave_len,
95
+ "val",
96
+ )
97
+ melfilter = MelFilter(
98
+ FLAGS.sample_rate, FLAGS.n_fft, FLAGS.mel_dim, FLAGS.fmin, FLAGS.fmax
99
+ )
100
+ batch = next(train_data_iter)
101
+ batch = batch._replace(mels=melfilter(batch.wavs.astype(jnp.float32) / (2**15)))
102
+ params, aux, rng, optim_state = initial_state(optimizer, batch)
103
+ losses = Deque(maxlen=1000)
104
+ val_losses = Deque(maxlen=100)
105
+
106
+ last_step = -1
107
+
108
+ # loading latest checkpoint
109
+ ckpt_fn = FLAGS.ckpt_dir / "acoustic_latest_ckpt.pickle"
110
+ if ckpt_fn.exists():
111
+ print("Resuming from latest checkpoint at", ckpt_fn)
112
+ with open(ckpt_fn, "rb") as f:
113
+ dic = pickle.load(f)
114
+ last_step, params, aux, rng, optim_state = (
115
+ dic["step"],
116
+ dic["params"],
117
+ dic["aux"],
118
+ dic["rng"],
119
+ dic["optim_state"],
120
+ )
121
+
122
+ tr = tqdm(
123
+ range(last_step + 1, FLAGS.num_training_steps + 1),
124
+ desc="training",
125
+ total=FLAGS.num_training_steps + 1,
126
+ initial=last_step + 1,
127
+ )
128
+ for step in tr:
129
+ batch = next(train_data_iter)
130
+ loss, (params, aux, rng, optim_state) = update(
131
+ params, aux, rng, optim_state, batch
132
+ )
133
+ losses.append(loss)
134
+
135
+ if step % 10 == 0:
136
+ val_batch = next(val_data_iter)
137
+ val_loss, val_aux, predicted_mel, gt_mel = val_loss_fn(
138
+ params, aux, rng, val_batch
139
+ )
140
+ val_losses.append(val_loss)
141
+ attn = jax.device_get(val_aux["acoustic_model"]["attn"])
142
+ predicted_mel = jax.device_get(predicted_mel[0])
143
+ gt_mel = jax.device_get(gt_mel[0])
144
+
145
+ if step % 1000 == 0:
146
+ loss = sum(losses).item() / len(losses)
147
+ val_loss = sum(val_losses).item() / len(val_losses)
148
+ tr.write(f"step {step} train loss {loss:.3f} val loss {val_loss:.3f}")
149
+
150
+ # saving predicted mels
151
+ plt.figure(figsize=(10, 10))
152
+ plt.subplot(3, 1, 1)
153
+ plt.imshow(predicted_mel.T, origin="lower", aspect="auto")
154
+ plt.subplot(3, 1, 2)
155
+ plt.imshow(gt_mel.T, origin="lower", aspect="auto")
156
+ plt.subplot(3, 1, 3)
157
+ plt.imshow(attn.T, origin="lower", aspect="auto")
158
+ plt.tight_layout()
159
+ plt.savefig(FLAGS.ckpt_dir / f"mel_{step:06d}.png")
160
+ plt.close()
161
+
162
+ # saving checkpoint
163
+ with open(ckpt_fn, "wb") as f:
164
+ pickle.dump(
165
+ {
166
+ "step": step,
167
+ "params": params,
168
+ "aux": aux,
169
+ "rng": rng,
170
+ "optim_state": optim_state,
171
+ },
172
+ f,
173
+ )
174
+
175
+
176
+ if __name__ == "__main__":
177
+ print_flags(FLAGS.__dict__)
178
+ if not FLAGS.ckpt_dir.exists():
179
+ print("Create checkpoint dir at", FLAGS.ckpt_dir)
180
+ FLAGS.ckpt_dir.mkdir(parents=True, exist_ok=True)
181
+ train()
vietTTS/nat/config.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import Namespace
2
+ from pathlib import Path
3
+ from typing import NamedTuple
4
+
5
+ from jax.numpy import ndarray
6
+
7
+
8
+ class FLAGS(Namespace):
9
+ """Configurations"""
10
+
11
+ duration_lstm_dim = 256
12
+ vocab_size = 256
13
+ duration_embed_dropout_rate = 0.5
14
+ num_training_steps = 200_000
15
+ postnet_dim = 512
16
+ acoustic_decoder_dim = 512
17
+ acoustic_encoder_dim = 256
18
+
19
+ # dataset
20
+ max_phoneme_seq_len = 256 * 1
21
+ assert max_phoneme_seq_len % 256 == 0 # prevent compilation error on Colab T4 GPU
22
+ max_wave_len = 1024 * 64 * 3
23
+
24
+ # Montreal Forced Aligner
25
+ special_phonemes = ["sil", "sp", "spn", " "] # [sil], [sp] [spn] [word end]
26
+ sil_index = special_phonemes.index("sil")
27
+ sp_index = sil_index # no use of "sp"
28
+ word_end_index = special_phonemes.index(" ")
29
+ _normal_phonemes = (
30
+ []
31
+ + ["a", "b", "c", "d", "e", "g", "h", "i", "k", "l"]
32
+ + ["m", "n", "o", "p", "q", "r", "s", "t", "u", "v"]
33
+ + ["x", "y", "à", "á", "â", "ã", "è", "é", "ê", "ì"]
34
+ + ["í", "ò", "ó", "ô", "õ", "ù", "ú", "ý", "ă", "đ"]
35
+ + ["ĩ", "ũ", "ơ", "ư", "ạ", "ả", "ấ", "ầ", "ẩ", "ẫ"]
36
+ + ["ậ", "ắ", "ằ", "ẳ", "ẵ", "ặ", "ẹ", "ẻ", "ẽ", "ế"]
37
+ + ["ề", "ể", "ễ", "ệ", "ỉ", "ị", "ọ", "ỏ", "ố", "ồ"]
38
+ + ["ổ", "ỗ", "ộ", "ớ", "ờ", "ở", "ỡ", "ợ", "ụ", "ủ"]
39
+ + ["ứ", "ừ", "ử", "ữ", "ự", "ỳ", "ỵ", "ỷ", "ỹ"]
40
+ )
41
+
42
+ # dsp
43
+ mel_dim = 80
44
+ n_fft = 1024
45
+ sample_rate = 16000
46
+ fmin = 0.0
47
+ fmax = 8000
48
+
49
+ # training
50
+ batch_size = 64
51
+ learning_rate = 1e-4
52
+ duration_learning_rate = 1e-4
53
+ max_grad_norm = 1.0
54
+ weight_decay = 1e-4
55
+ token_mask_prob = 0.1
56
+
57
+ # ckpt
58
+ ckpt_dir = Path("assets/infore/nat")
59
+ data_dir = Path("train_data")
60
+
61
+
62
+ class DurationInput(NamedTuple):
63
+ phonemes: ndarray
64
+ lengths: ndarray
65
+ durations: ndarray
66
+
67
+
68
+ class AcousticInput(NamedTuple):
69
+ phonemes: ndarray
70
+ lengths: ndarray
71
+ durations: ndarray
72
+ wavs: ndarray
73
+ wav_lengths: ndarray
74
+ mels: ndarray
vietTTS/nat/data_loader.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from pathlib import Path
3
+
4
+ import numpy as np
5
+ import textgrid
6
+ from scipy.io import wavfile
7
+
8
+ from .config import FLAGS, AcousticInput, DurationInput
9
+
10
+
11
+ def load_phonemes_set():
12
+ S = FLAGS.special_phonemes + FLAGS._normal_phonemes
13
+ return S
14
+
15
+
16
+ def pad_seq(s, maxlen, value=0):
17
+ assert maxlen >= len(s)
18
+ return tuple(s) + (value,) * (maxlen - len(s))
19
+
20
+
21
+ def is_in_word(phone, word):
22
+ def time_in_word(time, word):
23
+ return (word.minTime - 1e-3) < time and (word.maxTime + 1e-3) > time
24
+
25
+ return time_in_word(phone.minTime, word) and time_in_word(phone.maxTime, word)
26
+
27
+
28
+ def load_textgrid(fn: Path):
29
+ """load textgrid file"""
30
+ tg = textgrid.TextGrid.fromFile(str(fn.resolve()))
31
+ data = []
32
+ words = list(tg[0])
33
+ widx = 0
34
+ assert tg[1][0].minTime == 0, "The first phoneme has to start at time 0"
35
+ for p in tg[1]:
36
+ if not p in words[widx]:
37
+ widx = widx + 1
38
+ if len(words[widx - 1].mark) > 0:
39
+ data.append((FLAGS.special_phonemes[FLAGS.word_end_index], 0.0))
40
+ if widx >= len(words):
41
+ break
42
+ assert p in words[widx], "mismatched word vs phoneme"
43
+ mark = p.mark.strip().lower()
44
+ if len(mark) == 0:
45
+ mark = "sil"
46
+ data.append((mark, p.duration()))
47
+ return data
48
+
49
+
50
+ def textgrid_data_loader(data_dir: Path, seq_len: int, batch_size: int, mode: str):
51
+ """load all textgrid files in the directory"""
52
+ tg_files = sorted(data_dir.glob("*.TextGrid"))
53
+ random.Random(42).shuffle(tg_files)
54
+ L = len(tg_files) * 95 // 100
55
+ assert mode in ["train", "val"]
56
+ phonemes = load_phonemes_set()
57
+ if mode == "train":
58
+ tg_files = tg_files[:L]
59
+ if mode == "val":
60
+ tg_files = tg_files[L:]
61
+
62
+ data = []
63
+ for fn in tg_files:
64
+ ps, ds = zip(*load_textgrid(fn))
65
+ ps = [phonemes.index(p) for p in ps]
66
+ l = len(ps)
67
+ ps = pad_seq(ps, seq_len, 0)
68
+ ds = pad_seq(ds, seq_len, 0)
69
+ data.append((ps, ds, l))
70
+
71
+ batch = []
72
+ while True:
73
+ random.shuffle(data)
74
+ for e in data:
75
+ batch.append(e)
76
+ if len(batch) == batch_size:
77
+ ps, ds, lengths = zip(*batch)
78
+ ps = np.array(ps, dtype=np.int32)
79
+ ds = np.array(ds, dtype=np.float32)
80
+ lengths = np.array(lengths, dtype=np.int32)
81
+ yield DurationInput(ps, lengths, ds)
82
+ batch = []
83
+
84
+
85
+ def load_textgrid_wav(
86
+ data_dir: Path, token_seq_len: int, batch_size, pad_wav_len, mode: str
87
+ ):
88
+ """load wav and textgrid files to memory."""
89
+ tg_files = sorted(data_dir.glob("*.TextGrid"))
90
+ random.Random(42).shuffle(tg_files)
91
+ L = len(tg_files) * 95 // 100
92
+ assert mode in ["train", "val", "gta"]
93
+ phonemes = load_phonemes_set()
94
+ if mode == "gta":
95
+ tg_files = tg_files # all files
96
+ elif mode == "train":
97
+ tg_files = tg_files[:L]
98
+ elif mode == "val":
99
+ tg_files = tg_files[L:]
100
+
101
+ data = []
102
+ for fn in tg_files:
103
+ ps, ds = zip(*load_textgrid(fn))
104
+ ps = [phonemes.index(p) for p in ps]
105
+ l = len(ps)
106
+ ps = pad_seq(ps, token_seq_len, 0)
107
+ ds = pad_seq(ds, token_seq_len, 0)
108
+
109
+ wav_file = data_dir / f"{fn.stem}.wav"
110
+ sr, y = wavfile.read(wav_file)
111
+ y = np.copy(y)
112
+ start_time = 0
113
+ for i, (phone_idx, duration) in enumerate(zip(ps, ds)):
114
+ l = int(start_time * sr)
115
+ end_time = start_time + duration
116
+ r = int(end_time * sr)
117
+ if i == len(ps) - 1:
118
+ r = len(y)
119
+ if phone_idx < len(FLAGS.special_phonemes):
120
+ y[l:r] = 0
121
+ start_time = end_time
122
+
123
+ if len(y) > pad_wav_len:
124
+ y = y[:pad_wav_len]
125
+
126
+ # # normalize to match hifigan preprocessing
127
+ # y = y.astype(np.float32)
128
+ # y = y / np.max(np.abs(y))
129
+ # y = y * 0.95
130
+ # y = y * (2 ** 15)
131
+ # y = y.astype(np.int16)
132
+
133
+ wav_length = len(y)
134
+ y = np.pad(y, (0, pad_wav_len - len(y)))
135
+ data.append((fn.stem, ps, ds, l, y, wav_length))
136
+
137
+ batch = []
138
+ while True:
139
+ random.shuffle(data)
140
+ for idx, e in enumerate(data):
141
+ batch.append(e)
142
+ if len(batch) == batch_size or (mode == "gta" and idx == len(data) - 1):
143
+ names, ps, ds, lengths, wavs, wav_lengths = zip(*batch)
144
+ ps = np.array(ps, dtype=np.int32)
145
+ ds = np.array(ds, dtype=np.float32)
146
+ lengths = np.array(lengths, dtype=np.int32)
147
+ wavs = np.array(wavs, dtype=np.int16)
148
+ wav_lengths = np.array(wav_lengths, dtype=np.int32)
149
+ if mode == "gta":
150
+ yield names, AcousticInput(ps, lengths, ds, wavs, wav_lengths, None)
151
+ else:
152
+ yield AcousticInput(ps, lengths, ds, wavs, wav_lengths, None)
153
+ batch = []
154
+ if mode == "gta":
155
+ assert len(batch) == 0
156
+ break
vietTTS/nat/dsp.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Optional
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import librosa
7
+ from einops import rearrange
8
+ from jax.numpy import ndarray
9
+
10
+
11
+ def rolling_window(a: ndarray, window: int, hop_length: int):
12
+ """return a stack of overlap subsequence of an array.
13
+ ``return jnp.stack( [a[0:10], a[5:15], a[10:20],...], axis=0)``
14
+ Source: https://github.com/google/jax/issues/3171
15
+ Args:
16
+ a (ndarray): input array of shape `[L, ...]`
17
+ window (int): length of each subarray (window).
18
+ hop_length (int): distance between neighbouring windows.
19
+ """
20
+
21
+ idx = (
22
+ jnp.arange(window)[:, None]
23
+ + jnp.arange((len(a) - window) // hop_length + 1)[None, :] * hop_length
24
+ )
25
+ return a[idx]
26
+
27
+
28
+ @partial(jax.jit, static_argnums=[1, 2, 3, 4, 5, 6])
29
+ def stft(
30
+ y: ndarray,
31
+ n_fft: int = 2048,
32
+ hop_length: Optional[int] = None,
33
+ win_length: Optional[int] = None,
34
+ window: str = "hann",
35
+ center: bool = True,
36
+ pad_mode: str = "reflect",
37
+ ):
38
+ """A jax reimplementation of ``librosa.stft`` function."""
39
+
40
+ if win_length is None:
41
+ win_length = n_fft
42
+
43
+ if hop_length is None:
44
+ hop_length = win_length // 4
45
+
46
+ if window == "hann":
47
+ fft_window = jnp.hanning(win_length + 1)[:-1]
48
+ else:
49
+ raise RuntimeError(f"{window} window function is not supported!")
50
+
51
+ pad_len = (n_fft - win_length) // 2
52
+ fft_window = jnp.pad(fft_window, (pad_len, pad_len), mode="constant")
53
+ fft_window = fft_window[:, None]
54
+ if center:
55
+ y = jnp.pad(y, int(n_fft // 2), mode=pad_mode)
56
+
57
+ # jax does not support ``np.lib.stride_tricks.as_strided`` function
58
+ # see https://github.com/google/jax/issues/3171 for comments.
59
+ y_frames = rolling_window(y, n_fft, hop_length) * fft_window
60
+ stft_matrix = jnp.fft.fft(y_frames, axis=0)
61
+ d = int(1 + n_fft // 2)
62
+ return stft_matrix[:d]
63
+
64
+
65
+ @partial(jax.jit, static_argnums=[1, 2, 3, 4, 5, 6])
66
+ def batched_stft(
67
+ y: ndarray,
68
+ n_fft: int,
69
+ hop_length: int,
70
+ win_length: int,
71
+ window: str,
72
+ center: bool = True,
73
+ pad_mode: str = "reflect",
74
+ ):
75
+ """Batched version of ``stft`` function.
76
+ TN => FTN
77
+ """
78
+
79
+ assert len(y.shape) >= 2
80
+ if window == "hann":
81
+ fft_window = jnp.hanning(win_length + 1)[:-1]
82
+ else:
83
+ raise RuntimeError(f"{window} window function is not supported!")
84
+ pad_len = (n_fft - win_length) // 2
85
+ if pad_len > 0:
86
+ fft_window = jnp.pad(fft_window, (pad_len, pad_len), mode="constant")
87
+ win_length = n_fft
88
+ else:
89
+ fft_window = fft_window
90
+ if center:
91
+ pad_width = ((n_fft // 2, n_fft // 2),) + ((0, 0),) * (len(y.shape) - 1)
92
+ y = jnp.pad(y, pad_width, mode=pad_mode)
93
+
94
+ # jax does not support ``np.lib.stride_tricks.as_strided`` function
95
+ # see https://github.com/google/jax/issues/3171 for comments.
96
+ y_frames = rolling_window(y, n_fft, hop_length)
97
+ fft_window = jnp.reshape(fft_window, (-1,) + (1,) * (len(y.shape)))
98
+ y_frames = y_frames * fft_window
99
+ stft_matrix = jnp.fft.fft(y_frames, axis=0)
100
+ d = int(1 + n_fft // 2)
101
+ return stft_matrix[:d]
102
+
103
+
104
+ class MelFilter:
105
+ """Convert waveform to mel spectrogram."""
106
+
107
+ def __init__(self, sample_rate: int, n_fft: int, n_mels: int, fmin=0.0, fmax=8000):
108
+ self.melfb = jax.device_put(
109
+ librosa.filters.mel(
110
+ sr=sample_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax
111
+ )
112
+ )
113
+ self.n_fft = n_fft
114
+
115
+ def __call__(self, y: ndarray) -> ndarray:
116
+ hop_length = self.n_fft // 4
117
+ window_length = self.n_fft
118
+ assert len(y.shape) == 2
119
+ y = rearrange(y, "n s -> s n")
120
+ p = (self.n_fft - hop_length) // 2
121
+ y = jnp.pad(y, ((p, p), (0, 0)), mode="reflect")
122
+ spec = batched_stft(
123
+ y, self.n_fft, hop_length, window_length, "hann", False, "reflect"
124
+ )
125
+ mag = jnp.sqrt(jnp.square(spec.real) + jnp.square(spec.imag) + 1e-9)
126
+ mel = jnp.einsum("ms,sfn->nfm", self.melfb, mag)
127
+ cond = jnp.log(jnp.clip(mel, a_min=1e-5, a_max=None))
128
+ return cond
vietTTS/nat/duration_trainer.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Deque
3
+
4
+ import haiku as hk
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import optax
10
+ from tqdm.auto import tqdm
11
+ from vietTTS.nat.config import DurationInput
12
+
13
+ from .config import FLAGS
14
+ from .data_loader import textgrid_data_loader
15
+ from .model import DurationModel
16
+ from .utils import load_latest_ckpt, print_flags, save_ckpt
17
+
18
+
19
+ def loss_fn(params, aux, rng, x: DurationInput, is_training=True):
20
+ """return the l1 loss"""
21
+
22
+ @hk.transform_with_state
23
+ def net(x):
24
+ return DurationModel(is_training=is_training)(x)
25
+
26
+ if is_training:
27
+ # randomly mask tokens with [WORD END] token
28
+ # during training to avoid overfitting
29
+ m_rng, rng = jax.random.split(rng, 2)
30
+ m = jax.random.bernoulli(m_rng, FLAGS.token_mask_prob, x.phonemes.shape)
31
+ x = x._replace(phonemes=jnp.where(m, FLAGS.word_end_index, x.phonemes))
32
+ durations, aux = net.apply(params, aux, rng, x)
33
+ mask = jnp.arange(0, x.phonemes.shape[1])[None, :] < x.lengths[:, None]
34
+ # NOT predict [WORD END] token
35
+ mask = jnp.where(x.phonemes == FLAGS.word_end_index, False, mask)
36
+ masked_loss = jnp.abs(durations - x.durations) * mask
37
+ loss = jnp.sum(masked_loss) / jnp.sum(mask)
38
+ return loss, aux
39
+
40
+
41
+ forward_fn = jax.jit(
42
+ hk.transform_with_state(lambda x: DurationModel(is_training=False)(x)).apply
43
+ )
44
+
45
+
46
+ def predict_duration(params, aux, rng, x: DurationInput):
47
+ d, _ = forward_fn(params, aux, rng, x)
48
+ return d, x.durations
49
+
50
+
51
+ val_loss_fn = jax.jit(partial(loss_fn, is_training=False))
52
+
53
+ loss_vag = jax.value_and_grad(loss_fn, has_aux=True)
54
+
55
+ optimizer = optax.chain(
56
+ optax.clip_by_global_norm(FLAGS.max_grad_norm),
57
+ optax.adamw(FLAGS.duration_learning_rate, weight_decay=FLAGS.weight_decay),
58
+ )
59
+
60
+
61
+ @jax.jit
62
+ def update(params, aux, rng, optim_state, inputs: DurationInput):
63
+ rng, new_rng = jax.random.split(rng)
64
+ (loss, new_aux), grads = loss_vag(params, aux, rng, inputs)
65
+ updates, new_optim_state = optimizer.update(grads, optim_state, params)
66
+ new_params = optax.apply_updates(params, updates)
67
+ return loss, (new_params, new_aux, new_rng, new_optim_state)
68
+
69
+
70
+ def initial_state(batch):
71
+ rng = jax.random.PRNGKey(42)
72
+ params, aux = hk.transform_with_state(lambda x: DurationModel(True)(x)).init(
73
+ rng, batch
74
+ )
75
+ optim_state = optimizer.init(params)
76
+ return params, aux, rng, optim_state
77
+
78
+
79
+ def plot_val_duration(step: int, batch, params, aux, rng):
80
+ fn = FLAGS.ckpt_dir / f"duration_{step:06d}.png"
81
+ predicted_dur, gt_dur = predict_duration(params, aux, rng, batch)
82
+ L = batch.lengths[0]
83
+ x = np.arange(0, L) * 3
84
+ plt.plot(predicted_dur[0, :L])
85
+ plt.plot(gt_dur[0, :L])
86
+ plt.legend(["predicted", "gt"])
87
+ plt.title("Phoneme durations")
88
+ plt.savefig(fn)
89
+ plt.close()
90
+
91
+
92
+ def train():
93
+ train_data_iter = textgrid_data_loader(
94
+ FLAGS.data_dir, FLAGS.max_phoneme_seq_len, FLAGS.batch_size, mode="train"
95
+ )
96
+ val_data_iter = textgrid_data_loader(
97
+ FLAGS.data_dir, FLAGS.max_phoneme_seq_len, FLAGS.batch_size, mode="val"
98
+ )
99
+ losses = Deque(maxlen=1000)
100
+ val_losses = Deque(maxlen=100)
101
+ latest_ckpt = load_latest_ckpt(FLAGS.ckpt_dir)
102
+ if latest_ckpt is not None:
103
+ last_step, params, aux, rng, optim_state = latest_ckpt
104
+ else:
105
+ last_step = -1
106
+ print("Generate random initial states...")
107
+ params, aux, rng, optim_state = initial_state(next(train_data_iter))
108
+
109
+ tr = tqdm(
110
+ range(last_step + 1, 1 + FLAGS.num_training_steps),
111
+ total=1 + FLAGS.num_training_steps,
112
+ initial=last_step + 1,
113
+ ncols=80,
114
+ desc="training",
115
+ )
116
+ for step in tr:
117
+ batch = next(train_data_iter)
118
+ loss, (params, aux, rng, optim_state) = update(
119
+ params, aux, rng, optim_state, batch
120
+ )
121
+ losses.append(loss)
122
+
123
+ if step % 10 == 0:
124
+ val_loss, _ = val_loss_fn(params, aux, rng, next(val_data_iter))
125
+ val_losses.append(val_loss)
126
+
127
+ if step % 1000 == 0:
128
+ loss = sum(losses).item() / len(losses)
129
+ val_loss = sum(val_losses).item() / len(val_losses)
130
+ plot_val_duration(step, next(val_data_iter), params, aux, rng)
131
+ tr.write(
132
+ f" {step:>6d}/{FLAGS.num_training_steps:>6d} | train loss {loss:.5f} | val loss {val_loss:.5f}"
133
+ )
134
+ save_ckpt(step, params, aux, rng, optim_state, ckpt_dir=FLAGS.ckpt_dir)
135
+
136
+
137
+ if __name__ == "__main__":
138
+ print_flags(FLAGS.__dict__)
139
+ if not FLAGS.ckpt_dir.exists():
140
+ print("Create checkpoint dir at", FLAGS.ckpt_dir)
141
+ FLAGS.ckpt_dir.mkdir(parents=True, exist_ok=True)
142
+ train()
vietTTS/nat/gta.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from argparse import ArgumentParser
3
+ from pathlib import Path
4
+
5
+ import haiku as hk
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+ from tqdm.auto import tqdm
10
+ from vietTTS.nat.config import AcousticInput
11
+
12
+ from .config import FLAGS, AcousticInput
13
+ from .data_loader import load_textgrid_wav
14
+ from .dsp import MelFilter
15
+ from .model import AcousticModel
16
+
17
+
18
+ @hk.transform_with_state
19
+ def net(x):
20
+ return AcousticModel(is_training=True)(x)
21
+
22
+
23
+ @hk.transform_with_state
24
+ def val_net(x):
25
+ return AcousticModel(is_training=False)(x)
26
+
27
+
28
+ def forward_fn_(params, aux, rng, inputs: AcousticInput):
29
+ melfilter = MelFilter(
30
+ FLAGS.sample_rate, FLAGS.n_fft, FLAGS.mel_dim, FLAGS.fmin, FLAGS.fmax
31
+ )
32
+ mels = melfilter(inputs.wavs.astype(jnp.float32) / (2**15))
33
+ B, L, D = mels.shape
34
+ inp_mels = jnp.concatenate(
35
+ (jnp.zeros((B, 1, D), dtype=jnp.float32), mels[:, :-1, :]), axis=1
36
+ )
37
+ n_frames = inputs.durations * FLAGS.sample_rate / (FLAGS.n_fft // 4)
38
+ inputs = inputs._replace(mels=inp_mels, durations=n_frames)
39
+ (mel1_hat, mel2_hat), new_aux = val_net.apply(params, aux, rng, inputs)
40
+ return mel2_hat
41
+
42
+
43
+ forward_fn = jax.jit(forward_fn_)
44
+
45
+
46
+ def generate_gta(out_dir: Path):
47
+ out_dir.mkdir(parents=True, exist_ok=True)
48
+ data_iter = load_textgrid_wav(
49
+ FLAGS.data_dir,
50
+ FLAGS.max_phoneme_seq_len,
51
+ FLAGS.batch_size,
52
+ FLAGS.max_wave_len,
53
+ "gta",
54
+ )
55
+ ckpt_fn = FLAGS.ckpt_dir / "acoustic_latest_ckpt.pickle"
56
+ print("Resuming from latest checkpoint at", ckpt_fn)
57
+ with open(ckpt_fn, "rb") as f:
58
+ dic = pickle.load(f)
59
+ _, params, aux, rng, _ = (
60
+ dic["step"],
61
+ dic["params"],
62
+ dic["aux"],
63
+ dic["rng"],
64
+ dic["optim_state"],
65
+ )
66
+
67
+ tr = tqdm(data_iter)
68
+ for names, batch in tr:
69
+ lengths = batch.wav_lengths
70
+ predicted_mel = forward_fn(params, aux, rng, batch)
71
+ mel = jax.device_get(predicted_mel)
72
+ for idx, fn in enumerate(names):
73
+ file = out_dir / f"{fn}.npy"
74
+ tr.write(f"saving to file {file}")
75
+ l = lengths[idx] // (FLAGS.n_fft // 4)
76
+ np.save(file, mel[idx, :l].T)
77
+
78
+
79
+ if __name__ == "__main__":
80
+ parser = ArgumentParser()
81
+ parser.add_argument("-o", "--output-dir", type=Path, default="gta")
82
+ generate_gta(parser.parse_args().output_dir)