awsaf49's picture
Initial Commit
3f50570
from sonics.models.spectttra import SpecTTTra
from sonics.models.vit import ViT
from sonics.layers.feature import FeatureExtractor
from sonics.layers.augment import AugmentLayer
import torch.nn as nn
import torch.nn.functional as F
import timm
def use_global_pool(model_name):
"""
Check if the model requires global pooling or not.
"""
no_global_pool = ["timm"]
return False if any(x in model_name for x in no_global_pool) else True
def get_embed_dim(model_name, encoder):
"""
Get the embedding dimension of the encoder.
"""
if "timm" in model_name:
return encoder.head_hidden_size
else:
return encoder.embed_dim
def use_init_weights(model_name):
"""
Check if the model requires initialization of weights or not.
"""
has_init_weights = ["timm"]
return False if any(x in model_name for x in has_init_weights) else True
class AudioClassifier(nn.Module):
def __init__(self, cfg):
super().__init__()
self.model_name = cfg.model.name
self.input_shape = cfg.model.input_shape
self.num_classes = cfg.num_classes
self.ft_extractor = FeatureExtractor(cfg)
self.augment = AugmentLayer(cfg)
self.encoder = self.get_encoder(cfg)
self.embed_dim = get_embed_dim(self.model_name, self.encoder)
self.classifier = nn.Linear(self.embed_dim, self.num_classes)
self.use_init_weights = getattr(cfg.model, "use_init_weights", True)
# Initialize weights
(
self.initialize_weights()
if self.use_init_weights and use_init_weights(self.model_name)
else None
)
def get_encoder(self, cfg):
if cfg.model.name == "SpecTTTra":
model = SpecTTTra(
input_spec_dim=cfg.model.input_shape[0],
input_temp_dim=cfg.model.input_shape[1],
embed_dim=cfg.model.embed_dim,
t_clip=cfg.model.t_clip,
f_clip=cfg.model.f_clip,
num_heads=cfg.model.num_heads,
num_layers=cfg.model.num_layers,
pre_norm=cfg.model.pre_norm,
pe_learnable=cfg.model.pe_learnable,
pos_drop_rate=getattr(cfg.model, "pos_drop_rate", 0.0),
attn_drop_rate=getattr(cfg.model, "attn_drop_rate", 0.0),
proj_drop_rate=getattr(cfg.model, "proj_drop_rate", 0.0),
mlp_ratio=getattr(cfg.model, "mlp_ratio", 4.0),
)
elif cfg.model.name == "ViT":
model = ViT(
image_size=cfg.model.input_shape,
patch_size=cfg.model.patch_size,
embed_dim=cfg.model.embed_dim,
num_heads=cfg.model.num_heads,
num_layers=cfg.model.num_layers,
pe_learnable=cfg.model.pe_learnable,
patch_norm=getattr(cfg.model, "patch_norm", False),
pos_drop_rate=getattr(cfg.model, "pos_drop_rate", 0.0),
attn_drop_rate=getattr(cfg.model, "attn_drop_rate", 0.0),
proj_drop_rate=getattr(cfg.model, "proj_drop_rate", 0.0),
mlp_ratio=getattr(cfg.model, "mlp_ratio", 4.0),
)
elif "timm" in cfg.model.name:
model_name = cfg.model.name.replace("timm-", "")
model = timm.create_model(
model_name,
pretrained=cfg.model.pretrained,
in_chans=1,
num_classes=0,
)
else:
raise ValueError(f"Model {cfg.model.name} not supported in V1.")
return model
def forward(self, audio, y=None):
spec = self.ft_extractor(audio) # shape: (batch_size, n_mels, n_frames)
if self.training:
spec, y = self.augment(spec, y)
spec = spec.unsqueeze(1) # shape: (batch_size, 1, n_mels, n_frames)
spec = F.interpolate(spec, size=tuple(self.input_shape), mode="bilinear")
features = self.encoder(spec)
embeds = features.mean(dim=1) if use_global_pool(self.model_name) else features
preds = self.classifier(embeds)
return preds if y is None else (preds, y)
def initialize_weights(self):
for name, module in self.named_modules():
if isinstance(module, nn.Linear):
if name.startswith("classifier"):
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, 0.0)
else:
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.normal_(module.bias, std=1e-6)
elif isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv1d):
nn.init.kaiming_normal_(
module.weight, mode="fan_out", nonlinearity="relu"
)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif hasattr(module, "init_weights"):
module.init_weights()