eustlb HF staff commited on
Commit
fec8e73
·
1 Parent(s): 3ab5772
README.md CHANGED
@@ -6,13 +6,43 @@ library_name: transformers
6
  pipeline_tag: automatic-speech-recognition
7
  arxiv: https://arxiv.org/abs/2410.15608
8
  ---
9
- # Model Card: Moonshine
10
 
11
  [[Blog]](https://petewarden.com/2024/10/21/introducing-moonshine-the-new-state-of-the-art-for-speech-to-text/) [[Paper]](https://arxiv.org/abs/2410.15608) [[Installation]](https://github.com/usefulsensors/moonshine/blob/main/README.md) [[Podcast]](https://notebooklm.google.com/notebook/d787d6c2-7d7b-478c-b7d5-a0be4c74ae19/audio)
12
 
13
  This is the model card for running the automatic speech recognition (ASR) models (Moonshine models) trained and released by Useful Sensors.
14
 
15
- Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993), we're providing some information about the automatic speech recognition model. More information on how these models were trained and evaluated can be found [in the paper](https://arxiv.org/abs/2410.15608). Note, a lot of the text has been copied verbatim from the [model card](https://github.com/openai/whisper/blob/main/model-card.md) for the Whisper model developed by OpenAI, because both models serve identical purposes, and carry identical risks.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  ## Model Details
18
 
@@ -94,8 +124,8 @@ There are also potential dual-use concerns that come with releasing Moonshine. W
94
  if sr != 16000:
95
  audio = torchaudio.functional.resample(audio, sr, 16000)
96
 
97
- model = AutoModelForSpeechSeq2Seq.from_pretrained('usefulsensors/moonshine-base', trust_remote_code=True)
98
- tokenizer = PreTrainedTokenizerFast.from_pretrained('usefulsensors/moonshine-base')
99
 
100
  tokens = model(audio)
101
  print(tokenizer.decode(tokens[0], skip_special_tokens=True))
@@ -113,4 +143,4 @@ If you benefit from our work, please cite us:
113
  primaryClass={cs.SD},
114
  url={https://arxiv.org/abs/2410.15608},
115
  }
116
- ```
 
6
  pipeline_tag: automatic-speech-recognition
7
  arxiv: https://arxiv.org/abs/2410.15608
8
  ---
9
+ # Moonshine
10
 
11
  [[Blog]](https://petewarden.com/2024/10/21/introducing-moonshine-the-new-state-of-the-art-for-speech-to-text/) [[Paper]](https://arxiv.org/abs/2410.15608) [[Installation]](https://github.com/usefulsensors/moonshine/blob/main/README.md) [[Podcast]](https://notebooklm.google.com/notebook/d787d6c2-7d7b-478c-b7d5-a0be4c74ae19/audio)
12
 
13
  This is the model card for running the automatic speech recognition (ASR) models (Moonshine models) trained and released by Useful Sensors.
14
 
15
+ Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993), we're providing some information about the automatic speech recognition model. More information on how these models were trained and evaluated can be found [in the paper](https://arxiv.ojrg/abs/2410.15608). Note, a lot of the text has been copied verbatim from the [model card](https://github.com/openai/whisper/blob/main/model-card.md) for the Whisper model developed by OpenAI, because both models serve identical purposes, and carry identical risks.
16
+
17
+ ## Usage
18
+
19
+ Moonshine is supported in Hugging Face 🤗 Transformers. To run the model, first install the Transformers library. For this example, we'll also install 🤗 Datasets to load toy audio dataset from the Hugging Face Hub, and 🤗 Accelerate to reduce the model loading time:
20
+
21
+ ```bash
22
+ pip install --upgrade pip
23
+ pip install --upgrade transformers datasets[audio]
24
+ ```
25
+
26
+ ```python
27
+ from transformers import MoonshineForConditionalGeneration, AutoProcessor
28
+ from datasets import load_dataset, Audio
29
+ import torch
30
+
31
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
32
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
33
+
34
+ model = MoonshineForConditionalGeneration.from_pretrained('UsefulSensors/moonshine-base').to(device).to(torch_dtype)
35
+ processor = AutoProcessor.from_pretrained('UsefulSensors/moonshine-base')
36
+
37
+ dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
38
+ dataset = dataset.cast_column("audio", Audio(processor.feature_extractor.sampling_rate))
39
+ sample = dataset[0]["audio"]
40
+
41
+ inputs = processor(sample["array"], return_tensors="pt").to(device).to(torch_dtype)
42
+
43
+ generated_ids = model.generate(**inputs)
44
+ print(processor.decode(generated_ids[0], skip_special_tokens=True))
45
+ ```
46
 
47
  ## Model Details
48
 
 
124
  if sr != 16000:
125
  audio = torchaudio.functional.resample(audio, sr, 16000)
126
 
127
+ model = AutoModelForSpeechSeq2Seq.from_pretrained('usefulsensors/moonshine-tiny', trust_remote_code=True)
128
+ tokenizer = PreTrainedTokenizerFast.from_pretrained('usefulsensors/moonshine-tiny')
129
 
130
  tokens = model(audio)
131
  print(tokenizer.decode(tokens[0], skip_special_tokens=True))
 
143
  primaryClass={cs.SD},
144
  url={https://arxiv.org/abs/2410.15608},
145
  }
146
+ ```
config.json CHANGED
@@ -1,20 +1,32 @@
1
  {
 
2
  "architectures": [
3
- "MoonshineModel"
4
  ],
5
- "auto_map": {
6
- "AutoConfig": "configuration_moonshine.MoonshineConfig",
7
- "AutoModelForSpeechSeq2Seq": "modeling_moonshine.MoonshineModel"
8
- },
9
- "dec_depth": 8,
10
- "dec_ff_swiglu": true,
11
- "dec_voc_size": 32768,
12
- "dim": 416,
13
- "enc_depth": 8,
14
- "enc_ff_swiglu": false,
15
- "inner_dim": 416,
 
 
 
 
 
 
 
16
  "model_type": "moonshine",
17
- "n_head": 8,
 
 
18
  "torch_dtype": "float32",
19
- "transformers_version": "4.46.1"
 
 
20
  }
 
1
  {
2
+ "_name_or_path": "/home/eustache_lebihan/dev/add-moonshine/moonshine-base",
3
  "architectures": [
4
+ "MoonshineForConditionalGeneration"
5
  ],
6
+ "attention_bias": false,
7
+ "attention_dropout": 0.0,
8
+ "bos_token_id": 1,
9
+ "decoder_hidden_act": "silu",
10
+ "decoder_num_attention_heads": 8,
11
+ "decoder_num_hidden_layers": 8,
12
+ "decoder_num_key_value_heads": 8,
13
+ "decoder_start_token_id": 1,
14
+ "encoder_hidden_act": "gelu",
15
+ "encoder_num_attention_heads": 8,
16
+ "encoder_num_hidden_layers": 8,
17
+ "encoder_num_key_value_heads": 8,
18
+ "eos_token_id": 2,
19
+ "hidden_size": 416,
20
+ "initializer_range": 0.02,
21
+ "intermediate_size": 1664,
22
+ "is_encoder_decoder": true,
23
+ "max_position_embeddings": 512,
24
  "model_type": "moonshine",
25
+ "partial_rotary_factor": 0.62,
26
+ "rope_scaling": null,
27
+ "rope_theta": 10000.0,
28
  "torch_dtype": "float32",
29
+ "transformers_version": "4.48.0.dev0",
30
+ "use_cache": true,
31
+ "vocab_size": 32768
32
  }
configuration_moonshine.py DELETED
@@ -1,32 +0,0 @@
1
- from transformers import PretrainedConfig
2
- from typing import List
3
-
4
-
5
- class MoonshineConfig(PretrainedConfig):
6
- model_type = "moonshine"
7
-
8
- def __init__(
9
- self,
10
- dim: int = 288,
11
- inner_dim: int = None,
12
- enc_depth: int = 8,
13
- dec_depth: int = 8,
14
- n_head: int = 8,
15
- dec_voc_size: int = 32768,
16
- enc_ff_swiglu: bool = False,
17
- dec_ff_swiglu: bool = True,
18
- **kwargs
19
- ):
20
- if inner_dim is None:
21
- inner_dim = dim
22
- if inner_dim % n_head != 0:
23
- raise ValueError("`inner dim` must be divisible by `n_head`")
24
- self.dim = dim
25
- self.inner_dim = inner_dim
26
- self.enc_depth = enc_depth
27
- self.dec_depth = dec_depth
28
- self.n_head = n_head
29
- self.dec_voc_size = dec_voc_size
30
- self.enc_ff_swiglu = enc_ff_swiglu
31
- self.dec_ff_swiglu = dec_ff_swiglu
32
- super().__init__(**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "decoder_start_token_id": 1,
5
+ "eos_token_id": 2,
6
+ "transformers_version": "4.48.0.dev0"
7
+ }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:34a96dca0b71860f98e3f07d30e0fbea17bbce5529eebab32f8c7aff262622b4
3
- size 411541680
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e020c79d0a979a7ec099f718ff1cd2f19e92aead230d69654bca5975a8e1b862
3
+ size 246079928
modeling_moonshine.py DELETED
@@ -1,512 +0,0 @@
1
- from einops import rearrange
2
- from einops.layers.torch import Rearrange
3
- from torch import nn
4
- from transformers import PreTrainedModel
5
-
6
- import math
7
- import torch
8
-
9
- from .configuration_moonshine import MoonshineConfig
10
-
11
-
12
- class RotaryEmbedding(nn.Module):
13
- def __init__(self, dim, base=10000):
14
- super().__init__()
15
-
16
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
17
- self.register_buffer("inv_freq", inv_freq, persistent=False)
18
-
19
- def forward(self, t):
20
- freqs = torch.einsum("i , j -> i j", t.type_as(self.inv_freq), self.inv_freq)
21
- freqs = torch.stack((freqs, freqs), dim=-1)
22
- return rearrange(freqs, "... d r -> ... (d r)")
23
-
24
-
25
- def rotate_half(x):
26
- x = rearrange(x, "... (d r) -> ... d r", r=2)
27
- x1, x2 = x.unbind(dim=-1)
28
- x = torch.stack((-x2, x1), dim=-1)
29
- return rearrange(x, "... d r -> ... (d r)")
30
-
31
-
32
- def apply_rotary_pos_emb(t, freqs):
33
- rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype
34
-
35
- freqs = freqs[-seq_len:, :]
36
-
37
- # partial rotary embeddings, Wang et al. GPT-J
38
- t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
39
- t = t * freqs.cos() + rotate_half(t) * freqs.sin()
40
- out = torch.cat((t, t_unrotated), dim=-1)
41
-
42
- return out.type(orig_dtype)
43
-
44
-
45
- class MultiHeadAttention(nn.Module):
46
- def __init__(self, dim, inner_dim, n_head):
47
- super().__init__()
48
- self.n_head = n_head
49
- self.to_q = nn.Linear(dim, inner_dim, bias=False)
50
- self.to_k = nn.Linear(dim, inner_dim, bias=False)
51
- self.to_v = nn.Linear(dim, inner_dim, bias=False)
52
- self.to_out = nn.Linear(inner_dim, dim, bias=False)
53
- self.softmax = nn.Softmax(dim=-1)
54
-
55
- # Scaled dot product attention
56
- def sdp_attention(self, q, k_t, v, mask=None):
57
- d_tensor = v.shape[3]
58
-
59
- op = (q @ k_t) / math.sqrt(d_tensor)
60
- if mask is not None:
61
- op = op.masked_fill(mask, -torch.finfo(op.dtype).max)
62
- score = self.softmax(op)
63
- out = score @ v
64
-
65
- # concat and pass to linear layer
66
- out = rearrange(out, "b h n d -> b n (h d)")
67
- return self.to_out(out)
68
-
69
- def forward(self, q, k, v, rot_pos_emb=None, mask=None):
70
- # dot product with weight matrices
71
- q, k, v = self.to_q(q), self.to_k(k), self.to_v(v)
72
-
73
- q = rearrange(q, "b n (h d) -> b h n d", h=self.n_head)
74
- k = rearrange(k, "b n (h d) -> b h n d", h=self.n_head)
75
- v = rearrange(v, "b n (h d) -> b h n d", h=self.n_head)
76
-
77
- # apply RoPE
78
- if rot_pos_emb is not None:
79
- q = apply_rotary_pos_emb(q, rot_pos_emb)
80
- k = apply_rotary_pos_emb(k, rot_pos_emb)
81
-
82
- k_t = k.transpose(2, 3)
83
-
84
- return self.sdp_attention(q, k_t, v, mask), k_t, v
85
-
86
-
87
- class MultiHeadCausalSelfAttentionWithKVCache(MultiHeadAttention):
88
- def __init__(self, dim, inner_dim, n_head):
89
- super().__init__(dim, inner_dim, n_head)
90
-
91
- def forward(self, q, k, v, k_cache, v_cache, rot_pos_emb, mask):
92
- # dot product with weight matrices
93
- q, k, v = self.to_q(q), self.to_k(k), self.to_v(v)
94
-
95
- q = rearrange(q, "b n (h d) -> b h n d", h=self.n_head)
96
- k = rearrange(k, "b n (h d) -> b h n d", h=self.n_head)
97
- v = rearrange(v, "b n (h d) -> b h n d", h=self.n_head)
98
-
99
- # apply RoPE
100
- q = apply_rotary_pos_emb(q, rot_pos_emb)
101
- k = apply_rotary_pos_emb(k, rot_pos_emb)
102
-
103
- k_t = k.transpose(2, 3)
104
-
105
- # Append new rows to K and V caches.
106
- k_t = torch.concat((k_cache, k_t), dim=3)
107
- v = torch.concat((v_cache, v), dim=2)
108
-
109
- return super().sdp_attention(q, k_t, v, mask=mask), k_t, v
110
-
111
-
112
- class MultiHeadCrossAttentionWithKVCache(MultiHeadAttention):
113
- def __init__(self, dim, inner_dim, n_head):
114
- super().__init__(dim, inner_dim, n_head)
115
-
116
- def forward(self, q, k_cache, v_cache, mask):
117
- q = self.to_q(q)
118
- q = rearrange(q, "b n (h d) -> b h n d", h=self.n_head)
119
-
120
- return super().sdp_attention(q, k_cache, v_cache, mask=mask)
121
-
122
-
123
- class FFLinearGelu(nn.Module):
124
- def __init__(self, dim, ff_mult=4):
125
- super().__init__()
126
-
127
- self.ff = nn.Sequential(
128
- nn.Linear(dim, dim * ff_mult, bias=True),
129
- nn.GELU(),
130
- nn.Linear(dim * ff_mult, dim, bias=True),
131
- )
132
-
133
- def forward(self, x):
134
- return self.ff(x)
135
-
136
-
137
- class FFSwiGLU(nn.Module):
138
- def __init__(self, dim, ff_mult=4):
139
- super().__init__()
140
-
141
- self.ff_proj = nn.Linear(dim, dim * ff_mult, bias=True)
142
- self.ff_noact = nn.Linear(dim, dim * ff_mult, bias=True)
143
- self.ff_act = nn.SiLU()
144
- self.ff_out = nn.Linear(dim * ff_mult, dim, bias=True)
145
-
146
- def forward(self, x):
147
- gate = self.ff_act(self.ff_proj(x))
148
- x_noact = self.ff_noact(x)
149
- x = x_noact * gate
150
- return self.ff_out(x)
151
-
152
-
153
- class EncoderLayer(nn.Module):
154
- def __init__(self, dim, inner_dim, n_head, ff_swiglu, ff_mult=4):
155
- super().__init__()
156
-
157
- self.norm1 = nn.LayerNorm(dim, bias=False)
158
-
159
- self.attention = MultiHeadAttention(dim, inner_dim=inner_dim, n_head=n_head)
160
-
161
- self.norm2 = nn.LayerNorm(dim, bias=False)
162
-
163
- self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
164
-
165
- def forward(self, x, rot_pos_emb, mask):
166
- _x = x
167
- x = self.norm1(x)
168
- x, _, _ = self.attention(q=x, k=x, v=x, rot_pos_emb=rot_pos_emb, mask=mask)
169
- x = x + _x
170
-
171
- _x = x
172
- x = self.norm2(x)
173
- x = self.ff(x)
174
-
175
- x = x + _x
176
- return x
177
-
178
-
179
- class Encoder(nn.Module):
180
- def __init__(self, dim, inner_dim, n_head, n_layers, ff_swiglu):
181
- super().__init__()
182
- rot_embed_dim = max(inner_dim / n_head / 2, 32)
183
- self.rot_pos_emb = RotaryEmbedding(rot_embed_dim)
184
-
185
- self.layers = nn.ModuleList(
186
- [EncoderLayer(dim, inner_dim, n_head, ff_swiglu) for _ in range(n_layers)]
187
- )
188
- self.post_norm = nn.LayerNorm(dim, bias=False)
189
-
190
- def forward(self, x, mask):
191
- pos = torch.arange(x.shape[-2], device=x.device)
192
- rot_pos_emb = self.rot_pos_emb(pos)
193
-
194
- for idx, layer in enumerate(self.layers):
195
- x = layer(x, rot_pos_emb=rot_pos_emb, mask=mask)
196
- return self.post_norm(x)
197
-
198
-
199
- class DecoderLayer(nn.Module):
200
- def __init__(self, dim, inner_dim, n_head, ff_swiglu, ff_mult=4):
201
- super().__init__()
202
-
203
- self.norm1 = nn.LayerNorm(dim, bias=False)
204
-
205
- self.self_attention = MultiHeadCausalSelfAttentionWithKVCache(
206
- dim, inner_dim=inner_dim, n_head=n_head
207
- )
208
-
209
- self.norm2 = nn.LayerNorm(dim, bias=False)
210
- self.cross_attention = MultiHeadCrossAttentionWithKVCache(
211
- dim, inner_dim=inner_dim, n_head=n_head
212
- )
213
-
214
- self.norm3 = nn.LayerNorm(dim, bias=False)
215
- self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
216
-
217
- def forward(self, x, k_cache, v_cache, x_attn_k_cache, x_attn_v_cache, rot_pos_emb, input_mask):
218
- dim = x.size()[1]
219
- causal_mask = torch.ones((dim, dim), dtype=torch.bool).triu(1).to(x.device)
220
- _x = x
221
- x = self.norm1(x)
222
- x, new_k_cache, new_v_cache = self.self_attention(
223
- q=x,
224
- k=x,
225
- v=x,
226
- k_cache=k_cache,
227
- v_cache=v_cache,
228
- rot_pos_emb=rot_pos_emb,
229
- mask=causal_mask,
230
- )
231
- x = x + _x
232
-
233
- _x = x
234
- x = self.norm2(x)
235
- x = self.cross_attention(q=x, k_cache=x_attn_k_cache, v_cache=x_attn_v_cache, mask=input_mask)
236
- x = x + _x
237
-
238
- _x = x
239
- x = self.norm3(x)
240
- x = self.ff(x)
241
- x = x + _x
242
-
243
- return x, new_k_cache, new_v_cache
244
-
245
-
246
- class Decoder(nn.Module):
247
- def __init__(self, dim, inner_dim, n_head, n_layers, dec_voc_size, ff_swiglu):
248
- super().__init__()
249
-
250
- self.n_head = n_head
251
- self.d_head = inner_dim // n_head
252
-
253
- rot_embed_dim = max(inner_dim / n_head / 2, 32)
254
- self.rot_pos_emb = RotaryEmbedding(rot_embed_dim)
255
-
256
- self.layers = nn.ModuleList(
257
- [DecoderLayer(dim, inner_dim, n_head, ff_swiglu) for _ in range(n_layers)]
258
- )
259
- self.final_norm = nn.LayerNorm(dim, bias=False)
260
- self.token_embedding = nn.Embedding(dec_voc_size, dim)
261
-
262
- def forward(self, x, input_mask, *args):
263
- pos = torch.arange(x.shape[1], device=x.device)
264
- rot_pos_emb = self.rot_pos_emb(pos)
265
- x = self.token_embedding(x)
266
-
267
- k_cache_new = []
268
- v_cache_new = []
269
-
270
- n_layer = len(self.layers)
271
- k_cache, v_cache, x_attn_k_cache, x_attn_v_cache = [
272
- args[i : i + n_layer] for i in range(0, 4 * n_layer, n_layer)
273
- ]
274
- for idx, layer in enumerate(self.layers):
275
- x, new_k_line, new_v_line = layer(
276
- x[:, -1:],
277
- k_cache=k_cache[idx],
278
- v_cache=v_cache[idx],
279
- x_attn_k_cache=x_attn_k_cache[idx],
280
- x_attn_v_cache=x_attn_v_cache[idx],
281
- rot_pos_emb=rot_pos_emb,
282
- input_mask=input_mask,
283
- )
284
- k_cache_new.append(new_k_line)
285
- v_cache_new.append(new_v_line)
286
-
287
- x = self.final_norm(x)
288
-
289
- return x @ self.token_embedding.weight.t(), *k_cache_new, *v_cache_new
290
-
291
-
292
- class InitialDecoderLayer(nn.Module):
293
- def __init__(self, dim, inner_dim, n_head, ff_swiglu, ff_mult=4):
294
- super().__init__()
295
-
296
- self.norm1 = nn.LayerNorm(dim, bias=False)
297
-
298
- self.self_attention = MultiHeadAttention(
299
- dim, inner_dim=inner_dim, n_head=n_head
300
- )
301
-
302
- self.norm2 = nn.LayerNorm(dim, bias=False)
303
- self.cross_attention = MultiHeadAttention(
304
- dim, inner_dim=inner_dim, n_head=n_head
305
- )
306
-
307
- self.norm3 = nn.LayerNorm(dim, bias=False)
308
- self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
309
-
310
- def forward(self, x, context, rot_pos_emb, input_mask):
311
- dim = x.size()[1]
312
- causal_mask = torch.ones((dim, dim), dtype=torch.bool).triu(1).to(x.device)
313
- _x = x
314
- x = self.norm1(x)
315
- x, new_k_cache, new_v_cache = self.self_attention(
316
- q=x,
317
- k=x,
318
- v=x,
319
- rot_pos_emb=rot_pos_emb,
320
- mask=causal_mask,
321
- )
322
- x = x + _x
323
-
324
- _x = x
325
- x = self.norm2(x)
326
- x, x_attn_k_cache, x_attn_v_cache = self.cross_attention(
327
- q=x, k=context, v=context, mask=input_mask,
328
- )
329
- x = x + _x
330
-
331
- _x = x
332
- x = self.norm3(x)
333
- x = self.ff(x)
334
- x = x + _x
335
-
336
- return x, new_k_cache, new_v_cache, x_attn_k_cache, x_attn_v_cache
337
-
338
-
339
- class DecoderInitial(Decoder):
340
- def __init__(self, dim, inner_dim, n_head, n_layers, dec_voc_size, ff_swiglu):
341
- super().__init__(dim, inner_dim, n_head, n_layers, dec_voc_size, ff_swiglu)
342
- self.layers = nn.ModuleList(
343
- [
344
- InitialDecoderLayer(dim, inner_dim, n_head, ff_swiglu)
345
- for _ in range(n_layers)
346
- ]
347
- )
348
-
349
- def forward(self, x, enc_src, input_mask):
350
- pos = torch.arange(x.shape[1], device=x.device)
351
- rot_pos_emb = self.rot_pos_emb(pos)
352
- x = self.token_embedding(x)
353
-
354
- # Shape [n_layers, batch_size, n_head, seq_len, inner_dim]. Cache K transposed.
355
- n_layer = len(self.layers)
356
- k_cache = []
357
- v_cache = []
358
- x_attn_k_cache = []
359
- x_attn_v_cache = []
360
-
361
- for idx, layer in enumerate(self.layers):
362
- x, new_k_line, new_v_line, new_x_attn_k_line, new_x_attn_v_line = layer(
363
- x,
364
- enc_src,
365
- rot_pos_emb,
366
- input_mask,
367
- )
368
-
369
- k_cache.append(new_k_line)
370
- v_cache.append(new_v_line)
371
- x_attn_k_cache.append(new_x_attn_k_line)
372
- x_attn_v_cache.append(new_x_attn_v_line)
373
-
374
- x = self.final_norm(x)
375
-
376
- return (
377
- x @ self.token_embedding.weight.t(),
378
- *k_cache,
379
- *v_cache,
380
- *x_attn_k_cache,
381
- *x_attn_v_cache,
382
- )
383
-
384
-
385
- class AudioPreprocessor(nn.Module):
386
- def __init__(self, dim):
387
- super().__init__()
388
- self.audio_preprocess = nn.Sequential(
389
- nn.Conv1d(1, dim, 127, 64, bias=False),
390
- nn.Tanh(),
391
- nn.GroupNorm(1, dim),
392
- nn.Conv1d(dim, 2 * dim, 7, 3),
393
- nn.GELU(),
394
- nn.Conv1d(2 * dim, dim, 3, 2),
395
- nn.GELU(),
396
- Rearrange("... c s -> ... s c"),
397
- )
398
-
399
- def forward(self, src):
400
- assert (
401
- src.shape[-1] >= 1023
402
- ), f"src shape[-1] {src.shape[-1]} should be at least 1023"
403
- src = src.reshape((-1, 1, src.shape[-1]))
404
- return self.audio_preprocess(src)
405
-
406
-
407
- class MoonshineModelTorch(nn.Module):
408
- def __init__(
409
- self,
410
- dim,
411
- inner_dim,
412
- enc_depth,
413
- dec_depth,
414
- n_head=8,
415
- dec_voc_size=32768,
416
- enc_ff_swiglu=False,
417
- dec_ff_swiglu=False,
418
- ):
419
- super().__init__()
420
- self.preprocessor = AudioPreprocessor(dim)
421
- self.encoder = Encoder(
422
- dim, inner_dim, n_head, enc_depth, ff_swiglu=enc_ff_swiglu
423
- )
424
- self.decoder_initial = DecoderInitial(
425
- dim, inner_dim, n_head, dec_depth, dec_voc_size, ff_swiglu=dec_ff_swiglu
426
- )
427
- self.decoder = Decoder(
428
- dim, inner_dim, n_head, dec_depth, dec_voc_size, ff_swiglu=dec_ff_swiglu
429
- )
430
- self.dec_depth = dec_depth
431
- self.n_head = n_head
432
- self.d_head = inner_dim // n_head
433
-
434
- def generate(self, src, mask):
435
- preprocessed = self.preprocessor(src)
436
- batch_size = preprocessed.shape[0]
437
-
438
- # Get max sequence length based on number of unmasked inputs for each sample in batch.
439
- token_limit_factor = 6.5 / 16000.0 # Maximum of 6.5 tokens per second.
440
- if mask is not None:
441
- seq_lens = torch.sum(mask, dim=-1, keepdim=True) * token_limit_factor
442
- else:
443
- token_limit = torch.tensor([src.shape[-1] * token_limit_factor])
444
- seq_lens = torch.stack([token_limit for _ in range(batch_size)])
445
- seq_lens = seq_lens.to(torch.int32).to(src.device).squeeze()
446
-
447
- # Preprocess mask so that it matches preprocessed audio.
448
- if mask is not None:
449
- mask = mask[..., :-127:64][..., :-7:3][..., :-3:2].to(torch.bool)
450
- mask = ~mask.reshape((batch_size, 1, 1, -1))
451
- mask = torch.nn.functional.pad(mask, (0, preprocessed.shape[-2] - mask.shape[-1]))
452
-
453
- enc = self.encoder(preprocessed, mask)
454
-
455
- sot_token = 1
456
- eot_token = 2
457
-
458
- sot_array = [[sot_token] for _ in range(batch_size)]
459
- seq = torch.as_tensor(sot_array).to(src.device)
460
-
461
- vals = self.decoder_initial(x=seq, enc_src=enc, input_mask=mask)
462
- logits = vals[0]
463
- k_cache, v_cache, x_attn_k_cache, x_attn_v_cache = [
464
- vals[i : i + self.dec_depth]
465
- for i in range(1, 1 + self.dec_depth * 4, self.dec_depth)
466
- ]
467
-
468
- sample = logits[:, -1].argmax(dim=-1, keepdim=True)
469
- seq = torch.cat((seq, sample), dim=-1)
470
-
471
- eot_mask = torch.zeros((batch_size), dtype=torch.bool).to(src.device)
472
- while not torch.all(eot_mask):
473
- vals = self.decoder(
474
- seq,
475
- mask,
476
- *k_cache,
477
- *v_cache,
478
- *x_attn_k_cache,
479
- *x_attn_v_cache,
480
- )
481
- logits = vals[0]
482
- k_cache = vals[1 : self.dec_depth + 1]
483
- v_cache = vals[self.dec_depth + 1 :]
484
- logits = logits[:, -1] # get last token
485
- sample = logits.argmax(dim=-1, keepdim=True)
486
- # For each sample in batch detect EOT or token limit reached.
487
- eot_mask = eot_mask | (sample.squeeze() == eot_token)
488
- eot_mask = eot_mask | (seq.shape[-1] >= seq_lens)
489
- sample = sample.masked_fill(eot_mask.reshape((-1, 1)), eot_token)
490
- seq = torch.cat((seq, sample), dim=-1)
491
-
492
- return seq
493
-
494
-
495
- class MoonshineModel(PreTrainedModel):
496
- config_class = MoonshineConfig
497
-
498
- def __init__(self, config):
499
- super().__init__(config)
500
- self.model = MoonshineModelTorch(
501
- dim = config.dim,
502
- inner_dim = config.inner_dim,
503
- enc_depth = config.enc_depth,
504
- dec_depth = config.dec_depth,
505
- n_head = config.n_head,
506
- dec_voc_size = config.dec_voc_size,
507
- enc_ff_swiglu = config.enc_ff_swiglu,
508
- dec_ff_swiglu = config.dec_ff_swiglu,
509
- )
510
-
511
- def forward(self, tensor, input_mask=None):
512
- return self.model.generate(tensor, input_mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
preprocessor_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": false,
3
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
+ "feature_size": 1,
5
+ "padding_side": "right",
6
+ "padding_value": 0.0,
7
+ "return_attention_mask": true,
8
+ "sampling_rate": 16000
9
+ }
special_tokens_map.json DELETED
@@ -1 +0,0 @@
1
- {}
 
 
tokenizer.json CHANGED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json DELETED
The diff for this file is too large to render. See raw diff