awsaf49's picture
Initial Commit
3f50570
import torch.nn as nn
from sonics.layers import Transformer
from sonics.layers.tokenizer import STTokenizer
class SpecTTTra(nn.Module):
def __init__(
self,
input_spec_dim,
input_temp_dim,
embed_dim,
t_clip,
f_clip,
num_heads,
num_layers,
pre_norm=False,
pe_learnable=False,
pos_drop_rate=0.0,
attn_drop_rate=0.0,
proj_drop_rate=0.0,
mlp_ratio=4.0,
):
super(SpecTTTra, self).__init__()
self.input_spec_dim = input_spec_dim
self.input_temp_dim = input_temp_dim
self.embed_dim = embed_dim
self.t_clip = t_clip
self.f_clip = f_clip
self.num_heads = num_heads
self.num_layers = num_layers
self.pre_norm = (
pre_norm # applied after tokenization before transformer (used in CLIP)
)
self.pe_learnable = pe_learnable # learned positional encoding
self.pos_drop_rate = pos_drop_rate
self.attn_drop_rate = attn_drop_rate
self.proj_drop_rate = proj_drop_rate
self.mlp_ratio = mlp_ratio
self.st_tokenizer = STTokenizer(
input_spec_dim,
input_temp_dim,
t_clip,
f_clip,
embed_dim,
pre_norm=pre_norm,
pe_learnable=pe_learnable,
)
self.pos_drop = nn.Dropout(p=pos_drop_rate)
self.transformer = Transformer(
embed_dim,
num_heads,
num_layers,
attn_drop=self.attn_drop_rate,
proj_drop=self.proj_drop_rate,
mlp_ratio=self.mlp_ratio,
)
def forward(self, x):
# Squeeze the channel dimension if it exists
if x.dim() == 4:
x = x.squeeze(1)
# Spectro-temporal tokenization
spectro_temporal_tokens = self.st_tokenizer(x)
# Positional dropout
spectro_temporal_tokens = self.pos_drop(spectro_temporal_tokens)
# Transformer
output = self.transformer(spectro_temporal_tokens) # shape: (B, T/t + F/f, dim)
return output
# Example usage:
input_spec_dim = 384
input_temp_dim = 128
embed_dim = 512
t_clip = 20 # This means t
f_clip = 10 # This means f
num_heads = 8
num_layers = 6
dim_feedforward = 512
num_classes = 10