from sonics.models.spectttra import SpecTTTra from sonics.models.vit import ViT from sonics.layers.feature import FeatureExtractor from sonics.layers.augment import AugmentLayer import torch.nn as nn import torch.nn.functional as F import timm def use_global_pool(model_name): """ Check if the model requires global pooling or not. """ no_global_pool = ["timm"] return False if any(x in model_name for x in no_global_pool) else True def get_embed_dim(model_name, encoder): """ Get the embedding dimension of the encoder. """ if "timm" in model_name: return encoder.head_hidden_size else: return encoder.embed_dim def use_init_weights(model_name): """ Check if the model requires initialization of weights or not. """ has_init_weights = ["timm"] return False if any(x in model_name for x in has_init_weights) else True class AudioClassifier(nn.Module): def __init__(self, cfg): super().__init__() self.model_name = cfg.model.name self.input_shape = cfg.model.input_shape self.num_classes = cfg.num_classes self.ft_extractor = FeatureExtractor(cfg) self.augment = AugmentLayer(cfg) self.encoder = self.get_encoder(cfg) self.embed_dim = get_embed_dim(self.model_name, self.encoder) self.classifier = nn.Linear(self.embed_dim, self.num_classes) self.use_init_weights = getattr(cfg.model, "use_init_weights", True) # Initialize weights ( self.initialize_weights() if self.use_init_weights and use_init_weights(self.model_name) else None ) def get_encoder(self, cfg): if cfg.model.name == "SpecTTTra": model = SpecTTTra( input_spec_dim=cfg.model.input_shape[0], input_temp_dim=cfg.model.input_shape[1], embed_dim=cfg.model.embed_dim, t_clip=cfg.model.t_clip, f_clip=cfg.model.f_clip, num_heads=cfg.model.num_heads, num_layers=cfg.model.num_layers, pre_norm=cfg.model.pre_norm, pe_learnable=cfg.model.pe_learnable, pos_drop_rate=getattr(cfg.model, "pos_drop_rate", 0.0), attn_drop_rate=getattr(cfg.model, "attn_drop_rate", 0.0), proj_drop_rate=getattr(cfg.model, "proj_drop_rate", 0.0), mlp_ratio=getattr(cfg.model, "mlp_ratio", 4.0), ) elif cfg.model.name == "ViT": model = ViT( image_size=cfg.model.input_shape, patch_size=cfg.model.patch_size, embed_dim=cfg.model.embed_dim, num_heads=cfg.model.num_heads, num_layers=cfg.model.num_layers, pe_learnable=cfg.model.pe_learnable, patch_norm=getattr(cfg.model, "patch_norm", False), pos_drop_rate=getattr(cfg.model, "pos_drop_rate", 0.0), attn_drop_rate=getattr(cfg.model, "attn_drop_rate", 0.0), proj_drop_rate=getattr(cfg.model, "proj_drop_rate", 0.0), mlp_ratio=getattr(cfg.model, "mlp_ratio", 4.0), ) elif "timm" in cfg.model.name: model_name = cfg.model.name.replace("timm-", "") model = timm.create_model( model_name, pretrained=cfg.model.pretrained, in_chans=1, num_classes=0, ) else: raise ValueError(f"Model {cfg.model.name} not supported in V1.") return model def forward(self, audio, y=None): spec = self.ft_extractor(audio) # shape: (batch_size, n_mels, n_frames) if self.training: spec, y = self.augment(spec, y) spec = spec.unsqueeze(1) # shape: (batch_size, 1, n_mels, n_frames) spec = F.interpolate(spec, size=tuple(self.input_shape), mode="bilinear") features = self.encoder(spec) embeds = features.mean(dim=1) if use_global_pool(self.model_name) else features preds = self.classifier(embeds) return preds if y is None else (preds, y) def initialize_weights(self): for name, module in self.named_modules(): if isinstance(module, nn.Linear): if name.startswith("classifier"): nn.init.zeros_(module.weight) nn.init.constant_(module.bias, 0.0) else: nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.normal_(module.bias, std=1e-6) elif isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv1d): nn.init.kaiming_normal_( module.weight, mode="fan_out", nonlinearity="relu" ) if module.bias is not None: nn.init.zeros_(module.bias) elif hasattr(module, "init_weights"): module.init_weights()