NamedCurves / models /attention_fusion.py
davidserra9's picture
First commit from github repo
117183e verified
import torch
from torch import nn
class LocalFusion(nn.Module):
def __init__(self, att_in_dim=3, num_categories=6, max_pool_ksize1=4, max_pool_ksize2=2, encoder_dims=[8, 16]):
super().__init__()
self.num_categories = num_categories
self.att_in_dim = att_in_dim
self.attention_fusion = nn.ModuleList([Self_Attn(in_dim=att_in_dim, max_pool_ksize1=max_pool_ksize1, max_pool_ksize2=max_pool_ksize2, encoder_dims=encoder_dims) for _ in range(num_categories)])
def forward(self, x, color_naming_probs=None, q=None):
# Using the average to compute the blending
if color_naming_probs is None:
# Using the same input tensor for query, key, and value
if q is None:
return torch.mean(torch.stack([att(x_color, q=x) for att, x_color in zip(self.attention_fusion, x)], dim=0))
else:
return torch.mean(torch.stack([att(x_color, q=q) for att, x_color in zip(self.attention_fusion, x)], dim=0))
# Using the color naming probabilities to compute the blending. Weighted average with color naming probs as
# weights.
else:
color_naming_probs = (color_naming_probs > 0.20).float()
color_naming_avg = torch.sum(color_naming_probs, dim=0).unsqueeze(1).repeat(1, 3, 1, 1)
color_naming_probs = color_naming_probs.unsqueeze(2).repeat(1, 1, 3, 1, 1)
# Using the same input tensor for query, key, and value
if q is None:
out = torch.stack([att(x_color, q=x) for att, x_color in zip(self.attention_fusion, x)], dim=0)
else:
out = torch.stack([att(x_color, q=q) for att, x_color in zip(self.attention_fusion, x)], dim=0)
out = torch.sum(out * color_naming_probs, dim=0) / color_naming_avg
return torch.clip(out, 0, 1)
class Self_Attn(nn.Module):
def __init__(self, in_dim, max_pool_ksize1=4, max_pool_ksize2=2, encoder_dims=[8, 16]):
super(Self_Attn, self).__init__()
self.chanel_in = in_dim
self.max_pool_ksize1 = max_pool_ksize1
self.max_pool_ksize2 = max_pool_ksize2
self.down_ratio = max_pool_ksize1 * max_pool_ksize2
self.query_conv = nn.Sequential(
nn.Conv2d(in_channels=in_dim, out_channels=encoder_dims[0], kernel_size=1),
nn.ReLU(),
nn.MaxPool2d(4, 4),
nn.Conv2d(in_channels=encoder_dims[0], out_channels=encoder_dims[1], kernel_size=1),
nn.ReLU(),
nn.MaxPool2d(2, 2))
self.key_conv = nn.Sequential(
nn.Conv2d(in_channels=in_dim, out_channels=encoder_dims[0], kernel_size=1),
nn.ReLU(),
nn.MaxPool2d(4, 4),
nn.Conv2d(in_channels=encoder_dims[0], out_channels=encoder_dims[1], kernel_size=1),
nn.ReLU(),
nn.MaxPool2d(2, 2))
self.value_conv = nn.Sequential(
nn.Conv2d(in_channels=in_dim, out_channels=encoder_dims[0], kernel_size=1),
nn.ReLU(),
nn.MaxPool2d(4, 4),
nn.Conv2d(in_channels=encoder_dims[0], out_channels=encoder_dims[1], kernel_size=1),
nn.ReLU(),
nn.MaxPool2d(2, 2))
self.upsample = nn.Sequential(
nn.Conv2d(in_channels=encoder_dims[1], out_channels=encoder_dims[0], kernel_size=1),
nn.ReLU(),
nn.UpsamplingNearest2d(scale_factor=4),
nn.Conv2d(in_channels=encoder_dims[0], out_channels=encoder_dims[0], kernel_size=1),
nn.ReLU(),
nn.Conv2d(in_channels=encoder_dims[0], out_channels=3, kernel_size=1),
nn.ReLU())
self.last_conv = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
self.max_pool = nn.MaxPool2d(2, 2)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, q=None):
if q is None:
q = x
m_batch_size, C, width, height = x.size()
proj_query = self.query_conv(q).view(m_batch_size, -1, int((width//self.down_ratio)*(height//self.down_ratio))).permute(0, 2, 1)
proj_key = self.key_conv(x).view(m_batch_size, -1, int((width//self.down_ratio)*(height//self.down_ratio)))
energy = torch.bmm(proj_query, proj_key)
attention = self.softmax(energy)
proj_value = self.value_conv(x).view(m_batch_size, -1, int((width//self.down_ratio)*(height//self.down_ratio)))
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(m_batch_size, 16, int(width//self.down_ratio), int(height//self.down_ratio))
out = self.upsample(out)
upsampled_layer = nn.Upsample(size=x.size()[2:], mode='bilinear', align_corners=False)
out = upsampled_layer(out)
out = self.last_conv(out)
out = out + x
return out