import math import torch import torch.nn as nn from sonics.layers.embedding import ( SinusoidPositionalEncoding, LearnedPositionalEncoding, ) class STTokenizer(nn.Module): def __init__( self, input_spec_dim, input_temp_dim, t_clip, f_clip, embed_dim, pre_norm=False, pe_learnable=False, ): super(STTokenizer, self).__init__() self.input_spec_dim = input_spec_dim self.input_temp_dim = input_temp_dim self.t_clip = t_clip self.f_clip = f_clip self.embed_dim = embed_dim self.pre_norm = pre_norm self.pe_learnable = pe_learnable self.num_temporal_tokens = math.floor( (input_temp_dim - t_clip) / t_clip + 1 ) # floor((1280 - 5) / 5 + 1)= 256 self.num_spectral_tokens = math.floor( (input_spec_dim - f_clip) / f_clip + 1 ) # floor((128 - 3) / 3 + 1) = 42 # L_out = floor((L_in + 2*p - d*(k - 1) - 1) / s + 1) (ref: PyTorch docs) self.num_tokens = ( self.num_temporal_tokens + self.num_spectral_tokens ) # 255 + 42 = 299 # For ViT, num_tokens = (1280 * 128)//(5 * 3) = 10922 :) self.temporal_tokenizer = Tokenizer1D( input_spec_dim, embed_dim, clip_size=t_clip, num_clips=self.num_temporal_tokens, pre_norm=pre_norm, pe_learnable=pe_learnable, ) self.spectral_tokenizer = Tokenizer1D( input_temp_dim, embed_dim, clip_size=f_clip, num_clips=self.num_spectral_tokens, pre_norm=pre_norm, pe_learnable=pe_learnable, ) def forward(self, x): # Temporal tokenization temporal_input = x # shape: (B, F, T) temporal_tokens = self.temporal_tokenizer( temporal_input ) # shape: (B, T/t, dim) # Spectral tokenization spectral_input = x.permute(0, 2, 1) # shape: (batch_size, T, F) spectral_tokens = self.spectral_tokenizer( spectral_input ) # shape: (B, F/f, dim) spectro_temporal_tokens = torch.cat( (temporal_tokens, spectral_tokens), dim=1 ) # shape: (B, T/t + F/f, dim) return spectro_temporal_tokens class Tokenizer1D(nn.Module): """Teimporal/Spectral Tokenizer Whisper uses temporal tokenizer but time_clip_size is too small, stride=1, thus complexity is very high. We use stride=clip_size - 1 to reduce complexity. """ def __init__( self, input_dim, token_dim, clip_size, num_clips, pre_norm=False, pe_learnable=False, ): super(Tokenizer1D, self).__init__() self.conv1d = nn.Conv1d( input_dim, token_dim, clip_size, stride=clip_size, bias=not pre_norm, # # disable bias if pre-norm is used (e.g. CLIP) ) self.act = nn.GELU() self.pos_encoder = ( SinusoidPositionalEncoding(token_dim) if not pe_learnable else LearnedPositionalEncoding(token_dim, num_clips) ) self.norm_pre = nn.LayerNorm(token_dim, eps=1e-6) if pre_norm else nn.Identity() def forward(self, x): x = x # (F, T) x = self.conv1d(x) # (F, T) -> (dim, T/t) x = self.act(x) x = x.transpose(1, 2) # (dim, T/t) -> (T/t, dim) x = self.pos_encoder(x) # add position embeds x = self.norm_pre(x) return x