meow
init
d6d3a5b
import numpy as np
import torch
import time
from scipy.spatial.transform import Rotation as R
try:
from torch_cluster import fps
except:
pass
from collections import OrderedDict
import os, argparse, copy, json
import math
def sample_pcd_from_mesh(vertices, triangles, npoints=512):
arears = []
for i in range(triangles.shape[0]):
v_a, v_b, v_c = int(triangles[i, 0].item()), int(triangles[i, 1].item()), int(triangles[i, 2].item())
v_a, v_b, v_c = vertices[v_a], vertices[v_b], vertices[v_c]
ab, ac = v_b - v_a, v_c - v_a
cos_ab_ac = (np.sum(ab * ac) / np.clip(np.sqrt(np.sum(ab ** 2)) * np.sqrt(np.sum(ac ** 2)), a_min=1e-9, a_max=9999999.0)).item()
sin_ab_ac = math.sqrt(1. - cos_ab_ac ** 2)
cur_area = 0.5 * sin_ab_ac * np.sqrt(np.sum(ab ** 2)).item() * np.sqrt(np.sum(ac ** 2)).item()
arears.append(cur_area)
tot_area = sum(arears)
sampled_pcts = []
tot_indices = []
tot_factors = []
for i in range(triangles.shape[0]):
v_a, v_b, v_c = int(triangles[i, 0].item()), int(triangles[i, 1].item()), int(
triangles[i, 2].item())
v_a, v_b, v_c = vertices[v_a], vertices[v_b], vertices[v_c]
# ab, ac = v_b - v_a, v_c - v_a
# cur_sampled_pts = int(npoints * (arears[i] / tot_area))
cur_sampled_pts = math.ceil(npoints * (arears[i] / tot_area))
# if cur_sampled_pts == 0:
cur_sampled_pts = int(arears[i] * npoints)
cur_sampled_pts = 1 if cur_sampled_pts == 0 else cur_sampled_pts
tmp_x, tmp_y = np.random.uniform(0, 1., (cur_sampled_pts,)).tolist(), np.random.uniform(0., 1., (cur_sampled_pts,)).tolist()
for xx, yy in zip(tmp_x, tmp_y):
sqrt_xx, sqrt_yy = math.sqrt(xx), math.sqrt(yy)
aa = 1. - sqrt_xx
bb = sqrt_xx * (1. - yy)
cc = yy * sqrt_xx
cur_pos = v_a * aa + v_b * bb + v_c * cc
sampled_pcts.append(cur_pos)
tot_indices.append(triangles[i]) # tot_indices for triangles # # vertices indices
tot_factors.append([aa, bb, cc])
tot_indices = np.array(tot_indices, dtype=np.long)
tot_factors = np.array(tot_factors, dtype=np.float32)
sampled_ptcs = np.array(sampled_pcts)
print("sampled points from surface:", sampled_ptcs.shape)
# sampled_pcts = np.concatenate([sampled_pcts, vertices], axis=0)
return sampled_ptcs, tot_indices, tot_factors
def read_obj_file_ours(obj_fn, sub_one=False):
vertices = []
faces = []
with open(obj_fn, "r") as rf:
for line in rf:
items = line.strip().split(" ")
if items[0] == 'v':
cur_verts = items[1:]
cur_verts = [float(vv) for vv in cur_verts]
vertices.append(cur_verts)
elif items[0] == 'f':
cur_faces = items[1:] # faces
cur_face_idxes = []
for cur_f in cur_faces:
try:
cur_f_idx = int(cur_f.split("/")[0])
except:
cur_f_idx = int(cur_f.split("//")[0])
cur_face_idxes.append(cur_f_idx if not sub_one else cur_f_idx - 1)
faces.append(cur_face_idxes)
rf.close()
vertices = np.array(vertices, dtype=np.float)
return vertices, faces
def clamp_gradient(model, clip):
for p in model.parameters():
torch.nn.utils.clip_grad_value_(p, clip)
def clamp_gradient_norm(model, max_norm, norm_type=2):
for p in model.parameters():
torch.nn.utils.clip_grad_norm_(p, max_norm, norm_type=2)
def save_network(net, directory, network_label, epoch_label=None, **kwargs):
"""
save model to directory with name {network_label}_{epoch_label}.pth
Args:
net: pytorch model
directory: output directory
network_label: str
epoch_label: convertible to str
kwargs: additional value to be included
"""
save_filename = "_".join((network_label, str(epoch_label))) + ".pth"
save_path = os.path.join(directory, save_filename)
merge_states = OrderedDict()
merge_states["states"] = net.cpu().state_dict()
for k in kwargs:
merge_states[k] = kwargs[k]
torch.save(merge_states, save_path)
net = net.cuda()
def load_network(net, path):
"""
load network parameters whose name exists in the pth file.
return:
INT trained step
"""
# warnings.DeprecationWarning("load_network is deprecated. Use module.load_state_dict(strict=False) instead.")
if isinstance(path, str):
logger.info("loading network from {}".format(path))
if path[-3:] == "pth":
loaded_state = torch.load(path)
if "states" in loaded_state:
loaded_state = loaded_state["states"]
else:
loaded_state = np.load(path).item()
if "states" in loaded_state:
loaded_state = loaded_state["states"]
elif isinstance(path, dict):
loaded_state = path
network = net.module if isinstance(
net, torch.nn.DataParallel) else net
missingkeys, unexpectedkeys = network.load_state_dict(loaded_state, strict=False)
if len(missingkeys)>0:
logger.warn("load_network {} missing keys".format(len(missingkeys)), "\n".join(missingkeys))
if len(unexpectedkeys)>0:
logger.warn("load_network {} unexpected keys".format(len(unexpectedkeys)), "\n".join(unexpectedkeys))
def weights_init(m):
"""
initialize the weighs of the network for Convolutional layers and batchnorm layers
"""
if isinstance(m, (torch.nn.modules.conv._ConvNd, torch.nn.Linear)):
torch.nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0.0)
elif isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
torch.nn.init.constant_(m.bias, 0.0)
torch.nn.init.constant_(m.weight, 1.0)
def seal(mesh_to_seal):
circle_v_id = np.array([108, 79, 78, 121, 214, 215, 279, 239, 234, 92, 38, 122, 118, 117, 119, 120], dtype = np.int32)
center = (mesh_to_seal.v[circle_v_id, :]).mean(0)
sealed_mesh = copy.copy(mesh_to_seal)
sealed_mesh.v = np.vstack([mesh_to_seal.v, center])
center_v_id = sealed_mesh.v.shape[0] - 1
for i in range(circle_v_id.shape[0]):
new_faces = [circle_v_id[i-1], circle_v_id[i], center_v_id]
sealed_mesh.f = np.vstack([sealed_mesh.f, new_faces])
return sealed_mesh
def read_pos_fr_txt(txt_fn):
pos_data = []
with open(txt_fn, "r") as rf:
for line in rf:
cur_pos = line.strip().split(" ")
cur_pos = [float(p) for p in cur_pos]
pos_data.append(cur_pos)
rf.close()
pos_data = np.array(pos_data, dtype=np.float32)
print(f"pos_data: {pos_data.shape}")
return pos_data
def read_field_data_fr_txt(field_fn):
field_data = []
with open(field_fn, "r") as rf:
for line in rf:
cur_field = line.strip().split(" ")
cur_field = [float(p) for p in cur_field]
field_data.append(cur_field)
rf.close()
field_data = np.array(field_data, dtype=np.float32)
print(f"filed_data: {field_data.shape}")
return field_data
def farthest_point_sampling(pos: torch.FloatTensor, n_sampling: int):
bz, N = pos.size(0), pos.size(1)
feat_dim = pos.size(-1)
device = pos.device
sampling_ratio = float(n_sampling / N)
pos_float = pos.float()
batch = torch.arange(bz, dtype=torch.long).view(bz, 1).to(device)
mult_one = torch.ones((N,), dtype=torch.long).view(1, N).to(device)
batch = batch * mult_one
batch = batch.view(-1)
pos_float = pos_float.contiguous().view(-1, feat_dim).contiguous() # (bz x N, 3)
# sampling_ratio = torch.tensor([sampling_ratio for _ in range(bz)], dtype=torch.float).to(device)
# batch = torch.zeros((N, ), dtype=torch.long, device=device)
sampled_idx = fps(pos_float, batch, ratio=sampling_ratio, random_start=False)
# shape of sampled_idx?
return sampled_idx
def batched_index_select_ours(values, indices, dim = 1):
value_dims = values.shape[(dim + 1):]
values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
indices = indices[(..., *((None,) * len(value_dims)))]
indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
value_expand_len = len(indices_shape) - (dim + 1)
values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
value_expand_shape = [-1] * len(values.shape)
expand_slice = slice(dim, (dim + value_expand_len))
value_expand_shape[expand_slice] = indices.shape[expand_slice]
values = values.expand(*value_expand_shape)
dim += value_expand_len
return values.gather(dim, indices)
def compute_nearest(query, verts):
# query: bsz x nn_q x 3
# verts: bsz x nn_q x 3
dists = torch.sum((query.unsqueeze(2) - verts.unsqueeze(1)) ** 2, dim=-1)
minn_dists, minn_dists_idx = torch.min(dists, dim=-1) # bsz x nn_q
minn_pts_pos = batched_index_select_ours(values=verts, indices=minn_dists_idx, dim=1)
minn_pts_pos = minn_pts_pos.unsqueeze(2)
minn_dists_idx = minn_dists_idx.unsqueeze(2)
return minn_dists, minn_dists_idx, minn_pts_pos
def batched_index_select(t, dim, inds):
"""
Helper function to extract batch-varying indicies along array
:param t: array to select from
:param dim: dimension to select along
:param inds: batch-vary indicies
:return:
"""
dummy = inds.unsqueeze(2).expand(inds.size(0), inds.size(1), t.size(2))
out = t.gather(dim, dummy) # b x e x f
return out
def batched_get_rot_mtx_fr_vecs(normal_vecs):
# normal_vecs: nn_pts x 3 #
#
normal_vecs = normal_vecs / torch.clamp(torch.norm(normal_vecs, p=2, dim=-1, keepdim=True), min=1e-5)
sin_theta = normal_vecs[..., 0]
cos_theta = torch.sqrt(1. - sin_theta ** 2)
sin_phi = normal_vecs[..., 1] / torch.clamp(cos_theta, min=1e-5)
# cos_phi = torch.sqrt(1. - sin_phi ** 2)
cos_phi = normal_vecs[..., 2] / torch.clamp(cos_theta, min=1e-5)
sin_phi[cos_theta < 1e-5] = 1.
cos_phi[cos_theta < 1e-5] = 0.
#
y_rot_mtx = torch.stack(
[
torch.stack([cos_theta, torch.zeros_like(cos_theta), -sin_theta], dim=-1),
torch.stack([torch.zeros_like(cos_theta), torch.ones_like(cos_theta), torch.zeros_like(cos_theta)], dim=-1),
torch.stack([sin_theta, torch.zeros_like(cos_theta), cos_theta], dim=-1)
], dim=-1
)
x_rot_mtx = torch.stack(
[
torch.stack([torch.ones_like(cos_theta), torch.zeros_like(cos_theta), torch.zeros_like(cos_theta)], dim=-1),
torch.stack([torch.zeros_like(cos_phi), cos_phi, -sin_phi], dim=-1),
torch.stack([torch.zeros_like(cos_phi), sin_phi, cos_phi], dim=-1)
], dim=-1
)
rot_mtx = torch.matmul(x_rot_mtx, y_rot_mtx)
return rot_mtx
def batched_get_rot_mtx_fr_vecs_v2(normal_vecs):
# normal_vecs: nn_pts x 3 #
#
normal_vecs = normal_vecs / torch.clamp(torch.norm(normal_vecs, p=2, dim=-1, keepdim=True), min=1e-5)
sin_theta = normal_vecs[..., 0]
cos_theta = torch.sqrt(1. - sin_theta ** 2)
sin_phi = normal_vecs[..., 1] / torch.clamp(cos_theta, min=1e-5)
# cos_phi = torch.sqrt(1. - sin_phi ** 2)
cos_phi = normal_vecs[..., 2] / torch.clamp(cos_theta, min=1e-5)
sin_phi[cos_theta < 1e-5] = 1.
cos_phi[cos_theta < 1e-5] = 0.
# o: nn_pts x 3 #
o = torch.stack(
[torch.zeros_like(cos_phi), cos_phi, -sin_phi], dim=-1
)
nxo = torch.cross(o, normal_vecs)
# rot_mtx: nn_pts x 3 x 3 #
rot_mtx = torch.stack(
[nxo, o, normal_vecs], dim=-1
)
return rot_mtx
def batched_get_orientation_matrices(rot_vec):
rot_matrices = []
for i_w in range(rot_vec.shape[0]):
cur_rot_vec = rot_vec[i_w]
cur_rot_mtx = R.from_rotvec(cur_rot_vec).as_matrix()
rot_matrices.append(cur_rot_mtx)
rot_matrices = np.stack(rot_matrices, axis=0)
return rot_matrices
def batched_get_minn_dist_corresponding_pts(tips, obj_pcs):
dist_tips_to_obj_pc_minn_idx = np.argmin(
((tips.reshape(tips.shape[0], tips.shape[1], 1, 3) - obj_pcs.reshape(obj_pcs.shape[0], 1, obj_pcs.shape[1], 3)) ** 2).sum(axis=-1), axis=-1
)
obj_pcs_th = torch.from_numpy(obj_pcs).float()
dist_tips_to_obj_pc_minn_idx_th = torch.from_numpy(dist_tips_to_obj_pc_minn_idx).long()
nearest_pc_th = batched_index_select(obj_pcs_th, 1, dist_tips_to_obj_pc_minn_idx_th)
return nearest_pc_th, dist_tips_to_obj_pc_minn_idx_th
def get_affinity_fr_dist(dist, s=0.02):
### affinity scores ###
k = 0.5 * torch.cos(torch.pi / s * torch.abs(dist)) + 0.5
return k
def batched_reverse_transform(rot, transl, t_pc, trans=True):
# t_pc: ws x nn_obj x 3
# rot; ws x 3 x 3
# transl: ws x 1 x 3
if trans:
reverse_trans_pc = t_pc - transl
else:
reverse_trans_pc = t_pc
reverse_trans_pc = np.matmul(np.transpose(rot, (0, 2, 1)), np.transpose(reverse_trans_pc, (0, 2, 1)))
reverse_trans_pc = np.transpose(reverse_trans_pc, (0, 2, 1))
return reverse_trans_pc
def capsule_sdf(mesh_verts, mesh_normals, query_points, query_normals, caps_rad, caps_top, caps_bot, foreach_on_mesh):
# if caps on hand: mesh_verts = hand vert
"""
Find the SDF of query points to mesh verts
Capsule SDF formulation from https://iquilezles.org/www/articles/distfunctions/distfunctions.htm
:param mesh_verts: (batch, V, 3)
:param mesh_normals: (batch, V, 3)
:param query_points: (batch, Q, 3)
:param caps_rad: scalar, radius of capsules
:param caps_top: scalar, distance from mesh to top of capsule
:param caps_bot: scalar, distance from mesh to bottom of capsule
:param foreach_on_mesh: boolean, foreach point on mesh find closest query (V), or foreach query find closest mesh (Q)
:return: normalized sdsf + 1 (batch, V or Q)
"""
# TODO implement normal check?
if foreach_on_mesh: # Foreach mesh vert, find closest query point
# knn_dists, nearest_idx, nearest_pos = pytorch3d.ops.knn_points(mesh_verts, query_points, K=1, return_nn=True) # TODO should attract capsule middle?
# knn_dists, nearest_idx, nearest_pos = compute_nearest(query_points, mesh_verts)
knn_dists, nearest_idx, nearest_pos = compute_nearest(mesh_verts, query_points)
capsule_tops = mesh_verts + mesh_normals * caps_top
capsule_bots = mesh_verts + mesh_normals * caps_bot
delta_top = nearest_pos[:, :, 0, :] - capsule_tops
normal_dot = torch.sum(mesh_normals * batched_index_select(query_normals, 1, nearest_idx.squeeze(2)), dim=2)
rt_nearest_verts = mesh_verts
rt_nearest_normals = mesh_normals
else: # Foreach query vert, find closest mesh point
# knn_dists, nearest_idx, nearest_pos = pytorch3d.ops.knn_points(query_points, mesh_verts, K=1, return_nn=True) # TODO should attract capsule middle?
st_time = time.time()
knn_dists, nearest_idx, nearest_pos = compute_nearest(query_points, mesh_verts)
ed_time = time.time()
# print(f"Time for computing nearest: {ed_time - st_time}")
closest_mesh_verts = batched_index_select(mesh_verts, 1, nearest_idx.squeeze(2)) # Shape (batch, V, 3)
closest_mesh_normals = batched_index_select(mesh_normals, 1, nearest_idx.squeeze(2)) # Shape (batch, V, 3)
capsule_tops = closest_mesh_verts + closest_mesh_normals * caps_top # Coordinates of the top focii of the capsules (batch, V, 3)
capsule_bots = closest_mesh_verts + closest_mesh_normals * caps_bot
delta_top = query_points - capsule_tops
# normal_dot = torch.sum(query_normals * closest_mesh_normals, dim=2)
normal_dot = None
rt_nearest_verts = closest_mesh_verts
rt_nearest_normals = closest_mesh_normals
# (top -> bot) #!!#
bot_to_top = capsule_bots - capsule_tops # Vector from capsule bottom to top
along_axis = torch.sum(delta_top * bot_to_top, dim=2) # Dot product
top_to_bot_square = torch.sum(bot_to_top * bot_to_top, dim=2)
# print(f"top_to_bot_square: {top_to_bot_square[..., :10]}")
h = torch.clamp(along_axis / top_to_bot_square, 0, 1) # Could avoid NaNs with offset in division here
dist_to_axis = torch.norm(delta_top - bot_to_top * h.unsqueeze(2), dim=2) # Distance to capsule centerline
# two endpoints; edge of the capsule #
return dist_to_axis / caps_rad, normal_dot, rt_nearest_verts, rt_nearest_normals # (Normalized SDF)+1 0 on endpoint, 1 on edge of capsule
def reparameterize_gaussian(mean, logvar):
std = torch.exp(0.5 * logvar) ### std and eps -->
eps = torch.randn(std.size()).to(mean.device)
return mean + std * eps
def gaussian_entropy(logvar):
const = 0.5 * float(logvar.size(1)) * (1. + np.log(np.pi * 2))
ent = 0.5 * logvar.sum(dim=1, keepdim=False) + const
return ent
def standard_normal_logprob(z): # feature dim
dim = z.size(-1)
log_z = -0.5 * dim * np.log(2 * np.pi)
return log_z - z.pow(2) / 2
def truncated_normal_(tensor, mean=0, std=1, trunc_std=2):
"""
Taken from https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15
"""
size = tensor.shape
tmp = tensor.new_empty(size + (4,)).normal_()
valid = (tmp < trunc_std) & (tmp > -trunc_std)
ind = valid.max(-1, keepdim=True)[1]
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
tensor.data.mul_(std).add_(mean)
return tensor
def makepath(desired_path, isfile = False):
'''
if the path does not exist make it
:param desired_path: can be path to a file or a folder name
:return:
'''
import os
if isfile:
if not os.path.exists(os.path.dirname(desired_path)):os.makedirs(os.path.dirname(desired_path))
else:
if not os.path.exists(desired_path): os.makedirs(desired_path)
return desired_path
def batch_gather(arr, ind):
"""
:param arr: B x N x D
:param ind: B x M
:return: B x M x D
"""
dummy = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), arr.size(2))
out = torch.gather(arr, 1, dummy)
return out
def random_rotate_np(x):
aa = np.random.randn(3)
theta = np.sqrt(np.sum(aa**2))
k = aa / np.maximum(theta, 1e-6)
K = np.array([[0, -k[2], k[1]],
[k[2], 0, -k[0]],
[-k[1], k[0], 0]])
R = np.eye(3) + np.sin(theta)*K + (1-np.cos(theta))*np.matmul(K, K)
R = R.astype(np.float32)
return np.matmul(x, R), R
def rotate_x(x, rad):
rad = -rad
rotmat = np.array([
[1, 0, 0],
[0, np.cos(rad), -np.sin(rad)],
[0, np.sin(rad), np.cos(rad)]
])
return np.dot(x, rotmat)
def rotate_y(x, rad):
rad = -rad
rotmat = np.array([
[np.cos(rad), 0, np.sin(rad)],
[0, 1, 0],
[-np.sin(rad), 0, np.cos(rad)]
])
return np.dot(x, rotmat)
def rotate_z(x, rad):
rad = -rad
rotmat = np.array([
[np.cos(rad), -np.sin(rad), 0],
[np.sin(rad), np.cos(rad), 0],
[0, 0, 1]
])
return np.dot(x, rotmat)