UTMOSv2 / utmosv2 /model /multi_spec.py
kAIto47802
Resolved conflict in README.md
b55d767
raw
history blame
5.77 kB
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from utmosv2.dataset._utils import get_dataset_num
class MultiSpecModelV2(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.backbones = nn.ModuleList(
[
timm.create_model(
cfg.model.multi_spec.backbone,
pretrained=True,
num_classes=0,
)
for _ in range(len(cfg.dataset.specs))
]
)
for backbone in self.backbones:
backbone.global_pool = nn.Identity()
self.weights = nn.Parameter(
F.softmax(torch.randn(len(cfg.dataset.specs)), dim=0)
)
self.pooling = timm.layers.SelectAdaptivePool2d(
output_size=(None, 1) if self.cfg.model.multi_spec.atten else 1,
pool_type=self.cfg.model.multi_spec.pool_type,
flatten=False,
)
if self.cfg.model.multi_spec.atten:
self.attn = nn.MultiheadAttention(
embed_dim=self.backbones[0].num_features
* (2 if self.cfg.model.multi_spec.pool_type == "catavgmax" else 1),
num_heads=8,
dropout=0.2,
batch_first=True,
)
fc_in_features = (
self.backbones[0].num_features
* (2 if self.cfg.model.multi_spec.pool_type == "catavgmax" else 1)
* (2 if self.cfg.model.multi_spec.atten else 1)
)
self.fc = nn.Linear(fc_in_features, cfg.model.multi_spec.num_classes)
# if cfg.print_config:
# print(f"| backbone model: {cfg.model.multi_spec.backbone}")
# print(f"| Pooling: {cfg.model.multi_spec.pool_type}")
# print(f"| Number of fc input features: {self.fc.in_features}")
# print(f"| Number of fc output features: {self.fc.out_features}")
def forward(self, x):
x = [
x[:, i, :, :, :].squeeze(1)
for i in range(
self.cfg.dataset.spec_frames.num_frames * len(self.cfg.dataset.specs)
)
]
x = [
self.backbones[i % len(self.cfg.dataset.specs)](t) for i, t in enumerate(x)
]
x = [
sum(
[
x[i * len(self.cfg.dataset.specs) + j] * w
for j, w in enumerate(self.weights)
]
)
for i in range(self.cfg.dataset.spec_frames.num_frames)
]
x = torch.cat(x, dim=3)
x = self.pooling(x).squeeze(3)
if self.cfg.model.multi_spec.atten:
xt = torch.permute(x, (0, 2, 1))
y, _ = self.attn(xt, xt, xt)
x = torch.cat([torch.mean(y, dim=1), torch.max(x, dim=2).values], dim=1)
x = self.fc(x)
return x
class MultiSpecExtModel(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.backbones = nn.ModuleList(
[
timm.create_model(
cfg.model.multi_spec.backbone,
pretrained=True,
num_classes=0,
)
for _ in range(len(cfg.dataset.specs))
]
)
for backbone in self.backbones:
backbone.global_pool = nn.Identity()
self.weights = nn.Parameter(
F.softmax(torch.randn(len(cfg.dataset.specs)), dim=0)
)
self.pooling = timm.layers.SelectAdaptivePool2d(
output_size=(None, 1) if self.cfg.model.multi_spec.atten else 1,
pool_type=self.cfg.model.multi_spec.pool_type,
flatten=False,
)
if self.cfg.model.multi_spec.atten:
self.attn = nn.MultiheadAttention(
embed_dim=self.backbones[0].num_features
* (2 if self.cfg.model.multi_spec.pool_type == "catavgmax" else 1),
num_heads=8,
dropout=0.2,
batch_first=True,
)
fc_in_features = (
self.backbones[0].num_features
* (2 if self.cfg.model.multi_spec.pool_type == "catavgmax" else 1)
* (2 if self.cfg.model.multi_spec.atten else 1)
)
self.num_dataset = get_dataset_num(cfg)
self.fc = nn.Linear(
fc_in_features + self.num_dataset, cfg.model.multi_spec.num_classes
)
# if cfg.print_config:
# print(f"| backbone model: {cfg.model.multi_spec.backbone}")
# print(f"| Pooling: {cfg.model.multi_spec.pool_type}")
# print(f"| Number of fc input features: {self.fc.in_features}")
# print(f"| Number of fc output features: {self.fc.out_features}")
def forward(self, x, d):
x = [
x[:, i, :, :, :].squeeze(1)
for i in range(
self.cfg.dataset.spec_frames.num_frames * len(self.cfg.dataset.specs)
)
]
x = [
self.backbones[i % len(self.cfg.dataset.specs)](t) for i, t in enumerate(x)
]
x = [
sum(
[
x[i * len(self.cfg.dataset.specs) + j] * w
for j, w in enumerate(self.weights)
]
)
for i in range(self.cfg.dataset.spec_frames.num_frames)
]
x = torch.cat(x, dim=3)
x = self.pooling(x).squeeze(3)
if self.cfg.model.multi_spec.atten:
xt = torch.permute(x, (0, 2, 1))
y, _ = self.attn(xt, xt, xt)
x = torch.cat([torch.mean(y, dim=1), torch.max(x, dim=2).values], dim=1)
x = self.fc(torch.cat([x, d], dim=1))
return x