yourusername commited on
Commit
66a6dc0
·
1 Parent(s): 9f5a755

:beers: cheers

Browse files
Files changed (43) hide show
  1. app.py +91 -0
  2. deepafx_st/__init__.py +4 -0
  3. deepafx_st/callbacks/audio.py +184 -0
  4. deepafx_st/callbacks/ckpt.py +33 -0
  5. deepafx_st/callbacks/params.py +87 -0
  6. deepafx_st/callbacks/plotting.py +126 -0
  7. deepafx_st/data/audio.py +177 -0
  8. deepafx_st/data/augmentations.py +235 -0
  9. deepafx_st/data/dataset.py +344 -0
  10. deepafx_st/data/proxy.py +181 -0
  11. deepafx_st/data/style.py +62 -0
  12. deepafx_st/metrics.py +157 -0
  13. deepafx_st/models/baselines.py +280 -0
  14. deepafx_st/models/controller.py +75 -0
  15. deepafx_st/models/efficient_net/LICENSE +202 -0
  16. deepafx_st/models/efficient_net/__init__.py +9 -0
  17. deepafx_st/models/efficient_net/model.py +419 -0
  18. deepafx_st/models/efficient_net/utils.py +616 -0
  19. deepafx_st/models/encoder.py +113 -0
  20. deepafx_st/models/mobilenetv2.py +226 -0
  21. deepafx_st/probes/cdpam_encoder.py +68 -0
  22. deepafx_st/probes/probe_system.py +307 -0
  23. deepafx_st/probes/random_mel.py +93 -0
  24. deepafx_st/processors/autodiff/__init__.py +0 -0
  25. deepafx_st/processors/autodiff/channel.py +28 -0
  26. deepafx_st/processors/autodiff/compressor.py +169 -0
  27. deepafx_st/processors/autodiff/fir.py +68 -0
  28. deepafx_st/processors/autodiff/peq.py +274 -0
  29. deepafx_st/processors/autodiff/signal.py +194 -0
  30. deepafx_st/processors/dsp/compressor.py +177 -0
  31. deepafx_st/processors/dsp/peq.py +323 -0
  32. deepafx_st/processors/processor.py +87 -0
  33. deepafx_st/processors/proxy/channel.py +130 -0
  34. deepafx_st/processors/proxy/proxy_system.py +289 -0
  35. deepafx_st/processors/proxy/tcn.py +199 -0
  36. deepafx_st/processors/spsa/channel.py +179 -0
  37. deepafx_st/processors/spsa/eps_scheduler.py +32 -0
  38. deepafx_st/processors/spsa/spsa_func.py +131 -0
  39. deepafx_st/system.py +563 -0
  40. deepafx_st/utils.py +277 -0
  41. deepafx_st/version.py +6 -0
  42. packages.txt +3 -0
  43. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import resampy
4
+ import torch
5
+ import torchaudio
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ from deepafx_st.system import System
9
+ from deepafx_st.utils import DSPMode
10
+
11
+ system = System.load_from_checkpoint(
12
+ hf_hub_download("nateraw/deepafx-st-libritts-autodiff", "lit_model.ckpt"), batch_size=1
13
+ ).eval()
14
+
15
+ gpu = torch.cuda.is_available()
16
+
17
+ if gpu:
18
+ system.to("cuda")
19
+
20
+
21
+ def process(input_path, reference_path):
22
+ # load audio data
23
+ x, x_sr = torchaudio.load(input_path)
24
+ r, r_sr = torchaudio.load(reference_path)
25
+
26
+ # resample if needed
27
+ if x_sr != 24000:
28
+ print("Resampling to 24000 Hz...")
29
+ x_24000 = torch.tensor(resampy.resample(x.view(-1).numpy(), x_sr, 24000))
30
+ x_24000 = x_24000.view(1, -1)
31
+ else:
32
+ x_24000 = x
33
+
34
+ if r_sr != 24000:
35
+ print("Resampling to 24000 Hz...")
36
+ r_24000 = torch.tensor(resampy.resample(r.view(-1).numpy(), r_sr, 24000))
37
+ r_24000 = r_24000.view(1, -1)
38
+ else:
39
+ r_24000 = r
40
+
41
+ # peak normalize to -12 dBFS
42
+ x_24000 = x_24000[0:1, : 24000 * 5]
43
+ x_24000 /= x_24000.abs().max()
44
+ x_24000 *= 10 ** (-12 / 20.0)
45
+ x_24000 = x_24000.view(1, 1, -1)
46
+
47
+ # peak normalize to -12 dBFS
48
+ r_24000 = r_24000[0:1, : 24000 * 5]
49
+ r_24000 /= r_24000.abs().max()
50
+ r_24000 *= 10 ** (-12 / 20.0)
51
+ r_24000 = r_24000.view(1, 1, -1)
52
+
53
+ if gpu:
54
+ x_24000 = x_24000.to("cuda")
55
+ r_24000 = r_24000.to("cuda")
56
+
57
+ with torch.no_grad():
58
+ y_hat, p, e = system(x_24000, r_24000)
59
+
60
+ y_hat = y_hat.view(1, -1)
61
+ y_hat /= y_hat.abs().max()
62
+ x_24000 /= x_24000.abs().max()
63
+
64
+ # Sqeeze to (T,), convert to numpy, and convert to int16
65
+ out_audio = (32767 * y_hat).squeeze(0).detach().cpu().numpy().astype(np.int16)
66
+
67
+ return 24000, out_audio
68
+
69
+
70
+ gr.Interface(
71
+ fn=process,
72
+ inputs=[gr.Audio(type="filepath"), gr.Audio(type="filepath")],
73
+ outputs="audio",
74
+ examples=[
75
+ [
76
+ hf_hub_download("nateraw/examples", "voice_raw.wav", repo_type="dataset", cache_dir="./data"),
77
+ hf_hub_download("nateraw/examples", "voice_produced.wav", repo_type="dataset", cache_dir="./data"),
78
+ ],
79
+ ],
80
+ title="DeepAFx-ST",
81
+ description=(
82
+ "Gradio demo for DeepAFx-ST for style transfer of audio effects with differentiable signal processing. To use it, simply"
83
+ " upload your audio files or choose from one of the examples. Read more at the links below."
84
+ ),
85
+ article=(
86
+ "<div style='text-align: center;'><a href='https://github.com/adobe-research/DeepAFx-ST' target='_blank'>Github Repo</a>"
87
+ " <center><img src='https://visitor-badge.glitch.me/badge?page_id=nateraw_deepafx-st' alt='visitor"
88
+ " badge'></center></div>"
89
+ ),
90
+ allow_flagging="never",
91
+ ).launch()
deepafx_st/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """Top-level module for deepafx_st"""
3
+
4
+ from .version import version as __version__
deepafx_st/callbacks/audio.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import auraloss
2
+ import numpy as np
3
+ import pytorch_lightning as pl
4
+
5
+ from deepafx_st.callbacks.plotting import plot_multi_spectrum
6
+ from deepafx_st.metrics import (
7
+ LoudnessError,
8
+ SpectralCentroidError,
9
+ CrestFactorError,
10
+ PESQ,
11
+ MelSpectralDistance,
12
+ )
13
+
14
+
15
+ class LogAudioCallback(pl.callbacks.Callback):
16
+ def __init__(self, num_examples=4, peak_normalize=True, sample_rate=22050):
17
+ super().__init__()
18
+ self.num_examples = 4
19
+ self.peak_normalize = peak_normalize
20
+
21
+ self.metrics = {
22
+ "PESQ": PESQ(sample_rate),
23
+ "MRSTFT": auraloss.freq.MultiResolutionSTFTLoss(
24
+ fft_sizes=[32, 128, 512, 2048, 8192, 32768],
25
+ hop_sizes=[16, 64, 256, 1024, 4096, 16384],
26
+ win_lengths=[32, 128, 512, 2048, 8192, 32768],
27
+ w_sc=0.0,
28
+ w_phs=0.0,
29
+ w_lin_mag=1.0,
30
+ w_log_mag=1.0,
31
+ ),
32
+ "MSD": MelSpectralDistance(sample_rate),
33
+ "SCE": SpectralCentroidError(sample_rate),
34
+ "CFE": CrestFactorError(),
35
+ "LUFS": LoudnessError(sample_rate),
36
+ }
37
+
38
+ self.outputs = []
39
+
40
+ def on_validation_batch_end(
41
+ self,
42
+ trainer,
43
+ pl_module,
44
+ outputs,
45
+ batch,
46
+ batch_idx,
47
+ dataloader_idx,
48
+ ):
49
+ """Called when the validation batch ends."""
50
+
51
+ if outputs is not None:
52
+ examples = np.min([self.num_examples, outputs["x"].shape[0]])
53
+ self.outputs.append(outputs)
54
+
55
+ if batch_idx == 0:
56
+ for n in range(examples):
57
+ if batch_idx == 0:
58
+ self.log_audio(
59
+ outputs,
60
+ n,
61
+ pl_module.hparams.sample_rate,
62
+ pl_module.hparams.val_length,
63
+ trainer.global_step,
64
+ trainer.logger,
65
+ )
66
+
67
+ def on_validation_end(self, trainer, pl_module):
68
+ metrics = {
69
+ "PESQ": [],
70
+ "MRSTFT": [],
71
+ "MSD": [],
72
+ "SCE": [],
73
+ "CFE": [],
74
+ "LUFS": [],
75
+ }
76
+ for output in self.outputs:
77
+ for metric_name, metric in self.metrics.items():
78
+ try:
79
+ val = metric(output["y_hat"], output["y"])
80
+ metrics[metric_name].append(val)
81
+ except:
82
+ pass
83
+
84
+ # log final mean metrics
85
+ for metric_name, metric in metrics.items():
86
+ val = np.mean(metric)
87
+ trainer.logger.experiment.add_scalar(
88
+ f"metrics/{metric_name}", val, trainer.global_step
89
+ )
90
+
91
+ # clear outputs
92
+ self.outputs = []
93
+
94
+ def compute_metrics(self, metrics_dict, outputs, batch_idx, global_step):
95
+ # extract audio
96
+ y = outputs["y"][batch_idx, ...].float()
97
+ y_hat = outputs["y_hat"][batch_idx, ...].float()
98
+
99
+ # compute all metrics
100
+ for metric_name, metric in self.metrics.items():
101
+ try:
102
+ val = metric(y_hat.view(1, 1, -1), y.view(1, 1, -1))
103
+ metrics_dict[metric_name].append(val)
104
+ except:
105
+ pass
106
+
107
+ def log_audio(self, outputs, batch_idx, sample_rate, n_fft, global_step, logger):
108
+ x = outputs["x"][batch_idx, ...].float()
109
+ y = outputs["y"][batch_idx, ...].float()
110
+ y_hat = outputs["y_hat"][batch_idx, ...].float()
111
+
112
+ if self.peak_normalize:
113
+ x /= x.abs().max()
114
+ y /= y.abs().max()
115
+ y_hat /= y_hat.abs().max()
116
+
117
+ logger.experiment.add_audio(
118
+ f"x/{batch_idx+1}",
119
+ x[0:1, :],
120
+ global_step,
121
+ sample_rate=sample_rate,
122
+ )
123
+
124
+ logger.experiment.add_audio(
125
+ f"y/{batch_idx+1}",
126
+ y[0:1, :],
127
+ global_step,
128
+ sample_rate=sample_rate,
129
+ )
130
+
131
+ logger.experiment.add_audio(
132
+ f"y_hat/{batch_idx+1}",
133
+ y_hat[0:1, :],
134
+ global_step,
135
+ sample_rate=sample_rate,
136
+ )
137
+
138
+ if "y_ref" in outputs:
139
+ y_ref = outputs["y_ref"][batch_idx, ...].float()
140
+
141
+ if self.peak_normalize:
142
+ y_ref /= y_ref.abs().max()
143
+
144
+ logger.experiment.add_audio(
145
+ f"y_ref/{batch_idx+1}",
146
+ y_ref[0:1, :],
147
+ global_step,
148
+ sample_rate=sample_rate,
149
+ )
150
+ logger.experiment.add_image(
151
+ f"spec/{batch_idx+1}",
152
+ compare_spectra(
153
+ y_hat[0:1, :],
154
+ y[0:1, :],
155
+ x[0:1, :],
156
+ sample_rate=sample_rate,
157
+ n_fft=n_fft,
158
+ ),
159
+ global_step,
160
+ )
161
+
162
+
163
+ def compare_spectra(
164
+ deepafx_y_hat, y, x, baseline_y_hat=None, sample_rate=44100, n_fft=16384
165
+ ):
166
+ legend = ["Corrupted"]
167
+ signals = [x]
168
+ if baseline_y_hat is not None:
169
+ legend.append("Baseline")
170
+ signals.append(baseline_y_hat)
171
+
172
+ legend.append("DeepAFx")
173
+ signals.append(deepafx_y_hat)
174
+ legend.append("Target")
175
+ signals.append(y)
176
+
177
+ image = plot_multi_spectrum(
178
+ ys=signals,
179
+ legend=legend,
180
+ sample_rate=sample_rate,
181
+ n_fft=n_fft,
182
+ )
183
+
184
+ return image
deepafx_st/callbacks/ckpt.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import shutil
4
+ import pytorch_lightning as pl
5
+
6
+
7
+ class CopyPretrainedCheckpoints(pl.callbacks.Callback):
8
+ def __init__(self):
9
+ super().__init__()
10
+
11
+ def on_fit_start(self, trainer, pl_module):
12
+ """Before training, move the pre-trained checkpoints
13
+ to the current checkpoint directory.
14
+
15
+ """
16
+ # copy any pre-trained checkpoints to new directory
17
+ if pl_module.hparams.processor_model == "proxy":
18
+ pretrained_ckpt_dir = os.path.join(
19
+ pl_module.logger.experiment.log_dir, "pretrained_checkpoints"
20
+ )
21
+ if not os.path.isdir(pretrained_ckpt_dir):
22
+ os.makedirs(pretrained_ckpt_dir)
23
+ cp_proxy_ckpts = []
24
+ for proxy_ckpt in pl_module.hparams.proxy_ckpts:
25
+ new_ckpt = shutil.copy(
26
+ proxy_ckpt,
27
+ pretrained_ckpt_dir,
28
+ )
29
+ cp_proxy_ckpts.append(new_ckpt)
30
+ print(f"Moved checkpoint to {new_ckpt}.")
31
+ # overwrite to the paths in current experiment logs
32
+ pl_module.hparams.proxy_ckpts = cp_proxy_ckpts
33
+ print(pl_module.hparams.proxy_ckpts)
deepafx_st/callbacks/params.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pytorch_lightning as pl
3
+ import matplotlib.pyplot as plt
4
+
5
+ import deepafx_st.utils as utils
6
+
7
+
8
+ class LogParametersCallback(pl.callbacks.Callback):
9
+ def __init__(self, num_examples=4):
10
+ super().__init__()
11
+ self.num_examples = 4
12
+
13
+ def on_validation_epoch_start(self, trainer, pl_module):
14
+ """At the start of validation init storage for parameters."""
15
+ self.params = []
16
+
17
+ def on_validation_batch_end(
18
+ self,
19
+ trainer,
20
+ pl_module,
21
+ outputs,
22
+ batch,
23
+ batch_idx,
24
+ dataloader_idx,
25
+ ):
26
+ """Called when the validation batch ends.
27
+
28
+ Here we log the parameters only from the first batch.
29
+
30
+ """
31
+ if outputs is not None and batch_idx == 0:
32
+ examples = np.min([self.num_examples, outputs["x"].shape[0]])
33
+ for n in range(examples):
34
+ self.log_parameters(
35
+ outputs,
36
+ n,
37
+ pl_module.processor.ports,
38
+ trainer.global_step,
39
+ trainer.logger,
40
+ True if batch_idx == 0 else False,
41
+ )
42
+
43
+ def on_validation_epoch_end(self, trainer, pl_module):
44
+ pass
45
+
46
+ def log_parameters(self, outputs, batch_idx, ports, global_step, logger, log=True):
47
+ p = outputs["p"][batch_idx, ...]
48
+
49
+ table = ""
50
+
51
+ # table += f"""## {plugin["name"]}\n"""
52
+ table += "| Index| Name | Value | Units | Min | Max | Default | Raw Value | \n"
53
+ table += "|------|------|------:|:------|----:|----:|--------:| ---------:| \n"
54
+
55
+ start_idx = 0
56
+ # set plugin parameters based on provided normalized parameters
57
+ for port_list in ports:
58
+ for pidx, port in enumerate(port_list):
59
+ param_max = port["max"]
60
+ param_min = port["min"]
61
+ param_name = port["name"]
62
+ param_default = port["default"]
63
+ param_units = port["units"]
64
+
65
+ param_val = p[start_idx]
66
+ denorm_val = utils.denormalize(param_val, param_max, param_min)
67
+
68
+ # add values to table in row
69
+ table += f"| {start_idx + 1} | {param_name} "
70
+ if np.abs(denorm_val) > 10:
71
+ table += f"| {denorm_val:0.1f} "
72
+ table += f"| {param_units} "
73
+ table += f"| {param_min:0.1f} | {param_max:0.1f} "
74
+ table += f"| {param_default:0.1f} "
75
+ else:
76
+ table += f"| {denorm_val:0.3f} "
77
+ table += f"| {param_units} "
78
+ table += f"| {param_min:0.3f} | {param_max:0.3f} "
79
+ table += f"| {param_default:0.3f} "
80
+
81
+ table += f"| {np.squeeze(param_val):0.2f} | \n"
82
+ start_idx += 1
83
+
84
+ table += "\n\n"
85
+
86
+ if log:
87
+ logger.experiment.add_text(f"params/{batch_idx+1}", table, global_step)
deepafx_st/callbacks/plotting.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import torch
3
+ import PIL.Image
4
+ import numpy as np
5
+ import scipy.signal
6
+ import librosa.display
7
+ import matplotlib.pyplot as plt
8
+
9
+ from torch.functional import Tensor
10
+ from torchvision.transforms import ToTensor
11
+
12
+
13
+ def compute_comparison_spectrogram(
14
+ x: np.ndarray,
15
+ y: np.ndarray,
16
+ sample_rate: float = 44100,
17
+ n_fft: int = 2048,
18
+ hop_length: int = 1024,
19
+ ) -> Tensor:
20
+ X = librosa.stft(x, n_fft=n_fft, hop_length=hop_length)
21
+ X_db = librosa.amplitude_to_db(np.abs(X), ref=np.max)
22
+
23
+ Y = librosa.stft(y, n_fft=n_fft, hop_length=hop_length)
24
+ Y_db = librosa.amplitude_to_db(np.abs(Y), ref=np.max)
25
+
26
+ fig, axs = plt.subplots(figsize=(9, 6), nrows=2)
27
+ img = librosa.display.specshow(
28
+ X_db,
29
+ ax=axs[0],
30
+ hop_length=hop_length,
31
+ x_axis="time",
32
+ y_axis="log",
33
+ sr=sample_rate,
34
+ )
35
+ # fig.colorbar(img, ax=axs[0])
36
+ img = librosa.display.specshow(
37
+ Y_db,
38
+ ax=axs[1],
39
+ hop_length=hop_length,
40
+ x_axis="time",
41
+ y_axis="log",
42
+ sr=sample_rate,
43
+ )
44
+ # fig.colorbar(img, ax=axs[1])
45
+
46
+ plt.tight_layout()
47
+
48
+ buf = io.BytesIO()
49
+ plt.savefig(buf, format="jpeg")
50
+ buf.seek(0)
51
+ image = PIL.Image.open(buf)
52
+ image = ToTensor()(image)
53
+ plt.close("all")
54
+
55
+ return image
56
+
57
+
58
+ def plot_multi_spectrum(
59
+ ys=None,
60
+ Hs=None,
61
+ legend=[],
62
+ title="Spectrum",
63
+ filename=None,
64
+ sample_rate=44100,
65
+ n_fft=1024,
66
+ zero_mean=False,
67
+ ):
68
+
69
+ if Hs is None:
70
+ Hs = []
71
+ for y in ys:
72
+ X = get_average_spectrum(y, n_fft)
73
+ X_sm = smooth_spectrum(X)
74
+ Hs.append(X_sm)
75
+
76
+ bin_width = (sample_rate / 2) / (n_fft // 2)
77
+ freqs = np.arange(0, (sample_rate / 2) + bin_width, step=bin_width)
78
+
79
+ fig, ax1 = plt.subplots()
80
+
81
+ for idx, H in enumerate(Hs):
82
+ H = np.nan_to_num(H)
83
+ H = np.clip(H, 0, np.max(H))
84
+ H_dB = 20 * np.log10(H + 1e-8)
85
+ if zero_mean:
86
+ H_dB -= np.mean(H_dB)
87
+ if "Target" in legend[idx]:
88
+ ax1.plot(freqs, H_dB, linestyle="--", color="k")
89
+ else:
90
+ ax1.plot(freqs, H_dB)
91
+
92
+ plt.legend(legend)
93
+
94
+ ax1.set_xscale("log")
95
+ ax1.set_ylim([-80, 0])
96
+ ax1.set_xlim([100, 11000])
97
+ plt.title(title)
98
+ plt.ylabel("Magnitude (dB)")
99
+ plt.xlabel("Frequency (Hz)")
100
+ plt.grid(c="lightgray", which="both")
101
+
102
+ if filename is not None:
103
+ plt.savefig(f"{filename}.png", dpi=300)
104
+
105
+ plt.tight_layout()
106
+
107
+ buf = io.BytesIO()
108
+ plt.savefig(buf, format="jpeg")
109
+ buf.seek(0)
110
+ image = PIL.Image.open(buf)
111
+ image = ToTensor()(image)
112
+ plt.close("all")
113
+
114
+ return image
115
+
116
+
117
+ def smooth_spectrum(H):
118
+ # apply Savgol filter for smoothed target curve
119
+ return scipy.signal.savgol_filter(H, 1025, 2)
120
+
121
+
122
+ def get_average_spectrum(x, n_fft):
123
+ X = torch.stft(x, n_fft, return_complex=True, normalized=True)
124
+ X = X.abs() # convert to magnitude
125
+ X = X.mean(dim=-1).view(-1) # average across frames
126
+ return X
deepafx_st/data/audio.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import torch
4
+ import warnings
5
+ import torchaudio
6
+ import pyloudnorm as pyln
7
+
8
+
9
+ class AudioFile(object):
10
+ def __init__(self, filepath, preload=False, half=False, target_loudness=None):
11
+ """Base class for audio files to handle metadata and loading.
12
+
13
+ Args:
14
+ filepath (str): Path to audio file to load from disk.
15
+ preload (bool, optional): If set, load audio data into RAM. Default: False
16
+ half (bool, optional): If set, store audio data as float16 to save space. Default: False
17
+ target_loudness (float, optional): Loudness normalize to dB LUFS value. Default:
18
+ """
19
+ super().__init__()
20
+
21
+ self.filepath = filepath
22
+ self.half = half
23
+ self.target_loudness = target_loudness
24
+ self.loaded = False
25
+
26
+ if preload:
27
+ self.load()
28
+ num_frames = self.audio.shape[-1]
29
+ num_channels = self.audio.shape[0]
30
+ else:
31
+ metadata = torchaudio.info(filepath)
32
+ audio = None
33
+ self.sample_rate = metadata.sample_rate
34
+ num_frames = metadata.num_frames
35
+ num_channels = metadata.num_channels
36
+
37
+ self.num_frames = num_frames
38
+ self.num_channels = num_channels
39
+
40
+ def load(self):
41
+ audio, sr = torchaudio.load(self.filepath, normalize=True)
42
+ self.audio = audio
43
+ self.sample_rate = sr
44
+
45
+ if self.target_loudness is not None:
46
+ self.loudness_normalize()
47
+
48
+ if self.half:
49
+ self.audio = audio.half()
50
+
51
+ self.loaded = True
52
+
53
+ def loudness_normalize(self):
54
+ meter = pyln.Meter(self.sample_rate)
55
+
56
+ # conver mono to stereo
57
+ if self.audio.shape[0] == 1:
58
+ tmp_audio = self.audio.repeat(2, 1)
59
+ else:
60
+ tmp_audio = self.audio
61
+
62
+ # measure integrated loudness
63
+ input_loudness = meter.integrated_loudness(tmp_audio.numpy().T)
64
+
65
+ # compute and apply gain
66
+ gain_dB = self.target_loudness - input_loudness
67
+ gain_ln = 10 ** (gain_dB / 20.0)
68
+ self.audio *= gain_ln
69
+
70
+ # check for potentially clipped samples
71
+ if self.audio.abs().max() >= 1.0:
72
+ warnings.warn("Possible clipped samples in output.")
73
+
74
+
75
+ class AudioFileDataset(torch.utils.data.Dataset):
76
+ """Base class for audio file datasets loaded from disk.
77
+
78
+ Datasets can be either paired or unpaired. A paired dataset requires passing the `target_dir` path.
79
+
80
+ Args:
81
+ input_dir (List[str]): List of paths to the directories containing input audio files.
82
+ target_dir (List[str], optional): List of paths to the directories containing correponding audio files. Default: []
83
+ subset (str, optional): Dataset subset. One of ["train", "val", "test"]. Default: "train"
84
+ length (int, optional): Number of samples to load for each example. Default: 65536
85
+ normalize (bool, optional): Normalize audio amplitiude to -1 to 1. Default: True
86
+ train_frac (float, optional): Fraction of the files to use for training subset. Default: 0.8
87
+ val_frac (float, optional): Fraction of the files to use for validation subset. Default: 0.1
88
+ preload (bool, optional): Read audio files into RAM at the start of training. Default: False
89
+ num_examples_per_epoch (int, optional): Define an epoch as certain number of audio examples. Default: 10000
90
+ ext (str, optional): Expected audio file extension. Default: "wav"
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ input_dirs,
96
+ target_dirs=[],
97
+ subset="train",
98
+ length=65536,
99
+ normalize=True,
100
+ train_per=0.8,
101
+ val_per=0.1,
102
+ preload=False,
103
+ num_examples_per_epoch=10000,
104
+ ext="wav",
105
+ ):
106
+ super().__init__()
107
+ self.input_dirs = input_dirs
108
+ self.target_dirs = target_dirs
109
+ self.subset = subset
110
+ self.length = length
111
+ self.normalize = normalize
112
+ self.train_per = train_per
113
+ self.val_per = val_per
114
+ self.preload = preload
115
+ self.num_examples_per_epoch = num_examples_per_epoch
116
+ self.ext = ext
117
+
118
+ self.input_filepaths = []
119
+ for input_dir in input_dirs:
120
+ search_path = os.path.join(input_dir, f"*.{ext}")
121
+ self.input_filepaths += glob.glob(search_path)
122
+ self.input_filepaths = sorted(self.input_filepaths)
123
+
124
+ self.target_filepaths = []
125
+ for target_dir in target_dirs:
126
+ search_path = os.path.join(target_dir, f"*.{ext}")
127
+ self.target_filepaths += glob.glob(search_path)
128
+ self.target_filepaths = sorted(self.target_filepaths)
129
+
130
+ # both sets must have same number of files in paired dataset
131
+ assert len(self.target_filepaths) == len(self.input_filepaths)
132
+
133
+ # get details about audio files
134
+ self.input_files = []
135
+ for input_filepath in self.input_filepaths:
136
+ self.input_files.append(
137
+ AudioFile(input_filepath, preload=preload, normalize=normalize)
138
+ )
139
+
140
+ self.target_files = []
141
+ if target_dir is not None:
142
+ for target_filepath in self.target_filepaths:
143
+ self.target_files.append(
144
+ AudioFile(target_filepath, preload=preload, normalize=normalize)
145
+ )
146
+
147
+ def __len__(self):
148
+ return self.num_examples_per_epoch
149
+
150
+ def __getitem__(self, idx):
151
+ """ """
152
+
153
+ # index the current audio file
154
+ input_file = self.input_files[idx]
155
+
156
+ # load the audio data if needed
157
+ if not input_file.loaded:
158
+ input_file.load()
159
+
160
+ # get a random patch of size `self.length`
161
+ start_idx = int(torch.rand() * (input_file.num_frames - self.length))
162
+ stop_idx = start_idx + self.length
163
+ input_audio = input_file.audio[:, start_idx:stop_idx]
164
+
165
+ # if there is a target file, get it (and load)
166
+ if len(self.target_files) > 0:
167
+ target_file = self.target_files[idx]
168
+
169
+ if not target_file.loaded:
170
+ target_file.load()
171
+
172
+ # use the same cropping indices
173
+ target_audio = target_file.audio[:, start_idx:stop_idx]
174
+
175
+ return input_audio, target_audio
176
+ else:
177
+ return input_audio
deepafx_st/data/augmentations.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import numpy as np
4
+
5
+
6
+ def gain(xs, min_dB=-12, max_dB=12):
7
+
8
+ gain_dB = (torch.rand(1) * (max_dB - min_dB)) + min_dB
9
+ gain_ln = 10 ** (gain_dB / 20)
10
+
11
+ for idx, x in enumerate(xs):
12
+ xs[idx] = x * gain_ln
13
+
14
+ return xs
15
+
16
+
17
+ def peaking_filter(xs, sr=44100, frequency=1000, width_q=0.707, gain_db=12):
18
+
19
+ # gain_db = ((torch.rand(1) * 6) + 6).numpy().squeeze()
20
+ # width_q = (torch.rand(1) * 4).numpy().squeeze()
21
+ # frequency = ((torch.rand(1) * 9960) + 40).numpy().squeeze()
22
+
23
+ # if torch.rand(1) > 0.5:
24
+ # gain_db = -gain_db
25
+
26
+ effects = [["equalizer", f"{frequency}", f"{width_q}", f"{gain_db}"]]
27
+
28
+ for idx, x in enumerate(xs):
29
+ y, sr = torchaudio.sox_effects.apply_effects_tensor(
30
+ x, sr, effects, channels_first=True
31
+ )
32
+ xs[idx] = y
33
+
34
+ return xs
35
+
36
+
37
+ def pitch_shift(xs, min_shift=-200, max_shift=200, sr=44100):
38
+
39
+ shift = min_shift + (torch.rand(1)).numpy().squeeze() * (max_shift - min_shift)
40
+
41
+ effects = [["pitch", f"{shift}"]]
42
+
43
+ for idx, x in enumerate(xs):
44
+ y, sr = torchaudio.sox_effects.apply_effects_tensor(
45
+ x, sr, effects, channels_first=True
46
+ )
47
+ xs[idx] = y
48
+
49
+ return xs
50
+
51
+
52
+ def time_stretch(xs, min_stretch=0.8, max_stretch=1.2, sr=44100):
53
+
54
+ stretch = min_stretch + (torch.rand(1)).numpy().squeeze() * (
55
+ max_stretch - min_stretch
56
+ )
57
+
58
+ effects = [["tempo", f"{stretch}"]]
59
+ for idx, x in enumerate(xs):
60
+ y, sr = torchaudio.sox_effects.apply_effects_tensor(
61
+ x, sr, effects, channels_first=True
62
+ )
63
+ xs[idx] = y
64
+
65
+ return xs
66
+
67
+
68
+ def frequency_corruption(xs, sr=44100):
69
+
70
+ effects = []
71
+
72
+ # apply a random number of peaking bands from 0 to 4s
73
+ bands = [[200, 2000], [800, 4000], [2000, 8000], [4000, int((sr // 2) * 0.9)]]
74
+ total_gain_db = 0.0
75
+ for band in bands:
76
+ if torch.rand(1).sum() > 0.2:
77
+ frequency = (torch.randint(band[0], band[1], [1])).numpy().squeeze()
78
+ width_q = ((torch.rand(1) * 10) + 0.1).numpy().squeeze()
79
+ gain_db = ((torch.rand(1) * 48)).numpy().squeeze()
80
+
81
+ if torch.rand(1).sum() > 0.5:
82
+ gain_db = -gain_db
83
+
84
+ total_gain_db += gain_db
85
+
86
+ if np.abs(total_gain_db) >= 24:
87
+ continue
88
+
89
+ cmd = ["equalizer", f"{frequency}", f"{width_q}", f"{gain_db}"]
90
+ effects.append(cmd)
91
+
92
+ # low shelf (bass)
93
+ if torch.rand(1).sum() > 0.2:
94
+ gain_db = ((torch.rand(1) * 24)).numpy().squeeze()
95
+ frequency = (torch.randint(20, 200, [1])).numpy().squeeze()
96
+ if torch.rand(1).sum() > 0.5:
97
+ gain_db = -gain_db
98
+ effects.append(["bass", f"{gain_db}", f"{frequency}"])
99
+
100
+ # high shelf (treble)
101
+ if torch.rand(1).sum() > 0.2:
102
+ gain_db = ((torch.rand(1) * 24)).numpy().squeeze()
103
+ frequency = (torch.randint(4000, int((sr // 2) * 0.9), [1])).numpy().squeeze()
104
+ if torch.rand(1).sum() > 0.5:
105
+ gain_db = -gain_db
106
+ effects.append(["treble", f"{gain_db}", f"{frequency}"])
107
+
108
+ for idx, x in enumerate(xs):
109
+ y, sr = torchaudio.sox_effects.apply_effects_tensor(
110
+ x.view(1, -1) * 10 ** (-48 / 20), sr, effects, channels_first=True
111
+ )
112
+ # apply gain back
113
+ y *= 10 ** (48 / 20)
114
+
115
+ xs[idx] = y
116
+
117
+ return xs
118
+
119
+
120
+ def dynamic_range_corruption(xs, sr=44100):
121
+ """Apply an expander."""
122
+
123
+ attack = (torch.rand([1]).numpy()[0] * 0.05) + 0.001
124
+ release = (torch.rand([1]).numpy()[0] * 0.2) + attack
125
+ knee = (torch.rand([1]).numpy()[0] * 12) + 0.0
126
+
127
+ # design the compressor transfer function
128
+ start = -100.0
129
+ threshold = -(
130
+ (torch.rand([1]).numpy()[0] * 20) + 10
131
+ ) # threshold from -30 to -10 dB
132
+ ratio = (torch.rand([1]).numpy()[0] * 4.0) + 1 # ratio from 1:1 to 5:1
133
+
134
+ # compute the transfer curve
135
+ point = -((-threshold / -ratio) + (-start / ratio) + -threshold)
136
+
137
+ # apply some makeup gain
138
+ makeup = torch.rand([1]).numpy()[0] * 6
139
+
140
+ effects = [
141
+ [
142
+ "compand",
143
+ f"{attack},{release}",
144
+ f"{knee}:{point},{start},{threshold},{threshold}",
145
+ f"{makeup}",
146
+ f"{start}",
147
+ ]
148
+ ]
149
+
150
+ for idx, x in enumerate(xs):
151
+ # if the input is clipping normalize it
152
+ if x.abs().max() >= 1.0:
153
+ x /= x.abs().max()
154
+ gain_db = -((torch.rand(1) * 24)).numpy().squeeze()
155
+ x *= 10 ** (gain_db / 20.0)
156
+
157
+ y, sr = torchaudio.sox_effects.apply_effects_tensor(
158
+ x.view(1, -1), sr, effects, channels_first=True
159
+ )
160
+ xs[idx] = y
161
+
162
+ return xs
163
+
164
+
165
+ def dynamic_range_compression(xs, sr=44100):
166
+ """Apply a compressor."""
167
+
168
+ attack = (torch.rand([1]).numpy()[0] * 0.05) + 0.0005
169
+ release = (torch.rand([1]).numpy()[0] * 0.2) + attack
170
+ knee = (torch.rand([1]).numpy()[0] * 12) + 0.0
171
+
172
+ # design the compressor transfer function
173
+ start = -100.0
174
+ threshold = -((torch.rand([1]).numpy()[0] * 52) + 12)
175
+ # threshold from -64 to -12 dB
176
+ ratio = (torch.rand([1]).numpy()[0] * 10.0) + 1 # ratio from 1:1 to 10:1
177
+
178
+ # compute the transfer curve
179
+ point = threshold * (1 - (1 / ratio))
180
+
181
+ # apply some makeup gain
182
+ makeup = torch.rand([1]).numpy()[0] * 6
183
+
184
+ effects = [
185
+ [
186
+ "compand",
187
+ f"{attack},{release}",
188
+ f"{knee}:{start},{threshold},{threshold},0,{point}",
189
+ f"{makeup}",
190
+ f"{start}",
191
+ f"{attack}",
192
+ ]
193
+ ]
194
+
195
+ for idx, x in enumerate(xs):
196
+ y, sr = torchaudio.sox_effects.apply_effects_tensor(
197
+ x.view(1, -1), sr, effects, channels_first=True
198
+ )
199
+ xs[idx] = y
200
+
201
+ return xs
202
+
203
+
204
+ def lowpass_filter(xs, sr=44100, frequency=4000):
205
+ effects = [["lowpass", f"{frequency}"]]
206
+
207
+ for idx, x in enumerate(xs):
208
+ y, sr = torchaudio.sox_effects.apply_effects_tensor(
209
+ x, sr, effects, channels_first=True
210
+ )
211
+ xs[idx] = y
212
+
213
+ return xs
214
+
215
+
216
+ def apply(xs, sr, augmentations):
217
+
218
+ # iterate over augmentation dict
219
+ for aug, params in augmentations.items():
220
+ if aug == "gain":
221
+ xs = gain(xs, **params)
222
+ elif aug == "peak":
223
+ xs = peaking_filter(xs, **params)
224
+ elif aug == "lowpass":
225
+ xs = lowpass_filter(xs, **params)
226
+ elif aug == "pitch":
227
+ xs = pitch_shift(xs, **params)
228
+ elif aug == "tempo":
229
+ xs = time_stretch(xs, **params)
230
+ elif aug == "freq_corrupt":
231
+ xs = frequency_corruption(xs, **params)
232
+ else:
233
+ raise RuntimeError("Invalid augmentation: {aug}")
234
+
235
+ return xs
deepafx_st/data/dataset.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import csv
4
+ import glob
5
+ import torch
6
+ import random
7
+ from tqdm import tqdm
8
+ from typing import List, Any
9
+
10
+ from deepafx_st.data.audio import AudioFile
11
+ import deepafx_st.utils as utils
12
+ import deepafx_st.data.augmentations as augmentations
13
+
14
+
15
+ class AudioDataset(torch.utils.data.Dataset):
16
+ """Audio dataset which returns an input and target file.
17
+
18
+ Args:
19
+ audio_dir (str): Path to the top level of the audio dataset.
20
+ input_dir (List[str], optional): List of paths to the directories containing input audio files. Default: ["clean"]
21
+ subset (str, optional): Dataset subset. One of ["train", "val", "test"]. Default: "train"
22
+ length (int, optional): Number of samples to load for each example. Default: 65536
23
+ train_frac (float, optional): Fraction of the files to use for training subset. Default: 0.8
24
+ val_frac (float, optional): Fraction of the files to use for validation subset. Default: 0.1
25
+ buffer_size_gb (float, optional): Size of audio to read into RAM in GB at any given time. Default: 10.0
26
+ Note: This is the buffer size PER DataLoader worker. So total RAM = buffer_size_gb * num_workers
27
+ buffer_reload_rate (int, optional): Number of items to generate before loading next chunk of dataset. Default: 10000
28
+ half (bool, optional): Sotre audio samples as float 16. Default: False
29
+ num_examples_per_epoch (int, optional): Define an epoch as certain number of audio examples. Default: 10000
30
+ random_scale_input (bool, optional): Apply random gain scaling to input utterances. Default: False
31
+ random_scale_target (bool, optional): Apply same random gain scaling to target utterances. Default: False
32
+ augmentations (dict, optional): List of augmentation types to apply to inputs. Default: []
33
+ freq_corrupt (bool, optional): Apply bad EQ filters. Default: False
34
+ drc_corrupt (bool, optional): Apply an expander to corrupt dynamic range. Default: False
35
+ ext (str, optional): Expected audio file extension. Default: "wav"
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ audio_dir,
41
+ input_dirs: List[str] = ["cleanraw"],
42
+ subset: str = "train",
43
+ length: int = 65536,
44
+ train_frac: float = 0.8,
45
+ val_per: float = 0.1,
46
+ buffer_size_gb: float = 1.0,
47
+ buffer_reload_rate: float = 1000,
48
+ half: bool = False,
49
+ num_examples_per_epoch: int = 10000,
50
+ random_scale_input: bool = False,
51
+ random_scale_target: bool = False,
52
+ augmentations: dict = {},
53
+ freq_corrupt: bool = False,
54
+ drc_corrupt: bool = False,
55
+ ext: str = "wav",
56
+ ):
57
+ super().__init__()
58
+ self.audio_dir = audio_dir
59
+ self.dataset_name = os.path.basename(audio_dir)
60
+ self.input_dirs = input_dirs
61
+ self.subset = subset
62
+ self.length = length
63
+ self.train_frac = train_frac
64
+ self.val_per = val_per
65
+ self.buffer_size_gb = buffer_size_gb
66
+ self.buffer_reload_rate = buffer_reload_rate
67
+ self.half = half
68
+ self.num_examples_per_epoch = num_examples_per_epoch
69
+ self.random_scale_input = random_scale_input
70
+ self.random_scale_target = random_scale_target
71
+ self.augmentations = augmentations
72
+ self.freq_corrupt = freq_corrupt
73
+ self.drc_corrupt = drc_corrupt
74
+ self.ext = ext
75
+
76
+ self.input_filepaths = []
77
+ for input_dir in input_dirs:
78
+ search_path = os.path.join(audio_dir, input_dir, f"*.{ext}")
79
+ self.input_filepaths += glob.glob(search_path)
80
+ self.input_filepaths = sorted(self.input_filepaths)
81
+
82
+ # create dataset split based on subset
83
+ self.input_filepaths = utils.split_dataset(
84
+ self.input_filepaths,
85
+ subset,
86
+ train_frac,
87
+ )
88
+
89
+ # get details about input audio files
90
+ input_files = {}
91
+ input_dur_frames = 0
92
+ for input_filepath in tqdm(self.input_filepaths, ncols=80):
93
+ file_id = os.path.basename(input_filepath)
94
+ audio_file = AudioFile(
95
+ input_filepath,
96
+ preload=False,
97
+ half=half,
98
+ )
99
+ if audio_file.num_frames < (self.length * 2):
100
+ continue
101
+ input_files[file_id] = audio_file
102
+ input_dur_frames += input_files[file_id].num_frames
103
+
104
+ if len(list(input_files.items())) < 1:
105
+ raise RuntimeError(f"No files found in {search_path}.")
106
+
107
+ input_dur_hr = (input_dur_frames / input_files[file_id].sample_rate) / 3600
108
+ print(
109
+ f"\nLoaded {len(input_files)} files for {subset} = {input_dur_hr:0.2f} hours."
110
+ )
111
+
112
+ self.sample_rate = input_files[file_id].sample_rate
113
+
114
+ # save a csv file with details about the train and test split
115
+ splits_dir = os.path.join("configs", "splits")
116
+ if not os.path.isdir(splits_dir):
117
+ os.makedirs(splits_dir)
118
+ csv_filepath = os.path.join(splits_dir, f"{self.dataset_name}_{self.subset}_set.csv")
119
+
120
+ with open(csv_filepath, "w") as fp:
121
+ dw = csv.DictWriter(fp, ["file_id", "filepath", "type", "subset"])
122
+ dw.writeheader()
123
+ for input_filepath in self.input_filepaths:
124
+ dw.writerow(
125
+ {
126
+ "file_id": self.get_file_id(input_filepath),
127
+ "filepath": input_filepath,
128
+ "type": "input",
129
+ "subset": self.subset,
130
+ }
131
+ )
132
+
133
+ # some setup for iteratble loading of the dataset into RAM
134
+ self.items_since_load = self.buffer_reload_rate
135
+
136
+ def __len__(self):
137
+ return self.num_examples_per_epoch
138
+
139
+ def load_audio_buffer(self):
140
+ self.input_files_loaded = {} # clear audio buffer
141
+ self.items_since_load = 0 # reset iteration counter
142
+ nbytes_loaded = 0 # counter for data in RAM
143
+
144
+ # different subset in each
145
+ random.shuffle(self.input_filepaths)
146
+
147
+ # load files into RAM
148
+ for input_filepath in self.input_filepaths:
149
+ file_id = os.path.basename(input_filepath)
150
+ audio_file = AudioFile(
151
+ input_filepath,
152
+ preload=True,
153
+ half=self.half,
154
+ )
155
+
156
+ if audio_file.num_frames < (self.length * 2):
157
+ continue
158
+
159
+ self.input_files_loaded[file_id] = audio_file
160
+
161
+ nbytes = audio_file.audio.element_size() * audio_file.audio.nelement()
162
+ nbytes_loaded += nbytes
163
+
164
+ # check the size of loaded data
165
+ if nbytes_loaded > self.buffer_size_gb * 1e9:
166
+ break
167
+
168
+ def generate_pair(self):
169
+ # ------------------------ Input audio ----------------------
170
+ rand_input_file_id = None
171
+ input_file = None
172
+ start_idx = None
173
+ stop_idx = None
174
+ while True:
175
+ rand_input_file_id = self.get_random_file_id(self.input_files_loaded.keys())
176
+
177
+ # use this random key to retrieve an input file
178
+ input_file = self.input_files_loaded[rand_input_file_id]
179
+
180
+ # load the audio data if needed
181
+ if not input_file.loaded:
182
+ raise RuntimeError("Audio not loaded.")
183
+
184
+ # get a random patch of size `self.length` x 2
185
+ start_idx, stop_idx = self.get_random_patch(
186
+ input_file, int(self.length * 2)
187
+ )
188
+ if start_idx >= 0:
189
+ break
190
+
191
+ input_audio = input_file.audio[:, start_idx:stop_idx].clone().detach()
192
+ input_audio = input_audio.view(1, -1)
193
+
194
+ if self.half:
195
+ input_audio = input_audio.float()
196
+
197
+ # peak normalize to -12 dBFS
198
+ input_audio /= input_audio.abs().max()
199
+ input_audio *= 10 ** (-12.0 / 20) # with min 3 dBFS headroom
200
+
201
+ if len(list(self.augmentations.items())) > 0:
202
+ if torch.rand(1).sum() < 0.5:
203
+ input_audio_aug = augmentations.apply(
204
+ [input_audio],
205
+ self.sample_rate,
206
+ self.augmentations,
207
+ )[0]
208
+ else:
209
+ input_audio_aug = input_audio.clone()
210
+ else:
211
+ input_audio_aug = input_audio.clone()
212
+
213
+ input_audio_corrupt = input_audio_aug.clone()
214
+ # apply frequency and dynamic range corrpution (expander)
215
+ if self.freq_corrupt and torch.rand(1).sum() < 0.75:
216
+ input_audio_corrupt = augmentations.frequency_corruption(
217
+ [input_audio_corrupt], self.sample_rate
218
+ )[0]
219
+
220
+ # peak normalize again before passing through dynamic range expander
221
+ input_audio_corrupt /= input_audio_corrupt.abs().max()
222
+ input_audio_corrupt *= 10 ** (-12.0 / 20) # with min 3 dBFS headroom
223
+
224
+ if self.drc_corrupt and torch.rand(1).sum() < 0.10:
225
+ input_audio_corrupt = augmentations.dynamic_range_corruption(
226
+ [input_audio_corrupt], self.sample_rate
227
+ )[0]
228
+
229
+ # ------------------------ Target audio ----------------------
230
+ # use the same augmented audio clip, add different random EQ and compressor
231
+
232
+ target_audio_corrupt = input_audio_aug.clone()
233
+ # apply frequency and dynamic range corrpution (expander)
234
+ if self.freq_corrupt and torch.rand(1).sum() < 0.75:
235
+ target_audio_corrupt = augmentations.frequency_corruption(
236
+ [target_audio_corrupt], self.sample_rate
237
+ )[0]
238
+
239
+ # peak normalize again before passing through dynamic range compressor
240
+ input_audio_corrupt /= input_audio_corrupt.abs().max()
241
+ input_audio_corrupt *= 10 ** (-12.0 / 20) # with min 3 dBFS headroom
242
+
243
+ if self.drc_corrupt and torch.rand(1).sum() < 0.75:
244
+ target_audio_corrupt = augmentations.dynamic_range_compression(
245
+ [target_audio_corrupt], self.sample_rate
246
+ )[0]
247
+
248
+ return input_audio_corrupt, target_audio_corrupt
249
+
250
+ def __getitem__(self, _):
251
+ """ """
252
+
253
+ # increment counter
254
+ self.items_since_load += 1
255
+
256
+ # load next chunk into buffer if needed
257
+ if self.items_since_load > self.buffer_reload_rate:
258
+ self.load_audio_buffer()
259
+
260
+ # generate pairs for style training
261
+ input_audio, target_audio = self.generate_pair()
262
+
263
+ # ------------------------ Conform length of files -------------------
264
+ input_audio = utils.conform_length(input_audio, int(self.length * 2))
265
+ target_audio = utils.conform_length(target_audio, int(self.length * 2))
266
+
267
+ # ------------------------ Apply fade in and fade out -------------------
268
+ input_audio = utils.linear_fade(input_audio, sample_rate=self.sample_rate)
269
+ target_audio = utils.linear_fade(target_audio, sample_rate=self.sample_rate)
270
+
271
+ # ------------------------ Final normalizeation ----------------------
272
+ # always peak normalize final input to -12 dBFS
273
+ input_audio /= input_audio.abs().max()
274
+ input_audio *= 10 ** (-12.0 / 20.0)
275
+
276
+ # always peak normalize the target to -12 dBFS
277
+ target_audio /= target_audio.abs().max()
278
+ target_audio *= 10 ** (-12.0 / 20.0)
279
+
280
+ return input_audio, target_audio
281
+
282
+ @staticmethod
283
+ def get_random_file_id(keys):
284
+ # generate a random index into the keys of the input files
285
+ rand_input_idx = torch.randint(0, len(keys) - 1, [1])[0]
286
+ # find the key (file_id) correponding to the random index
287
+ rand_input_file_id = list(keys)[rand_input_idx]
288
+
289
+ return rand_input_file_id
290
+
291
+ @staticmethod
292
+ def get_random_patch(audio_file, length, check_silence=True):
293
+ silent = True
294
+ count = 0
295
+ while silent:
296
+ count += 1
297
+ start_idx = torch.randint(0, audio_file.num_frames - length - 1, [1])[0]
298
+ # int(torch.rand(1) * (audio_file.num_frames - length))
299
+ stop_idx = start_idx + length
300
+ patch = audio_file.audio[:, start_idx:stop_idx].clone().detach()
301
+
302
+ length = patch.shape[-1]
303
+ first_patch = patch[..., : length // 2]
304
+ second_patch = patch[..., length // 2 :]
305
+
306
+ if (
307
+ (first_patch**2).mean() > 1e-5 and (second_patch**2).mean() > 1e-5
308
+ ) or not check_silence:
309
+ silent = False
310
+
311
+ if count > 100:
312
+ print("get_random_patch count", count)
313
+ return -1, -1
314
+ # break
315
+
316
+ return start_idx, stop_idx
317
+
318
+ def get_file_id(self, filepath):
319
+ """Given a filepath extract the DAPS file id.
320
+
321
+ Args:
322
+ filepath (str): Path to an audio files in the DAPS dataset.
323
+
324
+ Returns:
325
+ file_id (str): DAPS file id of the form <participant_id>_<script_id>
326
+ file_set (str): The DAPS set to which the file belongs.
327
+ """
328
+ file_id = os.path.basename(filepath).split("_")[:2]
329
+ file_id = "_".join(file_id)
330
+ return file_id
331
+
332
+ def get_file_set(self, filepath):
333
+ """Given a filepath extract the DAPS file set name.
334
+
335
+ Args:
336
+ filepath (str): Path to an audio files in the DAPS dataset.
337
+
338
+ Returns:
339
+ file_set (str): The DAPS set to which the file belongs.
340
+ """
341
+ file_set = os.path.basename(filepath).split("_")[2:]
342
+ file_set = "_".join(file_set)
343
+ file_set = file_set.replace(f".{self.ext}", "")
344
+ return file_set
deepafx_st/data/proxy.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import glob
4
+ import torch
5
+ import random
6
+ from tqdm import tqdm
7
+
8
+ # from deepafx_st.plugins.channel import Channel
9
+ from deepafx_st.processors.processor import Processor
10
+ from deepafx_st.data.audio import AudioFile
11
+ import deepafx_st.utils as utils
12
+
13
+
14
+ class DSPProxyDataset(torch.utils.data.Dataset):
15
+ """Class for generating input-output audio from Python DSP effects.
16
+
17
+ Args:
18
+ input_dir (List[str]): List of paths to the directories containing input audio files.
19
+ processor (Processor): Processor object to create proxy of.
20
+ processor_type (str): Processor name.
21
+ subset (str, optional): Dataset subset. One of ["train", "val", "test"]. Default: "train"
22
+ buffer_size_gb (float, optional): Size of audio to read into RAM in GB at any given time. Default: 10.0
23
+ Note: This is the buffer size PER DataLoader worker. So total RAM = buffer_size_gb * num_workers
24
+ buffer_reload_rate (int, optional): Number of items to generate before loading next chunk of dataset. Default: 10000
25
+ length (int, optional): Number of samples to load for each example. Default: 65536
26
+ num_examples_per_epoch (int, optional): Define an epoch as certain number of audio examples. Default: 10000
27
+ ext (str, optional): Expected audio file extension. Default: "wav"
28
+ hard_clip (bool, optional): Hard clip outputs between -1 and 1. Default: True
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ input_dir: str,
34
+ processor: Processor,
35
+ processor_type: str,
36
+ subset="train",
37
+ length=65536,
38
+ buffer_size_gb=1.0,
39
+ buffer_reload_rate=1000,
40
+ half=False,
41
+ num_examples_per_epoch=10000,
42
+ ext="wav",
43
+ soft_clip=True,
44
+ ):
45
+ super().__init__()
46
+ self.input_dir = input_dir
47
+ self.processor = processor
48
+ self.processor_type = processor_type
49
+ self.subset = subset
50
+ self.length = length
51
+ self.buffer_size_gb = buffer_size_gb
52
+ self.buffer_reload_rate = buffer_reload_rate
53
+ self.half = half
54
+ self.num_examples_per_epoch = num_examples_per_epoch
55
+ self.ext = ext
56
+ self.soft_clip = soft_clip
57
+
58
+ search_path = os.path.join(input_dir, f"*.{ext}")
59
+ self.input_filepaths = glob.glob(search_path)
60
+ self.input_filepaths = sorted(self.input_filepaths)
61
+
62
+ if len(self.input_filepaths) < 1:
63
+ raise RuntimeError(f"No files found in {input_dir}.")
64
+
65
+ # get training split
66
+ self.input_filepaths = utils.split_dataset(
67
+ self.input_filepaths, self.subset, 0.9
68
+ )
69
+
70
+ # get details about audio files
71
+ cnt = 0
72
+ self.input_files = {}
73
+ for input_filepath in tqdm(self.input_filepaths, ncols=80):
74
+ file_id = os.path.basename(input_filepath)
75
+ audio_file = AudioFile(
76
+ input_filepath,
77
+ preload=False,
78
+ half=half,
79
+ )
80
+ if audio_file.num_frames < self.length:
81
+ continue
82
+ self.input_files[file_id] = audio_file
83
+ self.sample_rate = self.input_files[file_id].sample_rate
84
+ cnt += 1
85
+ if cnt > 1000:
86
+ break
87
+
88
+ # some setup for iteratble loading of the dataset into RAM
89
+ self.items_since_load = self.buffer_reload_rate
90
+
91
+ def __len__(self):
92
+ return self.num_examples_per_epoch
93
+
94
+ def load_audio_buffer(self):
95
+ self.input_files_loaded = {} # clear audio buffer
96
+ self.items_since_load = 0 # reset iteration counter
97
+ nbytes_loaded = 0 # counter for data in RAM
98
+
99
+ # different subset in each
100
+ random.shuffle(self.input_filepaths)
101
+
102
+ # load files into RAM
103
+ for input_filepath in self.input_filepaths:
104
+ file_id = os.path.basename(input_filepath)
105
+ audio_file = AudioFile(
106
+ input_filepath,
107
+ preload=True,
108
+ half=self.half,
109
+ )
110
+
111
+ if audio_file.num_frames < self.length:
112
+ continue
113
+
114
+ self.input_files_loaded[file_id] = audio_file
115
+
116
+ nbytes = audio_file.audio.element_size() * audio_file.audio.nelement()
117
+ nbytes_loaded += nbytes
118
+
119
+ if nbytes_loaded > self.buffer_size_gb * 1e9:
120
+ break
121
+
122
+ def __getitem__(self, _):
123
+ """ """
124
+
125
+ # increment counter
126
+ self.items_since_load += 1
127
+
128
+ # load next chunk into buffer if needed
129
+ if self.items_since_load > self.buffer_reload_rate:
130
+ self.load_audio_buffer()
131
+
132
+ rand_input_file_id = utils.get_random_file_id(self.input_files_loaded.keys())
133
+ # use this random key to retrieve an input file
134
+ input_file = self.input_files_loaded[rand_input_file_id]
135
+
136
+ # load the audio data if needed
137
+ if not input_file.loaded:
138
+ input_file.load()
139
+
140
+ # get a random patch of size `self.length`
141
+ # start_idx, stop_idx = utils.get_random_patch(input_file, self.sample_rate, self.length)
142
+ start_idx, stop_idx = utils.get_random_patch(input_file, self.length)
143
+ input_audio = input_file.audio[:, start_idx:stop_idx].clone().detach()
144
+
145
+ # random scaling
146
+ input_audio /= input_audio.abs().max()
147
+ scale_dB = (torch.rand(1).squeeze().numpy() * 12) + 12
148
+ input_audio *= 10 ** (-scale_dB / 20.0)
149
+
150
+ # generate random parameters (uniform) over 0 to 1
151
+ params = torch.rand(self.processor.num_control_params)
152
+
153
+ # expects batch dim
154
+ # apply plugins with random parameters
155
+ if self.processor_type == "channel":
156
+ params[-1] = 0.5 # set makeup gain to 0dB
157
+ target_audio = self.processor(
158
+ input_audio.view(1, 1, -1),
159
+ params.view(1, -1),
160
+ )
161
+ target_audio = target_audio.view(1, -1)
162
+ elif self.processor_type == "peq":
163
+ target_audio = self.processor(
164
+ input_audio.view(1, 1, -1).numpy(),
165
+ params.view(1, -1).numpy(),
166
+ )
167
+ target_audio = torch.tensor(target_audio).view(1, -1)
168
+ elif self.processor_type == "comp":
169
+ params[-1] = 0.5 # set makeup gain to 0dB
170
+ target_audio = self.processor(
171
+ input_audio.view(1, 1, -1).numpy(),
172
+ params.view(1, -1).numpy(),
173
+ )
174
+ target_audio = torch.tensor(target_audio).view(1, -1)
175
+
176
+ # clip
177
+ if self.soft_clip:
178
+ # target_audio = target_audio.clamp(-2.0, 2.0)
179
+ target_audio = torch.tanh(target_audio / 2.0) * 2.0
180
+
181
+ return input_audio, target_audio, params
deepafx_st/data/style.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import torch
4
+ import torchaudio
5
+ from tqdm import tqdm
6
+
7
+
8
+ class StyleDataset(torch.utils.data.Dataset):
9
+ def __init__(
10
+ self,
11
+ audio_dir: str,
12
+ subset: str = "train",
13
+ sample_rate: int = 24000,
14
+ length: int = 131072,
15
+ ) -> None:
16
+ super().__init__()
17
+ self.audio_dir = audio_dir
18
+ self.subset = subset
19
+ self.sample_rate = sample_rate
20
+ self.length = length
21
+
22
+ self.style_dirs = glob.glob(os.path.join(audio_dir, subset, "*"))
23
+ self.style_dirs = [sd for sd in self.style_dirs if os.path.isdir(sd)]
24
+ self.num_classes = len(self.style_dirs)
25
+ self.class_labels = {"broadcast" : 0, "telephone": 1, "neutral": 2, "bright": 3, "warm": 4}
26
+
27
+ self.examples = []
28
+ for n, style_dir in enumerate(self.style_dirs):
29
+
30
+ # get all files in style dir
31
+ style_filepaths = glob.glob(os.path.join(style_dir, "*.wav"))
32
+ style_name = os.path.basename(style_dir)
33
+ for style_filepath in tqdm(style_filepaths, ncols=120):
34
+ # load audio file
35
+ x, sr = torchaudio.load(style_filepath)
36
+
37
+ # sum to mono if needed
38
+ if x.shape[0] > 1:
39
+ x = x.mean(dim=0, keepdim=True)
40
+
41
+ # resample
42
+ if sr != self.sample_rate:
43
+ x = torchaudio.transforms.Resample(sr, self.sample_rate)(x)
44
+
45
+ # crop length after resample
46
+ if x.shape[-1] >= self.length:
47
+ x = x[...,:self.length]
48
+
49
+ # store example
50
+ example = (x, self.class_labels[style_name])
51
+ self.examples.append(example)
52
+
53
+ print(f"Loaded {len(self.examples)} examples for {subset} subset.")
54
+
55
+ def __len__(self):
56
+ return len(self.examples)
57
+
58
+ def __getitem__(self, idx):
59
+ example = self.examples[idx]
60
+ x = example[0]
61
+ y = example[1]
62
+ return x, y
deepafx_st/metrics.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import auraloss
3
+ import resampy
4
+ import torchaudio
5
+ from pesq import pesq
6
+ import pyloudnorm as pyln
7
+
8
+
9
+ def crest_factor(x):
10
+ """Compute the crest factor of waveform."""
11
+
12
+ peak, _ = x.abs().max(dim=-1)
13
+ rms = torch.sqrt((x ** 2).mean(dim=-1))
14
+
15
+ return 20 * torch.log(peak / rms.clamp(1e-8))
16
+
17
+
18
+ def rms_energy(x):
19
+
20
+ rms = torch.sqrt((x ** 2).mean(dim=-1))
21
+
22
+ return 20 * torch.log(rms.clamp(1e-8))
23
+
24
+
25
+ def spectral_centroid(x):
26
+ """Compute the crest factor of waveform.
27
+
28
+ See: https://gist.github.com/endolith/359724
29
+
30
+ """
31
+
32
+ spectrum = torch.fft.rfft(x).abs()
33
+ normalized_spectrum = spectrum / spectrum.sum()
34
+ normalized_frequencies = torch.linspace(0, 1, spectrum.shape[-1])
35
+ spectral_centroid = torch.sum(normalized_frequencies * normalized_spectrum)
36
+
37
+ return spectral_centroid
38
+
39
+
40
+ def loudness(x, sample_rate):
41
+ """Compute the loudness in dB LUFS of waveform."""
42
+ meter = pyln.Meter(sample_rate)
43
+
44
+ # add stereo dim if needed
45
+ if x.shape[0] < 2:
46
+ x = x.repeat(2, 1)
47
+
48
+ return torch.tensor(meter.integrated_loudness(x.permute(1, 0).numpy()))
49
+
50
+
51
+ class MelSpectralDistance(torch.nn.Module):
52
+ def __init__(self, sample_rate, length=65536):
53
+ super().__init__()
54
+ self.error = auraloss.freq.MelSTFTLoss(
55
+ sample_rate,
56
+ fft_size=length,
57
+ hop_size=length,
58
+ win_length=length,
59
+ w_sc=0,
60
+ w_log_mag=1,
61
+ w_lin_mag=1,
62
+ n_mels=128,
63
+ scale_invariance=False,
64
+ )
65
+
66
+ # I think scale invariance may not work well,
67
+ # since aspects of the phase may be considered?
68
+
69
+ def forward(self, input, target):
70
+ return self.error(input, target)
71
+
72
+
73
+ class PESQ(torch.nn.Module):
74
+ def __init__(self, sample_rate):
75
+ super().__init__()
76
+ self.sample_rate = sample_rate
77
+
78
+ def forward(self, input, target):
79
+ if self.sample_rate != 16000:
80
+ target = resampy.resample(
81
+ target.view(-1).numpy(),
82
+ self.sample_rate,
83
+ 16000,
84
+ )
85
+ input = resampy.resample(
86
+ input.view(-1).numpy(),
87
+ self.sample_rate,
88
+ 16000,
89
+ )
90
+
91
+ return pesq(
92
+ 16000,
93
+ target,
94
+ input,
95
+ "wb",
96
+ )
97
+
98
+
99
+ class CrestFactorError(torch.nn.Module):
100
+ def __init__(self):
101
+ super().__init__()
102
+
103
+ def forward(self, input, target):
104
+ return torch.nn.functional.l1_loss(
105
+ crest_factor(input),
106
+ crest_factor(target),
107
+ ).item()
108
+
109
+
110
+ class RMSEnergyError(torch.nn.Module):
111
+ def __init__(self):
112
+ super().__init__()
113
+
114
+ def forward(self, input, target):
115
+ return torch.nn.functional.l1_loss(
116
+ rms_energy(input),
117
+ rms_energy(target),
118
+ ).item()
119
+
120
+
121
+ class SpectralCentroidError(torch.nn.Module):
122
+ def __init__(self, sample_rate, n_fft=2048, hop_length=512):
123
+ super().__init__()
124
+
125
+ self.spectral_centroid = torchaudio.transforms.SpectralCentroid(
126
+ sample_rate,
127
+ n_fft=n_fft,
128
+ hop_length=hop_length,
129
+ )
130
+
131
+ def forward(self, input, target):
132
+ return torch.nn.functional.l1_loss(
133
+ self.spectral_centroid(input + 1e-16).mean(),
134
+ self.spectral_centroid(target + 1e-16).mean(),
135
+ ).item()
136
+
137
+
138
+ class LoudnessError(torch.nn.Module):
139
+ def __init__(self, sample_rate: int, peak_normalize: bool = False):
140
+ super().__init__()
141
+ self.sample_rate = sample_rate
142
+ self.peak_normalize = peak_normalize
143
+
144
+ def forward(self, input, target):
145
+
146
+ if self.peak_normalize:
147
+ # peak normalize
148
+ x = input / input.abs().max()
149
+ y = target / target.abs().max()
150
+ else:
151
+ x = input
152
+ y = target
153
+
154
+ return torch.nn.functional.l1_loss(
155
+ loudness(x.view(1, -1), self.sample_rate),
156
+ loudness(y.view(1, -1), self.sample_rate),
157
+ ).item()
deepafx_st/models/baselines.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import scipy.signal
4
+ import numpy as np
5
+ import pyloudnorm as pyln
6
+ import matplotlib.pyplot as plt
7
+ from deepafx_st.processors.dsp.compressor import compressor
8
+
9
+ from tqdm import tqdm
10
+
11
+
12
+ class BaselineEQ(torch.nn.Module):
13
+ def __init__(
14
+ self,
15
+ ntaps: int = 63,
16
+ n_fft: int = 65536,
17
+ sample_rate: float = 44100,
18
+ ):
19
+ super().__init__()
20
+ self.ntaps = ntaps
21
+ self.n_fft = n_fft
22
+ self.sample_rate = sample_rate
23
+
24
+ # compute the target spectrum
25
+ # print("Computing target spectrum...")
26
+ # self.target_spec, self.sm_target_spec = self.analyze_speech_dataset(filepaths)
27
+ # self.plot_spectrum(self.target_spec, filename="targetEQ")
28
+ # self.plot_spectrum(self.sm_target_spec, filename="targetEQsm")
29
+
30
+ def forward(self, x, y):
31
+
32
+ bs, ch, s = x.size()
33
+
34
+ x = x.view(bs * ch, -1)
35
+ y = y.view(bs * ch, -1)
36
+
37
+ in_spec = self.get_average_spectrum(x)
38
+ ref_spec = self.get_average_spectrum(y)
39
+
40
+ sm_in_spec = self.smooth_spectrum(in_spec)
41
+ sm_ref_spec = self.smooth_spectrum(ref_spec)
42
+
43
+ # self.plot_spectrum(in_spec, filename="inSpec")
44
+ # self.plot_spectrum(sm_in_spec, filename="inSpecsm")
45
+
46
+ # design inverse FIR filter to match target EQ
47
+ freqs = np.linspace(0, 1.0, num=(self.n_fft // 2) + 1)
48
+ response = sm_ref_spec / sm_in_spec
49
+ response[-1] = 0.0 # zero gain at nyquist
50
+
51
+ b = scipy.signal.firwin2(
52
+ self.ntaps,
53
+ freqs * (self.sample_rate / 2),
54
+ response,
55
+ fs=self.sample_rate,
56
+ )
57
+
58
+ # scale the coefficients for less intense filter
59
+ # clearb *= 0.5
60
+
61
+ # apply the filter
62
+ x_filt = scipy.signal.lfilter(b, [1.0], x.numpy())
63
+ x_filt = torch.tensor(x_filt.astype("float32"))
64
+
65
+ if False:
66
+ # plot the filter response
67
+ w, h = scipy.signal.freqz(b, fs=self.sample_rate, worN=response.shape[-1])
68
+
69
+ fig, ax1 = plt.subplots()
70
+ ax1.set_title("Digital filter frequency response")
71
+ ax1.plot(w, 20 * np.log10(abs(h + 1e-8)))
72
+ ax1.plot(w, 20 * np.log10(abs(response + 1e-8)))
73
+
74
+ ax1.set_xscale("log")
75
+ ax1.set_ylim([-12, 12])
76
+ plt.grid(c="lightgray")
77
+ plt.savefig(f"inverse.png")
78
+
79
+ x_filt_avg_spec = self.get_average_spectrum(x_filt)
80
+ sm_x_filt_avg_spec = self.smooth_spectrum(x_filt_avg_spec)
81
+ y_avg_spec = self.get_average_spectrum(y)
82
+ sm_y_avg_spec = self.smooth_spectrum(y_avg_spec)
83
+ compare = torch.stack(
84
+ [
85
+ torch.tensor(sm_in_spec),
86
+ torch.tensor(sm_x_filt_avg_spec),
87
+ torch.tensor(sm_ref_spec),
88
+ torch.tensor(sm_y_avg_spec),
89
+ ]
90
+ )
91
+ self.plot_multi_spectrum(
92
+ compare,
93
+ legend=["in", "out", "target curve", "actual target"],
94
+ filename="outSpec",
95
+ )
96
+
97
+ return x_filt
98
+
99
+ def analyze_speech_dataset(self, filepaths, peak=-3.0):
100
+ avg_spec = []
101
+ for filepath in tqdm(filepaths, ncols=80):
102
+ x, sr = torchaudio.load(filepath)
103
+ x /= x.abs().max()
104
+ x *= 10 ** (peak / 20.0)
105
+ avg_spec.append(self.get_average_spectrum(x))
106
+ avg_specs = torch.stack(avg_spec)
107
+
108
+ avg_spec = avg_specs.mean(dim=0).numpy()
109
+ avg_spec_std = avg_specs.std(dim=0).numpy()
110
+
111
+ # self.plot_multi_spectrum(avg_specs, filename="allTargetEQs")
112
+ # self.plot_spectrum_stats(avg_spec, avg_spec_std, filename="targetEQstats")
113
+
114
+ sm_avg_spec = self.smooth_spectrum(avg_spec)
115
+
116
+ return avg_spec, sm_avg_spec
117
+
118
+ def smooth_spectrum(self, H):
119
+ # apply Savgol filter for smoothed target curve
120
+ return scipy.signal.savgol_filter(H, 1025, 2)
121
+
122
+ def get_average_spectrum(self, x):
123
+
124
+ # x = x[:, : self.n_fft]
125
+ X = torch.stft(x, self.n_fft, return_complex=True, normalized=True)
126
+ # fft_size = self.next_power_of_2(x.shape[-1])
127
+ # X = torch.fft.rfft(x, n=fft_size)
128
+
129
+ X = X.abs() # convert to magnitude
130
+ X = X.mean(dim=-1).view(-1) # average across frames
131
+
132
+ return X
133
+
134
+ @staticmethod
135
+ def next_power_of_2(x):
136
+ return 1 if x == 0 else int(2 ** np.ceil(np.log2(x)))
137
+
138
+ def plot_multi_spectrum(self, Hs, legend=[], filename=None):
139
+
140
+ bin_width = (self.sample_rate / 2) / (self.n_fft // 2)
141
+ freqs = np.arange(0, (self.sample_rate / 2) + bin_width, step=bin_width)
142
+
143
+ fig, ax1 = plt.subplots()
144
+
145
+ for H in Hs:
146
+ ax1.plot(
147
+ freqs,
148
+ 20 * np.log10(abs(H) + 1e-8),
149
+ )
150
+
151
+ plt.legend(legend)
152
+
153
+ # avg_spec = Hs.mean(dim=0).numpy()
154
+ # ax1.plot(freqs, 20 * np.log10(avg_spec), color="k", linewidth=2)
155
+
156
+ ax1.set_xscale("log")
157
+ ax1.set_ylim([-80, 0])
158
+ plt.grid(c="lightgray")
159
+
160
+ if filename is not None:
161
+ plt.savefig(f"{filename}.png")
162
+
163
+ def plot_spectrum_stats(self, H_mean, H_std, filename=None):
164
+ bin_width = (self.sample_rate / 2) / (self.n_fft // 2)
165
+ freqs = np.arange(0, (self.sample_rate / 2) + bin_width, step=bin_width)
166
+
167
+ fig, ax1 = plt.subplots()
168
+ ax1.plot(freqs, 20 * np.log10(H_mean))
169
+ ax1.plot(
170
+ freqs,
171
+ (20 * np.log10(H_mean)) + (20 * np.log10(H_std)),
172
+ linestyle="--",
173
+ color="k",
174
+ )
175
+ ax1.plot(
176
+ freqs,
177
+ (20 * np.log10(H_mean)) - (20 * np.log10(H_std)),
178
+ linestyle="--",
179
+ color="k",
180
+ )
181
+
182
+ ax1.set_xscale("log")
183
+ ax1.set_ylim([-80, 0])
184
+ plt.grid(c="lightgray")
185
+
186
+ if filename is not None:
187
+ plt.savefig(f"{filename}.png")
188
+
189
+ def plot_spectrum(self, H, legend=[], filename=None):
190
+
191
+ bin_width = (self.sample_rate / 2) / (self.n_fft // 2)
192
+ freqs = np.arange(0, (self.sample_rate / 2) + bin_width, step=bin_width)
193
+
194
+ fig, ax1 = plt.subplots()
195
+ ax1.plot(freqs, 20 * np.log10(H))
196
+ ax1.set_xscale("log")
197
+ ax1.set_ylim([-80, 0])
198
+ plt.grid(c="lightgray")
199
+
200
+ plt.legend(legend)
201
+
202
+ if filename is not None:
203
+ plt.savefig(f"{filename}.png")
204
+
205
+
206
+ class BaslineComp(torch.nn.Module):
207
+ def __init__(
208
+ self,
209
+ sample_rate: float = 44100,
210
+ ):
211
+ super().__init__()
212
+ self.sample_rate = sample_rate
213
+ self.meter = pyln.Meter(sample_rate)
214
+
215
+ def forward(self, x, y):
216
+
217
+ x_lufs = self.meter.integrated_loudness(x.view(-1).numpy())
218
+ y_lufs = self.meter.integrated_loudness(y.view(-1).numpy())
219
+
220
+ delta_lufs = y_lufs - x_lufs
221
+
222
+ threshold = 0.0
223
+ x_comp = x
224
+ x_comp_new = x
225
+ while delta_lufs > 0.5 and threshold > -80.0:
226
+ x_comp = x_comp_new # use the last setting
227
+ x_comp_new = compressor(
228
+ x.view(-1).numpy(),
229
+ self.sample_rate,
230
+ threshold=threshold,
231
+ ratio=3,
232
+ attack_time=0.001,
233
+ release_time=0.05,
234
+ knee_dB=6.0,
235
+ makeup_gain_dB=0.0,
236
+ )
237
+ x_comp_new = torch.tensor(x_comp_new)
238
+ x_comp_new /= x_comp_new.abs().max()
239
+ x_comp_new *= 10 ** (-12.0 / 20)
240
+ x_lufs = self.meter.integrated_loudness(x_comp_new.view(-1).numpy())
241
+ delta_lufs = y_lufs - x_lufs
242
+ threshold -= 0.5
243
+
244
+ return x_comp.view(1, 1, -1)
245
+
246
+
247
+ class BaselineEQAndComp(torch.nn.Module):
248
+ def __init__(
249
+ self,
250
+ ntaps=63,
251
+ n_fft=65536,
252
+ sample_rate=44100,
253
+ block_size=1024,
254
+ plugin_config=None,
255
+ ):
256
+ super().__init__()
257
+ self.eq = BaselineEQ(ntaps, n_fft, sample_rate)
258
+ self.comp = BaslineComp(sample_rate)
259
+
260
+ def forward(self, x, y):
261
+
262
+ with torch.inference_mode():
263
+ x /= x.abs().max()
264
+ y /= y.abs().max()
265
+ x *= 10 ** (-12.0 / 20)
266
+ y *= 10 ** (-12.0 / 20)
267
+
268
+ x = self.eq(x, y)
269
+
270
+ x /= x.abs().max()
271
+ y /= y.abs().max()
272
+ x *= 10 ** (-12.0 / 20)
273
+ y *= 10 ** (-12.0 / 20)
274
+
275
+ x = self.comp(x, y)
276
+
277
+ x /= x.abs().max()
278
+ x *= 10 ** (-12.0 / 20)
279
+
280
+ return x
deepafx_st/models/controller.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class StyleTransferController(torch.nn.Module):
4
+ def __init__(
5
+ self,
6
+ num_control_params,
7
+ edim,
8
+ hidden_dim=256,
9
+ agg_method="mlp",
10
+ ):
11
+ """Plugin parameter controller module to map from input to target style.
12
+
13
+ Args:
14
+ num_control_params (int): Number of plugin parameters to predicted.
15
+ edim (int): Size of the encoder representations.
16
+ hidden_dim (int, optional): Hidden size of the 3-layer parameter predictor MLP. Default: 256
17
+ agg_method (str, optional): Input/reference embed aggregation method ["conv" or "linear", "mlp"]. Default: "mlp"
18
+ """
19
+ super().__init__()
20
+ self.num_control_params = num_control_params
21
+ self.edim = edim
22
+ self.hidden_dim = hidden_dim
23
+ self.agg_method = agg_method
24
+
25
+ if agg_method == "conv":
26
+ self.agg = torch.nn.Conv1d(
27
+ 2,
28
+ 1,
29
+ kernel_size=129,
30
+ stride=1,
31
+ padding="same",
32
+ bias=False,
33
+ )
34
+ mlp_in_dim = edim
35
+ elif agg_method == "linear":
36
+ self.agg = torch.nn.Linear(edim * 2, edim)
37
+ elif agg_method == "mlp":
38
+ self.agg = None
39
+ mlp_in_dim = edim * 2
40
+ else:
41
+ raise ValueError(f"Invalid agg_method = {self.agg_method}.")
42
+
43
+ self.mlp = torch.nn.Sequential(
44
+ torch.nn.Linear(mlp_in_dim, hidden_dim),
45
+ torch.nn.LeakyReLU(0.01),
46
+ torch.nn.Linear(hidden_dim, hidden_dim),
47
+ torch.nn.LeakyReLU(0.01),
48
+ torch.nn.Linear(hidden_dim, num_control_params),
49
+ torch.nn.Sigmoid(), # normalize between 0 and 1
50
+ )
51
+
52
+ def forward(self, e_x, e_y, z=None):
53
+ """Forward pass to generate plugin parameters.
54
+
55
+ Args:
56
+ e_x (tensor): Input signal embedding of shape (batch, edim)
57
+ e_y (tensor): Target signal embedding of shape (batch, edim)
58
+ Returns:
59
+ p (tensor): Estimated control parameters of shape (batch, num_control_params)
60
+ """
61
+
62
+ # use learnable projection
63
+ if self.agg_method == "conv":
64
+ e_xy = torch.stack((e_x, e_y), dim=1) # concat on channel dim
65
+ e_xy = self.agg(e_xy)
66
+ elif self.agg_method == "linear":
67
+ e_xy = torch.cat((e_x, e_y), dim=-1) # concat on embed dim
68
+ e_xy = self.agg(e_xy)
69
+ else:
70
+ e_xy = torch.cat((e_x, e_y), dim=-1) # concat on embed dim
71
+
72
+ # pass through MLP to project to control parametesr
73
+ p = self.mlp(e_xy.squeeze(1))
74
+
75
+ return p
deepafx_st/models/efficient_net/LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [yyyy] [name of copyright owner]
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
deepafx_st/models/efficient_net/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "0.7.1"
2
+ from .model import EfficientNet, VALID_MODELS
3
+ from .utils import (
4
+ GlobalParams,
5
+ BlockArgs,
6
+ BlockDecoder,
7
+ efficientnet,
8
+ get_model_params,
9
+ )
deepafx_st/models/efficient_net/model.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """model.py - Model and module class for EfficientNet.
2
+ They are built to mirror those in the official TensorFlow implementation.
3
+ """
4
+
5
+ # Author: lukemelas (github username)
6
+ # Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
7
+ # With adjustments and added comments by workingcoder (github username).
8
+
9
+ import torch
10
+ from torch import nn
11
+ from torch.nn import functional as F
12
+ from .utils import (
13
+ round_filters,
14
+ round_repeats,
15
+ drop_connect,
16
+ get_same_padding_conv2d,
17
+ get_model_params,
18
+ efficientnet_params,
19
+ load_pretrained_weights,
20
+ Swish,
21
+ MemoryEfficientSwish,
22
+ calculate_output_image_size
23
+ )
24
+
25
+
26
+ VALID_MODELS = (
27
+ 'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3',
28
+ 'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7',
29
+ 'efficientnet-b8',
30
+
31
+ # Support the construction of 'efficientnet-l2' without pretrained weights
32
+ 'efficientnet-l2'
33
+ )
34
+
35
+
36
+ class MBConvBlock(nn.Module):
37
+ """Mobile Inverted Residual Bottleneck Block.
38
+
39
+ Args:
40
+ block_args (namedtuple): BlockArgs, defined in utils.py.
41
+ global_params (namedtuple): GlobalParam, defined in utils.py.
42
+ image_size (tuple or list): [image_height, image_width].
43
+
44
+ References:
45
+ [1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
46
+ [2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
47
+ [3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
48
+ """
49
+
50
+ def __init__(self, block_args, global_params, image_size=None):
51
+ super().__init__()
52
+ self._block_args = block_args
53
+ self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow
54
+ self._bn_eps = global_params.batch_norm_epsilon
55
+ self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
56
+ self.id_skip = block_args.id_skip # whether to use skip connection and drop connect
57
+
58
+ # Expansion phase (Inverted Bottleneck)
59
+ inp = self._block_args.input_filters # number of input channels
60
+ oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
61
+ if self._block_args.expand_ratio != 1:
62
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
63
+ self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
64
+ self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
65
+ # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
66
+
67
+ # Depthwise convolution phase
68
+ k = self._block_args.kernel_size
69
+ s = self._block_args.stride
70
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
71
+ self._depthwise_conv = Conv2d(
72
+ in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
73
+ kernel_size=k, stride=s, bias=False)
74
+ self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
75
+ image_size = calculate_output_image_size(image_size, s)
76
+
77
+ # Squeeze and Excitation layer, if desired
78
+ if self.has_se:
79
+ Conv2d = get_same_padding_conv2d(image_size=(1, 1))
80
+ num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
81
+ self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
82
+ self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
83
+
84
+ # Pointwise convolution phase
85
+ final_oup = self._block_args.output_filters
86
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
87
+ self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
88
+ self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
89
+ self._swish = MemoryEfficientSwish()
90
+
91
+ def forward(self, inputs, drop_connect_rate=None):
92
+ """MBConvBlock's forward function.
93
+
94
+ Args:
95
+ inputs (tensor): Input tensor.
96
+ drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
97
+
98
+ Returns:
99
+ Output of this block after processing.
100
+ """
101
+
102
+ # Expansion and Depthwise Convolution
103
+ x = inputs
104
+ if self._block_args.expand_ratio != 1:
105
+ x = self._expand_conv(inputs)
106
+ x = self._bn0(x)
107
+ x = self._swish(x)
108
+
109
+ x = self._depthwise_conv(x)
110
+ x = self._bn1(x)
111
+ x = self._swish(x)
112
+
113
+ # Squeeze and Excitation
114
+ if self.has_se:
115
+ x_squeezed = F.adaptive_avg_pool2d(x, 1)
116
+ x_squeezed = self._se_reduce(x_squeezed)
117
+ x_squeezed = self._swish(x_squeezed)
118
+ x_squeezed = self._se_expand(x_squeezed)
119
+ x = torch.sigmoid(x_squeezed) * x
120
+
121
+ # Pointwise Convolution
122
+ x = self._project_conv(x)
123
+ x = self._bn2(x)
124
+
125
+ # Skip connection and drop connect
126
+ input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
127
+ if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
128
+ # The combination of skip connection and drop connect brings about stochastic depth.
129
+ if drop_connect_rate:
130
+ x = drop_connect(x, p=drop_connect_rate, training=self.training)
131
+ x = x + inputs # skip connection
132
+ return x
133
+
134
+ def set_swish(self, memory_efficient=True):
135
+ """Sets swish function as memory efficient (for training) or standard (for export).
136
+
137
+ Args:
138
+ memory_efficient (bool): Whether to use memory-efficient version of swish.
139
+ """
140
+ self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
141
+
142
+
143
+ class EfficientNet(nn.Module):
144
+ """EfficientNet model.
145
+ Most easily loaded with the .from_name or .from_pretrained methods.
146
+
147
+ Args:
148
+ blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks.
149
+ global_params (namedtuple): A set of GlobalParams shared between blocks.
150
+
151
+ References:
152
+ [1] https://arxiv.org/abs/1905.11946 (EfficientNet)
153
+
154
+ Example:
155
+ >>> import torch
156
+ >>> from efficientnet.model import EfficientNet
157
+ >>> inputs = torch.rand(1, 3, 224, 224)
158
+ >>> model = EfficientNet.from_pretrained('efficientnet-b0')
159
+ >>> model.eval()
160
+ >>> outputs = model(inputs)
161
+ """
162
+
163
+ def __init__(self, blocks_args=None, global_params=None):
164
+ super().__init__()
165
+ assert isinstance(blocks_args, list), 'blocks_args should be a list'
166
+ assert len(blocks_args) > 0, 'block args must be greater than 0'
167
+ self._global_params = global_params
168
+ self._blocks_args = blocks_args
169
+
170
+ # Batch norm parameters
171
+ bn_mom = 1 - self._global_params.batch_norm_momentum
172
+ bn_eps = self._global_params.batch_norm_epsilon
173
+
174
+ # Get stem static or dynamic convolution depending on image size
175
+ image_size = global_params.image_size
176
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
177
+
178
+ # Stem
179
+ in_channels = 3 # rgb
180
+ out_channels = round_filters(32, self._global_params) # number of output channels
181
+ self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
182
+ self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
183
+ image_size = calculate_output_image_size(image_size, 2)
184
+
185
+ # Build blocks
186
+ self._blocks = nn.ModuleList([])
187
+ for block_args in self._blocks_args:
188
+
189
+ # Update block input and output filters based on depth multiplier.
190
+ block_args = block_args._replace(
191
+ input_filters=round_filters(block_args.input_filters, self._global_params),
192
+ output_filters=round_filters(block_args.output_filters, self._global_params),
193
+ num_repeat=round_repeats(block_args.num_repeat, self._global_params)
194
+ )
195
+
196
+ # The first block needs to take care of stride and filter size increase.
197
+ self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
198
+ image_size = calculate_output_image_size(image_size, block_args.stride)
199
+ if block_args.num_repeat > 1: # modify block_args to keep same output size
200
+ block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
201
+ for _ in range(block_args.num_repeat - 1):
202
+ self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
203
+ # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1
204
+
205
+ # Head
206
+ in_channels = block_args.output_filters # output of final block
207
+ out_channels = round_filters(1280, self._global_params)
208
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
209
+ self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
210
+ self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
211
+
212
+ # Final linear layer
213
+ self._avg_pooling = nn.AdaptiveAvgPool2d(1)
214
+ if self._global_params.include_top:
215
+ self._dropout = nn.Dropout(self._global_params.dropout_rate)
216
+ self._fc = nn.Linear(out_channels, self._global_params.num_classes)
217
+
218
+ # set activation to memory efficient swish by default
219
+ self._swish = MemoryEfficientSwish()
220
+
221
+ def set_swish(self, memory_efficient=True):
222
+ """Sets swish function as memory efficient (for training) or standard (for export).
223
+
224
+ Args:
225
+ memory_efficient (bool): Whether to use memory-efficient version of swish.
226
+ """
227
+ self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
228
+ for block in self._blocks:
229
+ block.set_swish(memory_efficient)
230
+
231
+ def extract_endpoints(self, inputs):
232
+ """Use convolution layer to extract features
233
+ from reduction levels i in [1, 2, 3, 4, 5].
234
+
235
+ Args:
236
+ inputs (tensor): Input tensor.
237
+
238
+ Returns:
239
+ Dictionary of last intermediate features
240
+ with reduction levels i in [1, 2, 3, 4, 5].
241
+ Example:
242
+ >>> import torch
243
+ >>> from efficientnet.model import EfficientNet
244
+ >>> inputs = torch.rand(1, 3, 224, 224)
245
+ >>> model = EfficientNet.from_pretrained('efficientnet-b0')
246
+ >>> endpoints = model.extract_endpoints(inputs)
247
+ >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
248
+ >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
249
+ >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
250
+ >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
251
+ >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7])
252
+ >>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7])
253
+ """
254
+ endpoints = dict()
255
+
256
+ # Stem
257
+ x = self._swish(self._bn0(self._conv_stem(inputs)))
258
+ prev_x = x
259
+
260
+ # Blocks
261
+ for idx, block in enumerate(self._blocks):
262
+ drop_connect_rate = self._global_params.drop_connect_rate
263
+ if drop_connect_rate:
264
+ drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
265
+ x = block(x, drop_connect_rate=drop_connect_rate)
266
+ if prev_x.size(2) > x.size(2):
267
+ endpoints['reduction_{}'.format(len(endpoints) + 1)] = prev_x
268
+ elif idx == len(self._blocks) - 1:
269
+ endpoints['reduction_{}'.format(len(endpoints) + 1)] = x
270
+ prev_x = x
271
+
272
+ # Head
273
+ x = self._swish(self._bn1(self._conv_head(x)))
274
+ endpoints['reduction_{}'.format(len(endpoints) + 1)] = x
275
+
276
+ return endpoints
277
+
278
+ def extract_features(self, inputs):
279
+ """use convolution layer to extract feature .
280
+
281
+ Args:
282
+ inputs (tensor): Input tensor.
283
+
284
+ Returns:
285
+ Output of the final convolution
286
+ layer in the efficientnet model.
287
+ """
288
+ # Stem
289
+ x = self._swish(self._bn0(self._conv_stem(inputs)))
290
+
291
+ # Blocks
292
+ for idx, block in enumerate(self._blocks):
293
+ drop_connect_rate = self._global_params.drop_connect_rate
294
+ if drop_connect_rate:
295
+ drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
296
+ x = block(x, drop_connect_rate=drop_connect_rate)
297
+
298
+ # Head
299
+ x = self._swish(self._bn1(self._conv_head(x)))
300
+
301
+ return x
302
+
303
+ def forward(self, inputs):
304
+ """EfficientNet's forward function.
305
+ Calls extract_features to extract features, applies final linear layer, and returns logits.
306
+
307
+ Args:
308
+ inputs (tensor): Input tensor.
309
+
310
+ Returns:
311
+ Output of this model after processing.
312
+ """
313
+ # Convolution layers
314
+ x = self.extract_features(inputs)
315
+ # Pooling and final linear layer
316
+ x = self._avg_pooling(x)
317
+ if self._global_params.include_top:
318
+ x = x.flatten(start_dim=1)
319
+ x = self._dropout(x)
320
+ x = self._fc(x)
321
+ return x
322
+
323
+ @classmethod
324
+ def from_name(cls, model_name, in_channels=3, **override_params):
325
+ """Create an efficientnet model according to name.
326
+
327
+ Args:
328
+ model_name (str): Name for efficientnet.
329
+ in_channels (int): Input data's channel number.
330
+ override_params (other key word params):
331
+ Params to override model's global_params.
332
+ Optional key:
333
+ 'width_coefficient', 'depth_coefficient',
334
+ 'image_size', 'dropout_rate',
335
+ 'num_classes', 'batch_norm_momentum',
336
+ 'batch_norm_epsilon', 'drop_connect_rate',
337
+ 'depth_divisor', 'min_depth'
338
+
339
+ Returns:
340
+ An efficientnet model.
341
+ """
342
+ cls._check_model_name_is_valid(model_name)
343
+ blocks_args, global_params = get_model_params(model_name, override_params)
344
+ model = cls(blocks_args, global_params)
345
+ model._change_in_channels(in_channels)
346
+ return model
347
+
348
+ @classmethod
349
+ def from_pretrained(cls, model_name, weights_path=None, advprop=False,
350
+ in_channels=3, num_classes=1000, **override_params):
351
+ """Create an efficientnet model according to name.
352
+
353
+ Args:
354
+ model_name (str): Name for efficientnet.
355
+ weights_path (None or str):
356
+ str: path to pretrained weights file on the local disk.
357
+ None: use pretrained weights downloaded from the Internet.
358
+ advprop (bool):
359
+ Whether to load pretrained weights
360
+ trained with advprop (valid when weights_path is None).
361
+ in_channels (int): Input data's channel number.
362
+ num_classes (int):
363
+ Number of categories for classification.
364
+ It controls the output size for final linear layer.
365
+ override_params (other key word params):
366
+ Params to override model's global_params.
367
+ Optional key:
368
+ 'width_coefficient', 'depth_coefficient',
369
+ 'image_size', 'dropout_rate',
370
+ 'batch_norm_momentum',
371
+ 'batch_norm_epsilon', 'drop_connect_rate',
372
+ 'depth_divisor', 'min_depth'
373
+
374
+ Returns:
375
+ A pretrained efficientnet model.
376
+ """
377
+ model = cls.from_name(model_name, num_classes=num_classes, **override_params)
378
+ load_pretrained_weights(model, model_name, weights_path=weights_path,
379
+ load_fc=(num_classes == 1000), advprop=advprop)
380
+ model._change_in_channels(in_channels)
381
+ return model
382
+
383
+ @classmethod
384
+ def get_image_size(cls, model_name):
385
+ """Get the input image size for a given efficientnet model.
386
+
387
+ Args:
388
+ model_name (str): Name for efficientnet.
389
+
390
+ Returns:
391
+ Input image size (resolution).
392
+ """
393
+ cls._check_model_name_is_valid(model_name)
394
+ _, _, res, _ = efficientnet_params(model_name)
395
+ return res
396
+
397
+ @classmethod
398
+ def _check_model_name_is_valid(cls, model_name):
399
+ """Validates model name.
400
+
401
+ Args:
402
+ model_name (str): Name for efficientnet.
403
+
404
+ Returns:
405
+ bool: Is a valid name or not.
406
+ """
407
+ if model_name not in VALID_MODELS:
408
+ raise ValueError('model_name should be one of: ' + ', '.join(VALID_MODELS))
409
+
410
+ def _change_in_channels(self, in_channels):
411
+ """Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
412
+
413
+ Args:
414
+ in_channels (int): Input data's channel number.
415
+ """
416
+ if in_channels != 3:
417
+ Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
418
+ out_channels = round_filters(32, self._global_params)
419
+ self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
deepafx_st/models/efficient_net/utils.py ADDED
@@ -0,0 +1,616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """utils.py - Helper functions for building the model and for loading model parameters.
2
+ These helper functions are built to mirror those in the official TensorFlow implementation.
3
+ """
4
+
5
+ # Author: lukemelas (github username)
6
+ # Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
7
+ # With adjustments and added comments by workingcoder (github username).
8
+
9
+ import re
10
+ import math
11
+ import collections
12
+ from functools import partial
13
+ import torch
14
+ from torch import nn
15
+ from torch.nn import functional as F
16
+ from torch.utils import model_zoo
17
+
18
+
19
+ ################################################################################
20
+ # Help functions for model architecture
21
+ ################################################################################
22
+
23
+ # GlobalParams and BlockArgs: Two namedtuples
24
+ # Swish and MemoryEfficientSwish: Two implementations of the method
25
+ # round_filters and round_repeats:
26
+ # Functions to calculate params for scaling model width and depth ! ! !
27
+ # get_width_and_height_from_size and calculate_output_image_size
28
+ # drop_connect: A structural design
29
+ # get_same_padding_conv2d:
30
+ # Conv2dDynamicSamePadding
31
+ # Conv2dStaticSamePadding
32
+ # get_same_padding_maxPool2d:
33
+ # MaxPool2dDynamicSamePadding
34
+ # MaxPool2dStaticSamePadding
35
+ # It's an additional function, not used in EfficientNet,
36
+ # but can be used in other model (such as EfficientDet).
37
+
38
+ # Parameters for the entire model (stem, all blocks, and head)
39
+ GlobalParams = collections.namedtuple('GlobalParams', [
40
+ 'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate',
41
+ 'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon',
42
+ 'drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top'])
43
+
44
+ # Parameters for an individual model block
45
+ BlockArgs = collections.namedtuple('BlockArgs', [
46
+ 'num_repeat', 'kernel_size', 'stride', 'expand_ratio',
47
+ 'input_filters', 'output_filters', 'se_ratio', 'id_skip'])
48
+
49
+ # Set GlobalParams and BlockArgs's defaults
50
+ GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
51
+ BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
52
+
53
+ # Swish activation function
54
+ if hasattr(nn, 'SiLU'):
55
+ Swish = nn.SiLU
56
+ else:
57
+ # For compatibility with old PyTorch versions
58
+ class Swish(nn.Module):
59
+ def forward(self, x):
60
+ return x * torch.sigmoid(x)
61
+
62
+
63
+ # A memory-efficient implementation of Swish function
64
+ class SwishImplementation(torch.autograd.Function):
65
+ @staticmethod
66
+ def forward(ctx, i):
67
+ result = i * torch.sigmoid(i)
68
+ ctx.save_for_backward(i)
69
+ return result
70
+
71
+ @staticmethod
72
+ def backward(ctx, grad_output):
73
+ i = ctx.saved_tensors[0]
74
+ sigmoid_i = torch.sigmoid(i)
75
+ return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
76
+
77
+
78
+ class MemoryEfficientSwish(nn.Module):
79
+ def forward(self, x):
80
+ return SwishImplementation.apply(x)
81
+
82
+
83
+ def round_filters(filters, global_params):
84
+ """Calculate and round number of filters based on width multiplier.
85
+ Use width_coefficient, depth_divisor and min_depth of global_params.
86
+
87
+ Args:
88
+ filters (int): Filters number to be calculated.
89
+ global_params (namedtuple): Global params of the model.
90
+
91
+ Returns:
92
+ new_filters: New filters number after calculating.
93
+ """
94
+ multiplier = global_params.width_coefficient
95
+ if not multiplier:
96
+ return filters
97
+ # TODO: modify the params names.
98
+ # maybe the names (width_divisor,min_width)
99
+ # are more suitable than (depth_divisor,min_depth).
100
+ divisor = global_params.depth_divisor
101
+ min_depth = global_params.min_depth
102
+ filters *= multiplier
103
+ min_depth = min_depth or divisor # pay attention to this line when using min_depth
104
+ # follow the formula transferred from official TensorFlow implementation
105
+ new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
106
+ if new_filters < 0.9 * filters: # prevent rounding by more than 10%
107
+ new_filters += divisor
108
+ return int(new_filters)
109
+
110
+
111
+ def round_repeats(repeats, global_params):
112
+ """Calculate module's repeat number of a block based on depth multiplier.
113
+ Use depth_coefficient of global_params.
114
+
115
+ Args:
116
+ repeats (int): num_repeat to be calculated.
117
+ global_params (namedtuple): Global params of the model.
118
+
119
+ Returns:
120
+ new repeat: New repeat number after calculating.
121
+ """
122
+ multiplier = global_params.depth_coefficient
123
+ if not multiplier:
124
+ return repeats
125
+ # follow the formula transferred from official TensorFlow implementation
126
+ return int(math.ceil(multiplier * repeats))
127
+
128
+
129
+ def drop_connect(inputs, p, training):
130
+ """Drop connect.
131
+
132
+ Args:
133
+ input (tensor: BCWH): Input of this structure.
134
+ p (float: 0.0~1.0): Probability of drop connection.
135
+ training (bool): The running mode.
136
+
137
+ Returns:
138
+ output: Output after drop connection.
139
+ """
140
+ assert 0 <= p <= 1, 'p must be in range of [0,1]'
141
+
142
+ if not training:
143
+ return inputs
144
+
145
+ batch_size = inputs.shape[0]
146
+ keep_prob = 1 - p
147
+
148
+ # generate binary_tensor mask according to probability (p for 0, 1-p for 1)
149
+ random_tensor = keep_prob
150
+ random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
151
+ binary_tensor = torch.floor(random_tensor)
152
+
153
+ output = inputs / keep_prob * binary_tensor
154
+ return output
155
+
156
+
157
+ def get_width_and_height_from_size(x):
158
+ """Obtain height and width from x.
159
+
160
+ Args:
161
+ x (int, tuple or list): Data size.
162
+
163
+ Returns:
164
+ size: A tuple or list (H,W).
165
+ """
166
+ if isinstance(x, int):
167
+ return x, x
168
+ if isinstance(x, list) or isinstance(x, tuple):
169
+ return x
170
+ else:
171
+ raise TypeError()
172
+
173
+
174
+ def calculate_output_image_size(input_image_size, stride):
175
+ """Calculates the output image size when using Conv2dSamePadding with a stride.
176
+ Necessary for static padding. Thanks to mannatsingh for pointing this out.
177
+
178
+ Args:
179
+ input_image_size (int, tuple or list): Size of input image.
180
+ stride (int, tuple or list): Conv2d operation's stride.
181
+
182
+ Returns:
183
+ output_image_size: A list [H,W].
184
+ """
185
+ if input_image_size is None:
186
+ return None
187
+ image_height, image_width = get_width_and_height_from_size(input_image_size)
188
+ stride = stride if isinstance(stride, int) else stride[0]
189
+ image_height = int(math.ceil(image_height / stride))
190
+ image_width = int(math.ceil(image_width / stride))
191
+ return [image_height, image_width]
192
+
193
+
194
+ # Note:
195
+ # The following 'SamePadding' functions make output size equal ceil(input size/stride).
196
+ # Only when stride equals 1, can the output size be the same as input size.
197
+ # Don't be confused by their function names ! ! !
198
+
199
+ def get_same_padding_conv2d(image_size=None):
200
+ """Chooses static padding if you have specified an image size, and dynamic padding otherwise.
201
+ Static padding is necessary for ONNX exporting of models.
202
+
203
+ Args:
204
+ image_size (int or tuple): Size of the image.
205
+
206
+ Returns:
207
+ Conv2dDynamicSamePadding or Conv2dStaticSamePadding.
208
+ """
209
+ if image_size is None:
210
+ return Conv2dDynamicSamePadding
211
+ else:
212
+ return partial(Conv2dStaticSamePadding, image_size=image_size)
213
+
214
+
215
+ class Conv2dDynamicSamePadding(nn.Conv2d):
216
+ """2D Convolutions like TensorFlow, for a dynamic image size.
217
+ The padding is operated in forward function by calculating dynamically.
218
+ """
219
+
220
+ # Tips for 'SAME' mode padding.
221
+ # Given the following:
222
+ # i: width or height
223
+ # s: stride
224
+ # k: kernel size
225
+ # d: dilation
226
+ # p: padding
227
+ # Output after Conv2d:
228
+ # o = floor((i+p-((k-1)*d+1))/s+1)
229
+ # If o equals i, i = floor((i+p-((k-1)*d+1))/s+1),
230
+ # => p = (i-1)*s+((k-1)*d+1)-i
231
+
232
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
233
+ super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
234
+ self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
235
+
236
+ def forward(self, x):
237
+ ih, iw = x.size()[-2:]
238
+ kh, kw = self.weight.size()[-2:]
239
+ sh, sw = self.stride
240
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) # change the output size according to stride ! ! !
241
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
242
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
243
+ if pad_h > 0 or pad_w > 0:
244
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
245
+ return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
246
+
247
+
248
+ class Conv2dStaticSamePadding(nn.Conv2d):
249
+ """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size.
250
+ The padding mudule is calculated in construction function, then used in forward.
251
+ """
252
+
253
+ # With the same calculation as Conv2dDynamicSamePadding
254
+
255
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs):
256
+ super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs)
257
+ self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
258
+
259
+ # Calculate padding based on image size and save it
260
+ assert image_size is not None
261
+ ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
262
+ kh, kw = self.weight.size()[-2:]
263
+ sh, sw = self.stride
264
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
265
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
266
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
267
+ if pad_h > 0 or pad_w > 0:
268
+ self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2,
269
+ pad_h // 2, pad_h - pad_h // 2))
270
+ else:
271
+ self.static_padding = nn.Identity()
272
+
273
+ def forward(self, x):
274
+ x = self.static_padding(x)
275
+ x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
276
+ return x
277
+
278
+
279
+ def get_same_padding_maxPool2d(image_size=None):
280
+ """Chooses static padding if you have specified an image size, and dynamic padding otherwise.
281
+ Static padding is necessary for ONNX exporting of models.
282
+
283
+ Args:
284
+ image_size (int or tuple): Size of the image.
285
+
286
+ Returns:
287
+ MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding.
288
+ """
289
+ if image_size is None:
290
+ return MaxPool2dDynamicSamePadding
291
+ else:
292
+ return partial(MaxPool2dStaticSamePadding, image_size=image_size)
293
+
294
+
295
+ class MaxPool2dDynamicSamePadding(nn.MaxPool2d):
296
+ """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size.
297
+ The padding is operated in forward function by calculating dynamically.
298
+ """
299
+
300
+ def __init__(self, kernel_size, stride, padding=0, dilation=1, return_indices=False, ceil_mode=False):
301
+ super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode)
302
+ self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
303
+ self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size
304
+ self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
305
+
306
+ def forward(self, x):
307
+ ih, iw = x.size()[-2:]
308
+ kh, kw = self.kernel_size
309
+ sh, sw = self.stride
310
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
311
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
312
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
313
+ if pad_h > 0 or pad_w > 0:
314
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
315
+ return F.max_pool2d(x, self.kernel_size, self.stride, self.padding,
316
+ self.dilation, self.ceil_mode, self.return_indices)
317
+
318
+
319
+ class MaxPool2dStaticSamePadding(nn.MaxPool2d):
320
+ """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size.
321
+ The padding mudule is calculated in construction function, then used in forward.
322
+ """
323
+
324
+ def __init__(self, kernel_size, stride, image_size=None, **kwargs):
325
+ super().__init__(kernel_size, stride, **kwargs)
326
+ self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
327
+ self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size
328
+ self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
329
+
330
+ # Calculate padding based on image size and save it
331
+ assert image_size is not None
332
+ ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
333
+ kh, kw = self.kernel_size
334
+ sh, sw = self.stride
335
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
336
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
337
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
338
+ if pad_h > 0 or pad_w > 0:
339
+ self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
340
+ else:
341
+ self.static_padding = nn.Identity()
342
+
343
+ def forward(self, x):
344
+ x = self.static_padding(x)
345
+ x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding,
346
+ self.dilation, self.ceil_mode, self.return_indices)
347
+ return x
348
+
349
+
350
+ ################################################################################
351
+ # Helper functions for loading model params
352
+ ################################################################################
353
+
354
+ # BlockDecoder: A Class for encoding and decoding BlockArgs
355
+ # efficientnet_params: A function to query compound coefficient
356
+ # get_model_params and efficientnet:
357
+ # Functions to get BlockArgs and GlobalParams for efficientnet
358
+ # url_map and url_map_advprop: Dicts of url_map for pretrained weights
359
+ # load_pretrained_weights: A function to load pretrained weights
360
+
361
+ class BlockDecoder(object):
362
+ """Block Decoder for readability,
363
+ straight from the official TensorFlow repository.
364
+ """
365
+
366
+ @staticmethod
367
+ def _decode_block_string(block_string):
368
+ """Get a block through a string notation of arguments.
369
+
370
+ Args:
371
+ block_string (str): A string notation of arguments.
372
+ Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'.
373
+
374
+ Returns:
375
+ BlockArgs: The namedtuple defined at the top of this file.
376
+ """
377
+ assert isinstance(block_string, str)
378
+
379
+ ops = block_string.split('_')
380
+ options = {}
381
+ for op in ops:
382
+ splits = re.split(r'(\d.*)', op)
383
+ if len(splits) >= 2:
384
+ key, value = splits[:2]
385
+ options[key] = value
386
+
387
+ # Check stride
388
+ assert (('s' in options and len(options['s']) == 1) or
389
+ (len(options['s']) == 2 and options['s'][0] == options['s'][1]))
390
+
391
+ return BlockArgs(
392
+ num_repeat=int(options['r']),
393
+ kernel_size=int(options['k']),
394
+ stride=[int(options['s'][0])],
395
+ expand_ratio=int(options['e']),
396
+ input_filters=int(options['i']),
397
+ output_filters=int(options['o']),
398
+ se_ratio=float(options['se']) if 'se' in options else None,
399
+ id_skip=('noskip' not in block_string))
400
+
401
+ @staticmethod
402
+ def _encode_block_string(block):
403
+ """Encode a block to a string.
404
+
405
+ Args:
406
+ block (namedtuple): A BlockArgs type argument.
407
+
408
+ Returns:
409
+ block_string: A String form of BlockArgs.
410
+ """
411
+ args = [
412
+ 'r%d' % block.num_repeat,
413
+ 'k%d' % block.kernel_size,
414
+ 's%d%d' % (block.strides[0], block.strides[1]),
415
+ 'e%s' % block.expand_ratio,
416
+ 'i%d' % block.input_filters,
417
+ 'o%d' % block.output_filters
418
+ ]
419
+ if 0 < block.se_ratio <= 1:
420
+ args.append('se%s' % block.se_ratio)
421
+ if block.id_skip is False:
422
+ args.append('noskip')
423
+ return '_'.join(args)
424
+
425
+ @staticmethod
426
+ def decode(string_list):
427
+ """Decode a list of string notations to specify blocks inside the network.
428
+
429
+ Args:
430
+ string_list (list[str]): A list of strings, each string is a notation of block.
431
+
432
+ Returns:
433
+ blocks_args: A list of BlockArgs namedtuples of block args.
434
+ """
435
+ assert isinstance(string_list, list)
436
+ blocks_args = []
437
+ for block_string in string_list:
438
+ blocks_args.append(BlockDecoder._decode_block_string(block_string))
439
+ return blocks_args
440
+
441
+ @staticmethod
442
+ def encode(blocks_args):
443
+ """Encode a list of BlockArgs to a list of strings.
444
+
445
+ Args:
446
+ blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args.
447
+
448
+ Returns:
449
+ block_strings: A list of strings, each string is a notation of block.
450
+ """
451
+ block_strings = []
452
+ for block in blocks_args:
453
+ block_strings.append(BlockDecoder._encode_block_string(block))
454
+ return block_strings
455
+
456
+
457
+ def efficientnet_params(model_name):
458
+ """Map EfficientNet model name to parameter coefficients.
459
+
460
+ Args:
461
+ model_name (str): Model name to be queried.
462
+
463
+ Returns:
464
+ params_dict[model_name]: A (width,depth,res,dropout) tuple.
465
+ """
466
+ params_dict = {
467
+ # Coefficients: width,depth,res,dropout
468
+ 'efficientnet-b0': (1.0, 1.0, 224, 0.2),
469
+ 'efficientnet-b1': (1.0, 1.1, 240, 0.2),
470
+ 'efficientnet-b2': (1.1, 1.2, 260, 0.3),
471
+ 'efficientnet-b3': (1.2, 1.4, 300, 0.3),
472
+ 'efficientnet-b4': (1.4, 1.8, 380, 0.4),
473
+ 'efficientnet-b5': (1.6, 2.2, 456, 0.4),
474
+ 'efficientnet-b6': (1.8, 2.6, 528, 0.5),
475
+ 'efficientnet-b7': (2.0, 3.1, 600, 0.5),
476
+ 'efficientnet-b8': (2.2, 3.6, 672, 0.5),
477
+ 'efficientnet-l2': (4.3, 5.3, 800, 0.5),
478
+ }
479
+ return params_dict[model_name]
480
+
481
+
482
+ def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None,
483
+ dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000, include_top=True):
484
+ """Create BlockArgs and GlobalParams for efficientnet model.
485
+
486
+ Args:
487
+ width_coefficient (float)
488
+ depth_coefficient (float)
489
+ image_size (int)
490
+ dropout_rate (float)
491
+ drop_connect_rate (float)
492
+ num_classes (int)
493
+
494
+ Meaning as the name suggests.
495
+
496
+ Returns:
497
+ blocks_args, global_params.
498
+ """
499
+
500
+ # Blocks args for the whole model(efficientnet-b0 by default)
501
+ # It will be modified in the construction of EfficientNet Class according to model
502
+ blocks_args = [
503
+ 'r1_k3_s11_e1_i32_o16_se0.25',
504
+ 'r2_k3_s22_e6_i16_o24_se0.25',
505
+ 'r2_k5_s22_e6_i24_o40_se0.25',
506
+ 'r3_k3_s22_e6_i40_o80_se0.25',
507
+ 'r3_k5_s11_e6_i80_o112_se0.25',
508
+ 'r4_k5_s22_e6_i112_o192_se0.25',
509
+ 'r1_k3_s11_e6_i192_o320_se0.25',
510
+ ]
511
+ blocks_args = BlockDecoder.decode(blocks_args)
512
+
513
+ global_params = GlobalParams(
514
+ width_coefficient=width_coefficient,
515
+ depth_coefficient=depth_coefficient,
516
+ image_size=image_size,
517
+ dropout_rate=dropout_rate,
518
+
519
+ num_classes=num_classes,
520
+ batch_norm_momentum=0.99,
521
+ batch_norm_epsilon=1e-3,
522
+ drop_connect_rate=drop_connect_rate,
523
+ depth_divisor=8,
524
+ min_depth=None,
525
+ include_top=include_top,
526
+ )
527
+
528
+ return blocks_args, global_params
529
+
530
+
531
+ def get_model_params(model_name, override_params):
532
+ """Get the block args and global params for a given model name.
533
+
534
+ Args:
535
+ model_name (str): Model's name.
536
+ override_params (dict): A dict to modify global_params.
537
+
538
+ Returns:
539
+ blocks_args, global_params
540
+ """
541
+ if model_name.startswith('efficientnet'):
542
+ w, d, s, p = efficientnet_params(model_name)
543
+ # note: all models have drop connect rate = 0.2
544
+ blocks_args, global_params = efficientnet(
545
+ width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
546
+ else:
547
+ raise NotImplementedError('model name is not pre-defined: {}'.format(model_name))
548
+ if override_params:
549
+ # ValueError will be raised here if override_params has fields not included in global_params.
550
+ global_params = global_params._replace(**override_params)
551
+ return blocks_args, global_params
552
+
553
+
554
+ # train with Standard methods
555
+ # check more details in paper(EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks)
556
+ url_map = {
557
+ 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth',
558
+ 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth',
559
+ 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth',
560
+ 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth',
561
+ 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth',
562
+ 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth',
563
+ 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth',
564
+ 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth',
565
+ }
566
+
567
+ # train with Adversarial Examples(AdvProp)
568
+ # check more details in paper(Adversarial Examples Improve Image Recognition)
569
+ url_map_advprop = {
570
+ 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth',
571
+ 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth',
572
+ 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth',
573
+ 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth',
574
+ 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth',
575
+ 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth',
576
+ 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth',
577
+ 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth',
578
+ 'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth',
579
+ }
580
+
581
+ # TODO: add the petrained weights url map of 'efficientnet-l2'
582
+
583
+
584
+ def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, advprop=False, verbose=True):
585
+ """Loads pretrained weights from weights path or download using url.
586
+
587
+ Args:
588
+ model (Module): The whole model of efficientnet.
589
+ model_name (str): Model name of efficientnet.
590
+ weights_path (None or str):
591
+ str: path to pretrained weights file on the local disk.
592
+ None: use pretrained weights downloaded from the Internet.
593
+ load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model.
594
+ advprop (bool): Whether to load pretrained weights
595
+ trained with advprop (valid when weights_path is None).
596
+ """
597
+ if isinstance(weights_path, str):
598
+ state_dict = torch.load(weights_path)
599
+ else:
600
+ # AutoAugment or Advprop (different preprocessing)
601
+ url_map_ = url_map_advprop if advprop else url_map
602
+ state_dict = model_zoo.load_url(url_map_[model_name])
603
+
604
+ if load_fc:
605
+ ret = model.load_state_dict(state_dict, strict=False)
606
+ assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
607
+ else:
608
+ state_dict.pop('_fc.weight')
609
+ state_dict.pop('_fc.bias')
610
+ ret = model.load_state_dict(state_dict, strict=False)
611
+ assert set(ret.missing_keys) == set(
612
+ ['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
613
+ assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys)
614
+
615
+ if verbose:
616
+ print('Loaded pretrained weights for {}'.format(model_name))
deepafx_st/models/encoder.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from deepafx_st.models.mobilenetv2 import MobileNetV2
4
+ from deepafx_st.models.efficient_net import EfficientNet
5
+
6
+
7
+ class SpectralEncoder(torch.nn.Module):
8
+ def __init__(
9
+ self,
10
+ num_params,
11
+ sample_rate,
12
+ encoder_model="mobilenet_v2",
13
+ embed_dim=1028,
14
+ width_mult=1,
15
+ min_level_db=-80,
16
+ ):
17
+ """Encoder operating on spectrograms.
18
+
19
+ Args:
20
+ num_params (int): Number of processor parameters to generate.
21
+ sample_rate (float): Audio sample rate for computing melspectrogram.
22
+ encoder_model (str, optional): Encoder model architecture. Default: "mobilenet_v2"
23
+ embed_dim (int, optional): Dimentionality of the encoder representations.
24
+ width_mult (int, optional): Encoder size. Default: 1
25
+ min_level_db (float, optional): Minimal dB value for the spectrogram. Default: -80
26
+ """
27
+ super().__init__()
28
+ self.num_params = num_params
29
+ self.sample_rate = sample_rate
30
+ self.encoder_model = encoder_model
31
+ self.embed_dim = embed_dim
32
+ self.width_mult = width_mult
33
+ self.min_level_db = min_level_db
34
+
35
+ # load model from torch.hub
36
+ if encoder_model == "mobilenet_v2":
37
+ self.encoder = MobileNetV2(embed_dim=embed_dim, width_mult=width_mult)
38
+ elif encoder_model == "efficient_net":
39
+ self.encoder = EfficientNet.from_name(
40
+ "efficientnet-b2",
41
+ in_channels=1,
42
+ image_size=(128, 65),
43
+ include_top=False,
44
+ )
45
+ self.embedding_projection = torch.nn.Conv2d(
46
+ in_channels=1408,
47
+ out_channels=embed_dim,
48
+ kernel_size=(1, 1),
49
+ stride=(1, 1),
50
+ padding=(0, 0),
51
+ bias=True,
52
+ )
53
+
54
+ else:
55
+ raise ValueError(f"Invalid encoder_model: {encoder_model}.")
56
+
57
+ self.window = torch.nn.Parameter(torch.hann_window(4096))
58
+
59
+ def forward(self, x):
60
+ """
61
+ Args:
62
+ x (Tensor): Input waveform of shape [batch x channels x samples]
63
+
64
+ Returns:
65
+ e (Tensor): Latent embedding produced by Encoder. [batch x embed_dim]
66
+ """
67
+ bs, chs, samp = x.size()
68
+
69
+ # compute spectrogram of waveform
70
+ X = torch.stft(
71
+ x.view(bs, -1),
72
+ 4096,
73
+ 2048,
74
+ window=self.window,
75
+ return_complex=True,
76
+ )
77
+ X_db = torch.pow(X.abs() + 1e-8, 0.3)
78
+ X_db_norm = X_db
79
+
80
+ # standardize (0, 1) 0.322970 0.278452
81
+ X_db_norm -= 0.322970
82
+ X_db_norm /= 0.278452
83
+ X_db_norm = X_db_norm.unsqueeze(1).permute(0, 1, 3, 2)
84
+
85
+ if self.encoder_model == "mobilenet_v2":
86
+ # repeat channels by 3 to fit vision model
87
+ X_db_norm = X_db_norm.repeat(1, 3, 1, 1)
88
+
89
+ # pass melspectrogram through encoder
90
+ e = self.encoder(X_db_norm)
91
+
92
+ # apply avg pooling across time for encoder embeddings
93
+ e = torch.nn.functional.adaptive_avg_pool2d(e, 1).reshape(e.shape[0], -1)
94
+
95
+ # normalize by L2 norm
96
+ norm = torch.norm(e, p=2, dim=-1, keepdim=True)
97
+ e_norm = e / norm
98
+
99
+ elif self.encoder_model == "efficient_net":
100
+
101
+ # Efficient Net internal downsamples by 32 on time and freq axis, then average pools the rest
102
+ e = self.encoder(X_db_norm)
103
+
104
+ # Adding 1x1 conv to project down or up to the requested embedding size
105
+ e = self.embedding_projection(e)
106
+ e = torch.squeeze(e, dim=3)
107
+ e = torch.squeeze(e, dim=2)
108
+
109
+ # normalize by L2 norm
110
+ norm = torch.norm(e, p=2, dim=-1, keepdim=True)
111
+ e_norm = e / norm
112
+
113
+ return e_norm
deepafx_st/models/mobilenetv2.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BSD 3-Clause License
2
+
3
+ # Copyright (c) Soumith Chintala 2016,
4
+ # All rights reserved.
5
+
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+
9
+ # * Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+
12
+ # * Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+
16
+ # * Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+
31
+ # Adaptation of the PyTorch torchvision MobileNetV2 without a classifier.
32
+ # See source here: https://pytorch.org/vision/0.8/_modules/torchvision/models/mobilenet.html#mobilenet_v2
33
+ from torch import nn
34
+
35
+
36
+ def _make_divisible(v, divisor, min_value=None):
37
+ """
38
+ This function is taken from the original tf repo.
39
+ It ensures that all layers have a channel number that is divisible by 8
40
+ It can be seen here:
41
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
42
+ :param v:
43
+ :param divisor:
44
+ :param min_value:
45
+ :return:
46
+ """
47
+ if min_value is None:
48
+ min_value = divisor
49
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
50
+ # Make sure that round down does not go down by more than 10%.
51
+ if new_v < 0.9 * v:
52
+ new_v += divisor
53
+ return new_v
54
+
55
+
56
+ class ConvBNReLU(nn.Sequential):
57
+ def __init__(
58
+ self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=None
59
+ ):
60
+ padding = (kernel_size - 1) // 2
61
+ if norm_layer is None:
62
+ norm_layer = nn.BatchNorm2d
63
+ super(ConvBNReLU, self).__init__(
64
+ nn.Conv2d(
65
+ in_planes,
66
+ out_planes,
67
+ kernel_size,
68
+ stride,
69
+ padding,
70
+ groups=groups,
71
+ bias=False,
72
+ ),
73
+ norm_layer(out_planes),
74
+ nn.ReLU6(inplace=True),
75
+ )
76
+
77
+
78
+ class InvertedResidual(nn.Module):
79
+ def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None):
80
+ super(InvertedResidual, self).__init__()
81
+ self.stride = stride
82
+ assert stride in [1, 2]
83
+
84
+ if norm_layer is None:
85
+ norm_layer = nn.BatchNorm2d
86
+
87
+ hidden_dim = int(round(inp * expand_ratio))
88
+ self.use_res_connect = self.stride == 1 and inp == oup
89
+
90
+ layers = []
91
+ if expand_ratio != 1:
92
+ # pw
93
+ layers.append(
94
+ ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer)
95
+ )
96
+ layers.extend(
97
+ [
98
+ # dw
99
+ ConvBNReLU(
100
+ hidden_dim,
101
+ hidden_dim,
102
+ stride=stride,
103
+ groups=hidden_dim,
104
+ norm_layer=norm_layer,
105
+ ),
106
+ # pw-linear
107
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
108
+ norm_layer(oup),
109
+ ]
110
+ )
111
+ self.conv = nn.Sequential(*layers)
112
+
113
+ def forward(self, x):
114
+ if self.use_res_connect:
115
+ return x + self.conv(x)
116
+ else:
117
+ return self.conv(x)
118
+
119
+
120
+ class MobileNetV2(nn.Module):
121
+ def __init__(
122
+ self,
123
+ embed_dim=1028,
124
+ width_mult=1.0,
125
+ inverted_residual_setting=None,
126
+ round_nearest=8,
127
+ block=None,
128
+ norm_layer=None,
129
+ ):
130
+ """
131
+ MobileNet V2 main class
132
+
133
+ Args:
134
+ embed_dim (int): Number of channels in the final output.
135
+ width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
136
+ inverted_residual_setting: Network structure
137
+ round_nearest (int): Round the number of channels in each layer to be a multiple of this number
138
+ Set to 1 to turn off rounding
139
+ block: Module specifying inverted residual building block for mobilenet
140
+ norm_layer: Module specifying the normalization layer to use
141
+
142
+ """
143
+ super(MobileNetV2, self).__init__()
144
+
145
+ if block is None:
146
+ block = InvertedResidual
147
+
148
+ if norm_layer is None:
149
+ norm_layer = nn.BatchNorm2d
150
+
151
+ input_channel = 32
152
+ last_channel = embed_dim / width_mult
153
+
154
+ if inverted_residual_setting is None:
155
+ inverted_residual_setting = [
156
+ # t, c, n, s
157
+ [1, 16, 1, 1],
158
+ [6, 24, 2, 2],
159
+ [6, 32, 3, 2],
160
+ [6, 64, 4, 2],
161
+ [6, 96, 3, 1],
162
+ [6, 160, 3, 2],
163
+ [6, 320, 1, 1],
164
+ ]
165
+
166
+ # only check the first element, assuming user knows t,c,n,s are required
167
+ if (
168
+ len(inverted_residual_setting) == 0
169
+ or len(inverted_residual_setting[0]) != 4
170
+ ):
171
+ raise ValueError(
172
+ "inverted_residual_setting should be non-empty "
173
+ "or a 4-element list, got {}".format(inverted_residual_setting)
174
+ )
175
+
176
+ # building first layer
177
+ input_channel = _make_divisible(input_channel * width_mult, round_nearest)
178
+ self.last_channel = _make_divisible(
179
+ last_channel * max(1.0, width_mult), round_nearest
180
+ )
181
+ features = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
182
+ # building inverted residual blocks
183
+ for t, c, n, s in inverted_residual_setting:
184
+ output_channel = _make_divisible(c * width_mult, round_nearest)
185
+ for i in range(n):
186
+ stride = s if i == 0 else 1
187
+ features.append(
188
+ block(
189
+ input_channel,
190
+ output_channel,
191
+ stride,
192
+ expand_ratio=t,
193
+ norm_layer=norm_layer,
194
+ )
195
+ )
196
+ input_channel = output_channel
197
+ # building last several layers
198
+ features.append(
199
+ ConvBNReLU(
200
+ input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer
201
+ )
202
+ )
203
+ # make it nn.Sequential
204
+ self.features = nn.Sequential(*features)
205
+
206
+ # weight initialization
207
+ for m in self.modules():
208
+ if isinstance(m, nn.Conv2d):
209
+ nn.init.kaiming_normal_(m.weight, mode="fan_out")
210
+ if m.bias is not None:
211
+ nn.init.zeros_(m.bias)
212
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
213
+ nn.init.ones_(m.weight)
214
+ nn.init.zeros_(m.bias)
215
+ elif isinstance(m, nn.Linear):
216
+ nn.init.normal_(m.weight, 0, 0.01)
217
+ nn.init.zeros_(m.bias)
218
+
219
+ def _forward_impl(self, x):
220
+ # This exists since TorchScript doesn't support inheritance, so the superclass method
221
+ # (this one) needs to have a name other than `forward` that can be accessed in a subclass
222
+ return self.features(x)
223
+ # return the features directly, no classifier or pooling
224
+
225
+ def forward(self, x):
226
+ return self._forward_impl(x)
deepafx_st/probes/cdpam_encoder.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2021 Pranay Manocha
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # code adapated from https://github.com/pranaymanocha/PerceptualAudio
24
+
25
+ import cdpam
26
+ import torch
27
+
28
+
29
+ class CDPAMEncoder(torch.nn.Module):
30
+ def __init__(self, cdpam_ckpt: str):
31
+ super().__init__()
32
+
33
+ # pre-trained model parameterss
34
+ encoder_layers = 16
35
+ encoder_filters = 64
36
+ input_size = 512
37
+ proj_ndim = [512, 256]
38
+ ndim = [16, 6]
39
+ classif_BN = 0
40
+ classif_act = "no"
41
+ proj_dp = 0.1
42
+ proj_BN = 1
43
+ classif_dp = 0.05
44
+
45
+ model = cdpam.models.FINnet(
46
+ encoder_layers=encoder_layers,
47
+ encoder_filters=encoder_filters,
48
+ ndim=ndim,
49
+ classif_dp=classif_dp,
50
+ classif_BN=classif_BN,
51
+ classif_act=classif_act,
52
+ input_size=input_size,
53
+ )
54
+
55
+ state = torch.load(cdpam_ckpt, map_location="cpu")["state"]
56
+ model.load_state_dict(state)
57
+ model.eval()
58
+
59
+ self.model = model
60
+ self.embed_dim = 512
61
+
62
+ def forward(self, x):
63
+
64
+ with torch.no_grad():
65
+ _, a1, c1 = self.model.base_encoder.forward(x)
66
+ a1 = torch.nn.functional.normalize(a1, dim=1)
67
+
68
+ return a1
deepafx_st/probes/probe_system.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import julius
3
+ import torchopenl3
4
+ import torchmetrics
5
+ import pytorch_lightning as pl
6
+ from typing import Tuple, List, Dict
7
+ from argparse import ArgumentParser
8
+
9
+ from deepafx_st.probes.cdpam_encoder import CDPAMEncoder
10
+ from deepafx_st.probes.random_mel import RandomMelProjection
11
+
12
+ import deepafx_st.utils as utils
13
+ from deepafx_st.utils import DSPMode
14
+ from deepafx_st.system import System
15
+ from deepafx_st.data.style import StyleDataset
16
+
17
+
18
+ class ProbeSystem(pl.LightningModule):
19
+ def __init__(
20
+ self,
21
+ audio_dir=None,
22
+ num_classes=5,
23
+ task="style",
24
+ encoder_type="deepafx_st_autodiff",
25
+ deepafx_st_autodiff_ckpt=None,
26
+ deepafx_st_spsa_ckpt=None,
27
+ deepafx_st_proxy0_ckpt=None,
28
+ probe_type="linear",
29
+ batch_size=32,
30
+ lr=3e-4,
31
+ lr_patience=20,
32
+ patience=10,
33
+ preload=False,
34
+ sample_rate=24000,
35
+ shuffle=True,
36
+ num_workers=16,
37
+ **kwargs,
38
+ ):
39
+ super().__init__()
40
+ self.save_hyperparameters()
41
+
42
+ if "deepafx_st" in self.hparams.encoder_type:
43
+
44
+ if "autodiff" in self.hparams.encoder_type:
45
+ self.hparams.deepafx_st_ckpt = self.hparams.deepafx_st_autodiff_ckpt
46
+ elif "spsa" in self.hparams.encoder_type:
47
+ self.hparams.deepafx_st_ckpt = self.hparams.deepafx_st_spsa_ckpt
48
+ elif "proxy0" in self.hparams.encoder_type:
49
+ self.hparams.deepafx_st_ckpt = self.hparams.deepafx_st_proxy0_ckpt
50
+
51
+ else:
52
+ raise RuntimeError(f"Invalid encoder_type: {self.hparams.encoder_type}")
53
+
54
+ if self.hparams.deepafx_st_ckpt is None:
55
+ raise RuntimeError(
56
+ f"Must supply {self.hparams.encoder_type}_ckpt checkpoint."
57
+ )
58
+ use_dsp = DSPMode.NONE
59
+ system = System.load_from_checkpoint(
60
+ self.hparams.deepafx_st_ckpt,
61
+ use_dsp=use_dsp,
62
+ batch_size=self.hparams.batch_size,
63
+ spsa_parallel=False,
64
+ proxy_ckpts=[],
65
+ strict=False,
66
+ )
67
+ system.eval()
68
+ self.encoder = system.encoder
69
+ self.hparams.embed_dim = self.encoder.embed_dim
70
+
71
+ # freeze weights
72
+ for name, param in self.encoder.named_parameters():
73
+ param.requires_grad = False
74
+
75
+ elif self.hparams.encoder_type == "openl3":
76
+ self.encoder = torchopenl3.models.load_audio_embedding_model(
77
+ input_repr=self.hparams.openl3_input_repr,
78
+ embedding_size=self.hparams.openl3_embedding_size,
79
+ content_type=self.hparams.openl3_content_type,
80
+ )
81
+ self.hparams.embed_dim = 6144
82
+ elif self.hparams.encoder_type == "random_mel":
83
+ self.encoder = RandomMelProjection(
84
+ self.hparams.sample_rate,
85
+ self.hparams.random_mel_embedding_size,
86
+ self.hparams.random_mel_n_mels,
87
+ self.hparams.random_mel_n_fft,
88
+ self.hparams.random_mel_hop_size,
89
+ )
90
+ self.hparams.embed_dim = self.hparams.random_mel_embedding_size
91
+ elif self.hparams.encoder_type == "cdpam":
92
+ self.encoder = CDPAMEncoder(self.hparams.cdpam_ckpt)
93
+ self.encoder.eval()
94
+ self.hparams.embed_dim = self.encoder.embed_dim
95
+ else:
96
+ raise ValueError(f"Invalid encoder_type: {self.hparams.encoder_type}")
97
+
98
+ if self.hparams.probe_type == "linear":
99
+ if self.hparams.task == "style":
100
+ self.probe = torch.nn.Sequential(
101
+ torch.nn.Linear(self.hparams.embed_dim, self.hparams.num_classes),
102
+ # torch.nn.Softmax(-1),
103
+ )
104
+ elif self.hparams.probe_type == "mlp":
105
+ if self.hparams.task == "style":
106
+ self.probe = torch.nn.Sequential(
107
+ torch.nn.Linear(self.hparams.embed_dim, 512),
108
+ torch.nn.ReLU(),
109
+ torch.nn.Linear(512, 512),
110
+ torch.nn.ReLU(),
111
+ torch.nn.Linear(512, self.hparams.num_classes),
112
+ )
113
+ self.accuracy = torchmetrics.Accuracy()
114
+ self.f1_score = torchmetrics.F1Score(self.hparams.num_classes)
115
+
116
+ def forward(self, x):
117
+ bs, chs, samp = x.size()
118
+ with torch.no_grad():
119
+ if "deepafx_st" in self.hparams.encoder_type:
120
+ x /= x.abs().max()
121
+ x *= 10 ** (-12.0 / 20) # with min 12 dBFS headroom
122
+ e = self.encoder(x)
123
+ norm = torch.norm(e, p=2, dim=-1, keepdim=True)
124
+ e = e / norm
125
+ elif self.hparams.encoder_type == "openl3":
126
+ # x = julius.resample_frac(x, self.hparams.sample_rate, 48000)
127
+ e, ts = torchopenl3.get_audio_embedding(
128
+ x,
129
+ 48000,
130
+ model=self.encoder,
131
+ input_repr="mel128",
132
+ content_type="music",
133
+ )
134
+ e = e.permute(0, 2, 1)
135
+ e = e.mean(dim=-1)
136
+ # normalize by L2 norm
137
+ norm = torch.norm(e, p=2, dim=-1, keepdim=True)
138
+ e = e / norm
139
+ elif self.hparams.encoder_type == "random_mel":
140
+ e = self.encoder(x)
141
+ norm = torch.norm(e, p=2, dim=-1, keepdim=True)
142
+ e = e / norm
143
+ elif self.hparams.encoder_type == "cdpam":
144
+ # x = julius.resample_frac(x, self.hparams.sample_rate, 22050)
145
+ x = torch.round(x * 32768)
146
+ e = self.encoder(x)
147
+
148
+ return self.probe(e)
149
+
150
+ def common_step(
151
+ self,
152
+ batch: Tuple,
153
+ batch_idx: int,
154
+ optimizer_idx: int = 0,
155
+ train: bool = True,
156
+ ):
157
+ loss = 0
158
+ x, y = batch
159
+
160
+ y_hat = self(x)
161
+
162
+ # compute CE
163
+ if self.hparams.task == "style":
164
+ loss = torch.nn.functional.cross_entropy(y_hat, y)
165
+
166
+ if not train:
167
+ # store audio data
168
+ data_dict = {"x": x.float().cpu()}
169
+ else:
170
+ data_dict = {}
171
+
172
+ self.log(
173
+ "train_loss" if train else "val_loss",
174
+ loss,
175
+ on_step=True,
176
+ on_epoch=True,
177
+ prog_bar=False,
178
+ logger=True,
179
+ sync_dist=True,
180
+ )
181
+
182
+ if not train and self.hparams.task == "style":
183
+ self.log("val_acc_step", self.accuracy(y_hat, y))
184
+ self.log("val_f1_step", self.f1_score(y_hat, y))
185
+
186
+ return loss, data_dict
187
+
188
+ def training_step(self, batch, batch_idx, optimizer_idx=0):
189
+ loss, _ = self.common_step(batch, batch_idx)
190
+ return loss
191
+
192
+ def validation_step(self, batch, batch_idx):
193
+ loss, data_dict = self.common_step(batch, batch_idx, train=False)
194
+
195
+ if batch_idx == 0:
196
+ return data_dict
197
+
198
+ def validation_epoch_end(self, outputs) -> None:
199
+ if self.hparams.task == "style":
200
+ self.log("val_acc_epoch", self.accuracy.compute())
201
+ self.log("val_f1_epoch", self.f1_score.compute())
202
+
203
+ return super().validation_epoch_end(outputs)
204
+
205
+ def configure_optimizers(self):
206
+ optimizer = torch.optim.AdamW(
207
+ self.probe.parameters(),
208
+ lr=self.hparams.lr,
209
+ betas=(0.9, 0.999),
210
+ )
211
+
212
+ ms1 = int(self.hparams.max_epochs * 0.8)
213
+ ms2 = int(self.hparams.max_epochs * 0.95)
214
+ print(
215
+ "Learning rate schedule:",
216
+ f"0 {self.hparams.lr:0.2e} -> ",
217
+ f"{ms1} {self.hparams.lr*0.1:0.2e} -> ",
218
+ f"{ms2} {self.hparams.lr*0.01:0.2e}",
219
+ )
220
+ scheduler = torch.optim.lr_scheduler.MultiStepLR(
221
+ optimizer,
222
+ milestones=[ms1, ms2],
223
+ gamma=0.1,
224
+ )
225
+
226
+ return [optimizer], {"scheduler": scheduler, "monitor": "val_loss"}
227
+
228
+ def train_dataloader(self):
229
+
230
+ if self.hparams.task == "style":
231
+ train_dataset = StyleDataset(
232
+ self.hparams.audio_dir,
233
+ "train",
234
+ sample_rate=self.hparams.encoder_sample_rate,
235
+ )
236
+
237
+ g = torch.Generator()
238
+ g.manual_seed(0)
239
+
240
+ return torch.utils.data.DataLoader(
241
+ train_dataset,
242
+ num_workers=self.hparams.num_workers,
243
+ batch_size=self.hparams.batch_size,
244
+ shuffle=True,
245
+ worker_init_fn=utils.seed_worker,
246
+ generator=g,
247
+ pin_memory=True,
248
+ )
249
+
250
+ def val_dataloader(self):
251
+
252
+ if self.hparams.task == "style":
253
+ val_dataset = StyleDataset(
254
+ self.hparams.audio_dir,
255
+ subset="val",
256
+ sample_rate=self.hparams.encoder_sample_rate,
257
+ )
258
+
259
+ g = torch.Generator()
260
+ g.manual_seed(0)
261
+
262
+ return torch.utils.data.DataLoader(
263
+ val_dataset,
264
+ num_workers=self.hparams.num_workers,
265
+ batch_size=self.hparams.batch_size,
266
+ worker_init_fn=utils.seed_worker,
267
+ generator=g,
268
+ pin_memory=True,
269
+ )
270
+
271
+ # add any model hyperparameters here
272
+ @staticmethod
273
+ def add_model_specific_args(parent_parser):
274
+ parser = ArgumentParser(parents=[parent_parser], add_help=False)
275
+ # --- Model ---
276
+ parser.add_argument("--encoder_type", type=str, default="deeapfx2")
277
+ parser.add_argument("--probe_type", type=str, default="linear")
278
+ parser.add_argument("--task", type=str, default="style")
279
+ parser.add_argument("--encoder_sample_rate", type=int, default=24000)
280
+ # --- deeapfx2 ---
281
+ parser.add_argument("--deepafx_st_autodiff_ckpt", type=str)
282
+ parser.add_argument("--deepafx_st_spsa_ckpt", type=str)
283
+ parser.add_argument("--deepafx_st_proxy0_ckpt", type=str)
284
+
285
+ # --- cdpam ---
286
+ parser.add_argument("--cdpam_ckpt", type=str)
287
+ # --- openl3 ---
288
+ parser.add_argument("--openl3_input_repr", type=str, default="mel128")
289
+ parser.add_argument("--openl3_content_type", type=str, default="env")
290
+ parser.add_argument("--openl3_embedding_size", type=int, default=6144)
291
+ # --- random_mel ---
292
+ parser.add_argument("--random_mel_embedding_size", type=str, default=4096)
293
+ parser.add_argument("--random_mel_n_fft", type=str, default=4096)
294
+ parser.add_argument("--random_mel_hop_size", type=str, default=1024)
295
+ parser.add_argument("--random_mel_n_mels", type=str, default=128)
296
+ # --- Training ---
297
+ parser.add_argument("--audio_dir", type=str)
298
+ parser.add_argument("--num_classes", type=int, default=5)
299
+ parser.add_argument("--batch_size", type=int, default=32)
300
+ parser.add_argument("--lr", type=float, default=3e-4)
301
+ parser.add_argument("--lr_patience", type=int, default=20)
302
+ parser.add_argument("--patience", type=int, default=10)
303
+ parser.add_argument("--preload", action="store_true")
304
+ parser.add_argument("--sample_rate", type=int, default=24000)
305
+ parser.add_argument("--num_workers", type=int, default=8)
306
+
307
+ return parser
deepafx_st/probes/random_mel.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import librosa
4
+
5
+ # based on https://github.com/neuralaudio/hear-baseline/blob/main/hearbaseline/naive.py
6
+
7
+
8
+ class RandomMelProjection(torch.nn.Module):
9
+ def __init__(
10
+ self,
11
+ sample_rate,
12
+ embed_dim=4096,
13
+ n_mels=128,
14
+ n_fft=4096,
15
+ hop_size=1024,
16
+ seed=0,
17
+ epsilon=1e-4,
18
+ ):
19
+ super().__init__()
20
+ self.sample_rate = sample_rate
21
+ self.embed_dim = embed_dim
22
+ self.n_mels = n_mels
23
+ self.n_fft = n_fft
24
+ self.hop_size = hop_size
25
+ self.seed = seed
26
+ self.epsilon = epsilon
27
+
28
+ # Set random seed
29
+ torch.random.manual_seed(self.seed)
30
+
31
+ # Create a Hann window buffer to apply to frames prior to FFT.
32
+ self.register_buffer("window", torch.hann_window(self.n_fft))
33
+
34
+ # Create a mel filter buffer.
35
+ mel_scale = torch.tensor(
36
+ librosa.filters.mel(
37
+ self.sample_rate,
38
+ n_fft=self.n_fft,
39
+ n_mels=self.n_mels,
40
+ )
41
+ )
42
+ self.register_buffer("mel_scale", mel_scale)
43
+
44
+ # Projection matrices.
45
+ normalization = math.sqrt(self.n_mels)
46
+ self.projection = torch.nn.Parameter(
47
+ torch.rand(self.n_mels, self.embed_dim) / normalization,
48
+ requires_grad=False,
49
+ )
50
+
51
+ def forward(self, x):
52
+ bs, chs, samp = x.size()
53
+
54
+ x = torch.stft(
55
+ x.view(bs, -1),
56
+ self.n_fft,
57
+ self.hop_size,
58
+ window=self.window,
59
+ return_complex=True,
60
+ )
61
+ x = x.unsqueeze(1).permute(0, 1, 3, 2)
62
+
63
+ # Apply the mel-scale filter to the power spectrum.
64
+ x = torch.matmul(x.abs(), self.mel_scale.transpose(0, 1))
65
+
66
+ # power scale
67
+ x = torch.pow(x + self.epsilon, 0.3)
68
+
69
+ # apply random projection
70
+ e = x.matmul(self.projection)
71
+
72
+ # take mean across temporal dim
73
+ e = e.mean(dim=2).view(bs, -1)
74
+
75
+ return e
76
+
77
+ def compute_frame_embedding(self, x):
78
+ # Compute the real-valued Fourier transform on windowed input signal.
79
+ x = torch.fft.rfft(x * self.window)
80
+
81
+ # Convert to a power spectrum.
82
+ x = torch.abs(x) ** 2.0
83
+
84
+ # Apply the mel-scale filter to the power spectrum.
85
+ x = torch.matmul(x, self.mel_scale.transpose(0, 1))
86
+
87
+ # Convert to a log mel spectrum.
88
+ x = torch.log(x + self.epsilon)
89
+
90
+ # Apply projection to get a 4096 dimension embedding
91
+ embedding = x.matmul(self.projection)
92
+
93
+ return embedding
deepafx_st/processors/autodiff/__init__.py ADDED
File without changes
deepafx_st/processors/autodiff/channel.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from deepafx_st.processors.autodiff.compressor import Compressor
4
+ from deepafx_st.processors.autodiff.peq import ParametricEQ
5
+ from deepafx_st.processors.autodiff.fir import FIRFilter
6
+
7
+
8
+ class AutodiffChannel(torch.nn.Module):
9
+ def __init__(self, sample_rate):
10
+ super().__init__()
11
+
12
+ self.peq = ParametricEQ(sample_rate)
13
+ self.comp = Compressor(sample_rate)
14
+ self.ports = [self.peq.ports, self.comp.ports]
15
+ self.num_control_params = (
16
+ self.peq.num_control_params + self.comp.num_control_params
17
+ )
18
+
19
+ def forward(self, x, p, sample_rate=24000, **kwargs):
20
+
21
+ # split params between EQ and Comp.
22
+ p_peq = p[:, : self.peq.num_control_params]
23
+ p_comp = p[:, self.peq.num_control_params :]
24
+
25
+ y = self.peq(x, p_peq, sample_rate)
26
+ y = self.comp(y, p_comp, sample_rate)
27
+
28
+ return y
deepafx_st/processors/autodiff/compressor.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import scipy.signal
4
+
5
+ import deepafx_st.processors.autodiff.signal
6
+ from deepafx_st.processors.processor import Processor
7
+
8
+
9
+ @torch.jit.script
10
+ def compressor(
11
+ x: torch.Tensor,
12
+ sample_rate: float,
13
+ threshold: torch.Tensor,
14
+ ratio: torch.Tensor,
15
+ attack_time: torch.Tensor,
16
+ release_time: torch.Tensor,
17
+ knee_dB: torch.Tensor,
18
+ makeup_gain_dB: torch.Tensor,
19
+ eps: float = 1e-8,
20
+ ):
21
+ """Note the `release` parameter is not used."""
22
+ # print(f"autodiff comp fs = {sample_rate}")
23
+
24
+ s = x.size() # should be one 1d
25
+
26
+ threshold = threshold.squeeze()
27
+ ratio = ratio.squeeze()
28
+ attack_time = attack_time.squeeze()
29
+ makeup_gain_dB = makeup_gain_dB.squeeze()
30
+
31
+ # uni-polar dB signal
32
+ # Turn the input signal into a uni-polar signal on the dB scale
33
+ x_G = 20 * torch.log10(torch.abs(x) + 1e-8) # x_uni casts type
34
+
35
+ # Ensure there are no values of negative infinity
36
+ x_G = torch.clamp(x_G, min=-96)
37
+
38
+ # Static characteristics with knee
39
+ y_G = torch.zeros(s).type_as(x)
40
+
41
+ ratio = ratio.view(-1)
42
+ threshold = threshold.view(-1)
43
+ attack_time = attack_time.view(-1)
44
+ release_time = release_time.view(-1)
45
+ knee_dB = knee_dB.view(-1)
46
+ makeup_gain_dB = makeup_gain_dB.view(-1)
47
+
48
+ # Below knee
49
+ idx = torch.where((2 * (x_G - threshold)) < -knee_dB)[0]
50
+ y_G[idx] = x_G[idx]
51
+
52
+ # At knee
53
+ idx = torch.where((2 * torch.abs(x_G - threshold)) <= knee_dB)[0]
54
+ y_G[idx] = x_G[idx] + (
55
+ (1 / ratio) * (((x_G[idx] - threshold + knee_dB) / 2) ** 2)
56
+ ) / (2 * knee_dB)
57
+
58
+ # Above knee threshold
59
+ idx = torch.where((2 * (x_G - threshold)) > knee_dB)[0]
60
+ y_G[idx] = threshold + ((x_G[idx] - threshold) / ratio)
61
+
62
+ x_L = x_G - y_G
63
+
64
+ # design 1-pole butterworth lowpass
65
+ fc = 1.0 / (attack_time * sample_rate)
66
+ b, a = deepafx_st.processors.autodiff.signal.butter(fc)
67
+
68
+ # apply FIR approx of IIR filter
69
+ y_L = deepafx_st.processors.autodiff.signal.approx_iir_filter(b, a, x_L)
70
+
71
+ lin_y_L = torch.pow(10.0, -y_L / 20.0) # convert back to linear
72
+ y = lin_y_L * x # apply gain
73
+
74
+ # apply makeup gain
75
+ y *= torch.pow(10.0, makeup_gain_dB / 20.0)
76
+
77
+ return y
78
+
79
+
80
+ class Compressor(Processor):
81
+ def __init__(
82
+ self,
83
+ sample_rate,
84
+ max_threshold=0.0,
85
+ min_threshold=-80,
86
+ max_ratio=20.0,
87
+ min_ratio=1.0,
88
+ max_attack=0.1,
89
+ min_attack=0.0001,
90
+ max_release=1.0,
91
+ min_release=0.005,
92
+ max_knee=12.0,
93
+ min_knee=0.0,
94
+ max_mkgain=48.0,
95
+ min_mkgain=-48.0,
96
+ eps=1e-8,
97
+ ):
98
+ """ """
99
+ super().__init__()
100
+ self.sample_rate = sample_rate
101
+ self.eps = eps
102
+ self.ports = [
103
+ {
104
+ "name": "Threshold",
105
+ "min": min_threshold,
106
+ "max": max_threshold,
107
+ "default": -12.0,
108
+ "units": "dB",
109
+ },
110
+ {
111
+ "name": "Ratio",
112
+ "min": min_ratio,
113
+ "max": max_ratio,
114
+ "default": 2.0,
115
+ "units": "",
116
+ },
117
+ {
118
+ "name": "Attack",
119
+ "min": min_attack,
120
+ "max": max_attack,
121
+ "default": 0.001,
122
+ "units": "s",
123
+ },
124
+ {
125
+ # this is a dummy parameter
126
+ "name": "Release (dummy)",
127
+ "min": min_release,
128
+ "max": max_release,
129
+ "default": 0.045,
130
+ "units": "s",
131
+ },
132
+ {
133
+ "name": "Knee",
134
+ "min": min_knee,
135
+ "max": max_knee,
136
+ "default": 6.0,
137
+ "units": "dB",
138
+ },
139
+ {
140
+ "name": "Makeup Gain",
141
+ "min": min_mkgain,
142
+ "max": max_mkgain,
143
+ "default": 0.0,
144
+ "units": "dB",
145
+ },
146
+ ]
147
+
148
+ self.num_control_params = len(self.ports)
149
+
150
+ def forward(self, x, p, sample_rate=24000, **kwargs):
151
+ """
152
+
153
+ Assume that parameters in p are normalized between 0 and 1.
154
+
155
+ x (tensor): Shape batch x 1 x samples
156
+ p (tensor): shape batch x params
157
+
158
+ """
159
+ bs, ch, s = x.size()
160
+
161
+ inputs = torch.split(x, 1, 0)
162
+ params = torch.split(p, 1, 0)
163
+
164
+ y = [] # loop over batch dimension
165
+ for input, param in zip(inputs, params):
166
+ denorm_param = self.denormalize_params(param.view(-1))
167
+ y.append(compressor(input.view(-1), sample_rate, *denorm_param))
168
+
169
+ return torch.stack(y, dim=0).view(bs, 1, -1)
deepafx_st/processors/autodiff/fir.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class FIRFilter(torch.nn.Module):
5
+ def __init__(self, num_control_params=63):
6
+ super().__init__()
7
+ self.num_control_params = num_control_params
8
+ self.adaptor = torch.nn.Linear(num_control_params, num_control_params)
9
+ #self.batched_lfilter = torch.vmap(self.lfilter)
10
+
11
+ def forward(self, x, b, **kwargs):
12
+ """Forward pass by appling FIR filter to each batch element.
13
+
14
+ Args:
15
+ x (tensor): Input signals with shape (batch x 1 x samples)
16
+ b (tensor): Matrix of FIR filter coefficients with shape (batch x ntaps)
17
+
18
+ """
19
+ bs, ch, s = x.size()
20
+ b = self.adaptor(b)
21
+
22
+ # pad input
23
+ x = torch.nn.functional.pad(x, (b.shape[-1] // 2, b.shape[-1] // 2))
24
+
25
+ # add extra dim for virutal batch dim
26
+ x = x.view(bs, 1, ch, -1)
27
+ b = b.view(bs, 1, 1, -1)
28
+
29
+ # exlcuding vmap for now
30
+ y = self.batched_lfilter(x, b).view(bs, ch, s)
31
+
32
+ return y
33
+
34
+ @staticmethod
35
+ def lfilter(x, b):
36
+ return torch.nn.functional.conv1d(x, b)
37
+
38
+
39
+ class FrequencyDomainFIRFilter(torch.nn.Module):
40
+ def __init__(self, num_control_params=31):
41
+ super().__init__()
42
+ self.num_control_params = num_control_params
43
+ self.adaptor = torch.nn.Linear(num_control_params, num_control_params)
44
+
45
+ def forward(self, x, b, **kwargs):
46
+ """Forward pass by appling FIR filter to each batch element.
47
+
48
+ Args:
49
+ x (tensor): Input signals with shape (batch x 1 x samples)
50
+ b (tensor): Matrix of FIR filter coefficients with shape (batch x ntaps)
51
+ """
52
+ bs, c, s = x.size()
53
+
54
+ b = self.adaptor(b)
55
+
56
+ # transform input to freq. domain
57
+ X = torch.fft.rfft(x.view(bs, -1))
58
+
59
+ # frequency response of filter
60
+ H = torch.fft.rfft(b.view(bs, -1))
61
+
62
+ # apply filter as multiplication in freq. domain
63
+ Y = X * H
64
+
65
+ # transform back to time domain
66
+ y = torch.fft.ifft(Y).view(bs, 1, -1)
67
+
68
+ return y
deepafx_st/processors/autodiff/peq.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import deepafx_st.processors.autodiff.signal
4
+ from deepafx_st.processors.processor import Processor
5
+
6
+
7
+ @torch.jit.script
8
+ def parametric_eq(
9
+ x: torch.Tensor,
10
+ sample_rate: float,
11
+ low_shelf_gain_dB: torch.Tensor,
12
+ low_shelf_cutoff_freq: torch.Tensor,
13
+ low_shelf_q_factor: torch.Tensor,
14
+ first_band_gain_dB: torch.Tensor,
15
+ first_band_cutoff_freq: torch.Tensor,
16
+ first_band_q_factor: torch.Tensor,
17
+ second_band_gain_dB: torch.Tensor,
18
+ second_band_cutoff_freq: torch.Tensor,
19
+ second_band_q_factor: torch.Tensor,
20
+ third_band_gain_dB: torch.Tensor,
21
+ third_band_cutoff_freq: torch.Tensor,
22
+ third_band_q_factor: torch.Tensor,
23
+ fourth_band_gain_dB: torch.Tensor,
24
+ fourth_band_cutoff_freq: torch.Tensor,
25
+ fourth_band_q_factor: torch.Tensor,
26
+ high_shelf_gain_dB: torch.Tensor,
27
+ high_shelf_cutoff_freq: torch.Tensor,
28
+ high_shelf_q_factor: torch.Tensor,
29
+ ):
30
+ """Six-band parametric EQ.
31
+
32
+ Low-shelf -> Band 1 -> Band 2 -> Band 3 -> Band 4 -> High-shelf
33
+
34
+ Args:
35
+ x (torch.Tensor): 1d signal.
36
+
37
+
38
+ """
39
+ a_s, b_s = [], []
40
+ #print(f"autodiff peq fs = {sample_rate}")
41
+
42
+ # -------- apply low-shelf filter --------
43
+ b, a = deepafx_st.processors.autodiff.signal.biqaud(
44
+ low_shelf_gain_dB,
45
+ low_shelf_cutoff_freq,
46
+ low_shelf_q_factor,
47
+ sample_rate,
48
+ "low_shelf",
49
+ )
50
+ b_s.append(b)
51
+ a_s.append(a)
52
+
53
+ # -------- apply first-band peaking filter --------
54
+ b, a = deepafx_st.processors.autodiff.signal.biqaud(
55
+ first_band_gain_dB,
56
+ first_band_cutoff_freq,
57
+ first_band_q_factor,
58
+ sample_rate,
59
+ "peaking",
60
+ )
61
+ b_s.append(b)
62
+ a_s.append(a)
63
+
64
+ # -------- apply second-band peaking filter --------
65
+ b, a = deepafx_st.processors.autodiff.signal.biqaud(
66
+ second_band_gain_dB,
67
+ second_band_cutoff_freq,
68
+ second_band_q_factor,
69
+ sample_rate,
70
+ "peaking",
71
+ )
72
+ b_s.append(b)
73
+ a_s.append(a)
74
+
75
+ # -------- apply third-band peaking filter --------
76
+ b, a = deepafx_st.processors.autodiff.signal.biqaud(
77
+ third_band_gain_dB,
78
+ third_band_cutoff_freq,
79
+ third_band_q_factor,
80
+ sample_rate,
81
+ "peaking",
82
+ )
83
+ b_s.append(b)
84
+ a_s.append(a)
85
+
86
+ # -------- apply fourth-band peaking filter --------
87
+ b, a = deepafx_st.processors.autodiff.signal.biqaud(
88
+ fourth_band_gain_dB,
89
+ fourth_band_cutoff_freq,
90
+ fourth_band_q_factor,
91
+ sample_rate,
92
+ "peaking",
93
+ )
94
+ b_s.append(b)
95
+ a_s.append(a)
96
+
97
+ # -------- apply high-shelf filter --------
98
+ b, a = deepafx_st.processors.autodiff.signal.biqaud(
99
+ high_shelf_gain_dB,
100
+ high_shelf_cutoff_freq,
101
+ high_shelf_q_factor,
102
+ sample_rate,
103
+ "high_shelf",
104
+ )
105
+ b_s.append(b)
106
+ a_s.append(a)
107
+
108
+ x = deepafx_st.processors.autodiff.signal.approx_iir_filter_cascade(
109
+ b_s, a_s, x.view(-1)
110
+ )
111
+
112
+ return x
113
+
114
+
115
+ class ParametricEQ(Processor):
116
+ def __init__(
117
+ self,
118
+ sample_rate,
119
+ min_gain_dB=-24.0,
120
+ default_gain_dB=0.0,
121
+ max_gain_dB=24.0,
122
+ min_q_factor=0.1,
123
+ default_q_factor=0.707,
124
+ max_q_factor=10,
125
+ eps=1e-8,
126
+ ):
127
+ """ """
128
+ super().__init__()
129
+ self.sample_rate = sample_rate
130
+ self.eps = eps
131
+ self.ports = [
132
+ {
133
+ "name": "Lowshelf gain",
134
+ "min": min_gain_dB,
135
+ "max": max_gain_dB,
136
+ "default": default_gain_dB,
137
+ "units": "dB",
138
+ },
139
+ {
140
+ "name": "Lowshelf cutoff",
141
+ "min": 20.0,
142
+ "max": 200.0,
143
+ "default": 100.0,
144
+ "units": "Hz",
145
+ },
146
+ {
147
+ "name": "Lowshelf Q",
148
+ "min": min_q_factor,
149
+ "max": max_q_factor,
150
+ "default": default_q_factor,
151
+ "units": "",
152
+ },
153
+ {
154
+ "name": "First band gain",
155
+ "min": min_gain_dB,
156
+ "max": max_gain_dB,
157
+ "default": default_gain_dB,
158
+ "units": "dB",
159
+ },
160
+ {
161
+ "name": "First band cutoff",
162
+ "min": 200.0,
163
+ "max": 2000.0,
164
+ "default": 400.0,
165
+ "units": "Hz",
166
+ },
167
+ {
168
+ "name": "First band Q",
169
+ "min": min_q_factor,
170
+ "max": max_q_factor,
171
+ "default": 0.707,
172
+ "units": "",
173
+ },
174
+ {
175
+ "name": "Second band gain",
176
+ "min": min_gain_dB,
177
+ "max": max_gain_dB,
178
+ "default": default_gain_dB,
179
+ "units": "dB",
180
+ },
181
+ {
182
+ "name": "Second band cutoff",
183
+ "min": 200.0,
184
+ "max": 4000.0,
185
+ "default": 1000.0,
186
+ "units": "Hz",
187
+ },
188
+ {
189
+ "name": "Second band Q",
190
+ "min": min_q_factor,
191
+ "max": max_q_factor,
192
+ "default": default_q_factor,
193
+ "units": "",
194
+ },
195
+ {
196
+ "name": "Third band gain",
197
+ "min": min_gain_dB,
198
+ "max": max_gain_dB,
199
+ "default": default_gain_dB,
200
+ "units": "dB",
201
+ },
202
+ {
203
+ "name": "Third band cutoff",
204
+ "min": 2000.0,
205
+ "max": 8000.0,
206
+ "default": 4000.0,
207
+ "units": "Hz",
208
+ },
209
+ {
210
+ "name": "Third band Q",
211
+ "min": min_q_factor,
212
+ "max": max_q_factor,
213
+ "default": default_q_factor,
214
+ "units": "",
215
+ },
216
+ {
217
+ "name": "Fourth band gain",
218
+ "min": min_gain_dB,
219
+ "max": max_gain_dB,
220
+ "default": default_gain_dB,
221
+ "units": "dB",
222
+ },
223
+ {
224
+ "name": "Fourth band cutoff",
225
+ "min": 4000.0,
226
+ "max": (24000 // 2) * 0.9,
227
+ "default": 8000.0,
228
+ "units": "Hz",
229
+ },
230
+ {
231
+ "name": "Fourth band Q",
232
+ "min": min_q_factor,
233
+ "max": max_q_factor,
234
+ "default": default_q_factor,
235
+ "units": "",
236
+ },
237
+ {
238
+ "name": "Highshelf gain",
239
+ "min": min_gain_dB,
240
+ "max": max_gain_dB,
241
+ "default": default_gain_dB,
242
+ "units": "dB",
243
+ },
244
+ {
245
+ "name": "Highshelf cutoff",
246
+ "min": 4000.0,
247
+ "max": (24000 // 2) * 0.9,
248
+ "default": 8000.0,
249
+ "units": "Hz",
250
+ },
251
+ {
252
+ "name": "Highshelf Q",
253
+ "min": min_q_factor,
254
+ "max": max_q_factor,
255
+ "default": default_q_factor,
256
+ "units": "",
257
+ },
258
+ ]
259
+
260
+ self.num_control_params = len(self.ports)
261
+
262
+ def forward(self, x, p, sample_rate=24000, **kwargs):
263
+
264
+ bs, chs, s = x.size()
265
+
266
+ inputs = torch.split(x, 1, 0)
267
+ params = torch.split(p, 1, 0)
268
+
269
+ y = [] # loop over batch dimension
270
+ for input, param in zip(inputs, params):
271
+ denorm_param = self.denormalize_params(param.view(-1))
272
+ y.append(parametric_eq(input.view(-1), sample_rate, *denorm_param))
273
+
274
+ return torch.stack(y, dim=0).view(bs, 1, -1)
deepafx_st/processors/autodiff/signal.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from typing import List
4
+
5
+
6
+ def butter(fc, fs: float = 2.0):
7
+ """
8
+
9
+ Recall Butterworth polynomials
10
+ N = 1 s + 1
11
+ N = 2 s^2 + sqrt(2s) + 1
12
+ N = 3 (s^2 + s + 1)(s + 1)
13
+ N = 4 (s^2 + 0.76536s + 1)(s^2 + 1.84776s + 1)
14
+
15
+ Scaling
16
+ LP to LP: s -> s/w_c
17
+ LP to HP: s -> w_c/s
18
+
19
+ Bilinear transform:
20
+ s = 2/T_d * (1 - z^-1)/(1 + z^-1)
21
+
22
+ For 1-pole butterworth lowpass
23
+
24
+ 1 / (s + 1) 1-pole prototype
25
+ 1 / (s/w_c + 1) LP to LP
26
+ 1 / (2/T_d * (1 - z^-1)/(1 + z^-1))/w_c + 1) Bilinear transform
27
+
28
+ """
29
+
30
+ # apply pre-warping to the cutoff
31
+ T_d = 1 / fs
32
+ w_d = (2 * math.pi * fc) / fs
33
+ # sys.exit()
34
+ w_c = (2 / T_d) * torch.tan(w_d / 2)
35
+
36
+ a0 = 2 + (T_d * w_c)
37
+ a1 = (T_d * w_c) - 2
38
+ b0 = T_d * w_c
39
+ b1 = T_d * w_c
40
+
41
+ b = torch.stack([b0, b1], dim=0).view(-1)
42
+ a = torch.stack([a0, a1], dim=0).view(-1)
43
+
44
+ # normalize
45
+ b = b.type_as(fc) / a0
46
+ a = a.type_as(fc) / a0
47
+
48
+ return b, a
49
+
50
+
51
+ def biqaud(
52
+ gain_dB: torch.Tensor,
53
+ cutoff_freq: torch.Tensor,
54
+ q_factor: torch.Tensor,
55
+ sample_rate: float,
56
+ filter_type: str = "peaking",
57
+ ):
58
+
59
+ # convert inputs to Tensors if needed
60
+ # gain_dB = torch.tensor([gain_dB])
61
+ # cutoff_freq = torch.tensor([cutoff_freq])
62
+ # q_factor = torch.tensor([q_factor])
63
+
64
+ A = 10 ** (gain_dB / 40.0)
65
+ w0 = 2 * math.pi * (cutoff_freq / sample_rate)
66
+ alpha = torch.sin(w0) / (2 * q_factor)
67
+ cos_w0 = torch.cos(w0)
68
+ sqrt_A = torch.sqrt(A)
69
+
70
+ if filter_type == "high_shelf":
71
+ b0 = A * ((A + 1) + (A - 1) * cos_w0 + 2 * sqrt_A * alpha)
72
+ b1 = -2 * A * ((A - 1) + (A + 1) * cos_w0)
73
+ b2 = A * ((A + 1) + (A - 1) * cos_w0 - 2 * sqrt_A * alpha)
74
+ a0 = (A + 1) - (A - 1) * cos_w0 + 2 * sqrt_A * alpha
75
+ a1 = 2 * ((A - 1) - (A + 1) * cos_w0)
76
+ a2 = (A + 1) - (A - 1) * cos_w0 - 2 * sqrt_A * alpha
77
+ elif filter_type == "low_shelf":
78
+ b0 = A * ((A + 1) - (A - 1) * cos_w0 + 2 * sqrt_A * alpha)
79
+ b1 = 2 * A * ((A - 1) - (A + 1) * cos_w0)
80
+ b2 = A * ((A + 1) - (A - 1) * cos_w0 - 2 * sqrt_A * alpha)
81
+ a0 = (A + 1) + (A - 1) * cos_w0 + 2 * sqrt_A * alpha
82
+ a1 = -2 * ((A - 1) + (A + 1) * cos_w0)
83
+ a2 = (A + 1) + (A - 1) * cos_w0 - 2 * sqrt_A * alpha
84
+ elif filter_type == "peaking":
85
+ b0 = 1 + alpha * A
86
+ b1 = -2 * cos_w0
87
+ b2 = 1 - alpha * A
88
+ a0 = 1 + (alpha / A)
89
+ a1 = -2 * cos_w0
90
+ a2 = 1 - (alpha / A)
91
+ else:
92
+ raise ValueError(f"Invalid filter_type: {filter_type}.")
93
+
94
+ b = torch.stack([b0, b1, b2], dim=0).view(-1)
95
+ a = torch.stack([a0, a1, a2], dim=0).view(-1)
96
+
97
+ # normalize
98
+ b = b.type_as(gain_dB) / a0
99
+ a = a.type_as(gain_dB) / a0
100
+
101
+ return b, a
102
+
103
+
104
+ def freqz(b, a, n_fft: int = 512):
105
+
106
+ B = torch.fft.rfft(b, n_fft)
107
+ A = torch.fft.rfft(a, n_fft)
108
+
109
+ H = B / A
110
+
111
+ return H
112
+
113
+
114
+ def freq_domain_filter(x, H, n_fft):
115
+
116
+ X = torch.fft.rfft(x, n_fft)
117
+
118
+ # move H to same device as input x
119
+ H = H.type_as(X)
120
+
121
+ Y = X * H
122
+
123
+ y = torch.fft.irfft(Y, n_fft)
124
+
125
+ return y
126
+
127
+
128
+ def approx_iir_filter(b, a, x):
129
+ """Approimxate the application of an IIR filter.
130
+
131
+ Args:
132
+ b (Tensor): The numerator coefficients.
133
+
134
+ """
135
+
136
+ # round up to nearest power of 2 for FFT
137
+ # n_fft = 2 ** math.ceil(math.log2(x.shape[-1] + x.shape[-1] - 1))
138
+
139
+ n_fft = 2 ** torch.ceil(torch.log2(torch.tensor(x.shape[-1] + x.shape[-1] - 1)))
140
+ n_fft = n_fft.int()
141
+
142
+ # move coefficients to same device as x
143
+ b = b.type_as(x).view(-1)
144
+ a = a.type_as(x).view(-1)
145
+
146
+ # compute complex response
147
+ H = freqz(b, a, n_fft=n_fft).view(-1)
148
+
149
+ # apply filter
150
+ y = freq_domain_filter(x, H, n_fft)
151
+
152
+ # crop
153
+ y = y[: x.shape[-1]]
154
+
155
+ return y
156
+
157
+
158
+ def approx_iir_filter_cascade(
159
+ b_s: List[torch.Tensor],
160
+ a_s: List[torch.Tensor],
161
+ x: torch.Tensor,
162
+ ):
163
+ """Apply a cascade of IIR filters.
164
+
165
+ Args:
166
+ b (list[Tensor]): List of tensors of shape (3)
167
+ a (list[Tensor]): List of tensors of (3)
168
+ x (torch.Tensor): 1d Tensor.
169
+ """
170
+
171
+ if len(b_s) != len(a_s):
172
+ raise RuntimeError(
173
+ f"Must have same number of coefficients. Got b: {len(b_s)} and a: {len(a_s)}."
174
+ )
175
+
176
+ # round up to nearest power of 2 for FFT
177
+ # n_fft = 2 ** math.ceil(math.log2(x.shape[-1] + x.shape[-1] - 1))
178
+ n_fft = 2 ** torch.ceil(torch.log2(torch.tensor(x.shape[-1] + x.shape[-1] - 1)))
179
+ n_fft = n_fft.int()
180
+
181
+ # this could be done in parallel
182
+ b = torch.stack(b_s, dim=0).type_as(x)
183
+ a = torch.stack(a_s, dim=0).type_as(x)
184
+
185
+ H = freqz(b, a, n_fft=n_fft)
186
+ H = torch.prod(H, dim=0).view(-1)
187
+
188
+ # apply filter
189
+ y = freq_domain_filter(x, H, n_fft)
190
+
191
+ # crop
192
+ y = y[: x.shape[-1]]
193
+
194
+ return y
deepafx_st/processors/dsp/compressor.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ import numpy as np
4
+ import scipy.signal
5
+ from numba import jit
6
+
7
+ from deepafx_st.processors.processor import Processor
8
+
9
+
10
+ # Adapted from: https://github.com/drscotthawley/signaltrain/blob/master/signaltrain/audio.py
11
+ @jit(nopython=True)
12
+ def my_clip_min(
13
+ x: np.ndarray,
14
+ clip_min: float,
15
+ ): # does the work of np.clip(), which numba doesn't support yet
16
+ # TODO: keep an eye on Numba PR https://github.com/numba/numba/pull/3468 that fixes this
17
+ inds = np.where(x < clip_min)
18
+ x[inds] = clip_min
19
+ return x
20
+
21
+
22
+ @jit(nopython=True)
23
+ def compressor(
24
+ x: np.ndarray,
25
+ sample_rate: float,
26
+ threshold: float = -24.0,
27
+ ratio: float = 2.0,
28
+ attack_time: float = 0.01,
29
+ release_time: float = 0.01,
30
+ knee_dB: float = 0.0,
31
+ makeup_gain_dB: float = 0.0,
32
+ dtype=np.float32,
33
+ ):
34
+ """
35
+
36
+ Args:
37
+ x (np.ndarray): Input signal.
38
+ sample_rate (float): Sample rate in Hz.
39
+ threshold (float): Threhold in dB.
40
+ ratio (float): Ratio (should be >=1 , i.e. ratio:1).
41
+ attack_time (float): Attack time in seconds.
42
+ release_time (float): Release time in seconds.
43
+ knee_dB (float): Knee.
44
+ makeup_gain_dB (float): Makeup Gain.
45
+ dtype (type): Output type. Default: np.float32
46
+
47
+ Returns:
48
+ y (np.ndarray): Output signal.
49
+
50
+ """
51
+ # print(f"dsp comp fs = {sample_rate}")
52
+
53
+ N = len(x)
54
+ dtype = x.dtype
55
+ y = np.zeros(N, dtype=dtype)
56
+
57
+ # Initialize separate attack and release times
58
+ # Where do these numbers come from
59
+ alpha_A = np.exp(-np.log(9) / (sample_rate * attack_time))
60
+ alpha_R = np.exp(-np.log(9) / (sample_rate * release_time))
61
+
62
+ # Turn the input signal into a uni-polar signal on the dB scale
63
+ x_G = 20 * np.log10(np.abs(x) + 1e-8) # x_uni casts type
64
+
65
+ # Ensure there are no values of negative infinity
66
+ x_G = my_clip_min(x_G, -96)
67
+
68
+ # Static characteristics with knee
69
+ y_G = np.zeros(N, dtype=dtype)
70
+
71
+ # Below knee
72
+ idx = np.where((2 * (x_G - threshold)) < -knee_dB)
73
+ y_G[idx] = x_G[idx]
74
+
75
+ # At knee
76
+ idx = np.where((2 * np.abs(x_G - threshold)) <= knee_dB)
77
+ y_G[idx] = x_G[idx] + (
78
+ (1 / ratio) * (((x_G[idx] - threshold + knee_dB) / 2) ** 2)
79
+ ) / (2 * knee_dB)
80
+
81
+ # Above knee threshold
82
+ idx = np.where((2 * (x_G - threshold)) > knee_dB)
83
+ y_G[idx] = threshold + ((x_G[idx] - threshold) / ratio)
84
+
85
+ x_L = x_G - y_G
86
+
87
+ # this loop is slow but not vectorizable due to its cumulative, sequential nature. @autojit makes it fast(er).
88
+ y_L = np.zeros(N, dtype=dtype)
89
+ for n in range(1, N):
90
+ # smooth over the gainChange
91
+ if x_L[n] > y_L[n - 1]: # attack mode
92
+ y_L[n] = (alpha_A * y_L[n - 1]) + ((1 - alpha_A) * x_L[n])
93
+ else: # release
94
+ y_L[n] = (alpha_R * y_L[n - 1]) + ((1 - alpha_R) * x_L[n])
95
+
96
+ # Convert to linear amplitude scalar; i.e. map from dB to amplitude
97
+ lin_y_L = np.power(10.0, (-y_L / 20.0))
98
+ y = lin_y_L * x # Apply linear amplitude to input sample
99
+
100
+ y *= np.power(10.0, makeup_gain_dB / 20.0) # apply makeup gain
101
+
102
+ return y.astype(dtype)
103
+
104
+
105
+ class Compressor(Processor):
106
+ def __init__(
107
+ self,
108
+ sample_rate,
109
+ max_threshold=0.0,
110
+ min_threshold=-80,
111
+ max_ratio=20.0,
112
+ min_ratio=1.0,
113
+ max_attack=0.1,
114
+ min_attack=0.0001,
115
+ max_release=1.0,
116
+ min_release=0.005,
117
+ max_knee=12.0,
118
+ min_knee=0.0,
119
+ max_mkgain=48.0,
120
+ min_mkgain=-48.0,
121
+ eps=1e-8,
122
+ ):
123
+ """ """
124
+ super().__init__()
125
+ self.sample_rate = sample_rate
126
+ self.eps = eps
127
+ self.ports = [
128
+ {
129
+ "name": "Threshold",
130
+ "min": min_threshold,
131
+ "max": max_threshold,
132
+ "default": -12.0,
133
+ "units": "",
134
+ },
135
+ {
136
+ "name": "Ratio",
137
+ "min": min_ratio,
138
+ "max": max_ratio,
139
+ "default": 2.0,
140
+ "units": "",
141
+ },
142
+ {
143
+ "name": "Attack Time",
144
+ "min": min_attack,
145
+ "max": max_attack,
146
+ "default": 0.001,
147
+ "units": "s",
148
+ },
149
+ {
150
+ "name": "Release Time",
151
+ "min": min_release,
152
+ "max": max_release,
153
+ "default": 0.045,
154
+ "units": "s",
155
+ },
156
+ {
157
+ "name": "Knee",
158
+ "min": min_knee,
159
+ "max": max_knee,
160
+ "default": 6.0,
161
+ "units": "dB",
162
+ },
163
+ {
164
+ "name": "Makeup Gain",
165
+ "min": min_mkgain,
166
+ "max": max_mkgain,
167
+ "default": 0.0,
168
+ "units": "dB",
169
+ },
170
+ ]
171
+
172
+ self.num_control_params = len(self.ports)
173
+ self.process_fn = compressor
174
+
175
+ def forward(self, x, p, sample_rate=24000, **kwargs):
176
+ "All processing in the forward is in numpy."
177
+ return self.run_series(x, p, sample_rate)
deepafx_st/processors/dsp/peq.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import scipy.signal
4
+ from numba import jit
5
+
6
+ from deepafx_st.processors.processor import Processor
7
+
8
+
9
+ @jit(nopython=True)
10
+ def biqaud(
11
+ gain_dB: float,
12
+ cutoff_freq: float,
13
+ q_factor: float,
14
+ sample_rate: float,
15
+ filter_type: str,
16
+ ):
17
+ """Use design parameters to generate coeffieicnets for a specific filter type.
18
+
19
+ Args:
20
+ gain_dB (float): Shelving filter gain in dB.
21
+ cutoff_freq (float): Cutoff frequency in Hz.
22
+ q_factor (float): Q factor.
23
+ sample_rate (float): Sample rate in Hz.
24
+ filter_type (str): Filter type.
25
+ One of "low_shelf", "high_shelf", or "peaking"
26
+
27
+ Returns:
28
+ b (np.ndarray): Numerator filter coefficients stored as [b0, b1, b2]
29
+ a (np.ndarray): Denominator filter coefficients stored as [a0, a1, a2]
30
+ """
31
+
32
+ A = 10 ** (gain_dB / 40.0)
33
+ w0 = 2.0 * np.pi * (cutoff_freq / sample_rate)
34
+ alpha = np.sin(w0) / (2.0 * q_factor)
35
+
36
+ cos_w0 = np.cos(w0)
37
+ sqrt_A = np.sqrt(A)
38
+
39
+ if filter_type == "high_shelf":
40
+ b0 = A * ((A + 1) + (A - 1) * cos_w0 + 2 * sqrt_A * alpha)
41
+ b1 = -2 * A * ((A - 1) + (A + 1) * cos_w0)
42
+ b2 = A * ((A + 1) + (A - 1) * cos_w0 - 2 * sqrt_A * alpha)
43
+ a0 = (A + 1) - (A - 1) * cos_w0 + 2 * sqrt_A * alpha
44
+ a1 = 2 * ((A - 1) - (A + 1) * cos_w0)
45
+ a2 = (A + 1) - (A - 1) * cos_w0 - 2 * sqrt_A * alpha
46
+ elif filter_type == "low_shelf":
47
+ b0 = A * ((A + 1) - (A - 1) * cos_w0 + 2 * sqrt_A * alpha)
48
+ b1 = 2 * A * ((A - 1) - (A + 1) * cos_w0)
49
+ b2 = A * ((A + 1) - (A - 1) * cos_w0 - 2 * sqrt_A * alpha)
50
+ a0 = (A + 1) + (A - 1) * cos_w0 + 2 * sqrt_A * alpha
51
+ a1 = -2 * ((A - 1) + (A + 1) * cos_w0)
52
+ a2 = (A + 1) + (A - 1) * cos_w0 - 2 * sqrt_A * alpha
53
+ elif filter_type == "peaking":
54
+ b0 = 1 + alpha * A
55
+ b1 = -2 * cos_w0
56
+ b2 = 1 - alpha * A
57
+ a0 = 1 + alpha / A
58
+ a1 = -2 * cos_w0
59
+ a2 = 1 - alpha / A
60
+ else:
61
+ pass
62
+ # raise ValueError(f"Invalid filter_type: {filter_type}.")
63
+
64
+ b = np.array([b0, b1, b2]) / a0
65
+ a = np.array([a0, a1, a2]) / a0
66
+
67
+ return b, a
68
+
69
+
70
+ # Adapted from https://github.com/csteinmetz1/pyloudnorm/blob/master/pyloudnorm/iirfilter.py
71
+ def parametric_eq(
72
+ x: np.ndarray,
73
+ sample_rate: float,
74
+ low_shelf_gain_dB: float = 0.0,
75
+ low_shelf_cutoff_freq: float = 80.0,
76
+ low_shelf_q_factor: float = 0.707,
77
+ first_band_gain_dB: float = 0.0,
78
+ first_band_cutoff_freq: float = 300.0,
79
+ first_band_q_factor: float = 0.707,
80
+ second_band_gain_dB: float = 0.0,
81
+ second_band_cutoff_freq: float = 1000.0,
82
+ second_band_q_factor: float = 0.707,
83
+ third_band_gain_dB: float = 0.0,
84
+ third_band_cutoff_freq: float = 4000.0,
85
+ third_band_q_factor: float = 0.707,
86
+ fourth_band_gain_dB: float = 0.0,
87
+ fourth_band_cutoff_freq: float = 8000.0,
88
+ fourth_band_q_factor: float = 0.707,
89
+ high_shelf_gain_dB: float = 0.0,
90
+ high_shelf_cutoff_freq: float = 1000.0,
91
+ high_shelf_q_factor: float = 0.707,
92
+ dtype=np.float32,
93
+ ):
94
+ """Six-band parametric EQ.
95
+
96
+ Low-shelf -> Band 1 -> Band 2 -> Band 3 -> Band 4 -> High-shelf
97
+
98
+ Args:
99
+
100
+
101
+ """
102
+ # print(f"autodiff peq fs = {sample_rate}")
103
+
104
+ # -------- apply low-shelf filter --------
105
+ b, a = biqaud(
106
+ low_shelf_gain_dB,
107
+ low_shelf_cutoff_freq,
108
+ low_shelf_q_factor,
109
+ sample_rate,
110
+ "low_shelf",
111
+ )
112
+ sos0 = np.concatenate((b, a))
113
+ x = scipy.signal.lfilter(b, a, x)
114
+
115
+ # -------- apply first-band peaking filter --------
116
+ b, a = biqaud(
117
+ first_band_gain_dB,
118
+ first_band_cutoff_freq,
119
+ first_band_q_factor,
120
+ sample_rate,
121
+ "peaking",
122
+ )
123
+ sos1 = np.concatenate((b, a))
124
+ x = scipy.signal.lfilter(b, a, x)
125
+
126
+ # -------- apply second-band peaking filter --------
127
+ b, a = biqaud(
128
+ second_band_gain_dB,
129
+ second_band_cutoff_freq,
130
+ second_band_q_factor,
131
+ sample_rate,
132
+ "peaking",
133
+ )
134
+ sos2 = np.concatenate((b, a))
135
+ x = scipy.signal.lfilter(b, a, x)
136
+
137
+ # -------- apply third-band peaking filter --------
138
+ b, a = biqaud(
139
+ third_band_gain_dB,
140
+ third_band_cutoff_freq,
141
+ third_band_q_factor,
142
+ sample_rate,
143
+ "peaking",
144
+ )
145
+ sos3 = np.concatenate((b, a))
146
+ x = scipy.signal.lfilter(b, a, x)
147
+
148
+ # -------- apply fourth-band peaking filter --------
149
+ b, a = biqaud(
150
+ fourth_band_gain_dB,
151
+ fourth_band_cutoff_freq,
152
+ fourth_band_q_factor,
153
+ sample_rate,
154
+ "peaking",
155
+ )
156
+ sos4 = np.concatenate((b, a))
157
+ x = scipy.signal.lfilter(b, a, x)
158
+
159
+ # -------- apply high-shelf filter --------
160
+ b, a = biqaud(
161
+ high_shelf_gain_dB,
162
+ high_shelf_cutoff_freq,
163
+ high_shelf_q_factor,
164
+ sample_rate,
165
+ "high_shelf",
166
+ )
167
+ sos5 = np.concatenate((b, a))
168
+ x = scipy.signal.lfilter(b, a, x)
169
+
170
+ return x.astype(dtype)
171
+
172
+
173
+ class ParametricEQ(Processor):
174
+ def __init__(
175
+ self,
176
+ sample_rate,
177
+ min_gain_dB=-24.0,
178
+ default_gain_dB=0.0,
179
+ max_gain_dB=24.0,
180
+ min_q_factor=0.1,
181
+ default_q_factor=0.707,
182
+ max_q_factor=10,
183
+ eps=1e-8,
184
+ ):
185
+ """ """
186
+ super().__init__()
187
+ self.sample_rate = sample_rate
188
+ self.eps = eps
189
+ self.ports = [
190
+ {
191
+ "name": "Lowshelf gain",
192
+ "min": min_gain_dB,
193
+ "max": max_gain_dB,
194
+ "default": default_gain_dB,
195
+ "units": "dB",
196
+ },
197
+ {
198
+ "name": "Lowshelf cutoff",
199
+ "min": 20.0,
200
+ "max": 200.0,
201
+ "default": 100.0,
202
+ "units": "Hz",
203
+ },
204
+ {
205
+ "name": "Lowshelf Q",
206
+ "min": min_q_factor,
207
+ "max": max_q_factor,
208
+ "default": default_q_factor,
209
+ "units": "",
210
+ },
211
+ {
212
+ "name": "First band gain",
213
+ "min": min_gain_dB,
214
+ "max": max_gain_dB,
215
+ "default": default_gain_dB,
216
+ "units": "dB",
217
+ },
218
+ {
219
+ "name": "First band cutoff",
220
+ "min": 200.0,
221
+ "max": 2000.0,
222
+ "default": 400.0,
223
+ "units": "Hz",
224
+ },
225
+ {
226
+ "name": "First band Q",
227
+ "min": min_q_factor,
228
+ "max": max_q_factor,
229
+ "default": 0.707,
230
+ "units": "",
231
+ },
232
+ {
233
+ "name": "Second band gain",
234
+ "min": min_gain_dB,
235
+ "max": max_gain_dB,
236
+ "default": default_gain_dB,
237
+ "units": "dB",
238
+ },
239
+ {
240
+ "name": "Second band cutoff",
241
+ "min": 800.0,
242
+ "max": 4000.0,
243
+ "default": 1000.0,
244
+ "units": "Hz",
245
+ },
246
+ {
247
+ "name": "Second band Q",
248
+ "min": min_q_factor,
249
+ "max": max_q_factor,
250
+ "default": default_q_factor,
251
+ "units": "",
252
+ },
253
+ {
254
+ "name": "Third band gain",
255
+ "min": min_gain_dB,
256
+ "max": max_gain_dB,
257
+ "default": default_gain_dB,
258
+ "units": "dB",
259
+ },
260
+ {
261
+ "name": "Third band cutoff",
262
+ "min": 2000.0,
263
+ "max": 8000.0,
264
+ "default": 4000.0,
265
+ "units": "Hz",
266
+ },
267
+ {
268
+ "name": "Third band Q",
269
+ "min": min_q_factor,
270
+ "max": max_q_factor,
271
+ "default": default_q_factor,
272
+ "units": "",
273
+ },
274
+ {
275
+ "name": "Fourth band gain",
276
+ "min": min_gain_dB,
277
+ "max": max_gain_dB,
278
+ "default": default_gain_dB,
279
+ "units": "dB",
280
+ },
281
+ {
282
+ "name": "Fourth band cutoff",
283
+ "min": 4000.0,
284
+ "max": (24000 // 2) * 0.9,
285
+ "default": 8000.0,
286
+ "units": "Hz",
287
+ },
288
+ {
289
+ "name": "Fourth band Q",
290
+ "min": min_q_factor,
291
+ "max": max_q_factor,
292
+ "default": default_q_factor,
293
+ "units": "",
294
+ },
295
+ {
296
+ "name": "Highshelf gain",
297
+ "min": min_gain_dB,
298
+ "max": max_gain_dB,
299
+ "default": default_gain_dB,
300
+ "units": "dB",
301
+ },
302
+ {
303
+ "name": "Highshelf cutoff",
304
+ "min": 4000.0,
305
+ "max": (24000 // 2) * 0.9,
306
+ "default": 8000.0,
307
+ "units": "Hz",
308
+ },
309
+ {
310
+ "name": "Highshelf Q",
311
+ "min": min_q_factor,
312
+ "max": max_q_factor,
313
+ "default": default_q_factor,
314
+ "units": "",
315
+ },
316
+ ]
317
+
318
+ self.num_control_params = len(self.ports)
319
+ self.process_fn = parametric_eq
320
+
321
+ def forward(self, x, p, sample_rate=24000, **kwargs):
322
+ "All processing in the forward is in numpy."
323
+ return self.run_series(x, p, sample_rate)
deepafx_st/processors/processor.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import multiprocessing
3
+ from abc import ABC, abstractmethod
4
+ import deepafx_st.utils as utils
5
+ import numpy as np
6
+
7
+
8
+ class Processor(torch.nn.Module, ABC):
9
+ """Processor base class."""
10
+
11
+ def __init__(
12
+ self,
13
+ ):
14
+ super().__init__()
15
+
16
+ def denormalize_params(self, p):
17
+ """This method takes a tensor of parameters scaled from 0-1 and
18
+ restores them back to the original parameter range."""
19
+
20
+ # check if the number of parameters is correct
21
+ params = p # torch.split(p, 1, -1)
22
+ if len(params) != self.num_control_params:
23
+ raise RuntimeError(
24
+ f"Invalid number of parameters. ",
25
+ f"Expected {self.num_control_params} but found {len(params)} {params.shape}.",
26
+ )
27
+
28
+ # iterate over the parameters and expand from 0-1 to full range
29
+ denorm_params = []
30
+ for param, port in zip(params, self.ports):
31
+ # check if parameter exceeds range
32
+ if param > 1.0 or param < 0.0:
33
+ raise RuntimeError(
34
+ f"""Parameter '{port["name"]}' exceeds range: {param}"""
35
+ )
36
+
37
+ # denormalize and store result
38
+ denorm_params.append(utils.denormalize(param, port["max"], port["min"]))
39
+
40
+ return denorm_params
41
+
42
+ def normalize_params(self, *params):
43
+ """This method creates a vector of parameters normalized from 0-1."""
44
+
45
+ # check if the number of parameters is correct
46
+ if len(params) != self.num_control_params:
47
+ raise RuntimeError(
48
+ f"Invalid number of parameters. ",
49
+ f"Expected {self.num_control_params} but found {len(params)}.",
50
+ )
51
+
52
+ norm_params = []
53
+ for param, port in zip(params, self.ports):
54
+ norm_params.append(utils.normalize(param, port["max"], port["min"]))
55
+
56
+ p = torch.tensor(norm_params).view(1, -1)
57
+
58
+ return p
59
+
60
+ # def run_series(self, inputs, params):
61
+ # """Run the process function in a loop given a list of inputs and parameters"""
62
+ # p_b_denorm = [p for p in self.denormalize_params(params)]
63
+ # y = self.process_fn(inputs, self.sample_rate, *p_b_denorm)
64
+ # return y
65
+
66
+ def run_series(self, inputs, params, sample_rate=24000):
67
+ """Run the process function in a loop given a list of inputs and parameters"""
68
+ if params.ndim == 1:
69
+ params = np.reshape(params, (1, -1))
70
+ inputs = np.reshape(inputs, (1, -1))
71
+ bs = inputs.shape[0]
72
+ ys = []
73
+ params = np.clip(params, 0, 1)
74
+ for bidx in range(bs):
75
+ p_b_denorm = [p for p in self.denormalize_params(params[bidx, :])]
76
+ y = self.process_fn(
77
+ inputs[bidx, ...].reshape(-1),
78
+ sample_rate,
79
+ *p_b_denorm,
80
+ )
81
+ ys.append(y)
82
+ y = np.stack(ys, axis=0)
83
+ return y
84
+
85
+ @abstractmethod
86
+ def forward(self, x, p):
87
+ pass
deepafx_st/processors/proxy/channel.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from deepafx_st.processors.proxy.proxy_system import ProxySystem
3
+ from deepafx_st.utils import DSPMode
4
+
5
+
6
+ class ProxyChannel(torch.nn.Module):
7
+ def __init__(
8
+ self,
9
+ proxy_system_ckpts: list,
10
+ freeze_proxies: bool = True,
11
+ dsp_mode: DSPMode = DSPMode.NONE,
12
+ num_tcns: int = 2,
13
+ tcn_nblocks: int = 4,
14
+ tcn_dilation_growth: int = 8,
15
+ tcn_channel_width: int = 64,
16
+ tcn_kernel_size: int = 13,
17
+ sample_rate: int = 24000,
18
+ ):
19
+ super().__init__()
20
+ self.freeze_proxies = freeze_proxies
21
+ self.dsp_mode = dsp_mode
22
+ self.num_tcns = num_tcns
23
+
24
+ # load the proxies
25
+ self.proxies = torch.nn.ModuleList()
26
+ self.num_control_params = 0
27
+ self.ports = []
28
+ for proxy_system_ckpt in proxy_system_ckpts:
29
+ proxy = ProxySystem.load_from_checkpoint(proxy_system_ckpt)
30
+ # freeze model parameters
31
+ if freeze_proxies:
32
+ for param in proxy.parameters():
33
+ param.requires_grad = False
34
+ self.proxies.append(proxy)
35
+ if proxy.hparams.processor == "channel":
36
+ self.ports = proxy.processor.ports
37
+ else:
38
+ self.ports.append(proxy.processor.ports)
39
+ self.num_control_params += proxy.processor.num_control_params
40
+
41
+ if len(proxy_system_ckpts) == 0:
42
+ if self.num_tcns == 2:
43
+ peq_proxy = ProxySystem(
44
+ processor="peq",
45
+ output_gain=False,
46
+ nblocks=tcn_nblocks,
47
+ dilation_growth=tcn_dilation_growth,
48
+ kernel_size=tcn_kernel_size,
49
+ channel_width=tcn_channel_width,
50
+ sample_rate=sample_rate,
51
+ )
52
+ self.proxies.append(peq_proxy)
53
+ self.ports.append(peq_proxy.processor.ports)
54
+ self.num_control_params += peq_proxy.processor.num_control_params
55
+ comp_proxy = ProxySystem(
56
+ processor="comp",
57
+ output_gain=True,
58
+ nblocks=tcn_nblocks,
59
+ dilation_growth=tcn_dilation_growth,
60
+ kernel_size=tcn_kernel_size,
61
+ channel_width=tcn_channel_width,
62
+ sample_rate=sample_rate,
63
+ )
64
+ self.proxies.append(comp_proxy)
65
+ self.ports.append(comp_proxy.processor.ports)
66
+ self.num_control_params += comp_proxy.processor.num_control_params
67
+ elif self.num_tcns == 1:
68
+ channel_proxy = ProxySystem(
69
+ processor="channel",
70
+ output_gain=True,
71
+ nblocks=tcn_nblocks,
72
+ dilation_growth=tcn_dilation_growth,
73
+ kernel_size=tcn_kernel_size,
74
+ channel_width=tcn_channel_width,
75
+ sample_rate=sample_rate,
76
+ )
77
+ self.proxies.append(channel_proxy)
78
+ for port_list in channel_proxy.processor.ports:
79
+ self.ports.append(port_list)
80
+ self.num_control_params += channel_proxy.processor.num_control_params
81
+ else:
82
+ raise ValueError(f"num_tcns must be <= 2. Asked for {self.num_tcns}.")
83
+
84
+ def forward(
85
+ self,
86
+ x: torch.Tensor,
87
+ p: torch.Tensor,
88
+ dsp_mode: DSPMode = DSPMode.NONE,
89
+ sample_rate: int = 24000,
90
+ **kwargs,
91
+ ):
92
+ # loop over the proxies and pass parameters
93
+ stop_idx = 0
94
+ for proxy in self.proxies:
95
+ start_idx = stop_idx
96
+ stop_idx += proxy.processor.num_control_params
97
+ p_subset = p[:, start_idx:stop_idx]
98
+ if dsp_mode.name == DSPMode.NONE.name:
99
+ x = proxy(
100
+ x,
101
+ p_subset,
102
+ use_dsp=False,
103
+ )
104
+ elif dsp_mode.name == DSPMode.INFER.name:
105
+ x = proxy(
106
+ x,
107
+ p_subset,
108
+ use_dsp=True,
109
+ sample_rate=sample_rate,
110
+ )
111
+ elif dsp_mode.name == DSPMode.TRAIN_INFER.name:
112
+ # Mimic gumbel softmax implementation to replace grads similar to
113
+ # https://gist.github.com/yzh119/fd2146d2aeb329d067568a493b20172f
114
+ x_hard = proxy(
115
+ x,
116
+ p_subset,
117
+ use_dsp=True,
118
+ sample_rate=sample_rate,
119
+ )
120
+ x = proxy(
121
+ x,
122
+ p_subset,
123
+ use_dsp=False,
124
+ sample_rate=sample_rate,
125
+ )
126
+ x = (x_hard - x).detach() + x
127
+ else:
128
+ assert 0, "invalid dsp model for proxy"
129
+
130
+ return x
deepafx_st/processors/proxy/proxy_system.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from re import X
2
+ import torch
3
+ import auraloss
4
+ import pytorch_lightning as pl
5
+ from typing import Tuple, List, Dict
6
+ from argparse import ArgumentParser
7
+
8
+
9
+ import deepafx_st.utils as utils
10
+ from deepafx_st.data.proxy import DSPProxyDataset
11
+ from deepafx_st.processors.proxy.tcn import ConditionalTCN
12
+ from deepafx_st.processors.spsa.channel import SPSAChannel
13
+ from deepafx_st.processors.dsp.peq import ParametricEQ
14
+ from deepafx_st.processors.dsp.compressor import Compressor
15
+
16
+
17
+ class ProxySystem(pl.LightningModule):
18
+ def __init__(
19
+ self,
20
+ causal=True,
21
+ nblocks=4,
22
+ dilation_growth=8,
23
+ kernel_size=13,
24
+ channel_width=64,
25
+ input_dir=None,
26
+ processor="channel",
27
+ batch_size=32,
28
+ lr=3e-4,
29
+ lr_patience=20,
30
+ patience=10,
31
+ preload=False,
32
+ sample_rate=24000,
33
+ shuffle=True,
34
+ train_length=65536,
35
+ train_examples_per_epoch=10000,
36
+ val_length=131072,
37
+ val_examples_per_epoch=1000,
38
+ num_workers=16,
39
+ output_gain=False,
40
+ **kwargs,
41
+ ):
42
+ super().__init__()
43
+ self.save_hyperparameters()
44
+ #print(f"Proxy Processor: {processor} @ fs={sample_rate} Hz")
45
+
46
+ # construct both the true DSP...
47
+ if self.hparams.processor == "peq":
48
+ self.processor = ParametricEQ(self.hparams.sample_rate)
49
+ elif self.hparams.processor == "comp":
50
+ self.processor = Compressor(self.hparams.sample_rate)
51
+ elif self.hparams.processor == "channel":
52
+ self.processor = SPSAChannel(self.hparams.sample_rate)
53
+
54
+ # and the neural network proxy
55
+ self.proxy = ConditionalTCN(
56
+ self.hparams.sample_rate,
57
+ num_control_params=self.processor.num_control_params,
58
+ causal=self.hparams.causal,
59
+ nblocks=self.hparams.nblocks,
60
+ channel_width=self.hparams.channel_width,
61
+ kernel_size=self.hparams.kernel_size,
62
+ dilation_growth=self.hparams.dilation_growth,
63
+ )
64
+
65
+ self.receptive_field = self.proxy.compute_receptive_field()
66
+
67
+ self.recon_losses = {}
68
+ self.recon_loss_weights = {}
69
+
70
+ self.recon_losses["mrstft"] = auraloss.freq.MultiResolutionSTFTLoss(
71
+ fft_sizes=[32, 128, 512, 2048, 8192, 32768],
72
+ hop_sizes=[16, 64, 256, 1024, 4096, 16384],
73
+ win_lengths=[32, 128, 512, 2048, 8192, 32768],
74
+ w_sc=0.0,
75
+ w_phs=0.0,
76
+ w_lin_mag=1.0,
77
+ w_log_mag=1.0,
78
+ )
79
+ self.recon_loss_weights["mrstft"] = 1.0
80
+
81
+ self.recon_losses["l1"] = torch.nn.L1Loss()
82
+ self.recon_loss_weights["l1"] = 100.0
83
+
84
+ def forward(self, x, p, use_dsp=False, sample_rate=24000, **kwargs):
85
+ """Use the pre-trained neural network proxy effect."""
86
+ bs, chs, samp = x.size()
87
+ if not use_dsp:
88
+ y = self.proxy(x, p)
89
+ # manually apply the makeup gain parameter
90
+ if self.hparams.output_gain and not self.hparams.processor == "peq":
91
+ gain_db = (p[..., -1] * 96) - 48
92
+ gain_ln = 10 ** (gain_db / 20.0)
93
+ y *= gain_ln.view(bs, chs, 1)
94
+ else:
95
+ with torch.no_grad():
96
+ bs, chs, s = x.shape
97
+
98
+ if self.hparams.output_gain and not self.hparams.processor == "peq":
99
+ # override makeup gain
100
+ gain_db = (p[..., -1] * 96) - 48
101
+ gain_ln = 10 ** (gain_db / 20.0)
102
+ p[..., -1] = 0.5
103
+
104
+ if self.hparams.processor == "channel":
105
+ y_temp = self.processor(x.cpu(), p.cpu())
106
+ y_temp = y_temp.view(bs, chs, s).type_as(x)
107
+ else:
108
+ y_temp = self.processor(
109
+ x.cpu().numpy(),
110
+ p.cpu().numpy(),
111
+ sample_rate,
112
+ )
113
+ y_temp = torch.tensor(y_temp).view(bs, chs, s).type_as(x)
114
+
115
+ y = y_temp.type_as(x).view(bs, 1, -1)
116
+
117
+ if self.hparams.output_gain and not self.hparams.processor == "peq":
118
+ y *= gain_ln.view(bs, chs, 1)
119
+
120
+ return y
121
+
122
+ def common_step(
123
+ self,
124
+ batch: Tuple,
125
+ batch_idx: int,
126
+ optimizer_idx: int = 0,
127
+ train: bool = True,
128
+ ):
129
+ loss = 0
130
+ x, y, p = batch
131
+
132
+ y_hat = self(x, p)
133
+
134
+ # compute loss
135
+ for loss_idx, (loss_name, loss_fn) in enumerate(self.recon_losses.items()):
136
+ tmp_loss = loss_fn(y_hat.float(), y.float())
137
+ loss += self.recon_loss_weights[loss_name] * tmp_loss
138
+
139
+ self.log(
140
+ f"train_loss/{loss_name}" if train else f"val_loss/{loss_name}",
141
+ tmp_loss,
142
+ on_step=True,
143
+ on_epoch=True,
144
+ prog_bar=False,
145
+ logger=True,
146
+ sync_dist=True,
147
+ )
148
+
149
+ if not train:
150
+ # store audio data
151
+ data_dict = {
152
+ "x": x.float().cpu(),
153
+ "y": y.float().cpu(),
154
+ "p": p.float().cpu(),
155
+ "y_hat": y_hat.float().cpu(),
156
+ }
157
+ else:
158
+ data_dict = {}
159
+
160
+ self.log(
161
+ "train_loss" if train else "val_loss",
162
+ loss,
163
+ on_step=True,
164
+ on_epoch=True,
165
+ prog_bar=False,
166
+ logger=True,
167
+ sync_dist=True,
168
+ )
169
+
170
+ return loss, data_dict
171
+
172
+ def training_step(self, batch, batch_idx, optimizer_idx=0):
173
+ loss, _ = self.common_step(batch, batch_idx)
174
+ return loss
175
+
176
+ def validation_step(self, batch, batch_idx):
177
+ loss, data_dict = self.common_step(batch, batch_idx, train=False)
178
+
179
+ if batch_idx == 0:
180
+ return data_dict
181
+
182
+ def configure_optimizers(self):
183
+ optimizer = torch.optim.Adam(
184
+ self.proxy.parameters(),
185
+ lr=self.hparams.lr,
186
+ betas=(0.9, 0.999),
187
+ )
188
+
189
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
190
+ optimizer,
191
+ patience=self.hparams.lr_patience,
192
+ verbose=True,
193
+ )
194
+
195
+ return [optimizer], {"scheduler": scheduler, "monitor": "val_loss"}
196
+
197
+ def train_dataloader(self):
198
+
199
+ train_dataset = DSPProxyDataset(
200
+ self.hparams.input_dir,
201
+ self.processor,
202
+ self.hparams.processor, # name
203
+ subset="train",
204
+ length=self.hparams.train_length,
205
+ num_examples_per_epoch=self.hparams.train_examples_per_epoch,
206
+ half=True if self.hparams.precision == 16 else False,
207
+ buffer_size_gb=self.hparams.buffer_size_gb,
208
+ buffer_reload_rate=self.hparams.buffer_reload_rate,
209
+ )
210
+
211
+ g = torch.Generator()
212
+ g.manual_seed(0)
213
+
214
+ return torch.utils.data.DataLoader(
215
+ train_dataset,
216
+ num_workers=self.hparams.num_workers,
217
+ batch_size=self.hparams.batch_size,
218
+ worker_init_fn=utils.seed_worker,
219
+ generator=g,
220
+ pin_memory=True,
221
+ )
222
+
223
+ def val_dataloader(self):
224
+
225
+ val_dataset = DSPProxyDataset(
226
+ self.hparams.input_dir,
227
+ self.processor,
228
+ self.hparams.processor, # name
229
+ subset="val",
230
+ length=self.hparams.val_length,
231
+ num_examples_per_epoch=self.hparams.val_examples_per_epoch,
232
+ half=True if self.hparams.precision == 16 else False,
233
+ buffer_size_gb=self.hparams.buffer_size_gb,
234
+ buffer_reload_rate=self.hparams.buffer_reload_rate,
235
+ )
236
+
237
+ g = torch.Generator()
238
+ g.manual_seed(0)
239
+
240
+ return torch.utils.data.DataLoader(
241
+ val_dataset,
242
+ num_workers=self.hparams.num_workers,
243
+ batch_size=self.hparams.batch_size,
244
+ worker_init_fn=utils.seed_worker,
245
+ generator=g,
246
+ pin_memory=True,
247
+ )
248
+
249
+ @staticmethod
250
+ def count_control_params(plugin_config):
251
+ num_control_params = 0
252
+
253
+ for plugin in plugin_config["plugins"]:
254
+ for port in plugin["ports"]:
255
+ if port["optim"]:
256
+ num_control_params += 1
257
+
258
+ return num_control_params
259
+
260
+ # add any model hyperparameters here
261
+ @staticmethod
262
+ def add_model_specific_args(parent_parser):
263
+ parser = ArgumentParser(parents=[parent_parser], add_help=False)
264
+ # --- Model ---
265
+ parser.add_argument("--causal", action="store_true")
266
+ parser.add_argument("--output_gain", action="store_true")
267
+ parser.add_argument("--dilation_growth", type=int, default=8)
268
+ parser.add_argument("--nblocks", type=int, default=4)
269
+ parser.add_argument("--kernel_size", type=int, default=13)
270
+ parser.add_argument("--channel_width", type=int, default=13)
271
+ # --- Training ---
272
+ parser.add_argument("--input_dir", type=str)
273
+ parser.add_argument("--processor", type=str)
274
+ parser.add_argument("--batch_size", type=int, default=32)
275
+ parser.add_argument("--lr", type=float, default=3e-4)
276
+ parser.add_argument("--lr_patience", type=int, default=20)
277
+ parser.add_argument("--patience", type=int, default=10)
278
+ parser.add_argument("--preload", action="store_true")
279
+ parser.add_argument("--sample_rate", type=int, default=24000)
280
+ parser.add_argument("--shuffle", type=bool, default=True)
281
+ parser.add_argument("--train_length", type=int, default=65536)
282
+ parser.add_argument("--train_examples_per_epoch", type=int, default=10000)
283
+ parser.add_argument("--val_length", type=int, default=131072)
284
+ parser.add_argument("--val_examples_per_epoch", type=int, default=1000)
285
+ parser.add_argument("--num_workers", type=int, default=8)
286
+ parser.add_argument("--buffer_reload_rate", type=int, default=1000)
287
+ parser.add_argument("--buffer_size_gb", type=float, default=1.0)
288
+
289
+ return parser
deepafx_st/processors/proxy/tcn.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Christian J. Steinmetz
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # TCN implementation adapted from:
16
+ # https://github.com/csteinmetz1/micro-tcn/blob/main/microtcn/tcn.py
17
+
18
+ import torch
19
+ from argparse import ArgumentParser
20
+
21
+ from deepafx_st.utils import center_crop, causal_crop
22
+
23
+
24
+ class FiLM(torch.nn.Module):
25
+ def __init__(self, num_features, cond_dim):
26
+ super().__init__()
27
+ self.num_features = num_features
28
+ self.bn = torch.nn.BatchNorm1d(num_features, affine=False)
29
+ self.adaptor = torch.nn.Linear(cond_dim, num_features * 2)
30
+
31
+ def forward(self, x, cond):
32
+
33
+ # project conditioning to 2 x num. conv channels
34
+ cond = self.adaptor(cond)
35
+
36
+ # split the projection into gain and bias
37
+ g, b = torch.chunk(cond, 2, dim=-1)
38
+
39
+ # add virtual channel dim if needed
40
+ if g.ndim == 2:
41
+ g = g.unsqueeze(1)
42
+ b = b.unsqueeze(1)
43
+
44
+ # reshape for application
45
+ g = g.permute(0, 2, 1)
46
+ b = b.permute(0, 2, 1)
47
+
48
+ x = self.bn(x) # apply BatchNorm without affine
49
+ x = (x * g) + b # then apply conditional affine
50
+
51
+ return x
52
+
53
+
54
+ class ConditionalTCNBlock(torch.nn.Module):
55
+ def __init__(
56
+ self, in_ch, out_ch, cond_dim, kernel_size=3, dilation=1, causal=False, **kwargs
57
+ ):
58
+ super().__init__()
59
+
60
+ self.in_ch = in_ch
61
+ self.out_ch = out_ch
62
+ self.kernel_size = kernel_size
63
+ self.dilation = dilation
64
+ self.causal = causal
65
+
66
+ self.conv1 = torch.nn.Conv1d(
67
+ in_ch,
68
+ out_ch,
69
+ kernel_size=kernel_size,
70
+ padding=0,
71
+ dilation=dilation,
72
+ bias=True,
73
+ )
74
+ self.film = FiLM(out_ch, cond_dim)
75
+ self.relu = torch.nn.PReLU(out_ch)
76
+ self.res = torch.nn.Conv1d(
77
+ in_ch, out_ch, kernel_size=1, groups=in_ch, bias=False
78
+ )
79
+
80
+ def forward(self, x, p):
81
+ x_in = x
82
+
83
+ x = self.conv1(x)
84
+ x = self.film(x, p) # apply FiLM conditioning
85
+ x = self.relu(x)
86
+ x_res = self.res(x_in)
87
+
88
+ if self.causal:
89
+ x = x + causal_crop(x_res, x.shape[-1])
90
+ else:
91
+ x = x + center_crop(x_res, x.shape[-1])
92
+
93
+ return x
94
+
95
+
96
+ class ConditionalTCN(torch.nn.Module):
97
+ """Temporal convolutional network with conditioning module.
98
+ Args:
99
+ sample_rate (float): Audio sample rate.
100
+ num_control_params (int, optional): Dimensionality of the conditioning signal. Default: 24
101
+ ninputs (int, optional): Number of input channels (mono = 1, stereo 2). Default: 1
102
+ noutputs (int, optional): Number of output channels (mono = 1, stereo 2). Default: 1
103
+ nblocks (int, optional): Number of total TCN blocks. Default: 10
104
+ kernel_size (int, optional: Width of the convolutional kernels. Default: 3
105
+ dialation_growth (int, optional): Compute the dilation factor at each block as dilation_growth ** (n % stack_size). Default: 1
106
+ channel_growth (int, optional): Compute the output channels at each black as in_ch * channel_growth. Default: 2
107
+ channel_width (int, optional): When channel_growth = 1 all blocks use convolutions with this many channels. Default: 64
108
+ stack_size (int, optional): Number of blocks that constitute a single stack of blocks. Default: 10
109
+ causal (bool, optional): Causal TCN configuration does not consider future input values. Default: False
110
+ """
111
+
112
+ def __init__(
113
+ self,
114
+ sample_rate,
115
+ num_control_params=24,
116
+ ninputs=1,
117
+ noutputs=1,
118
+ nblocks=10,
119
+ kernel_size=15,
120
+ dilation_growth=2,
121
+ channel_growth=1,
122
+ channel_width=64,
123
+ stack_size=10,
124
+ causal=False,
125
+ skip_connections=False,
126
+ **kwargs,
127
+ ):
128
+ super().__init__()
129
+ self.num_control_params = num_control_params
130
+ self.ninputs = ninputs
131
+ self.noutputs = noutputs
132
+ self.nblocks = nblocks
133
+ self.kernel_size = kernel_size
134
+ self.dilation_growth = dilation_growth
135
+ self.channel_growth = channel_growth
136
+ self.channel_width = channel_width
137
+ self.stack_size = stack_size
138
+ self.causal = causal
139
+ self.skip_connections = skip_connections
140
+ self.sample_rate = sample_rate
141
+
142
+ self.blocks = torch.nn.ModuleList()
143
+ for n in range(nblocks):
144
+ in_ch = out_ch if n > 0 else ninputs
145
+
146
+ if self.channel_growth > 1:
147
+ out_ch = in_ch * self.channel_growth
148
+ else:
149
+ out_ch = self.channel_width
150
+
151
+ dilation = self.dilation_growth ** (n % self.stack_size)
152
+
153
+ self.blocks.append(
154
+ ConditionalTCNBlock(
155
+ in_ch,
156
+ out_ch,
157
+ self.num_control_params,
158
+ kernel_size=self.kernel_size,
159
+ dilation=dilation,
160
+ padding="same" if self.causal else "valid",
161
+ causal=self.causal,
162
+ )
163
+ )
164
+
165
+ self.output = torch.nn.Conv1d(out_ch, noutputs, kernel_size=1)
166
+ self.receptive_field = self.compute_receptive_field()
167
+ # print(
168
+ # f"TCN receptive field: {self.receptive_field} samples",
169
+ # f" or {(self.receptive_field/self.sample_rate)*1e3:0.3f} ms",
170
+ # )
171
+
172
+ def forward(self, x, p, **kwargs):
173
+
174
+ # causally pad input signal
175
+ x = torch.nn.functional.pad(x, (self.receptive_field - 1, 0))
176
+
177
+ # iterate over blocks passing conditioning
178
+ for idx, block in enumerate(self.blocks):
179
+ x = block(x, p)
180
+ if self.skip_connections:
181
+ if idx == 0:
182
+ skips = x
183
+ else:
184
+ skips = center_crop(skips, x[-1]) + x
185
+ else:
186
+ skips = 0
187
+
188
+ # final 1x1 convolution to collapse channels
189
+ out = self.output(x + skips)
190
+
191
+ return out
192
+
193
+ def compute_receptive_field(self):
194
+ """Compute the receptive field in samples."""
195
+ rf = self.kernel_size
196
+ for n in range(1, self.nblocks):
197
+ dilation = self.dilation_growth ** (n % self.stack_size)
198
+ rf = rf + ((self.kernel_size - 1) * dilation)
199
+ return rf
deepafx_st/processors/spsa/channel.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.multiprocessing as mp
4
+
5
+ from deepafx_st.processors.dsp.peq import ParametricEQ
6
+ from deepafx_st.processors.dsp.compressor import Compressor
7
+ from deepafx_st.processors.spsa.spsa_func import SPSAFunction
8
+ from deepafx_st.utils import rademacher
9
+
10
+
11
+ def dsp_func(x, p, dsp, sample_rate=24000):
12
+
13
+ (peq, comp), meta = dsp
14
+
15
+ p_peq = p[:meta]
16
+ p_comp = p[meta:]
17
+
18
+ y = peq(x, p_peq, sample_rate)
19
+ y = comp(y, p_comp, sample_rate)
20
+
21
+ return y
22
+
23
+
24
+ class SPSAChannel(torch.nn.Module):
25
+ """
26
+
27
+ Args:
28
+ sample_rate (float): Sample rate of the plugin instance
29
+ parallel (bool, optional): Use parallel workers for DSP.
30
+
31
+ By default, this utilizes parallelized instances of the plugin channel,
32
+ where the number of workers is equal to the batch size.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ sample_rate: int,
38
+ parallel: bool = False,
39
+ batch_size: int = 8,
40
+ ):
41
+ super().__init__()
42
+
43
+ self.batch_size = batch_size
44
+ self.parallel = parallel
45
+
46
+ if self.parallel:
47
+ self.apply_func = SPSAFunction.apply
48
+
49
+ procs = {}
50
+ for b in range(self.batch_size):
51
+
52
+ peq = ParametricEQ(sample_rate)
53
+ comp = Compressor(sample_rate)
54
+ dsp = ((peq, comp), peq.num_control_params)
55
+
56
+ parent_conn, child_conn = mp.Pipe()
57
+ p = mp.Process(target=SPSAChannel.worker_pipe, args=(child_conn, dsp))
58
+ p.start()
59
+ procs[b] = [p, parent_conn, child_conn]
60
+ #print(b, p)
61
+
62
+ # Update stuff for external public members TODO: fix
63
+ self.ports = [peq.ports, comp.ports]
64
+ self.num_control_params = (
65
+ comp.num_control_params + peq.num_control_params
66
+ )
67
+
68
+ self.procs = procs
69
+ #print(self.procs)
70
+
71
+ else:
72
+ self.peq = ParametricEQ(sample_rate)
73
+ self.comp = Compressor(sample_rate)
74
+ self.apply_func = SPSAFunction.apply
75
+ self.ports = [self.peq.ports, self.comp.ports]
76
+ self.num_control_params = (
77
+ self.comp.num_control_params + self.peq.num_control_params
78
+ )
79
+ self.dsp = ((self.peq, self.comp), self.peq.num_control_params)
80
+
81
+ # add one param for wet/dry mix
82
+ # self.num_control_params += 1
83
+
84
+ def __del__(self):
85
+ if hasattr(self, "procs"):
86
+ for proc_idx, proc in self.procs.items():
87
+ #print(f"Closing {proc_idx}...")
88
+ proc[0].terminate()
89
+
90
+ def forward(self, x, p, epsilon=0.001, sample_rate=24000, **kwargs):
91
+ """
92
+ Args:
93
+ x (Tensor): Input signal with shape: [batch x channels x samples]
94
+ p (Tensor): Audio effect control parameters with shape: [batch x parameters]
95
+ epsilon (float, optional): Twiddle parameter range for SPSA gradient estimation.
96
+
97
+ Returns:
98
+ y (Tensor): Processed audio signal.
99
+
100
+ """
101
+ if self.parallel:
102
+ y = self.apply_func(x, p, None, epsilon, self, sample_rate)
103
+
104
+ else:
105
+ # this will process on CPU in NumPy
106
+ y = self.apply_func(x, p, None, epsilon, self, sample_rate)
107
+
108
+ return y.type_as(x)
109
+
110
+ @staticmethod
111
+ def static_backward(dsp, value):
112
+
113
+ (
114
+ batch_index,
115
+ x,
116
+ params,
117
+ needs_input_grad,
118
+ needs_param_grad,
119
+ grad_output,
120
+ epsilon,
121
+ ) = value
122
+
123
+ grads_input = None
124
+ grads_params = None
125
+ ps = params.shape[-1]
126
+ factors = [1.0]
127
+
128
+ # estimate gradient w.r.t input
129
+ if needs_input_grad:
130
+ delta_k = rademacher(x.shape).numpy()
131
+ J_plus = dsp_func(x + epsilon * delta_k, params, dsp)
132
+ J_minus = dsp_func(x - epsilon * delta_k, params, dsp)
133
+ grads_input = (J_plus - J_minus) / (2.0 * epsilon)
134
+
135
+ # estimate gradient w.r.t params
136
+ grads_params_runs = []
137
+ if needs_param_grad:
138
+ for factor in factors:
139
+ params_sublist = []
140
+ delta_k = rademacher(params.shape).numpy()
141
+
142
+ # compute output in two random directions of the parameter space
143
+ params_plus = np.clip(params + (factor * epsilon * delta_k), 0, 1)
144
+ J_plus = dsp_func(x, params_plus, dsp)
145
+
146
+ params_minus = np.clip(params - (factor * epsilon * delta_k), 0, 1)
147
+ J_minus = dsp_func(x, params_minus, dsp)
148
+ grad_param = J_plus - J_minus
149
+
150
+ # compute gradient for each parameter as a function of epsilon and random direction
151
+ for sub_p_idx in range(ps):
152
+ grad_p = grad_param / (2 * epsilon * delta_k[sub_p_idx])
153
+ params_sublist.append(np.sum(grad_output * grad_p))
154
+
155
+ grads_params = np.array(params_sublist)
156
+ grads_params_runs.append(grads_params)
157
+
158
+ # average gradients
159
+ grads_params = np.mean(grads_params_runs, axis=0)
160
+
161
+ return grads_input, grads_params
162
+
163
+ @staticmethod
164
+ def static_forward(dsp, value):
165
+ batch_index, x, p, sample_rate = value
166
+ y = dsp_func(x, p, dsp, sample_rate)
167
+ return y
168
+
169
+ @staticmethod
170
+ def worker_pipe(child_conn, dsp):
171
+
172
+ while True:
173
+ msg, value = child_conn.recv()
174
+ if msg == "forward":
175
+ child_conn.send(SPSAChannel.static_forward(dsp, value))
176
+ elif msg == "backward":
177
+ child_conn.send(SPSAChannel.static_backward(dsp, value))
178
+ elif msg == "shutdown":
179
+ break
deepafx_st/processors/spsa/eps_scheduler.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class EpsilonScheduler:
5
+ def __init__(
6
+ self,
7
+ epsilon: float = 0.001,
8
+ patience: int = 10,
9
+ factor: float = 0.5,
10
+ verbose: bool = False,
11
+ ):
12
+ self.epsilon = epsilon
13
+ self.patience = patience
14
+ self.factor = factor
15
+ self.best = 1e16
16
+ self.count = 0
17
+ self.verbose = verbose
18
+
19
+ def step(self, metric: float):
20
+
21
+ if metric < self.best:
22
+ self.best = metric
23
+ self.count = 0
24
+ else:
25
+ self.count += 1
26
+ if self.verbose:
27
+ print(f"Train loss has not improved for {self.count} epochs.")
28
+ if self.count >= self.patience:
29
+ self.count = 0
30
+ self.epsilon *= self.factor
31
+ if self.verbose:
32
+ print(f"Reducing epsilon to {self.epsilon:0.2e}...")
deepafx_st/processors/spsa/spsa_func.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def spsa_func(input, params, process, i, sample_rate=24000):
5
+ return process(input.cpu(), params.cpu(), i, sample_rate).type_as(input)
6
+
7
+
8
+ class SPSAFunction(torch.autograd.Function):
9
+ @staticmethod
10
+ def forward(
11
+ ctx,
12
+ input,
13
+ params,
14
+ process,
15
+ epsilon,
16
+ thread_context,
17
+ sample_rate=24000,
18
+ ):
19
+ """Apply processor to a batch of tensors using given parameters.
20
+
21
+ Args:
22
+ input (Tensor): Audio with shape: batch x 2 x samples
23
+ params (Tensor): Processor parameters with shape: batch x params
24
+ process (function): Function that will apply processing.
25
+ epsilon (float): Perturbation strength for SPSA computation.
26
+
27
+ Returns:
28
+ output (Tensor): Processed audio with same shape as input.
29
+ """
30
+ ctx.save_for_backward(input, params)
31
+ ctx.epsilon = epsilon
32
+ ctx.process = process
33
+ ctx.thread_context = thread_context
34
+
35
+ if thread_context.parallel:
36
+
37
+ for i in range(input.shape[0]):
38
+ msg = (
39
+ "forward",
40
+ (
41
+ i,
42
+ input[i].view(-1).detach().cpu().numpy(),
43
+ params[i].view(-1).detach().cpu().numpy(),
44
+ sample_rate,
45
+ ),
46
+ )
47
+ thread_context.procs[i][1].send(msg)
48
+
49
+ z = torch.empty_like(input)
50
+ for i in range(input.shape[0]):
51
+ z[i] = torch.from_numpy(thread_context.procs[i][1].recv())
52
+ else:
53
+ z = torch.empty_like(input)
54
+ for i in range(input.shape[0]):
55
+ value = (
56
+ i,
57
+ input[i].view(-1).detach().cpu().numpy(),
58
+ params[i].view(-1).detach().cpu().numpy(),
59
+ sample_rate,
60
+ )
61
+ z[i] = torch.from_numpy(
62
+ thread_context.static_forward(thread_context.dsp, value)
63
+ )
64
+
65
+ return z
66
+
67
+ @staticmethod
68
+ def backward(ctx, grad_output):
69
+ """Estimate gradients using SPSA."""
70
+
71
+ input, params = ctx.saved_tensors
72
+ epsilon = ctx.epsilon
73
+ needs_input_grad = ctx.needs_input_grad[0]
74
+ needs_param_grad = ctx.needs_input_grad[1]
75
+ thread_context = ctx.thread_context
76
+
77
+ grads_input = None
78
+ grads_params = None
79
+
80
+ # Receive grads
81
+ if needs_input_grad:
82
+ grads_input = torch.empty_like(input)
83
+ if needs_param_grad:
84
+ grads_params = torch.empty_like(params)
85
+
86
+ if thread_context.parallel:
87
+
88
+ for i in range(input.shape[0]):
89
+ msg = (
90
+ "backward",
91
+ (
92
+ i,
93
+ input[i].view(-1).detach().cpu().numpy(),
94
+ params[i].view(-1).detach().cpu().numpy(),
95
+ needs_input_grad,
96
+ needs_param_grad,
97
+ grad_output[i].view(-1).detach().cpu().numpy(),
98
+ epsilon,
99
+ ),
100
+ )
101
+ thread_context.procs[i][1].send(msg)
102
+
103
+ # Wait for output
104
+ for i in range(input.shape[0]):
105
+ temp1, temp2 = thread_context.procs[i][1].recv()
106
+
107
+ if temp1 is not None:
108
+ grads_input[i] = torch.from_numpy(temp1)
109
+
110
+ if temp2 is not None:
111
+ grads_params[i] = torch.from_numpy(temp2)
112
+
113
+ return grads_input, grads_params, None, None, None, None
114
+ else:
115
+ for i in range(input.shape[0]):
116
+ value = (
117
+ i,
118
+ input[i].view(-1).detach().cpu().numpy(),
119
+ params[i].view(-1).detach().cpu().numpy(),
120
+ needs_input_grad,
121
+ needs_param_grad,
122
+ grad_output[i].view(-1).detach().cpu().numpy(),
123
+ epsilon,
124
+ )
125
+ temp1, temp2 = thread_context.static_backward(thread_context.dsp, value)
126
+ if temp1 is not None:
127
+ grads_input[i] = torch.from_numpy(temp1)
128
+
129
+ if temp2 is not None:
130
+ grads_params[i] = torch.from_numpy(temp2)
131
+ return grads_input, grads_params, None, None, None, None
deepafx_st/system.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import auraloss
3
+ import torchaudio
4
+ from itertools import chain
5
+ import pytorch_lightning as pl
6
+ from argparse import ArgumentParser
7
+ from typing import Tuple, List, Dict
8
+
9
+ import deepafx_st.utils as utils
10
+ from deepafx_st.utils import DSPMode
11
+ from deepafx_st.data.dataset import AudioDataset
12
+ from deepafx_st.models.encoder import SpectralEncoder
13
+ from deepafx_st.models.controller import StyleTransferController
14
+ from deepafx_st.processors.spsa.channel import SPSAChannel
15
+ from deepafx_st.processors.spsa.eps_scheduler import EpsilonScheduler
16
+ from deepafx_st.processors.proxy.channel import ProxyChannel
17
+ from deepafx_st.processors.autodiff.channel import AutodiffChannel
18
+
19
+
20
+ class System(pl.LightningModule):
21
+ def __init__(
22
+ self,
23
+ ext="wav",
24
+ dsp_sample_rate=24000,
25
+ **kwargs,
26
+ ):
27
+ super().__init__()
28
+ self.save_hyperparameters()
29
+
30
+ self.eps_scheduler = EpsilonScheduler(
31
+ self.hparams.spsa_epsilon,
32
+ self.hparams.spsa_patience,
33
+ self.hparams.spsa_factor,
34
+ self.hparams.spsa_verbose,
35
+ )
36
+
37
+ self.hparams.dsp_mode = DSPMode.NONE
38
+
39
+ # first construct the processor, since this will dictate encoder
40
+ if self.hparams.processor_model == "spsa":
41
+ self.processor = SPSAChannel(
42
+ self.hparams.dsp_sample_rate,
43
+ self.hparams.spsa_parallel,
44
+ self.hparams.batch_size,
45
+ )
46
+ elif self.hparams.processor_model == "autodiff":
47
+ self.processor = AutodiffChannel(self.hparams.dsp_sample_rate)
48
+ elif self.hparams.processor_model == "proxy0":
49
+ # print('self.hparams.proxy_ckpts,',self.hparams.proxy_ckpts)
50
+ self.hparams.dsp_mode = DSPMode.NONE
51
+ self.processor = ProxyChannel(
52
+ self.hparams.proxy_ckpts,
53
+ self.hparams.freeze_proxies,
54
+ self.hparams.dsp_mode,
55
+ sample_rate=self.hparams.dsp_sample_rate,
56
+ )
57
+ elif self.hparams.processor_model == "proxy1":
58
+ # print('self.hparams.proxy_ckpts,',self.hparams.proxy_ckpts)
59
+ self.hparams.dsp_mode = DSPMode.INFER
60
+ self.processor = ProxyChannel(
61
+ self.hparams.proxy_ckpts,
62
+ self.hparams.freeze_proxies,
63
+ self.hparams.dsp_mode,
64
+ sample_rate=self.hparams.dsp_sample_rate,
65
+ )
66
+ elif self.hparams.processor_model == "proxy2":
67
+ # print('self.hparams.proxy_ckpts,',self.hparams.proxy_ckpts)
68
+ self.hparams.dsp_mode = DSPMode.TRAIN_INFER
69
+ self.processor = ProxyChannel(
70
+ self.hparams.proxy_ckpts,
71
+ self.hparams.freeze_proxies,
72
+ self.hparams.dsp_mode,
73
+ sample_rate=self.hparams.dsp_sample_rate,
74
+ )
75
+ elif self.hparams.processor_model == "tcn1":
76
+ # self.processor = ConditionalTCN(self.hparams.sample_rate)
77
+ self.hparams.dsp_mode = DSPMode.NONE
78
+ self.processor = ProxyChannel(
79
+ [],
80
+ freeze_proxies=False,
81
+ dsp_mode=self.hparams.dsp_mode,
82
+ tcn_nblocks=self.hparams.tcn_nblocks,
83
+ tcn_dilation_growth=self.hparams.tcn_dilation_growth,
84
+ tcn_channel_width=self.hparams.tcn_channel_width,
85
+ tcn_kernel_size=self.hparams.tcn_kernel_size,
86
+ num_tcns=1,
87
+ sample_rate=self.hparams.sample_rate,
88
+ )
89
+ elif self.hparams.processor_model == "tcn2":
90
+ self.hparams.dsp_mode = DSPMode.NONE
91
+ self.processor = ProxyChannel(
92
+ [],
93
+ freeze_proxies=False,
94
+ dsp_mode=self.hparams.dsp_mode,
95
+ tcn_nblocks=self.hparams.tcn_nblocks,
96
+ tcn_dilation_growth=self.hparams.tcn_dilation_growth,
97
+ tcn_channel_width=self.hparams.tcn_channel_width,
98
+ tcn_kernel_size=self.hparams.tcn_kernel_size,
99
+ num_tcns=2,
100
+ sample_rate=self.hparams.sample_rate,
101
+ )
102
+ else:
103
+ raise ValueError(f"Invalid processor_model: {self.hparams.processor_model}")
104
+
105
+ if self.hparams.encoder_ckpt is not None:
106
+ # load encoder weights from a pre-trained system
107
+ system = System.load_from_checkpoint(self.hparams.encoder_ckpt)
108
+ self.encoder = system.encoder
109
+ self.hparams.encoder_embed_dim = system.encoder.embed_dim
110
+ else:
111
+ self.encoder = SpectralEncoder(
112
+ self.processor.num_control_params,
113
+ self.hparams.sample_rate,
114
+ encoder_model=self.hparams.encoder_model,
115
+ embed_dim=self.hparams.encoder_embed_dim,
116
+ width_mult=self.hparams.encoder_width_mult,
117
+ )
118
+
119
+ if self.hparams.encoder_freeze:
120
+ for param in self.encoder.parameters():
121
+ param.requires_grad = False
122
+
123
+ self.controller = StyleTransferController(
124
+ self.processor.num_control_params,
125
+ self.hparams.encoder_embed_dim,
126
+ )
127
+
128
+ if len(self.hparams.recon_losses) != len(self.hparams.recon_loss_weights):
129
+ raise ValueError("Must supply same number of weights as losses.")
130
+
131
+ self.recon_losses = torch.nn.ModuleDict()
132
+ for recon_loss in self.hparams.recon_losses:
133
+ if recon_loss == "mrstft":
134
+ self.recon_losses[recon_loss] = auraloss.freq.MultiResolutionSTFTLoss(
135
+ fft_sizes=[32, 128, 512, 2048, 8192, 32768],
136
+ hop_sizes=[16, 64, 256, 1024, 4096, 16384],
137
+ win_lengths=[32, 128, 512, 2048, 8192, 32768],
138
+ w_sc=0.0,
139
+ w_phs=0.0,
140
+ w_lin_mag=1.0,
141
+ w_log_mag=1.0,
142
+ )
143
+ elif recon_loss == "mrstft-md":
144
+ self.recon_losses[recon_loss] = auraloss.freq.MultiResolutionSTFTLoss(
145
+ fft_sizes=[128, 512, 2048, 8192],
146
+ hop_sizes=[32, 128, 512, 2048], # 1 / 4
147
+ win_lengths=[128, 512, 2048, 8192],
148
+ w_sc=0.0,
149
+ w_phs=0.0,
150
+ w_lin_mag=1.0,
151
+ w_log_mag=1.0,
152
+ )
153
+ elif recon_loss == "mrstft-sm":
154
+ self.recon_losses[recon_loss] = auraloss.freq.MultiResolutionSTFTLoss(
155
+ fft_sizes=[512, 2048, 8192],
156
+ hop_sizes=[256, 1024, 4096], # 1 / 4
157
+ win_lengths=[512, 2048, 8192],
158
+ w_sc=0.0,
159
+ w_phs=0.0,
160
+ w_lin_mag=1.0,
161
+ w_log_mag=1.0,
162
+ )
163
+ elif recon_loss == "melfft":
164
+ self.recon_losses[recon_loss] = auraloss.freq.MelSTFTLoss(
165
+ self.hparams.sample_rate,
166
+ fft_size=self.hparams.train_length,
167
+ hop_size=self.hparams.train_length // 2,
168
+ win_length=self.hparams.train_length,
169
+ n_mels=128,
170
+ w_sc=0.0,
171
+ device="cuda" if self.hparams.gpus > 0 else "cpu",
172
+ )
173
+ elif recon_loss == "melstft":
174
+ self.recon_losses[recon_loss] = auraloss.freq.MelSTFTLoss(
175
+ self.hparams.sample_rate,
176
+ device="cuda" if self.hparams.gpus > 0 else "cpu",
177
+ )
178
+ elif recon_loss == "l1":
179
+ self.recon_losses[recon_loss] = torch.nn.L1Loss()
180
+ elif recon_loss == "sisdr":
181
+ self.recon_losses[recon_loss] = auraloss.time.SISDRLoss()
182
+ else:
183
+ raise ValueError(
184
+ f"Invalid reconstruction loss: {self.hparams.recon_losses}"
185
+ )
186
+
187
+ def forward(
188
+ self,
189
+ x: torch.Tensor,
190
+ y: torch.Tensor = None,
191
+ e_y: torch.Tensor = None,
192
+ z: torch.Tensor = None,
193
+ dsp_mode: DSPMode = DSPMode.NONE,
194
+ analysis_length: int = 0,
195
+ sample_rate: int = 24000,
196
+ ):
197
+ """Forward pass through the system subnetworks.
198
+
199
+ Args:
200
+ x (tensor): Input audio tensor with shape (batch x 1 x samples)
201
+ y (tensor): Target audio tensor with shape (batch x 1 x samples)
202
+ e_y (tensor): Target embedding with shape (batch x edim)
203
+ z (tensor): Bottleneck latent.
204
+ dsp_mode (DSPMode): Mode of operation for the DSP blocks.
205
+ analysis_length (optional, int): Only analyze the first N samples.
206
+ sample_rate (optional, int): Desired sampling rate for the DSP blocks.
207
+
208
+ You must supply target audio `y`, `z`, or an embedding for the target `e_y`.
209
+
210
+ Returns:
211
+ y_hat (tensor): Output audio.
212
+ p (tensor):
213
+ e (tensor):
214
+
215
+ """
216
+ bs, chs, samp = x.size()
217
+
218
+ if sample_rate != self.hparams.sample_rate:
219
+ x_enc = torchaudio.transforms.Resample(
220
+ sample_rate, self.hparams.sample_rate
221
+ ).to(x.device)(x)
222
+ if y is not None:
223
+ y_enc = torchaudio.transforms.Resample(
224
+ sample_rate, self.hparams.sample_rate
225
+ ).to(x.device)(y)
226
+ else:
227
+ x_enc = x
228
+ y_enc = y
229
+
230
+ if analysis_length > 0:
231
+ x_enc = x_enc[..., :analysis_length]
232
+ if y is not None:
233
+ y_enc = y_enc[..., :analysis_length]
234
+
235
+ e_x = self.encoder(x_enc) # generate latent embedding for input
236
+
237
+ if y is not None:
238
+ e_y = self.encoder(y_enc) # generate latent embedding for target
239
+ elif e_y is None:
240
+ raise RuntimeError("Must supply y, z, or e_y. None supplied.")
241
+
242
+ # learnable comparision
243
+ p = self.controller(e_x, e_y, z=z)
244
+
245
+ # process audio conditioned on parameters
246
+ # if there are multiple channels process them using same parameters
247
+ y_hat = torch.zeros(x.shape).type_as(x)
248
+ for ch_idx in range(chs):
249
+ y_hat_ch = self.processor(
250
+ x[:, ch_idx : ch_idx + 1, :],
251
+ p,
252
+ epsilon=self.eps_scheduler.epsilon,
253
+ dsp_mode=dsp_mode,
254
+ sample_rate=sample_rate,
255
+ )
256
+ y_hat[:, ch_idx : ch_idx + 1, :] = y_hat_ch
257
+
258
+ return y_hat, p, e_x
259
+
260
+ def common_paired_step(
261
+ self,
262
+ batch: Tuple,
263
+ batch_idx: int,
264
+ optimizer_idx: int = 0,
265
+ train: bool = False,
266
+ ):
267
+ """Model step used for validation and training.
268
+
269
+ Args:
270
+ batch (Tuple[Tensor, Tensor]): Batch items containing input audio (x) and target audio (y).
271
+ batch_idx (int): Index of the batch within the current epoch.
272
+ optimizer_idx (int): Index of the optimizer, this step is called once for each optimizer.
273
+ The firs optimizer corresponds to the generator and the second optimizer,
274
+ corresponds to the adversarial loss (when in use).
275
+ train (bool): Whether step is called during training (True) or validation (False).
276
+ """
277
+ x, y = batch
278
+ loss = 0
279
+ dsp_mode = self.hparams.dsp_mode
280
+
281
+ if train and dsp_mode.INFER.name == DSPMode.INFER.name:
282
+ dsp_mode = DSPMode.NONE
283
+
284
+ # proces input audio through model
285
+ if self.hparams.style_transfer:
286
+ length = x.shape[-1]
287
+
288
+ x_A = x[..., : length // 2]
289
+ x_B = x[..., length // 2 :]
290
+
291
+ y_A = y[..., : length // 2]
292
+ y_B = y[..., length // 2 :]
293
+
294
+ if torch.rand(1).sum() > 0.5:
295
+ y_ref = y_B
296
+ y = y_A
297
+ x = x_A
298
+ else:
299
+ y_ref = y_A
300
+ y = y_B
301
+ x = x_B
302
+
303
+ y_hat, p, e = self(x, y=y_ref, dsp_mode=dsp_mode)
304
+ else:
305
+ y_ref = None
306
+ y_hat, p, e = self(x, dsp_mode=dsp_mode)
307
+
308
+ # compute reconstruction loss terms
309
+ for loss_idx, (loss_name, recon_loss_fn) in enumerate(
310
+ self.recon_losses.items()
311
+ ):
312
+ temp_loss = recon_loss_fn(y_hat, y) # reconstruction loss
313
+ loss += float(self.hparams.recon_loss_weights[loss_idx]) * temp_loss
314
+
315
+ self.log(
316
+ ("train" if train else "val") + f"_loss/{loss_name}",
317
+ temp_loss,
318
+ on_step=True,
319
+ on_epoch=True,
320
+ prog_bar=False,
321
+ logger=True,
322
+ sync_dist=True,
323
+ )
324
+
325
+ # log the overall aggregate loss
326
+ self.log(
327
+ ("train" if train else "val") + "_loss/loss",
328
+ loss,
329
+ on_step=True,
330
+ on_epoch=True,
331
+ prog_bar=False,
332
+ logger=True,
333
+ sync_dist=True,
334
+ )
335
+
336
+ # store audio data
337
+ data_dict = {
338
+ "x": x.cpu(),
339
+ "y": y.cpu(),
340
+ "p": p.cpu(),
341
+ "e": e.cpu(),
342
+ "y_hat": y_hat.cpu(),
343
+ }
344
+
345
+ if y_ref is not None:
346
+ data_dict["y_ref"] = y_ref.cpu()
347
+
348
+ return loss, data_dict
349
+
350
+ def training_step(self, batch, batch_idx, optimizer_idx=0):
351
+ loss, _ = self.common_paired_step(
352
+ batch,
353
+ batch_idx,
354
+ optimizer_idx,
355
+ train=True,
356
+ )
357
+
358
+ return loss
359
+
360
+ def training_epoch_end(self, training_step_outputs):
361
+ if self.hparams.spsa_schedule and self.hparams.processor_model == "spsa":
362
+ self.eps_scheduler.step(
363
+ self.trainer.callback_metrics[self.hparams.train_monitor],
364
+ )
365
+
366
+ def validation_step(self, batch, batch_idx):
367
+ loss, data_dict = self.common_paired_step(batch, batch_idx)
368
+
369
+ return data_dict
370
+
371
+ def optimizer_step(
372
+ self,
373
+ epoch,
374
+ batch_idx,
375
+ optimizer,
376
+ optimizer_idx,
377
+ optimizer_closure,
378
+ on_tpu=False,
379
+ using_native_amp=False,
380
+ using_lbfgs=False,
381
+ ):
382
+ if optimizer_idx == 0:
383
+ optimizer.step(closure=optimizer_closure)
384
+
385
+ def configure_optimizers(self):
386
+ # we need additional optimizer for the discriminator
387
+ optimizers = []
388
+ g_optimizer = torch.optim.Adam(
389
+ chain(
390
+ self.encoder.parameters(),
391
+ self.processor.parameters(),
392
+ self.controller.parameters(),
393
+ ),
394
+ lr=self.hparams.lr,
395
+ betas=(0.9, 0.999),
396
+ )
397
+ optimizers.append(g_optimizer)
398
+
399
+ g_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
400
+ g_optimizer,
401
+ patience=self.hparams.lr_patience,
402
+ verbose=True,
403
+ )
404
+ ms1 = int(self.hparams.max_epochs * 0.8)
405
+ ms2 = int(self.hparams.max_epochs * 0.95)
406
+ print(
407
+ "Learning rate schedule:",
408
+ f"0 {self.hparams.lr:0.2e} -> ",
409
+ f"{ms1} {self.hparams.lr*0.1:0.2e} -> ",
410
+ f"{ms2} {self.hparams.lr*0.01:0.2e}",
411
+ )
412
+ g_scheduler = torch.optim.lr_scheduler.MultiStepLR(
413
+ g_optimizer,
414
+ milestones=[ms1, ms2],
415
+ gamma=0.1,
416
+ )
417
+
418
+ lr_schedulers = {
419
+ "scheduler": g_scheduler,
420
+ }
421
+
422
+ return optimizers, lr_schedulers
423
+
424
+ def train_dataloader(self):
425
+
426
+ train_dataset = AudioDataset(
427
+ self.hparams.audio_dir,
428
+ subset="train",
429
+ train_frac=self.hparams.train_frac,
430
+ half=self.hparams.half,
431
+ length=self.hparams.train_length,
432
+ input_dirs=self.hparams.input_dirs,
433
+ random_scale_input=self.hparams.random_scale_input,
434
+ random_scale_target=self.hparams.random_scale_target,
435
+ buffer_size_gb=self.hparams.buffer_size_gb,
436
+ buffer_reload_rate=self.hparams.buffer_reload_rate,
437
+ num_examples_per_epoch=self.hparams.train_examples_per_epoch,
438
+ augmentations={
439
+ "pitch": {"sr": self.hparams.sample_rate},
440
+ "tempo": {"sr": self.hparams.sample_rate},
441
+ },
442
+ freq_corrupt=self.hparams.freq_corrupt,
443
+ drc_corrupt=self.hparams.drc_corrupt,
444
+ ext=self.hparams.ext,
445
+ )
446
+
447
+ g = torch.Generator()
448
+ g.manual_seed(0)
449
+
450
+ return torch.utils.data.DataLoader(
451
+ train_dataset,
452
+ num_workers=self.hparams.num_workers,
453
+ batch_size=self.hparams.batch_size,
454
+ worker_init_fn=utils.seed_worker,
455
+ generator=g,
456
+ pin_memory=True,
457
+ persistent_workers=True,
458
+ timeout=60,
459
+ )
460
+
461
+ def val_dataloader(self):
462
+
463
+ val_dataset = AudioDataset(
464
+ self.hparams.audio_dir,
465
+ subset="val",
466
+ half=self.hparams.half,
467
+ train_frac=self.hparams.train_frac,
468
+ length=self.hparams.val_length,
469
+ input_dirs=self.hparams.input_dirs,
470
+ buffer_size_gb=self.hparams.buffer_size_gb,
471
+ buffer_reload_rate=self.hparams.buffer_reload_rate,
472
+ random_scale_input=self.hparams.random_scale_input,
473
+ random_scale_target=self.hparams.random_scale_target,
474
+ num_examples_per_epoch=self.hparams.val_examples_per_epoch,
475
+ augmentations={},
476
+ freq_corrupt=self.hparams.freq_corrupt,
477
+ drc_corrupt=self.hparams.drc_corrupt,
478
+ ext=self.hparams.ext,
479
+ )
480
+
481
+ self.val_dataset = val_dataset
482
+
483
+ g = torch.Generator()
484
+ g.manual_seed(0)
485
+
486
+ return torch.utils.data.DataLoader(
487
+ val_dataset,
488
+ num_workers=1,
489
+ batch_size=self.hparams.batch_size,
490
+ worker_init_fn=utils.seed_worker,
491
+ generator=g,
492
+ pin_memory=True,
493
+ persistent_workers=True,
494
+ timeout=60,
495
+ )
496
+ def shutdown(self):
497
+ del self.processor
498
+
499
+ # add any model hyperparameters here
500
+ @staticmethod
501
+ def add_model_specific_args(parent_parser):
502
+ parser = ArgumentParser(parents=[parent_parser], add_help=False)
503
+ # --- Training ---
504
+ parser.add_argument("--batch_size", type=int, default=32)
505
+ parser.add_argument("--lr", type=float, default=3e-4)
506
+ parser.add_argument("--lr_patience", type=int, default=20)
507
+ parser.add_argument("--recon_losses", nargs="+", default=["l1"])
508
+ parser.add_argument("--recon_loss_weights", nargs="+", default=[1.0])
509
+ # --- Controller ---
510
+ parser.add_argument(
511
+ "--processor_model",
512
+ type=str,
513
+ help="autodiff, spsa, tcn1, tcn2, proxy0, proxy1, proxy2",
514
+ )
515
+ parser.add_argument("--controller_hidden_dim", type=int, default=256)
516
+ parser.add_argument("--style_transfer", action="store_true")
517
+ # --- Encoder ---
518
+ parser.add_argument("--encoder_model", type=str, default="mobilenet_v2")
519
+ parser.add_argument("--encoder_embed_dim", type=int, default=128)
520
+ parser.add_argument("--encoder_width_mult", type=int, default=2)
521
+ parser.add_argument("--encoder_ckpt", type=str, default=None)
522
+ parser.add_argument("--encoder_freeze", action="store_true", default=False)
523
+ # --- TCN ---
524
+ parser.add_argument("--tcn_causal", action="store_true")
525
+ parser.add_argument("--tcn_nblocks", type=int, default=4)
526
+ parser.add_argument("--tcn_dilation_growth", type=int, default=8)
527
+ parser.add_argument("--tcn_channel_width", type=int, default=32)
528
+ parser.add_argument("--tcn_kernel_size", type=int, default=13)
529
+ # --- SPSA ---
530
+ parser.add_argument("--plugin_config_file", type=str, default=None)
531
+ parser.add_argument("--spsa_epsilon", type=float, default=0.001)
532
+ parser.add_argument("--spsa_schedule", action="store_true")
533
+ parser.add_argument("--spsa_patience", type=int, default=10)
534
+ parser.add_argument("--spsa_verbose", action="store_true")
535
+ parser.add_argument("--spsa_factor", type=float, default=0.5)
536
+ parser.add_argument("--spsa_parallel", action="store_true")
537
+ # --- Proxy ----
538
+ parser.add_argument("--proxy_ckpts", nargs="+")
539
+ parser.add_argument("--freeze_proxies", action="store_true", default=False)
540
+ parser.add_argument("--use_dsp", action="store_true", default=False)
541
+ parser.add_argument("--dsp_mode", choices=DSPMode, type=DSPMode)
542
+ # --- Dataset ---
543
+ parser.add_argument("--audio_dir", type=str)
544
+ parser.add_argument("--ext", type=str, default="wav")
545
+ parser.add_argument("--input_dirs", nargs="+")
546
+ parser.add_argument("--buffer_reload_rate", type=int, default=1000)
547
+ parser.add_argument("--buffer_size_gb", type=float, default=1.0)
548
+ parser.add_argument("--sample_rate", type=int, default=24000)
549
+ parser.add_argument("--dsp_sample_rate", type=int, default=24000)
550
+ parser.add_argument("--shuffle", type=bool, default=True)
551
+ parser.add_argument("--random_scale_input", action="store_true")
552
+ parser.add_argument("--random_scale_target", action="store_true")
553
+ parser.add_argument("--freq_corrupt", action="store_true")
554
+ parser.add_argument("--drc_corrupt", action="store_true")
555
+ parser.add_argument("--train_length", type=int, default=65536)
556
+ parser.add_argument("--train_frac", type=float, default=0.8)
557
+ parser.add_argument("--half", action="store_true")
558
+ parser.add_argument("--train_examples_per_epoch", type=int, default=10000)
559
+ parser.add_argument("--val_length", type=int, default=131072)
560
+ parser.add_argument("--val_examples_per_epoch", type=int, default=1000)
561
+ parser.add_argument("--num_workers", type=int, default=16)
562
+
563
+ return parser
deepafx_st/utils.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from:
2
+ # https://github.com/csteinmetz1/micro-tcn/blob/main/microtcn/utils.py
3
+ import os
4
+ import csv
5
+ import torch
6
+ import fnmatch
7
+ import numpy as np
8
+ import random
9
+ from enum import Enum
10
+ import pyloudnorm as pyln
11
+
12
+
13
+ class DSPMode(Enum):
14
+ NONE = "none"
15
+ TRAIN_INFER = "train_infer"
16
+ INFER = "infer"
17
+
18
+ def __str__(self):
19
+ return self.value
20
+
21
+
22
+ def loudness_normalize(x, sample_rate, target_loudness=-24.0):
23
+ x = x.view(1, -1)
24
+ stereo_audio = x.repeat(2, 1).permute(1, 0).numpy()
25
+ meter = pyln.Meter(sample_rate)
26
+ loudness = meter.integrated_loudness(stereo_audio)
27
+ norm_x = pyln.normalize.loudness(
28
+ stereo_audio,
29
+ loudness,
30
+ target_loudness,
31
+ )
32
+ x = torch.tensor(norm_x).permute(1, 0)
33
+ x = x[0, :].view(1, -1)
34
+
35
+ return x
36
+
37
+
38
+ def get_random_file_id(keys):
39
+ # generate a random index into the keys of the input files
40
+ rand_input_idx = torch.randint(0, len(keys) - 1, [1])[0]
41
+ # find the key (file_id) correponding to the random index
42
+ rand_input_file_id = list(keys)[rand_input_idx]
43
+
44
+ return rand_input_file_id
45
+
46
+
47
+ def get_random_patch(audio_file, length, check_silence=True):
48
+ silent = True
49
+ while silent:
50
+ start_idx = int(torch.rand(1) * (audio_file.num_frames - length))
51
+ stop_idx = start_idx + length
52
+ patch = audio_file.audio[:, start_idx:stop_idx].clone().detach()
53
+ if (patch ** 2).mean() > 1e-4 or not check_silence:
54
+ silent = False
55
+
56
+ return start_idx, stop_idx
57
+
58
+
59
+ def seed_worker(worker_id):
60
+ worker_seed = torch.initial_seed() % 2 ** 32
61
+ np.random.seed(worker_seed)
62
+ random.seed(worker_seed)
63
+
64
+
65
+ def getFilesPath(directory, extension):
66
+
67
+ n_path = []
68
+ for path, subdirs, files in os.walk(directory):
69
+ for name in files:
70
+ if fnmatch.fnmatch(name, extension):
71
+ n_path.append(os.path.join(path, name))
72
+ n_path.sort()
73
+
74
+ return n_path
75
+
76
+
77
+ def count_parameters(model, trainable_only=True):
78
+
79
+ if trainable_only:
80
+ if len(list(model.parameters())) > 0:
81
+ params = sum(p.numel() for p in model.parameters() if p.requires_grad)
82
+ else:
83
+ params = 0
84
+ else:
85
+ if len(list(model.parameters())) > 0:
86
+ params = sum(p.numel() for p in model.parameters())
87
+ else:
88
+ params = 0
89
+
90
+ return params
91
+
92
+
93
+ def system_summary(system):
94
+ print(f"Encoder: {count_parameters(system.encoder)/1e6:0.2f} M")
95
+ print(f"Processor: {count_parameters(system.processor)/1e6:0.2f} M")
96
+
97
+ if hasattr(system, "adv_loss_fn"):
98
+ for idx, disc in enumerate(system.adv_loss_fn.discriminators):
99
+ print(f"Discriminator {idx+1}: {count_parameters(disc)/1e6:0.2f} M")
100
+
101
+
102
+ def center_crop(x, length: int):
103
+ if x.shape[-1] != length:
104
+ start = (x.shape[-1] - length) // 2
105
+ stop = start + length
106
+ x = x[..., start:stop]
107
+ return x
108
+
109
+
110
+ def causal_crop(x, length: int):
111
+ if x.shape[-1] != length:
112
+ stop = x.shape[-1] - 1
113
+ start = stop - length
114
+ x = x[..., start:stop]
115
+ return x
116
+
117
+
118
+ def denormalize(norm_val, max_val, min_val):
119
+ return (norm_val * (max_val - min_val)) + min_val
120
+
121
+
122
+ def normalize(denorm_val, max_val, min_val):
123
+ return (denorm_val - min_val) / (max_val - min_val)
124
+
125
+
126
+ def get_random_patch(audio_file, length, energy_treshold=1e-4):
127
+ """Produce sample indicies for a random patch of size `length`.
128
+
129
+ This function will check the energy of the selected patch to
130
+ ensure that it is not complete silence. If silence is found,
131
+ it will continue searching for a non-silent patch.
132
+
133
+ Args:
134
+ audio_file (AudioFile): Audio file object.
135
+ length (int): Number of samples in random patch.
136
+
137
+ Returns:
138
+ start_idx (int): Starting sample index
139
+ stop_idx (int): Stop sample index
140
+ """
141
+
142
+ silent = True
143
+ while silent:
144
+ start_idx = int(torch.rand(1) * (audio_file.num_frames - length))
145
+ stop_idx = start_idx + length
146
+ patch = audio_file.audio[:, start_idx:stop_idx]
147
+ if (patch ** 2).mean() > energy_treshold:
148
+ silent = False
149
+
150
+ return start_idx, stop_idx
151
+
152
+
153
+ def split_dataset(file_list, subset, train_frac):
154
+ """Given a list of files, split into train/val/test sets.
155
+
156
+ Args:
157
+ file_list (list): List of audio files.
158
+ subset (str): One of "train", "val", or "test".
159
+ train_frac (float): Fraction of the dataset to use for training.
160
+
161
+ Returns:
162
+ file_list (list): List of audio files corresponding to subset.
163
+ """
164
+ assert train_frac > 0.1 and train_frac < 1.0
165
+
166
+ total_num_examples = len(file_list)
167
+
168
+ train_num_examples = int(total_num_examples * train_frac)
169
+ val_num_examples = int(total_num_examples * (1 - train_frac) / 2)
170
+ test_num_examples = total_num_examples - (train_num_examples + val_num_examples)
171
+
172
+ if train_num_examples < 0:
173
+ raise ValueError(
174
+ f"No examples in training set. Try increasing train_frac: {train_frac}."
175
+ )
176
+ elif val_num_examples < 0:
177
+ raise ValueError(
178
+ f"No examples in validation set. Try decreasing train_frac: {train_frac}."
179
+ )
180
+ elif test_num_examples < 0:
181
+ raise ValueError(
182
+ f"No examples in test set. Try decreasing train_frac: {train_frac}."
183
+ )
184
+
185
+ if subset == "train":
186
+ start_idx = 0
187
+ stop_idx = train_num_examples
188
+ elif subset == "val":
189
+ start_idx = train_num_examples
190
+ stop_idx = start_idx + val_num_examples
191
+ elif subset == "test":
192
+ start_idx = train_num_examples + val_num_examples
193
+ stop_idx = start_idx + test_num_examples + 1
194
+ else:
195
+ raise ValueError("Invalid subset: {subset}.")
196
+
197
+ return file_list[start_idx:stop_idx]
198
+
199
+
200
+ def rademacher(size):
201
+ """Generates random samples from a Rademacher distribution +-1
202
+
203
+ Args:
204
+ size (int):
205
+
206
+ """
207
+ m = torch.distributions.binomial.Binomial(1, 0.5)
208
+ x = m.sample(size)
209
+ x[x == 0] = -1
210
+ return x
211
+
212
+
213
+ def get_subset(csv_file):
214
+ subset_files = []
215
+ with open(csv_file) as fp:
216
+ reader = csv.DictReader(fp)
217
+ for row in reader:
218
+ subset_files.append(row["filepath"])
219
+
220
+ return list(set(subset_files))
221
+
222
+
223
+ def conform_length(x: torch.Tensor, length: int):
224
+ """Crop or pad input on last dim to match `length`."""
225
+ if x.shape[-1] < length:
226
+ padsize = length - x.shape[-1]
227
+ x = torch.nn.functional.pad(x, (0, padsize))
228
+ elif x.shape[-1] > length:
229
+ x = x[..., :length]
230
+
231
+ return x
232
+
233
+
234
+ def linear_fade(
235
+ x: torch.Tensor,
236
+ fade_ms: float = 50.0,
237
+ sample_rate: float = 22050,
238
+ ):
239
+ """Apply fade in and fade out to last dim."""
240
+ fade_samples = int(fade_ms * 1e-3 * 22050)
241
+
242
+ fade_in = torch.linspace(0.0, 1.0, steps=fade_samples)
243
+ fade_out = torch.linspace(1.0, 0.0, steps=fade_samples)
244
+
245
+ # fade in
246
+ x[..., :fade_samples] *= fade_in
247
+
248
+ # fade out
249
+ x[..., -fade_samples:] *= fade_out
250
+
251
+ return x
252
+
253
+
254
+ # def get_random_patch(x, sample_rate, length_samples):
255
+ # length = length_samples
256
+ # silent = True
257
+ # while silent:
258
+ # start_idx = np.random.randint(0, x.shape[-1] - length - 1)
259
+ # stop_idx = start_idx + length
260
+ # x_crop = x[0:1, start_idx:stop_idx]
261
+
262
+ # # check for silence
263
+ # frames = length // sample_rate
264
+ # silent_frames = []
265
+ # for n in range(frames):
266
+ # start_idx = n * sample_rate
267
+ # stop_idx = start_idx + sample_rate
268
+ # x_frame = x_crop[0:1, start_idx:stop_idx]
269
+ # if (x_frame ** 2).mean() > 3e-4:
270
+ # silent_frames.append(False)
271
+ # else:
272
+ # silent_frames.append(True)
273
+ # silent = True if any(silent_frames) else False
274
+
275
+ # x_crop /= x_crop.abs().max()
276
+
277
+ # return x_crop
deepafx_st/version.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # !/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ '''Version info'''
4
+
5
+ short_version = '0.0'
6
+ version = '0.0.1'
packages.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ libsndfile1
2
+ sox
3
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ git+https://github.com/adobe-research/DeepAFx-ST.git
2
+ gradio
3
+ huggingface_hub