Aku Rouhe
commited on
Commit
·
fa827db
1
Parent(s):
59b49c5
Newer interface
Browse files- .gitattributes +1 -0
- hyperparams.yaml +7 -1
- interface.py +10 -1
.gitattributes
CHANGED
@@ -25,3 +25,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
hyperparams.yaml
CHANGED
@@ -1,6 +1,12 @@
|
|
|
|
|
|
1 |
feature_extractor: !new:speechbrain.lobes.features.Fbank
|
2 |
n_fft: 400
|
3 |
-
n_mels:
|
|
|
|
|
|
|
|
|
4 |
|
5 |
normalizer: !new:speechbrain.processing.features.InputNormalization
|
6 |
norm_type: global
|
|
|
1 |
+
n_mels: 40
|
2 |
+
|
3 |
feature_extractor: !new:speechbrain.lobes.features.Fbank
|
4 |
n_fft: 400
|
5 |
+
n_mels: !ref <n_mels>
|
6 |
+
|
7 |
+
feature_scaler: !new:custom.FeatureScaler
|
8 |
+
num_in: !ref <n_mels>
|
9 |
+
scale: 0.5
|
10 |
|
11 |
normalizer: !new:speechbrain.processing.features.InputNormalization
|
12 |
norm_type: global
|
interface.py
CHANGED
@@ -1,6 +1,14 @@
|
|
1 |
import torch
|
2 |
import speechbrain as sb
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
class Custom(sb.pretrained.interfaces.Pretrained):
|
5 |
MODULES_NEEDED = ["normalizer"]
|
6 |
HPARAMS_NEEDED = ["feature_extractor"]
|
@@ -8,7 +16,8 @@ class Custom(sb.pretrained.interfaces.Pretrained):
|
|
8 |
def feats_from_audio(self, audio, lengths=torch.tensor([1.0])):
|
9 |
feats = self.hparams.feature_extractor(audio)
|
10 |
normalized = self.mods.normalizer(feats, lengths)
|
11 |
-
|
|
|
12 |
|
13 |
def feats_from_file(self, path):
|
14 |
audio = self.load_audio(path)
|
|
|
1 |
import torch
|
2 |
import speechbrain as sb
|
3 |
|
4 |
+
class FeatureScaler(torch.nn.Module):
|
5 |
+
def __init__(self, num_in, scale):
|
6 |
+
super().__init__()
|
7 |
+
self.scaler = torch.nn.eye(num_in) * scale
|
8 |
+
|
9 |
+
def forward(x):
|
10 |
+
return x * self.scaler
|
11 |
+
|
12 |
class Custom(sb.pretrained.interfaces.Pretrained):
|
13 |
MODULES_NEEDED = ["normalizer"]
|
14 |
HPARAMS_NEEDED = ["feature_extractor"]
|
|
|
16 |
def feats_from_audio(self, audio, lengths=torch.tensor([1.0])):
|
17 |
feats = self.hparams.feature_extractor(audio)
|
18 |
normalized = self.mods.normalizer(feats, lengths)
|
19 |
+
scaled = self.mods.feature_scaler(normalized)
|
20 |
+
return scaled
|
21 |
|
22 |
def feats_from_file(self, path):
|
23 |
audio = self.load_audio(path)
|