Initial Commit
Browse files- .gitignore +124 -0
- README copy.md +80 -0
- app.py +145 -0
- requirements.txt +24 -0
- sonics/__init__.py +5 -0
- sonics/layers/__init__.py +6 -0
- sonics/layers/augment.py +244 -0
- sonics/layers/embedding.py +33 -0
- sonics/layers/feature.py +146 -0
- sonics/layers/tokenizer.py +117 -0
- sonics/layers/transformer.py +176 -0
- sonics/models/__init__.py +3 -0
- sonics/models/hf_model.py +108 -0
- sonics/models/model.py +128 -0
- sonics/models/spectttra.py +85 -0
- sonics/models/vit.py +101 -0
- sonics/utils/config.py +24 -0
- sonics/utils/dataset.py +137 -0
- sonics/utils/losses.py +65 -0
- sonics/utils/metrics.py +149 -0
- sonics/utils/perf.py +107 -0
- sonics/utils/scheduler.py +95 -0
- sonics/utils/seed.py +22 -0
.gitignore
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
|
53 |
+
# Translations
|
54 |
+
*.mo
|
55 |
+
*.pot
|
56 |
+
|
57 |
+
# Django stuff:
|
58 |
+
*.log
|
59 |
+
local_settings.py
|
60 |
+
db.sqlite3
|
61 |
+
db.sqlite3-journal
|
62 |
+
|
63 |
+
# Flask stuff:
|
64 |
+
instance/
|
65 |
+
.webassets-cache
|
66 |
+
|
67 |
+
# Scrapy stuff:
|
68 |
+
.scrapy
|
69 |
+
|
70 |
+
# Sphinx documentation
|
71 |
+
docs/_build/
|
72 |
+
|
73 |
+
# PyBuilder
|
74 |
+
target/
|
75 |
+
|
76 |
+
# Jupyter Notebook
|
77 |
+
.ipynb_checkpoints
|
78 |
+
|
79 |
+
# IPython
|
80 |
+
profile_default/
|
81 |
+
ipython_config.py
|
82 |
+
|
83 |
+
# pyenv
|
84 |
+
.python-version
|
85 |
+
|
86 |
+
# pipenv
|
87 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
88 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
89 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
90 |
+
# install all needed dependencies.
|
91 |
+
#Pipfile.lock
|
92 |
+
|
93 |
+
# celery beat schedule file
|
94 |
+
celerybeat-schedule
|
95 |
+
|
96 |
+
# SageMath parsed files
|
97 |
+
*.sage.py
|
98 |
+
|
99 |
+
# Environments
|
100 |
+
.env
|
101 |
+
.venv
|
102 |
+
env/
|
103 |
+
venv/
|
104 |
+
ENV/
|
105 |
+
env.bak/
|
106 |
+
venv.bak/
|
107 |
+
|
108 |
+
# Spyder project settings
|
109 |
+
.spyderproject
|
110 |
+
.spyproject
|
111 |
+
|
112 |
+
# Rope project settings
|
113 |
+
.ropeproject
|
114 |
+
|
115 |
+
# mkdocs documentation
|
116 |
+
/site
|
117 |
+
|
118 |
+
# mypy
|
119 |
+
.mypy_cache/
|
120 |
+
.dmypy.json
|
121 |
+
dmypy.json
|
122 |
+
|
123 |
+
# Pyre type checker
|
124 |
+
.pyre/
|
README copy.md
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SONICS: Synthetic Or Not - Identifying Counterfeit Songs
|
2 |
+
|
3 |
+
This repository contains the official source code for our paper **SONICS: Synthetic Or Not - Identifying Counterfeit Songs**.
|
4 |
+
|
5 |
+
|
6 |
+
## System Configuration
|
7 |
+
|
8 |
+
- Disk Space: 150GB
|
9 |
+
- GPU Memory: 48GB
|
10 |
+
- RAM: 32GB
|
11 |
+
- Python Version: 3.10
|
12 |
+
- OS: Ubuntu 20.04
|
13 |
+
- CUDA Version: 12.4
|
14 |
+
|
15 |
+
## Installation
|
16 |
+
|
17 |
+
```
|
18 |
+
python -m venv .venv
|
19 |
+
source .venv/bin/activate
|
20 |
+
pip install -r requirements.txt
|
21 |
+
```
|
22 |
+
|
23 |
+
## Dataset
|
24 |
+
|
25 |
+
[As a part of our submission, we are not providing our dataset. It will be published after the final decision.]
|
26 |
+
|
27 |
+
After downloading the dataset, the folder structure should look like following:
|
28 |
+
|
29 |
+
```
|
30 |
+
parentFolder
|
31 |
+
│
|
32 |
+
├──sonics
|
33 |
+
│
|
34 |
+
├──dataset
|
35 |
+
│ ├──real_songs
|
36 |
+
│ │ └──xxx.mp3
|
37 |
+
│ ├──fake_songs
|
38 |
+
│ │ └──yyy.mp3
|
39 |
+
│ ├──real_songs.csv
|
40 |
+
│ └──fake_songs.csv
|
41 |
+
```
|
42 |
+
|
43 |
+
After downloading the dataset, to split it into train, val, and test set, we will need to run the following part from the parent folder
|
44 |
+
|
45 |
+
```shell
|
46 |
+
python data_split.py
|
47 |
+
```
|
48 |
+
|
49 |
+
> **Note:** The `real_songs.csv` and `fake_songs.csv` contain the metadata for the songs including filepath, duration, split, etc and config file contains path of the metadata.
|
50 |
+
|
51 |
+
> **Note:** Output files including checkpoints, model predictions will be saved in `./output/<experiment_name>/` folder.
|
52 |
+
|
53 |
+
## Training
|
54 |
+
|
55 |
+
Choose any of the config from `config` folder and run the following
|
56 |
+
|
57 |
+
```shell
|
58 |
+
python train.py --config <path to the config file>
|
59 |
+
```
|
60 |
+
|
61 |
+
## Testing
|
62 |
+
|
63 |
+
Choose any of the config from `config` folder and run the following
|
64 |
+
|
65 |
+
```shell
|
66 |
+
python test.py --config <path to the config file> --ckpt_path <path to the checkpoint file>
|
67 |
+
```
|
68 |
+
|
69 |
+
## Model Profiling
|
70 |
+
|
71 |
+
Choose any of the config from `config` folder and run the following
|
72 |
+
```shell
|
73 |
+
python model_profile.py --config <path to the config file> --batch_size 12
|
74 |
+
```
|
75 |
+
|
76 |
+
## Acknowledgement
|
77 |
+
|
78 |
+
We have utilized the code and models provided in the following repository:
|
79 |
+
|
80 |
+
- [Pytorch Image Models](https://github.com/huggingface/pytorch-image-models)
|
app.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import gradio as gr
|
4 |
+
import torch
|
5 |
+
import librosa
|
6 |
+
import pandas as pd
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from sonics import HFAudioClassifier
|
10 |
+
|
11 |
+
|
12 |
+
# Constants
|
13 |
+
MODEL_IDS = {
|
14 |
+
"SpecTTTra-α (5s)": "awsaf49/sonics-spectttra-alpha-5s",
|
15 |
+
"SpecTTTra-β (5s)": "awsaf49/sonics-spectttra-beta-5s",
|
16 |
+
"SpecTTTra-γ (5s)": "awsaf49/sonics-spectttra-gamma-5s",
|
17 |
+
"SpecTTTra-α (120s)": "awsaf49/sonics-spectttra-alpha-120s",
|
18 |
+
"SpecTTTra-β (120s)": "awsaf49/sonics-spectttra-beta-120s",
|
19 |
+
"SpecTTTra-γ (120s)": "awsaf49/sonics-spectttra-gamma-120s",
|
20 |
+
}
|
21 |
+
|
22 |
+
|
23 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
24 |
+
model_cache = {}
|
25 |
+
|
26 |
+
|
27 |
+
def load_model(model_name):
|
28 |
+
"""Load model if not already cached"""
|
29 |
+
if model_name not in model_cache:
|
30 |
+
model_id = MODEL_IDS[model_name]
|
31 |
+
model = HFAudioClassifier.from_pretrained(model_id)
|
32 |
+
model = model.to(device)
|
33 |
+
model.eval()
|
34 |
+
model_cache[model_name] = model
|
35 |
+
return model_cache[model_name]
|
36 |
+
|
37 |
+
|
38 |
+
def process_audio(audio_path, model_name):
|
39 |
+
"""Process audio file and return prediction"""
|
40 |
+
try:
|
41 |
+
# Load model
|
42 |
+
model = load_model(model_name)
|
43 |
+
|
44 |
+
# Get max time from model config
|
45 |
+
max_time = model.config.audio.max_time
|
46 |
+
|
47 |
+
# Load and process audio
|
48 |
+
audio, sr = librosa.load(audio_path, sr=16000)
|
49 |
+
duration = len(audio) / sr
|
50 |
+
|
51 |
+
# Calculate chunk size and middle position
|
52 |
+
chunk_samples = int(max_time * sr)
|
53 |
+
total_chunks = len(audio) // chunk_samples
|
54 |
+
middle_chunk_idx = total_chunks // 2
|
55 |
+
|
56 |
+
# Extract middle chunk
|
57 |
+
start = middle_chunk_idx * chunk_samples
|
58 |
+
end = start + chunk_samples
|
59 |
+
chunk = audio[start:end]
|
60 |
+
|
61 |
+
# Pad if needed (shouldn't be necessary for middle chunk)
|
62 |
+
if len(chunk) < chunk_samples:
|
63 |
+
chunk = np.pad(chunk, (0, chunk_samples - len(chunk)))
|
64 |
+
|
65 |
+
# Convert to tensor and get prediction
|
66 |
+
with torch.no_grad():
|
67 |
+
chunk = torch.from_numpy(chunk).float().to(device)
|
68 |
+
pred = model(chunk.unsqueeze(0))
|
69 |
+
prob = torch.sigmoid(pred).cpu().numpy()[0]
|
70 |
+
|
71 |
+
# Get prediction
|
72 |
+
output = {"Real": 1 - prob, "Fake": prob}
|
73 |
+
|
74 |
+
return output
|
75 |
+
|
76 |
+
except Exception as e:
|
77 |
+
return {
|
78 |
+
"Duration": "Error",
|
79 |
+
"Prediction": f"Error: {str(e)}",
|
80 |
+
"Confidence": "N/A",
|
81 |
+
}
|
82 |
+
|
83 |
+
|
84 |
+
def predict(audio_file, model_name):
|
85 |
+
"""Gradio interface function"""
|
86 |
+
if audio_file is None:
|
87 |
+
return {
|
88 |
+
"Duration": "No file",
|
89 |
+
"Prediction": "Please upload an audio file",
|
90 |
+
"Confidence": "N/A",
|
91 |
+
}
|
92 |
+
|
93 |
+
return process_audio(audio_file, model_name)
|
94 |
+
|
95 |
+
|
96 |
+
# Create Gradio interface
|
97 |
+
css = """
|
98 |
+
.heading {
|
99 |
+
text-align: center;
|
100 |
+
margin-bottom: 2rem;
|
101 |
+
}
|
102 |
+
.logo {
|
103 |
+
max-width: 250px;
|
104 |
+
margin: 0 auto;
|
105 |
+
display: block;
|
106 |
+
}
|
107 |
+
"""
|
108 |
+
|
109 |
+
with gr.Blocks(css=css) as demo:
|
110 |
+
gr.HTML(
|
111 |
+
"""
|
112 |
+
<div class="heading">
|
113 |
+
<img src="https://i.postimg.cc/3Jx3yZ5b/real-vs-fake-sonics-w-logo.jpg" class="logo">
|
114 |
+
<h1>SONICS: Synthetic Or Not - Identifying Counterfeit Songs</h1>
|
115 |
+
<h3><span style="color:red;"><b>ICLR 2025 [Poster]</b></span></h3>
|
116 |
+
</div>
|
117 |
+
"""
|
118 |
+
)
|
119 |
+
|
120 |
+
with gr.Row():
|
121 |
+
with gr.Column():
|
122 |
+
audio_input = gr.Audio(label="Upload Audio", type="filepath")
|
123 |
+
model_dropdown = gr.Dropdown(
|
124 |
+
choices=list(MODEL_IDS.keys()),
|
125 |
+
value="SpecTTTra-γ (5s)",
|
126 |
+
label="Select Model",
|
127 |
+
)
|
128 |
+
submit_btn = gr.Button("Predict")
|
129 |
+
|
130 |
+
with gr.Column():
|
131 |
+
output = gr.Label(label="Result", num_top_classes=2)
|
132 |
+
|
133 |
+
submit_btn.click(fn=predict, inputs=[audio_input, model_dropdown], outputs=[output])
|
134 |
+
|
135 |
+
gr.Markdown(
|
136 |
+
"""
|
137 |
+
## Resources
|
138 |
+
- 📄 [Paper](https://openreview.net/forum?id=PY7KSh29Z8)
|
139 |
+
- 🎵 [Dataset](https://huggingface.co/datasets/awsaf49/sonics)
|
140 |
+
- 🔬 [ArXiv](https://arxiv.org/abs/2408.14080)
|
141 |
+
- 💻 [GitHub](https://github.com/awsaf49/sonics)
|
142 |
+
"""
|
143 |
+
)
|
144 |
+
|
145 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Core libraries
|
2 |
+
torch>=2.4.0
|
3 |
+
torchaudio>=2.4.0
|
4 |
+
|
5 |
+
# Audio processing
|
6 |
+
librosa>=0.9.0
|
7 |
+
|
8 |
+
# Data processing
|
9 |
+
pandas>=1.3.0
|
10 |
+
|
11 |
+
# Visualization
|
12 |
+
matplotlib>=3.4.0
|
13 |
+
tqdm>=4.60.0
|
14 |
+
|
15 |
+
# ML utilities
|
16 |
+
scikit-learn>=1.0.0
|
17 |
+
|
18 |
+
# flop
|
19 |
+
fvcore
|
20 |
+
timm>=1.0.7
|
21 |
+
|
22 |
+
# gradio
|
23 |
+
gradio>=4.0.0
|
24 |
+
|
sonics/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sonics.utils.seed import set_seed
|
2 |
+
from sonics.utils.config import dict2cfg
|
3 |
+
from sonics.utils.dataset import get_dataloader
|
4 |
+
from sonics.utils.scheduler import get_scheduler
|
5 |
+
from sonics.models.hf_model import HFAudioClassifier
|
sonics/layers/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sonics.layers.tokenizer import Tokenizer1D, STTokenizer
|
2 |
+
from sonics.layers.embedding import (
|
3 |
+
SinusoidPositionalEncoding,
|
4 |
+
LearnedPositionalEncoding,
|
5 |
+
)
|
6 |
+
from sonics.layers.transformer import Transformer
|
sonics/layers/augment.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Tuple
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torchaudio.transforms import SpecAugment
|
6 |
+
from torch import Tensor
|
7 |
+
from torchvision.transforms import functional as F
|
8 |
+
|
9 |
+
|
10 |
+
class AugmentLayer(nn.Module):
|
11 |
+
def __init__(self, cfg):
|
12 |
+
super().__init__()
|
13 |
+
self.cfg = cfg
|
14 |
+
|
15 |
+
# Initialize MixUp
|
16 |
+
self.mixup = MixUp(
|
17 |
+
alpha=cfg.augment.mixup_alpha,
|
18 |
+
num_classes=cfg.num_classes,
|
19 |
+
p=cfg.augment.mixup_p,
|
20 |
+
inplace=True,
|
21 |
+
)
|
22 |
+
|
23 |
+
# Initialize other augmentations
|
24 |
+
self.time_freq_mask = SpecAugment(
|
25 |
+
n_time_masks=cfg.augment.n_time_masks,
|
26 |
+
time_mask_param=cfg.augment.time_mask_param,
|
27 |
+
n_freq_masks=cfg.augment.n_freq_masks,
|
28 |
+
freq_mask_param=cfg.augment.freq_mask_param,
|
29 |
+
p=cfg.augment.time_freq_mask_p,
|
30 |
+
zero_masking=True,
|
31 |
+
)
|
32 |
+
|
33 |
+
def forward(self, spec, y=None):
|
34 |
+
# Apply MixUp or CutMix with RandomChoice
|
35 |
+
if y is not None:
|
36 |
+
# img = spec.unsqueeze(1) # shape: (batch_size, 1, n_mels, n_frames)
|
37 |
+
spec, y = self.mixup(spec, y)
|
38 |
+
# spec = img.squeeze(1) # shape: (batch_size, n_mels, n_frames)
|
39 |
+
|
40 |
+
# Apply TimeMasking and FrequencyMasking
|
41 |
+
spec = self.time_freq_mask(spec)
|
42 |
+
return spec, y
|
43 |
+
|
44 |
+
|
45 |
+
class MixUp(torch.nn.Module):
|
46 |
+
"""Randomly apply MixUp to the provided batch and targets.
|
47 |
+
The class implements the data augmentations as described in the paper
|
48 |
+
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
num_classes (int): number of classes used for one-hot encoding.
|
52 |
+
p (float): probability of the batch being transformed. Default value is 0.5.
|
53 |
+
alpha (float): hyperparameter of the Beta distribution used for mixup.
|
54 |
+
Default value is 1.0.
|
55 |
+
inplace (bool): boolean to make this transform inplace. Default set to False.
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
num_classes: int,
|
61 |
+
p: float = 0.5,
|
62 |
+
alpha: float = 1.0,
|
63 |
+
inplace: bool = False,
|
64 |
+
) -> None:
|
65 |
+
super().__init__()
|
66 |
+
|
67 |
+
if num_classes < 1:
|
68 |
+
raise ValueError(
|
69 |
+
f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}"
|
70 |
+
)
|
71 |
+
|
72 |
+
if alpha <= 0:
|
73 |
+
raise ValueError("Alpha param can't be zero.")
|
74 |
+
|
75 |
+
self.num_classes = num_classes
|
76 |
+
self.p = p
|
77 |
+
self.alpha = alpha
|
78 |
+
self.inplace = inplace
|
79 |
+
|
80 |
+
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
|
81 |
+
"""
|
82 |
+
Args:
|
83 |
+
batch (Tensor): Float tensor of size (B, C, H, W)
|
84 |
+
target (Tensor): Integer tensor of size (B, )
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
Tensor: Randomly transformed batch.
|
88 |
+
"""
|
89 |
+
if batch.ndim != 3 and batch.ndim != 2:
|
90 |
+
raise ValueError(
|
91 |
+
f"Batch ndim should be 3 (b, f, t) or 2 (b, n). Got {batch.ndim}"
|
92 |
+
)
|
93 |
+
if target.ndim != 1:
|
94 |
+
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
|
95 |
+
if not batch.is_floating_point():
|
96 |
+
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
|
97 |
+
if target.dtype != torch.int64 and self.num_classes > 1:
|
98 |
+
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
|
99 |
+
|
100 |
+
if not self.inplace:
|
101 |
+
batch = batch.clone()
|
102 |
+
target = target.clone()
|
103 |
+
|
104 |
+
if target.ndim == 1 and self.num_classes > 1:
|
105 |
+
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes)
|
106 |
+
|
107 |
+
target = target.to(dtype=batch.dtype)
|
108 |
+
|
109 |
+
if torch.rand(1).item() >= self.p:
|
110 |
+
return batch, target
|
111 |
+
|
112 |
+
# It's faster to roll the batch by one instead of shuffling it to create image pairs
|
113 |
+
batch_rolled = batch.roll(1, 0)
|
114 |
+
target_rolled = target.roll(1, 0)
|
115 |
+
|
116 |
+
# Implemented as on mixup paper, page 3.
|
117 |
+
lambda_param = float(
|
118 |
+
torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]
|
119 |
+
)
|
120 |
+
batch_rolled.mul_(1.0 - lambda_param)
|
121 |
+
batch.mul_(lambda_param).add_(batch_rolled)
|
122 |
+
|
123 |
+
target_rolled.mul_(1.0 - lambda_param)
|
124 |
+
target.mul_(lambda_param).add_(target_rolled)
|
125 |
+
|
126 |
+
return batch, target
|
127 |
+
|
128 |
+
def __repr__(self) -> str:
|
129 |
+
s = (
|
130 |
+
f"{self.__class__.__name__}("
|
131 |
+
f"num_classes={self.num_classes}"
|
132 |
+
f", p={self.p}"
|
133 |
+
f", alpha={self.alpha}"
|
134 |
+
f", inplace={self.inplace}"
|
135 |
+
f")"
|
136 |
+
)
|
137 |
+
return s
|
138 |
+
|
139 |
+
|
140 |
+
# Todo: height of spec should be 1, adjust it for audio input (bs, n_samples)
|
141 |
+
class CutMix(torch.nn.Module):
|
142 |
+
"""Randomly apply CutMix to the provided batch and targets.
|
143 |
+
The class implements the data augmentations as described in the paper
|
144 |
+
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
|
145 |
+
<https://arxiv.org/abs/1905.04899>`_.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
num_classes (int): number of classes used for one-hot encoding.
|
149 |
+
p (float): probability of the batch being transformed. Default value is 0.5.
|
150 |
+
alpha (float): hyperparameter of the Beta distribution used for cutmix.
|
151 |
+
Default value is 1.0.
|
152 |
+
inplace (bool): boolean to make this transform inplace. Default set to False.
|
153 |
+
"""
|
154 |
+
|
155 |
+
def __init__(
|
156 |
+
self,
|
157 |
+
num_classes: int,
|
158 |
+
p: float = 0.5,
|
159 |
+
alpha: float = 1.0,
|
160 |
+
inplace: bool = False,
|
161 |
+
) -> None:
|
162 |
+
super().__init__()
|
163 |
+
if num_classes < 1:
|
164 |
+
raise ValueError(
|
165 |
+
"Please provide a valid positive value for the num_classes."
|
166 |
+
)
|
167 |
+
if alpha <= 0:
|
168 |
+
raise ValueError("Alpha param can't be zero.")
|
169 |
+
|
170 |
+
self.num_classes = num_classes
|
171 |
+
self.p = p
|
172 |
+
self.alpha = alpha
|
173 |
+
self.inplace = inplace
|
174 |
+
|
175 |
+
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
|
176 |
+
"""
|
177 |
+
Args:
|
178 |
+
batch (Tensor): Float tensor of size (B, C, H, W)
|
179 |
+
target (Tensor): Integer tensor of size (B, )
|
180 |
+
|
181 |
+
Returns:
|
182 |
+
Tensor: Randomly transformed batch.
|
183 |
+
"""
|
184 |
+
if batch.ndim != 4:
|
185 |
+
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
|
186 |
+
if target.ndim != 1:
|
187 |
+
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
|
188 |
+
if not batch.is_floating_point():
|
189 |
+
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
|
190 |
+
if target.dtype != torch.int64 and self.num_classes > 1:
|
191 |
+
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
|
192 |
+
|
193 |
+
if not self.inplace:
|
194 |
+
batch = batch.clone()
|
195 |
+
target = target.clone()
|
196 |
+
|
197 |
+
if target.ndim == 1 and self.num_classes > 1:
|
198 |
+
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes)
|
199 |
+
|
200 |
+
target = target.to(dtype=batch.dtype)
|
201 |
+
|
202 |
+
if torch.rand(1).item() >= self.p:
|
203 |
+
return batch, target
|
204 |
+
|
205 |
+
# It's faster to roll the batch by one instead of shuffling it to create image pairs
|
206 |
+
batch_rolled = batch.roll(1, 0)
|
207 |
+
target_rolled = target.roll(1, 0)
|
208 |
+
|
209 |
+
# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
|
210 |
+
lambda_param = float(
|
211 |
+
torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]
|
212 |
+
)
|
213 |
+
_, H, W = F.get_dimensions(batch)
|
214 |
+
|
215 |
+
r_x = torch.randint(W, (1,))
|
216 |
+
r_y = torch.randint(H, (1,))
|
217 |
+
|
218 |
+
r = 0.5 * math.sqrt(1.0 - lambda_param)
|
219 |
+
r_w_half = int(r * W)
|
220 |
+
r_h_half = int(r * H)
|
221 |
+
|
222 |
+
x1 = int(torch.clamp(r_x - r_w_half, min=0))
|
223 |
+
y1 = int(torch.clamp(r_y - r_h_half, min=0))
|
224 |
+
x2 = int(torch.clamp(r_x + r_w_half, max=W))
|
225 |
+
y2 = int(torch.clamp(r_y + r_h_half, max=H))
|
226 |
+
|
227 |
+
batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
|
228 |
+
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
|
229 |
+
|
230 |
+
target_rolled.mul_(1.0 - lambda_param)
|
231 |
+
target.mul_(lambda_param).add_(target_rolled)
|
232 |
+
|
233 |
+
return batch, target
|
234 |
+
|
235 |
+
def __repr__(self) -> str:
|
236 |
+
s = (
|
237 |
+
f"{self.__class__.__name__}("
|
238 |
+
f"num_classes={self.num_classes}"
|
239 |
+
f", p={self.p}"
|
240 |
+
f", alpha={self.alpha}"
|
241 |
+
f", inplace={self.inplace}"
|
242 |
+
f")"
|
243 |
+
)
|
244 |
+
return s
|
sonics/layers/embedding.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class SinusoidPositionalEncoding(nn.Module):
|
6 |
+
def __init__(self, token_dim, max_len=5000):
|
7 |
+
super(SinusoidPositionalEncoding, self).__init__()
|
8 |
+
pe = torch.zeros(max_len, token_dim) # shape: (max_len, token_dim)
|
9 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(
|
10 |
+
1
|
11 |
+
) # shape: (max_len, 1)
|
12 |
+
div_term = torch.exp(
|
13 |
+
torch.arange(0, token_dim, 2).float()
|
14 |
+
* (-torch.log(torch.tensor(10000.0)) / token_dim)
|
15 |
+
) # shape: (token_dim // 2)
|
16 |
+
pe[:, 0::2] = torch.sin(position * div_term) # shape: (max_len, token_dim // 2)
|
17 |
+
pe[:, 1::2] = torch.cos(position * div_term) # shape: (max_len, token_dim // 2)
|
18 |
+
pe = pe.unsqueeze(0) # shape: (1, max_len, token_dim)
|
19 |
+
self.register_buffer("pe", pe)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
x = x + self.pe[:, : x.size(1), :] # shape: (batch_size, seq_len, token_dim)
|
23 |
+
return x
|
24 |
+
|
25 |
+
|
26 |
+
class LearnedPositionalEncoding(nn.Module):
|
27 |
+
def __init__(self, token_dim, num_tokens):
|
28 |
+
super(LearnedPositionalEncoding, self).__init__()
|
29 |
+
self.pe = nn.Parameter(torch.randn(1, num_tokens, token_dim) * 0.02)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
x = x + self.pe
|
33 |
+
return x
|
sonics/layers/feature.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
try:
|
6 |
+
from torch.amp import autocast
|
7 |
+
|
8 |
+
torch_amp_new = True
|
9 |
+
except:
|
10 |
+
from torch.cuda.amp import autocast
|
11 |
+
|
12 |
+
torch_amp_new = False
|
13 |
+
|
14 |
+
from torchaudio.transforms import AmplitudeToDB, MelSpectrogram
|
15 |
+
|
16 |
+
|
17 |
+
class FeatureExtractor(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
cfg,
|
21 |
+
):
|
22 |
+
"""
|
23 |
+
Feature extraction module.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
params (dict): Parameters for the spectrogram.
|
27 |
+
aug_config (dict, optional): Configuration for data augmentation. Defaults to None.
|
28 |
+
top_db (float, optional): Threshold for computing the amplitude to dB. Defaults to None.
|
29 |
+
norm (str, optional): Normalization method. Defaults to "min_max".
|
30 |
+
"""
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
self.audio2melspec = MelSpectrogram(
|
34 |
+
n_fft=cfg.melspec.n_fft,
|
35 |
+
hop_length=cfg.melspec.hop_length,
|
36 |
+
win_length=cfg.melspec.win_length,
|
37 |
+
n_mels=cfg.melspec.n_mels,
|
38 |
+
sample_rate=cfg.audio.sample_rate,
|
39 |
+
f_min=cfg.melspec.f_min,
|
40 |
+
f_max=cfg.melspec.f_max,
|
41 |
+
power=cfg.melspec.power,
|
42 |
+
)
|
43 |
+
self.amplitude_to_db = AmplitudeToDB(top_db=cfg.melspec.top_db)
|
44 |
+
|
45 |
+
if cfg.melspec.norm == "mean_std":
|
46 |
+
self.normalizer = MeanStdNorm()
|
47 |
+
elif cfg.melspec.norm == "min_max":
|
48 |
+
self.normalizer = MinMaxNorm()
|
49 |
+
elif cfg.melspec.norm == "simple":
|
50 |
+
self.normalizer = SimpleNorm()
|
51 |
+
else:
|
52 |
+
self.normalizer = nn.Identity()
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
"""
|
56 |
+
Forward pass of the feature extractor.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
x (torch.Tensor): Input audio data.
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
torch.Tensor: Extracted features.
|
63 |
+
"""
|
64 |
+
|
65 |
+
with (
|
66 |
+
autocast("cuda", enabled=False)
|
67 |
+
if torch_amp_new
|
68 |
+
else autocast(enabled=False)
|
69 |
+
):
|
70 |
+
melspec = self.audio2melspec(x.float())
|
71 |
+
melspec = self.amplitude_to_db(melspec)
|
72 |
+
melspec = self.normalizer(melspec)
|
73 |
+
|
74 |
+
return melspec
|
75 |
+
|
76 |
+
|
77 |
+
class MinMaxNorm(nn.Module):
|
78 |
+
def __init__(self, eps=1e-6):
|
79 |
+
"""
|
80 |
+
Module for performing min-max normalization on input data.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
eps (float, optional): Small value to avoid division by zero. Defaults to 1e-6.
|
84 |
+
"""
|
85 |
+
super().__init__()
|
86 |
+
self.eps = eps
|
87 |
+
|
88 |
+
def forward(self, X):
|
89 |
+
"""
|
90 |
+
Forward pass of the min-max normalization module.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
X (torch.Tensor): Input data.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
torch.Tensor: Normalized data.
|
97 |
+
"""
|
98 |
+
min_ = torch.amax(X, dim=(1, 2), keepdim=True)
|
99 |
+
max_ = torch.amin(X, dim=(1, 2), keepdim=True)
|
100 |
+
return (X - min_) / (max_ - min_ + self.eps)
|
101 |
+
|
102 |
+
|
103 |
+
class SimpleNorm(nn.Module):
|
104 |
+
def __init__(self):
|
105 |
+
"""
|
106 |
+
Module for performing simple normalization on input data.
|
107 |
+
"""
|
108 |
+
super().__init__()
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
"""
|
112 |
+
Forward pass of the simple normalization module.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
x (torch.Tensor): Input data.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
torch.Tensor: Normalized data.
|
119 |
+
"""
|
120 |
+
return (x - 40) / 80
|
121 |
+
|
122 |
+
|
123 |
+
class MeanStdNorm(nn.Module):
|
124 |
+
def __init__(self, eps=1e-6):
|
125 |
+
"""
|
126 |
+
Module for performing mean and standard deviation normalization on input data.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
eps (float, optional): Small value to avoid division by zero. Defaults to 1e-6.
|
130 |
+
"""
|
131 |
+
super().__init__()
|
132 |
+
self.eps = eps
|
133 |
+
|
134 |
+
def forward(self, X):
|
135 |
+
"""
|
136 |
+
Forward pass of the mean and standard deviation normalization module.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
X (torch.Tensor): Input data.
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
torch.Tensor: Normalized data.
|
143 |
+
"""
|
144 |
+
mean = X.mean((1, 2), keepdim=True)
|
145 |
+
std = X.reshape(X.size(0), -1).std(1, keepdim=True).unsqueeze(-1)
|
146 |
+
return (X - mean) / (std + self.eps)
|
sonics/layers/tokenizer.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from sonics.layers.embedding import (
|
5 |
+
SinusoidPositionalEncoding,
|
6 |
+
LearnedPositionalEncoding,
|
7 |
+
)
|
8 |
+
|
9 |
+
|
10 |
+
class STTokenizer(nn.Module):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
input_spec_dim,
|
14 |
+
input_temp_dim,
|
15 |
+
t_clip,
|
16 |
+
f_clip,
|
17 |
+
embed_dim,
|
18 |
+
pre_norm=False,
|
19 |
+
pe_learnable=False,
|
20 |
+
):
|
21 |
+
super(STTokenizer, self).__init__()
|
22 |
+
self.input_spec_dim = input_spec_dim
|
23 |
+
self.input_temp_dim = input_temp_dim
|
24 |
+
self.t_clip = t_clip
|
25 |
+
self.f_clip = f_clip
|
26 |
+
self.embed_dim = embed_dim
|
27 |
+
self.pre_norm = pre_norm
|
28 |
+
self.pe_learnable = pe_learnable
|
29 |
+
|
30 |
+
self.num_temporal_tokens = math.floor(
|
31 |
+
(input_temp_dim - t_clip) / t_clip + 1
|
32 |
+
) # floor((1280 - 5) / 5 + 1)= 256
|
33 |
+
self.num_spectral_tokens = math.floor(
|
34 |
+
(input_spec_dim - f_clip) / f_clip + 1
|
35 |
+
) # floor((128 - 3) / 3 + 1) = 42
|
36 |
+
# L_out = floor((L_in + 2*p - d*(k - 1) - 1) / s + 1) (ref: PyTorch docs)
|
37 |
+
self.num_tokens = (
|
38 |
+
self.num_temporal_tokens + self.num_spectral_tokens
|
39 |
+
) # 255 + 42 = 299
|
40 |
+
# For ViT, num_tokens = (1280 * 128)//(5 * 3) = 10922 :)
|
41 |
+
|
42 |
+
self.temporal_tokenizer = Tokenizer1D(
|
43 |
+
input_spec_dim,
|
44 |
+
embed_dim,
|
45 |
+
clip_size=t_clip,
|
46 |
+
num_clips=self.num_temporal_tokens,
|
47 |
+
pre_norm=pre_norm,
|
48 |
+
pe_learnable=pe_learnable,
|
49 |
+
)
|
50 |
+
self.spectral_tokenizer = Tokenizer1D(
|
51 |
+
input_temp_dim,
|
52 |
+
embed_dim,
|
53 |
+
clip_size=f_clip,
|
54 |
+
num_clips=self.num_spectral_tokens,
|
55 |
+
pre_norm=pre_norm,
|
56 |
+
pe_learnable=pe_learnable,
|
57 |
+
)
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
# Temporal tokenization
|
61 |
+
temporal_input = x # shape: (B, F, T)
|
62 |
+
temporal_tokens = self.temporal_tokenizer(
|
63 |
+
temporal_input
|
64 |
+
) # shape: (B, T/t, dim)
|
65 |
+
|
66 |
+
# Spectral tokenization
|
67 |
+
spectral_input = x.permute(0, 2, 1) # shape: (batch_size, T, F)
|
68 |
+
spectral_tokens = self.spectral_tokenizer(
|
69 |
+
spectral_input
|
70 |
+
) # shape: (B, F/f, dim)
|
71 |
+
|
72 |
+
spectro_temporal_tokens = torch.cat(
|
73 |
+
(temporal_tokens, spectral_tokens), dim=1
|
74 |
+
) # shape: (B, T/t + F/f, dim)
|
75 |
+
return spectro_temporal_tokens
|
76 |
+
|
77 |
+
|
78 |
+
class Tokenizer1D(nn.Module):
|
79 |
+
"""Teimporal/Spectral Tokenizer
|
80 |
+
|
81 |
+
Whisper uses temporal tokenizer but time_clip_size is too small, stride=1, thus
|
82 |
+
complexity is very high. We use stride=clip_size - 1 to reduce complexity.
|
83 |
+
"""
|
84 |
+
|
85 |
+
def __init__(
|
86 |
+
self,
|
87 |
+
input_dim,
|
88 |
+
token_dim,
|
89 |
+
clip_size,
|
90 |
+
num_clips,
|
91 |
+
pre_norm=False,
|
92 |
+
pe_learnable=False,
|
93 |
+
):
|
94 |
+
super(Tokenizer1D, self).__init__()
|
95 |
+
self.conv1d = nn.Conv1d(
|
96 |
+
input_dim,
|
97 |
+
token_dim,
|
98 |
+
clip_size,
|
99 |
+
stride=clip_size,
|
100 |
+
bias=not pre_norm, # # disable bias if pre-norm is used (e.g. CLIP)
|
101 |
+
)
|
102 |
+
self.act = nn.GELU()
|
103 |
+
self.pos_encoder = (
|
104 |
+
SinusoidPositionalEncoding(token_dim)
|
105 |
+
if not pe_learnable
|
106 |
+
else LearnedPositionalEncoding(token_dim, num_clips)
|
107 |
+
)
|
108 |
+
self.norm_pre = nn.LayerNorm(token_dim, eps=1e-6) if pre_norm else nn.Identity()
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
x = x # (F, T)
|
112 |
+
x = self.conv1d(x) # (F, T) -> (dim, T/t)
|
113 |
+
x = self.act(x)
|
114 |
+
x = x.transpose(1, 2) # (dim, T/t) -> (T/t, dim)
|
115 |
+
x = self.pos_encoder(x) # add position embeds
|
116 |
+
x = self.norm_pre(x)
|
117 |
+
return x
|
sonics/layers/transformer.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.utils.checkpoint
|
8 |
+
from torch.jit import Final
|
9 |
+
|
10 |
+
from timm.layers import (
|
11 |
+
Mlp,
|
12 |
+
DropPath,
|
13 |
+
use_fused_attn,
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
class Attention(nn.Module):
|
18 |
+
fused_attn: Final[bool]
|
19 |
+
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
dim: int,
|
23 |
+
num_heads: int = 8,
|
24 |
+
qkv_bias: bool = False,
|
25 |
+
qk_norm: bool = False,
|
26 |
+
attn_drop: float = 0.0,
|
27 |
+
proj_drop: float = 0.0,
|
28 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
29 |
+
) -> None:
|
30 |
+
super().__init__()
|
31 |
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
32 |
+
self.num_heads = num_heads
|
33 |
+
self.head_dim = dim // num_heads
|
34 |
+
self.scale = self.head_dim**-0.5
|
35 |
+
self.fused_attn = use_fused_attn()
|
36 |
+
|
37 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
38 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
39 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
40 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
41 |
+
self.proj = nn.Linear(dim, dim)
|
42 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
43 |
+
|
44 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
45 |
+
B, N, C = x.shape
|
46 |
+
qkv = (
|
47 |
+
self.qkv(x)
|
48 |
+
.reshape(B, N, 3, self.num_heads, self.head_dim)
|
49 |
+
.permute(2, 0, 3, 1, 4)
|
50 |
+
)
|
51 |
+
q, k, v = qkv.unbind(0)
|
52 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
53 |
+
|
54 |
+
if self.fused_attn:
|
55 |
+
x = F.scaled_dot_product_attention(
|
56 |
+
q,
|
57 |
+
k,
|
58 |
+
v,
|
59 |
+
dropout_p=self.attn_drop.p if self.training else 0.0,
|
60 |
+
)
|
61 |
+
else:
|
62 |
+
q = q * self.scale
|
63 |
+
attn = q @ k.transpose(-2, -1)
|
64 |
+
attn = attn.softmax(dim=-1)
|
65 |
+
attn = self.attn_drop(attn)
|
66 |
+
x = attn @ v
|
67 |
+
|
68 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
69 |
+
x = self.proj(x)
|
70 |
+
x = self.proj_drop(x)
|
71 |
+
return x
|
72 |
+
|
73 |
+
|
74 |
+
class LayerScale(nn.Module):
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
dim: int,
|
78 |
+
init_values: float = 1e-5,
|
79 |
+
inplace: bool = False,
|
80 |
+
) -> None:
|
81 |
+
super().__init__()
|
82 |
+
self.inplace = inplace
|
83 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
84 |
+
|
85 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
86 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
87 |
+
|
88 |
+
|
89 |
+
class TransformerBlock(nn.Module):
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
dim: int,
|
93 |
+
num_heads: int,
|
94 |
+
mlp_ratio: float = 4.0,
|
95 |
+
qkv_bias: bool = False,
|
96 |
+
qk_norm: bool = False,
|
97 |
+
proj_drop: float = 0.0,
|
98 |
+
attn_drop: float = 0.0,
|
99 |
+
init_values: Optional[float] = None,
|
100 |
+
drop_path: float = 0.0,
|
101 |
+
act_layer: nn.Module = nn.GELU,
|
102 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
103 |
+
mlp_layer: nn.Module = Mlp,
|
104 |
+
) -> None:
|
105 |
+
super().__init__()
|
106 |
+
self.norm1 = norm_layer(dim)
|
107 |
+
self.attn = Attention(
|
108 |
+
dim,
|
109 |
+
num_heads=num_heads,
|
110 |
+
qkv_bias=qkv_bias,
|
111 |
+
qk_norm=qk_norm,
|
112 |
+
attn_drop=attn_drop,
|
113 |
+
proj_drop=proj_drop,
|
114 |
+
norm_layer=norm_layer,
|
115 |
+
)
|
116 |
+
self.ls1 = (
|
117 |
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
118 |
+
)
|
119 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
120 |
+
|
121 |
+
self.norm2 = norm_layer(dim)
|
122 |
+
self.mlp = mlp_layer(
|
123 |
+
in_features=dim,
|
124 |
+
hidden_features=int(dim * mlp_ratio),
|
125 |
+
act_layer=act_layer,
|
126 |
+
drop=proj_drop,
|
127 |
+
)
|
128 |
+
self.ls2 = (
|
129 |
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
130 |
+
)
|
131 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
132 |
+
|
133 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
134 |
+
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
|
135 |
+
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
136 |
+
return x
|
137 |
+
|
138 |
+
|
139 |
+
class Transformer(nn.Module):
|
140 |
+
"""
|
141 |
+
Transformer layer, taken from timm library
|
142 |
+
"""
|
143 |
+
|
144 |
+
def __init__(
|
145 |
+
self,
|
146 |
+
embed_dim: int,
|
147 |
+
num_heads: int,
|
148 |
+
num_layers: int,
|
149 |
+
mlp_ratio: float = 4.0,
|
150 |
+
qkv_bias: bool = False,
|
151 |
+
qk_norm: bool = False,
|
152 |
+
proj_drop: float = 0.0,
|
153 |
+
attn_drop: float = 0.0,
|
154 |
+
drop_path: float = 0.0,
|
155 |
+
):
|
156 |
+
super(Transformer, self).__init__()
|
157 |
+
self.blocks = nn.ModuleList(
|
158 |
+
[
|
159 |
+
TransformerBlock(
|
160 |
+
dim=embed_dim,
|
161 |
+
num_heads=num_heads,
|
162 |
+
mlp_ratio=mlp_ratio,
|
163 |
+
qkv_bias=qkv_bias,
|
164 |
+
qk_norm=qk_norm,
|
165 |
+
proj_drop=proj_drop,
|
166 |
+
attn_drop=attn_drop,
|
167 |
+
drop_path=drop_path,
|
168 |
+
)
|
169 |
+
for _ in range(num_layers)
|
170 |
+
]
|
171 |
+
)
|
172 |
+
|
173 |
+
def forward(self, x):
|
174 |
+
for block in self.blocks:
|
175 |
+
x = block(x)
|
176 |
+
return x
|
sonics/models/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from sonics.models.model import AudioClassifier
|
2 |
+
from sonics.models.spectttra import SpecTTTra
|
3 |
+
from sonics.models.vit import ViT
|
sonics/models/hf_model.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from .model import AudioClassifier
|
6 |
+
from ..utils.config import dict2cfg, cfg2dict
|
7 |
+
from huggingface_hub import HfApi, create_repo, hf_hub_download
|
8 |
+
|
9 |
+
class HFAudioClassifier(AudioClassifier):
|
10 |
+
"""Hugging Face compatible AudioClassifier model"""
|
11 |
+
|
12 |
+
def __init__(self, config):
|
13 |
+
if isinstance(config, dict):
|
14 |
+
self.config = dict2cfg(config)
|
15 |
+
super().__init__(self.config)
|
16 |
+
|
17 |
+
@classmethod
|
18 |
+
def from_pretrained(cls, model_id, cache_dir=None, map_location="cpu", strict=False):
|
19 |
+
# Check if model_id is a local path
|
20 |
+
is_local = os.path.exists(model_id)
|
21 |
+
|
22 |
+
if is_local:
|
23 |
+
# Load from local checkpoint
|
24 |
+
config_file = os.path.join(model_id, "config.json")
|
25 |
+
model_file = os.path.join(model_id, "pytorch_model.bin")
|
26 |
+
else:
|
27 |
+
# Download from HF Hub
|
28 |
+
config_file = hf_hub_download(repo_id=model_id, filename="config.json", cache_dir=cache_dir)
|
29 |
+
model_file = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin", cache_dir=cache_dir)
|
30 |
+
|
31 |
+
# Read config
|
32 |
+
config = None
|
33 |
+
if os.path.exists(config_file):
|
34 |
+
with open(config_file, "r", encoding="utf-8") as f:
|
35 |
+
config = json.load(f)
|
36 |
+
|
37 |
+
# Create model
|
38 |
+
model = cls(config)
|
39 |
+
|
40 |
+
# Load weights
|
41 |
+
if os.path.exists(model_file):
|
42 |
+
state_dict = torch.load(model_file, map_location=torch.device(map_location))
|
43 |
+
model.load_state_dict(state_dict, strict=strict)
|
44 |
+
model.eval()
|
45 |
+
else:
|
46 |
+
raise FileNotFoundError(f"Model weights not found at {model_file}")
|
47 |
+
|
48 |
+
return model
|
49 |
+
|
50 |
+
|
51 |
+
def push_to_hub(self, repo_id, token=None, commit_message=None, private=False):
|
52 |
+
"""Push model and config to Hugging Face Hub.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
repo_id (str): Repository ID on HuggingFace Hub (e.g., 'username/model-name')
|
56 |
+
token (str, optional): HuggingFace token. If None, will use token from ~/.huggingface/token
|
57 |
+
commit_message (str, optional): Commit message for the push
|
58 |
+
private (bool, optional): Whether to make the repository private
|
59 |
+
"""
|
60 |
+
|
61 |
+
# Create repo if it doesn't exist
|
62 |
+
api = HfApi()
|
63 |
+
try:
|
64 |
+
create_repo(repo_id, private=private, token=token, exist_ok=True)
|
65 |
+
except Exception as e:
|
66 |
+
print(f"Repository creation failed: {e}")
|
67 |
+
return
|
68 |
+
|
69 |
+
# Save config
|
70 |
+
config = cfg2dict(self.config)
|
71 |
+
with open("config.json", "w", encoding="utf-8") as f:
|
72 |
+
json.dump(config, f, indent=2, sort_keys=True)
|
73 |
+
|
74 |
+
# Save model weights
|
75 |
+
torch.save(self.cpu().state_dict(), "pytorch_model.bin")
|
76 |
+
self.to(self.device if hasattr(self, 'device') else 'cuda' if torch.cuda.is_available() else 'cpu') # restore device
|
77 |
+
|
78 |
+
# Push files to hub
|
79 |
+
files_to_push = ["config.json", "pytorch_model.bin"]
|
80 |
+
for file in files_to_push:
|
81 |
+
api.upload_file(
|
82 |
+
path_or_fileobj=file,
|
83 |
+
path_in_repo=file,
|
84 |
+
repo_id=repo_id,
|
85 |
+
token=token,
|
86 |
+
commit_message=commit_message or f"Upload {file}"
|
87 |
+
)
|
88 |
+
os.remove(file) # Clean up local files
|
89 |
+
|
90 |
+
def save_pretrained(self, save_directory: str, **kwargs):
|
91 |
+
"""Save model weights and configuration to a directory.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
save_directory (str): Directory to save files in
|
95 |
+
**kwargs: Additional arguments passed to save functions
|
96 |
+
"""
|
97 |
+
os.makedirs(save_directory, exist_ok=True)
|
98 |
+
|
99 |
+
# Save config
|
100 |
+
config = cfg2dict(self.config)
|
101 |
+
config_file = os.path.join(save_directory, "config.json")
|
102 |
+
with open(config_file, "w", encoding="utf-8") as f:
|
103 |
+
json.dump(config, f, indent=2, sort_keys=True)
|
104 |
+
|
105 |
+
# Save model weights
|
106 |
+
model_file = os.path.join(save_directory, "pytorch_model.bin")
|
107 |
+
torch.save(self.cpu().state_dict(), model_file)
|
108 |
+
self.to(self.device if hasattr(self, 'device') else 'cuda' if torch.cuda.is_available() else 'cpu') # restore device
|
sonics/models/model.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sonics.models.spectttra import SpecTTTra
|
2 |
+
from sonics.models.vit import ViT
|
3 |
+
from sonics.layers.feature import FeatureExtractor
|
4 |
+
from sonics.layers.augment import AugmentLayer
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import timm
|
8 |
+
|
9 |
+
|
10 |
+
def use_global_pool(model_name):
|
11 |
+
"""
|
12 |
+
Check if the model requires global pooling or not.
|
13 |
+
"""
|
14 |
+
no_global_pool = ["timm"]
|
15 |
+
return False if any(x in model_name for x in no_global_pool) else True
|
16 |
+
|
17 |
+
|
18 |
+
def get_embed_dim(model_name, encoder):
|
19 |
+
"""
|
20 |
+
Get the embedding dimension of the encoder.
|
21 |
+
"""
|
22 |
+
if "timm" in model_name:
|
23 |
+
return encoder.head_hidden_size
|
24 |
+
else:
|
25 |
+
return encoder.embed_dim
|
26 |
+
|
27 |
+
|
28 |
+
def use_init_weights(model_name):
|
29 |
+
"""
|
30 |
+
Check if the model requires initialization of weights or not.
|
31 |
+
"""
|
32 |
+
has_init_weights = ["timm"]
|
33 |
+
return False if any(x in model_name for x in has_init_weights) else True
|
34 |
+
|
35 |
+
|
36 |
+
class AudioClassifier(nn.Module):
|
37 |
+
def __init__(self, cfg):
|
38 |
+
super().__init__()
|
39 |
+
|
40 |
+
self.model_name = cfg.model.name
|
41 |
+
self.input_shape = cfg.model.input_shape
|
42 |
+
self.num_classes = cfg.num_classes
|
43 |
+
self.ft_extractor = FeatureExtractor(cfg)
|
44 |
+
self.augment = AugmentLayer(cfg)
|
45 |
+
self.encoder = self.get_encoder(cfg)
|
46 |
+
self.embed_dim = get_embed_dim(self.model_name, self.encoder)
|
47 |
+
self.classifier = nn.Linear(self.embed_dim, self.num_classes)
|
48 |
+
self.use_init_weights = getattr(cfg.model, "use_init_weights", True)
|
49 |
+
|
50 |
+
# Initialize weights
|
51 |
+
(
|
52 |
+
self.initialize_weights()
|
53 |
+
if self.use_init_weights and use_init_weights(self.model_name)
|
54 |
+
else None
|
55 |
+
)
|
56 |
+
|
57 |
+
def get_encoder(self, cfg):
|
58 |
+
if cfg.model.name == "SpecTTTra":
|
59 |
+
model = SpecTTTra(
|
60 |
+
input_spec_dim=cfg.model.input_shape[0],
|
61 |
+
input_temp_dim=cfg.model.input_shape[1],
|
62 |
+
embed_dim=cfg.model.embed_dim,
|
63 |
+
t_clip=cfg.model.t_clip,
|
64 |
+
f_clip=cfg.model.f_clip,
|
65 |
+
num_heads=cfg.model.num_heads,
|
66 |
+
num_layers=cfg.model.num_layers,
|
67 |
+
pre_norm=cfg.model.pre_norm,
|
68 |
+
pe_learnable=cfg.model.pe_learnable,
|
69 |
+
pos_drop_rate=getattr(cfg.model, "pos_drop_rate", 0.0),
|
70 |
+
attn_drop_rate=getattr(cfg.model, "attn_drop_rate", 0.0),
|
71 |
+
proj_drop_rate=getattr(cfg.model, "proj_drop_rate", 0.0),
|
72 |
+
mlp_ratio=getattr(cfg.model, "mlp_ratio", 4.0),
|
73 |
+
)
|
74 |
+
elif cfg.model.name == "ViT":
|
75 |
+
model = ViT(
|
76 |
+
image_size=cfg.model.input_shape,
|
77 |
+
patch_size=cfg.model.patch_size,
|
78 |
+
embed_dim=cfg.model.embed_dim,
|
79 |
+
num_heads=cfg.model.num_heads,
|
80 |
+
num_layers=cfg.model.num_layers,
|
81 |
+
pe_learnable=cfg.model.pe_learnable,
|
82 |
+
patch_norm=getattr(cfg.model, "patch_norm", False),
|
83 |
+
pos_drop_rate=getattr(cfg.model, "pos_drop_rate", 0.0),
|
84 |
+
attn_drop_rate=getattr(cfg.model, "attn_drop_rate", 0.0),
|
85 |
+
proj_drop_rate=getattr(cfg.model, "proj_drop_rate", 0.0),
|
86 |
+
mlp_ratio=getattr(cfg.model, "mlp_ratio", 4.0),
|
87 |
+
)
|
88 |
+
elif "timm" in cfg.model.name:
|
89 |
+
model_name = cfg.model.name.replace("timm-", "")
|
90 |
+
model = timm.create_model(
|
91 |
+
model_name,
|
92 |
+
pretrained=cfg.model.pretrained,
|
93 |
+
in_chans=1,
|
94 |
+
num_classes=0,
|
95 |
+
)
|
96 |
+
else:
|
97 |
+
raise ValueError(f"Model {cfg.model.name} not supported in V1.")
|
98 |
+
return model
|
99 |
+
|
100 |
+
def forward(self, audio, y=None):
|
101 |
+
spec = self.ft_extractor(audio) # shape: (batch_size, n_mels, n_frames)
|
102 |
+
if self.training:
|
103 |
+
spec, y = self.augment(spec, y)
|
104 |
+
spec = spec.unsqueeze(1) # shape: (batch_size, 1, n_mels, n_frames)
|
105 |
+
spec = F.interpolate(spec, size=tuple(self.input_shape), mode="bilinear")
|
106 |
+
features = self.encoder(spec)
|
107 |
+
embeds = features.mean(dim=1) if use_global_pool(self.model_name) else features
|
108 |
+
preds = self.classifier(embeds)
|
109 |
+
return preds if y is None else (preds, y)
|
110 |
+
|
111 |
+
def initialize_weights(self):
|
112 |
+
for name, module in self.named_modules():
|
113 |
+
if isinstance(module, nn.Linear):
|
114 |
+
if name.startswith("classifier"):
|
115 |
+
nn.init.zeros_(module.weight)
|
116 |
+
nn.init.constant_(module.bias, 0.0)
|
117 |
+
else:
|
118 |
+
nn.init.xavier_uniform_(module.weight)
|
119 |
+
if module.bias is not None:
|
120 |
+
nn.init.normal_(module.bias, std=1e-6)
|
121 |
+
elif isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv1d):
|
122 |
+
nn.init.kaiming_normal_(
|
123 |
+
module.weight, mode="fan_out", nonlinearity="relu"
|
124 |
+
)
|
125 |
+
if module.bias is not None:
|
126 |
+
nn.init.zeros_(module.bias)
|
127 |
+
elif hasattr(module, "init_weights"):
|
128 |
+
module.init_weights()
|
sonics/models/spectttra.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from sonics.layers import Transformer
|
3 |
+
from sonics.layers.tokenizer import STTokenizer
|
4 |
+
|
5 |
+
|
6 |
+
class SpecTTTra(nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
input_spec_dim,
|
10 |
+
input_temp_dim,
|
11 |
+
embed_dim,
|
12 |
+
t_clip,
|
13 |
+
f_clip,
|
14 |
+
num_heads,
|
15 |
+
num_layers,
|
16 |
+
pre_norm=False,
|
17 |
+
pe_learnable=False,
|
18 |
+
pos_drop_rate=0.0,
|
19 |
+
attn_drop_rate=0.0,
|
20 |
+
proj_drop_rate=0.0,
|
21 |
+
mlp_ratio=4.0,
|
22 |
+
):
|
23 |
+
super(SpecTTTra, self).__init__()
|
24 |
+
self.input_spec_dim = input_spec_dim
|
25 |
+
self.input_temp_dim = input_temp_dim
|
26 |
+
self.embed_dim = embed_dim
|
27 |
+
self.t_clip = t_clip
|
28 |
+
self.f_clip = f_clip
|
29 |
+
self.num_heads = num_heads
|
30 |
+
self.num_layers = num_layers
|
31 |
+
self.pre_norm = (
|
32 |
+
pre_norm # applied after tokenization before transformer (used in CLIP)
|
33 |
+
)
|
34 |
+
self.pe_learnable = pe_learnable # learned positional encoding
|
35 |
+
self.pos_drop_rate = pos_drop_rate
|
36 |
+
self.attn_drop_rate = attn_drop_rate
|
37 |
+
self.proj_drop_rate = proj_drop_rate
|
38 |
+
self.mlp_ratio = mlp_ratio
|
39 |
+
|
40 |
+
self.st_tokenizer = STTokenizer(
|
41 |
+
input_spec_dim,
|
42 |
+
input_temp_dim,
|
43 |
+
t_clip,
|
44 |
+
f_clip,
|
45 |
+
embed_dim,
|
46 |
+
pre_norm=pre_norm,
|
47 |
+
pe_learnable=pe_learnable,
|
48 |
+
)
|
49 |
+
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
50 |
+
self.transformer = Transformer(
|
51 |
+
embed_dim,
|
52 |
+
num_heads,
|
53 |
+
num_layers,
|
54 |
+
attn_drop=self.attn_drop_rate,
|
55 |
+
proj_drop=self.proj_drop_rate,
|
56 |
+
mlp_ratio=self.mlp_ratio,
|
57 |
+
)
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
# Squeeze the channel dimension if it exists
|
61 |
+
if x.dim() == 4:
|
62 |
+
x = x.squeeze(1)
|
63 |
+
|
64 |
+
# Spectro-temporal tokenization
|
65 |
+
spectro_temporal_tokens = self.st_tokenizer(x)
|
66 |
+
|
67 |
+
# Positional dropout
|
68 |
+
spectro_temporal_tokens = self.pos_drop(spectro_temporal_tokens)
|
69 |
+
|
70 |
+
# Transformer
|
71 |
+
output = self.transformer(spectro_temporal_tokens) # shape: (B, T/t + F/f, dim)
|
72 |
+
|
73 |
+
return output
|
74 |
+
|
75 |
+
|
76 |
+
# Example usage:
|
77 |
+
input_spec_dim = 384
|
78 |
+
input_temp_dim = 128
|
79 |
+
embed_dim = 512
|
80 |
+
t_clip = 20 # This means t
|
81 |
+
f_clip = 10 # This means f
|
82 |
+
num_heads = 8
|
83 |
+
num_layers = 6
|
84 |
+
dim_feedforward = 512
|
85 |
+
num_classes = 10
|
sonics/models/vit.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from sonics.layers import (
|
4 |
+
SinusoidPositionalEncoding,
|
5 |
+
LearnedPositionalEncoding,
|
6 |
+
Transformer,
|
7 |
+
)
|
8 |
+
from timm.layers import PatchEmbed
|
9 |
+
|
10 |
+
|
11 |
+
class ViT(nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
image_size,
|
15 |
+
patch_size,
|
16 |
+
embed_dim,
|
17 |
+
num_heads,
|
18 |
+
num_layers,
|
19 |
+
pe_learnable=False,
|
20 |
+
patch_norm=False,
|
21 |
+
pos_drop_rate=0.0,
|
22 |
+
attn_drop_rate=0.0,
|
23 |
+
proj_drop_rate=0.0,
|
24 |
+
mlp_ratio=4.0,
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
assert (
|
28 |
+
image_size[0] % patch_size == 0 and image_size[1] % patch_size == 0
|
29 |
+
), "Image dimensions must be divisible by patch size."
|
30 |
+
|
31 |
+
self.patch_size = patch_size
|
32 |
+
self.embed_dim = embed_dim
|
33 |
+
self.num_heads = num_heads
|
34 |
+
self.num_layers = num_layers
|
35 |
+
self.pe_learnable = pe_learnable
|
36 |
+
self.patch_norm = patch_norm
|
37 |
+
self.pos_drop_rate = pos_drop_rate
|
38 |
+
self.attn_drop_rate = attn_drop_rate
|
39 |
+
self.proj_drop_rate = proj_drop_rate
|
40 |
+
self.mlp_ratio = mlp_ratio
|
41 |
+
|
42 |
+
self.num_patches = (image_size[0] // patch_size) * (image_size[1] // patch_size)
|
43 |
+
|
44 |
+
# self.patch_conv = nn.Conv2d(
|
45 |
+
# 1, embed_dim, kernel_size=patch_size, stride=patch_size
|
46 |
+
# ) # Original ViT has 3 input channels
|
47 |
+
self.patch_encoder = PatchEmbed(
|
48 |
+
img_size=image_size,
|
49 |
+
patch_size=patch_size,
|
50 |
+
in_chans=1,
|
51 |
+
embed_dim=embed_dim,
|
52 |
+
norm_layer=nn.LayerNorm if patch_norm else None,
|
53 |
+
)
|
54 |
+
self.pos_encoder = (
|
55 |
+
SinusoidPositionalEncoding(embed_dim)
|
56 |
+
if not pe_learnable
|
57 |
+
else LearnedPositionalEncoding(embed_dim, self.num_patches)
|
58 |
+
)
|
59 |
+
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
60 |
+
|
61 |
+
self.transformer = Transformer(
|
62 |
+
embed_dim,
|
63 |
+
num_heads,
|
64 |
+
num_layers,
|
65 |
+
attn_drop=self.attn_drop_rate,
|
66 |
+
proj_drop=self.proj_drop_rate,
|
67 |
+
mlp_ratio=self.mlp_ratio,
|
68 |
+
)
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
B = x.shape[0]
|
72 |
+
# x = x.unsqueeze(1) # B x 1 x n_mels x n_frames # taken care of in the AudioClassifier
|
73 |
+
if x.dim() == 3:
|
74 |
+
x = x.unsqueeze(1) # timm PatchEmbed expects 4D tensor
|
75 |
+
|
76 |
+
# Convolutional patch embedding
|
77 |
+
# patches = self.patch_conv(x) # B x embed_dim x num_patches_h x num_patches_w
|
78 |
+
patches = self.patch_encoder(x)
|
79 |
+
|
80 |
+
# # Reshape patches
|
81 |
+
# patches = patches.permute(
|
82 |
+
# 0, 2, 3, 1
|
83 |
+
# ).contiguous() # B x num_patches_h x num_patches_w x embed_dim
|
84 |
+
# patches = patches.view(B, -1, patches.size(-1)) # B x num_patches x embed_dim
|
85 |
+
|
86 |
+
# Add positional embeddings
|
87 |
+
embeddings = self.pos_encoder(patches)
|
88 |
+
|
89 |
+
# Positional dropout
|
90 |
+
embeddings = self.pos_drop(embeddings)
|
91 |
+
|
92 |
+
# Transformer encoding
|
93 |
+
output = self.transformer(embeddings) # B x num_patches x embed_dim
|
94 |
+
|
95 |
+
return output
|
96 |
+
|
97 |
+
|
98 |
+
batch_size = 1
|
99 |
+
input_height = 128
|
100 |
+
input_width = 384 * 6 * 4
|
101 |
+
patch_size = 16
|
sonics/utils/config.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from types import SimpleNamespace
|
2 |
+
|
3 |
+
|
4 |
+
def dict2cfg(d):
|
5 |
+
"""
|
6 |
+
Converts a dictionary into a SimpleNamespace
|
7 |
+
"""
|
8 |
+
for k, v in d.items():
|
9 |
+
if type(v) == dict:
|
10 |
+
d[k] = SimpleNamespace(**v)
|
11 |
+
c = SimpleNamespace(**d)
|
12 |
+
c.audio.max_len = int(c.audio.max_time * c.audio.sample_rate)
|
13 |
+
return c
|
14 |
+
|
15 |
+
|
16 |
+
def cfg2dict(cfg):
|
17 |
+
"""
|
18 |
+
Converts a SimpleNamespace into a dictionary without modifying the original cfg.
|
19 |
+
"""
|
20 |
+
d = vars(cfg).copy() # Make a shallow copy of the cfg's __dict__
|
21 |
+
for k, v in d.items():
|
22 |
+
if isinstance(v, SimpleNamespace):
|
23 |
+
d[k] = cfg2dict(v) # Recursively convert nested SimpleNamespace objects
|
24 |
+
return d
|
sonics/utils/dataset.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import librosa
|
6 |
+
|
7 |
+
|
8 |
+
class AudioDataset(Dataset):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
filepaths,
|
12 |
+
labels,
|
13 |
+
skip_times=None,
|
14 |
+
num_classes=1,
|
15 |
+
normalize="std",
|
16 |
+
max_len=32000,
|
17 |
+
random_sampling=True,
|
18 |
+
train=False,
|
19 |
+
**kwargs
|
20 |
+
):
|
21 |
+
super().__init__(**kwargs)
|
22 |
+
self.filepaths = filepaths
|
23 |
+
self.labels = labels
|
24 |
+
self.skip_times = skip_times
|
25 |
+
self.num_classes = num_classes
|
26 |
+
self.random_sampling = random_sampling
|
27 |
+
self.normalize = normalize
|
28 |
+
self.max_len = max_len
|
29 |
+
self.train = train
|
30 |
+
if not self.train:
|
31 |
+
assert (
|
32 |
+
not self.random_sampling
|
33 |
+
), "Ensure random_sampling is disabled for val"
|
34 |
+
|
35 |
+
def __len__(self):
|
36 |
+
return len(self.filepaths)
|
37 |
+
|
38 |
+
def crop_or_pad(self, audio, max_len, random_sampling=True):
|
39 |
+
audio_len = audio.shape[0]
|
40 |
+
if random_sampling:
|
41 |
+
diff_len = abs(max_len - audio_len)
|
42 |
+
if audio_len < max_len:
|
43 |
+
pad1 = np.random.randint(0, diff_len)
|
44 |
+
pad2 = diff_len - pad1
|
45 |
+
audio = np.pad(audio, (pad1, pad2), mode="constant")
|
46 |
+
elif audio_len > max_len:
|
47 |
+
idx = np.random.randint(0, diff_len)
|
48 |
+
audio = audio[idx : (idx + max_len)]
|
49 |
+
else:
|
50 |
+
if audio_len < max_len:
|
51 |
+
audio = np.pad(audio, (0, max_len - audio_len), mode="constant")
|
52 |
+
elif audio_len > max_len:
|
53 |
+
# Crop from the beginning
|
54 |
+
# audio = audio[:max_len]
|
55 |
+
|
56 |
+
# Crop from 3/4 of the audio
|
57 |
+
# eq: l = (3x + t + x) => idx = 3x = (l - t) / 4 * 3
|
58 |
+
idx = int((audio_len - max_len) / 4 * 3)
|
59 |
+
audio = audio[idx : (idx + max_len)]
|
60 |
+
return audio
|
61 |
+
|
62 |
+
def __getitem__(self, idx):
|
63 |
+
# Load audio
|
64 |
+
audio, sr = librosa.load(self.filepaths[idx], sr=None)
|
65 |
+
target = np.array([self.labels[idx]])
|
66 |
+
|
67 |
+
# Trim start of audio (torchaudio.transforms.vad)
|
68 |
+
if self.skip_times is not None:
|
69 |
+
skip_time = self.skip_times[idx]
|
70 |
+
audio = audio[int(skip_time*sr):]
|
71 |
+
|
72 |
+
# Ensure fixed length
|
73 |
+
audio = self.crop_or_pad(audio, self.max_len, self.random_sampling)
|
74 |
+
|
75 |
+
if self.normalize == "std":
|
76 |
+
audio /= np.maximum(np.std(audio), 1e-6)
|
77 |
+
elif self.normalize == "minmax":
|
78 |
+
audio -= np.min(audio)
|
79 |
+
audio /= np.maximum(np.max(audio), 1e-6)
|
80 |
+
|
81 |
+
audio = torch.from_numpy(audio).float()
|
82 |
+
target = torch.from_numpy(target).float().squeeze()
|
83 |
+
return {
|
84 |
+
"audio": audio,
|
85 |
+
"target": target,
|
86 |
+
}
|
87 |
+
|
88 |
+
|
89 |
+
def get_dataloader(
|
90 |
+
filepaths,
|
91 |
+
labels,
|
92 |
+
skip_times=None,
|
93 |
+
batch_size=8,
|
94 |
+
num_classes=1,
|
95 |
+
max_len=32000,
|
96 |
+
random_sampling=True,
|
97 |
+
normalize="std",
|
98 |
+
train=False,
|
99 |
+
# drop_last=False,
|
100 |
+
pin_memory=True,
|
101 |
+
worker_init_fn=None,
|
102 |
+
collate_fn=None,
|
103 |
+
num_workers=0,
|
104 |
+
distributed=False,
|
105 |
+
):
|
106 |
+
dataset = AudioDataset(
|
107 |
+
filepaths,
|
108 |
+
labels,
|
109 |
+
skip_times=skip_times,
|
110 |
+
num_classes=num_classes,
|
111 |
+
max_len=max_len,
|
112 |
+
random_sampling=random_sampling,
|
113 |
+
normalize=normalize,
|
114 |
+
train=train,
|
115 |
+
)
|
116 |
+
|
117 |
+
if distributed:
|
118 |
+
# drop_last is set to True to validate properly
|
119 |
+
# Ref: https://discuss.pytorch.org/t/how-do-i-validate-with-pytorch-distributeddataparallel/172269/8
|
120 |
+
sampler = torch.utils.data.distributed.DistributedSampler(
|
121 |
+
dataset, shuffle=train, drop_last=not train
|
122 |
+
)
|
123 |
+
else:
|
124 |
+
sampler = None
|
125 |
+
|
126 |
+
dataloader = DataLoader(
|
127 |
+
dataset,
|
128 |
+
batch_size=batch_size,
|
129 |
+
shuffle=(sampler is None) and train,
|
130 |
+
# drop_last=drop_last,
|
131 |
+
num_workers=num_workers,
|
132 |
+
pin_memory=pin_memory,
|
133 |
+
worker_init_fn=worker_init_fn,
|
134 |
+
collate_fn=collate_fn,
|
135 |
+
sampler=sampler,
|
136 |
+
)
|
137 |
+
return dataloader
|
sonics/utils/losses.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class BCEWithLogitsLoss(nn.BCEWithLogitsLoss):
|
7 |
+
def __init__(self, label_smoothing=0.0, **kwargs):
|
8 |
+
super(BCEWithLogitsLoss, self).__init__(**kwargs)
|
9 |
+
self.label_smoothing = label_smoothing
|
10 |
+
|
11 |
+
def forward(self, input, target):
|
12 |
+
if self.label_smoothing:
|
13 |
+
target = target * (1.0 - self.label_smoothing) + 0.5 * self.label_smoothing
|
14 |
+
return super(BCEWithLogitsLoss, self).forward(input, target)
|
15 |
+
|
16 |
+
|
17 |
+
class SigmoidFocalLoss(nn.Module):
|
18 |
+
def __init__(self, alpha=1, gamma=2, label_smoothing=0.0, reduction="mean"):
|
19 |
+
"""
|
20 |
+
Args:
|
21 |
+
alpha (float): Weighting factor in range (0,1) to balance positive vs negative examples.
|
22 |
+
gamma (float): Focusing parameter to reduce the relative loss for well-classified examples.
|
23 |
+
label_smoothing (float): Label smoothing factor to reduce the confidence of the true label.
|
24 |
+
reduction (str): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.
|
25 |
+
'none': no reduction will be applied,
|
26 |
+
'mean': the sum of the output will be divided by the number of elements in the output,
|
27 |
+
'sum': the output will be summed.
|
28 |
+
"""
|
29 |
+
super(SigmoidFocalLoss, self).__init__()
|
30 |
+
self.alpha = alpha
|
31 |
+
self.gamma = gamma
|
32 |
+
self.label_smoothing = label_smoothing
|
33 |
+
self.reduction = reduction
|
34 |
+
|
35 |
+
def forward(self, input, target):
|
36 |
+
"""
|
37 |
+
Args:
|
38 |
+
input (Tensor): Predicted logits for each example.
|
39 |
+
target (Tensor): Ground truth binary labels (0 or 1) for each example.
|
40 |
+
"""
|
41 |
+
if self.label_smoothing:
|
42 |
+
target = target * (1.0 - self.label_smoothing) + 0.5 * self.label_smoothing
|
43 |
+
|
44 |
+
p = torch.sigmoid(input)
|
45 |
+
|
46 |
+
ce_loss = F.binary_cross_entropy_with_logits(input, target, reduction="none")
|
47 |
+
p_t = p * target + (1 - p) * (1 - target)
|
48 |
+
loss = ce_loss * ((1 - p_t) ** self.gamma)
|
49 |
+
|
50 |
+
if self.alpha >= 0:
|
51 |
+
alpha_t = self.alpha * target + (1 - self.alpha) * (1 - target)
|
52 |
+
loss = alpha_t * loss
|
53 |
+
|
54 |
+
# Check reduction option and return loss accordingly
|
55 |
+
if self.reduction == "none":
|
56 |
+
pass
|
57 |
+
elif self.reduction == "mean":
|
58 |
+
loss = loss.mean()
|
59 |
+
elif self.reduction == "sum":
|
60 |
+
loss = loss.sum()
|
61 |
+
else:
|
62 |
+
raise ValueError(
|
63 |
+
f"Invalid Value for arg 'reduction': '{self.reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
|
64 |
+
)
|
65 |
+
return loss
|
sonics/utils/metrics.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
from sklearn import metrics
|
4 |
+
|
5 |
+
np.seterr(divide="ignore", invalid="ignore")
|
6 |
+
|
7 |
+
|
8 |
+
class AverageMeter:
|
9 |
+
def __init__(self):
|
10 |
+
self.reset()
|
11 |
+
|
12 |
+
def reset(self):
|
13 |
+
self.val = 0
|
14 |
+
self.avg = 0
|
15 |
+
self.sum = 0
|
16 |
+
self.count = 0
|
17 |
+
|
18 |
+
def update(self, val, n=1):
|
19 |
+
self.val = val
|
20 |
+
self.sum += val * n
|
21 |
+
self.count += n
|
22 |
+
self.avg = self.sum / self.count
|
23 |
+
|
24 |
+
|
25 |
+
class F1Meter:
|
26 |
+
def __init__(self, average="binary"):
|
27 |
+
self.average = average
|
28 |
+
self.reset()
|
29 |
+
|
30 |
+
def update(self, y_true, y_pred):
|
31 |
+
self.y_true = np.concatenate([self.y_true, y_true])
|
32 |
+
self.y_pred = np.concatenate([self.y_pred, y_pred])
|
33 |
+
self.avg = metrics.f1_score(self.y_true, self.y_pred, average=self.average)
|
34 |
+
|
35 |
+
def reset(self):
|
36 |
+
self.y_true = np.array([])
|
37 |
+
self.y_pred = np.array([])
|
38 |
+
|
39 |
+
|
40 |
+
class SensitivityMeter:
|
41 |
+
def __init__(self, average="binary"):
|
42 |
+
self.average = average
|
43 |
+
self.reset()
|
44 |
+
|
45 |
+
def update(self, y_true, y_pred):
|
46 |
+
self.y_true = np.concatenate([self.y_true, y_true])
|
47 |
+
self.y_pred = np.concatenate([self.y_pred, y_pred])
|
48 |
+
self.avg = metrics.recall_score(
|
49 |
+
self.y_true, self.y_pred, pos_label=1, average=self.average
|
50 |
+
)
|
51 |
+
|
52 |
+
def reset(self):
|
53 |
+
self.y_true = np.array([])
|
54 |
+
self.y_pred = np.array([])
|
55 |
+
|
56 |
+
|
57 |
+
class SpecificityMeter:
|
58 |
+
def __init__(self, average="binary"):
|
59 |
+
self.average = average
|
60 |
+
self.reset()
|
61 |
+
|
62 |
+
def update(self, y_true, y_pred):
|
63 |
+
self.y_true = np.concatenate([self.y_true, y_true])
|
64 |
+
self.y_pred = np.concatenate([self.y_pred, y_pred])
|
65 |
+
self.avg = metrics.recall_score(
|
66 |
+
self.y_true, self.y_pred, pos_label=0, average=self.average
|
67 |
+
)
|
68 |
+
|
69 |
+
def reset(self):
|
70 |
+
self.y_true = np.array([])
|
71 |
+
self.y_pred = np.array([])
|
72 |
+
|
73 |
+
|
74 |
+
class AccuracyMeter:
|
75 |
+
def __init__(self):
|
76 |
+
self.reset()
|
77 |
+
|
78 |
+
def update(self, y_true, y_pred):
|
79 |
+
self.y_true = np.concatenate([self.y_true, y_true])
|
80 |
+
self.y_pred = np.concatenate([self.y_pred, y_pred])
|
81 |
+
self.avg = metrics.balanced_accuracy_score(self.y_true, self.y_pred)
|
82 |
+
|
83 |
+
def reset(self):
|
84 |
+
self.y_true = np.array([])
|
85 |
+
self.y_pred = np.array([])
|
86 |
+
|
87 |
+
|
88 |
+
def get_part_result(test_pred_df):
|
89 |
+
# Create `singer` column to store whether the singer is seen or unseen
|
90 |
+
test_pred_df["singer"] = test_pred_df.artist_overlap.map(
|
91 |
+
lambda x: "seen" if x else "unseen"
|
92 |
+
)
|
93 |
+
|
94 |
+
# Create `fake_type` column to store different types of fake songs
|
95 |
+
test_pred_df["fake_type"] = test_pred_df.label
|
96 |
+
|
97 |
+
# Create `length` column to store different duration type songs
|
98 |
+
test_pred_df["length"] = test_pred_df["duration_part"] = test_pred_df[
|
99 |
+
"duration"
|
100 |
+
].map(lambda t: "short" if t <= 60 else ("long" if t > 120 else "medium"))
|
101 |
+
|
102 |
+
# Initialize an empty DataFrame to store results
|
103 |
+
part_result_df = pd.DataFrame()
|
104 |
+
|
105 |
+
# Loop through the specified categories
|
106 |
+
for cat in ["algorithm", "singer", "fake_type", "length"]:
|
107 |
+
# Filter the dataframe based on the condition for each category
|
108 |
+
if cat in ["algorithm", "fake_type"]:
|
109 |
+
cat_df = test_pred_df.query("target == 1")
|
110 |
+
elif cat == "singer":
|
111 |
+
cat_df = test_pred_df.query("target == 0")
|
112 |
+
else:
|
113 |
+
cat_df = test_pred_df.copy()
|
114 |
+
|
115 |
+
# Compute metrics for each partition
|
116 |
+
for part in cat_df[cat].unique():
|
117 |
+
part_df = cat_df[cat_df[cat] == part]
|
118 |
+
y_true = part_df.y_true.values.astype(int)
|
119 |
+
y_pred = (part_df.y_pred.values > 0.5).astype(int)
|
120 |
+
|
121 |
+
# Compute TPR for `algorithm`, `fake_type`; TNR for `singer` and F1 for `length`
|
122 |
+
score = (
|
123 |
+
metrics.recall_score(
|
124 |
+
y_true, y_pred, pos_label=1 if cat != "singer" else 0
|
125 |
+
)
|
126 |
+
if cat != "length"
|
127 |
+
else metrics.f1_score(y_true, y_pred, average="macro")
|
128 |
+
)
|
129 |
+
|
130 |
+
# Create a DataFrame for the current result
|
131 |
+
result_df = pd.DataFrame(
|
132 |
+
{
|
133 |
+
"category": [cat],
|
134 |
+
"partition": [part],
|
135 |
+
"score": [score],
|
136 |
+
"size": [len(part_df)],
|
137 |
+
}
|
138 |
+
)
|
139 |
+
|
140 |
+
# Concatenate the result with the existing DataFrame
|
141 |
+
part_result_df = pd.concat([part_result_df, result_df], ignore_index=True)
|
142 |
+
|
143 |
+
# Create a dictionary with the results
|
144 |
+
result_dict = {
|
145 |
+
f"{row['category']}/{row['partition']}": row["score"]
|
146 |
+
for _, row in part_result_df.iterrows()
|
147 |
+
}
|
148 |
+
|
149 |
+
return part_result_df, result_dict
|
sonics/utils/perf.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import torch
|
3 |
+
import pandas as pd
|
4 |
+
from fvcore.nn import FlopCountAnalysis, ActivationCountAnalysis
|
5 |
+
|
6 |
+
|
7 |
+
def profile_model(model, input_tensor, display=False):
|
8 |
+
flops = calculate_flops(model, input_tensor[0:1, ...]) # (1, n_mels, n_frames)
|
9 |
+
acts = calculate_activations(model, input_tensor[0:1, ...]) # (1, n_mels, n_frames)
|
10 |
+
params = calculate_params(model)
|
11 |
+
speed = calculate_speed(model, input_tensor[0:1, ...]) # (1, n_mels, n_frames)
|
12 |
+
memory = calculate_memory(model, input_tensor) # (B, n_mels, n_frames)
|
13 |
+
profile_data = {
|
14 |
+
"Metric": [
|
15 |
+
"FLOPs (G)",
|
16 |
+
"Activations (M)",
|
17 |
+
"Params (M)",
|
18 |
+
"Memory (GB)",
|
19 |
+
"Speed (A/S)",
|
20 |
+
],
|
21 |
+
"Value": [flops, acts, params, memory, speed],
|
22 |
+
}
|
23 |
+
profile_df = pd.DataFrame(profile_data).set_index("Metric").T
|
24 |
+
if display:
|
25 |
+
print(profile_df.to_markdown(index=False, tablefmt="grid"))
|
26 |
+
return profile_df
|
27 |
+
|
28 |
+
|
29 |
+
def calculate_speed(model, input_tensor, num_runs=100, warmup_runs=5):
|
30 |
+
model.eval()
|
31 |
+
|
32 |
+
if torch.cuda.is_available():
|
33 |
+
# Warm-up iterations
|
34 |
+
with torch.no_grad():
|
35 |
+
for _ in range(warmup_runs):
|
36 |
+
_ = model(input_tensor)
|
37 |
+
|
38 |
+
# Create CUDA events for timing
|
39 |
+
start = torch.cuda.Event(enable_timing=True)
|
40 |
+
end = torch.cuda.Event(enable_timing=True)
|
41 |
+
|
42 |
+
# Actual timing
|
43 |
+
start.record()
|
44 |
+
with torch.no_grad():
|
45 |
+
for _ in range(num_runs):
|
46 |
+
_ = model(input_tensor)
|
47 |
+
end.record()
|
48 |
+
|
49 |
+
# Synchronize to wait for the events to be recorded
|
50 |
+
torch.cuda.synchronize()
|
51 |
+
|
52 |
+
# Calculate elapsed time
|
53 |
+
elapsed_time = start.elapsed_time(end) # in milliseconds
|
54 |
+
latency = elapsed_time / num_runs / 1000.0 # convert to seconds
|
55 |
+
else:
|
56 |
+
# Warm-up iterations
|
57 |
+
with torch.no_grad():
|
58 |
+
for _ in range(warmup_runs):
|
59 |
+
_ = model(input_tensor)
|
60 |
+
|
61 |
+
# Actual timing
|
62 |
+
start = time.time()
|
63 |
+
with torch.no_grad():
|
64 |
+
for _ in range(num_runs):
|
65 |
+
_ = model(input_tensor)
|
66 |
+
end = time.time()
|
67 |
+
|
68 |
+
# Calculate elapsed time
|
69 |
+
latency = (end - start) / num_runs
|
70 |
+
|
71 |
+
return 1.0 / latency
|
72 |
+
|
73 |
+
|
74 |
+
def calculate_flops(model, input_tensor):
|
75 |
+
"""Calculate FLOPs in GigaFLOPs.
|
76 |
+
Models often reports MACs as FLOPs e.g. ConvNeXt, timm library
|
77 |
+
Reference:
|
78 |
+
1. https://github.com/huggingface/pytorch-image-models/blob/main/benchmark.py#L206
|
79 |
+
2. https://github.com/facebookresearch/fvcore/issues/69
|
80 |
+
"""
|
81 |
+
flops = FlopCountAnalysis(model, input_tensor).total()
|
82 |
+
return flops / 1e9 # in GigaFLOPs
|
83 |
+
|
84 |
+
|
85 |
+
def calculate_activations(model, input_tensor):
|
86 |
+
acts = ActivationCountAnalysis(model, input_tensor).total()
|
87 |
+
return acts / 1e6 # in Millions
|
88 |
+
|
89 |
+
|
90 |
+
def calculate_params(model):
|
91 |
+
return sum(p.numel() for p in model.parameters()) / 1e6 # in Millions
|
92 |
+
|
93 |
+
|
94 |
+
def calculate_memory(model, input_tensor):
|
95 |
+
if torch.cuda.is_available():
|
96 |
+
torch.cuda.empty_cache()
|
97 |
+
torch.cuda.reset_peak_memory_stats(device=None)
|
98 |
+
start_memory = torch.cuda.max_memory_allocated(device=None)
|
99 |
+
model.train()
|
100 |
+
_ = model(input_tensor)
|
101 |
+
end_memory = torch.cuda.max_memory_allocated(device=None)
|
102 |
+
torch.cuda.empty_cache()
|
103 |
+
torch.cuda.reset_peak_memory_stats(device=None)
|
104 |
+
memory = (end_memory - start_memory) / (1024**3) # in GB
|
105 |
+
else:
|
106 |
+
memory = 0
|
107 |
+
return memory
|
sonics/utils/scheduler.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from torch.optim.lr_scheduler import LambdaLR
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
|
6 |
+
def get_scheduler(
|
7 |
+
optimizer,
|
8 |
+
start_lr,
|
9 |
+
max_lr,
|
10 |
+
min_lr,
|
11 |
+
warmup_epochs,
|
12 |
+
sustain_epochs,
|
13 |
+
total_epochs,
|
14 |
+
decay,
|
15 |
+
mode="cosine",
|
16 |
+
):
|
17 |
+
def lr_lambda(epoch):
|
18 |
+
if epoch < warmup_epochs:
|
19 |
+
return (max_lr - start_lr) / warmup_epochs * epoch + start_lr
|
20 |
+
|
21 |
+
elif epoch < warmup_epochs + sustain_epochs:
|
22 |
+
return max_lr
|
23 |
+
|
24 |
+
elif mode == "exponential":
|
25 |
+
return (max_lr - min_lr) * decay ** (
|
26 |
+
epoch - warmup_epochs - sustain_epochs
|
27 |
+
) + min_lr
|
28 |
+
|
29 |
+
elif mode == "step":
|
30 |
+
return max_lr * decay ** ((epoch - warmup_epochs - sustain_epochs) // 2)
|
31 |
+
|
32 |
+
elif mode == "cosine":
|
33 |
+
decay_total_epochs = total_epochs - warmup_epochs - sustain_epochs + 3
|
34 |
+
decay_epoch_index = epoch - warmup_epochs - sustain_epochs
|
35 |
+
phase = math.pi * decay_epoch_index / decay_total_epochs
|
36 |
+
cosine_decay = 0.5 * (1 + math.cos(phase))
|
37 |
+
return (max_lr - min_lr) * cosine_decay + min_lr
|
38 |
+
|
39 |
+
else:
|
40 |
+
raise ValueError(
|
41 |
+
f"Unsupported mode '{mode}'. Supported modes are 'exp', 'step', 'cosine'."
|
42 |
+
)
|
43 |
+
|
44 |
+
return LambdaLR(optimizer, lr_lambda)
|
45 |
+
|
46 |
+
|
47 |
+
def _get_cosine_schedule_with_warmup_lr_lambda(
|
48 |
+
current_step: int,
|
49 |
+
*,
|
50 |
+
num_warmup_steps: int,
|
51 |
+
num_training_steps: int,
|
52 |
+
num_cycles: float,
|
53 |
+
):
|
54 |
+
if current_step < num_warmup_steps:
|
55 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
56 |
+
progress = float(current_step - num_warmup_steps) / float(
|
57 |
+
max(1, num_training_steps - num_warmup_steps)
|
58 |
+
)
|
59 |
+
return max(
|
60 |
+
0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
|
61 |
+
)
|
62 |
+
|
63 |
+
|
64 |
+
def get_cosine_schedule_with_warmup(
|
65 |
+
optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1
|
66 |
+
):
|
67 |
+
"""
|
68 |
+
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
69 |
+
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
|
70 |
+
initial lr set in the optimizer.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
74 |
+
The optimizer for which to schedule the learning rate.
|
75 |
+
num_warmup_steps (`int`):
|
76 |
+
The number of steps for the warmup phase.
|
77 |
+
num_training_steps (`int`):
|
78 |
+
The total number of training steps.
|
79 |
+
num_cycles (`float`, *optional*, defaults to 0.5):
|
80 |
+
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
|
81 |
+
following a half-cosine).
|
82 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
83 |
+
The index of the last epoch when resuming training.
|
84 |
+
|
85 |
+
Return:
|
86 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
87 |
+
"""
|
88 |
+
|
89 |
+
lr_lambda = partial(
|
90 |
+
_get_cosine_schedule_with_warmup_lr_lambda,
|
91 |
+
num_warmup_steps=num_warmup_steps,
|
92 |
+
num_training_steps=num_training_steps,
|
93 |
+
num_cycles=num_cycles,
|
94 |
+
)
|
95 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
sonics/utils/seed.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
|
6 |
+
|
7 |
+
def set_seed(seed, cudnn=False):
|
8 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
9 |
+
random.seed(seed)
|
10 |
+
np.random.seed(seed)
|
11 |
+
torch.manual_seed(seed)
|
12 |
+
if torch.cuda.is_available():
|
13 |
+
torch.cuda.manual_seed(seed)
|
14 |
+
torch.cuda.manual_seed_all(seed)
|
15 |
+
# May affect performance ref: https://pytorch.org/docs/stable/notes/randomness.html
|
16 |
+
if torch.backends.cudnn.is_available and cudnn:
|
17 |
+
torch.backends.cudnn.deterministic = True
|
18 |
+
torch.backends.cudnn.benchmark = False
|
19 |
+
|
20 |
+
|
21 |
+
def worker_init_fn(worker_id):
|
22 |
+
np.random.seed(np.random.get_state()[1][0] + worker_id)
|