import torch import torch.nn as nn from sonics.layers import ( SinusoidPositionalEncoding, LearnedPositionalEncoding, Transformer, ) from timm.layers import PatchEmbed class ViT(nn.Module): def __init__( self, image_size, patch_size, embed_dim, num_heads, num_layers, pe_learnable=False, patch_norm=False, pos_drop_rate=0.0, attn_drop_rate=0.0, proj_drop_rate=0.0, mlp_ratio=4.0, ): super().__init__() assert ( image_size[0] % patch_size == 0 and image_size[1] % patch_size == 0 ), "Image dimensions must be divisible by patch size." self.patch_size = patch_size self.embed_dim = embed_dim self.num_heads = num_heads self.num_layers = num_layers self.pe_learnable = pe_learnable self.patch_norm = patch_norm self.pos_drop_rate = pos_drop_rate self.attn_drop_rate = attn_drop_rate self.proj_drop_rate = proj_drop_rate self.mlp_ratio = mlp_ratio self.num_patches = (image_size[0] // patch_size) * (image_size[1] // patch_size) # self.patch_conv = nn.Conv2d( # 1, embed_dim, kernel_size=patch_size, stride=patch_size # ) # Original ViT has 3 input channels self.patch_encoder = PatchEmbed( img_size=image_size, patch_size=patch_size, in_chans=1, embed_dim=embed_dim, norm_layer=nn.LayerNorm if patch_norm else None, ) self.pos_encoder = ( SinusoidPositionalEncoding(embed_dim) if not pe_learnable else LearnedPositionalEncoding(embed_dim, self.num_patches) ) self.pos_drop = nn.Dropout(p=pos_drop_rate) self.transformer = Transformer( embed_dim, num_heads, num_layers, attn_drop=self.attn_drop_rate, proj_drop=self.proj_drop_rate, mlp_ratio=self.mlp_ratio, ) def forward(self, x): B = x.shape[0] # x = x.unsqueeze(1) # B x 1 x n_mels x n_frames # taken care of in the AudioClassifier if x.dim() == 3: x = x.unsqueeze(1) # timm PatchEmbed expects 4D tensor # Convolutional patch embedding # patches = self.patch_conv(x) # B x embed_dim x num_patches_h x num_patches_w patches = self.patch_encoder(x) # # Reshape patches # patches = patches.permute( # 0, 2, 3, 1 # ).contiguous() # B x num_patches_h x num_patches_w x embed_dim # patches = patches.view(B, -1, patches.size(-1)) # B x num_patches x embed_dim # Add positional embeddings embeddings = self.pos_encoder(patches) # Positional dropout embeddings = self.pos_drop(embeddings) # Transformer encoding output = self.transformer(embeddings) # B x num_patches x embed_dim return output batch_size = 1 input_height = 128 input_width = 384 * 6 * 4 patch_size = 16