import timm import torch import torch.nn as nn import torch.nn.functional as F from utmosv2.dataset._utils import get_dataset_num class MultiSpecModelV2(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg self.backbones = nn.ModuleList( [ timm.create_model( cfg.model.multi_spec.backbone, pretrained=True, num_classes=0, ) for _ in range(len(cfg.dataset.specs)) ] ) for backbone in self.backbones: backbone.global_pool = nn.Identity() self.weights = nn.Parameter( F.softmax(torch.randn(len(cfg.dataset.specs)), dim=0) ) self.pooling = timm.layers.SelectAdaptivePool2d( output_size=(None, 1) if self.cfg.model.multi_spec.atten else 1, pool_type=self.cfg.model.multi_spec.pool_type, flatten=False, ) if self.cfg.model.multi_spec.atten: self.attn = nn.MultiheadAttention( embed_dim=self.backbones[0].num_features * (2 if self.cfg.model.multi_spec.pool_type == "catavgmax" else 1), num_heads=8, dropout=0.2, batch_first=True, ) fc_in_features = ( self.backbones[0].num_features * (2 if self.cfg.model.multi_spec.pool_type == "catavgmax" else 1) * (2 if self.cfg.model.multi_spec.atten else 1) ) self.fc = nn.Linear(fc_in_features, cfg.model.multi_spec.num_classes) # if cfg.print_config: # print(f"| backbone model: {cfg.model.multi_spec.backbone}") # print(f"| Pooling: {cfg.model.multi_spec.pool_type}") # print(f"| Number of fc input features: {self.fc.in_features}") # print(f"| Number of fc output features: {self.fc.out_features}") def forward(self, x): x = [ x[:, i, :, :, :].squeeze(1) for i in range( self.cfg.dataset.spec_frames.num_frames * len(self.cfg.dataset.specs) ) ] x = [ self.backbones[i % len(self.cfg.dataset.specs)](t) for i, t in enumerate(x) ] x = [ sum( [ x[i * len(self.cfg.dataset.specs) + j] * w for j, w in enumerate(self.weights) ] ) for i in range(self.cfg.dataset.spec_frames.num_frames) ] x = torch.cat(x, dim=3) x = self.pooling(x).squeeze(3) if self.cfg.model.multi_spec.atten: xt = torch.permute(x, (0, 2, 1)) y, _ = self.attn(xt, xt, xt) x = torch.cat([torch.mean(y, dim=1), torch.max(x, dim=2).values], dim=1) x = self.fc(x) return x class MultiSpecExtModel(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg self.backbones = nn.ModuleList( [ timm.create_model( cfg.model.multi_spec.backbone, pretrained=True, num_classes=0, ) for _ in range(len(cfg.dataset.specs)) ] ) for backbone in self.backbones: backbone.global_pool = nn.Identity() self.weights = nn.Parameter( F.softmax(torch.randn(len(cfg.dataset.specs)), dim=0) ) self.pooling = timm.layers.SelectAdaptivePool2d( output_size=(None, 1) if self.cfg.model.multi_spec.atten else 1, pool_type=self.cfg.model.multi_spec.pool_type, flatten=False, ) if self.cfg.model.multi_spec.atten: self.attn = nn.MultiheadAttention( embed_dim=self.backbones[0].num_features * (2 if self.cfg.model.multi_spec.pool_type == "catavgmax" else 1), num_heads=8, dropout=0.2, batch_first=True, ) fc_in_features = ( self.backbones[0].num_features * (2 if self.cfg.model.multi_spec.pool_type == "catavgmax" else 1) * (2 if self.cfg.model.multi_spec.atten else 1) ) self.num_dataset = get_dataset_num(cfg) self.fc = nn.Linear( fc_in_features + self.num_dataset, cfg.model.multi_spec.num_classes ) # if cfg.print_config: # print(f"| backbone model: {cfg.model.multi_spec.backbone}") # print(f"| Pooling: {cfg.model.multi_spec.pool_type}") # print(f"| Number of fc input features: {self.fc.in_features}") # print(f"| Number of fc output features: {self.fc.out_features}") def forward(self, x, d): x = [ x[:, i, :, :, :].squeeze(1) for i in range( self.cfg.dataset.spec_frames.num_frames * len(self.cfg.dataset.specs) ) ] x = [ self.backbones[i % len(self.cfg.dataset.specs)](t) for i, t in enumerate(x) ] x = [ sum( [ x[i * len(self.cfg.dataset.specs) + j] * w for j, w in enumerate(self.weights) ] ) for i in range(self.cfg.dataset.spec_frames.num_frames) ] x = torch.cat(x, dim=3) x = self.pooling(x).squeeze(3) if self.cfg.model.multi_spec.atten: xt = torch.permute(x, (0, 2, 1)) y, _ = self.attn(xt, xt, xt) x = torch.cat([torch.mean(y, dim=1), torch.max(x, dim=2).values], dim=1) x = self.fc(torch.cat([x, d], dim=1)) return x