Spaces:
Running
Running
Commit
·
2e6f087
1
Parent(s):
204a314
commit name
Browse files- .gitignore +9 -0
- LICENSE +201 -0
- inference.ipynb +239 -0
- inference_client.py +161 -0
- inference_client_webrtc.py +255 -0
- inference_server.py +172 -0
- ioblocks.py +333 -0
- model.py +443 -0
- requirements.txt +15 -0
- requirements_webrtc.txt +2 -0
- tokenizer.py +581 -0
- transformer.py +381 -0
- utils/__init__.py +3 -0
- utils/blocks.py +92 -0
- utils/dist.py +98 -0
- utils/interp.py +84 -0
.gitignore
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.venv
|
2 |
+
*.wav
|
3 |
+
*.mp3
|
4 |
+
*.m4a
|
5 |
+
!prompts/*.wav
|
6 |
+
!prompts/*.mp3
|
7 |
+
!prompts/*.m4a
|
8 |
+
__pycache__
|
9 |
+
*ckpt
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2024 Standard Intelligence PBC
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
inference.ipynb
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"%load_ext autoreload\n",
|
10 |
+
"%autoreload 2"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "code",
|
15 |
+
"execution_count": null,
|
16 |
+
"metadata": {},
|
17 |
+
"outputs": [],
|
18 |
+
"source": [
|
19 |
+
"import torch as T\n",
|
20 |
+
"import torch.nn as nn\n",
|
21 |
+
"import torch.nn.functional as F\n",
|
22 |
+
"import torchaudio\n",
|
23 |
+
"from utils import load_ckpt, print_colored\n",
|
24 |
+
"from tokenizer import make_tokenizer\n",
|
25 |
+
"from model import get_hertz_dev_config\n",
|
26 |
+
"import matplotlib.pyplot as plt\n",
|
27 |
+
"from IPython.display import Audio, display\n",
|
28 |
+
"\n",
|
29 |
+
"\n",
|
30 |
+
"# If you get an error like \"undefined symbol: __nvJitLinkComplete_12_4, version libnvJitLink.so.12\",\n",
|
31 |
+
"# you need to install PyTorch with the correct CUDA version. Run:\n",
|
32 |
+
"# `pip3 uninstall torch torchaudio && pip3 install torch torchaudio --index-url https://download.pytorch.org/whl/cu121`\n",
|
33 |
+
"\n",
|
34 |
+
"device = 'cuda' if T.cuda.is_available() else 'cpu'\n",
|
35 |
+
"T.cuda.set_device(0)\n",
|
36 |
+
"print_colored(f\"Using device: {device}\", \"grey\")"
|
37 |
+
]
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"cell_type": "code",
|
41 |
+
"execution_count": null,
|
42 |
+
"metadata": {},
|
43 |
+
"outputs": [],
|
44 |
+
"source": [
|
45 |
+
"# This code will automatically download them if it can't find them.\n",
|
46 |
+
"audio_tokenizer = make_tokenizer(device)"
|
47 |
+
]
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"cell_type": "code",
|
51 |
+
"execution_count": 7,
|
52 |
+
"metadata": {},
|
53 |
+
"outputs": [],
|
54 |
+
"source": [
|
55 |
+
"# We have different checkpoints for the single-speaker and two-speaker models\n",
|
56 |
+
"# Set to True to load and run inference with the two-speaker model\n",
|
57 |
+
"TWO_SPEAKER = False\n",
|
58 |
+
"USE_PURE_AUDIO_ABLATION = False # We trained a base model with no text initialization at all. Toggle this to enable it.\n",
|
59 |
+
"assert not (USE_PURE_AUDIO_ABLATION and TWO_SPEAKER) # We only have a single-speaker version of this model.\n"
|
60 |
+
]
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"cell_type": "code",
|
64 |
+
"execution_count": null,
|
65 |
+
"metadata": {},
|
66 |
+
"outputs": [],
|
67 |
+
"source": [
|
68 |
+
"model_config = get_hertz_dev_config(is_split=TWO_SPEAKER, use_pure_audio_ablation=USE_PURE_AUDIO_ABLATION)\n",
|
69 |
+
"\n",
|
70 |
+
"generator = model_config()\n",
|
71 |
+
"generator = generator.eval().to(T.bfloat16).to(device)"
|
72 |
+
]
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"cell_type": "code",
|
76 |
+
"execution_count": null,
|
77 |
+
"metadata": {},
|
78 |
+
"outputs": [],
|
79 |
+
"source": [
|
80 |
+
"def load_and_preprocess_audio(audio_path):\n",
|
81 |
+
" print_colored(\"Loading and preprocessing audio...\", \"blue\", bold=True)\n",
|
82 |
+
" # Load audio file\n",
|
83 |
+
" audio_tensor, sr = torchaudio.load(audio_path)\n",
|
84 |
+
" print_colored(f\"Loaded audio shape: {audio_tensor.shape}\", \"grey\")\n",
|
85 |
+
" \n",
|
86 |
+
" if TWO_SPEAKER:\n",
|
87 |
+
" if audio_tensor.shape[0] == 1:\n",
|
88 |
+
" print_colored(\"Converting mono to stereo...\", \"grey\")\n",
|
89 |
+
" audio_tensor = audio_tensor.repeat(2, 1)\n",
|
90 |
+
" print_colored(f\"Stereo audio shape: {audio_tensor.shape}\", \"grey\")\n",
|
91 |
+
" else:\n",
|
92 |
+
" if audio_tensor.shape[0] == 2:\n",
|
93 |
+
" print_colored(\"Converting stereo to mono...\", \"grey\")\n",
|
94 |
+
" audio_tensor = audio_tensor.mean(dim=0).unsqueeze(0)\n",
|
95 |
+
" print_colored(f\"Mono audio shape: {audio_tensor.shape}\", \"grey\")\n",
|
96 |
+
" \n",
|
97 |
+
" # Resample to 16kHz if needed\n",
|
98 |
+
" if sr != 16000:\n",
|
99 |
+
" print_colored(f\"Resampling from {sr}Hz to 16000Hz...\", \"grey\")\n",
|
100 |
+
" resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)\n",
|
101 |
+
" audio_tensor = resampler(audio_tensor)\n",
|
102 |
+
" \n",
|
103 |
+
" # Clip to 5 minutes if needed\n",
|
104 |
+
" max_samples = 16000 * 60 * 5\n",
|
105 |
+
" if audio_tensor.shape[1] > max_samples:\n",
|
106 |
+
" print_colored(\"Clipping audio to 5 minutes...\", \"grey\")\n",
|
107 |
+
" audio_tensor = audio_tensor[:, :max_samples]\n",
|
108 |
+
"\n",
|
109 |
+
" \n",
|
110 |
+
" print_colored(\"Audio preprocessing complete!\", \"green\")\n",
|
111 |
+
" return audio_tensor.unsqueeze(0)\n",
|
112 |
+
"\n",
|
113 |
+
"def display_audio(audio_tensor):\n",
|
114 |
+
" audio_tensor = audio_tensor.cpu().squeeze()\n",
|
115 |
+
" if audio_tensor.ndim == 1:\n",
|
116 |
+
" audio_tensor = audio_tensor.unsqueeze(0)\n",
|
117 |
+
" audio_tensor = audio_tensor.float()\n",
|
118 |
+
"\n",
|
119 |
+
" # Make a waveform plot\n",
|
120 |
+
" plt.figure(figsize=(4, 1))\n",
|
121 |
+
" plt.plot(audio_tensor.numpy()[0], linewidth=0.5)\n",
|
122 |
+
" plt.axis('off')\n",
|
123 |
+
" plt.show()\n",
|
124 |
+
"\n",
|
125 |
+
" # Make an audio player\n",
|
126 |
+
" display(Audio(audio_tensor.numpy(), rate=16000))\n",
|
127 |
+
" print_colored(f\"Audio ready for playback ↑\", \"green\", bold=True)\n",
|
128 |
+
" \n",
|
129 |
+
" \n",
|
130 |
+
"\n",
|
131 |
+
"# Our model is very prompt-sensitive, so we recommend experimenting with a diverse set of prompts.\n",
|
132 |
+
"prompt_audio = load_and_preprocess_audio('./prompts/toaskanymore.wav')\n",
|
133 |
+
"display_audio(prompt_audio)\n",
|
134 |
+
"prompt_len_seconds = 3\n",
|
135 |
+
"prompt_len = prompt_len_seconds * 8"
|
136 |
+
]
|
137 |
+
},
|
138 |
+
{
|
139 |
+
"cell_type": "code",
|
140 |
+
"execution_count": null,
|
141 |
+
"metadata": {},
|
142 |
+
"outputs": [],
|
143 |
+
"source": [
|
144 |
+
"print_colored(\"Encoding prompt...\", \"blue\")\n",
|
145 |
+
"with T.autocast(device_type='cuda', dtype=T.bfloat16):\n",
|
146 |
+
" if TWO_SPEAKER:\n",
|
147 |
+
" encoded_prompt_audio_ch1 = audio_tokenizer.latent_from_data(prompt_audio[:, 0:1].to(device))\n",
|
148 |
+
" encoded_prompt_audio_ch2 = audio_tokenizer.latent_from_data(prompt_audio[:, 1:2].to(device))\n",
|
149 |
+
" encoded_prompt_audio = T.cat([encoded_prompt_audio_ch1, encoded_prompt_audio_ch2], dim=-1)\n",
|
150 |
+
" else:\n",
|
151 |
+
" encoded_prompt_audio = audio_tokenizer.latent_from_data(prompt_audio.to(device))\n",
|
152 |
+
"print_colored(f\"Encoded prompt shape: {encoded_prompt_audio.shape}\", \"grey\")\n",
|
153 |
+
"print_colored(\"Prompt encoded successfully!\", \"green\")"
|
154 |
+
]
|
155 |
+
},
|
156 |
+
{
|
157 |
+
"cell_type": "code",
|
158 |
+
"execution_count": null,
|
159 |
+
"metadata": {},
|
160 |
+
"outputs": [],
|
161 |
+
"source": [
|
162 |
+
"def get_completion(encoded_prompt_audio, prompt_len, gen_len=None):\n",
|
163 |
+
" prompt_len_seconds = prompt_len / 8\n",
|
164 |
+
" print_colored(f\"Prompt length: {prompt_len_seconds:.2f}s\", \"grey\")\n",
|
165 |
+
" print_colored(\"Completing audio...\", \"blue\")\n",
|
166 |
+
" encoded_prompt_audio = encoded_prompt_audio[:, :prompt_len]\n",
|
167 |
+
" with T.autocast(device_type='cuda', dtype=T.bfloat16):\n",
|
168 |
+
" completed_audio_batch = generator.completion(\n",
|
169 |
+
" encoded_prompt_audio, \n",
|
170 |
+
" temps=(.8, (0.5, 0.1)), # (token_temp, (categorical_temp, gaussian_temp))\n",
|
171 |
+
" use_cache=True,\n",
|
172 |
+
" gen_len=gen_len)\n",
|
173 |
+
"\n",
|
174 |
+
" completed_audio = completed_audio_batch\n",
|
175 |
+
" print_colored(f\"Decoding completion...\", \"blue\")\n",
|
176 |
+
" if TWO_SPEAKER:\n",
|
177 |
+
" decoded_completion_ch1 = audio_tokenizer.data_from_latent(completed_audio[:, :, :32].bfloat16())\n",
|
178 |
+
" decoded_completion_ch2 = audio_tokenizer.data_from_latent(completed_audio[:, :, 32:].bfloat16())\n",
|
179 |
+
" decoded_completion = T.cat([decoded_completion_ch1, decoded_completion_ch2], dim=0)\n",
|
180 |
+
" else:\n",
|
181 |
+
" decoded_completion = audio_tokenizer.data_from_latent(completed_audio.bfloat16())\n",
|
182 |
+
" print_colored(f\"Decoded completion shape: {decoded_completion.shape}\", \"grey\")\n",
|
183 |
+
"\n",
|
184 |
+
" print_colored(\"Preparing audio for playback...\", \"blue\")\n",
|
185 |
+
"\n",
|
186 |
+
" audio_tensor = decoded_completion.cpu().squeeze()\n",
|
187 |
+
" if audio_tensor.ndim == 1:\n",
|
188 |
+
" audio_tensor = audio_tensor.unsqueeze(0)\n",
|
189 |
+
" audio_tensor = audio_tensor.float()\n",
|
190 |
+
"\n",
|
191 |
+
" if audio_tensor.abs().max() > 1:\n",
|
192 |
+
" audio_tensor = audio_tensor / audio_tensor.abs().max()\n",
|
193 |
+
"\n",
|
194 |
+
" return audio_tensor[:, max(prompt_len*2000 - 16000, 0):]\n",
|
195 |
+
"\n",
|
196 |
+
"num_completions = 10\n",
|
197 |
+
"print_colored(f\"Generating {num_completions} completions...\", \"blue\")\n",
|
198 |
+
"for _ in range(num_completions):\n",
|
199 |
+
" completion = get_completion(encoded_prompt_audio, prompt_len, gen_len=20*8) # 20 seconds of generation\n",
|
200 |
+
" display_audio(completion)"
|
201 |
+
]
|
202 |
+
},
|
203 |
+
{
|
204 |
+
"cell_type": "code",
|
205 |
+
"execution_count": null,
|
206 |
+
"metadata": {},
|
207 |
+
"outputs": [],
|
208 |
+
"source": []
|
209 |
+
},
|
210 |
+
{
|
211 |
+
"cell_type": "code",
|
212 |
+
"execution_count": null,
|
213 |
+
"metadata": {},
|
214 |
+
"outputs": [],
|
215 |
+
"source": []
|
216 |
+
}
|
217 |
+
],
|
218 |
+
"metadata": {
|
219 |
+
"kernelspec": {
|
220 |
+
"display_name": ".venv",
|
221 |
+
"language": "python",
|
222 |
+
"name": "python3"
|
223 |
+
},
|
224 |
+
"language_info": {
|
225 |
+
"codemirror_mode": {
|
226 |
+
"name": "ipython",
|
227 |
+
"version": 3
|
228 |
+
},
|
229 |
+
"file_extension": ".py",
|
230 |
+
"mimetype": "text/x-python",
|
231 |
+
"name": "python",
|
232 |
+
"nbconvert_exporter": "python",
|
233 |
+
"pygments_lexer": "ipython3",
|
234 |
+
"version": "3.10.12"
|
235 |
+
}
|
236 |
+
},
|
237 |
+
"nbformat": 4,
|
238 |
+
"nbformat_minor": 2
|
239 |
+
}
|
inference_client.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# server.py remains the same as before
|
2 |
+
|
3 |
+
# Updated client.py
|
4 |
+
import asyncio
|
5 |
+
import websockets
|
6 |
+
import sounddevice as sd
|
7 |
+
import numpy as np
|
8 |
+
import base64
|
9 |
+
import queue
|
10 |
+
import argparse
|
11 |
+
import requests
|
12 |
+
import time
|
13 |
+
|
14 |
+
class AudioClient:
|
15 |
+
def __init__(self, server_url="ws://localhost:8000", token_temp=None, categorical_temp=None, gaussian_temp=None):
|
16 |
+
# Convert ws:// to http:// for the base URL
|
17 |
+
self.base_url = server_url.replace("ws://", "http://")
|
18 |
+
self.server_url = f"{server_url}/audio"
|
19 |
+
|
20 |
+
# Set temperatures if provided
|
21 |
+
if any(t is not None for t in [token_temp, categorical_temp, gaussian_temp]):
|
22 |
+
self.set_temperature_and_echo(token_temp, categorical_temp, gaussian_temp)
|
23 |
+
|
24 |
+
# Initialize queues
|
25 |
+
self.audio_queue = queue.Queue()
|
26 |
+
self.output_queue = queue.Queue()
|
27 |
+
|
28 |
+
def set_temperature_and_echo(self, token_temp=None, categorical_temp=None, gaussian_temp=None, echo_testing = False):
|
29 |
+
"""Send temperature settings to server"""
|
30 |
+
params = {}
|
31 |
+
if token_temp is not None:
|
32 |
+
params['token_temp'] = token_temp
|
33 |
+
if categorical_temp is not None:
|
34 |
+
params['categorical_temp'] = categorical_temp
|
35 |
+
if gaussian_temp is not None:
|
36 |
+
params['gaussian_temp'] = gaussian_temp
|
37 |
+
|
38 |
+
response = requests.post(f"{self.base_url}/set_temperature", params=params)
|
39 |
+
print(response.json()['message'])
|
40 |
+
|
41 |
+
def audio_callback(self, indata, frames, time, status):
|
42 |
+
"""This is called for each audio block"""
|
43 |
+
if status:
|
44 |
+
print(status)
|
45 |
+
# if np.isclose(indata, 0).all():
|
46 |
+
# raise Exception('Audio input is not working - received all zeros')
|
47 |
+
# Convert float32 to int16 for efficient transmission
|
48 |
+
indata_int16 = (indata.copy() * 32767).astype(np.int16)
|
49 |
+
# indata_int16 = np.zeros_like(indata_int16)
|
50 |
+
self.audio_queue.put(indata_int16)
|
51 |
+
|
52 |
+
def output_stream_callback(self, outdata, frames, time, status):
|
53 |
+
"""Callback for output stream to get audio data"""
|
54 |
+
if status:
|
55 |
+
print(status)
|
56 |
+
|
57 |
+
try:
|
58 |
+
data = self.output_queue.get_nowait()
|
59 |
+
data = data.astype(np.float32) / 32767.0
|
60 |
+
if len(data) < len(outdata):
|
61 |
+
outdata[:len(data)] = data
|
62 |
+
outdata[len(data):] = 0
|
63 |
+
else:
|
64 |
+
outdata[:] = data[:len(outdata)]
|
65 |
+
except queue.Empty:
|
66 |
+
outdata.fill(0)
|
67 |
+
|
68 |
+
async def process_audio(self):
|
69 |
+
async with websockets.connect(self.server_url) as ws:
|
70 |
+
while self.running:
|
71 |
+
if not self.audio_queue.empty():
|
72 |
+
# Get recorded audio
|
73 |
+
audio_data = self.audio_queue.get()
|
74 |
+
print(f'Data from microphone:{audio_data.shape, audio_data.dtype, audio_data.min(), audio_data.max()}')
|
75 |
+
|
76 |
+
# Convert to base64
|
77 |
+
audio_b64 = base64.b64encode(audio_data.tobytes()).decode('utf-8')
|
78 |
+
|
79 |
+
# Send to server
|
80 |
+
time_sent = time.time()
|
81 |
+
await ws.send(f"data:audio/raw;base64,{audio_b64}")
|
82 |
+
|
83 |
+
# Receive processed audio
|
84 |
+
response = await ws.recv()
|
85 |
+
response = response.split(",")[1]
|
86 |
+
time_received = time.time()
|
87 |
+
print(f"Data sent: {audio_b64[:10]}. Data received: {response[:10]}. Received in {(time_received - time_sent) * 1000:.2f} ms")
|
88 |
+
processed_audio = np.frombuffer(
|
89 |
+
base64.b64decode(response),
|
90 |
+
dtype=np.int16
|
91 |
+
).reshape(-1, CHANNELS)
|
92 |
+
print(f'Data from model:{processed_audio.shape, processed_audio.dtype, processed_audio.min(), processed_audio.max()}')
|
93 |
+
|
94 |
+
self.output_queue.put(processed_audio)
|
95 |
+
|
96 |
+
def start(self):
|
97 |
+
self.running = True
|
98 |
+
# Print audio device information
|
99 |
+
devices = sd.query_devices()
|
100 |
+
default_input = sd.query_devices(kind='input')
|
101 |
+
default_output = sd.query_devices(kind='output')
|
102 |
+
|
103 |
+
print("\nAudio Device Configuration:")
|
104 |
+
print("-" * 50)
|
105 |
+
print(f"Default Input Device:\n{default_input}\n")
|
106 |
+
print(f"Default Output Device:\n{default_output}\n")
|
107 |
+
print("\nAll Available Devices:")
|
108 |
+
print("-" * 50)
|
109 |
+
for i, device in enumerate(devices):
|
110 |
+
print(f"Device {i}:")
|
111 |
+
print(f"Name: {device['name']}")
|
112 |
+
print(f"Channels (in/out): {device['max_input_channels']}/{device['max_output_channels']}")
|
113 |
+
print(f"Sample Rates: {device['default_samplerate']}")
|
114 |
+
print()
|
115 |
+
input_device = input("Enter the index of the input device or press enter for default: ")
|
116 |
+
output_device = input("Enter the index of the output device or press enter for default: ")
|
117 |
+
if input_device == "":
|
118 |
+
input_device = default_input['index']
|
119 |
+
if output_device == "":
|
120 |
+
output_device = default_output['index']
|
121 |
+
with sd.InputStream(callback=self.audio_callback,
|
122 |
+
channels=CHANNELS,
|
123 |
+
samplerate=SAMPLE_RATE,
|
124 |
+
device=int(input_device),
|
125 |
+
blocksize=2000), \
|
126 |
+
sd.OutputStream(callback=self.output_stream_callback,
|
127 |
+
channels=CHANNELS,
|
128 |
+
samplerate=SAMPLE_RATE,
|
129 |
+
blocksize=2000,
|
130 |
+
device=int(output_device)):
|
131 |
+
|
132 |
+
asyncio.run(self.process_audio())
|
133 |
+
|
134 |
+
def stop(self):
|
135 |
+
self.running = False
|
136 |
+
|
137 |
+
if __name__ == "__main__":
|
138 |
+
parser = argparse.ArgumentParser(description='Audio Client with Temperature Control')
|
139 |
+
parser.add_argument('--token_temp', '-t1', type=float, help='Token (LM) temperature parameter')
|
140 |
+
parser.add_argument('--categorical_temp', '-t2', type=float, help='Categorical (VAE) temperature parameter')
|
141 |
+
parser.add_argument('--gaussian_temp', '-t3', type=float, help='Gaussian (VAE) temperature parameter')
|
142 |
+
parser.add_argument('--server', '-s', default="ws://localhost:8000",
|
143 |
+
help='Server URL (default: ws://localhost:8000)')
|
144 |
+
|
145 |
+
args = parser.parse_args()
|
146 |
+
|
147 |
+
# Audio settings
|
148 |
+
SAMPLE_RATE = 16000
|
149 |
+
CHANNELS = 1
|
150 |
+
|
151 |
+
client = AudioClient(
|
152 |
+
server_url=args.server,
|
153 |
+
token_temp=args.token_temp,
|
154 |
+
categorical_temp=args.categorical_temp,
|
155 |
+
gaussian_temp=args.gaussian_temp
|
156 |
+
)
|
157 |
+
|
158 |
+
try:
|
159 |
+
client.start()
|
160 |
+
except KeyboardInterrupt:
|
161 |
+
client.stop()
|
inference_client_webrtc.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# server.py remains the same as before
|
2 |
+
|
3 |
+
# Updated client.py
|
4 |
+
import asyncio
|
5 |
+
import websockets
|
6 |
+
import numpy as np
|
7 |
+
import base64
|
8 |
+
import argparse
|
9 |
+
import requests
|
10 |
+
import time
|
11 |
+
import torch
|
12 |
+
import torchaudio
|
13 |
+
|
14 |
+
import av
|
15 |
+
import streamlit as st
|
16 |
+
from typing import List
|
17 |
+
from streamlit_webrtc import WebRtcMode, webrtc_streamer
|
18 |
+
|
19 |
+
class AudioClient:
|
20 |
+
def __init__(self, server_url="ws://localhost:8000", token_temp=None, categorical_temp=None, gaussian_temp=None):
|
21 |
+
# Convert ws:// to http:// for the base URL
|
22 |
+
self.base_url = server_url.replace("ws://", "http://")
|
23 |
+
self.server_url = f"{server_url}/audio"
|
24 |
+
self.sound_check = False
|
25 |
+
|
26 |
+
# Set temperatures if provided
|
27 |
+
if any(t is not None for t in [token_temp, categorical_temp, gaussian_temp]):
|
28 |
+
response_message = self.set_temperature_and_echo(token_temp, categorical_temp, gaussian_temp)
|
29 |
+
print(response_message)
|
30 |
+
|
31 |
+
self.downsampler = torchaudio.transforms.Resample(STREAMING_SAMPLE_RATE, SAMPLE_RATE)
|
32 |
+
self.upsampler = torchaudio.transforms.Resample(SAMPLE_RATE, STREAMING_SAMPLE_RATE)
|
33 |
+
self.ws = None
|
34 |
+
self.in_buffer = None
|
35 |
+
self.out_buffer = None
|
36 |
+
|
37 |
+
def set_temperature_and_echo(self, token_temp=None, categorical_temp=None, gaussian_temp=None, echo_testing = False):
|
38 |
+
"""Send temperature settings to server"""
|
39 |
+
params = {}
|
40 |
+
if token_temp is not None:
|
41 |
+
params['token_temp'] = token_temp
|
42 |
+
if categorical_temp is not None:
|
43 |
+
params['categorical_temp'] = categorical_temp
|
44 |
+
if gaussian_temp is not None:
|
45 |
+
params['gaussian_temp'] = gaussian_temp
|
46 |
+
|
47 |
+
response = requests.post(f"{self.base_url}/set_temperature", params=params)
|
48 |
+
response_message = response.json()['message']
|
49 |
+
return response_message
|
50 |
+
|
51 |
+
def _resample(self, audio_data: np.ndarray, resampler: torchaudio.transforms.Resample) -> np.ndarray:
|
52 |
+
audio_data = audio_data.astype(np.float32) / 32767.0
|
53 |
+
audio_data = resampler(torch.tensor(audio_data)).numpy()
|
54 |
+
audio_data = (audio_data * 32767.0).astype(np.int16)
|
55 |
+
return audio_data
|
56 |
+
|
57 |
+
def upsample(self, audio_data: np.ndarray) -> np.ndarray:
|
58 |
+
return self._resample(audio_data, self.upsampler)
|
59 |
+
|
60 |
+
def downsample(self, audio_data: np.ndarray) -> np.ndarray:
|
61 |
+
return self._resample(audio_data, self.downsampler)
|
62 |
+
|
63 |
+
def from_s16_format(self, audio_data: np.ndarray, channels: int) -> np.ndarray:
|
64 |
+
if channels == 2:
|
65 |
+
audio_data = audio_data.reshape(-1, 2).T
|
66 |
+
else:
|
67 |
+
audio_data = audio_data.reshape(-1)
|
68 |
+
return audio_data
|
69 |
+
|
70 |
+
def to_s16_format(self, audio_data: np.ndarray):
|
71 |
+
if len(audio_data.shape) == 2 and audio_data.shape[0] == 2:
|
72 |
+
audio_data = audio_data.T.reshape(1, -1)
|
73 |
+
elif len(audio_data.shape) == 1:
|
74 |
+
audio_data = audio_data.reshape(1, -1)
|
75 |
+
return audio_data
|
76 |
+
|
77 |
+
def to_channels(self, audio_data: np.ndarray, channels: int) -> np.ndarray:
|
78 |
+
current_channels = audio_data.shape[0] if len(audio_data.shape) == 2 else 1
|
79 |
+
if current_channels == channels:
|
80 |
+
return audio_data
|
81 |
+
elif current_channels == 1 and channels == 2:
|
82 |
+
audio_data = np.tile(audio_data, 2).reshape(2, -1)
|
83 |
+
elif current_channels == 2 and channels == 1:
|
84 |
+
audio_data = audio_data.astype(np.float32) / 32767.0
|
85 |
+
audio_data = audio_data.mean(axis=0)
|
86 |
+
audio_data = (audio_data * 32767.0).astype(np.int16)
|
87 |
+
return audio_data
|
88 |
+
|
89 |
+
async def process_audio(self, audio_data: np.ndarray) -> np.ndarray:
|
90 |
+
if self.ws is None:
|
91 |
+
self.ws = await websockets.connect(self.server_url)
|
92 |
+
|
93 |
+
audio_data = audio_data.reshape(-1, CHANNELS)
|
94 |
+
print(f'Data from microphone:{audio_data.shape, audio_data.dtype, audio_data.min(), audio_data.max()}')
|
95 |
+
|
96 |
+
# Convert to base64
|
97 |
+
audio_b64 = base64.b64encode(audio_data.tobytes()).decode('utf-8')
|
98 |
+
|
99 |
+
# Send to server
|
100 |
+
time_sent = time.time()
|
101 |
+
await self.ws.send(f"data:audio/raw;base64,{audio_b64}")
|
102 |
+
|
103 |
+
# Receive processed audio
|
104 |
+
response = await self.ws.recv()
|
105 |
+
response = response.split(",")[1]
|
106 |
+
time_received = time.time()
|
107 |
+
print(f"Data sent: {audio_b64[:10]}. Data received: {response[:10]}. Received in {(time_received - time_sent) * 1000:.2f} ms")
|
108 |
+
processed_audio = np.frombuffer(
|
109 |
+
base64.b64decode(response),
|
110 |
+
dtype=np.int16
|
111 |
+
).reshape(-1, CHANNELS)
|
112 |
+
print(f'Data from model:{processed_audio.shape, processed_audio.dtype, processed_audio.min(), processed_audio.max()}')
|
113 |
+
|
114 |
+
if CHANNELS == 1:
|
115 |
+
processed_audio = processed_audio.reshape(-1)
|
116 |
+
return processed_audio
|
117 |
+
|
118 |
+
async def queued_audio_frames_callback(self, frames: List[av.AudioFrame]) -> List[av.AudioFrame]:
|
119 |
+
out_frames = []
|
120 |
+
for frame in frames:
|
121 |
+
# Read in audio
|
122 |
+
audio_data = frame.to_ndarray()
|
123 |
+
|
124 |
+
# Convert input audio from s16 format, convert to `CHANNELS` number of channels, and downsample
|
125 |
+
audio_data = self.from_s16_format(audio_data, len(frame.layout.channels))
|
126 |
+
audio_data = self.to_channels(audio_data, CHANNELS)
|
127 |
+
audio_data = self.downsample(audio_data)
|
128 |
+
|
129 |
+
# Add audio to input buffer
|
130 |
+
if self.in_buffer is None:
|
131 |
+
self.in_buffer = audio_data
|
132 |
+
else:
|
133 |
+
self.in_buffer = np.concatenate((self.in_buffer, audio_data), axis=-1)
|
134 |
+
|
135 |
+
# Take BLOCK_SIZE samples from input buffer if available for processing
|
136 |
+
if self.in_buffer.shape[0] >= BLOCK_SIZE:
|
137 |
+
audio_data = self.in_buffer[:BLOCK_SIZE]
|
138 |
+
self.in_buffer = self.in_buffer[BLOCK_SIZE:]
|
139 |
+
else:
|
140 |
+
audio_data = None
|
141 |
+
|
142 |
+
# Process audio if available and add resulting audio to output buffer
|
143 |
+
if audio_data is not None:
|
144 |
+
if not self.sound_check:
|
145 |
+
audio_data = await self.process_audio(audio_data)
|
146 |
+
if self.out_buffer is None:
|
147 |
+
self.out_buffer = audio_data
|
148 |
+
else:
|
149 |
+
self.out_buffer = np.concatenate((self.out_buffer, audio_data), axis=-1)
|
150 |
+
|
151 |
+
# Take `out_samples` samples from output buffer if available for output
|
152 |
+
out_samples = int(frame.samples * SAMPLE_RATE / STREAMING_SAMPLE_RATE)
|
153 |
+
if self.out_buffer is not None and self.out_buffer.shape[0] >= out_samples:
|
154 |
+
audio_data = self.out_buffer[:out_samples]
|
155 |
+
self.out_buffer = self.out_buffer[out_samples:]
|
156 |
+
else:
|
157 |
+
audio_data = None
|
158 |
+
|
159 |
+
# Output silence if no audio data available
|
160 |
+
if audio_data is None:
|
161 |
+
# output silence
|
162 |
+
audio_data = np.zeros(out_samples, dtype=np.int16)
|
163 |
+
|
164 |
+
# Upsample output audio, convert to original number of channels, and convert to s16 format
|
165 |
+
audio_data = self.upsample(audio_data)
|
166 |
+
audio_data = self.to_channels(audio_data, len(frame.layout.channels))
|
167 |
+
audio_data = self.to_s16_format(audio_data)
|
168 |
+
|
169 |
+
# return audio data as AudioFrame
|
170 |
+
new_frame = av.AudioFrame.from_ndarray(audio_data, format=frame.format.name, layout=frame.layout.name)
|
171 |
+
new_frame.sample_rate = frame.sample_rate
|
172 |
+
out_frames.append(new_frame)
|
173 |
+
|
174 |
+
return out_frames
|
175 |
+
|
176 |
+
def stop(self):
|
177 |
+
if self.ws is not None:
|
178 |
+
# TODO: this hangs. Figure out why.
|
179 |
+
#asyncio.get_event_loop().run_until_complete(self.ws.close())
|
180 |
+
print("Websocket closed")
|
181 |
+
self.ws = None
|
182 |
+
self.in_buffer = None
|
183 |
+
self.out_buffer = None
|
184 |
+
|
185 |
+
if __name__ == "__main__":
|
186 |
+
parser = argparse.ArgumentParser(description='Audio Client with Temperature Control')
|
187 |
+
parser.add_argument('--token_temp', '-t1', type=float, help='Token (LM) temperature parameter')
|
188 |
+
parser.add_argument('--categorical_temp', '-t2', type=float, help='Categorical (VAE) temperature parameter')
|
189 |
+
parser.add_argument('--gaussian_temp', '-t3', type=float, help='Gaussian (VAE) temperature parameter')
|
190 |
+
parser.add_argument('--server', '-s', default="ws://localhost:8000",
|
191 |
+
help='Server URL (default: ws://localhost:8000)')
|
192 |
+
parser.add_argument("--use_ice_servers", action="store_true", help="Use public STUN servers")
|
193 |
+
|
194 |
+
args = parser.parse_args()
|
195 |
+
|
196 |
+
# Audio settings
|
197 |
+
STREAMING_SAMPLE_RATE = 48000
|
198 |
+
SAMPLE_RATE = 16000
|
199 |
+
BLOCK_SIZE = 2000
|
200 |
+
CHANNELS = 1
|
201 |
+
|
202 |
+
st.title("hertz-dev webrtc demo!")
|
203 |
+
st.markdown("""
|
204 |
+
Welcome to the audio processing interface! Here you can talk live with hertz.
|
205 |
+
- Process audio in real-time through your microphone
|
206 |
+
- Adjust various temperature parameters for inference
|
207 |
+
- Test your microphone with sound check mode
|
208 |
+
- Enable/disable echo cancellation and noise suppression
|
209 |
+
|
210 |
+
To begin, click the START button below and allow microphone access.
|
211 |
+
""")
|
212 |
+
|
213 |
+
audio_client = st.session_state.get("audio_client")
|
214 |
+
if audio_client is None:
|
215 |
+
audio_client = AudioClient(
|
216 |
+
server_url=args.server,
|
217 |
+
token_temp=args.token_temp,
|
218 |
+
categorical_temp=args.categorical_temp,
|
219 |
+
gaussian_temp=args.gaussian_temp
|
220 |
+
)
|
221 |
+
st.session_state.audio_client = audio_client
|
222 |
+
|
223 |
+
with st.sidebar:
|
224 |
+
st.markdown("## Inference Settings")
|
225 |
+
token_temp_default = args.token_temp if args.token_temp is not None else 0.8
|
226 |
+
token_temp = st.slider("Token Temperature", 0.05, 2.0, token_temp_default, step=0.05)
|
227 |
+
categorical_temp_default = args.categorical_temp if args.categorical_temp is not None else 0.4
|
228 |
+
categorical_temp = st.slider("Categorical Temperature", 0.01, 1.0, categorical_temp_default, step=0.01)
|
229 |
+
gaussian_temp_default = args.gaussian_temp if args.gaussian_temp is not None else 0.1
|
230 |
+
gaussian_temp = st.slider("Gaussian Temperature", 0.01, 1.0, gaussian_temp_default, step=0.01)
|
231 |
+
if st.button("Set Temperatures"):
|
232 |
+
response_message = audio_client.set_temperature_and_echo(token_temp, categorical_temp, gaussian_temp)
|
233 |
+
st.write(response_message)
|
234 |
+
|
235 |
+
st.markdown("## Microphone Settings")
|
236 |
+
audio_client.sound_check = st.toggle("Sound Check (Echo)", value=False)
|
237 |
+
echo_cancellation = st.toggle("Echo Cancellation*‡", value=False)
|
238 |
+
noise_suppression = st.toggle("Noise Suppression*", value=False)
|
239 |
+
st.markdown(r"\* *Restart stream to take effect*")
|
240 |
+
st.markdown("‡ *May cause audio to cut out*")
|
241 |
+
|
242 |
+
# Use a free STUN server from Google if --use_ice_servers is given
|
243 |
+
# (found in get_ice_servers() at https://github.com/whitphx/streamlit-webrtc/blob/main/sample_utils/turn.py)
|
244 |
+
rtc_configuration = {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]} if args.use_ice_servers else None
|
245 |
+
audio_config = {"echoCancellation": echo_cancellation, "noiseSuppression": noise_suppression}
|
246 |
+
webrtc_streamer(
|
247 |
+
key="streamer",
|
248 |
+
mode=WebRtcMode.SENDRECV,
|
249 |
+
rtc_configuration=rtc_configuration,
|
250 |
+
media_stream_constraints={"audio": audio_config, "video": False},
|
251 |
+
queued_audio_frames_callback=audio_client.queued_audio_frames_callback,
|
252 |
+
on_audio_ended=audio_client.stop,
|
253 |
+
async_processing=True,
|
254 |
+
)
|
255 |
+
|
inference_server.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import numpy as np
|
3 |
+
from fastapi import FastAPI, WebSocket
|
4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
5 |
+
import base64
|
6 |
+
import uvicorn
|
7 |
+
import traceback
|
8 |
+
import numpy as np
|
9 |
+
import argparse
|
10 |
+
|
11 |
+
import torch as T
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import torchaudio
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import Optional
|
17 |
+
|
18 |
+
from utils import print_colored
|
19 |
+
from model import get_hertz_dev_config
|
20 |
+
|
21 |
+
|
22 |
+
argparse = argparse.ArgumentParser()
|
23 |
+
|
24 |
+
argparse.add_argument('--prompt_path', type=str, default='./prompts/bob_mono.wav', help="""
|
25 |
+
We highly recommend making your own prompt based on a conversation between you and another person.
|
26 |
+
bob_mono.wav seems to work better for two-channel than bob_stereo.wav.
|
27 |
+
""")
|
28 |
+
args = argparse.parse_args()
|
29 |
+
|
30 |
+
|
31 |
+
device = 'cuda' if T.cuda.is_available() else T.device('cpu')
|
32 |
+
print_colored(f"Using device: {device}", "grey")
|
33 |
+
|
34 |
+
model_config = get_hertz_dev_config(is_split=True)
|
35 |
+
|
36 |
+
model = model_config()
|
37 |
+
model = model.eval().bfloat16().to(device)
|
38 |
+
|
39 |
+
app = FastAPI()
|
40 |
+
|
41 |
+
app.add_middleware(
|
42 |
+
CORSMiddleware,
|
43 |
+
allow_origins=["*"],
|
44 |
+
allow_credentials=True,
|
45 |
+
allow_methods=["*"],
|
46 |
+
allow_headers=["*"],
|
47 |
+
)
|
48 |
+
|
49 |
+
|
50 |
+
# Hyperparams or something.
|
51 |
+
SAMPLE_RATE = 16000 # Don't change this
|
52 |
+
TEMPS = (0.8, (0.4, 0.1)) # You can change this, but there's also an endpoint for it.
|
53 |
+
REPLAY_SECONDS = 3 # What the user hears as context.
|
54 |
+
|
55 |
+
class AudioProcessor:
|
56 |
+
def __init__(self, model, prompt_path):
|
57 |
+
self.model = model
|
58 |
+
self.prompt_path = prompt_path
|
59 |
+
self.initialize_state(prompt_path)
|
60 |
+
|
61 |
+
def initialize_state(self, prompt_path):
|
62 |
+
loaded_audio, sr = torchaudio.load(prompt_path)
|
63 |
+
self.replay_seconds = REPLAY_SECONDS
|
64 |
+
|
65 |
+
if sr != SAMPLE_RATE:
|
66 |
+
resampler = torchaudio.transforms.Resample(sr, SAMPLE_RATE)
|
67 |
+
loaded_audio = resampler(loaded_audio)
|
68 |
+
|
69 |
+
if loaded_audio.shape[0] == 1:
|
70 |
+
loaded_audio = loaded_audio.repeat(2, 1)
|
71 |
+
|
72 |
+
audio_length = loaded_audio.shape[-1]
|
73 |
+
num_chunks = audio_length // 2000
|
74 |
+
loaded_audio = loaded_audio[..., :num_chunks * 2000]
|
75 |
+
|
76 |
+
self.loaded_audio = loaded_audio.to(device)
|
77 |
+
|
78 |
+
with T.autocast(device_type=device, dtype=T.bfloat16), T.inference_mode():
|
79 |
+
self.model.init_cache(bsize=1, device=device, dtype=T.bfloat16, length=1024)
|
80 |
+
self.next_model_audio = self.model.next_audio_from_audio(self.loaded_audio.unsqueeze(0), temps=TEMPS)
|
81 |
+
self.prompt_buffer = None
|
82 |
+
self.prompt_position = 0
|
83 |
+
self.chunks_until_live = int(self.replay_seconds * 8)
|
84 |
+
self.initialize_prompt_buffer()
|
85 |
+
print_colored("AudioProcessor state initialized", "green")
|
86 |
+
|
87 |
+
def initialize_prompt_buffer(self):
|
88 |
+
self.recorded_audio = self.loaded_audio
|
89 |
+
prompt_audio = self.loaded_audio.reshape(1, 2, -1)
|
90 |
+
prompt_audio = prompt_audio[:, :, -(16000*self.replay_seconds):].cpu().numpy()
|
91 |
+
prompt_audio_mono = prompt_audio.mean(axis=1)
|
92 |
+
self.prompt_buffer = np.array_split(prompt_audio_mono[0], int(self.replay_seconds * 8))
|
93 |
+
print_colored(f"Initialized prompt buffer with {len(self.prompt_buffer)} chunks", "grey")
|
94 |
+
|
95 |
+
async def process_audio(self, audio_data):
|
96 |
+
if self.chunks_until_live > 0:
|
97 |
+
print_colored(f"Serving from prompt buffer, {self.chunks_until_live} chunks left", "grey")
|
98 |
+
chunk = self.prompt_buffer[int(self.replay_seconds * 8) - self.chunks_until_live]
|
99 |
+
self.chunks_until_live -= 1
|
100 |
+
|
101 |
+
if self.chunks_until_live == 0:
|
102 |
+
print_colored("Switching to live processing mode", "green")
|
103 |
+
|
104 |
+
time.sleep(0.05)
|
105 |
+
return chunk
|
106 |
+
|
107 |
+
audio_tensor = T.from_numpy(audio_data).to(device)
|
108 |
+
audio_tensor = audio_tensor.reshape(1, 1, -1)
|
109 |
+
audio_tensor = T.cat([audio_tensor, self.next_model_audio], dim=1)
|
110 |
+
|
111 |
+
with T.autocast(device_type=device, dtype=T.bfloat16), T.inference_mode():
|
112 |
+
curr_model_audio = self.model.next_audio_from_audio(
|
113 |
+
audio_tensor,
|
114 |
+
temps=TEMPS
|
115 |
+
)
|
116 |
+
print(f"Recorded audio shape {self.recorded_audio.shape}, audio tensor shape {audio_tensor.shape}")
|
117 |
+
self.recorded_audio = T.cat([self.recorded_audio.cpu(), audio_tensor.squeeze(0).cpu()], dim=-1)
|
118 |
+
|
119 |
+
self.next_model_audio = curr_model_audio
|
120 |
+
|
121 |
+
return curr_model_audio.float().cpu().numpy()
|
122 |
+
|
123 |
+
def cleanup(self):
|
124 |
+
print_colored("Cleaning up audio processor...", "blue")
|
125 |
+
os.makedirs('audio_recordings', exist_ok=True)
|
126 |
+
torchaudio.save(f'audio_recordings/{time.strftime("%d-%H-%M")}.wav', self.recorded_audio.cpu(), SAMPLE_RATE)
|
127 |
+
self.model.deinit_cache()
|
128 |
+
self.initialize_state(self.prompt_path)
|
129 |
+
print_colored("Audio processor cleanup complete", "green")
|
130 |
+
|
131 |
+
@app.post("/set_temperature")
|
132 |
+
async def set_temperature(token_temp: Optional[float] = None, categorical_temp: Optional[float] = None, gaussian_temp: Optional[float] = None):
|
133 |
+
try:
|
134 |
+
global TEMPS
|
135 |
+
TEMPS = (token_temp, (categorical_temp, gaussian_temp))
|
136 |
+
|
137 |
+
print_colored(f"Temperature updated to: {TEMPS}", "green")
|
138 |
+
return {"message": f"Temperature updated to: {TEMPS}", "status": "success"}
|
139 |
+
except Exception as e:
|
140 |
+
print_colored(f"Error setting temperature: {str(e)}", "red")
|
141 |
+
return {"message": f"Error setting temperature: {str(e)}", "status": "error"}
|
142 |
+
|
143 |
+
@app.websocket("/audio")
|
144 |
+
async def websocket_endpoint(websocket: WebSocket):
|
145 |
+
await websocket.accept()
|
146 |
+
try:
|
147 |
+
while True:
|
148 |
+
data = await websocket.receive_text()
|
149 |
+
audio_data = np.frombuffer(
|
150 |
+
base64.b64decode(data.split(",")[1]),
|
151 |
+
dtype=np.int16
|
152 |
+
)
|
153 |
+
audio_data = audio_data.astype(np.float32) / 32767.0
|
154 |
+
processed_audio = await audio_processor.process_audio(audio_data)
|
155 |
+
processed_audio = (processed_audio * 32767).astype(np.int16)
|
156 |
+
|
157 |
+
processed_data = base64.b64encode(processed_audio.tobytes()).decode('utf-8')
|
158 |
+
await websocket.send_text(f"data:audio/raw;base64,{processed_data}")
|
159 |
+
|
160 |
+
except Exception as e:
|
161 |
+
print_colored(f"WebSocket error: {e}", "red")
|
162 |
+
print_colored(f"Full traceback:\n{traceback.format_exc()}", "red")
|
163 |
+
finally:
|
164 |
+
audio_processor.cleanup()
|
165 |
+
await websocket.close()
|
166 |
+
|
167 |
+
|
168 |
+
audio_processor = AudioProcessor(model=model, prompt_path=args.prompt_path)
|
169 |
+
|
170 |
+
if __name__ == "__main__":
|
171 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
172 |
+
print("Server started")
|
ioblocks.py
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
from functools import partial
|
3 |
+
from contextlib import nullcontext
|
4 |
+
from typing import List, Tuple
|
5 |
+
from math import ceil
|
6 |
+
|
7 |
+
import torch as T
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch.distributed as dist
|
11 |
+
from torch import Tensor, int32
|
12 |
+
from torch.amp import autocast
|
13 |
+
|
14 |
+
from einops import rearrange, pack, unpack
|
15 |
+
|
16 |
+
|
17 |
+
from utils import si_module, exists, default, maybe
|
18 |
+
|
19 |
+
|
20 |
+
@si_module
|
21 |
+
class GaussianMixtureIOLayer(nn.Module):
|
22 |
+
class Config:
|
23 |
+
latent_dim: int
|
24 |
+
dim: int
|
25 |
+
num_components: int
|
26 |
+
|
27 |
+
def __init__(self, c: Config):
|
28 |
+
super().__init__()
|
29 |
+
self.latent_dim = c.latent_dim
|
30 |
+
self.num_components = c.num_components
|
31 |
+
self.input_projection = nn.Linear(c.latent_dim, c.dim)
|
32 |
+
|
33 |
+
self.fc_loc = nn.Linear(c.dim, c.num_components * c.latent_dim)
|
34 |
+
self.fc_scale = nn.Linear(c.dim, c.num_components * c.latent_dim)
|
35 |
+
self.fc_weight = nn.Linear(c.dim, c.num_components)
|
36 |
+
|
37 |
+
def _square_plus(self, x):
|
38 |
+
return (x + T.sqrt(T.square(x) + 4)) / 2
|
39 |
+
|
40 |
+
def input(self, sampled_latents: T.Tensor) -> T.Tensor:
|
41 |
+
"""Pre-sampled latents T.Tensor (B, L, Z) -> float tensor (B, L, D)"""
|
42 |
+
hidden = self.input_projection(sampled_latents)
|
43 |
+
return hidden
|
44 |
+
|
45 |
+
def output(self, h: T.Tensor) -> Tuple[T.Tensor, T.Tensor, T.Tensor]:
|
46 |
+
"""float tensor (B, L, D) -> Tuple of locs, scales, and weights"""
|
47 |
+
batch_size, seq_len, _ = h.shape
|
48 |
+
|
49 |
+
locs = self.fc_loc(h).view(batch_size, seq_len, self.num_components, self.latent_dim)
|
50 |
+
scales = T.clamp(self._square_plus(self.fc_scale(h)), min=1e-6).view(batch_size, seq_len, self.num_components, self.latent_dim)
|
51 |
+
weights = self.fc_weight(h).view(batch_size, seq_len, self.num_components)
|
52 |
+
|
53 |
+
return (locs, scales, weights)
|
54 |
+
|
55 |
+
def loss(self, data, dataHat):
|
56 |
+
locs, scales, weights = dataHat
|
57 |
+
log_probs = -0.5 * T.sum(
|
58 |
+
(data.unsqueeze(-2) - locs).pow(2) / scales.pow(2) +
|
59 |
+
2 * T.log(scales) +
|
60 |
+
T.log(T.tensor(2 * T.pi)),
|
61 |
+
dim=-1
|
62 |
+
)
|
63 |
+
log_weights = F.log_softmax(weights, dim=-1)
|
64 |
+
return -T.logsumexp(log_weights + log_probs, dim=-1)
|
65 |
+
|
66 |
+
|
67 |
+
def temp_sample(self, orig_pdist, temp):
|
68 |
+
locs, scales, weights = orig_pdist
|
69 |
+
if temp is None:
|
70 |
+
component_samples = locs + scales * T.randn_like(scales)
|
71 |
+
mixture_samples = F.gumbel_softmax(weights, hard=True)
|
72 |
+
sampled = (component_samples * mixture_samples.unsqueeze(-1)).sum(dim=-2)
|
73 |
+
elif isinstance(temp, tuple):
|
74 |
+
assert len(temp) == 2
|
75 |
+
categorical_temp, gaussian_temp = temp
|
76 |
+
component_samples = locs + scales * gaussian_temp * T.randn_like(scales)
|
77 |
+
mixture_samples = F.gumbel_softmax(weights / categorical_temp, hard=True)
|
78 |
+
sampled = (component_samples * mixture_samples.unsqueeze(-1)).sum(dim=-2)
|
79 |
+
else:
|
80 |
+
component_samples = locs + scales * temp * T.randn_like(scales)
|
81 |
+
mixture_samples = F.gumbel_softmax(weights / temp, hard=True)
|
82 |
+
sampled = (component_samples * mixture_samples.unsqueeze(-1)).sum(dim=-2)
|
83 |
+
return sampled
|
84 |
+
|
85 |
+
|
86 |
+
class GPTOutput(nn.Module):
|
87 |
+
def __init__(self, dim, vocab_size):
|
88 |
+
super().__init__()
|
89 |
+
self.output = nn.Linear(dim, vocab_size, bias=False)
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
return self.output(x)
|
93 |
+
|
94 |
+
|
95 |
+
# helper functions
|
96 |
+
|
97 |
+
def pack_one(t, pattern):
|
98 |
+
return pack([t], pattern)
|
99 |
+
|
100 |
+
def unpack_one(t, ps, pattern):
|
101 |
+
return unpack(t, ps, pattern)[0]
|
102 |
+
|
103 |
+
def first(l):
|
104 |
+
return l[0]
|
105 |
+
|
106 |
+
def round_up_multiple(num, mult):
|
107 |
+
return ceil(num / mult) * mult
|
108 |
+
|
109 |
+
def get_code_utilization(codes, codebook_size, get_global=False):
|
110 |
+
if get_global and dist.is_initialized():
|
111 |
+
world_size = dist.get_world_size()
|
112 |
+
else:
|
113 |
+
world_size = 1
|
114 |
+
|
115 |
+
if world_size > 1:
|
116 |
+
gathered_tokens = [T.zeros_like(codes) for _ in range(world_size)]
|
117 |
+
dist.all_gather(gathered_tokens, codes)
|
118 |
+
gathered_tokens = T.cat(gathered_tokens, dim=0)
|
119 |
+
else:
|
120 |
+
gathered_tokens = codes
|
121 |
+
unique_tokens = len(T.unique(gathered_tokens))
|
122 |
+
code_utilization = unique_tokens / min(gathered_tokens.numel(), codebook_size)
|
123 |
+
return code_utilization
|
124 |
+
|
125 |
+
# tensor helpers
|
126 |
+
|
127 |
+
def round_ste(z: Tensor) -> Tensor:
|
128 |
+
"""Round with straight through gradients."""
|
129 |
+
zhat = z.round()
|
130 |
+
return z + (zhat - z).detach()
|
131 |
+
|
132 |
+
# main class
|
133 |
+
# lucidrains fsq
|
134 |
+
@si_module
|
135 |
+
class FSQ(nn.Module):
|
136 |
+
@property
|
137 |
+
def needs_float32_params(self):
|
138 |
+
return True
|
139 |
+
|
140 |
+
class Config:
|
141 |
+
levels: List[int]
|
142 |
+
dim: int | None = None
|
143 |
+
num_codebooks: int = 1
|
144 |
+
keep_num_codebooks_dim: bool | None = None
|
145 |
+
scale: float | None = None
|
146 |
+
allowed_dtypes: Tuple[str, ...] = ('float32', 'float64')
|
147 |
+
channel_first: bool = False
|
148 |
+
projection_has_bias: bool = True
|
149 |
+
return_indices: bool = True
|
150 |
+
force_quantization_f32: bool = True
|
151 |
+
use_rms: bool = False
|
152 |
+
|
153 |
+
def __init__(self, c: Config):
|
154 |
+
super().__init__()
|
155 |
+
_levels = T.tensor(c.levels, dtype=int32)
|
156 |
+
self.register_buffer("_levels", _levels, persistent = False)
|
157 |
+
|
158 |
+
_basis = T.cumprod(T.tensor([1] + c.levels[:-1]), dim=0, dtype=int32)
|
159 |
+
self.register_buffer("_basis", _basis, persistent = False)
|
160 |
+
|
161 |
+
self.scale = c.scale
|
162 |
+
|
163 |
+
codebook_dim = len(c.levels)
|
164 |
+
self.codebook_dim = codebook_dim
|
165 |
+
|
166 |
+
effective_codebook_dim = codebook_dim * c.num_codebooks
|
167 |
+
self.num_codebooks = c.num_codebooks
|
168 |
+
|
169 |
+
self.allowed_dtypes = []
|
170 |
+
for dtype_str in c.allowed_dtypes:
|
171 |
+
if hasattr(T, dtype_str):
|
172 |
+
self.allowed_dtypes.append(getattr(T, dtype_str))
|
173 |
+
else:
|
174 |
+
raise ValueError(f"Invalid dtype string: {dtype_str}")
|
175 |
+
|
176 |
+
self.effective_codebook_dim = effective_codebook_dim
|
177 |
+
|
178 |
+
keep_num_codebooks_dim = default(c.keep_num_codebooks_dim, c.num_codebooks > 1)
|
179 |
+
assert not (c.num_codebooks > 1 and not keep_num_codebooks_dim)
|
180 |
+
self.keep_num_codebooks_dim = keep_num_codebooks_dim
|
181 |
+
|
182 |
+
self.dim = default(c.dim, len(_levels) * c.num_codebooks)
|
183 |
+
|
184 |
+
self.channel_first = c.channel_first
|
185 |
+
|
186 |
+
has_projections = self.dim != effective_codebook_dim
|
187 |
+
self.project_in = nn.Linear(self.dim, effective_codebook_dim, bias = c.projection_has_bias) if has_projections else nn.Identity()
|
188 |
+
self.project_out = nn.Linear(effective_codebook_dim, self.dim, bias = c.projection_has_bias) if has_projections else nn.Identity()
|
189 |
+
|
190 |
+
self.has_projections = has_projections
|
191 |
+
|
192 |
+
self.return_indices = c.return_indices
|
193 |
+
if c.return_indices:
|
194 |
+
self.codebook_size = self._levels.prod().item()
|
195 |
+
implicit_codebook = self._indices_to_codes(T.arange(self.codebook_size))
|
196 |
+
self.register_buffer("implicit_codebook", implicit_codebook, persistent = False)
|
197 |
+
|
198 |
+
self.allowed_dtypes = c.allowed_dtypes
|
199 |
+
self.force_quantization_f32 = c.force_quantization_f32
|
200 |
+
|
201 |
+
self.latent_loss = None
|
202 |
+
|
203 |
+
def latent_metric(self, codes, get_global=False):
|
204 |
+
return {'code_util_estimate': get_code_utilization(codes, self.codebook_size, get_global)}
|
205 |
+
|
206 |
+
def repr_from_latent(self, latent):
|
207 |
+
return self.indices_to_codes(latent)
|
208 |
+
|
209 |
+
def bound(self, z, eps: float = 1e-3):
|
210 |
+
""" Bound `z`, an array of shape (..., d). """
|
211 |
+
half_l = (self._levels - 1) * (1 + eps) / 2
|
212 |
+
offset = T.where(self._levels % 2 == 0, 0.5, 0.0)
|
213 |
+
shift = (offset / half_l).atanh()
|
214 |
+
return (z + shift).tanh() * half_l - offset
|
215 |
+
|
216 |
+
def quantize(self, z):
|
217 |
+
""" Quantizes z, returns quantized zhat, same shape as z. """
|
218 |
+
quantized = round_ste(self.bound(z))
|
219 |
+
half_width = self._levels // 2 # Renormalize to [-1, 1].
|
220 |
+
return quantized / half_width
|
221 |
+
|
222 |
+
def _scale_and_shift(self, zhat_normalized):
|
223 |
+
half_width = self._levels // 2
|
224 |
+
return (zhat_normalized * half_width) + half_width
|
225 |
+
|
226 |
+
def _scale_and_shift_inverse(self, zhat):
|
227 |
+
half_width = self._levels // 2
|
228 |
+
return (zhat - half_width) / half_width
|
229 |
+
|
230 |
+
def _indices_to_codes(self, indices):
|
231 |
+
level_indices = self.indices_to_level_indices(indices)
|
232 |
+
codes = self._scale_and_shift_inverse(level_indices)
|
233 |
+
return codes
|
234 |
+
|
235 |
+
def codes_to_indices(self, zhat):
|
236 |
+
""" Converts a `code` to an index in the codebook. """
|
237 |
+
assert zhat.shape[-1] == self.codebook_dim
|
238 |
+
zhat = self._scale_and_shift(zhat)
|
239 |
+
return (zhat * self._basis).sum(dim=-1).to(int32)
|
240 |
+
|
241 |
+
def indices_to_level_indices(self, indices):
|
242 |
+
""" Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings """
|
243 |
+
indices = rearrange(indices, '... -> ... 1')
|
244 |
+
codes_non_centered = (indices // self._basis) % self._levels
|
245 |
+
return codes_non_centered
|
246 |
+
|
247 |
+
def indices_to_codes(self, indices):
|
248 |
+
""" Inverse of `codes_to_indices`. """
|
249 |
+
assert exists(indices)
|
250 |
+
|
251 |
+
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
|
252 |
+
|
253 |
+
codes = self._indices_to_codes(indices)
|
254 |
+
|
255 |
+
if self.keep_num_codebooks_dim:
|
256 |
+
codes = rearrange(codes, '... c d -> ... (c d)')
|
257 |
+
|
258 |
+
codes = self.project_out(codes)
|
259 |
+
|
260 |
+
if is_img_or_video or self.channel_first:
|
261 |
+
codes = rearrange(codes, 'b ... d -> b d ...')
|
262 |
+
|
263 |
+
return codes
|
264 |
+
|
265 |
+
# @autocast(device_type='cuda', enabled = False)
|
266 |
+
def forward(self, z, return_codes=False):
|
267 |
+
"""
|
268 |
+
einstein notation
|
269 |
+
b - batch
|
270 |
+
n - sequence (or flattened spatial dimensions)
|
271 |
+
d - feature dimension
|
272 |
+
c - number of codebook dim
|
273 |
+
"""
|
274 |
+
|
275 |
+
is_img_or_video = z.ndim >= 4
|
276 |
+
need_move_channel_last = is_img_or_video or self.channel_first
|
277 |
+
|
278 |
+
# standardize image or video into (batch, seq, dimension)
|
279 |
+
|
280 |
+
if need_move_channel_last:
|
281 |
+
z = rearrange(z, 'b d ... -> b ... d')
|
282 |
+
z, ps = pack_one(z, 'b * d')
|
283 |
+
|
284 |
+
assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'
|
285 |
+
|
286 |
+
z = self.project_in(z)
|
287 |
+
|
288 |
+
z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks)
|
289 |
+
|
290 |
+
# whether to force quantization step to be full precision or not
|
291 |
+
|
292 |
+
force_f32 = self.force_quantization_f32
|
293 |
+
quantization_context = partial(autocast, device_type='cuda', enabled = False) if force_f32 else nullcontext
|
294 |
+
|
295 |
+
with quantization_context():
|
296 |
+
orig_dtype = z.dtype
|
297 |
+
|
298 |
+
if force_f32 and orig_dtype not in self.allowed_dtypes:
|
299 |
+
z = z.float()
|
300 |
+
|
301 |
+
codes = self.quantize(z)
|
302 |
+
|
303 |
+
# returning indices could be optional
|
304 |
+
|
305 |
+
indices = None
|
306 |
+
|
307 |
+
if self.return_indices:
|
308 |
+
indices = self.codes_to_indices(codes)
|
309 |
+
|
310 |
+
codes = rearrange(codes, 'b n c d -> b n (c d)')
|
311 |
+
|
312 |
+
codes = codes.type(orig_dtype)
|
313 |
+
|
314 |
+
# project out
|
315 |
+
if return_codes:
|
316 |
+
return codes, indices
|
317 |
+
|
318 |
+
out = self.project_out(codes)
|
319 |
+
|
320 |
+
# reconstitute image or video dimensions
|
321 |
+
|
322 |
+
if need_move_channel_last:
|
323 |
+
out = unpack_one(out, ps, 'b * d')
|
324 |
+
out = rearrange(out, 'b ... d -> b d ...')
|
325 |
+
|
326 |
+
indices = maybe(unpack_one)(indices, ps, 'b * c')
|
327 |
+
|
328 |
+
if not self.keep_num_codebooks_dim and self.return_indices:
|
329 |
+
indices = maybe(rearrange)(indices, '... 1 -> ...')
|
330 |
+
|
331 |
+
# return quantized output and indices
|
332 |
+
|
333 |
+
return out, indices
|
model.py
ADDED
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
|
3 |
+
import torch as T
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from ioblocks import GaussianMixtureIOLayer, FSQ
|
8 |
+
|
9 |
+
from transformer import Stack, ShapeRotator, Block as PerfBlock, GPTOutput, CACHE_FILL_VALUE, FFNN, Norm
|
10 |
+
from tokenizer import make_tokenizer
|
11 |
+
|
12 |
+
|
13 |
+
from utils import si_module, exists, isnt, tqdm0, print0, default, print0_colored
|
14 |
+
from utils import load_ckpt
|
15 |
+
|
16 |
+
|
17 |
+
@si_module
|
18 |
+
class LatentQuantizer(nn.Module):
|
19 |
+
class Config:
|
20 |
+
compressor_config: Optional[FSQ.Config] = None
|
21 |
+
|
22 |
+
dim: Optional[int] = None
|
23 |
+
ff_dim: Optional[int] = None
|
24 |
+
input_dim: int = None
|
25 |
+
|
26 |
+
from_pretrained: Optional[Tuple[str, str]] = None
|
27 |
+
|
28 |
+
def __init__(self, c: Config):
|
29 |
+
super().__init__()
|
30 |
+
|
31 |
+
if exists(c.from_pretrained):
|
32 |
+
checkpoint = load_ckpt(*c.from_pretrained)
|
33 |
+
else:
|
34 |
+
assert exists(c.compressor_config), f'hmm {c}'
|
35 |
+
|
36 |
+
self.compressor = c.compressor_config()
|
37 |
+
self.ffnn = FFNN(c.dim, c.ff_dim)
|
38 |
+
self.input = nn.Linear(c.input_dim, c.dim) if exists(c.input_dim) else nn.Identity()
|
39 |
+
|
40 |
+
if exists(c.from_pretrained):
|
41 |
+
self.load_state_dict(checkpoint)
|
42 |
+
|
43 |
+
@T.no_grad()
|
44 |
+
def forward(self, x, return_latent=False, known_latent=None):
|
45 |
+
"""
|
46 |
+
x: (B, S, D)
|
47 |
+
"""
|
48 |
+
if exists(known_latent):
|
49 |
+
return self.compressor.indices_to_codes(known_latent)
|
50 |
+
|
51 |
+
x = self.input(x)
|
52 |
+
x = self.ffnn(x)
|
53 |
+
x, tokens = self.compressor(x)
|
54 |
+
|
55 |
+
if return_latent:
|
56 |
+
return x, tokens
|
57 |
+
return x
|
58 |
+
|
59 |
+
|
60 |
+
@si_module
|
61 |
+
class TransformerVAE(nn.Module):
|
62 |
+
class Config:
|
63 |
+
io_config: Optional[GaussianMixtureIOLayer.Config] = None
|
64 |
+
stack_config: Optional[Stack.Config] = None
|
65 |
+
quantizer_config: Optional[LatentQuantizer.Config] = None
|
66 |
+
|
67 |
+
plex_layer: int = None
|
68 |
+
plex_roll: int = 1
|
69 |
+
split: bool = True
|
70 |
+
|
71 |
+
from_pretrained: Optional[Tuple[str, str]] = None
|
72 |
+
|
73 |
+
def __init__(self, c: Config):
|
74 |
+
super().__init__()
|
75 |
+
|
76 |
+
if exists(c.from_pretrained):
|
77 |
+
checkpoint = load_ckpt(*c.from_pretrained)
|
78 |
+
else:
|
79 |
+
assert (exists(c.io_config) and exists(c.stack_config) and exists(c.quantizer_config)), f'hmm {c}'
|
80 |
+
|
81 |
+
self.io = c.io_config()
|
82 |
+
self.stack = c.stack_config()
|
83 |
+
|
84 |
+
self.plex_layer = c.stack_config.layers//2
|
85 |
+
self.plex_roll = c.plex_roll
|
86 |
+
self.plex_dim = c.quantizer_config.dim
|
87 |
+
|
88 |
+
assert self.plex_dim is not None and c.stack_config.dim is not None, f'One of the following are None: self.plex_dim: {self.plex_dim}, c.stack_config.dim: {c.stack_config.dim}'
|
89 |
+
self.plex_projection = nn.Linear(self.plex_dim, c.stack_config.dim)
|
90 |
+
self.out_norm = Norm(c.stack_config.dim)
|
91 |
+
|
92 |
+
if c.split:
|
93 |
+
self.io2 = c.io_config()
|
94 |
+
self.plex_projection2 = nn.Linear(self.plex_dim, c.stack_config.dim)
|
95 |
+
|
96 |
+
self.io2.fc_loc = None
|
97 |
+
self.io2.fc_scale = None
|
98 |
+
self.io2.fc_weight = None
|
99 |
+
|
100 |
+
kv_heads = c.stack_config.kv_heads or c.stack_config.n_head
|
101 |
+
head_dim = c.stack_config.dim // c.stack_config.n_head
|
102 |
+
self.cache_num_layers = c.stack_config.layers + ((c.stack_config.layers - self.plex_layer) if c.split else 0)
|
103 |
+
cache_shape = [self.cache_num_layers, c.stack_config.seq_len, 2, kv_heads, head_dim]
|
104 |
+
self.cache_shape = cache_shape
|
105 |
+
self.cache = [None] * self.cache_num_layers
|
106 |
+
|
107 |
+
if exists(c.from_pretrained):
|
108 |
+
result = self.load_state_dict(checkpoint, strict=False)
|
109 |
+
print0_colored(result, 'yellow')
|
110 |
+
|
111 |
+
self.quantizer = c.quantizer_config().eval()
|
112 |
+
self.quantizer.requires_grad = False
|
113 |
+
|
114 |
+
@T.no_grad()
|
115 |
+
def quantize(self, x):
|
116 |
+
if self.c.split:
|
117 |
+
x1, x2 = x.chunk(2, dim=-1)
|
118 |
+
with T.autocast(device_type='cuda', dtype=T.bfloat16):
|
119 |
+
quantized1 = self.quantizer(x1)
|
120 |
+
quantized2 = self.quantizer(x2)
|
121 |
+
return quantized1, quantized2
|
122 |
+
else:
|
123 |
+
with T.autocast(device_type='cuda', dtype=T.bfloat16):
|
124 |
+
return self.quantizer(x)
|
125 |
+
|
126 |
+
@T.no_grad()
|
127 |
+
def untokenize(self, token_data):
|
128 |
+
return self.quantizer(None, known_latent=token_data)
|
129 |
+
|
130 |
+
def init_cache(self, bsize, device, dtype, length:int=None):
|
131 |
+
cache_shape = self.cache_shape.copy()
|
132 |
+
cache_shape[1] = length or cache_shape[1]
|
133 |
+
self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1)
|
134 |
+
|
135 |
+
def deinit_cache(self):
|
136 |
+
self.cache = [None] * self.cache_num_layers
|
137 |
+
|
138 |
+
@T.no_grad()
|
139 |
+
def forward(self, data, next_tokens: Optional[Tuple[T.Tensor, T.Tensor]] = None, temps: Optional[Tuple[float, Tuple[float, float]]] = None):
|
140 |
+
if self.c.split:
|
141 |
+
x1, x2 = data.chunk(2, dim=-1)
|
142 |
+
x = self.io.input(x1) + self.io2.input(x2)
|
143 |
+
else:
|
144 |
+
x = self.io.input(data)
|
145 |
+
|
146 |
+
cache_idx = 0
|
147 |
+
for l, layer in enumerate(self.stack.layers):
|
148 |
+
if l == self.plex_layer:
|
149 |
+
if self.c.split:
|
150 |
+
plex1, plex2 = self.quantize(data)
|
151 |
+
plex1 = T.roll(plex1, -self.c.plex_roll, dims=1)
|
152 |
+
plex2 = T.roll(plex2, -self.c.plex_roll, dims=1)
|
153 |
+
if exists(next_tokens):
|
154 |
+
plex1[:, -1:] = self.untokenize(next_tokens[0])
|
155 |
+
plex2[:, -1:] = self.untokenize(next_tokens[1])
|
156 |
+
x1 = x + self.plex_projection(plex1)
|
157 |
+
x2 = x + self.plex_projection2(plex2)
|
158 |
+
else:
|
159 |
+
plex = self.quantize(data)
|
160 |
+
plex = T.roll(plex, -self.c.plex_roll, dims=1)
|
161 |
+
if exists(next_tokens):
|
162 |
+
plex[:, -1:] = self.untokenize(next_tokens)
|
163 |
+
x = x + self.plex_projection(plex)
|
164 |
+
|
165 |
+
if l < self.plex_layer:
|
166 |
+
x = layer(x, kv=self.cache[l])
|
167 |
+
else:
|
168 |
+
if self.c.split:
|
169 |
+
x1 = layer(x1, kv=self.cache[self.plex_layer + cache_idx])
|
170 |
+
cache_idx += 1
|
171 |
+
x2 = layer(x2, kv=self.cache[self.plex_layer + cache_idx])
|
172 |
+
cache_idx += 1
|
173 |
+
else:
|
174 |
+
x = layer(x, kv=self.cache[l])
|
175 |
+
|
176 |
+
with T.autocast(device_type='cuda', dtype=T.bfloat16):
|
177 |
+
if self.c.split:
|
178 |
+
x1, x2 = self.out_norm(x1), self.out_norm(x2)
|
179 |
+
out1, out2 = self.io.output(x1), self.io.output(x2)
|
180 |
+
else:
|
181 |
+
x = self.out_norm(x)
|
182 |
+
out = self.io.output(x)
|
183 |
+
|
184 |
+
if isnt(temps):
|
185 |
+
if self.c.split:
|
186 |
+
return out1, out2
|
187 |
+
else:
|
188 |
+
return out
|
189 |
+
else:
|
190 |
+
if self.c.split:
|
191 |
+
next_data1 = self.io.temp_sample(out1, temps)[:, -1:, :]
|
192 |
+
next_data2 = self.io2.temp_sample(out2, temps)[:, -1:, :]
|
193 |
+
next_data = T.cat([next_data1, next_data2], dim=-1)
|
194 |
+
return next_data
|
195 |
+
else:
|
196 |
+
next_data = self.io.temp_sample(out, temps)[:, -1:, :]
|
197 |
+
return next_data
|
198 |
+
|
199 |
+
@si_module
|
200 |
+
class HertzDevModel(nn.Module):
|
201 |
+
class Config:
|
202 |
+
dim: int
|
203 |
+
vocab_size: int
|
204 |
+
stack_config: Optional[Stack.Config] = None
|
205 |
+
latent_size: int = 32
|
206 |
+
|
207 |
+
split: bool = True
|
208 |
+
|
209 |
+
quantizer_config: Optional[LatentQuantizer.Config] = None
|
210 |
+
resynthesizer_config: Optional[TransformerVAE.Config] = None
|
211 |
+
|
212 |
+
from_pretrained: Optional[Tuple[str, str]] = None
|
213 |
+
|
214 |
+
def __init__(self, c: Config):
|
215 |
+
super().__init__()
|
216 |
+
|
217 |
+
if exists(c.from_pretrained):
|
218 |
+
checkpoint = load_ckpt(*c.from_pretrained)
|
219 |
+
else:
|
220 |
+
assert (exists(c.stack_config)), f'hmm {c}'
|
221 |
+
|
222 |
+
self.input = nn.Linear(c.latent_size, c.dim)
|
223 |
+
if self.c.split:
|
224 |
+
self.input2 = nn.Linear(c.latent_size, c.dim)
|
225 |
+
|
226 |
+
self.shape_rotator = ShapeRotator(c.stack_config.dim//c.stack_config.n_head, c.stack_config.seq_len, theta=c.stack_config.theta)
|
227 |
+
|
228 |
+
self.layers = nn.ModuleList([
|
229 |
+
PerfBlock(
|
230 |
+
dim=c.stack_config.dim,
|
231 |
+
layer_id=l,
|
232 |
+
n_head=c.stack_config.n_head,
|
233 |
+
kv_heads=c.stack_config.kv_heads,
|
234 |
+
ff_dim=c.stack_config.ff_dim,
|
235 |
+
eps=c.stack_config.eps,
|
236 |
+
shape_rotator=self.shape_rotator,
|
237 |
+
) for l in range(c.stack_config.layers)
|
238 |
+
])
|
239 |
+
|
240 |
+
self.output = GPTOutput(c.dim, c.vocab_size)
|
241 |
+
if self.c.split:
|
242 |
+
self.output2 = GPTOutput(c.dim, c.vocab_size)
|
243 |
+
|
244 |
+
self.cache = [None] * c.stack_config.layers
|
245 |
+
self.kv_heads = c.stack_config.kv_heads or c.stack_config.n_head
|
246 |
+
self.head_dim = c.stack_config.dim // c.stack_config.n_head
|
247 |
+
|
248 |
+
if exists(c.from_pretrained):
|
249 |
+
result = self.load_state_dict(checkpoint, strict=False)
|
250 |
+
print0_colored(result, 'yellow')
|
251 |
+
|
252 |
+
self.resynthesizer = c.resynthesizer_config().eval()
|
253 |
+
self.resynthesizer.requires_grad = False
|
254 |
+
|
255 |
+
self.audio_tokenizer = make_tokenizer(device='cpu')
|
256 |
+
self.audio_cache = None
|
257 |
+
self.audio_latent_cache = None
|
258 |
+
self.use_audio_cache = False
|
259 |
+
|
260 |
+
@T.no_grad()
|
261 |
+
def tokenize(self, audio_data):
|
262 |
+
orig_audio_shape = audio_data.shape
|
263 |
+
if exists(self.audio_cache):
|
264 |
+
audio_data = T.cat([self.audio_cache, audio_data], dim=-1)
|
265 |
+
self.audio_cache = audio_data[..., -(6*16_000):]
|
266 |
+
elif self.use_audio_cache:
|
267 |
+
self.audio_cache = audio_data[..., -(6*16_000):]
|
268 |
+
|
269 |
+
if audio_data.shape[1] == 2:
|
270 |
+
enc_ch1 = self.audio_tokenizer.latent_from_data(audio_data[:, 0:1])
|
271 |
+
enc_ch2 = self.audio_tokenizer.latent_from_data(audio_data[:, 1:2])
|
272 |
+
return T.cat([enc_ch1, enc_ch2], dim=-1)[:, -(orig_audio_shape[-1]//2000):]
|
273 |
+
else:
|
274 |
+
return self.audio_tokenizer.latent_from_data(audio_data)[:, -(orig_audio_shape[-1]//2000):]
|
275 |
+
|
276 |
+
@T.no_grad()
|
277 |
+
def untokenize(self, token_data):
|
278 |
+
if exists(self.audio_latent_cache):
|
279 |
+
token_data = T.cat([self.audio_latent_cache, token_data], dim=1)
|
280 |
+
self.audio_latent_cache = token_data[:, -(6*8):]
|
281 |
+
elif self.use_audio_cache:
|
282 |
+
self.audio_latent_cache = token_data[:, -(6*8):]
|
283 |
+
|
284 |
+
if token_data.shape[-1] == 2*self.c.latent_size:
|
285 |
+
dec_ch1 = self.audio_tokenizer.data_from_latent(token_data[:, :self.c.latent_size])
|
286 |
+
dec_ch2 = self.audio_tokenizer.data_from_latent(token_data[:, self.c.latent_size:])
|
287 |
+
return T.cat([dec_ch1, dec_ch2], dim=1)[..., -(token_data.shape[1]*2000):]
|
288 |
+
else:
|
289 |
+
return self.audio_tokenizer.data_from_latent(token_data)[..., -(token_data.shape[1]*2000):]
|
290 |
+
|
291 |
+
def init_cache(self, bsize, device, dtype, length:int=None):
|
292 |
+
cache_shape = [self.c.stack_config.layers, length or self.c.stack_config.seq_len, 2, self.kv_heads, self.head_dim]
|
293 |
+
self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1)
|
294 |
+
self.resynthesizer.init_cache(bsize, device, dtype, length)
|
295 |
+
self.use_audio_cache = True
|
296 |
+
|
297 |
+
def deinit_cache(self):
|
298 |
+
self.cache = [None] * len(self.layers)
|
299 |
+
self.resynthesizer.deinit_cache()
|
300 |
+
self.audio_cache = None
|
301 |
+
self.audio_latent_cache = None
|
302 |
+
self.use_audio_cache = False
|
303 |
+
|
304 |
+
@T.no_grad()
|
305 |
+
def forward(self, data):
|
306 |
+
if self.c.split:
|
307 |
+
x1, x2 = data.chunk(2, dim=-1)
|
308 |
+
x = self.input(x1) + self.input2(x2)
|
309 |
+
else:
|
310 |
+
x = self.input(data)
|
311 |
+
|
312 |
+
for l, layer in enumerate(self.layers):
|
313 |
+
x = layer(x, kv=self.cache[l])
|
314 |
+
|
315 |
+
if self.c.split:
|
316 |
+
return self.output(x), self.output2(x)
|
317 |
+
else:
|
318 |
+
return self.output(x)
|
319 |
+
|
320 |
+
@T.no_grad()
|
321 |
+
def next_audio_from_audio(self, audio_data: T.Tensor, temps=(0.8, (0.5, 0.1))):
|
322 |
+
latents_in = self.tokenize(audio_data)
|
323 |
+
next_latents = self.next_latent(latents_in, temps)
|
324 |
+
next_model_latent = next_latents[..., self.c.latent_size:]
|
325 |
+
audio_decoded = self.untokenize(next_model_latent)[..., -2000:]
|
326 |
+
return audio_decoded
|
327 |
+
|
328 |
+
|
329 |
+
@T.no_grad()
|
330 |
+
def next_latent(self, model_input: T.Tensor, temps=(0.8, (0.5, 0.1))):
|
331 |
+
|
332 |
+
if self.c.split:
|
333 |
+
logits1, logits2 = self.forward(model_input)
|
334 |
+
next_logits1 = logits1[:, -1]
|
335 |
+
next_logits2 = logits2[:, -1]
|
336 |
+
next_token1 = F.softmax(next_logits1 / temps[0], dim=-1).multinomial(1)
|
337 |
+
next_token2 = F.softmax(next_logits2 / temps[0], dim=-1).multinomial(1)
|
338 |
+
|
339 |
+
next_input = self.resynthesizer(model_input, next_tokens=(next_token1, next_token2), temps=temps[1])
|
340 |
+
else:
|
341 |
+
logits = self.forward(model_input)
|
342 |
+
next_logits = logits[:, -1]
|
343 |
+
next_token = F.softmax(next_logits / temps[0], dim=-1).multinomial(1)
|
344 |
+
|
345 |
+
next_input = self.resynthesizer(model_input, next_tokens=next_token, temps=temps[1])
|
346 |
+
|
347 |
+
return next_input
|
348 |
+
|
349 |
+
|
350 |
+
@T.no_grad()
|
351 |
+
def completion(self, data: T.Tensor, temps=(0.8, (0.5, 0.1)), gen_len=None, use_cache=True) -> T.Tensor:
|
352 |
+
"""
|
353 |
+
only accepts latent-space data.
|
354 |
+
"""
|
355 |
+
if use_cache:
|
356 |
+
self.init_cache(data.shape[0], data.device, T.bfloat16)
|
357 |
+
|
358 |
+
next_input = generated = data
|
359 |
+
|
360 |
+
target_len = min(data.shape[1] + default(gen_len, data.shape[1]), self.c.stack_config.seq_len)
|
361 |
+
|
362 |
+
for _ in tqdm0(range(data.shape[1], target_len)):
|
363 |
+
model_input = next_input if use_cache else generated
|
364 |
+
|
365 |
+
next_input = self.next_latent(model_input, temps)
|
366 |
+
|
367 |
+
generated = T.cat([generated, next_input], dim=1)
|
368 |
+
|
369 |
+
if use_cache:
|
370 |
+
self.deinit_cache()
|
371 |
+
return generated
|
372 |
+
|
373 |
+
|
374 |
+
|
375 |
+
def get_hertz_dev_config(is_split=True, use_pure_audio_ablation=False):
|
376 |
+
if is_split:
|
377 |
+
checkpoints = [('inference_care_50000', 'e4ff4fe5c7e9f066410d2a5673b7a935'), ('inference_scion_54000', 'cb8bc484423922747b277ebc2933af5d')]
|
378 |
+
elif not use_pure_audio_ablation:
|
379 |
+
checkpoints = [('inference_whip_72000', '5e7cee7316900737d55fc5d44cc7a8f7'), ('inference_caraway_112000', 'fcb8368ef8ebf7712f3e31e6856da580')]
|
380 |
+
else:
|
381 |
+
checkpoints = [('inference_whip_72000', '5e7cee7316900737d55fc5d44cc7a8f7'), ('inference_syrup_110000', '353c48f553f1706824c11f3bb6a049e9')]
|
382 |
+
|
383 |
+
quantizer_config=LatentQuantizer.Config(
|
384 |
+
from_pretrained=('inference_volcano_3', 'd42bf674022c5f84b051d5d7794f6169'),
|
385 |
+
compressor_config=FSQ.Config(
|
386 |
+
levels=[8,8,8,8,8],
|
387 |
+
dim=2048,
|
388 |
+
num_codebooks=1,
|
389 |
+
keep_num_codebooks_dim=None,
|
390 |
+
scale=None,
|
391 |
+
allowed_dtypes=['float32', 'float64', 'bfloat16'],
|
392 |
+
channel_first=False,
|
393 |
+
projection_has_bias=True,
|
394 |
+
return_indices=True,
|
395 |
+
force_quantization_f32=True,
|
396 |
+
use_rms=False
|
397 |
+
),
|
398 |
+
dim=2048,
|
399 |
+
ff_dim=8192,
|
400 |
+
input_dim=32
|
401 |
+
)
|
402 |
+
|
403 |
+
resynthesizer_config=TransformerVAE.Config(
|
404 |
+
io_config=GaussianMixtureIOLayer.Config(
|
405 |
+
latent_dim=32,
|
406 |
+
dim=4096,
|
407 |
+
num_components=8,
|
408 |
+
),
|
409 |
+
stack_config=Stack.Config(
|
410 |
+
layers=8,
|
411 |
+
dim=4096,
|
412 |
+
seq_len=8192,
|
413 |
+
n_head=16,
|
414 |
+
ff_dim=11008,
|
415 |
+
kv_heads=16,
|
416 |
+
eps=1e-5,
|
417 |
+
theta=10_000
|
418 |
+
),
|
419 |
+
quantizer_config=quantizer_config,
|
420 |
+
plex_layer=None,
|
421 |
+
plex_roll=1,
|
422 |
+
split=is_split,
|
423 |
+
from_pretrained=checkpoints[0],
|
424 |
+
)
|
425 |
+
|
426 |
+
return HertzDevModel.Config(
|
427 |
+
dim=4096,
|
428 |
+
vocab_size=32_768,
|
429 |
+
stack_config=Stack.Config(
|
430 |
+
layers=32,
|
431 |
+
dim=4096,
|
432 |
+
seq_len=2048,
|
433 |
+
n_head=32,
|
434 |
+
ff_dim=None,
|
435 |
+
kv_heads=None,
|
436 |
+
eps=1e-5,
|
437 |
+
theta=10_000,
|
438 |
+
),
|
439 |
+
quantizer_config=quantizer_config,
|
440 |
+
resynthesizer_config=resynthesizer_config,
|
441 |
+
split=is_split,
|
442 |
+
from_pretrained=checkpoints[1],
|
443 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.5.1
|
2 |
+
torchaudio==2.5.1
|
3 |
+
einops==0.8.0
|
4 |
+
tqdm==4.66.6
|
5 |
+
ipython==8.18.1
|
6 |
+
numpy==1.26.3
|
7 |
+
soundfile==0.12.1
|
8 |
+
websockets==13.1
|
9 |
+
requests==2.32.3
|
10 |
+
sounddevice==0.5.1
|
11 |
+
matplotlib==3.9.2
|
12 |
+
fastapi==0.115.4
|
13 |
+
uvicorn==0.32.0
|
14 |
+
huggingface-hub[hf_transfer]==0.26.2
|
15 |
+
IProgress==0.4
|
requirements_webrtc.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
streamlit==1.33.0
|
2 |
+
streamlit-webrtc==0.47.9
|
tokenizer.py
ADDED
@@ -0,0 +1,581 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Union, Tuple, Literal
|
4 |
+
|
5 |
+
import torch as T
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.nn.utils.parametrizations import weight_norm
|
8 |
+
|
9 |
+
from utils import load_ckpt
|
10 |
+
from utils.interp import print_colored
|
11 |
+
from utils import si_module, get_activation
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
# Adapted from https://github.com/facebookresearch/AudioDec
|
16 |
+
|
17 |
+
def Conv1d1x1(in_channels, out_channels, bias=True):
|
18 |
+
return nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=bias)
|
19 |
+
|
20 |
+
|
21 |
+
class NonCausalConv1d(nn.Module):
|
22 |
+
"""1D noncausal convolution w/ 2-sides padding."""
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
in_channels,
|
27 |
+
out_channels,
|
28 |
+
kernel_size,
|
29 |
+
stride=1,
|
30 |
+
padding=-1,
|
31 |
+
dilation=1,
|
32 |
+
groups=1,
|
33 |
+
bias=True):
|
34 |
+
super().__init__()
|
35 |
+
self.in_channels = in_channels
|
36 |
+
self.out_channels = out_channels
|
37 |
+
self.kernel_size = kernel_size
|
38 |
+
if padding < 0:
|
39 |
+
padding = (kernel_size - 1) // 2 * dilation
|
40 |
+
self.dilation = dilation
|
41 |
+
self.conv = nn.Conv1d(
|
42 |
+
in_channels=in_channels,
|
43 |
+
out_channels=out_channels,
|
44 |
+
kernel_size=kernel_size,
|
45 |
+
stride=stride,
|
46 |
+
padding=padding,
|
47 |
+
dilation=dilation,
|
48 |
+
groups=groups,
|
49 |
+
bias=bias,
|
50 |
+
)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
"""
|
54 |
+
Args:
|
55 |
+
x (Tensor): Float tensor variable with the shape (B, C, T).
|
56 |
+
Returns:
|
57 |
+
Tensor: Float tensor variable with the shape (B, C, T).
|
58 |
+
"""
|
59 |
+
x = self.conv(x)
|
60 |
+
return x
|
61 |
+
|
62 |
+
|
63 |
+
class NonCausalConvTranspose1d(nn.Module):
|
64 |
+
"""1D noncausal transpose convolution."""
|
65 |
+
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
in_channels,
|
69 |
+
out_channels,
|
70 |
+
kernel_size,
|
71 |
+
stride,
|
72 |
+
padding=-1,
|
73 |
+
output_padding=-1,
|
74 |
+
groups=1,
|
75 |
+
bias=True,
|
76 |
+
):
|
77 |
+
super().__init__()
|
78 |
+
if padding < 0:
|
79 |
+
padding = (stride+1) // 2
|
80 |
+
if output_padding < 0:
|
81 |
+
output_padding = 1 if stride % 2 else 0
|
82 |
+
self.deconv = nn.ConvTranspose1d(
|
83 |
+
in_channels=in_channels,
|
84 |
+
out_channels=out_channels,
|
85 |
+
kernel_size=kernel_size,
|
86 |
+
stride=stride,
|
87 |
+
padding=padding,
|
88 |
+
output_padding=output_padding,
|
89 |
+
groups=groups,
|
90 |
+
bias=bias,
|
91 |
+
)
|
92 |
+
|
93 |
+
def forward(self, x):
|
94 |
+
"""
|
95 |
+
Args:
|
96 |
+
x (Tensor): Float tensor variable with the shape (B, C, T).
|
97 |
+
Returns:
|
98 |
+
Tensor: Float tensor variable with the shape (B, C', T').
|
99 |
+
"""
|
100 |
+
x = self.deconv(x)
|
101 |
+
return x
|
102 |
+
|
103 |
+
|
104 |
+
class CausalConv1d(NonCausalConv1d):
|
105 |
+
def __init__(
|
106 |
+
self,
|
107 |
+
in_channels,
|
108 |
+
out_channels,
|
109 |
+
kernel_size,
|
110 |
+
stride=1,
|
111 |
+
dilation=1,
|
112 |
+
groups=1,
|
113 |
+
bias=True
|
114 |
+
):
|
115 |
+
super(CausalConv1d, self).__init__(
|
116 |
+
in_channels=in_channels,
|
117 |
+
out_channels=out_channels,
|
118 |
+
kernel_size=kernel_size,
|
119 |
+
stride=stride,
|
120 |
+
padding=0,
|
121 |
+
dilation=dilation,
|
122 |
+
groups=groups,
|
123 |
+
bias=bias,
|
124 |
+
)
|
125 |
+
self.stride = stride
|
126 |
+
self.pad_length = (kernel_size - 1) * dilation
|
127 |
+
def forward(self, x):
|
128 |
+
pad = nn.ConstantPad1d((self.pad_length, 0), 0.0)
|
129 |
+
x = pad(x)
|
130 |
+
return self.conv(x)
|
131 |
+
|
132 |
+
|
133 |
+
class CausalConvTranspose1d(NonCausalConvTranspose1d):
|
134 |
+
def __init__(
|
135 |
+
self,
|
136 |
+
in_channels,
|
137 |
+
out_channels,
|
138 |
+
kernel_size,
|
139 |
+
stride,
|
140 |
+
bias=True,
|
141 |
+
pad_buffer=None,
|
142 |
+
):
|
143 |
+
super(CausalConvTranspose1d, self).__init__(
|
144 |
+
in_channels=in_channels,
|
145 |
+
out_channels=out_channels,
|
146 |
+
kernel_size=kernel_size,
|
147 |
+
stride=stride,
|
148 |
+
padding=0,
|
149 |
+
output_padding=0,
|
150 |
+
bias=bias,
|
151 |
+
)
|
152 |
+
self.stride = stride
|
153 |
+
self.pad_length = (math.ceil(kernel_size/stride) - 1)
|
154 |
+
if pad_buffer is None:
|
155 |
+
pad_buffer = T.zeros(1, in_channels, self.pad_length)
|
156 |
+
self.register_buffer("pad_buffer", pad_buffer)
|
157 |
+
|
158 |
+
def forward(self, x):
|
159 |
+
pad = nn.ReplicationPad1d((self.pad_length, 0))
|
160 |
+
x = pad(x)
|
161 |
+
return self.deconv(x)[:, :, self.stride : -self.stride]
|
162 |
+
|
163 |
+
def inference(self, x):
|
164 |
+
x = T.cat((self.pad_buffer, x), -1)
|
165 |
+
self.pad_buffer = x[:, :, -self.pad_length:]
|
166 |
+
return self.deconv(x)[:, :, self.stride : -self.stride]
|
167 |
+
|
168 |
+
def reset_buffer(self):
|
169 |
+
self.pad_buffer.zero_()
|
170 |
+
|
171 |
+
|
172 |
+
class NonCausalResUnit(nn.Module):
|
173 |
+
def __init__(
|
174 |
+
self,
|
175 |
+
in_channels,
|
176 |
+
out_channels,
|
177 |
+
kernel_size=7,
|
178 |
+
dilation=1,
|
179 |
+
bias=False,
|
180 |
+
):
|
181 |
+
super().__init__()
|
182 |
+
self.activation = nn.ELU()
|
183 |
+
self.conv1 = NonCausalConv1d(
|
184 |
+
in_channels=in_channels,
|
185 |
+
out_channels=out_channels,
|
186 |
+
kernel_size=kernel_size,
|
187 |
+
stride=1,
|
188 |
+
dilation=dilation,
|
189 |
+
bias=bias,
|
190 |
+
)
|
191 |
+
self.conv2 = Conv1d1x1(out_channels, out_channels, bias)
|
192 |
+
|
193 |
+
def forward(self, x):
|
194 |
+
y = self.conv1(self.activation(x))
|
195 |
+
y = self.conv2(self.activation(y))
|
196 |
+
return x + y
|
197 |
+
|
198 |
+
|
199 |
+
class CausalResUnit(NonCausalResUnit):
|
200 |
+
def __init__(
|
201 |
+
self,
|
202 |
+
in_channels,
|
203 |
+
out_channels,
|
204 |
+
kernel_size=7,
|
205 |
+
dilation=1,
|
206 |
+
bias=False,
|
207 |
+
):
|
208 |
+
super(CausalResUnit, self).__init__(
|
209 |
+
in_channels=in_channels,
|
210 |
+
out_channels=out_channels,
|
211 |
+
kernel_size=kernel_size,
|
212 |
+
dilation=dilation,
|
213 |
+
bias=bias,
|
214 |
+
)
|
215 |
+
self.conv1 = CausalConv1d(
|
216 |
+
in_channels=in_channels,
|
217 |
+
out_channels=out_channels,
|
218 |
+
kernel_size=kernel_size,
|
219 |
+
stride=1,
|
220 |
+
dilation=dilation,
|
221 |
+
bias=bias,
|
222 |
+
)
|
223 |
+
|
224 |
+
def inference(self, x):
|
225 |
+
y = self.conv1.inference(self.activation(x))
|
226 |
+
y = self.conv2(self.activation(y))
|
227 |
+
return x + y
|
228 |
+
|
229 |
+
|
230 |
+
class ResNetBlock(nn.Module):
|
231 |
+
def __init__(self,
|
232 |
+
in_channels,
|
233 |
+
out_channels,
|
234 |
+
stride,
|
235 |
+
kernel_size=7,
|
236 |
+
dilations=(1, 3, 9),
|
237 |
+
bias=True,
|
238 |
+
mode='encoder',
|
239 |
+
):
|
240 |
+
super().__init__()
|
241 |
+
assert mode in ('encoder', 'decoder'), f"Mode ({mode}) is not supported!"
|
242 |
+
|
243 |
+
self.mode = mode
|
244 |
+
self.stride = stride
|
245 |
+
|
246 |
+
ConvUnit = CausalConv1d if mode == 'encoder' else CausalConvTranspose1d
|
247 |
+
|
248 |
+
res_channels = in_channels if mode == 'encoder' else out_channels
|
249 |
+
|
250 |
+
res_units = [CausalResUnit(
|
251 |
+
res_channels,
|
252 |
+
res_channels,
|
253 |
+
kernel_size=kernel_size,
|
254 |
+
dilation=dilation,
|
255 |
+
) for dilation in dilations]
|
256 |
+
|
257 |
+
if in_channels == out_channels:
|
258 |
+
if mode == 'encoder':
|
259 |
+
self.pool = nn.AvgPool1d(kernel_size=stride, stride=stride)
|
260 |
+
if mode == 'decoder':
|
261 |
+
self.upsample = nn.Upsample(scale_factor=stride, mode='nearest')
|
262 |
+
conv_unit = nn.Conv1d(
|
263 |
+
in_channels=in_channels,
|
264 |
+
out_channels=out_channels,
|
265 |
+
kernel_size=1,
|
266 |
+
bias=bias,
|
267 |
+
) if in_channels != out_channels else nn.Identity()
|
268 |
+
else:
|
269 |
+
conv_unit = ConvUnit(
|
270 |
+
in_channels=in_channels,
|
271 |
+
out_channels=out_channels,
|
272 |
+
kernel_size=(2 * stride),
|
273 |
+
stride=stride,
|
274 |
+
bias=bias,
|
275 |
+
)
|
276 |
+
|
277 |
+
if mode == 'encoder':
|
278 |
+
if in_channels == out_channels:
|
279 |
+
self.res_block = nn.Sequential(*res_units, self.pool, conv_unit)
|
280 |
+
else:
|
281 |
+
self.res_block = nn.Sequential(*res_units, conv_unit)
|
282 |
+
elif mode == 'decoder':
|
283 |
+
if in_channels == out_channels:
|
284 |
+
self.res_block = nn.Sequential(self.upsample, conv_unit, *res_units)
|
285 |
+
else:
|
286 |
+
self.res_block = nn.Sequential(conv_unit, *res_units)
|
287 |
+
|
288 |
+
def forward(self, x):
|
289 |
+
out = x
|
290 |
+
for unit in self.res_block:
|
291 |
+
out = unit(out)
|
292 |
+
return out
|
293 |
+
|
294 |
+
def inference(self, x):
|
295 |
+
for unit in self.res_block:
|
296 |
+
x = unit.inference(x)
|
297 |
+
return x
|
298 |
+
|
299 |
+
|
300 |
+
|
301 |
+
|
302 |
+
@si_module
|
303 |
+
class ResNetStack(nn.Module):
|
304 |
+
"""
|
305 |
+
ResNet encoder or decoder stack. Channel ratios
|
306 |
+
and strides take the default order of from
|
307 |
+
data/io-layer, to the middle of the model.
|
308 |
+
"""
|
309 |
+
class Config:
|
310 |
+
input_channels: int = 1
|
311 |
+
output_channels: int = 1
|
312 |
+
encode_channels: int = 32
|
313 |
+
decode_channel_multiplier: int = 1
|
314 |
+
latent_dim: int = None
|
315 |
+
kernel_size: int = 7
|
316 |
+
bias: bool = True
|
317 |
+
channel_ratios: Tuple[int, ...] = (2, 4, 8, 16)
|
318 |
+
strides: Tuple[int, ...] = (3, 4, 5, 5)
|
319 |
+
mode: Literal['encoder', 'decoder'] = 'encoder'
|
320 |
+
|
321 |
+
def __init__(self, c: Config):
|
322 |
+
super().__init__()
|
323 |
+
assert c.mode in ('encoder', 'decoder'), f"Mode ({c.mode}) is not supported!"
|
324 |
+
|
325 |
+
self.mode = c.mode
|
326 |
+
|
327 |
+
assert len(c.channel_ratios) == len(c.strides)
|
328 |
+
channel_ratios = (1,) + c.channel_ratios
|
329 |
+
strides = c.strides
|
330 |
+
self.middle_channels = c.encode_channels * channel_ratios[-1]
|
331 |
+
if c.mode == 'decoder':
|
332 |
+
channel_ratios = tuple(reversed(channel_ratios))
|
333 |
+
strides = tuple(reversed(strides))
|
334 |
+
|
335 |
+
self.multiplier = c.decode_channel_multiplier if c.mode == 'decoder' else 1
|
336 |
+
res_blocks = [ResNetBlock(
|
337 |
+
c.encode_channels * channel_ratios[s_idx] * self.multiplier,
|
338 |
+
c.encode_channels * channel_ratios[s_idx+1] * self.multiplier,
|
339 |
+
stride,
|
340 |
+
kernel_size=c.kernel_size,
|
341 |
+
bias=c.bias,
|
342 |
+
mode=c.mode,
|
343 |
+
) for s_idx, stride in enumerate(strides)]
|
344 |
+
|
345 |
+
data_conv = CausalConv1d(
|
346 |
+
in_channels=c.input_channels if c.mode == 'encoder' else c.encode_channels * self.multiplier,
|
347 |
+
out_channels=c.encode_channels if c.mode == 'encoder' else c.output_channels,
|
348 |
+
kernel_size=c.kernel_size,
|
349 |
+
stride=1,
|
350 |
+
bias=False,
|
351 |
+
)
|
352 |
+
|
353 |
+
if c.mode == 'encoder':
|
354 |
+
self.res_stack = nn.Sequential(data_conv, *res_blocks)
|
355 |
+
elif c.mode == 'decoder':
|
356 |
+
self.res_stack = nn.Sequential(*res_blocks, data_conv)
|
357 |
+
|
358 |
+
if c.latent_dim is not None:
|
359 |
+
self.latent_proj = Conv1d1x1(self.middle_channels, c.latent_dim, bias=c.bias) if c.mode == 'encoder' else Conv1d1x1(c.latent_dim, self.middle_channels, bias=c.bias)
|
360 |
+
if self.multiplier != 1:
|
361 |
+
self.multiplier_proj = Conv1d1x1(self.middle_channels, self.middle_channels * self.multiplier, bias=c.bias)
|
362 |
+
|
363 |
+
def forward(self, x, return_feats=False):
|
364 |
+
if self.c.latent_dim is not None and self.mode == 'decoder':
|
365 |
+
x = self.latent_proj(x)
|
366 |
+
if self.multiplier != 1:
|
367 |
+
x = self.multiplier_proj(x)
|
368 |
+
|
369 |
+
feats = []
|
370 |
+
for block in self.res_stack:
|
371 |
+
x = block(x)
|
372 |
+
if return_feats:
|
373 |
+
feats.append(x)
|
374 |
+
if self.c.latent_dim is not None and self.mode == 'encoder':
|
375 |
+
x = self.latent_proj(x)
|
376 |
+
if return_feats:
|
377 |
+
feats.append(x)
|
378 |
+
if return_feats:
|
379 |
+
return feats
|
380 |
+
return x
|
381 |
+
|
382 |
+
def inference(self, x):
|
383 |
+
for block in self.res_stack:
|
384 |
+
x = block.inference(x)
|
385 |
+
return x
|
386 |
+
|
387 |
+
def reset_buffer(self):
|
388 |
+
def _reset_buffer(m):
|
389 |
+
if isinstance(m, CausalConv1d) or isinstance(m, CausalConvTranspose1d):
|
390 |
+
m.reset_buffer()
|
391 |
+
self.apply(_reset_buffer)
|
392 |
+
|
393 |
+
def reset_parameters(self):
|
394 |
+
def _reset_parameters(m):
|
395 |
+
if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d)):
|
396 |
+
m.weight.data.normal_(0.0, 0.01)
|
397 |
+
|
398 |
+
self.apply(_reset_parameters)
|
399 |
+
|
400 |
+
|
401 |
+
def apply_weight_norm(self):
|
402 |
+
def _apply_weight_norm(m):
|
403 |
+
if isinstance(m, nn.Conv1d) or isinstance(
|
404 |
+
m, nn.ConvTranspose1d
|
405 |
+
):
|
406 |
+
nn.utils.parametrizations.weight_norm(m)
|
407 |
+
|
408 |
+
self.apply(_apply_weight_norm)
|
409 |
+
|
410 |
+
|
411 |
+
def remove_weight_norm(self):
|
412 |
+
def _remove_weight_norm(m):
|
413 |
+
try:
|
414 |
+
print(m)
|
415 |
+
nn.utils.remove_weight_norm(m)
|
416 |
+
except ValueError: # this module didn't have weight norm
|
417 |
+
return
|
418 |
+
|
419 |
+
self.apply(_remove_weight_norm)
|
420 |
+
|
421 |
+
|
422 |
+
|
423 |
+
@si_module
|
424 |
+
class GaussianZ(nn.Module):
|
425 |
+
class Config:
|
426 |
+
dim: int
|
427 |
+
latent_dim: int
|
428 |
+
bias: bool = False
|
429 |
+
use_weight_norm: bool = False
|
430 |
+
|
431 |
+
def __init__(self, c: Config):
|
432 |
+
super().__init__()
|
433 |
+
|
434 |
+
self.proj_in = nn.Linear(c.dim, c.latent_dim * 2, bias=c.bias)
|
435 |
+
self.proj_out = nn.Linear(c.latent_dim, c.dim, bias=c.bias)
|
436 |
+
|
437 |
+
if c.use_weight_norm:
|
438 |
+
self.proj_in = weight_norm(self.proj_in)
|
439 |
+
self.proj_out = weight_norm(self.proj_out)
|
440 |
+
|
441 |
+
def reparam(self, mu, logvar):
|
442 |
+
std = T.exp(logvar / 2)
|
443 |
+
eps = T.randn_like(std)
|
444 |
+
return mu + eps * std
|
445 |
+
|
446 |
+
def kl_divergence(self, mu, logvar):
|
447 |
+
return T.mean(-0.5 * T.sum(
|
448 |
+
1 + logvar - mu.pow(2) - logvar.exp(),
|
449 |
+
dim=(1, 2))
|
450 |
+
)
|
451 |
+
|
452 |
+
def repr_from_latent(self, latent: Union[dict, T.Tensor]):
|
453 |
+
if isinstance(latent, T.Tensor):
|
454 |
+
z = latent
|
455 |
+
else:
|
456 |
+
z = self.reparam(latent['mu'], latent['logvar'])
|
457 |
+
l = self.proj_out(z)
|
458 |
+
return l
|
459 |
+
|
460 |
+
def forward(self, x: T.Tensor) -> Tuple[T.Tensor, dict]:
|
461 |
+
mu, logvar = self.proj_in(x).chunk(2, dim=-1)
|
462 |
+
kl_div = self.kl_divergence(mu, logvar)
|
463 |
+
z = self.reparam(mu, logvar)
|
464 |
+
xhat = self.proj_out(z)
|
465 |
+
latent = {'mu': mu, 'logvar': logvar, 'z': z, 'kl_divergence': kl_div}
|
466 |
+
return xhat, latent
|
467 |
+
|
468 |
+
|
469 |
+
|
470 |
+
@si_module
|
471 |
+
class WaveCodec(nn.Module):
|
472 |
+
class Config:
|
473 |
+
resnet_config: ResNetStack.Config = None
|
474 |
+
sample_rate: int = 16_000
|
475 |
+
use_weight_norm: bool = False
|
476 |
+
|
477 |
+
compressor_config: dataclass = None
|
478 |
+
|
479 |
+
norm_stddev: float = 1.0
|
480 |
+
|
481 |
+
def __init__(self, c: Config):
|
482 |
+
super().__init__()
|
483 |
+
self.norm_stddev = c.norm_stddev
|
484 |
+
self.encoder = c.resnet_config(mode='encoder')
|
485 |
+
self.sample_rate = c.sample_rate
|
486 |
+
|
487 |
+
self.total_stride = 1
|
488 |
+
for stride in c.resnet_config.strides:
|
489 |
+
self.total_stride *= stride
|
490 |
+
self.tokens_per_second = self.sample_rate / self.total_stride
|
491 |
+
|
492 |
+
self.compressor = c.compressor_config(dim=self.encoder.middle_channels)
|
493 |
+
|
494 |
+
self.decoder = c.resnet_config(mode='decoder')
|
495 |
+
|
496 |
+
if c.use_weight_norm:
|
497 |
+
self.encoder.apply_weight_norm()
|
498 |
+
self.decoder.apply_weight_norm()
|
499 |
+
self.encoder.reset_parameters()
|
500 |
+
self.decoder.reset_parameters()
|
501 |
+
|
502 |
+
def encode(self, data):
|
503 |
+
return self.encoder(data/self.norm_stddev)
|
504 |
+
|
505 |
+
def decode(self, latent):
|
506 |
+
return self.decoder(latent.transpose(1, 2))*self.norm_stddev
|
507 |
+
|
508 |
+
@T.no_grad()
|
509 |
+
def latent_from_data(self, data, get_parameters=False):
|
510 |
+
x = self.encode(data)
|
511 |
+
l_in = x.transpose(1, 2)
|
512 |
+
l, latent = self.compressor(l_in)
|
513 |
+
return latent['z'] if not get_parameters else {
|
514 |
+
'mu': latent['mu'],
|
515 |
+
'logvar': latent['logvar'],
|
516 |
+
'z': latent['z'],
|
517 |
+
}
|
518 |
+
|
519 |
+
@T.no_grad()
|
520 |
+
def data_from_latent(self, latent):
|
521 |
+
l = self.compressor.repr_from_latent(latent)
|
522 |
+
x = self.decode(l)
|
523 |
+
return x
|
524 |
+
|
525 |
+
def process(self, x):
|
526 |
+
return self.latent_from_data(x)
|
527 |
+
|
528 |
+
def unprocess(self, latent):
|
529 |
+
return self.data_from_latent(latent)
|
530 |
+
|
531 |
+
def forward(self, audio_input):
|
532 |
+
x = self.encode(audio_input)
|
533 |
+
|
534 |
+
l_in = x.transpose(1, 2)
|
535 |
+
l, latent = self.compressor(l_in)
|
536 |
+
|
537 |
+
xhat = self.decode(l)
|
538 |
+
return xhat, latent
|
539 |
+
|
540 |
+
|
541 |
+
|
542 |
+
def make_tokenizer(device='cuda'):
|
543 |
+
generator_config = WaveCodec.Config(
|
544 |
+
resnet_config=ResNetStack.Config(
|
545 |
+
input_channels=1,
|
546 |
+
output_channels=1,
|
547 |
+
encode_channels=16,
|
548 |
+
decode_channel_multiplier=4,
|
549 |
+
kernel_size=7,
|
550 |
+
bias=True,
|
551 |
+
channel_ratios=(4, 8, 16, 16, 16, 16),
|
552 |
+
strides=(2, 2, 4, 5, 5, 5),
|
553 |
+
mode=None,
|
554 |
+
),
|
555 |
+
use_weight_norm=True,
|
556 |
+
|
557 |
+
compressor_config=GaussianZ.Config(
|
558 |
+
dim=None,
|
559 |
+
latent_dim=32,
|
560 |
+
|
561 |
+
bias=True,
|
562 |
+
use_weight_norm=True
|
563 |
+
),
|
564 |
+
|
565 |
+
norm_stddev=0.05,
|
566 |
+
)
|
567 |
+
checkpoint = load_ckpt("inference_apatosaurus_95000", expected_hash="ba876edb97b988e9196e449dd176ca97")
|
568 |
+
|
569 |
+
tokenizer = generator_config()
|
570 |
+
|
571 |
+
load_result = tokenizer.load_state_dict(checkpoint, strict=False)
|
572 |
+
print_colored(f"Loaded tokenizer state dict: {load_result}", "grey")
|
573 |
+
|
574 |
+
tokenizer = tokenizer.eval()
|
575 |
+
# Only convert to bfloat16 if using CUDA
|
576 |
+
if device == 'cuda':
|
577 |
+
tokenizer = tokenizer.bfloat16()
|
578 |
+
tokenizer = tokenizer.to(device)
|
579 |
+
tokenizer.requires_grad_ = False
|
580 |
+
return tokenizer
|
581 |
+
|
transformer.py
ADDED
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, MutableMapping
|
2 |
+
from typing import Union
|
3 |
+
import math
|
4 |
+
from contextlib import nullcontext
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch as T
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch import Tensor
|
11 |
+
from torch.nn.attention import SDPBackend
|
12 |
+
|
13 |
+
from einops import rearrange
|
14 |
+
|
15 |
+
from utils import si_module, default, exists, load_ckpt
|
16 |
+
|
17 |
+
CACHE_FILL_VALUE = -1
|
18 |
+
|
19 |
+
def get_cache_len(cache: Optional[Tensor]) -> int:
|
20 |
+
"""
|
21 |
+
cache: (batch, seq_len, 2, kv_heads, head_dim)
|
22 |
+
"""
|
23 |
+
if cache is None:
|
24 |
+
return 0
|
25 |
+
nonzeros = T.any(cache.flatten(2) != CACHE_FILL_VALUE, dim=-1)
|
26 |
+
length = nonzeros.sum(dim=-1).int()
|
27 |
+
assert T.all(length == length[0])
|
28 |
+
return length[0]
|
29 |
+
|
30 |
+
|
31 |
+
def rotate_half(x):
|
32 |
+
x1, x2 = x.chunk(2, dim=-1)
|
33 |
+
return torch.cat((-x2, x1), dim=-1)
|
34 |
+
|
35 |
+
|
36 |
+
def apply_rotary_pos_emb(x, cos, sin, offset: int = 0):
|
37 |
+
assert (
|
38 |
+
cos.shape[1] >= offset + x.shape[1]
|
39 |
+
), f"Offset and/or input sequence is too large,\
|
40 |
+
\n offset: {offset}, seq_len: {x.shape[1]}, max: {cos.shape[1]}"
|
41 |
+
|
42 |
+
cos_out = cos[:, offset : offset + x.shape[1], :, :]
|
43 |
+
sin_out = sin[:, offset : offset + x.shape[1], :, :]
|
44 |
+
|
45 |
+
return (x * cos_out) + (rotate_half(x) * sin_out)
|
46 |
+
|
47 |
+
|
48 |
+
# Adapted from https://github.com/foundation-model-stack/foundation-model-stack
|
49 |
+
class ShapeRotator:
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
dim: int,
|
53 |
+
end: int,
|
54 |
+
theta: float = 10_000,
|
55 |
+
):
|
56 |
+
super().__init__()
|
57 |
+
self.dim = dim
|
58 |
+
self.ratio = theta
|
59 |
+
self.cached_freqs: MutableMapping[int, MutableMapping[int, torch.Tensor]] = {}
|
60 |
+
self.max_seq_len_cached: MutableMapping[int, int] = {}
|
61 |
+
self.ntk_scaling = False
|
62 |
+
self.max_seq_len = end
|
63 |
+
|
64 |
+
def compute_freqs_cis(self, device, max_seq_len=None):
|
65 |
+
alpha = 1
|
66 |
+
dev_idx = device.index
|
67 |
+
max_seq_len = default(max_seq_len, self.max_seq_len)
|
68 |
+
|
69 |
+
if dev_idx not in self.cached_freqs:
|
70 |
+
self.cached_freqs[dev_idx] = {}
|
71 |
+
if dev_idx not in self.max_seq_len_cached:
|
72 |
+
self.max_seq_len_cached[dev_idx] = 0
|
73 |
+
|
74 |
+
|
75 |
+
if self.max_seq_len_cached[dev_idx] > 0:
|
76 |
+
return 1
|
77 |
+
max_seq_len = max(max_seq_len, self.max_seq_len)
|
78 |
+
|
79 |
+
if (
|
80 |
+
1 in self.cached_freqs[dev_idx]
|
81 |
+
and max_seq_len <= self.max_seq_len_cached[dev_idx]
|
82 |
+
):
|
83 |
+
return 1
|
84 |
+
|
85 |
+
ratio = self.ratio
|
86 |
+
dim = self.dim
|
87 |
+
|
88 |
+
freqs = 1.0 / (ratio ** (torch.arange(0, dim, 2, device=device).float() / dim))
|
89 |
+
|
90 |
+
t = torch.arange(max_seq_len, device=device, dtype=freqs.dtype)
|
91 |
+
freqs = torch.einsum("i,j->ij", t, freqs)
|
92 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
93 |
+
|
94 |
+
cos_to_cache = emb.cos()[None, :, None, :]
|
95 |
+
sin_to_cache = emb.sin()[None, :, None, :]
|
96 |
+
|
97 |
+
self.max_seq_len_cached[dev_idx] = max_seq_len
|
98 |
+
|
99 |
+
self.cached_freqs[dev_idx][alpha] = torch.stack(
|
100 |
+
[
|
101 |
+
cos_to_cache,
|
102 |
+
sin_to_cache,
|
103 |
+
],
|
104 |
+
dim=-1,
|
105 |
+
)
|
106 |
+
|
107 |
+
return alpha
|
108 |
+
|
109 |
+
def rotate(
|
110 |
+
self,
|
111 |
+
q: Tensor,
|
112 |
+
k: Tensor,
|
113 |
+
offset: int = 0,
|
114 |
+
) -> Tuple[Tensor, Tensor]:
|
115 |
+
"""
|
116 |
+
Args
|
117 |
+
----
|
118 |
+
q : torch.Tensor
|
119 |
+
Embedded query tensor, expected size is B x S x H x Eh
|
120 |
+
k : torch.Tensor
|
121 |
+
Embedded query tensor, expected size is B x S x H x Eh
|
122 |
+
"""
|
123 |
+
assert len(q.size()) == 4
|
124 |
+
assert len(k.size()) == 4
|
125 |
+
|
126 |
+
seq_len = self.max_seq_len
|
127 |
+
alpha = self.compute_freqs_cis(q.device, seq_len)
|
128 |
+
freqs = self.cached_freqs[q.device.index][alpha]
|
129 |
+
|
130 |
+
freqs = freqs.float() # 1 L D/2 2 2
|
131 |
+
q_out = apply_rotary_pos_emb(q, freqs[..., 0], freqs[..., 1], offset=offset).type_as(q)
|
132 |
+
k_out = apply_rotary_pos_emb(k, freqs[..., 0], freqs[..., 1], offset=offset).type_as(k)
|
133 |
+
|
134 |
+
return q_out.view_as(q), k_out.view_as(k)
|
135 |
+
|
136 |
+
class Linear(nn.Linear):
|
137 |
+
def __init__(self, *args, **kwargs):
|
138 |
+
super().__init__(*args, **kwargs, bias=False)
|
139 |
+
|
140 |
+
class Norm(nn.Module):
|
141 |
+
def __init__(self,
|
142 |
+
dim: int,
|
143 |
+
eps: float = 1e-5,) -> None:
|
144 |
+
super().__init__()
|
145 |
+
self.eps = eps
|
146 |
+
self.weight = nn.Parameter(T.ones((dim,)))
|
147 |
+
|
148 |
+
def forward(self, input: Tensor) -> Tensor:
|
149 |
+
return F.layer_norm(input, (self.weight.shape[0],), weight=self.weight, bias=None, eps=self.eps)
|
150 |
+
|
151 |
+
|
152 |
+
class FFNN(nn.Module):
|
153 |
+
def __init__(self,
|
154 |
+
dim: int,
|
155 |
+
expand_dim: int = None,):
|
156 |
+
super().__init__()
|
157 |
+
expand_dim = default(expand_dim, 256 * ((int(2 * 4 * dim / 3) + 256 - 1) // 256))
|
158 |
+
self.dim = dim
|
159 |
+
self.expand_dim = expand_dim
|
160 |
+
|
161 |
+
self.gateup_proj = Linear(dim, 2*expand_dim)
|
162 |
+
self.down_proj = Linear(expand_dim, dim)
|
163 |
+
|
164 |
+
def forward(self, x):
|
165 |
+
gate, up = self.gateup_proj(x).chunk(2, dim=-1)
|
166 |
+
return self.down_proj(up * F.silu(gate))
|
167 |
+
|
168 |
+
class GQA(nn.Module):
|
169 |
+
def __init__(self,
|
170 |
+
dim: int,
|
171 |
+
n_head: int,
|
172 |
+
shape_rotator: ShapeRotator,
|
173 |
+
kv_heads: Optional[int] = None,
|
174 |
+
eps: float = 1e-5,
|
175 |
+
causal: bool = True,):
|
176 |
+
super().__init__()
|
177 |
+
self.n_heads = n_head
|
178 |
+
self.kv_heads = default(kv_heads, n_head)
|
179 |
+
self.head_dim = dim // n_head
|
180 |
+
self.causal = causal
|
181 |
+
|
182 |
+
self.proj_qkv = Linear(dim, self.head_dim*(n_head+2*self.kv_heads))
|
183 |
+
|
184 |
+
self.norm_q = Norm(self.head_dim*n_head, eps=eps)
|
185 |
+
self.norm_k = Norm(self.head_dim*self.kv_heads, eps=eps)
|
186 |
+
|
187 |
+
self.attn_out = Linear(dim, dim)
|
188 |
+
|
189 |
+
self.shape_rotator = shape_rotator
|
190 |
+
|
191 |
+
def _sdpa(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
192 |
+
k = k.repeat_interleave(self.n_heads // self.kv_heads, dim=2)
|
193 |
+
v = v.repeat_interleave(self.n_heads // self.kv_heads, dim=2)
|
194 |
+
x = F.scaled_dot_product_attention(
|
195 |
+
q.transpose(1, 2),
|
196 |
+
k.transpose(1, 2),
|
197 |
+
v.transpose(1, 2),
|
198 |
+
is_causal=False if (q.size(1) != k.size(1)) else self.causal,
|
199 |
+
)
|
200 |
+
x = x.transpose(1, 2).contiguous()
|
201 |
+
return x
|
202 |
+
|
203 |
+
def _attend(self, q: Tensor, k: Tensor, v: Tensor, kv_cache: Optional[Tensor] = None,):
|
204 |
+
cache_len = get_cache_len(kv_cache)
|
205 |
+
q, k = self.shape_rotator.rotate(q, k, offset=cache_len)
|
206 |
+
if exists(kv_cache):
|
207 |
+
k = T.cat([kv_cache[:, :cache_len, 0], k], dim=1)
|
208 |
+
v = T.cat([kv_cache[:, :cache_len, 1], v], dim=1)
|
209 |
+
kv_cache[:, :k.size(1), 0] = k
|
210 |
+
kv_cache[:, :v.size(1), 1] = v
|
211 |
+
x = self._sdpa(q, k, v)
|
212 |
+
return self.attn_out(rearrange(x, 'b s h d -> b s (h d)'))
|
213 |
+
|
214 |
+
def _project(self, x):
|
215 |
+
full_q, full_k, full_v = self.proj_qkv(x).chunk(3, dim=-1)
|
216 |
+
normed_full_q = self.norm_q(full_q).to(full_q.dtype)
|
217 |
+
normed_full_k = self.norm_k(full_k).to(full_k.dtype)
|
218 |
+
|
219 |
+
q = rearrange(normed_full_q, 'b s (h d) -> b s h d', h=self.n_heads)
|
220 |
+
k = rearrange(normed_full_k, 'b s (h d) -> b s h d', h=self.kv_heads)
|
221 |
+
v = rearrange(full_v, 'b s (h d) -> b s h d', h=self.kv_heads)
|
222 |
+
return q, k, v
|
223 |
+
|
224 |
+
def forward(self,
|
225 |
+
x: Tensor,
|
226 |
+
kv: Optional[Tensor] = None,):
|
227 |
+
"""
|
228 |
+
x: (B, S, D)
|
229 |
+
kv: (B, S, H, D)
|
230 |
+
"""
|
231 |
+
q, k, v = self._project(x)
|
232 |
+
return self._attend(q, k, v, kv_cache=kv)
|
233 |
+
|
234 |
+
|
235 |
+
class PreNormAttn(nn.Module):
|
236 |
+
def __init__(self,
|
237 |
+
dim: int,
|
238 |
+
n_head: int,
|
239 |
+
shape_rotator: ShapeRotator,
|
240 |
+
kv_heads: Optional[int] = None,
|
241 |
+
eps: float = 1e-5,
|
242 |
+
causal: bool = True,):
|
243 |
+
super().__init__()
|
244 |
+
self.attn_norm = Norm(dim, eps=eps)
|
245 |
+
self.attn = GQA(dim, n_head, shape_rotator, kv_heads, eps=eps, causal=causal)
|
246 |
+
|
247 |
+
def forward(self, x: Tensor, kv: Optional[Tensor] = None) -> Tensor:
|
248 |
+
"""
|
249 |
+
x: (B, S, D)
|
250 |
+
kv: (B, S, H, D)
|
251 |
+
"""
|
252 |
+
return x + self.attn(self.attn_norm(x), kv)
|
253 |
+
|
254 |
+
class PreNormFFNN(nn.Module):
|
255 |
+
def __init__(self,
|
256 |
+
dim: int,
|
257 |
+
ff_dim: int,
|
258 |
+
eps: float = 1e-5,):
|
259 |
+
super().__init__()
|
260 |
+
self.ffnn_norm = Norm(dim, eps=eps)
|
261 |
+
self.ffnn = FFNN(dim, ff_dim)
|
262 |
+
|
263 |
+
def forward(self, x: Tensor) -> Tensor:
|
264 |
+
return x + self.ffnn(self.ffnn_norm(x))
|
265 |
+
|
266 |
+
class Block(nn.Module):
|
267 |
+
def __init__(self,
|
268 |
+
dim: int,
|
269 |
+
layer_id: int = 0,
|
270 |
+
n_head: int = 16,
|
271 |
+
kv_heads: Optional[int] = None,
|
272 |
+
ff_dim: Optional[int] = None,
|
273 |
+
eps: float = 1e-5,
|
274 |
+
causal: bool = True,
|
275 |
+
shape_rotator: ShapeRotator = None):
|
276 |
+
super().__init__()
|
277 |
+
self.attn = PreNormAttn(dim, n_head, shape_rotator, kv_heads, eps=eps, causal=causal)
|
278 |
+
self.ffnn = PreNormFFNN(dim, ff_dim, eps=eps)
|
279 |
+
self.dim = dim
|
280 |
+
self.layer_id = layer_id
|
281 |
+
self.head_dim = dim // n_head
|
282 |
+
self.expand_dim = self.ffnn.ffnn.expand_dim
|
283 |
+
|
284 |
+
self.reset_parameters()
|
285 |
+
|
286 |
+
def reset_parameters(self):
|
287 |
+
std = 1.0 / math.sqrt(self.dim)
|
288 |
+
nn.init.trunc_normal_(self.ffnn.ffnn.gateup_proj.weight, std=std, a=-3 * std, b=3 * std)
|
289 |
+
nn.init.trunc_normal_(self.attn.attn.proj_qkv.weight, std=std, a=-3 * std, b=3 * std)
|
290 |
+
nn.init.trunc_normal_(self.attn.attn.attn_out.weight, std=std, a=-3 * std, b=3 * std)
|
291 |
+
|
292 |
+
xstd = 1.0 / math.sqrt(self.expand_dim)
|
293 |
+
nn.init.trunc_normal_(self.ffnn.ffnn.down_proj.weight, std=xstd, a=-3 * xstd, b=3 * xstd)
|
294 |
+
|
295 |
+
def forward(self, x: Tensor, kv: Optional[Tensor] = None) -> Tensor:
|
296 |
+
"""
|
297 |
+
x: (B, S, D)
|
298 |
+
kv: (B, S, H, D)
|
299 |
+
"""
|
300 |
+
h = self.attn(x, kv)
|
301 |
+
out = self.ffnn(h)
|
302 |
+
return out
|
303 |
+
|
304 |
+
|
305 |
+
|
306 |
+
class GPTOutput(nn.Module):
|
307 |
+
def __init__(self, dim, vocab_size):
|
308 |
+
super().__init__()
|
309 |
+
self.dim = dim
|
310 |
+
self.norm = Norm(dim)
|
311 |
+
self.output = Linear(dim, vocab_size)
|
312 |
+
|
313 |
+
self.reset_parameters()
|
314 |
+
|
315 |
+
def reset_parameters(self):
|
316 |
+
std = 1.0 / math.sqrt(self.dim**2)
|
317 |
+
nn.init.trunc_normal_(self.output.weight, std=std, a=-3 * std, b=3 * std)
|
318 |
+
|
319 |
+
def forward(self, x):
|
320 |
+
return self.output(self.norm(x))
|
321 |
+
|
322 |
+
@si_module
|
323 |
+
class Stack(nn.Module):
|
324 |
+
class Config:
|
325 |
+
layers: int
|
326 |
+
dim: int
|
327 |
+
seq_len: int
|
328 |
+
n_head: int = 32
|
329 |
+
ff_dim: int = None
|
330 |
+
kv_heads: int = None
|
331 |
+
eps: float = 1e-5
|
332 |
+
theta: Union[int, float] = 10_000
|
333 |
+
causal: bool = True
|
334 |
+
|
335 |
+
from_pretrained: Optional[Tuple[str, int]] = None
|
336 |
+
|
337 |
+
def __init__(self, c: Config):
|
338 |
+
super().__init__()
|
339 |
+
|
340 |
+
from_pretrained = c.from_pretrained
|
341 |
+
if exists(from_pretrained):
|
342 |
+
checkpoint = load_ckpt(c.from_pretrained)
|
343 |
+
|
344 |
+
self.shape_rotator = ShapeRotator(c.dim//c.n_head, c.seq_len, theta=c.theta)
|
345 |
+
|
346 |
+
self.layers = nn.ModuleList([
|
347 |
+
Block(
|
348 |
+
dim=c.dim,
|
349 |
+
layer_id=l,
|
350 |
+
n_head=c.n_head,
|
351 |
+
kv_heads=c.kv_heads,
|
352 |
+
ff_dim=c.ff_dim,
|
353 |
+
eps=c.eps,
|
354 |
+
causal=c.causal,
|
355 |
+
shape_rotator=self.shape_rotator,
|
356 |
+
) for l in range(c.layers)
|
357 |
+
])
|
358 |
+
|
359 |
+
kv_heads = c.kv_heads or c.n_head
|
360 |
+
head_dim = c.dim // c.n_head
|
361 |
+
cache_shape = [c.layers, c.seq_len, 2, kv_heads, head_dim]
|
362 |
+
self.cache_shape = cache_shape
|
363 |
+
self.cache = [None] * c.layers
|
364 |
+
|
365 |
+
if exists(from_pretrained):
|
366 |
+
self.load_state_dict(checkpoint)
|
367 |
+
|
368 |
+
def init_cache(self, bsize, device, dtype, length:int=None):
|
369 |
+
if self.cache_shape is None:
|
370 |
+
return
|
371 |
+
cache_shape = self.cache_shape.copy()
|
372 |
+
cache_shape[1] = length or cache_shape[1]
|
373 |
+
self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1)
|
374 |
+
|
375 |
+
def deinit_cache(self):
|
376 |
+
self.cache = [None] * len(self.cache)
|
377 |
+
|
378 |
+
def forward(self, x: Tensor) -> Tensor:
|
379 |
+
for l, layer in enumerate(self.layers):
|
380 |
+
x = layer(x, kv=self.cache[l])
|
381 |
+
return x
|
utils/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .blocks import *
|
2 |
+
from .dist import *
|
3 |
+
from .interp import *
|
utils/blocks.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import TypeVar, Generic, Type, Optional
|
3 |
+
from functools import wraps
|
4 |
+
import time
|
5 |
+
import random
|
6 |
+
|
7 |
+
import torch as T
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
# @TODO: remove si_module from codebase
|
11 |
+
# we use this in our research codebase to make modules from callable configs
|
12 |
+
si_module_TpV = TypeVar('si_module_TpV')
|
13 |
+
def si_module(cls: Type[si_module_TpV]) -> Type[si_module_TpV]:
|
14 |
+
if not hasattr(cls, 'Config') or not isinstance(cls.Config, type):
|
15 |
+
class Config:
|
16 |
+
pass
|
17 |
+
cls.Config = Config
|
18 |
+
|
19 |
+
cls.Config = dataclass(cls.Config)
|
20 |
+
|
21 |
+
class ConfigWrapper(cls.Config, Generic[si_module_TpV]):
|
22 |
+
def __call__(self, *args, **kwargs) -> si_module_TpV:
|
23 |
+
if len(kwargs) > 0:
|
24 |
+
config_dict = {field.name: getattr(self, field.name) for field in self.__dataclass_fields__.values()}
|
25 |
+
config_dict.update(kwargs)
|
26 |
+
new_config = type(self)(**config_dict)
|
27 |
+
return cls(new_config)
|
28 |
+
else:
|
29 |
+
return cls(self, *args)
|
30 |
+
|
31 |
+
ConfigWrapper.__module__ = cls.__module__
|
32 |
+
ConfigWrapper.__name__ = f"{cls.__name__}Config"
|
33 |
+
ConfigWrapper.__qualname__ = f"{cls.__qualname__}.Config"
|
34 |
+
|
35 |
+
cls.Config = ConfigWrapper
|
36 |
+
|
37 |
+
original_init = cls.__init__
|
38 |
+
def new_init(self, *args, **kwargs):
|
39 |
+
self.c = next((arg for arg in args if isinstance(arg, cls.Config)), None) or next((arg for arg in kwargs.values() if isinstance(arg, cls.Config)), None)
|
40 |
+
original_init(self, *args, **kwargs)
|
41 |
+
self.register_buffer('_device_tracker', T.Tensor(), persistent=False)
|
42 |
+
|
43 |
+
cls.__init__ = new_init
|
44 |
+
|
45 |
+
@property
|
46 |
+
def device(self):
|
47 |
+
return self._device_tracker.device
|
48 |
+
|
49 |
+
@property
|
50 |
+
def dtype(self):
|
51 |
+
return self._device_tracker.dtype
|
52 |
+
|
53 |
+
cls.device = device
|
54 |
+
cls.dtype = dtype
|
55 |
+
|
56 |
+
return cls
|
57 |
+
|
58 |
+
|
59 |
+
def get_activation(nonlinear_activation, nonlinear_activation_params={}):
|
60 |
+
if hasattr(nn, nonlinear_activation):
|
61 |
+
return getattr(nn, nonlinear_activation)(**nonlinear_activation_params)
|
62 |
+
else:
|
63 |
+
raise NotImplementedError(f"Activation {nonlinear_activation} not found in torch.nn")
|
64 |
+
|
65 |
+
|
66 |
+
def exists(v):
|
67 |
+
return v is not None
|
68 |
+
|
69 |
+
def isnt(v):
|
70 |
+
return not exists(v)
|
71 |
+
|
72 |
+
def truthyexists(v):
|
73 |
+
return exists(v) and v is not False
|
74 |
+
|
75 |
+
def truthyattr(obj, attr):
|
76 |
+
return hasattr(obj, attr) and truthyexists(getattr(obj, attr))
|
77 |
+
|
78 |
+
defaultT = TypeVar('defaultT')
|
79 |
+
|
80 |
+
def default(*args: Optional[defaultT]) -> Optional[defaultT]:
|
81 |
+
for arg in args:
|
82 |
+
if exists(arg):
|
83 |
+
return arg
|
84 |
+
return None
|
85 |
+
|
86 |
+
def maybe(fn):
|
87 |
+
@wraps(fn)
|
88 |
+
def inner(x, *args, **kwargs):
|
89 |
+
if not exists(x):
|
90 |
+
return x
|
91 |
+
return fn(x, *args, **kwargs)
|
92 |
+
return inner
|
utils/dist.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch as T
|
3 |
+
import re
|
4 |
+
from tqdm import tqdm
|
5 |
+
from datetime import timedelta
|
6 |
+
|
7 |
+
import requests
|
8 |
+
import hashlib
|
9 |
+
|
10 |
+
from io import BytesIO
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
+
|
13 |
+
def rank0():
|
14 |
+
rank = os.environ.get('RANK')
|
15 |
+
if rank is None or rank == '0':
|
16 |
+
return True
|
17 |
+
else:
|
18 |
+
return False
|
19 |
+
|
20 |
+
def local0():
|
21 |
+
local_rank = os.environ.get('LOCAL_RANK')
|
22 |
+
if local_rank is None or local_rank == '0':
|
23 |
+
return True
|
24 |
+
else:
|
25 |
+
return False
|
26 |
+
class tqdm0(tqdm):
|
27 |
+
def __init__(self, *args, **kwargs):
|
28 |
+
total = kwargs.get('total', None)
|
29 |
+
if total is None and len(args) > 0:
|
30 |
+
try:
|
31 |
+
total = len(args[0])
|
32 |
+
except TypeError:
|
33 |
+
pass
|
34 |
+
if total is not None:
|
35 |
+
kwargs['miniters'] = max(1, total // 20)
|
36 |
+
super().__init__(*args, **kwargs, disable=not rank0(), bar_format='{bar}| {n_fmt}/{total_fmt} [{rate_fmt}{postfix}]')
|
37 |
+
|
38 |
+
def print0(*args, **kwargs):
|
39 |
+
if rank0():
|
40 |
+
print(*args, **kwargs)
|
41 |
+
|
42 |
+
_PRINTED_IDS = set()
|
43 |
+
|
44 |
+
def printonce(*args, id=None, **kwargs):
|
45 |
+
if id is None:
|
46 |
+
id = ' '.join(map(str, args))
|
47 |
+
|
48 |
+
if id not in _PRINTED_IDS:
|
49 |
+
print(*args, **kwargs)
|
50 |
+
_PRINTED_IDS.add(id)
|
51 |
+
|
52 |
+
def print0once(*args, **kwargs):
|
53 |
+
if rank0():
|
54 |
+
printonce(*args, **kwargs)
|
55 |
+
|
56 |
+
def init_dist():
|
57 |
+
if T.distributed.is_initialized():
|
58 |
+
print0('Distributed already initialized')
|
59 |
+
rank = T.distributed.get_rank()
|
60 |
+
local_rank = int(os.environ.get('LOCAL_RANK', 0))
|
61 |
+
world_size = T.distributed.get_world_size()
|
62 |
+
else:
|
63 |
+
try:
|
64 |
+
rank = int(os.environ['RANK'])
|
65 |
+
local_rank = int(os.environ['LOCAL_RANK'])
|
66 |
+
world_size = int(os.environ['WORLD_SIZE'])
|
67 |
+
device = f'cuda:{local_rank}'
|
68 |
+
T.cuda.set_device(device)
|
69 |
+
T.distributed.init_process_group(backend='nccl', timeout=timedelta(minutes=30), rank=rank, world_size=world_size, device_id=T.device(device))
|
70 |
+
print(f'Rank {rank} of {world_size}.')
|
71 |
+
except Exception as e:
|
72 |
+
print0once(f'Not initializing distributed env: {e}')
|
73 |
+
rank = 0
|
74 |
+
local_rank = 0
|
75 |
+
world_size = 1
|
76 |
+
return rank, local_rank, world_size
|
77 |
+
|
78 |
+
def load_ckpt(load_from_location, expected_hash=None):
|
79 |
+
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1' #Disable this to speed up debugging errors with downloading from the hub
|
80 |
+
if local0():
|
81 |
+
repo_id = "si-pbc/hertz-dev"
|
82 |
+
print0(f'Loading checkpoint from repo_id {repo_id} and filename {load_from_location}.pt. This may take a while...')
|
83 |
+
save_path = hf_hub_download(repo_id=repo_id, filename=f"{load_from_location}.pt")
|
84 |
+
print0(f'Downloaded checkpoint to {save_path}')
|
85 |
+
if expected_hash is not None:
|
86 |
+
with open(save_path, 'rb') as f:
|
87 |
+
file_hash = hashlib.md5(f.read()).hexdigest()
|
88 |
+
if file_hash != expected_hash:
|
89 |
+
print(f'Hash mismatch for {save_path}. Expected {expected_hash} but got {file_hash}. Deleting checkpoint and trying again.')
|
90 |
+
os.remove(save_path)
|
91 |
+
return load_ckpt(load_from_location, expected_hash)
|
92 |
+
if T.distributed.is_initialized():
|
93 |
+
save_path = [save_path]
|
94 |
+
T.distributed.broadcast_object_list(save_path, src=0)
|
95 |
+
save_path = save_path[0]
|
96 |
+
loaded = T.load(save_path, weights_only=False, map_location='cpu')
|
97 |
+
print0(f'Loaded checkpoint from {save_path}')
|
98 |
+
return loaded
|
utils/interp.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch as T
|
2 |
+
import os
|
3 |
+
|
4 |
+
def rank0():
|
5 |
+
rank = os.environ.get('RANK')
|
6 |
+
if rank is None or rank == '0':
|
7 |
+
return True
|
8 |
+
else:
|
9 |
+
return False
|
10 |
+
|
11 |
+
def print_colored(message, color='reset', bold=False, **kwargs):
|
12 |
+
color_dict = {
|
13 |
+
'bold': '\033[1m',
|
14 |
+
'green': '\033[92m',
|
15 |
+
'yellow': '\033[93m',
|
16 |
+
'red': '\033[91m',
|
17 |
+
'blue': '\033[94m',
|
18 |
+
'grey': '\033[90m',
|
19 |
+
'white': '\033[97m',
|
20 |
+
'reset': '\033[0m'
|
21 |
+
}
|
22 |
+
|
23 |
+
color_code = color_dict.get(color.lower(), color_dict['reset'])
|
24 |
+
prefix = color_dict['bold'] if bold else ''
|
25 |
+
print(f"{prefix}{color_code}{message}{color_dict['reset']}", **kwargs)
|
26 |
+
|
27 |
+
def print0_colored(*args, **kwargs):
|
28 |
+
if rank0():
|
29 |
+
print_colored(*args, **kwargs)
|
30 |
+
|
31 |
+
def param_count(module):
|
32 |
+
def count_parameters(model):
|
33 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
34 |
+
|
35 |
+
total_params = count_parameters(module)
|
36 |
+
output = [f'Total model parameters: {total_params:,}', '---------------------------']
|
37 |
+
|
38 |
+
for name, child in module.named_children():
|
39 |
+
params = count_parameters(child)
|
40 |
+
output.append(f'{name} parameters: {params:,}')
|
41 |
+
|
42 |
+
return '\n'.join(output)
|
43 |
+
|
44 |
+
def model_size_estimation(module):
|
45 |
+
def estimate_size(model):
|
46 |
+
param_size = sum(p.nelement() * p.element_size() for p in model.parameters())
|
47 |
+
buffer_size = sum(b.nelement() * b.element_size() for b in model.buffers())
|
48 |
+
return param_size + buffer_size
|
49 |
+
|
50 |
+
total_size = estimate_size(module)
|
51 |
+
output = [f'Total model size: {total_size / 1024**2:.2f} MB', '---------------------------']
|
52 |
+
|
53 |
+
for name, child in module.named_children():
|
54 |
+
child_size = estimate_size(child)
|
55 |
+
output.append(f'{name} size: {child_size / 1024**2:.2f} MB')
|
56 |
+
|
57 |
+
return '\n'.join(output)
|
58 |
+
|
59 |
+
def layer_param_distribution(module):
|
60 |
+
def count_parameters(model):
|
61 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
62 |
+
|
63 |
+
def get_layer_types(model):
|
64 |
+
layer_types = {}
|
65 |
+
for name, module in model.named_modules():
|
66 |
+
layer_type = module.__class__.__name__
|
67 |
+
params = sum(p.numel() for p in module.parameters(recurse=False) if p.requires_grad)
|
68 |
+
if params > 0:
|
69 |
+
if layer_type not in layer_types:
|
70 |
+
layer_types[layer_type] = 0
|
71 |
+
layer_types[layer_type] += params
|
72 |
+
return layer_types
|
73 |
+
|
74 |
+
total_params = count_parameters(module)
|
75 |
+
layer_types = get_layer_types(module)
|
76 |
+
|
77 |
+
output = [f'Total trainable parameters: {total_params:,}', '---------------------------']
|
78 |
+
|
79 |
+
for layer_type, count in sorted(layer_types.items(), key=lambda x: x[1], reverse=True):
|
80 |
+
percentage = (count / total_params) * 100
|
81 |
+
output.append(f'{layer_type}: {count:,} ({percentage:.2f}%)')
|
82 |
+
|
83 |
+
return '\n'.join(output)
|
84 |
+
|