Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
class FCLayer(nn.Module): | |
def __init__(self, in_size, out_size=1): | |
super(FCLayer, self).__init__() | |
self.fc = nn.Sequential(nn.Linear(in_size, out_size)) | |
def forward(self, feats): | |
x = self.fc(feats) | |
return feats, x | |
class IClassifier(nn.Module): | |
def __init__(self, feature_extractor, feature_size, output_class): | |
super(IClassifier, self).__init__() | |
self.feature_extractor = feature_extractor | |
self.fc = nn.Linear(feature_size, output_class) | |
def forward(self, x): | |
device = x.device | |
feats = self.feature_extractor(x) # N x K | |
c = self.fc(feats.view(feats.shape[0], -1)) # N x C | |
return feats.view(feats.shape[0], -1), c | |
class BClassifier(nn.Module): | |
def __init__(self, input_size, output_class, dropout_v=0.0): # K, L, N | |
super(BClassifier, self).__init__() | |
self.q = nn.Linear(input_size, 128) | |
self.v = nn.Sequential( | |
nn.Dropout(dropout_v), | |
nn.Linear(input_size, input_size) | |
) | |
### 1D convolutional layer that can handle multiple class (including binary) | |
self.fcc = nn.Conv1d(output_class, output_class, kernel_size=input_size) | |
def forward(self, feats, c): # N x K, N x C | |
device = feats.device | |
V = self.v(feats) # N x V, unsorted | |
Q = self.q(feats).view(feats.shape[0], -1) # N x Q, unsorted | |
# handle multiple classes without for loop | |
_, m_indices = torch.sort(c, 0, descending=True) # sort class scores along the instance dimension, m_indices in shape N x C | |
m_feats = torch.index_select(feats, dim=0, index=m_indices[0, :]) # select critical instances, m_feats in shape C x K | |
q_max = self.q(m_feats) # compute queries of critical instances, q_max in shape C x Q | |
A = torch.mm(Q, q_max.transpose(0, 1)) # compute inner product of Q to each entry of q_max, A in shape N x C, each column contains unnormalized attention scores | |
A = F.softmax( A / torch.sqrt(torch.tensor(Q.shape[1], dtype=torch.float32, device=device)), 0) # normalize attention scores, A in shape N x C, | |
B = torch.mm(A.transpose(0, 1), V) # compute bag representation, B in shape C x V | |
# for i in range(c.shape[1]): | |
# _, indices = torch.sort(c[:, i], 0, True) | |
# feats = torch.index_select(feats, 0, indices) # N x K, sorted | |
# q_max = self.q(feats[0].view(1, -1)) # 1 x 1 x Q | |
# temp = torch.mm(Q, q_max.view(-1, 1)) / torch.sqrt(torch.tensor(Q.shape[1], dtype=torch.float32, device=device)) | |
# if i == 0: | |
# A = F.softmax(temp, 0) # N x 1 | |
# B = torch.sum(torch.mul(A, V), 0).view(1, -1) # 1 x V | |
# else: | |
# temp = F.softmax(temp, 0) # N x 1 | |
# A = torch.cat((A, temp), 1) # N x C | |
# B = torch.cat((B, torch.sum(torch.mul(temp, V), 0).view(1, -1)), 0) # C x V -> 1 x C x V | |
B = B.view(1, B.shape[0], B.shape[1]) # 1 x C x V | |
C = self.fcc(B) # 1 x C x 1 | |
C = C.view(1, -1) | |
return C, A, B | |
class MILNet(nn.Module): | |
def __init__(self, i_classifier, b_classifier): | |
super(MILNet, self).__init__() | |
self.i_classifier = i_classifier | |
self.b_classifier = b_classifier | |
def forward(self, x): | |
feats, classes = self.i_classifier(x) | |
prediction_bag, A, B = self.b_classifier(feats, classes) | |
return classes, prediction_bag, A, B | |