ShivamMore commited on
Commit
2e6f087
·
1 Parent(s): 204a314

commit name

Browse files
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ .venv
2
+ *.wav
3
+ *.mp3
4
+ *.m4a
5
+ !prompts/*.wav
6
+ !prompts/*.mp3
7
+ !prompts/*.m4a
8
+ __pycache__
9
+ *ckpt
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2024 Standard Intelligence PBC
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
inference.ipynb ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "%load_ext autoreload\n",
10
+ "%autoreload 2"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "metadata": {},
17
+ "outputs": [],
18
+ "source": [
19
+ "import torch as T\n",
20
+ "import torch.nn as nn\n",
21
+ "import torch.nn.functional as F\n",
22
+ "import torchaudio\n",
23
+ "from utils import load_ckpt, print_colored\n",
24
+ "from tokenizer import make_tokenizer\n",
25
+ "from model import get_hertz_dev_config\n",
26
+ "import matplotlib.pyplot as plt\n",
27
+ "from IPython.display import Audio, display\n",
28
+ "\n",
29
+ "\n",
30
+ "# If you get an error like \"undefined symbol: __nvJitLinkComplete_12_4, version libnvJitLink.so.12\",\n",
31
+ "# you need to install PyTorch with the correct CUDA version. Run:\n",
32
+ "# `pip3 uninstall torch torchaudio && pip3 install torch torchaudio --index-url https://download.pytorch.org/whl/cu121`\n",
33
+ "\n",
34
+ "device = 'cuda' if T.cuda.is_available() else 'cpu'\n",
35
+ "T.cuda.set_device(0)\n",
36
+ "print_colored(f\"Using device: {device}\", \"grey\")"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": null,
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "# This code will automatically download them if it can't find them.\n",
46
+ "audio_tokenizer = make_tokenizer(device)"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": 7,
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "# We have different checkpoints for the single-speaker and two-speaker models\n",
56
+ "# Set to True to load and run inference with the two-speaker model\n",
57
+ "TWO_SPEAKER = False\n",
58
+ "USE_PURE_AUDIO_ABLATION = False # We trained a base model with no text initialization at all. Toggle this to enable it.\n",
59
+ "assert not (USE_PURE_AUDIO_ABLATION and TWO_SPEAKER) # We only have a single-speaker version of this model.\n"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": null,
65
+ "metadata": {},
66
+ "outputs": [],
67
+ "source": [
68
+ "model_config = get_hertz_dev_config(is_split=TWO_SPEAKER, use_pure_audio_ablation=USE_PURE_AUDIO_ABLATION)\n",
69
+ "\n",
70
+ "generator = model_config()\n",
71
+ "generator = generator.eval().to(T.bfloat16).to(device)"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": null,
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "def load_and_preprocess_audio(audio_path):\n",
81
+ " print_colored(\"Loading and preprocessing audio...\", \"blue\", bold=True)\n",
82
+ " # Load audio file\n",
83
+ " audio_tensor, sr = torchaudio.load(audio_path)\n",
84
+ " print_colored(f\"Loaded audio shape: {audio_tensor.shape}\", \"grey\")\n",
85
+ " \n",
86
+ " if TWO_SPEAKER:\n",
87
+ " if audio_tensor.shape[0] == 1:\n",
88
+ " print_colored(\"Converting mono to stereo...\", \"grey\")\n",
89
+ " audio_tensor = audio_tensor.repeat(2, 1)\n",
90
+ " print_colored(f\"Stereo audio shape: {audio_tensor.shape}\", \"grey\")\n",
91
+ " else:\n",
92
+ " if audio_tensor.shape[0] == 2:\n",
93
+ " print_colored(\"Converting stereo to mono...\", \"grey\")\n",
94
+ " audio_tensor = audio_tensor.mean(dim=0).unsqueeze(0)\n",
95
+ " print_colored(f\"Mono audio shape: {audio_tensor.shape}\", \"grey\")\n",
96
+ " \n",
97
+ " # Resample to 16kHz if needed\n",
98
+ " if sr != 16000:\n",
99
+ " print_colored(f\"Resampling from {sr}Hz to 16000Hz...\", \"grey\")\n",
100
+ " resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)\n",
101
+ " audio_tensor = resampler(audio_tensor)\n",
102
+ " \n",
103
+ " # Clip to 5 minutes if needed\n",
104
+ " max_samples = 16000 * 60 * 5\n",
105
+ " if audio_tensor.shape[1] > max_samples:\n",
106
+ " print_colored(\"Clipping audio to 5 minutes...\", \"grey\")\n",
107
+ " audio_tensor = audio_tensor[:, :max_samples]\n",
108
+ "\n",
109
+ " \n",
110
+ " print_colored(\"Audio preprocessing complete!\", \"green\")\n",
111
+ " return audio_tensor.unsqueeze(0)\n",
112
+ "\n",
113
+ "def display_audio(audio_tensor):\n",
114
+ " audio_tensor = audio_tensor.cpu().squeeze()\n",
115
+ " if audio_tensor.ndim == 1:\n",
116
+ " audio_tensor = audio_tensor.unsqueeze(0)\n",
117
+ " audio_tensor = audio_tensor.float()\n",
118
+ "\n",
119
+ " # Make a waveform plot\n",
120
+ " plt.figure(figsize=(4, 1))\n",
121
+ " plt.plot(audio_tensor.numpy()[0], linewidth=0.5)\n",
122
+ " plt.axis('off')\n",
123
+ " plt.show()\n",
124
+ "\n",
125
+ " # Make an audio player\n",
126
+ " display(Audio(audio_tensor.numpy(), rate=16000))\n",
127
+ " print_colored(f\"Audio ready for playback ↑\", \"green\", bold=True)\n",
128
+ " \n",
129
+ " \n",
130
+ "\n",
131
+ "# Our model is very prompt-sensitive, so we recommend experimenting with a diverse set of prompts.\n",
132
+ "prompt_audio = load_and_preprocess_audio('./prompts/toaskanymore.wav')\n",
133
+ "display_audio(prompt_audio)\n",
134
+ "prompt_len_seconds = 3\n",
135
+ "prompt_len = prompt_len_seconds * 8"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "execution_count": null,
141
+ "metadata": {},
142
+ "outputs": [],
143
+ "source": [
144
+ "print_colored(\"Encoding prompt...\", \"blue\")\n",
145
+ "with T.autocast(device_type='cuda', dtype=T.bfloat16):\n",
146
+ " if TWO_SPEAKER:\n",
147
+ " encoded_prompt_audio_ch1 = audio_tokenizer.latent_from_data(prompt_audio[:, 0:1].to(device))\n",
148
+ " encoded_prompt_audio_ch2 = audio_tokenizer.latent_from_data(prompt_audio[:, 1:2].to(device))\n",
149
+ " encoded_prompt_audio = T.cat([encoded_prompt_audio_ch1, encoded_prompt_audio_ch2], dim=-1)\n",
150
+ " else:\n",
151
+ " encoded_prompt_audio = audio_tokenizer.latent_from_data(prompt_audio.to(device))\n",
152
+ "print_colored(f\"Encoded prompt shape: {encoded_prompt_audio.shape}\", \"grey\")\n",
153
+ "print_colored(\"Prompt encoded successfully!\", \"green\")"
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "execution_count": null,
159
+ "metadata": {},
160
+ "outputs": [],
161
+ "source": [
162
+ "def get_completion(encoded_prompt_audio, prompt_len, gen_len=None):\n",
163
+ " prompt_len_seconds = prompt_len / 8\n",
164
+ " print_colored(f\"Prompt length: {prompt_len_seconds:.2f}s\", \"grey\")\n",
165
+ " print_colored(\"Completing audio...\", \"blue\")\n",
166
+ " encoded_prompt_audio = encoded_prompt_audio[:, :prompt_len]\n",
167
+ " with T.autocast(device_type='cuda', dtype=T.bfloat16):\n",
168
+ " completed_audio_batch = generator.completion(\n",
169
+ " encoded_prompt_audio, \n",
170
+ " temps=(.8, (0.5, 0.1)), # (token_temp, (categorical_temp, gaussian_temp))\n",
171
+ " use_cache=True,\n",
172
+ " gen_len=gen_len)\n",
173
+ "\n",
174
+ " completed_audio = completed_audio_batch\n",
175
+ " print_colored(f\"Decoding completion...\", \"blue\")\n",
176
+ " if TWO_SPEAKER:\n",
177
+ " decoded_completion_ch1 = audio_tokenizer.data_from_latent(completed_audio[:, :, :32].bfloat16())\n",
178
+ " decoded_completion_ch2 = audio_tokenizer.data_from_latent(completed_audio[:, :, 32:].bfloat16())\n",
179
+ " decoded_completion = T.cat([decoded_completion_ch1, decoded_completion_ch2], dim=0)\n",
180
+ " else:\n",
181
+ " decoded_completion = audio_tokenizer.data_from_latent(completed_audio.bfloat16())\n",
182
+ " print_colored(f\"Decoded completion shape: {decoded_completion.shape}\", \"grey\")\n",
183
+ "\n",
184
+ " print_colored(\"Preparing audio for playback...\", \"blue\")\n",
185
+ "\n",
186
+ " audio_tensor = decoded_completion.cpu().squeeze()\n",
187
+ " if audio_tensor.ndim == 1:\n",
188
+ " audio_tensor = audio_tensor.unsqueeze(0)\n",
189
+ " audio_tensor = audio_tensor.float()\n",
190
+ "\n",
191
+ " if audio_tensor.abs().max() > 1:\n",
192
+ " audio_tensor = audio_tensor / audio_tensor.abs().max()\n",
193
+ "\n",
194
+ " return audio_tensor[:, max(prompt_len*2000 - 16000, 0):]\n",
195
+ "\n",
196
+ "num_completions = 10\n",
197
+ "print_colored(f\"Generating {num_completions} completions...\", \"blue\")\n",
198
+ "for _ in range(num_completions):\n",
199
+ " completion = get_completion(encoded_prompt_audio, prompt_len, gen_len=20*8) # 20 seconds of generation\n",
200
+ " display_audio(completion)"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": null,
206
+ "metadata": {},
207
+ "outputs": [],
208
+ "source": []
209
+ },
210
+ {
211
+ "cell_type": "code",
212
+ "execution_count": null,
213
+ "metadata": {},
214
+ "outputs": [],
215
+ "source": []
216
+ }
217
+ ],
218
+ "metadata": {
219
+ "kernelspec": {
220
+ "display_name": ".venv",
221
+ "language": "python",
222
+ "name": "python3"
223
+ },
224
+ "language_info": {
225
+ "codemirror_mode": {
226
+ "name": "ipython",
227
+ "version": 3
228
+ },
229
+ "file_extension": ".py",
230
+ "mimetype": "text/x-python",
231
+ "name": "python",
232
+ "nbconvert_exporter": "python",
233
+ "pygments_lexer": "ipython3",
234
+ "version": "3.10.12"
235
+ }
236
+ },
237
+ "nbformat": 4,
238
+ "nbformat_minor": 2
239
+ }
inference_client.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # server.py remains the same as before
2
+
3
+ # Updated client.py
4
+ import asyncio
5
+ import websockets
6
+ import sounddevice as sd
7
+ import numpy as np
8
+ import base64
9
+ import queue
10
+ import argparse
11
+ import requests
12
+ import time
13
+
14
+ class AudioClient:
15
+ def __init__(self, server_url="ws://localhost:8000", token_temp=None, categorical_temp=None, gaussian_temp=None):
16
+ # Convert ws:// to http:// for the base URL
17
+ self.base_url = server_url.replace("ws://", "http://")
18
+ self.server_url = f"{server_url}/audio"
19
+
20
+ # Set temperatures if provided
21
+ if any(t is not None for t in [token_temp, categorical_temp, gaussian_temp]):
22
+ self.set_temperature_and_echo(token_temp, categorical_temp, gaussian_temp)
23
+
24
+ # Initialize queues
25
+ self.audio_queue = queue.Queue()
26
+ self.output_queue = queue.Queue()
27
+
28
+ def set_temperature_and_echo(self, token_temp=None, categorical_temp=None, gaussian_temp=None, echo_testing = False):
29
+ """Send temperature settings to server"""
30
+ params = {}
31
+ if token_temp is not None:
32
+ params['token_temp'] = token_temp
33
+ if categorical_temp is not None:
34
+ params['categorical_temp'] = categorical_temp
35
+ if gaussian_temp is not None:
36
+ params['gaussian_temp'] = gaussian_temp
37
+
38
+ response = requests.post(f"{self.base_url}/set_temperature", params=params)
39
+ print(response.json()['message'])
40
+
41
+ def audio_callback(self, indata, frames, time, status):
42
+ """This is called for each audio block"""
43
+ if status:
44
+ print(status)
45
+ # if np.isclose(indata, 0).all():
46
+ # raise Exception('Audio input is not working - received all zeros')
47
+ # Convert float32 to int16 for efficient transmission
48
+ indata_int16 = (indata.copy() * 32767).astype(np.int16)
49
+ # indata_int16 = np.zeros_like(indata_int16)
50
+ self.audio_queue.put(indata_int16)
51
+
52
+ def output_stream_callback(self, outdata, frames, time, status):
53
+ """Callback for output stream to get audio data"""
54
+ if status:
55
+ print(status)
56
+
57
+ try:
58
+ data = self.output_queue.get_nowait()
59
+ data = data.astype(np.float32) / 32767.0
60
+ if len(data) < len(outdata):
61
+ outdata[:len(data)] = data
62
+ outdata[len(data):] = 0
63
+ else:
64
+ outdata[:] = data[:len(outdata)]
65
+ except queue.Empty:
66
+ outdata.fill(0)
67
+
68
+ async def process_audio(self):
69
+ async with websockets.connect(self.server_url) as ws:
70
+ while self.running:
71
+ if not self.audio_queue.empty():
72
+ # Get recorded audio
73
+ audio_data = self.audio_queue.get()
74
+ print(f'Data from microphone:{audio_data.shape, audio_data.dtype, audio_data.min(), audio_data.max()}')
75
+
76
+ # Convert to base64
77
+ audio_b64 = base64.b64encode(audio_data.tobytes()).decode('utf-8')
78
+
79
+ # Send to server
80
+ time_sent = time.time()
81
+ await ws.send(f"data:audio/raw;base64,{audio_b64}")
82
+
83
+ # Receive processed audio
84
+ response = await ws.recv()
85
+ response = response.split(",")[1]
86
+ time_received = time.time()
87
+ print(f"Data sent: {audio_b64[:10]}. Data received: {response[:10]}. Received in {(time_received - time_sent) * 1000:.2f} ms")
88
+ processed_audio = np.frombuffer(
89
+ base64.b64decode(response),
90
+ dtype=np.int16
91
+ ).reshape(-1, CHANNELS)
92
+ print(f'Data from model:{processed_audio.shape, processed_audio.dtype, processed_audio.min(), processed_audio.max()}')
93
+
94
+ self.output_queue.put(processed_audio)
95
+
96
+ def start(self):
97
+ self.running = True
98
+ # Print audio device information
99
+ devices = sd.query_devices()
100
+ default_input = sd.query_devices(kind='input')
101
+ default_output = sd.query_devices(kind='output')
102
+
103
+ print("\nAudio Device Configuration:")
104
+ print("-" * 50)
105
+ print(f"Default Input Device:\n{default_input}\n")
106
+ print(f"Default Output Device:\n{default_output}\n")
107
+ print("\nAll Available Devices:")
108
+ print("-" * 50)
109
+ for i, device in enumerate(devices):
110
+ print(f"Device {i}:")
111
+ print(f"Name: {device['name']}")
112
+ print(f"Channels (in/out): {device['max_input_channels']}/{device['max_output_channels']}")
113
+ print(f"Sample Rates: {device['default_samplerate']}")
114
+ print()
115
+ input_device = input("Enter the index of the input device or press enter for default: ")
116
+ output_device = input("Enter the index of the output device or press enter for default: ")
117
+ if input_device == "":
118
+ input_device = default_input['index']
119
+ if output_device == "":
120
+ output_device = default_output['index']
121
+ with sd.InputStream(callback=self.audio_callback,
122
+ channels=CHANNELS,
123
+ samplerate=SAMPLE_RATE,
124
+ device=int(input_device),
125
+ blocksize=2000), \
126
+ sd.OutputStream(callback=self.output_stream_callback,
127
+ channels=CHANNELS,
128
+ samplerate=SAMPLE_RATE,
129
+ blocksize=2000,
130
+ device=int(output_device)):
131
+
132
+ asyncio.run(self.process_audio())
133
+
134
+ def stop(self):
135
+ self.running = False
136
+
137
+ if __name__ == "__main__":
138
+ parser = argparse.ArgumentParser(description='Audio Client with Temperature Control')
139
+ parser.add_argument('--token_temp', '-t1', type=float, help='Token (LM) temperature parameter')
140
+ parser.add_argument('--categorical_temp', '-t2', type=float, help='Categorical (VAE) temperature parameter')
141
+ parser.add_argument('--gaussian_temp', '-t3', type=float, help='Gaussian (VAE) temperature parameter')
142
+ parser.add_argument('--server', '-s', default="ws://localhost:8000",
143
+ help='Server URL (default: ws://localhost:8000)')
144
+
145
+ args = parser.parse_args()
146
+
147
+ # Audio settings
148
+ SAMPLE_RATE = 16000
149
+ CHANNELS = 1
150
+
151
+ client = AudioClient(
152
+ server_url=args.server,
153
+ token_temp=args.token_temp,
154
+ categorical_temp=args.categorical_temp,
155
+ gaussian_temp=args.gaussian_temp
156
+ )
157
+
158
+ try:
159
+ client.start()
160
+ except KeyboardInterrupt:
161
+ client.stop()
inference_client_webrtc.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # server.py remains the same as before
2
+
3
+ # Updated client.py
4
+ import asyncio
5
+ import websockets
6
+ import numpy as np
7
+ import base64
8
+ import argparse
9
+ import requests
10
+ import time
11
+ import torch
12
+ import torchaudio
13
+
14
+ import av
15
+ import streamlit as st
16
+ from typing import List
17
+ from streamlit_webrtc import WebRtcMode, webrtc_streamer
18
+
19
+ class AudioClient:
20
+ def __init__(self, server_url="ws://localhost:8000", token_temp=None, categorical_temp=None, gaussian_temp=None):
21
+ # Convert ws:// to http:// for the base URL
22
+ self.base_url = server_url.replace("ws://", "http://")
23
+ self.server_url = f"{server_url}/audio"
24
+ self.sound_check = False
25
+
26
+ # Set temperatures if provided
27
+ if any(t is not None for t in [token_temp, categorical_temp, gaussian_temp]):
28
+ response_message = self.set_temperature_and_echo(token_temp, categorical_temp, gaussian_temp)
29
+ print(response_message)
30
+
31
+ self.downsampler = torchaudio.transforms.Resample(STREAMING_SAMPLE_RATE, SAMPLE_RATE)
32
+ self.upsampler = torchaudio.transforms.Resample(SAMPLE_RATE, STREAMING_SAMPLE_RATE)
33
+ self.ws = None
34
+ self.in_buffer = None
35
+ self.out_buffer = None
36
+
37
+ def set_temperature_and_echo(self, token_temp=None, categorical_temp=None, gaussian_temp=None, echo_testing = False):
38
+ """Send temperature settings to server"""
39
+ params = {}
40
+ if token_temp is not None:
41
+ params['token_temp'] = token_temp
42
+ if categorical_temp is not None:
43
+ params['categorical_temp'] = categorical_temp
44
+ if gaussian_temp is not None:
45
+ params['gaussian_temp'] = gaussian_temp
46
+
47
+ response = requests.post(f"{self.base_url}/set_temperature", params=params)
48
+ response_message = response.json()['message']
49
+ return response_message
50
+
51
+ def _resample(self, audio_data: np.ndarray, resampler: torchaudio.transforms.Resample) -> np.ndarray:
52
+ audio_data = audio_data.astype(np.float32) / 32767.0
53
+ audio_data = resampler(torch.tensor(audio_data)).numpy()
54
+ audio_data = (audio_data * 32767.0).astype(np.int16)
55
+ return audio_data
56
+
57
+ def upsample(self, audio_data: np.ndarray) -> np.ndarray:
58
+ return self._resample(audio_data, self.upsampler)
59
+
60
+ def downsample(self, audio_data: np.ndarray) -> np.ndarray:
61
+ return self._resample(audio_data, self.downsampler)
62
+
63
+ def from_s16_format(self, audio_data: np.ndarray, channels: int) -> np.ndarray:
64
+ if channels == 2:
65
+ audio_data = audio_data.reshape(-1, 2).T
66
+ else:
67
+ audio_data = audio_data.reshape(-1)
68
+ return audio_data
69
+
70
+ def to_s16_format(self, audio_data: np.ndarray):
71
+ if len(audio_data.shape) == 2 and audio_data.shape[0] == 2:
72
+ audio_data = audio_data.T.reshape(1, -1)
73
+ elif len(audio_data.shape) == 1:
74
+ audio_data = audio_data.reshape(1, -1)
75
+ return audio_data
76
+
77
+ def to_channels(self, audio_data: np.ndarray, channels: int) -> np.ndarray:
78
+ current_channels = audio_data.shape[0] if len(audio_data.shape) == 2 else 1
79
+ if current_channels == channels:
80
+ return audio_data
81
+ elif current_channels == 1 and channels == 2:
82
+ audio_data = np.tile(audio_data, 2).reshape(2, -1)
83
+ elif current_channels == 2 and channels == 1:
84
+ audio_data = audio_data.astype(np.float32) / 32767.0
85
+ audio_data = audio_data.mean(axis=0)
86
+ audio_data = (audio_data * 32767.0).astype(np.int16)
87
+ return audio_data
88
+
89
+ async def process_audio(self, audio_data: np.ndarray) -> np.ndarray:
90
+ if self.ws is None:
91
+ self.ws = await websockets.connect(self.server_url)
92
+
93
+ audio_data = audio_data.reshape(-1, CHANNELS)
94
+ print(f'Data from microphone:{audio_data.shape, audio_data.dtype, audio_data.min(), audio_data.max()}')
95
+
96
+ # Convert to base64
97
+ audio_b64 = base64.b64encode(audio_data.tobytes()).decode('utf-8')
98
+
99
+ # Send to server
100
+ time_sent = time.time()
101
+ await self.ws.send(f"data:audio/raw;base64,{audio_b64}")
102
+
103
+ # Receive processed audio
104
+ response = await self.ws.recv()
105
+ response = response.split(",")[1]
106
+ time_received = time.time()
107
+ print(f"Data sent: {audio_b64[:10]}. Data received: {response[:10]}. Received in {(time_received - time_sent) * 1000:.2f} ms")
108
+ processed_audio = np.frombuffer(
109
+ base64.b64decode(response),
110
+ dtype=np.int16
111
+ ).reshape(-1, CHANNELS)
112
+ print(f'Data from model:{processed_audio.shape, processed_audio.dtype, processed_audio.min(), processed_audio.max()}')
113
+
114
+ if CHANNELS == 1:
115
+ processed_audio = processed_audio.reshape(-1)
116
+ return processed_audio
117
+
118
+ async def queued_audio_frames_callback(self, frames: List[av.AudioFrame]) -> List[av.AudioFrame]:
119
+ out_frames = []
120
+ for frame in frames:
121
+ # Read in audio
122
+ audio_data = frame.to_ndarray()
123
+
124
+ # Convert input audio from s16 format, convert to `CHANNELS` number of channels, and downsample
125
+ audio_data = self.from_s16_format(audio_data, len(frame.layout.channels))
126
+ audio_data = self.to_channels(audio_data, CHANNELS)
127
+ audio_data = self.downsample(audio_data)
128
+
129
+ # Add audio to input buffer
130
+ if self.in_buffer is None:
131
+ self.in_buffer = audio_data
132
+ else:
133
+ self.in_buffer = np.concatenate((self.in_buffer, audio_data), axis=-1)
134
+
135
+ # Take BLOCK_SIZE samples from input buffer if available for processing
136
+ if self.in_buffer.shape[0] >= BLOCK_SIZE:
137
+ audio_data = self.in_buffer[:BLOCK_SIZE]
138
+ self.in_buffer = self.in_buffer[BLOCK_SIZE:]
139
+ else:
140
+ audio_data = None
141
+
142
+ # Process audio if available and add resulting audio to output buffer
143
+ if audio_data is not None:
144
+ if not self.sound_check:
145
+ audio_data = await self.process_audio(audio_data)
146
+ if self.out_buffer is None:
147
+ self.out_buffer = audio_data
148
+ else:
149
+ self.out_buffer = np.concatenate((self.out_buffer, audio_data), axis=-1)
150
+
151
+ # Take `out_samples` samples from output buffer if available for output
152
+ out_samples = int(frame.samples * SAMPLE_RATE / STREAMING_SAMPLE_RATE)
153
+ if self.out_buffer is not None and self.out_buffer.shape[0] >= out_samples:
154
+ audio_data = self.out_buffer[:out_samples]
155
+ self.out_buffer = self.out_buffer[out_samples:]
156
+ else:
157
+ audio_data = None
158
+
159
+ # Output silence if no audio data available
160
+ if audio_data is None:
161
+ # output silence
162
+ audio_data = np.zeros(out_samples, dtype=np.int16)
163
+
164
+ # Upsample output audio, convert to original number of channels, and convert to s16 format
165
+ audio_data = self.upsample(audio_data)
166
+ audio_data = self.to_channels(audio_data, len(frame.layout.channels))
167
+ audio_data = self.to_s16_format(audio_data)
168
+
169
+ # return audio data as AudioFrame
170
+ new_frame = av.AudioFrame.from_ndarray(audio_data, format=frame.format.name, layout=frame.layout.name)
171
+ new_frame.sample_rate = frame.sample_rate
172
+ out_frames.append(new_frame)
173
+
174
+ return out_frames
175
+
176
+ def stop(self):
177
+ if self.ws is not None:
178
+ # TODO: this hangs. Figure out why.
179
+ #asyncio.get_event_loop().run_until_complete(self.ws.close())
180
+ print("Websocket closed")
181
+ self.ws = None
182
+ self.in_buffer = None
183
+ self.out_buffer = None
184
+
185
+ if __name__ == "__main__":
186
+ parser = argparse.ArgumentParser(description='Audio Client with Temperature Control')
187
+ parser.add_argument('--token_temp', '-t1', type=float, help='Token (LM) temperature parameter')
188
+ parser.add_argument('--categorical_temp', '-t2', type=float, help='Categorical (VAE) temperature parameter')
189
+ parser.add_argument('--gaussian_temp', '-t3', type=float, help='Gaussian (VAE) temperature parameter')
190
+ parser.add_argument('--server', '-s', default="ws://localhost:8000",
191
+ help='Server URL (default: ws://localhost:8000)')
192
+ parser.add_argument("--use_ice_servers", action="store_true", help="Use public STUN servers")
193
+
194
+ args = parser.parse_args()
195
+
196
+ # Audio settings
197
+ STREAMING_SAMPLE_RATE = 48000
198
+ SAMPLE_RATE = 16000
199
+ BLOCK_SIZE = 2000
200
+ CHANNELS = 1
201
+
202
+ st.title("hertz-dev webrtc demo!")
203
+ st.markdown("""
204
+ Welcome to the audio processing interface! Here you can talk live with hertz.
205
+ - Process audio in real-time through your microphone
206
+ - Adjust various temperature parameters for inference
207
+ - Test your microphone with sound check mode
208
+ - Enable/disable echo cancellation and noise suppression
209
+
210
+ To begin, click the START button below and allow microphone access.
211
+ """)
212
+
213
+ audio_client = st.session_state.get("audio_client")
214
+ if audio_client is None:
215
+ audio_client = AudioClient(
216
+ server_url=args.server,
217
+ token_temp=args.token_temp,
218
+ categorical_temp=args.categorical_temp,
219
+ gaussian_temp=args.gaussian_temp
220
+ )
221
+ st.session_state.audio_client = audio_client
222
+
223
+ with st.sidebar:
224
+ st.markdown("## Inference Settings")
225
+ token_temp_default = args.token_temp if args.token_temp is not None else 0.8
226
+ token_temp = st.slider("Token Temperature", 0.05, 2.0, token_temp_default, step=0.05)
227
+ categorical_temp_default = args.categorical_temp if args.categorical_temp is not None else 0.4
228
+ categorical_temp = st.slider("Categorical Temperature", 0.01, 1.0, categorical_temp_default, step=0.01)
229
+ gaussian_temp_default = args.gaussian_temp if args.gaussian_temp is not None else 0.1
230
+ gaussian_temp = st.slider("Gaussian Temperature", 0.01, 1.0, gaussian_temp_default, step=0.01)
231
+ if st.button("Set Temperatures"):
232
+ response_message = audio_client.set_temperature_and_echo(token_temp, categorical_temp, gaussian_temp)
233
+ st.write(response_message)
234
+
235
+ st.markdown("## Microphone Settings")
236
+ audio_client.sound_check = st.toggle("Sound Check (Echo)", value=False)
237
+ echo_cancellation = st.toggle("Echo Cancellation*‡", value=False)
238
+ noise_suppression = st.toggle("Noise Suppression*", value=False)
239
+ st.markdown(r"\* *Restart stream to take effect*")
240
+ st.markdown("‡ *May cause audio to cut out*")
241
+
242
+ # Use a free STUN server from Google if --use_ice_servers is given
243
+ # (found in get_ice_servers() at https://github.com/whitphx/streamlit-webrtc/blob/main/sample_utils/turn.py)
244
+ rtc_configuration = {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]} if args.use_ice_servers else None
245
+ audio_config = {"echoCancellation": echo_cancellation, "noiseSuppression": noise_suppression}
246
+ webrtc_streamer(
247
+ key="streamer",
248
+ mode=WebRtcMode.SENDRECV,
249
+ rtc_configuration=rtc_configuration,
250
+ media_stream_constraints={"audio": audio_config, "video": False},
251
+ queued_audio_frames_callback=audio_client.queued_audio_frames_callback,
252
+ on_audio_ended=audio_client.stop,
253
+ async_processing=True,
254
+ )
255
+
inference_server.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import numpy as np
3
+ from fastapi import FastAPI, WebSocket
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ import base64
6
+ import uvicorn
7
+ import traceback
8
+ import numpy as np
9
+ import argparse
10
+
11
+ import torch as T
12
+ import torch.nn.functional as F
13
+ import torchaudio
14
+
15
+ import os
16
+ from typing import Optional
17
+
18
+ from utils import print_colored
19
+ from model import get_hertz_dev_config
20
+
21
+
22
+ argparse = argparse.ArgumentParser()
23
+
24
+ argparse.add_argument('--prompt_path', type=str, default='./prompts/bob_mono.wav', help="""
25
+ We highly recommend making your own prompt based on a conversation between you and another person.
26
+ bob_mono.wav seems to work better for two-channel than bob_stereo.wav.
27
+ """)
28
+ args = argparse.parse_args()
29
+
30
+
31
+ device = 'cuda' if T.cuda.is_available() else T.device('cpu')
32
+ print_colored(f"Using device: {device}", "grey")
33
+
34
+ model_config = get_hertz_dev_config(is_split=True)
35
+
36
+ model = model_config()
37
+ model = model.eval().bfloat16().to(device)
38
+
39
+ app = FastAPI()
40
+
41
+ app.add_middleware(
42
+ CORSMiddleware,
43
+ allow_origins=["*"],
44
+ allow_credentials=True,
45
+ allow_methods=["*"],
46
+ allow_headers=["*"],
47
+ )
48
+
49
+
50
+ # Hyperparams or something.
51
+ SAMPLE_RATE = 16000 # Don't change this
52
+ TEMPS = (0.8, (0.4, 0.1)) # You can change this, but there's also an endpoint for it.
53
+ REPLAY_SECONDS = 3 # What the user hears as context.
54
+
55
+ class AudioProcessor:
56
+ def __init__(self, model, prompt_path):
57
+ self.model = model
58
+ self.prompt_path = prompt_path
59
+ self.initialize_state(prompt_path)
60
+
61
+ def initialize_state(self, prompt_path):
62
+ loaded_audio, sr = torchaudio.load(prompt_path)
63
+ self.replay_seconds = REPLAY_SECONDS
64
+
65
+ if sr != SAMPLE_RATE:
66
+ resampler = torchaudio.transforms.Resample(sr, SAMPLE_RATE)
67
+ loaded_audio = resampler(loaded_audio)
68
+
69
+ if loaded_audio.shape[0] == 1:
70
+ loaded_audio = loaded_audio.repeat(2, 1)
71
+
72
+ audio_length = loaded_audio.shape[-1]
73
+ num_chunks = audio_length // 2000
74
+ loaded_audio = loaded_audio[..., :num_chunks * 2000]
75
+
76
+ self.loaded_audio = loaded_audio.to(device)
77
+
78
+ with T.autocast(device_type=device, dtype=T.bfloat16), T.inference_mode():
79
+ self.model.init_cache(bsize=1, device=device, dtype=T.bfloat16, length=1024)
80
+ self.next_model_audio = self.model.next_audio_from_audio(self.loaded_audio.unsqueeze(0), temps=TEMPS)
81
+ self.prompt_buffer = None
82
+ self.prompt_position = 0
83
+ self.chunks_until_live = int(self.replay_seconds * 8)
84
+ self.initialize_prompt_buffer()
85
+ print_colored("AudioProcessor state initialized", "green")
86
+
87
+ def initialize_prompt_buffer(self):
88
+ self.recorded_audio = self.loaded_audio
89
+ prompt_audio = self.loaded_audio.reshape(1, 2, -1)
90
+ prompt_audio = prompt_audio[:, :, -(16000*self.replay_seconds):].cpu().numpy()
91
+ prompt_audio_mono = prompt_audio.mean(axis=1)
92
+ self.prompt_buffer = np.array_split(prompt_audio_mono[0], int(self.replay_seconds * 8))
93
+ print_colored(f"Initialized prompt buffer with {len(self.prompt_buffer)} chunks", "grey")
94
+
95
+ async def process_audio(self, audio_data):
96
+ if self.chunks_until_live > 0:
97
+ print_colored(f"Serving from prompt buffer, {self.chunks_until_live} chunks left", "grey")
98
+ chunk = self.prompt_buffer[int(self.replay_seconds * 8) - self.chunks_until_live]
99
+ self.chunks_until_live -= 1
100
+
101
+ if self.chunks_until_live == 0:
102
+ print_colored("Switching to live processing mode", "green")
103
+
104
+ time.sleep(0.05)
105
+ return chunk
106
+
107
+ audio_tensor = T.from_numpy(audio_data).to(device)
108
+ audio_tensor = audio_tensor.reshape(1, 1, -1)
109
+ audio_tensor = T.cat([audio_tensor, self.next_model_audio], dim=1)
110
+
111
+ with T.autocast(device_type=device, dtype=T.bfloat16), T.inference_mode():
112
+ curr_model_audio = self.model.next_audio_from_audio(
113
+ audio_tensor,
114
+ temps=TEMPS
115
+ )
116
+ print(f"Recorded audio shape {self.recorded_audio.shape}, audio tensor shape {audio_tensor.shape}")
117
+ self.recorded_audio = T.cat([self.recorded_audio.cpu(), audio_tensor.squeeze(0).cpu()], dim=-1)
118
+
119
+ self.next_model_audio = curr_model_audio
120
+
121
+ return curr_model_audio.float().cpu().numpy()
122
+
123
+ def cleanup(self):
124
+ print_colored("Cleaning up audio processor...", "blue")
125
+ os.makedirs('audio_recordings', exist_ok=True)
126
+ torchaudio.save(f'audio_recordings/{time.strftime("%d-%H-%M")}.wav', self.recorded_audio.cpu(), SAMPLE_RATE)
127
+ self.model.deinit_cache()
128
+ self.initialize_state(self.prompt_path)
129
+ print_colored("Audio processor cleanup complete", "green")
130
+
131
+ @app.post("/set_temperature")
132
+ async def set_temperature(token_temp: Optional[float] = None, categorical_temp: Optional[float] = None, gaussian_temp: Optional[float] = None):
133
+ try:
134
+ global TEMPS
135
+ TEMPS = (token_temp, (categorical_temp, gaussian_temp))
136
+
137
+ print_colored(f"Temperature updated to: {TEMPS}", "green")
138
+ return {"message": f"Temperature updated to: {TEMPS}", "status": "success"}
139
+ except Exception as e:
140
+ print_colored(f"Error setting temperature: {str(e)}", "red")
141
+ return {"message": f"Error setting temperature: {str(e)}", "status": "error"}
142
+
143
+ @app.websocket("/audio")
144
+ async def websocket_endpoint(websocket: WebSocket):
145
+ await websocket.accept()
146
+ try:
147
+ while True:
148
+ data = await websocket.receive_text()
149
+ audio_data = np.frombuffer(
150
+ base64.b64decode(data.split(",")[1]),
151
+ dtype=np.int16
152
+ )
153
+ audio_data = audio_data.astype(np.float32) / 32767.0
154
+ processed_audio = await audio_processor.process_audio(audio_data)
155
+ processed_audio = (processed_audio * 32767).astype(np.int16)
156
+
157
+ processed_data = base64.b64encode(processed_audio.tobytes()).decode('utf-8')
158
+ await websocket.send_text(f"data:audio/raw;base64,{processed_data}")
159
+
160
+ except Exception as e:
161
+ print_colored(f"WebSocket error: {e}", "red")
162
+ print_colored(f"Full traceback:\n{traceback.format_exc()}", "red")
163
+ finally:
164
+ audio_processor.cleanup()
165
+ await websocket.close()
166
+
167
+
168
+ audio_processor = AudioProcessor(model=model, prompt_path=args.prompt_path)
169
+
170
+ if __name__ == "__main__":
171
+ uvicorn.run(app, host="0.0.0.0", port=8000)
172
+ print("Server started")
ioblocks.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from functools import partial
3
+ from contextlib import nullcontext
4
+ from typing import List, Tuple
5
+ from math import ceil
6
+
7
+ import torch as T
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.distributed as dist
11
+ from torch import Tensor, int32
12
+ from torch.amp import autocast
13
+
14
+ from einops import rearrange, pack, unpack
15
+
16
+
17
+ from utils import si_module, exists, default, maybe
18
+
19
+
20
+ @si_module
21
+ class GaussianMixtureIOLayer(nn.Module):
22
+ class Config:
23
+ latent_dim: int
24
+ dim: int
25
+ num_components: int
26
+
27
+ def __init__(self, c: Config):
28
+ super().__init__()
29
+ self.latent_dim = c.latent_dim
30
+ self.num_components = c.num_components
31
+ self.input_projection = nn.Linear(c.latent_dim, c.dim)
32
+
33
+ self.fc_loc = nn.Linear(c.dim, c.num_components * c.latent_dim)
34
+ self.fc_scale = nn.Linear(c.dim, c.num_components * c.latent_dim)
35
+ self.fc_weight = nn.Linear(c.dim, c.num_components)
36
+
37
+ def _square_plus(self, x):
38
+ return (x + T.sqrt(T.square(x) + 4)) / 2
39
+
40
+ def input(self, sampled_latents: T.Tensor) -> T.Tensor:
41
+ """Pre-sampled latents T.Tensor (B, L, Z) -> float tensor (B, L, D)"""
42
+ hidden = self.input_projection(sampled_latents)
43
+ return hidden
44
+
45
+ def output(self, h: T.Tensor) -> Tuple[T.Tensor, T.Tensor, T.Tensor]:
46
+ """float tensor (B, L, D) -> Tuple of locs, scales, and weights"""
47
+ batch_size, seq_len, _ = h.shape
48
+
49
+ locs = self.fc_loc(h).view(batch_size, seq_len, self.num_components, self.latent_dim)
50
+ scales = T.clamp(self._square_plus(self.fc_scale(h)), min=1e-6).view(batch_size, seq_len, self.num_components, self.latent_dim)
51
+ weights = self.fc_weight(h).view(batch_size, seq_len, self.num_components)
52
+
53
+ return (locs, scales, weights)
54
+
55
+ def loss(self, data, dataHat):
56
+ locs, scales, weights = dataHat
57
+ log_probs = -0.5 * T.sum(
58
+ (data.unsqueeze(-2) - locs).pow(2) / scales.pow(2) +
59
+ 2 * T.log(scales) +
60
+ T.log(T.tensor(2 * T.pi)),
61
+ dim=-1
62
+ )
63
+ log_weights = F.log_softmax(weights, dim=-1)
64
+ return -T.logsumexp(log_weights + log_probs, dim=-1)
65
+
66
+
67
+ def temp_sample(self, orig_pdist, temp):
68
+ locs, scales, weights = orig_pdist
69
+ if temp is None:
70
+ component_samples = locs + scales * T.randn_like(scales)
71
+ mixture_samples = F.gumbel_softmax(weights, hard=True)
72
+ sampled = (component_samples * mixture_samples.unsqueeze(-1)).sum(dim=-2)
73
+ elif isinstance(temp, tuple):
74
+ assert len(temp) == 2
75
+ categorical_temp, gaussian_temp = temp
76
+ component_samples = locs + scales * gaussian_temp * T.randn_like(scales)
77
+ mixture_samples = F.gumbel_softmax(weights / categorical_temp, hard=True)
78
+ sampled = (component_samples * mixture_samples.unsqueeze(-1)).sum(dim=-2)
79
+ else:
80
+ component_samples = locs + scales * temp * T.randn_like(scales)
81
+ mixture_samples = F.gumbel_softmax(weights / temp, hard=True)
82
+ sampled = (component_samples * mixture_samples.unsqueeze(-1)).sum(dim=-2)
83
+ return sampled
84
+
85
+
86
+ class GPTOutput(nn.Module):
87
+ def __init__(self, dim, vocab_size):
88
+ super().__init__()
89
+ self.output = nn.Linear(dim, vocab_size, bias=False)
90
+
91
+ def forward(self, x):
92
+ return self.output(x)
93
+
94
+
95
+ # helper functions
96
+
97
+ def pack_one(t, pattern):
98
+ return pack([t], pattern)
99
+
100
+ def unpack_one(t, ps, pattern):
101
+ return unpack(t, ps, pattern)[0]
102
+
103
+ def first(l):
104
+ return l[0]
105
+
106
+ def round_up_multiple(num, mult):
107
+ return ceil(num / mult) * mult
108
+
109
+ def get_code_utilization(codes, codebook_size, get_global=False):
110
+ if get_global and dist.is_initialized():
111
+ world_size = dist.get_world_size()
112
+ else:
113
+ world_size = 1
114
+
115
+ if world_size > 1:
116
+ gathered_tokens = [T.zeros_like(codes) for _ in range(world_size)]
117
+ dist.all_gather(gathered_tokens, codes)
118
+ gathered_tokens = T.cat(gathered_tokens, dim=0)
119
+ else:
120
+ gathered_tokens = codes
121
+ unique_tokens = len(T.unique(gathered_tokens))
122
+ code_utilization = unique_tokens / min(gathered_tokens.numel(), codebook_size)
123
+ return code_utilization
124
+
125
+ # tensor helpers
126
+
127
+ def round_ste(z: Tensor) -> Tensor:
128
+ """Round with straight through gradients."""
129
+ zhat = z.round()
130
+ return z + (zhat - z).detach()
131
+
132
+ # main class
133
+ # lucidrains fsq
134
+ @si_module
135
+ class FSQ(nn.Module):
136
+ @property
137
+ def needs_float32_params(self):
138
+ return True
139
+
140
+ class Config:
141
+ levels: List[int]
142
+ dim: int | None = None
143
+ num_codebooks: int = 1
144
+ keep_num_codebooks_dim: bool | None = None
145
+ scale: float | None = None
146
+ allowed_dtypes: Tuple[str, ...] = ('float32', 'float64')
147
+ channel_first: bool = False
148
+ projection_has_bias: bool = True
149
+ return_indices: bool = True
150
+ force_quantization_f32: bool = True
151
+ use_rms: bool = False
152
+
153
+ def __init__(self, c: Config):
154
+ super().__init__()
155
+ _levels = T.tensor(c.levels, dtype=int32)
156
+ self.register_buffer("_levels", _levels, persistent = False)
157
+
158
+ _basis = T.cumprod(T.tensor([1] + c.levels[:-1]), dim=0, dtype=int32)
159
+ self.register_buffer("_basis", _basis, persistent = False)
160
+
161
+ self.scale = c.scale
162
+
163
+ codebook_dim = len(c.levels)
164
+ self.codebook_dim = codebook_dim
165
+
166
+ effective_codebook_dim = codebook_dim * c.num_codebooks
167
+ self.num_codebooks = c.num_codebooks
168
+
169
+ self.allowed_dtypes = []
170
+ for dtype_str in c.allowed_dtypes:
171
+ if hasattr(T, dtype_str):
172
+ self.allowed_dtypes.append(getattr(T, dtype_str))
173
+ else:
174
+ raise ValueError(f"Invalid dtype string: {dtype_str}")
175
+
176
+ self.effective_codebook_dim = effective_codebook_dim
177
+
178
+ keep_num_codebooks_dim = default(c.keep_num_codebooks_dim, c.num_codebooks > 1)
179
+ assert not (c.num_codebooks > 1 and not keep_num_codebooks_dim)
180
+ self.keep_num_codebooks_dim = keep_num_codebooks_dim
181
+
182
+ self.dim = default(c.dim, len(_levels) * c.num_codebooks)
183
+
184
+ self.channel_first = c.channel_first
185
+
186
+ has_projections = self.dim != effective_codebook_dim
187
+ self.project_in = nn.Linear(self.dim, effective_codebook_dim, bias = c.projection_has_bias) if has_projections else nn.Identity()
188
+ self.project_out = nn.Linear(effective_codebook_dim, self.dim, bias = c.projection_has_bias) if has_projections else nn.Identity()
189
+
190
+ self.has_projections = has_projections
191
+
192
+ self.return_indices = c.return_indices
193
+ if c.return_indices:
194
+ self.codebook_size = self._levels.prod().item()
195
+ implicit_codebook = self._indices_to_codes(T.arange(self.codebook_size))
196
+ self.register_buffer("implicit_codebook", implicit_codebook, persistent = False)
197
+
198
+ self.allowed_dtypes = c.allowed_dtypes
199
+ self.force_quantization_f32 = c.force_quantization_f32
200
+
201
+ self.latent_loss = None
202
+
203
+ def latent_metric(self, codes, get_global=False):
204
+ return {'code_util_estimate': get_code_utilization(codes, self.codebook_size, get_global)}
205
+
206
+ def repr_from_latent(self, latent):
207
+ return self.indices_to_codes(latent)
208
+
209
+ def bound(self, z, eps: float = 1e-3):
210
+ """ Bound `z`, an array of shape (..., d). """
211
+ half_l = (self._levels - 1) * (1 + eps) / 2
212
+ offset = T.where(self._levels % 2 == 0, 0.5, 0.0)
213
+ shift = (offset / half_l).atanh()
214
+ return (z + shift).tanh() * half_l - offset
215
+
216
+ def quantize(self, z):
217
+ """ Quantizes z, returns quantized zhat, same shape as z. """
218
+ quantized = round_ste(self.bound(z))
219
+ half_width = self._levels // 2 # Renormalize to [-1, 1].
220
+ return quantized / half_width
221
+
222
+ def _scale_and_shift(self, zhat_normalized):
223
+ half_width = self._levels // 2
224
+ return (zhat_normalized * half_width) + half_width
225
+
226
+ def _scale_and_shift_inverse(self, zhat):
227
+ half_width = self._levels // 2
228
+ return (zhat - half_width) / half_width
229
+
230
+ def _indices_to_codes(self, indices):
231
+ level_indices = self.indices_to_level_indices(indices)
232
+ codes = self._scale_and_shift_inverse(level_indices)
233
+ return codes
234
+
235
+ def codes_to_indices(self, zhat):
236
+ """ Converts a `code` to an index in the codebook. """
237
+ assert zhat.shape[-1] == self.codebook_dim
238
+ zhat = self._scale_and_shift(zhat)
239
+ return (zhat * self._basis).sum(dim=-1).to(int32)
240
+
241
+ def indices_to_level_indices(self, indices):
242
+ """ Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings """
243
+ indices = rearrange(indices, '... -> ... 1')
244
+ codes_non_centered = (indices // self._basis) % self._levels
245
+ return codes_non_centered
246
+
247
+ def indices_to_codes(self, indices):
248
+ """ Inverse of `codes_to_indices`. """
249
+ assert exists(indices)
250
+
251
+ is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
252
+
253
+ codes = self._indices_to_codes(indices)
254
+
255
+ if self.keep_num_codebooks_dim:
256
+ codes = rearrange(codes, '... c d -> ... (c d)')
257
+
258
+ codes = self.project_out(codes)
259
+
260
+ if is_img_or_video or self.channel_first:
261
+ codes = rearrange(codes, 'b ... d -> b d ...')
262
+
263
+ return codes
264
+
265
+ # @autocast(device_type='cuda', enabled = False)
266
+ def forward(self, z, return_codes=False):
267
+ """
268
+ einstein notation
269
+ b - batch
270
+ n - sequence (or flattened spatial dimensions)
271
+ d - feature dimension
272
+ c - number of codebook dim
273
+ """
274
+
275
+ is_img_or_video = z.ndim >= 4
276
+ need_move_channel_last = is_img_or_video or self.channel_first
277
+
278
+ # standardize image or video into (batch, seq, dimension)
279
+
280
+ if need_move_channel_last:
281
+ z = rearrange(z, 'b d ... -> b ... d')
282
+ z, ps = pack_one(z, 'b * d')
283
+
284
+ assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'
285
+
286
+ z = self.project_in(z)
287
+
288
+ z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks)
289
+
290
+ # whether to force quantization step to be full precision or not
291
+
292
+ force_f32 = self.force_quantization_f32
293
+ quantization_context = partial(autocast, device_type='cuda', enabled = False) if force_f32 else nullcontext
294
+
295
+ with quantization_context():
296
+ orig_dtype = z.dtype
297
+
298
+ if force_f32 and orig_dtype not in self.allowed_dtypes:
299
+ z = z.float()
300
+
301
+ codes = self.quantize(z)
302
+
303
+ # returning indices could be optional
304
+
305
+ indices = None
306
+
307
+ if self.return_indices:
308
+ indices = self.codes_to_indices(codes)
309
+
310
+ codes = rearrange(codes, 'b n c d -> b n (c d)')
311
+
312
+ codes = codes.type(orig_dtype)
313
+
314
+ # project out
315
+ if return_codes:
316
+ return codes, indices
317
+
318
+ out = self.project_out(codes)
319
+
320
+ # reconstitute image or video dimensions
321
+
322
+ if need_move_channel_last:
323
+ out = unpack_one(out, ps, 'b * d')
324
+ out = rearrange(out, 'b ... d -> b d ...')
325
+
326
+ indices = maybe(unpack_one)(indices, ps, 'b * c')
327
+
328
+ if not self.keep_num_codebooks_dim and self.return_indices:
329
+ indices = maybe(rearrange)(indices, '... 1 -> ...')
330
+
331
+ # return quantized output and indices
332
+
333
+ return out, indices
model.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch as T
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from ioblocks import GaussianMixtureIOLayer, FSQ
8
+
9
+ from transformer import Stack, ShapeRotator, Block as PerfBlock, GPTOutput, CACHE_FILL_VALUE, FFNN, Norm
10
+ from tokenizer import make_tokenizer
11
+
12
+
13
+ from utils import si_module, exists, isnt, tqdm0, print0, default, print0_colored
14
+ from utils import load_ckpt
15
+
16
+
17
+ @si_module
18
+ class LatentQuantizer(nn.Module):
19
+ class Config:
20
+ compressor_config: Optional[FSQ.Config] = None
21
+
22
+ dim: Optional[int] = None
23
+ ff_dim: Optional[int] = None
24
+ input_dim: int = None
25
+
26
+ from_pretrained: Optional[Tuple[str, str]] = None
27
+
28
+ def __init__(self, c: Config):
29
+ super().__init__()
30
+
31
+ if exists(c.from_pretrained):
32
+ checkpoint = load_ckpt(*c.from_pretrained)
33
+ else:
34
+ assert exists(c.compressor_config), f'hmm {c}'
35
+
36
+ self.compressor = c.compressor_config()
37
+ self.ffnn = FFNN(c.dim, c.ff_dim)
38
+ self.input = nn.Linear(c.input_dim, c.dim) if exists(c.input_dim) else nn.Identity()
39
+
40
+ if exists(c.from_pretrained):
41
+ self.load_state_dict(checkpoint)
42
+
43
+ @T.no_grad()
44
+ def forward(self, x, return_latent=False, known_latent=None):
45
+ """
46
+ x: (B, S, D)
47
+ """
48
+ if exists(known_latent):
49
+ return self.compressor.indices_to_codes(known_latent)
50
+
51
+ x = self.input(x)
52
+ x = self.ffnn(x)
53
+ x, tokens = self.compressor(x)
54
+
55
+ if return_latent:
56
+ return x, tokens
57
+ return x
58
+
59
+
60
+ @si_module
61
+ class TransformerVAE(nn.Module):
62
+ class Config:
63
+ io_config: Optional[GaussianMixtureIOLayer.Config] = None
64
+ stack_config: Optional[Stack.Config] = None
65
+ quantizer_config: Optional[LatentQuantizer.Config] = None
66
+
67
+ plex_layer: int = None
68
+ plex_roll: int = 1
69
+ split: bool = True
70
+
71
+ from_pretrained: Optional[Tuple[str, str]] = None
72
+
73
+ def __init__(self, c: Config):
74
+ super().__init__()
75
+
76
+ if exists(c.from_pretrained):
77
+ checkpoint = load_ckpt(*c.from_pretrained)
78
+ else:
79
+ assert (exists(c.io_config) and exists(c.stack_config) and exists(c.quantizer_config)), f'hmm {c}'
80
+
81
+ self.io = c.io_config()
82
+ self.stack = c.stack_config()
83
+
84
+ self.plex_layer = c.stack_config.layers//2
85
+ self.plex_roll = c.plex_roll
86
+ self.plex_dim = c.quantizer_config.dim
87
+
88
+ assert self.plex_dim is not None and c.stack_config.dim is not None, f'One of the following are None: self.plex_dim: {self.plex_dim}, c.stack_config.dim: {c.stack_config.dim}'
89
+ self.plex_projection = nn.Linear(self.plex_dim, c.stack_config.dim)
90
+ self.out_norm = Norm(c.stack_config.dim)
91
+
92
+ if c.split:
93
+ self.io2 = c.io_config()
94
+ self.plex_projection2 = nn.Linear(self.plex_dim, c.stack_config.dim)
95
+
96
+ self.io2.fc_loc = None
97
+ self.io2.fc_scale = None
98
+ self.io2.fc_weight = None
99
+
100
+ kv_heads = c.stack_config.kv_heads or c.stack_config.n_head
101
+ head_dim = c.stack_config.dim // c.stack_config.n_head
102
+ self.cache_num_layers = c.stack_config.layers + ((c.stack_config.layers - self.plex_layer) if c.split else 0)
103
+ cache_shape = [self.cache_num_layers, c.stack_config.seq_len, 2, kv_heads, head_dim]
104
+ self.cache_shape = cache_shape
105
+ self.cache = [None] * self.cache_num_layers
106
+
107
+ if exists(c.from_pretrained):
108
+ result = self.load_state_dict(checkpoint, strict=False)
109
+ print0_colored(result, 'yellow')
110
+
111
+ self.quantizer = c.quantizer_config().eval()
112
+ self.quantizer.requires_grad = False
113
+
114
+ @T.no_grad()
115
+ def quantize(self, x):
116
+ if self.c.split:
117
+ x1, x2 = x.chunk(2, dim=-1)
118
+ with T.autocast(device_type='cuda', dtype=T.bfloat16):
119
+ quantized1 = self.quantizer(x1)
120
+ quantized2 = self.quantizer(x2)
121
+ return quantized1, quantized2
122
+ else:
123
+ with T.autocast(device_type='cuda', dtype=T.bfloat16):
124
+ return self.quantizer(x)
125
+
126
+ @T.no_grad()
127
+ def untokenize(self, token_data):
128
+ return self.quantizer(None, known_latent=token_data)
129
+
130
+ def init_cache(self, bsize, device, dtype, length:int=None):
131
+ cache_shape = self.cache_shape.copy()
132
+ cache_shape[1] = length or cache_shape[1]
133
+ self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1)
134
+
135
+ def deinit_cache(self):
136
+ self.cache = [None] * self.cache_num_layers
137
+
138
+ @T.no_grad()
139
+ def forward(self, data, next_tokens: Optional[Tuple[T.Tensor, T.Tensor]] = None, temps: Optional[Tuple[float, Tuple[float, float]]] = None):
140
+ if self.c.split:
141
+ x1, x2 = data.chunk(2, dim=-1)
142
+ x = self.io.input(x1) + self.io2.input(x2)
143
+ else:
144
+ x = self.io.input(data)
145
+
146
+ cache_idx = 0
147
+ for l, layer in enumerate(self.stack.layers):
148
+ if l == self.plex_layer:
149
+ if self.c.split:
150
+ plex1, plex2 = self.quantize(data)
151
+ plex1 = T.roll(plex1, -self.c.plex_roll, dims=1)
152
+ plex2 = T.roll(plex2, -self.c.plex_roll, dims=1)
153
+ if exists(next_tokens):
154
+ plex1[:, -1:] = self.untokenize(next_tokens[0])
155
+ plex2[:, -1:] = self.untokenize(next_tokens[1])
156
+ x1 = x + self.plex_projection(plex1)
157
+ x2 = x + self.plex_projection2(plex2)
158
+ else:
159
+ plex = self.quantize(data)
160
+ plex = T.roll(plex, -self.c.plex_roll, dims=1)
161
+ if exists(next_tokens):
162
+ plex[:, -1:] = self.untokenize(next_tokens)
163
+ x = x + self.plex_projection(plex)
164
+
165
+ if l < self.plex_layer:
166
+ x = layer(x, kv=self.cache[l])
167
+ else:
168
+ if self.c.split:
169
+ x1 = layer(x1, kv=self.cache[self.plex_layer + cache_idx])
170
+ cache_idx += 1
171
+ x2 = layer(x2, kv=self.cache[self.plex_layer + cache_idx])
172
+ cache_idx += 1
173
+ else:
174
+ x = layer(x, kv=self.cache[l])
175
+
176
+ with T.autocast(device_type='cuda', dtype=T.bfloat16):
177
+ if self.c.split:
178
+ x1, x2 = self.out_norm(x1), self.out_norm(x2)
179
+ out1, out2 = self.io.output(x1), self.io.output(x2)
180
+ else:
181
+ x = self.out_norm(x)
182
+ out = self.io.output(x)
183
+
184
+ if isnt(temps):
185
+ if self.c.split:
186
+ return out1, out2
187
+ else:
188
+ return out
189
+ else:
190
+ if self.c.split:
191
+ next_data1 = self.io.temp_sample(out1, temps)[:, -1:, :]
192
+ next_data2 = self.io2.temp_sample(out2, temps)[:, -1:, :]
193
+ next_data = T.cat([next_data1, next_data2], dim=-1)
194
+ return next_data
195
+ else:
196
+ next_data = self.io.temp_sample(out, temps)[:, -1:, :]
197
+ return next_data
198
+
199
+ @si_module
200
+ class HertzDevModel(nn.Module):
201
+ class Config:
202
+ dim: int
203
+ vocab_size: int
204
+ stack_config: Optional[Stack.Config] = None
205
+ latent_size: int = 32
206
+
207
+ split: bool = True
208
+
209
+ quantizer_config: Optional[LatentQuantizer.Config] = None
210
+ resynthesizer_config: Optional[TransformerVAE.Config] = None
211
+
212
+ from_pretrained: Optional[Tuple[str, str]] = None
213
+
214
+ def __init__(self, c: Config):
215
+ super().__init__()
216
+
217
+ if exists(c.from_pretrained):
218
+ checkpoint = load_ckpt(*c.from_pretrained)
219
+ else:
220
+ assert (exists(c.stack_config)), f'hmm {c}'
221
+
222
+ self.input = nn.Linear(c.latent_size, c.dim)
223
+ if self.c.split:
224
+ self.input2 = nn.Linear(c.latent_size, c.dim)
225
+
226
+ self.shape_rotator = ShapeRotator(c.stack_config.dim//c.stack_config.n_head, c.stack_config.seq_len, theta=c.stack_config.theta)
227
+
228
+ self.layers = nn.ModuleList([
229
+ PerfBlock(
230
+ dim=c.stack_config.dim,
231
+ layer_id=l,
232
+ n_head=c.stack_config.n_head,
233
+ kv_heads=c.stack_config.kv_heads,
234
+ ff_dim=c.stack_config.ff_dim,
235
+ eps=c.stack_config.eps,
236
+ shape_rotator=self.shape_rotator,
237
+ ) for l in range(c.stack_config.layers)
238
+ ])
239
+
240
+ self.output = GPTOutput(c.dim, c.vocab_size)
241
+ if self.c.split:
242
+ self.output2 = GPTOutput(c.dim, c.vocab_size)
243
+
244
+ self.cache = [None] * c.stack_config.layers
245
+ self.kv_heads = c.stack_config.kv_heads or c.stack_config.n_head
246
+ self.head_dim = c.stack_config.dim // c.stack_config.n_head
247
+
248
+ if exists(c.from_pretrained):
249
+ result = self.load_state_dict(checkpoint, strict=False)
250
+ print0_colored(result, 'yellow')
251
+
252
+ self.resynthesizer = c.resynthesizer_config().eval()
253
+ self.resynthesizer.requires_grad = False
254
+
255
+ self.audio_tokenizer = make_tokenizer(device='cpu')
256
+ self.audio_cache = None
257
+ self.audio_latent_cache = None
258
+ self.use_audio_cache = False
259
+
260
+ @T.no_grad()
261
+ def tokenize(self, audio_data):
262
+ orig_audio_shape = audio_data.shape
263
+ if exists(self.audio_cache):
264
+ audio_data = T.cat([self.audio_cache, audio_data], dim=-1)
265
+ self.audio_cache = audio_data[..., -(6*16_000):]
266
+ elif self.use_audio_cache:
267
+ self.audio_cache = audio_data[..., -(6*16_000):]
268
+
269
+ if audio_data.shape[1] == 2:
270
+ enc_ch1 = self.audio_tokenizer.latent_from_data(audio_data[:, 0:1])
271
+ enc_ch2 = self.audio_tokenizer.latent_from_data(audio_data[:, 1:2])
272
+ return T.cat([enc_ch1, enc_ch2], dim=-1)[:, -(orig_audio_shape[-1]//2000):]
273
+ else:
274
+ return self.audio_tokenizer.latent_from_data(audio_data)[:, -(orig_audio_shape[-1]//2000):]
275
+
276
+ @T.no_grad()
277
+ def untokenize(self, token_data):
278
+ if exists(self.audio_latent_cache):
279
+ token_data = T.cat([self.audio_latent_cache, token_data], dim=1)
280
+ self.audio_latent_cache = token_data[:, -(6*8):]
281
+ elif self.use_audio_cache:
282
+ self.audio_latent_cache = token_data[:, -(6*8):]
283
+
284
+ if token_data.shape[-1] == 2*self.c.latent_size:
285
+ dec_ch1 = self.audio_tokenizer.data_from_latent(token_data[:, :self.c.latent_size])
286
+ dec_ch2 = self.audio_tokenizer.data_from_latent(token_data[:, self.c.latent_size:])
287
+ return T.cat([dec_ch1, dec_ch2], dim=1)[..., -(token_data.shape[1]*2000):]
288
+ else:
289
+ return self.audio_tokenizer.data_from_latent(token_data)[..., -(token_data.shape[1]*2000):]
290
+
291
+ def init_cache(self, bsize, device, dtype, length:int=None):
292
+ cache_shape = [self.c.stack_config.layers, length or self.c.stack_config.seq_len, 2, self.kv_heads, self.head_dim]
293
+ self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1)
294
+ self.resynthesizer.init_cache(bsize, device, dtype, length)
295
+ self.use_audio_cache = True
296
+
297
+ def deinit_cache(self):
298
+ self.cache = [None] * len(self.layers)
299
+ self.resynthesizer.deinit_cache()
300
+ self.audio_cache = None
301
+ self.audio_latent_cache = None
302
+ self.use_audio_cache = False
303
+
304
+ @T.no_grad()
305
+ def forward(self, data):
306
+ if self.c.split:
307
+ x1, x2 = data.chunk(2, dim=-1)
308
+ x = self.input(x1) + self.input2(x2)
309
+ else:
310
+ x = self.input(data)
311
+
312
+ for l, layer in enumerate(self.layers):
313
+ x = layer(x, kv=self.cache[l])
314
+
315
+ if self.c.split:
316
+ return self.output(x), self.output2(x)
317
+ else:
318
+ return self.output(x)
319
+
320
+ @T.no_grad()
321
+ def next_audio_from_audio(self, audio_data: T.Tensor, temps=(0.8, (0.5, 0.1))):
322
+ latents_in = self.tokenize(audio_data)
323
+ next_latents = self.next_latent(latents_in, temps)
324
+ next_model_latent = next_latents[..., self.c.latent_size:]
325
+ audio_decoded = self.untokenize(next_model_latent)[..., -2000:]
326
+ return audio_decoded
327
+
328
+
329
+ @T.no_grad()
330
+ def next_latent(self, model_input: T.Tensor, temps=(0.8, (0.5, 0.1))):
331
+
332
+ if self.c.split:
333
+ logits1, logits2 = self.forward(model_input)
334
+ next_logits1 = logits1[:, -1]
335
+ next_logits2 = logits2[:, -1]
336
+ next_token1 = F.softmax(next_logits1 / temps[0], dim=-1).multinomial(1)
337
+ next_token2 = F.softmax(next_logits2 / temps[0], dim=-1).multinomial(1)
338
+
339
+ next_input = self.resynthesizer(model_input, next_tokens=(next_token1, next_token2), temps=temps[1])
340
+ else:
341
+ logits = self.forward(model_input)
342
+ next_logits = logits[:, -1]
343
+ next_token = F.softmax(next_logits / temps[0], dim=-1).multinomial(1)
344
+
345
+ next_input = self.resynthesizer(model_input, next_tokens=next_token, temps=temps[1])
346
+
347
+ return next_input
348
+
349
+
350
+ @T.no_grad()
351
+ def completion(self, data: T.Tensor, temps=(0.8, (0.5, 0.1)), gen_len=None, use_cache=True) -> T.Tensor:
352
+ """
353
+ only accepts latent-space data.
354
+ """
355
+ if use_cache:
356
+ self.init_cache(data.shape[0], data.device, T.bfloat16)
357
+
358
+ next_input = generated = data
359
+
360
+ target_len = min(data.shape[1] + default(gen_len, data.shape[1]), self.c.stack_config.seq_len)
361
+
362
+ for _ in tqdm0(range(data.shape[1], target_len)):
363
+ model_input = next_input if use_cache else generated
364
+
365
+ next_input = self.next_latent(model_input, temps)
366
+
367
+ generated = T.cat([generated, next_input], dim=1)
368
+
369
+ if use_cache:
370
+ self.deinit_cache()
371
+ return generated
372
+
373
+
374
+
375
+ def get_hertz_dev_config(is_split=True, use_pure_audio_ablation=False):
376
+ if is_split:
377
+ checkpoints = [('inference_care_50000', 'e4ff4fe5c7e9f066410d2a5673b7a935'), ('inference_scion_54000', 'cb8bc484423922747b277ebc2933af5d')]
378
+ elif not use_pure_audio_ablation:
379
+ checkpoints = [('inference_whip_72000', '5e7cee7316900737d55fc5d44cc7a8f7'), ('inference_caraway_112000', 'fcb8368ef8ebf7712f3e31e6856da580')]
380
+ else:
381
+ checkpoints = [('inference_whip_72000', '5e7cee7316900737d55fc5d44cc7a8f7'), ('inference_syrup_110000', '353c48f553f1706824c11f3bb6a049e9')]
382
+
383
+ quantizer_config=LatentQuantizer.Config(
384
+ from_pretrained=('inference_volcano_3', 'd42bf674022c5f84b051d5d7794f6169'),
385
+ compressor_config=FSQ.Config(
386
+ levels=[8,8,8,8,8],
387
+ dim=2048,
388
+ num_codebooks=1,
389
+ keep_num_codebooks_dim=None,
390
+ scale=None,
391
+ allowed_dtypes=['float32', 'float64', 'bfloat16'],
392
+ channel_first=False,
393
+ projection_has_bias=True,
394
+ return_indices=True,
395
+ force_quantization_f32=True,
396
+ use_rms=False
397
+ ),
398
+ dim=2048,
399
+ ff_dim=8192,
400
+ input_dim=32
401
+ )
402
+
403
+ resynthesizer_config=TransformerVAE.Config(
404
+ io_config=GaussianMixtureIOLayer.Config(
405
+ latent_dim=32,
406
+ dim=4096,
407
+ num_components=8,
408
+ ),
409
+ stack_config=Stack.Config(
410
+ layers=8,
411
+ dim=4096,
412
+ seq_len=8192,
413
+ n_head=16,
414
+ ff_dim=11008,
415
+ kv_heads=16,
416
+ eps=1e-5,
417
+ theta=10_000
418
+ ),
419
+ quantizer_config=quantizer_config,
420
+ plex_layer=None,
421
+ plex_roll=1,
422
+ split=is_split,
423
+ from_pretrained=checkpoints[0],
424
+ )
425
+
426
+ return HertzDevModel.Config(
427
+ dim=4096,
428
+ vocab_size=32_768,
429
+ stack_config=Stack.Config(
430
+ layers=32,
431
+ dim=4096,
432
+ seq_len=2048,
433
+ n_head=32,
434
+ ff_dim=None,
435
+ kv_heads=None,
436
+ eps=1e-5,
437
+ theta=10_000,
438
+ ),
439
+ quantizer_config=quantizer_config,
440
+ resynthesizer_config=resynthesizer_config,
441
+ split=is_split,
442
+ from_pretrained=checkpoints[1],
443
+ )
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.5.1
2
+ torchaudio==2.5.1
3
+ einops==0.8.0
4
+ tqdm==4.66.6
5
+ ipython==8.18.1
6
+ numpy==1.26.3
7
+ soundfile==0.12.1
8
+ websockets==13.1
9
+ requests==2.32.3
10
+ sounddevice==0.5.1
11
+ matplotlib==3.9.2
12
+ fastapi==0.115.4
13
+ uvicorn==0.32.0
14
+ huggingface-hub[hf_transfer]==0.26.2
15
+ IProgress==0.4
requirements_webrtc.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ streamlit==1.33.0
2
+ streamlit-webrtc==0.47.9
tokenizer.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Union, Tuple, Literal
4
+
5
+ import torch as T
6
+ import torch.nn as nn
7
+ from torch.nn.utils.parametrizations import weight_norm
8
+
9
+ from utils import load_ckpt
10
+ from utils.interp import print_colored
11
+ from utils import si_module, get_activation
12
+
13
+
14
+
15
+ # Adapted from https://github.com/facebookresearch/AudioDec
16
+
17
+ def Conv1d1x1(in_channels, out_channels, bias=True):
18
+ return nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=bias)
19
+
20
+
21
+ class NonCausalConv1d(nn.Module):
22
+ """1D noncausal convolution w/ 2-sides padding."""
23
+
24
+ def __init__(
25
+ self,
26
+ in_channels,
27
+ out_channels,
28
+ kernel_size,
29
+ stride=1,
30
+ padding=-1,
31
+ dilation=1,
32
+ groups=1,
33
+ bias=True):
34
+ super().__init__()
35
+ self.in_channels = in_channels
36
+ self.out_channels = out_channels
37
+ self.kernel_size = kernel_size
38
+ if padding < 0:
39
+ padding = (kernel_size - 1) // 2 * dilation
40
+ self.dilation = dilation
41
+ self.conv = nn.Conv1d(
42
+ in_channels=in_channels,
43
+ out_channels=out_channels,
44
+ kernel_size=kernel_size,
45
+ stride=stride,
46
+ padding=padding,
47
+ dilation=dilation,
48
+ groups=groups,
49
+ bias=bias,
50
+ )
51
+
52
+ def forward(self, x):
53
+ """
54
+ Args:
55
+ x (Tensor): Float tensor variable with the shape (B, C, T).
56
+ Returns:
57
+ Tensor: Float tensor variable with the shape (B, C, T).
58
+ """
59
+ x = self.conv(x)
60
+ return x
61
+
62
+
63
+ class NonCausalConvTranspose1d(nn.Module):
64
+ """1D noncausal transpose convolution."""
65
+
66
+ def __init__(
67
+ self,
68
+ in_channels,
69
+ out_channels,
70
+ kernel_size,
71
+ stride,
72
+ padding=-1,
73
+ output_padding=-1,
74
+ groups=1,
75
+ bias=True,
76
+ ):
77
+ super().__init__()
78
+ if padding < 0:
79
+ padding = (stride+1) // 2
80
+ if output_padding < 0:
81
+ output_padding = 1 if stride % 2 else 0
82
+ self.deconv = nn.ConvTranspose1d(
83
+ in_channels=in_channels,
84
+ out_channels=out_channels,
85
+ kernel_size=kernel_size,
86
+ stride=stride,
87
+ padding=padding,
88
+ output_padding=output_padding,
89
+ groups=groups,
90
+ bias=bias,
91
+ )
92
+
93
+ def forward(self, x):
94
+ """
95
+ Args:
96
+ x (Tensor): Float tensor variable with the shape (B, C, T).
97
+ Returns:
98
+ Tensor: Float tensor variable with the shape (B, C', T').
99
+ """
100
+ x = self.deconv(x)
101
+ return x
102
+
103
+
104
+ class CausalConv1d(NonCausalConv1d):
105
+ def __init__(
106
+ self,
107
+ in_channels,
108
+ out_channels,
109
+ kernel_size,
110
+ stride=1,
111
+ dilation=1,
112
+ groups=1,
113
+ bias=True
114
+ ):
115
+ super(CausalConv1d, self).__init__(
116
+ in_channels=in_channels,
117
+ out_channels=out_channels,
118
+ kernel_size=kernel_size,
119
+ stride=stride,
120
+ padding=0,
121
+ dilation=dilation,
122
+ groups=groups,
123
+ bias=bias,
124
+ )
125
+ self.stride = stride
126
+ self.pad_length = (kernel_size - 1) * dilation
127
+ def forward(self, x):
128
+ pad = nn.ConstantPad1d((self.pad_length, 0), 0.0)
129
+ x = pad(x)
130
+ return self.conv(x)
131
+
132
+
133
+ class CausalConvTranspose1d(NonCausalConvTranspose1d):
134
+ def __init__(
135
+ self,
136
+ in_channels,
137
+ out_channels,
138
+ kernel_size,
139
+ stride,
140
+ bias=True,
141
+ pad_buffer=None,
142
+ ):
143
+ super(CausalConvTranspose1d, self).__init__(
144
+ in_channels=in_channels,
145
+ out_channels=out_channels,
146
+ kernel_size=kernel_size,
147
+ stride=stride,
148
+ padding=0,
149
+ output_padding=0,
150
+ bias=bias,
151
+ )
152
+ self.stride = stride
153
+ self.pad_length = (math.ceil(kernel_size/stride) - 1)
154
+ if pad_buffer is None:
155
+ pad_buffer = T.zeros(1, in_channels, self.pad_length)
156
+ self.register_buffer("pad_buffer", pad_buffer)
157
+
158
+ def forward(self, x):
159
+ pad = nn.ReplicationPad1d((self.pad_length, 0))
160
+ x = pad(x)
161
+ return self.deconv(x)[:, :, self.stride : -self.stride]
162
+
163
+ def inference(self, x):
164
+ x = T.cat((self.pad_buffer, x), -1)
165
+ self.pad_buffer = x[:, :, -self.pad_length:]
166
+ return self.deconv(x)[:, :, self.stride : -self.stride]
167
+
168
+ def reset_buffer(self):
169
+ self.pad_buffer.zero_()
170
+
171
+
172
+ class NonCausalResUnit(nn.Module):
173
+ def __init__(
174
+ self,
175
+ in_channels,
176
+ out_channels,
177
+ kernel_size=7,
178
+ dilation=1,
179
+ bias=False,
180
+ ):
181
+ super().__init__()
182
+ self.activation = nn.ELU()
183
+ self.conv1 = NonCausalConv1d(
184
+ in_channels=in_channels,
185
+ out_channels=out_channels,
186
+ kernel_size=kernel_size,
187
+ stride=1,
188
+ dilation=dilation,
189
+ bias=bias,
190
+ )
191
+ self.conv2 = Conv1d1x1(out_channels, out_channels, bias)
192
+
193
+ def forward(self, x):
194
+ y = self.conv1(self.activation(x))
195
+ y = self.conv2(self.activation(y))
196
+ return x + y
197
+
198
+
199
+ class CausalResUnit(NonCausalResUnit):
200
+ def __init__(
201
+ self,
202
+ in_channels,
203
+ out_channels,
204
+ kernel_size=7,
205
+ dilation=1,
206
+ bias=False,
207
+ ):
208
+ super(CausalResUnit, self).__init__(
209
+ in_channels=in_channels,
210
+ out_channels=out_channels,
211
+ kernel_size=kernel_size,
212
+ dilation=dilation,
213
+ bias=bias,
214
+ )
215
+ self.conv1 = CausalConv1d(
216
+ in_channels=in_channels,
217
+ out_channels=out_channels,
218
+ kernel_size=kernel_size,
219
+ stride=1,
220
+ dilation=dilation,
221
+ bias=bias,
222
+ )
223
+
224
+ def inference(self, x):
225
+ y = self.conv1.inference(self.activation(x))
226
+ y = self.conv2(self.activation(y))
227
+ return x + y
228
+
229
+
230
+ class ResNetBlock(nn.Module):
231
+ def __init__(self,
232
+ in_channels,
233
+ out_channels,
234
+ stride,
235
+ kernel_size=7,
236
+ dilations=(1, 3, 9),
237
+ bias=True,
238
+ mode='encoder',
239
+ ):
240
+ super().__init__()
241
+ assert mode in ('encoder', 'decoder'), f"Mode ({mode}) is not supported!"
242
+
243
+ self.mode = mode
244
+ self.stride = stride
245
+
246
+ ConvUnit = CausalConv1d if mode == 'encoder' else CausalConvTranspose1d
247
+
248
+ res_channels = in_channels if mode == 'encoder' else out_channels
249
+
250
+ res_units = [CausalResUnit(
251
+ res_channels,
252
+ res_channels,
253
+ kernel_size=kernel_size,
254
+ dilation=dilation,
255
+ ) for dilation in dilations]
256
+
257
+ if in_channels == out_channels:
258
+ if mode == 'encoder':
259
+ self.pool = nn.AvgPool1d(kernel_size=stride, stride=stride)
260
+ if mode == 'decoder':
261
+ self.upsample = nn.Upsample(scale_factor=stride, mode='nearest')
262
+ conv_unit = nn.Conv1d(
263
+ in_channels=in_channels,
264
+ out_channels=out_channels,
265
+ kernel_size=1,
266
+ bias=bias,
267
+ ) if in_channels != out_channels else nn.Identity()
268
+ else:
269
+ conv_unit = ConvUnit(
270
+ in_channels=in_channels,
271
+ out_channels=out_channels,
272
+ kernel_size=(2 * stride),
273
+ stride=stride,
274
+ bias=bias,
275
+ )
276
+
277
+ if mode == 'encoder':
278
+ if in_channels == out_channels:
279
+ self.res_block = nn.Sequential(*res_units, self.pool, conv_unit)
280
+ else:
281
+ self.res_block = nn.Sequential(*res_units, conv_unit)
282
+ elif mode == 'decoder':
283
+ if in_channels == out_channels:
284
+ self.res_block = nn.Sequential(self.upsample, conv_unit, *res_units)
285
+ else:
286
+ self.res_block = nn.Sequential(conv_unit, *res_units)
287
+
288
+ def forward(self, x):
289
+ out = x
290
+ for unit in self.res_block:
291
+ out = unit(out)
292
+ return out
293
+
294
+ def inference(self, x):
295
+ for unit in self.res_block:
296
+ x = unit.inference(x)
297
+ return x
298
+
299
+
300
+
301
+
302
+ @si_module
303
+ class ResNetStack(nn.Module):
304
+ """
305
+ ResNet encoder or decoder stack. Channel ratios
306
+ and strides take the default order of from
307
+ data/io-layer, to the middle of the model.
308
+ """
309
+ class Config:
310
+ input_channels: int = 1
311
+ output_channels: int = 1
312
+ encode_channels: int = 32
313
+ decode_channel_multiplier: int = 1
314
+ latent_dim: int = None
315
+ kernel_size: int = 7
316
+ bias: bool = True
317
+ channel_ratios: Tuple[int, ...] = (2, 4, 8, 16)
318
+ strides: Tuple[int, ...] = (3, 4, 5, 5)
319
+ mode: Literal['encoder', 'decoder'] = 'encoder'
320
+
321
+ def __init__(self, c: Config):
322
+ super().__init__()
323
+ assert c.mode in ('encoder', 'decoder'), f"Mode ({c.mode}) is not supported!"
324
+
325
+ self.mode = c.mode
326
+
327
+ assert len(c.channel_ratios) == len(c.strides)
328
+ channel_ratios = (1,) + c.channel_ratios
329
+ strides = c.strides
330
+ self.middle_channels = c.encode_channels * channel_ratios[-1]
331
+ if c.mode == 'decoder':
332
+ channel_ratios = tuple(reversed(channel_ratios))
333
+ strides = tuple(reversed(strides))
334
+
335
+ self.multiplier = c.decode_channel_multiplier if c.mode == 'decoder' else 1
336
+ res_blocks = [ResNetBlock(
337
+ c.encode_channels * channel_ratios[s_idx] * self.multiplier,
338
+ c.encode_channels * channel_ratios[s_idx+1] * self.multiplier,
339
+ stride,
340
+ kernel_size=c.kernel_size,
341
+ bias=c.bias,
342
+ mode=c.mode,
343
+ ) for s_idx, stride in enumerate(strides)]
344
+
345
+ data_conv = CausalConv1d(
346
+ in_channels=c.input_channels if c.mode == 'encoder' else c.encode_channels * self.multiplier,
347
+ out_channels=c.encode_channels if c.mode == 'encoder' else c.output_channels,
348
+ kernel_size=c.kernel_size,
349
+ stride=1,
350
+ bias=False,
351
+ )
352
+
353
+ if c.mode == 'encoder':
354
+ self.res_stack = nn.Sequential(data_conv, *res_blocks)
355
+ elif c.mode == 'decoder':
356
+ self.res_stack = nn.Sequential(*res_blocks, data_conv)
357
+
358
+ if c.latent_dim is not None:
359
+ self.latent_proj = Conv1d1x1(self.middle_channels, c.latent_dim, bias=c.bias) if c.mode == 'encoder' else Conv1d1x1(c.latent_dim, self.middle_channels, bias=c.bias)
360
+ if self.multiplier != 1:
361
+ self.multiplier_proj = Conv1d1x1(self.middle_channels, self.middle_channels * self.multiplier, bias=c.bias)
362
+
363
+ def forward(self, x, return_feats=False):
364
+ if self.c.latent_dim is not None and self.mode == 'decoder':
365
+ x = self.latent_proj(x)
366
+ if self.multiplier != 1:
367
+ x = self.multiplier_proj(x)
368
+
369
+ feats = []
370
+ for block in self.res_stack:
371
+ x = block(x)
372
+ if return_feats:
373
+ feats.append(x)
374
+ if self.c.latent_dim is not None and self.mode == 'encoder':
375
+ x = self.latent_proj(x)
376
+ if return_feats:
377
+ feats.append(x)
378
+ if return_feats:
379
+ return feats
380
+ return x
381
+
382
+ def inference(self, x):
383
+ for block in self.res_stack:
384
+ x = block.inference(x)
385
+ return x
386
+
387
+ def reset_buffer(self):
388
+ def _reset_buffer(m):
389
+ if isinstance(m, CausalConv1d) or isinstance(m, CausalConvTranspose1d):
390
+ m.reset_buffer()
391
+ self.apply(_reset_buffer)
392
+
393
+ def reset_parameters(self):
394
+ def _reset_parameters(m):
395
+ if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d)):
396
+ m.weight.data.normal_(0.0, 0.01)
397
+
398
+ self.apply(_reset_parameters)
399
+
400
+
401
+ def apply_weight_norm(self):
402
+ def _apply_weight_norm(m):
403
+ if isinstance(m, nn.Conv1d) or isinstance(
404
+ m, nn.ConvTranspose1d
405
+ ):
406
+ nn.utils.parametrizations.weight_norm(m)
407
+
408
+ self.apply(_apply_weight_norm)
409
+
410
+
411
+ def remove_weight_norm(self):
412
+ def _remove_weight_norm(m):
413
+ try:
414
+ print(m)
415
+ nn.utils.remove_weight_norm(m)
416
+ except ValueError: # this module didn't have weight norm
417
+ return
418
+
419
+ self.apply(_remove_weight_norm)
420
+
421
+
422
+
423
+ @si_module
424
+ class GaussianZ(nn.Module):
425
+ class Config:
426
+ dim: int
427
+ latent_dim: int
428
+ bias: bool = False
429
+ use_weight_norm: bool = False
430
+
431
+ def __init__(self, c: Config):
432
+ super().__init__()
433
+
434
+ self.proj_in = nn.Linear(c.dim, c.latent_dim * 2, bias=c.bias)
435
+ self.proj_out = nn.Linear(c.latent_dim, c.dim, bias=c.bias)
436
+
437
+ if c.use_weight_norm:
438
+ self.proj_in = weight_norm(self.proj_in)
439
+ self.proj_out = weight_norm(self.proj_out)
440
+
441
+ def reparam(self, mu, logvar):
442
+ std = T.exp(logvar / 2)
443
+ eps = T.randn_like(std)
444
+ return mu + eps * std
445
+
446
+ def kl_divergence(self, mu, logvar):
447
+ return T.mean(-0.5 * T.sum(
448
+ 1 + logvar - mu.pow(2) - logvar.exp(),
449
+ dim=(1, 2))
450
+ )
451
+
452
+ def repr_from_latent(self, latent: Union[dict, T.Tensor]):
453
+ if isinstance(latent, T.Tensor):
454
+ z = latent
455
+ else:
456
+ z = self.reparam(latent['mu'], latent['logvar'])
457
+ l = self.proj_out(z)
458
+ return l
459
+
460
+ def forward(self, x: T.Tensor) -> Tuple[T.Tensor, dict]:
461
+ mu, logvar = self.proj_in(x).chunk(2, dim=-1)
462
+ kl_div = self.kl_divergence(mu, logvar)
463
+ z = self.reparam(mu, logvar)
464
+ xhat = self.proj_out(z)
465
+ latent = {'mu': mu, 'logvar': logvar, 'z': z, 'kl_divergence': kl_div}
466
+ return xhat, latent
467
+
468
+
469
+
470
+ @si_module
471
+ class WaveCodec(nn.Module):
472
+ class Config:
473
+ resnet_config: ResNetStack.Config = None
474
+ sample_rate: int = 16_000
475
+ use_weight_norm: bool = False
476
+
477
+ compressor_config: dataclass = None
478
+
479
+ norm_stddev: float = 1.0
480
+
481
+ def __init__(self, c: Config):
482
+ super().__init__()
483
+ self.norm_stddev = c.norm_stddev
484
+ self.encoder = c.resnet_config(mode='encoder')
485
+ self.sample_rate = c.sample_rate
486
+
487
+ self.total_stride = 1
488
+ for stride in c.resnet_config.strides:
489
+ self.total_stride *= stride
490
+ self.tokens_per_second = self.sample_rate / self.total_stride
491
+
492
+ self.compressor = c.compressor_config(dim=self.encoder.middle_channels)
493
+
494
+ self.decoder = c.resnet_config(mode='decoder')
495
+
496
+ if c.use_weight_norm:
497
+ self.encoder.apply_weight_norm()
498
+ self.decoder.apply_weight_norm()
499
+ self.encoder.reset_parameters()
500
+ self.decoder.reset_parameters()
501
+
502
+ def encode(self, data):
503
+ return self.encoder(data/self.norm_stddev)
504
+
505
+ def decode(self, latent):
506
+ return self.decoder(latent.transpose(1, 2))*self.norm_stddev
507
+
508
+ @T.no_grad()
509
+ def latent_from_data(self, data, get_parameters=False):
510
+ x = self.encode(data)
511
+ l_in = x.transpose(1, 2)
512
+ l, latent = self.compressor(l_in)
513
+ return latent['z'] if not get_parameters else {
514
+ 'mu': latent['mu'],
515
+ 'logvar': latent['logvar'],
516
+ 'z': latent['z'],
517
+ }
518
+
519
+ @T.no_grad()
520
+ def data_from_latent(self, latent):
521
+ l = self.compressor.repr_from_latent(latent)
522
+ x = self.decode(l)
523
+ return x
524
+
525
+ def process(self, x):
526
+ return self.latent_from_data(x)
527
+
528
+ def unprocess(self, latent):
529
+ return self.data_from_latent(latent)
530
+
531
+ def forward(self, audio_input):
532
+ x = self.encode(audio_input)
533
+
534
+ l_in = x.transpose(1, 2)
535
+ l, latent = self.compressor(l_in)
536
+
537
+ xhat = self.decode(l)
538
+ return xhat, latent
539
+
540
+
541
+
542
+ def make_tokenizer(device='cuda'):
543
+ generator_config = WaveCodec.Config(
544
+ resnet_config=ResNetStack.Config(
545
+ input_channels=1,
546
+ output_channels=1,
547
+ encode_channels=16,
548
+ decode_channel_multiplier=4,
549
+ kernel_size=7,
550
+ bias=True,
551
+ channel_ratios=(4, 8, 16, 16, 16, 16),
552
+ strides=(2, 2, 4, 5, 5, 5),
553
+ mode=None,
554
+ ),
555
+ use_weight_norm=True,
556
+
557
+ compressor_config=GaussianZ.Config(
558
+ dim=None,
559
+ latent_dim=32,
560
+
561
+ bias=True,
562
+ use_weight_norm=True
563
+ ),
564
+
565
+ norm_stddev=0.05,
566
+ )
567
+ checkpoint = load_ckpt("inference_apatosaurus_95000", expected_hash="ba876edb97b988e9196e449dd176ca97")
568
+
569
+ tokenizer = generator_config()
570
+
571
+ load_result = tokenizer.load_state_dict(checkpoint, strict=False)
572
+ print_colored(f"Loaded tokenizer state dict: {load_result}", "grey")
573
+
574
+ tokenizer = tokenizer.eval()
575
+ # Only convert to bfloat16 if using CUDA
576
+ if device == 'cuda':
577
+ tokenizer = tokenizer.bfloat16()
578
+ tokenizer = tokenizer.to(device)
579
+ tokenizer.requires_grad_ = False
580
+ return tokenizer
581
+
transformer.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, MutableMapping
2
+ from typing import Union
3
+ import math
4
+ from contextlib import nullcontext
5
+
6
+ import torch
7
+ import torch as T
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch import Tensor
11
+ from torch.nn.attention import SDPBackend
12
+
13
+ from einops import rearrange
14
+
15
+ from utils import si_module, default, exists, load_ckpt
16
+
17
+ CACHE_FILL_VALUE = -1
18
+
19
+ def get_cache_len(cache: Optional[Tensor]) -> int:
20
+ """
21
+ cache: (batch, seq_len, 2, kv_heads, head_dim)
22
+ """
23
+ if cache is None:
24
+ return 0
25
+ nonzeros = T.any(cache.flatten(2) != CACHE_FILL_VALUE, dim=-1)
26
+ length = nonzeros.sum(dim=-1).int()
27
+ assert T.all(length == length[0])
28
+ return length[0]
29
+
30
+
31
+ def rotate_half(x):
32
+ x1, x2 = x.chunk(2, dim=-1)
33
+ return torch.cat((-x2, x1), dim=-1)
34
+
35
+
36
+ def apply_rotary_pos_emb(x, cos, sin, offset: int = 0):
37
+ assert (
38
+ cos.shape[1] >= offset + x.shape[1]
39
+ ), f"Offset and/or input sequence is too large,\
40
+ \n offset: {offset}, seq_len: {x.shape[1]}, max: {cos.shape[1]}"
41
+
42
+ cos_out = cos[:, offset : offset + x.shape[1], :, :]
43
+ sin_out = sin[:, offset : offset + x.shape[1], :, :]
44
+
45
+ return (x * cos_out) + (rotate_half(x) * sin_out)
46
+
47
+
48
+ # Adapted from https://github.com/foundation-model-stack/foundation-model-stack
49
+ class ShapeRotator:
50
+ def __init__(
51
+ self,
52
+ dim: int,
53
+ end: int,
54
+ theta: float = 10_000,
55
+ ):
56
+ super().__init__()
57
+ self.dim = dim
58
+ self.ratio = theta
59
+ self.cached_freqs: MutableMapping[int, MutableMapping[int, torch.Tensor]] = {}
60
+ self.max_seq_len_cached: MutableMapping[int, int] = {}
61
+ self.ntk_scaling = False
62
+ self.max_seq_len = end
63
+
64
+ def compute_freqs_cis(self, device, max_seq_len=None):
65
+ alpha = 1
66
+ dev_idx = device.index
67
+ max_seq_len = default(max_seq_len, self.max_seq_len)
68
+
69
+ if dev_idx not in self.cached_freqs:
70
+ self.cached_freqs[dev_idx] = {}
71
+ if dev_idx not in self.max_seq_len_cached:
72
+ self.max_seq_len_cached[dev_idx] = 0
73
+
74
+
75
+ if self.max_seq_len_cached[dev_idx] > 0:
76
+ return 1
77
+ max_seq_len = max(max_seq_len, self.max_seq_len)
78
+
79
+ if (
80
+ 1 in self.cached_freqs[dev_idx]
81
+ and max_seq_len <= self.max_seq_len_cached[dev_idx]
82
+ ):
83
+ return 1
84
+
85
+ ratio = self.ratio
86
+ dim = self.dim
87
+
88
+ freqs = 1.0 / (ratio ** (torch.arange(0, dim, 2, device=device).float() / dim))
89
+
90
+ t = torch.arange(max_seq_len, device=device, dtype=freqs.dtype)
91
+ freqs = torch.einsum("i,j->ij", t, freqs)
92
+ emb = torch.cat((freqs, freqs), dim=-1).to(device)
93
+
94
+ cos_to_cache = emb.cos()[None, :, None, :]
95
+ sin_to_cache = emb.sin()[None, :, None, :]
96
+
97
+ self.max_seq_len_cached[dev_idx] = max_seq_len
98
+
99
+ self.cached_freqs[dev_idx][alpha] = torch.stack(
100
+ [
101
+ cos_to_cache,
102
+ sin_to_cache,
103
+ ],
104
+ dim=-1,
105
+ )
106
+
107
+ return alpha
108
+
109
+ def rotate(
110
+ self,
111
+ q: Tensor,
112
+ k: Tensor,
113
+ offset: int = 0,
114
+ ) -> Tuple[Tensor, Tensor]:
115
+ """
116
+ Args
117
+ ----
118
+ q : torch.Tensor
119
+ Embedded query tensor, expected size is B x S x H x Eh
120
+ k : torch.Tensor
121
+ Embedded query tensor, expected size is B x S x H x Eh
122
+ """
123
+ assert len(q.size()) == 4
124
+ assert len(k.size()) == 4
125
+
126
+ seq_len = self.max_seq_len
127
+ alpha = self.compute_freqs_cis(q.device, seq_len)
128
+ freqs = self.cached_freqs[q.device.index][alpha]
129
+
130
+ freqs = freqs.float() # 1 L D/2 2 2
131
+ q_out = apply_rotary_pos_emb(q, freqs[..., 0], freqs[..., 1], offset=offset).type_as(q)
132
+ k_out = apply_rotary_pos_emb(k, freqs[..., 0], freqs[..., 1], offset=offset).type_as(k)
133
+
134
+ return q_out.view_as(q), k_out.view_as(k)
135
+
136
+ class Linear(nn.Linear):
137
+ def __init__(self, *args, **kwargs):
138
+ super().__init__(*args, **kwargs, bias=False)
139
+
140
+ class Norm(nn.Module):
141
+ def __init__(self,
142
+ dim: int,
143
+ eps: float = 1e-5,) -> None:
144
+ super().__init__()
145
+ self.eps = eps
146
+ self.weight = nn.Parameter(T.ones((dim,)))
147
+
148
+ def forward(self, input: Tensor) -> Tensor:
149
+ return F.layer_norm(input, (self.weight.shape[0],), weight=self.weight, bias=None, eps=self.eps)
150
+
151
+
152
+ class FFNN(nn.Module):
153
+ def __init__(self,
154
+ dim: int,
155
+ expand_dim: int = None,):
156
+ super().__init__()
157
+ expand_dim = default(expand_dim, 256 * ((int(2 * 4 * dim / 3) + 256 - 1) // 256))
158
+ self.dim = dim
159
+ self.expand_dim = expand_dim
160
+
161
+ self.gateup_proj = Linear(dim, 2*expand_dim)
162
+ self.down_proj = Linear(expand_dim, dim)
163
+
164
+ def forward(self, x):
165
+ gate, up = self.gateup_proj(x).chunk(2, dim=-1)
166
+ return self.down_proj(up * F.silu(gate))
167
+
168
+ class GQA(nn.Module):
169
+ def __init__(self,
170
+ dim: int,
171
+ n_head: int,
172
+ shape_rotator: ShapeRotator,
173
+ kv_heads: Optional[int] = None,
174
+ eps: float = 1e-5,
175
+ causal: bool = True,):
176
+ super().__init__()
177
+ self.n_heads = n_head
178
+ self.kv_heads = default(kv_heads, n_head)
179
+ self.head_dim = dim // n_head
180
+ self.causal = causal
181
+
182
+ self.proj_qkv = Linear(dim, self.head_dim*(n_head+2*self.kv_heads))
183
+
184
+ self.norm_q = Norm(self.head_dim*n_head, eps=eps)
185
+ self.norm_k = Norm(self.head_dim*self.kv_heads, eps=eps)
186
+
187
+ self.attn_out = Linear(dim, dim)
188
+
189
+ self.shape_rotator = shape_rotator
190
+
191
+ def _sdpa(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
192
+ k = k.repeat_interleave(self.n_heads // self.kv_heads, dim=2)
193
+ v = v.repeat_interleave(self.n_heads // self.kv_heads, dim=2)
194
+ x = F.scaled_dot_product_attention(
195
+ q.transpose(1, 2),
196
+ k.transpose(1, 2),
197
+ v.transpose(1, 2),
198
+ is_causal=False if (q.size(1) != k.size(1)) else self.causal,
199
+ )
200
+ x = x.transpose(1, 2).contiguous()
201
+ return x
202
+
203
+ def _attend(self, q: Tensor, k: Tensor, v: Tensor, kv_cache: Optional[Tensor] = None,):
204
+ cache_len = get_cache_len(kv_cache)
205
+ q, k = self.shape_rotator.rotate(q, k, offset=cache_len)
206
+ if exists(kv_cache):
207
+ k = T.cat([kv_cache[:, :cache_len, 0], k], dim=1)
208
+ v = T.cat([kv_cache[:, :cache_len, 1], v], dim=1)
209
+ kv_cache[:, :k.size(1), 0] = k
210
+ kv_cache[:, :v.size(1), 1] = v
211
+ x = self._sdpa(q, k, v)
212
+ return self.attn_out(rearrange(x, 'b s h d -> b s (h d)'))
213
+
214
+ def _project(self, x):
215
+ full_q, full_k, full_v = self.proj_qkv(x).chunk(3, dim=-1)
216
+ normed_full_q = self.norm_q(full_q).to(full_q.dtype)
217
+ normed_full_k = self.norm_k(full_k).to(full_k.dtype)
218
+
219
+ q = rearrange(normed_full_q, 'b s (h d) -> b s h d', h=self.n_heads)
220
+ k = rearrange(normed_full_k, 'b s (h d) -> b s h d', h=self.kv_heads)
221
+ v = rearrange(full_v, 'b s (h d) -> b s h d', h=self.kv_heads)
222
+ return q, k, v
223
+
224
+ def forward(self,
225
+ x: Tensor,
226
+ kv: Optional[Tensor] = None,):
227
+ """
228
+ x: (B, S, D)
229
+ kv: (B, S, H, D)
230
+ """
231
+ q, k, v = self._project(x)
232
+ return self._attend(q, k, v, kv_cache=kv)
233
+
234
+
235
+ class PreNormAttn(nn.Module):
236
+ def __init__(self,
237
+ dim: int,
238
+ n_head: int,
239
+ shape_rotator: ShapeRotator,
240
+ kv_heads: Optional[int] = None,
241
+ eps: float = 1e-5,
242
+ causal: bool = True,):
243
+ super().__init__()
244
+ self.attn_norm = Norm(dim, eps=eps)
245
+ self.attn = GQA(dim, n_head, shape_rotator, kv_heads, eps=eps, causal=causal)
246
+
247
+ def forward(self, x: Tensor, kv: Optional[Tensor] = None) -> Tensor:
248
+ """
249
+ x: (B, S, D)
250
+ kv: (B, S, H, D)
251
+ """
252
+ return x + self.attn(self.attn_norm(x), kv)
253
+
254
+ class PreNormFFNN(nn.Module):
255
+ def __init__(self,
256
+ dim: int,
257
+ ff_dim: int,
258
+ eps: float = 1e-5,):
259
+ super().__init__()
260
+ self.ffnn_norm = Norm(dim, eps=eps)
261
+ self.ffnn = FFNN(dim, ff_dim)
262
+
263
+ def forward(self, x: Tensor) -> Tensor:
264
+ return x + self.ffnn(self.ffnn_norm(x))
265
+
266
+ class Block(nn.Module):
267
+ def __init__(self,
268
+ dim: int,
269
+ layer_id: int = 0,
270
+ n_head: int = 16,
271
+ kv_heads: Optional[int] = None,
272
+ ff_dim: Optional[int] = None,
273
+ eps: float = 1e-5,
274
+ causal: bool = True,
275
+ shape_rotator: ShapeRotator = None):
276
+ super().__init__()
277
+ self.attn = PreNormAttn(dim, n_head, shape_rotator, kv_heads, eps=eps, causal=causal)
278
+ self.ffnn = PreNormFFNN(dim, ff_dim, eps=eps)
279
+ self.dim = dim
280
+ self.layer_id = layer_id
281
+ self.head_dim = dim // n_head
282
+ self.expand_dim = self.ffnn.ffnn.expand_dim
283
+
284
+ self.reset_parameters()
285
+
286
+ def reset_parameters(self):
287
+ std = 1.0 / math.sqrt(self.dim)
288
+ nn.init.trunc_normal_(self.ffnn.ffnn.gateup_proj.weight, std=std, a=-3 * std, b=3 * std)
289
+ nn.init.trunc_normal_(self.attn.attn.proj_qkv.weight, std=std, a=-3 * std, b=3 * std)
290
+ nn.init.trunc_normal_(self.attn.attn.attn_out.weight, std=std, a=-3 * std, b=3 * std)
291
+
292
+ xstd = 1.0 / math.sqrt(self.expand_dim)
293
+ nn.init.trunc_normal_(self.ffnn.ffnn.down_proj.weight, std=xstd, a=-3 * xstd, b=3 * xstd)
294
+
295
+ def forward(self, x: Tensor, kv: Optional[Tensor] = None) -> Tensor:
296
+ """
297
+ x: (B, S, D)
298
+ kv: (B, S, H, D)
299
+ """
300
+ h = self.attn(x, kv)
301
+ out = self.ffnn(h)
302
+ return out
303
+
304
+
305
+
306
+ class GPTOutput(nn.Module):
307
+ def __init__(self, dim, vocab_size):
308
+ super().__init__()
309
+ self.dim = dim
310
+ self.norm = Norm(dim)
311
+ self.output = Linear(dim, vocab_size)
312
+
313
+ self.reset_parameters()
314
+
315
+ def reset_parameters(self):
316
+ std = 1.0 / math.sqrt(self.dim**2)
317
+ nn.init.trunc_normal_(self.output.weight, std=std, a=-3 * std, b=3 * std)
318
+
319
+ def forward(self, x):
320
+ return self.output(self.norm(x))
321
+
322
+ @si_module
323
+ class Stack(nn.Module):
324
+ class Config:
325
+ layers: int
326
+ dim: int
327
+ seq_len: int
328
+ n_head: int = 32
329
+ ff_dim: int = None
330
+ kv_heads: int = None
331
+ eps: float = 1e-5
332
+ theta: Union[int, float] = 10_000
333
+ causal: bool = True
334
+
335
+ from_pretrained: Optional[Tuple[str, int]] = None
336
+
337
+ def __init__(self, c: Config):
338
+ super().__init__()
339
+
340
+ from_pretrained = c.from_pretrained
341
+ if exists(from_pretrained):
342
+ checkpoint = load_ckpt(c.from_pretrained)
343
+
344
+ self.shape_rotator = ShapeRotator(c.dim//c.n_head, c.seq_len, theta=c.theta)
345
+
346
+ self.layers = nn.ModuleList([
347
+ Block(
348
+ dim=c.dim,
349
+ layer_id=l,
350
+ n_head=c.n_head,
351
+ kv_heads=c.kv_heads,
352
+ ff_dim=c.ff_dim,
353
+ eps=c.eps,
354
+ causal=c.causal,
355
+ shape_rotator=self.shape_rotator,
356
+ ) for l in range(c.layers)
357
+ ])
358
+
359
+ kv_heads = c.kv_heads or c.n_head
360
+ head_dim = c.dim // c.n_head
361
+ cache_shape = [c.layers, c.seq_len, 2, kv_heads, head_dim]
362
+ self.cache_shape = cache_shape
363
+ self.cache = [None] * c.layers
364
+
365
+ if exists(from_pretrained):
366
+ self.load_state_dict(checkpoint)
367
+
368
+ def init_cache(self, bsize, device, dtype, length:int=None):
369
+ if self.cache_shape is None:
370
+ return
371
+ cache_shape = self.cache_shape.copy()
372
+ cache_shape[1] = length or cache_shape[1]
373
+ self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1)
374
+
375
+ def deinit_cache(self):
376
+ self.cache = [None] * len(self.cache)
377
+
378
+ def forward(self, x: Tensor) -> Tensor:
379
+ for l, layer in enumerate(self.layers):
380
+ x = layer(x, kv=self.cache[l])
381
+ return x
utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .blocks import *
2
+ from .dist import *
3
+ from .interp import *
utils/blocks.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import TypeVar, Generic, Type, Optional
3
+ from functools import wraps
4
+ import time
5
+ import random
6
+
7
+ import torch as T
8
+ import torch.nn as nn
9
+
10
+ # @TODO: remove si_module from codebase
11
+ # we use this in our research codebase to make modules from callable configs
12
+ si_module_TpV = TypeVar('si_module_TpV')
13
+ def si_module(cls: Type[si_module_TpV]) -> Type[si_module_TpV]:
14
+ if not hasattr(cls, 'Config') or not isinstance(cls.Config, type):
15
+ class Config:
16
+ pass
17
+ cls.Config = Config
18
+
19
+ cls.Config = dataclass(cls.Config)
20
+
21
+ class ConfigWrapper(cls.Config, Generic[si_module_TpV]):
22
+ def __call__(self, *args, **kwargs) -> si_module_TpV:
23
+ if len(kwargs) > 0:
24
+ config_dict = {field.name: getattr(self, field.name) for field in self.__dataclass_fields__.values()}
25
+ config_dict.update(kwargs)
26
+ new_config = type(self)(**config_dict)
27
+ return cls(new_config)
28
+ else:
29
+ return cls(self, *args)
30
+
31
+ ConfigWrapper.__module__ = cls.__module__
32
+ ConfigWrapper.__name__ = f"{cls.__name__}Config"
33
+ ConfigWrapper.__qualname__ = f"{cls.__qualname__}.Config"
34
+
35
+ cls.Config = ConfigWrapper
36
+
37
+ original_init = cls.__init__
38
+ def new_init(self, *args, **kwargs):
39
+ self.c = next((arg for arg in args if isinstance(arg, cls.Config)), None) or next((arg for arg in kwargs.values() if isinstance(arg, cls.Config)), None)
40
+ original_init(self, *args, **kwargs)
41
+ self.register_buffer('_device_tracker', T.Tensor(), persistent=False)
42
+
43
+ cls.__init__ = new_init
44
+
45
+ @property
46
+ def device(self):
47
+ return self._device_tracker.device
48
+
49
+ @property
50
+ def dtype(self):
51
+ return self._device_tracker.dtype
52
+
53
+ cls.device = device
54
+ cls.dtype = dtype
55
+
56
+ return cls
57
+
58
+
59
+ def get_activation(nonlinear_activation, nonlinear_activation_params={}):
60
+ if hasattr(nn, nonlinear_activation):
61
+ return getattr(nn, nonlinear_activation)(**nonlinear_activation_params)
62
+ else:
63
+ raise NotImplementedError(f"Activation {nonlinear_activation} not found in torch.nn")
64
+
65
+
66
+ def exists(v):
67
+ return v is not None
68
+
69
+ def isnt(v):
70
+ return not exists(v)
71
+
72
+ def truthyexists(v):
73
+ return exists(v) and v is not False
74
+
75
+ def truthyattr(obj, attr):
76
+ return hasattr(obj, attr) and truthyexists(getattr(obj, attr))
77
+
78
+ defaultT = TypeVar('defaultT')
79
+
80
+ def default(*args: Optional[defaultT]) -> Optional[defaultT]:
81
+ for arg in args:
82
+ if exists(arg):
83
+ return arg
84
+ return None
85
+
86
+ def maybe(fn):
87
+ @wraps(fn)
88
+ def inner(x, *args, **kwargs):
89
+ if not exists(x):
90
+ return x
91
+ return fn(x, *args, **kwargs)
92
+ return inner
utils/dist.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch as T
3
+ import re
4
+ from tqdm import tqdm
5
+ from datetime import timedelta
6
+
7
+ import requests
8
+ import hashlib
9
+
10
+ from io import BytesIO
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ def rank0():
14
+ rank = os.environ.get('RANK')
15
+ if rank is None or rank == '0':
16
+ return True
17
+ else:
18
+ return False
19
+
20
+ def local0():
21
+ local_rank = os.environ.get('LOCAL_RANK')
22
+ if local_rank is None or local_rank == '0':
23
+ return True
24
+ else:
25
+ return False
26
+ class tqdm0(tqdm):
27
+ def __init__(self, *args, **kwargs):
28
+ total = kwargs.get('total', None)
29
+ if total is None and len(args) > 0:
30
+ try:
31
+ total = len(args[0])
32
+ except TypeError:
33
+ pass
34
+ if total is not None:
35
+ kwargs['miniters'] = max(1, total // 20)
36
+ super().__init__(*args, **kwargs, disable=not rank0(), bar_format='{bar}| {n_fmt}/{total_fmt} [{rate_fmt}{postfix}]')
37
+
38
+ def print0(*args, **kwargs):
39
+ if rank0():
40
+ print(*args, **kwargs)
41
+
42
+ _PRINTED_IDS = set()
43
+
44
+ def printonce(*args, id=None, **kwargs):
45
+ if id is None:
46
+ id = ' '.join(map(str, args))
47
+
48
+ if id not in _PRINTED_IDS:
49
+ print(*args, **kwargs)
50
+ _PRINTED_IDS.add(id)
51
+
52
+ def print0once(*args, **kwargs):
53
+ if rank0():
54
+ printonce(*args, **kwargs)
55
+
56
+ def init_dist():
57
+ if T.distributed.is_initialized():
58
+ print0('Distributed already initialized')
59
+ rank = T.distributed.get_rank()
60
+ local_rank = int(os.environ.get('LOCAL_RANK', 0))
61
+ world_size = T.distributed.get_world_size()
62
+ else:
63
+ try:
64
+ rank = int(os.environ['RANK'])
65
+ local_rank = int(os.environ['LOCAL_RANK'])
66
+ world_size = int(os.environ['WORLD_SIZE'])
67
+ device = f'cuda:{local_rank}'
68
+ T.cuda.set_device(device)
69
+ T.distributed.init_process_group(backend='nccl', timeout=timedelta(minutes=30), rank=rank, world_size=world_size, device_id=T.device(device))
70
+ print(f'Rank {rank} of {world_size}.')
71
+ except Exception as e:
72
+ print0once(f'Not initializing distributed env: {e}')
73
+ rank = 0
74
+ local_rank = 0
75
+ world_size = 1
76
+ return rank, local_rank, world_size
77
+
78
+ def load_ckpt(load_from_location, expected_hash=None):
79
+ os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1' #Disable this to speed up debugging errors with downloading from the hub
80
+ if local0():
81
+ repo_id = "si-pbc/hertz-dev"
82
+ print0(f'Loading checkpoint from repo_id {repo_id} and filename {load_from_location}.pt. This may take a while...')
83
+ save_path = hf_hub_download(repo_id=repo_id, filename=f"{load_from_location}.pt")
84
+ print0(f'Downloaded checkpoint to {save_path}')
85
+ if expected_hash is not None:
86
+ with open(save_path, 'rb') as f:
87
+ file_hash = hashlib.md5(f.read()).hexdigest()
88
+ if file_hash != expected_hash:
89
+ print(f'Hash mismatch for {save_path}. Expected {expected_hash} but got {file_hash}. Deleting checkpoint and trying again.')
90
+ os.remove(save_path)
91
+ return load_ckpt(load_from_location, expected_hash)
92
+ if T.distributed.is_initialized():
93
+ save_path = [save_path]
94
+ T.distributed.broadcast_object_list(save_path, src=0)
95
+ save_path = save_path[0]
96
+ loaded = T.load(save_path, weights_only=False, map_location='cpu')
97
+ print0(f'Loaded checkpoint from {save_path}')
98
+ return loaded
utils/interp.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as T
2
+ import os
3
+
4
+ def rank0():
5
+ rank = os.environ.get('RANK')
6
+ if rank is None or rank == '0':
7
+ return True
8
+ else:
9
+ return False
10
+
11
+ def print_colored(message, color='reset', bold=False, **kwargs):
12
+ color_dict = {
13
+ 'bold': '\033[1m',
14
+ 'green': '\033[92m',
15
+ 'yellow': '\033[93m',
16
+ 'red': '\033[91m',
17
+ 'blue': '\033[94m',
18
+ 'grey': '\033[90m',
19
+ 'white': '\033[97m',
20
+ 'reset': '\033[0m'
21
+ }
22
+
23
+ color_code = color_dict.get(color.lower(), color_dict['reset'])
24
+ prefix = color_dict['bold'] if bold else ''
25
+ print(f"{prefix}{color_code}{message}{color_dict['reset']}", **kwargs)
26
+
27
+ def print0_colored(*args, **kwargs):
28
+ if rank0():
29
+ print_colored(*args, **kwargs)
30
+
31
+ def param_count(module):
32
+ def count_parameters(model):
33
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
34
+
35
+ total_params = count_parameters(module)
36
+ output = [f'Total model parameters: {total_params:,}', '---------------------------']
37
+
38
+ for name, child in module.named_children():
39
+ params = count_parameters(child)
40
+ output.append(f'{name} parameters: {params:,}')
41
+
42
+ return '\n'.join(output)
43
+
44
+ def model_size_estimation(module):
45
+ def estimate_size(model):
46
+ param_size = sum(p.nelement() * p.element_size() for p in model.parameters())
47
+ buffer_size = sum(b.nelement() * b.element_size() for b in model.buffers())
48
+ return param_size + buffer_size
49
+
50
+ total_size = estimate_size(module)
51
+ output = [f'Total model size: {total_size / 1024**2:.2f} MB', '---------------------------']
52
+
53
+ for name, child in module.named_children():
54
+ child_size = estimate_size(child)
55
+ output.append(f'{name} size: {child_size / 1024**2:.2f} MB')
56
+
57
+ return '\n'.join(output)
58
+
59
+ def layer_param_distribution(module):
60
+ def count_parameters(model):
61
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
62
+
63
+ def get_layer_types(model):
64
+ layer_types = {}
65
+ for name, module in model.named_modules():
66
+ layer_type = module.__class__.__name__
67
+ params = sum(p.numel() for p in module.parameters(recurse=False) if p.requires_grad)
68
+ if params > 0:
69
+ if layer_type not in layer_types:
70
+ layer_types[layer_type] = 0
71
+ layer_types[layer_type] += params
72
+ return layer_types
73
+
74
+ total_params = count_parameters(module)
75
+ layer_types = get_layer_types(module)
76
+
77
+ output = [f'Total trainable parameters: {total_params:,}', '---------------------------']
78
+
79
+ for layer_type, count in sorted(layer_types.items(), key=lambda x: x[1], reverse=True):
80
+ percentage = (count / total_params) * 100
81
+ output.append(f'{layer_type}: {count:,} ({percentage:.2f}%)')
82
+
83
+ return '\n'.join(output)
84
+