File size: 8,805 Bytes
d526dbf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
r""" 4D and 6D convolutional Hough matching layers """
from torch.nn.modules.conv import _ConvNd
import torch.nn.functional as F
import torch.nn as nn
import torch
from common.logger import Logger
from . import chm_kernel
def fast4d(corr, kernel, bias=None):
r""" Optimized implementation of 4D convolution """
bsz, ch, srch, srcw, trgh, trgw = corr.size()
out_channels, _, kernel_size, kernel_size, kernel_size, kernel_size = kernel.size()
psz = kernel_size // 2
out_corr = torch.zeros((bsz, out_channels, srch, srcw, trgh, trgw))
corr = corr.transpose(1, 2).contiguous().view(bsz * srch, ch, srcw, trgh, trgw)
for pidx, k3d in enumerate(kernel.permute(2, 0, 1, 3, 4, 5)):
inter_corr = F.conv3d(corr, k3d, bias=None, stride=1, padding=psz)
inter_corr = inter_corr.view(bsz, srch, out_channels, srcw, trgh, trgw).transpose(1, 2).contiguous()
add_sid = max(psz - pidx, 0)
add_fid = min(srch, srch + psz - pidx)
slc_sid = max(pidx - psz, 0)
slc_fid = min(srch, srch - psz + pidx)
out_corr[:, :, add_sid:add_fid, :, :, :] += inter_corr[:, :, slc_sid:slc_fid, :, :, :]
if bias is not None:
out_corr += bias.view(1, out_channels, 1, 1, 1, 1)
return out_corr
def fast6d(corr, kernel, bias, diagonal_idx):
r""" Optimized implementation of 6D convolutional Hough matching
NOTE: this function only supports kernel size of (3, 3, 5, 5, 5, 5).
r"""
bsz, _, s6d, s6d, s4d, s4d, s4d, s4d = corr.size()
_, _, ks6d, ks6d, ks4d, ks4d, ks4d, ks4d = kernel.size()
corr = corr.permute(0, 2, 3, 1, 4, 5, 6, 7).contiguous().view(-1, 1, s4d, s4d, s4d, s4d)
kernel = kernel.view(-1, ks6d ** 2, ks4d, ks4d, ks4d, ks4d).transpose(0, 1)
corr = fast4d(corr, kernel).view(bsz, s6d * s6d, ks6d * ks6d, s4d, s4d, s4d, s4d)
corr = corr.view(bsz, s6d, s6d, ks6d, ks6d, s4d, s4d, s4d, s4d).transpose(2, 3).\
contiguous().view(-1, s6d * ks6d, s4d, s4d, s4d, s4d)
ndiag = s6d + (ks6d // 2) * 2
first_sum = []
for didx in diagonal_idx:
first_sum.append(corr[:, didx, :, :, :, :].sum(dim=1))
first_sum = torch.stack(first_sum).transpose(0, 1).view(bsz, s6d * ks6d, ndiag, s4d, s4d, s4d, s4d)
corr = []
for didx in diagonal_idx:
corr.append(first_sum[:, didx, :, :, :, :, :].sum(dim=1))
sidx = ks6d // 2
eidx = ndiag - sidx
corr = torch.stack(corr).transpose(0, 1)[:, sidx:eidx, sidx:eidx, :, :, :, :].unsqueeze(1).contiguous()
corr += bias.view(1, -1, 1, 1, 1, 1, 1, 1)
reverse_idx = torch.linspace(s6d * s6d - 1, 0, s6d * s6d).long()
corr = corr.view(bsz, 1, s6d * s6d, s4d, s4d, s4d, s4d)[:, :, reverse_idx, :, :, :, :].\
view(bsz, 1, s6d, s6d, s4d, s4d, s4d, s4d)
return corr
def init_param_idx4d(param_dict):
param_idx = []
for key in param_dict:
curr_offset = int(key.split('_')[-1])
param_idx.append(torch.tensor(param_dict[key]))
return param_idx
class CHM4d(_ConvNd):
r""" 4D convolutional Hough matching layer
NOTE: this function only supports in_channels=1 and out_channels=1.
r"""
def __init__(self, in_channels, out_channels, ksz4d, ktype, bias=True):
super(CHM4d, self).__init__(in_channels, out_channels, (ksz4d,) * 4,
(1,) * 4, (0,) * 4, (1,) * 4, False, (0,) * 4,
1, bias, padding_mode='zeros')
# Zero kernel initialization
self.zero_kernel4d = torch.zeros((in_channels, out_channels, ksz4d, ksz4d, ksz4d, ksz4d))
self.nkernels = in_channels * out_channels
# Initialize kernel indices
param_dict4d = chm_kernel.KernelGenerator(ksz4d, ktype).generate()
param_shared = param_dict4d is not None
if param_shared:
# Initialize the shared parameters (multiplied by the number of times being shared)
self.param_idx = init_param_idx4d(param_dict4d)
weights = torch.abs(torch.randn(len(self.param_idx) * self.nkernels)) * 1e-3
for weight, param_idx in zip(weights.sort()[0], self.param_idx):
weight *= len(param_idx)
self.weight = nn.Parameter(weights)
else: # full kernel initialziation
self.param_idx = None
self.weight = nn.Parameter(torch.abs(self.weight))
if bias: self.bias = nn.Parameter(torch.tensor(0.0))
Logger.info('(%s) # params in CHM 4D: %d' % (ktype, len(self.weight.view(-1))))
def forward(self, x):
kernel = self.init_kernel()
x = fast4d(x, kernel, self.bias)
return x
def init_kernel(self):
# Initialize CHM kernel (divided by the number of times being shared)
ksz = self.kernel_size[-1]
if self.param_idx is None:
kernel = self.weight
else:
kernel = torch.zeros_like(self.zero_kernel4d)
for idx, pdx in enumerate(self.param_idx):
kernel = kernel.view(-1, ksz, ksz, ksz, ksz)
for jdx, kernel_single in enumerate(kernel):
weight = self.weight[idx + jdx * len(self.param_idx)].repeat(len(pdx)) / len(pdx)
kernel_single.view(-1)[pdx] += weight
kernel = kernel.view(self.in_channels, self.out_channels, ksz, ksz, ksz, ksz)
return kernel
class CHM6d(_ConvNd):
r""" 6D convolutional Hough matching layer with kernel (3, 3, 5, 5, 5, 5)
NOTE: this function only supports in_channels=1 and out_channels=1.
r"""
def __init__(self, in_channels, out_channels, ksz6d, ksz4d, ktype):
kernel_size = (ksz6d, ksz6d, ksz4d, ksz4d, ksz4d, ksz4d)
super(CHM6d, self).__init__(in_channels, out_channels, kernel_size, (1,) * 6,
(0,) * 6, (1,) * 6, False, (0,) * 6,
1, bias=True, padding_mode='zeros')
# Zero kernel initialization
self.zero_kernel4d = torch.zeros((ksz4d, ksz4d, ksz4d, ksz4d))
self.zero_kernel6d = torch.zeros((ksz6d, ksz6d, ksz4d, ksz4d, ksz4d, ksz4d))
self.nkernels = in_channels * out_channels
# Initialize kernel indices
# Indices in scale-space where 4D convolutions are performed (3 by 3 scale-space)
self.diagonal_idx = [torch.tensor(x) for x in [[6], [3, 7], [0, 4, 8], [1, 5], [2]]]
param_dict4d = chm_kernel.KernelGenerator(ksz4d, ktype).generate()
param_shared = param_dict4d is not None
if param_shared: # psi & iso kernel initialization
if ktype == 'psi':
self.param_dict6d = [[4], [0, 8], [2, 6], [1, 3, 5, 7]]
elif ktype == 'iso':
self.param_dict6d = [[0, 4, 8], [2, 6], [1, 3, 5, 7]]
self.param_dict6d = [torch.tensor(i) for i in self.param_dict6d]
# Initialize the shared parameters (multiplied by the number of times being shared)
self.param_idx = init_param_idx4d(param_dict4d)
self.param = []
for param_dict6d in self.param_dict6d:
weights = torch.abs(torch.randn(len(self.param_idx))) * 1e-3
for weight, param_idx in zip(weights, self.param_idx):
weight *= (len(param_idx) * len(param_dict6d))
self.param.append(nn.Parameter(weights))
self.param = nn.ParameterList(self.param)
else: # full kernel initialziation
self.param_idx = None
self.param = nn.Parameter(torch.abs(self.weight) * 1e-3)
Logger.info('(%s) # params in CHM 6D: %d' % (ktype, sum([len(x.view(-1)) for x in self.param])))
self.weight = None
def forward(self, corr):
kernel = self.init_kernel()
corr = fast6d(corr, kernel, self.bias, self.diagonal_idx)
return corr
def init_kernel(self):
# Initialize CHM kernel (divided by the number of times being shared)
if self.param_idx is None:
return self.param
kernel6d = torch.zeros_like(self.zero_kernel6d)
for idx, (param, param_dict6d) in enumerate(zip(self.param, self.param_dict6d)):
ksz4d = self.kernel_size[-1]
kernel4d = torch.zeros_like(self.zero_kernel4d)
for jdx, pdx in enumerate(self.param_idx):
kernel4d.view(-1)[pdx] += ((param[jdx] / len(pdx)) / len(param_dict6d))
kernel6d.view(-1, ksz4d, ksz4d, ksz4d, ksz4d)[param_dict6d] += kernel4d.view(ksz4d, ksz4d, ksz4d, ksz4d)
kernel6d = kernel6d.unsqueeze(0).unsqueeze(0)
return kernel6d
|