Spaces:
Runtime error
Runtime error
# Adapted from SPFN | |
import torch | |
import torch.nn as nn | |
try: | |
from torch_cluster import fps | |
except: | |
pass | |
# from .point_convolution_universal import TransitionDown, TransitionUp | |
# from .model_util import construct_conv1d_modules, construct_conv_modules, CorrFlowPredNet, set_bn_not_training, set_grad_to_none | |
# from .utils import farthest_point_sampling, get_knn_idx, batched_index_select | |
def set_bn_not_training(module): | |
if isinstance(module, nn.ModuleList): | |
for block in module: | |
set_bn_not_training(block) | |
elif isinstance(module, nn.Sequential): | |
for block in module: | |
if isinstance(block, nn.BatchNorm1d) or isinstance(block, nn.BatchNorm2d): | |
block.is_training = False | |
else: | |
raise ValueError("Not recognized module to set not training!") | |
def set_grad_to_none(module): | |
if isinstance(module, nn.ModuleList): | |
for block in module: | |
set_grad_to_none(block) | |
elif isinstance(module, nn.Sequential): | |
for block in module: | |
for param in block.parameters(): | |
param.grad = None | |
else: | |
raise ValueError("Not recognized module to set not training!") | |
def apply_module_with_conv2d_bn(x, module): # bsz x npts x feats -> bsz x feats x npts -> | |
x = x.transpose(2, 3).contiguous().transpose(1, 2).contiguous() | |
# print(x.size()) | |
for layer in module: | |
for sublayer in layer: | |
x = sublayer(x.contiguous()) | |
x = x.float() | |
x = torch.transpose(x, 1, 2).transpose(2, 3) | |
return x | |
def batched_index_select(values, indices, dim = 1): | |
value_dims = values.shape[(dim + 1):] | |
values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices)) | |
indices = indices[(..., *((None,) * len(value_dims)))] | |
indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims) | |
value_expand_len = len(indices_shape) - (dim + 1) | |
values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)] | |
value_expand_shape = [-1] * len(values.shape) | |
expand_slice = slice(dim, (dim + value_expand_len)) | |
value_expand_shape[expand_slice] = indices.shape[expand_slice] | |
values = values.expand(*value_expand_shape) | |
dim += value_expand_len | |
return values.gather(dim, indices) | |
def init_weight(blocks): | |
for module in blocks: | |
if isinstance(module, nn.Sequential): | |
for subm in module: | |
if isinstance(subm, nn.Linear): | |
nn.init.xavier_uniform_(subm.weight) | |
nn.init.zeros_(subm.bias) | |
elif isinstance(module, nn.Linear): | |
nn.init.xavier_uniform_(module.weight) | |
nn.init.zeros_(module.bias) | |
def construct_conv_modules(mlp_dims, n_in, last_act=True, bn=True): | |
rt_module_list = nn.ModuleList() | |
for i, dim in enumerate(mlp_dims): | |
inc, ouc = n_in if i == 0 else mlp_dims[i-1], dim | |
if (i < len(mlp_dims) - 1 or (i == len(mlp_dims) - 1 and last_act)): | |
blk = nn.Sequential( | |
nn.Conv2d(in_channels=inc, out_channels=ouc, kernel_size=(1, 1), stride=(1, 1), bias=True), | |
nn.BatchNorm2d(num_features=ouc, eps=1e-5, momentum=0.1), | |
# nn.GroupNorm(num_groups=4, num_channels=ouc), | |
nn.ReLU() | |
) | |
# elif bn and ouc % 4 == 0: | |
elif bn: # and ouc % 4 == 0: | |
blk = nn.Sequential( | |
nn.Conv2d(in_channels=inc, out_channels=ouc, kernel_size=(1, 1), stride=(1, 1), bias=True), | |
nn.BatchNorm2d(num_features=ouc, eps=1e-5, momentum=0.1), | |
# nn.GroupNorm(num_groups=4, num_channels=ouc), | |
) | |
else: | |
blk = nn.Sequential( | |
nn.Conv2d(in_channels=inc, out_channels=ouc, kernel_size=(1, 1), stride=(1, 1), bias=True), | |
) | |
rt_module_list.append(blk) | |
init_weight(rt_module_list) | |
return rt_module_list | |
def farthest_point_sampling(pos: torch.FloatTensor, n_sampling: int): | |
bz, N = pos.size(0), pos.size(1) | |
feat_dim = pos.size(-1) | |
device = pos.device | |
sampling_ratio = float(n_sampling / N) | |
pos_float = pos.float() | |
batch = torch.arange(bz, dtype=torch.long).view(bz, 1).to(device) | |
mult_one = torch.ones((N,), dtype=torch.long).view(1, N).to(device) | |
batch = batch * mult_one | |
batch = batch.view(-1) | |
pos_float = pos_float.contiguous().view(-1, feat_dim).contiguous() # (bz x N, 3) | |
# sampling_ratio = torch.tensor([sampling_ratio for _ in range(bz)], dtype=torch.float).to(device) | |
# batch = torch.zeros((N, ), dtype=torch.long, device=device) | |
sampled_idx = fps(pos_float, batch, ratio=sampling_ratio, random_start=False) | |
# shape of sampled_idx? | |
return sampled_idx | |
class PointnetPP(nn.Module): | |
def __init__(self, in_feat_dim: int): | |
super(PointnetPP, self).__init__() | |
# if args is not None: # | |
# self.skip_global = args.skip_global | |
# else: | |
self.skip_global = False | |
# self.n_samples = [512, 128, 1] # if "motion" not in args.task else [256, 128, 1] | |
self.n_samples = [256, 128, 1] | |
# self.n_samples = [1024, 512, 1] | |
mlps = [[64,64,128], [128,128,256], [256,512,1024]] | |
mlps_in = [[in_feat_dim,64,64], [128+3,128,128], [256+3,256,512]] | |
# up_mlps = [[256, 256], [256, 128], [128, 128, 128]] | |
up_mlps = [[512, 512], [512, 512], [512, 512, 512]] | |
# up_mlps_in = [1024+256, 256+128, 128+3+3] | |
up_mlps_in = [1024 + 256, 512 + 128, 512 + in_feat_dim] | |
self.in_feat_dim = in_feat_dim | |
self.radius = [0.2, 0.4, None] | |
self.radius = [None, None, None] | |
# if args is not None: # radius? # | |
# n_layers = args.pnpp_n_layers | |
# self.n_samples = self.n_samples[:n_layers] | |
# mlps, mlps_in = mlps[:n_layers], mlps_in[:n_layers] | |
# self.radius = self.radius[:n_layers] | |
# up_mlps = up_mlps[-n_layers:] | |
# up_mlps_in = up_mlps_in[-n_layers:] | |
self.mlp_layers = nn.ModuleList() | |
for i, (dims_in, dims_out) in enumerate(zip(mlps_in, mlps)): | |
# if self.skip_global and i == len(mlps_in) - 1: | |
# break | |
conv_layers = construct_conv_modules( | |
mlp_dims=dims_out, n_in=dims_in[0], | |
last_act=True, | |
bn=True | |
) | |
self.mlp_layers.append(conv_layers) | |
self.up_mlp_layers = nn.ModuleList() | |
for i, (dim_in, dims_out) in enumerate(zip(up_mlps_in, up_mlps)): | |
# if self.skip_global and i == 0: | |
# continue | |
conv_layers = construct_conv_modules( | |
mlp_dims=dims_out, n_in=dim_in, | |
# last_act=False, | |
last_act=True, | |
bn=True | |
) | |
self.up_mlp_layers.append(conv_layers) | |
def eval(self): | |
super().eval() | |
self.set_bn_no_training() | |
# return super().eval() | |
def set_bn_no_training(self): | |
for sub_module in self.mlp_layers: | |
set_bn_not_training(sub_module) | |
for sub_module in self.up_mlp_layers: | |
set_bn_not_training(sub_module) | |
def set_grad_to_none(self): | |
for sub_module in self.mlp_layers: | |
set_grad_to_none(sub_module) | |
for sub_module in self.up_mlp_layers: | |
set_grad_to_none(sub_module) | |
def sample_and_group(self, feat, pos, n_samples, use_pos=True, k=64): | |
bz, N = pos.size(0), pos.size(1) | |
fps_idx = farthest_point_sampling(pos=pos[:, :, :3], n_sampling=n_samples) | |
# bz x n_samples x pos_dim | |
# sampled_pos = batched_index_select(values=pos, indices=fps_idx, dim=1) | |
sampled_pos = pos.contiguous().view(bz * N, -1)[fps_idx, :].contiguous().view(bz, n_samples, -1) | |
ppdist = torch.sum((sampled_pos.unsqueeze(2) - pos.unsqueeze(1)) ** 2, dim=-1) | |
ppdist = torch.sqrt(ppdist) | |
topk_dist, topk_idx = torch.topk(ppdist, k=k, dim=2, largest=False) | |
# if n_samples == 1: | |
# | |
grouped_pos = batched_index_select(values=pos, indices=topk_idx, dim=1) | |
grouped_pos = grouped_pos - sampled_pos.unsqueeze(2) | |
if feat is not None: | |
grouped_feat = batched_index_select(values=feat, indices=topk_idx, dim=1) | |
if use_pos: | |
grouped_feat = torch.cat([grouped_pos, grouped_feat], dim=-1) | |
else: | |
grouped_feat = grouped_pos | |
return grouped_feat, topk_dist, sampled_pos | |
def max_pooling_with_r(self, grouped_feat, ppdist, r=None): | |
if r is None: | |
res, _ = torch.max(grouped_feat, dim=2) | |
else: | |
# bz x N x k | |
indicators = (ppdist <= r).float() | |
indicators_expand = indicators.unsqueeze(-1).repeat(1, 1, 1, grouped_feat.size(-1)) | |
indicators_expand[indicators_expand < 0.5] = -1e8 | |
indicators_expand[indicators_expand > 0.5] = 0. | |
# grouped_feat[indicators_expand < 0.5] = -1e8 | |
# res, _ = torch.max(grouped_feat, dim=2) | |
res, _ = torch.max(grouped_feat + indicators_expand, dim=2) | |
return res | |
def interpolate_features(self, feat, p1, p2, ): | |
dist = p2[:, :, None, :] - p1[:, None, :, :] | |
dist = torch.norm(dist, dim=-1, p=2, keepdim=False) | |
topkk = min(3, dist.size(-1)) | |
dist, idx = dist.topk(topkk, dim=-1, largest=False) | |
# bz x N2 x 3 | |
# print(dist.size(), idx.size()) | |
dist_recip = 1.0 / (dist + 1e-8) | |
norm = torch.sum(dist_recip, dim=2, keepdim=True) | |
weight = dist_recip / norm | |
# weight.size() = bz x N2 x 3; idx.size() = bz x N2 x 3 | |
three_nearest_features = batched_index_select(feat, idx, dim=1) # 1 is the idx dimension | |
interpolated_feats = torch.sum(three_nearest_features * weight[:, :, :, None], dim=2, keepdim=False) | |
return interpolated_feats | |
def forward(self, x: torch.FloatTensor, pos: torch.FloatTensor, return_global=False, | |
): | |
# x = x[:, :, 3:] # bsz x nnf x nnbasepts x nnbaseptsfeats # | |
bz = pos.size(0) | |
cache = [] | |
cache.append((None if x is None else x.clone(), pos.clone())) | |
n_samples = self.n_samples | |
for i, n_samples in enumerate(n_samples): # point view ---> how to look joints from the base pts here --> and for the point convs # | |
if n_samples == 1: | |
grouped_feat = x.unsqueeze(1) | |
grouped_feat = torch.cat( | |
[pos.unsqueeze(1), grouped_feat], dim=-1 | |
) | |
grouped_feat = apply_module_with_conv2d_bn( | |
grouped_feat, self.mlp_layers[i] | |
).squeeze(1) | |
x, _ = torch.max(grouped_feat, dim=1, keepdim=True) | |
sampled_pos = torch.zeros((bz, 1, 3), dtype=torch.float, device=pos.device) | |
pos = sampled_pos | |
else: | |
grouped_feat, topk_dist, pos = self.sample_and_group(x, pos, n_samples, use_pos=True, k=64) | |
# print(f"x: {x.size()}, pos: {pos.size()}, grouped_feat: {grouped_feat.size()}") | |
grouped_feat = apply_module_with_conv2d_bn( | |
grouped_feat, self.mlp_layers[i] | |
) | |
cur_radius = self.radius[i] | |
x = self.max_pooling_with_r(grouped_feat, topk_dist, r=cur_radius) | |
cache.append((x.clone(), pos.clone())) | |
up_mlp_layers = self.up_mlp_layers | |
# global_x = x | |
for i, up_conv_layers in enumerate(up_mlp_layers): | |
prev_x, prev_pos = cache[-i-2][0], cache[-i-2][1] | |
# print(prev_pos.size(), x.size(), pos.size()) | |
# interpolate x via pos & prev_pos # interpolate features | |
interpolated_feats = self.interpolate_features(x, pos, prev_pos) | |
if prev_x is None: | |
prev_x = prev_pos | |
elif i == len(self.up_mlp_layers) - 1: | |
prev_x = torch.cat([prev_x, prev_pos], dim=-1) | |
# if without previous x, we only have the interpolated feature | |
cur_up_feats = torch.cat([interpolated_feats, prev_x], dim=-1) | |
x = apply_module_with_conv2d_bn( | |
cur_up_feats.unsqueeze(2), up_conv_layers | |
).squeeze(2) | |
pos = prev_pos | |
# # bsz x nnf x nnbasepts x nnbaseptsfeats # | |
# if return_global: | |
# return x, global_x, pos # pos, base_pts_feats # | |
# else: | |
return x, pos | |