gene-hoi-denoising / utils /common_utils.py
meow
init
d6d3a5b
import torch
import time
import numpy as np
import utils
# from common_utils import data_utils_torch as data_utils
# from common_utils.part_transform import revoluteTransform
# import random
import utils.model_util as model_util
# batched_index_select_ours
def get_random_rot_np():
aa = np.random.randn(3)
theta = np.sqrt(np.sum(aa**2))
k = aa / np.maximum(theta, 1e-6)
K = np.array([[0, -k[2], k[1]],
[k[2], 0, -k[0]],
[-k[1], k[0], 0]])
R = np.eye(3) + np.sin(theta)*K + (1-np.cos(theta))*np.matmul(K, K)
R = R.astype(np.float32)
return R
def get_faces_from_verts(verts, faces, sel_verts_idxes, sel_faces=None):
### n_sel_faces x 3 x 3 -->
# sel_faces = []
if not isinstance(sel_verts_idxes, list):
sel_verts_idxes = sel_verts_idxes.tolist()
sel_verts_idxes_dict = {sel_idx : 1 for sel_idx in sel_verts_idxes}
if sel_faces is None:
sel_faces = []
for i_f in range(faces.size(0)):
cur_f = faces[i_f]
va, vb, vc = cur_f.tolist()
# va = va - 1
# vb = vb - 1
# vc = vc - 1
if va in sel_verts_idxes_dict or vb in sel_verts_idxes_dict or vc in sel_verts_idxes_dict:
sel_faces.append(faces[i_f].tolist()) ### sel_faces items...
# print(f"number of sel_faces: {len(sel_faces)}")
sel_faces = torch.tensor(sel_faces, dtype=torch.long).cuda() ### n_sel_faces x 3 ## sel_
# print(f"verts: {verts.size()}, max_sel_faces: {torch.max(sel_faces)}, min_sel_faces: {torch.min(sel_faces)}")
sel_faces_vals = model_util.batched_index_select_ours(values=verts, indices=sel_faces, dim=0) ### self_faces_vals: n_sel_faces x 3 x 3 ### sel_fces_vals...
return sel_faces, sel_faces_vals
def sel_faces_values_from_sel_faces(verts, sel_faces):
sel_faces_vals = model_util.batched_index_select_ours(values=verts, indices=sel_faces, dim=0) #
return sel_faces_vals
def get_sub_verts_faces_from_pts(verts, faces, pts, rt_sel_faces=False, minn_dist_pts_verts_idx=None, sel_faces=None):
### return tyep: sel_verts: n_pts x 3; sel_faces: faces selected from sel_verts ###
## verts: n_verts x 3; pts: n_pts x 3
dis_pts_verts = torch.sum((pts.unsqueeze(1) - verts.unsqueeze(0)) ** 2, dim=-1) ### n_pts x n_verts ###
if minn_dist_pts_verts_idx is None:
minn_dist_pts_verts, minn_dist_pts_verts_idx = torch.min(dis_pts_verts, dim=-1) ###
sel_verts = verts[minn_dist_pts_verts_idx] ### should be close to pts in the euclidean distance
sel_faces, sel_faces_vals = get_faces_from_verts(verts, faces, minn_dist_pts_verts_idx, sel_faces=sel_faces) ### sel_faces_vals: n_sel_faces x 3 x 3
if rt_sel_faces:
return sel_verts, sel_faces, sel_faces_vals, minn_dist_pts_verts_idx
else:
return sel_verts, sel_faces_vals, minn_dist_pts_verts_idx
##### distance of each sel_vert in mesh_2 to each face in sel_faces in mesh_1 ##### --->
def get_faces_normals(faces_vals):
### faces_vals: n_faces x 3 x 3
vas = faces_vals[:, 0, :]
vbs = faces_vals[:, 1, :]
vcs = faces_vals[:, 2, :]
vabs = vbs - vas
vacs = vcs - vas ### n_faces x 3
vns = torch.cross(vabs, vacs) ### n_faces x 3 ---> cross product between two vectors
vns = vns / torch.clamp(torch.norm(vns, dim=-1, p=2, keepdim=True), min=1e-6) ### vns: n_faces x 3
return vns
###
def get_distance_pts_faces(pts, faces_vals, faces_vns):
### faces_vals: n_faces x 3 x 3 ##
### faces_vns: n_faces x 3
### ax + by + cz = d ### one pts and another pts --> faces_vals --> faces_ds
faces_ds = torch.sum(faces_vals[:, 0, :] * faces_vns, dim=-1) ## n_faces x 3 xxx n_faces x 3 --> n_faces
### distance from one point to another point ###
### pts: n_pts x 3; faces_vns: n_faces x 3
faces_pts_ds = torch.sum(pts.unsqueeze(1) * faces_vns.unsqueeze(0), dim=-1) ### n_pts x n_faces ### ### negative distances -->
delta_faces_pts_ds = faces_pts_ds - faces_ds.unsqueeze(0) ### n_pts x n_faces ### ### as an distance vector is pts can be projected to the faces ### pts_ds;
### 1 x n_faces x 3 xxxxxx n_pts x n_faces x 1 --> n_pts x n_faces x 3
projected_pts = pts.unsqueeze(1) - faces_vns.unsqueeze(0) * delta_faces_pts_ds.unsqueeze(-1)
### n_faces x 3 x 3
### vab vac ###
va, vb, vc = faces_vals[:, 0, :], faces_vals[:, 1, :], faces_vals[:, 2, :] ## n_faces x 3
projected_pts = projected_pts - va.unsqueeze(0)
vab, vac = vb - va, vc - va
vab_norm, vac_norm = vab / torch.clamp(torch.norm(vab, dim=-1, p=2, keepdim=True), min=1e-7), vac / torch.clamp(torch.norm(vac, dim=-1, p=2, keepdim=True), min=1e-7)
coeff_vab = torch.sum(vab_norm.unsqueeze(0) * projected_pts, dim=-1) / torch.clamp(torch.norm(vab, dim=-1, p=2, keepdim=False), min=1e-7)
coeff_vac = torch.sum(vac_norm.unsqueeze(0) * projected_pts, dim=-1) / torch.clamp(torch.norm(vac, dim=-1, p=2, keepdim=False), min=1e-7)
# coeff_vab = torch.sum(vab.unsqueeze(0) * projected_pts, dim=-1) ### n_pts x n_faces
# coeff_vac = torch.sum(vac.unsqueeze(0) * projected_pts, dim=-1) ### n_pts x n_faces
pts_in_faces = (((coeff_vab >= 0.).float() + (coeff_vac >= 0.).float() + (coeff_vab + coeff_vac <= 0.5).float()) > 2.5).float() ### n_pts x n_faces ### pts_in_faces -->
### pts_in_faces and delta_faces_pts_ds
### pts_in_faces: the projected pts in faces...
return delta_faces_pts_ds, pts_in_faces ### delta_faces_pts_ds: n_pts x n_faces; pts_in_faces: n_pts x n_faces ###
###
#### revolute joitns here ####
def collision_loss(mesh_1, mesh_2, keypts_1, keypts_2, joints, n_sim_steps=100, early_stop=False, penalize_largest=False, pts_loss=False, st_def_pcs=None):
joint_dir, joint_pvp, joint_angle = joints
joint_dir = joint_dir.cuda()
joint_pvp = joint_pvp.cuda()
### should save the sequence of transformed shapes ... ###
verts1, faces1 = mesh_1
verts2, faces2 = mesh_2 # mesh 1 mesh 2
### verts2, keypts_2.detach() ###
### verts2, keypts_2 ###
verts2 = verts2.detach()
keypts_2 = keypts_2.detach()
### verts2, keypts_2 ###
### not just a loss, but constraints ###
### iteratively projection ###
# print(f"verts1: {verts1.size()}, verts2: {verts2.size()}, faces1: {faces1.size()}, faces2: {faces2.size()}, joint_dir: {joint_dir.size()}, joint_pvp: {joint_pvp.size()}")
### sel_verts, sel_faces_vals ###
### for sub_verts and sub_faces_vals ###
sel_verts2, sel_faces_vals2 = get_sub_verts_faces_from_pts(verts2, faces2, keypts_2)
sel_faces_vns2 = get_faces_normals(sel_faces_vals2)
sel_faces_vns2 = sel_faces_vns2.detach()
### pts_in_faces & delta_ds in two adjacent time stamps ###
### collision response for the loss term? ###
### penetration depth * face_ns for all collided pts
keypts_sequence = [keypts_1.clone()]
keypts_sequence = []
delta_faces_pts_ds_sequence = []
pts_in_faces_sequence = []
### pivot point prediction, joint axis direction prediction ###
# ### joint axis direction prediction ###
delta_joint_angle = joint_angle / float(n_sim_steps) ### delta_joint_angles ###
tot_collision_loss = 0.
non_collided_pts = torch.ones((keypts_1.size(0), ), dtype=torch.float32, requires_grad=False).cuda()
mesh_pts_sequence = []
def_pcs_sequence = []
for i in range(0, n_sim_steps): ### joint_angle ###
cur_joint_angle = i * delta_joint_angle ### delta_joint_angle ###
### revoluteTransform ### joint_pvp
pts, m = revoluteTransform(keypts_1.detach().cpu().numpy(), joint_pvp.detach().cpu().numpy(), joint_dir.detach().cpu().numpy(), cur_joint_angle)
m = torch.from_numpy(m).float().cuda() ### 4 x 4
kpts_expanded = torch.cat([keypts_1, torch.ones((keypts_1.size(0), 1), dtype=torch.float32).cuda()], dim=-1) #### kpts_expanded
pts = torch.matmul(kpts_expanded, m)
# pts = torch.matmul(keypts_1, m[:3]) ### n_keypts x 3 xxx 3 x 4 -->
pts = pts[:, :3] ### pts: n_keypts x 3
# if st_def_pcs is not None:
# part1_pc = st_def_pcs[0][0]
# part1_pc_expanded = torch.cat([])
#### distance_pts_faces, pts_in_faces, delta_faces_pts_ds #### # distance pts faces #
delta_faces_pts_ds, pts_in_faces = get_distance_pts_faces(pts, sel_faces_vals2, sel_faces_vns2)
mesh_pts_expanded = torch.cat([verts1, torch.ones((verts1.size(0), 1), dtype=torch.float32).cuda()], dim=-1) #### kpts_expanded
mesh_pts_expanded = torch.matmul(mesh_pts_expanded, m)
# pts = torch.matmul(keypts_1, m[:3]) ### n_keypts x 3 xxx 3 x 4 -->
mesh_pts_expanded = mesh_pts_expanded[:, :3] ### pts: n_keypts x 3
mesh_pts_sequence.append(mesh_pts_expanded.clone())
###
# if penalize_largest:
# ### delta_faces_pts_ds: n_pts x n_faces; pts_in_faces: n_pts x n_faces ###
# abs_filtered_delta_faces_pts_ds = torch.abs(delta_faces_pts_ds) * pts_in_faces ###
# maxx_abs_pts_faces_ds, maxx_abs_pts_faces_ds_idxes = torch.max(abs_filtered_delta_faces_pts_ds, dim=-1) ###
# maxx_abs_pts_faces_ds_idxes = maxx_abs_pts_faces_ds_idxes.unsqueeze(-1).contiguous().repeat(1, pts_in_faces.size(1)) ###
# pts_faces_ds_idxes_range = torch.arange(pts_in_faces.size(1)).contiguous().unsqueeze(0).cuda()
# maxx_mask = (pts_faces_ds_idxes_range == maxx_abs_pts_faces_ds_idxes)
# cur_delta_faces_pts_ds = torch.zeros_like(delta_faces_pts_ds)
# cur_delta_faces_pts_ds[maxx_mask] = delta_faces_pts_ds[maxx_mask]
# delta_faces_pts_ds = cur_delta_faces_pts_ds
# if len(delta_faces_pts_ds_sequence) > 0 and i == 0 or i == n_sim_steps - 1:
if len(delta_faces_pts_ds_sequence) > 0:
prev_delta_faces_pts_ds = delta_faces_pts_ds_sequence[-1]
prev_pts_in_faces = pts_in_faces_sequence[-1] ### not important
prev_pts = keypts_sequence[-1]
sgn_delta_faces_ds = torch.sign(delta_faces_pts_ds) ### sign of faces_ds
sgn_prev_delta_faces_ds = torch.sign(prev_delta_faces_pts_ds) ### sign of prev_faces_ds
### different signs ###
collision_pts_faces = (sgn_delta_faces_ds != sgn_prev_delta_faces_ds).float() ### n_pts x n_faces
if penalize_largest:
### delta_faces_pts_ds: n_pts x n_faces; pts_in_faces: n_pts x n_faces ###
abs_filtered_delta_faces_pts_ds = torch.abs(delta_faces_pts_ds) * collision_pts_faces * pts_in_faces ###
maxx_abs_pts_faces_ds, maxx_abs_pts_faces_ds_idxes = torch.max(abs_filtered_delta_faces_pts_ds, dim=-1) ###
maxx_abs_pts_faces_ds_idxes = maxx_abs_pts_faces_ds_idxes.unsqueeze(-1).contiguous().repeat(1, pts_in_faces.size(1)) ###
pts_faces_ds_idxes_range = torch.arange(pts_in_faces.size(1)).contiguous().unsqueeze(0).cuda()
maxx_mask = (pts_faces_ds_idxes_range == maxx_abs_pts_faces_ds_idxes)
cur_delta_faces_pts_ds = torch.zeros_like(delta_faces_pts_ds)
cur_delta_faces_pts_ds[maxx_mask] = delta_faces_pts_ds[maxx_mask]
# delta_faces_pts_ds = cur_delta_faces_pts_ds
else:
cur_delta_faces_pts_ds = delta_faces_pts_ds
# collision_dists = collision_pts_faces * delta_faces_pts_ds ### n_pts x n_faces
collision_dists = collision_pts_faces * cur_delta_faces_pts_ds ### n_pts x n_faces
### collision_dists ### n_pts x n_faces
### i think the sim step
### whether tow meshes collide with each other: in the target mesh,
### collision_pulse, collision_dists
collision_pulse = (collision_dists * pts_in_faces * non_collided_pts.unsqueeze(-1)).unsqueeze(-1) * sel_faces_vns2.unsqueeze(0) ### n_pts x n_faces x 3 --> pulse
collided_indicator = ((pts_in_faces * non_collided_pts.unsqueeze(-1) * collision_pts_faces).sum(-1) > 0.1).float()
if pts_loss:
### loss version v2: for pts directly ###
collision_loss = torch.sum(collision_pulse * pts.unsqueeze(1), dim=-1) ### n_pts x n_faces ###
### loss version v2: for pts directly ###
else:
### loss version v1: for pts directly ###
delta_keypts = pts - prev_pts ### from the previous keypts to hte current keypts ### pts - prev_pts
collision_loss = torch.sum(collision_pulse * delta_keypts.unsqueeze(1), dim=-1) ### n_pts x n_faces ###
### loss version v1: for pts directly ###
non_collided_indicator = 1.0 - collided_indicator # - (collision_pulse.sum(-1).sum(-1) > 1e-6).float()
# print(f"collision_pulse: {collision_pulse.size()}, collided_indicator: {collided_indicator.size()}, non_collided_pts: {non_collided_pts.size()}")
# non_collided_pts[collided_indicator] = non_collided_pts[collided_indicator] * 0.
non_collided_pts = non_collided_pts * non_collided_indicator
# print(f"collision_loss: {collision_loss.sum().mean().item()}, collided_indicator: {collided_indicator.sum(-1).item()}, non_collided_pts: {non_collided_pts.sum(-1).item()}")
# ### loss version v2: for pts directly ###
# collision_loss = torch.sum(collision_pulse * pts.unsqueeze(1), dim=-1) ### n_pts x n_faces ###
# ### loss version v2: for pts directly ###
collision_loss = torch.sum(collision_loss, dim=-1).sum() ### ## for all faces ###
tot_collision_loss += collision_loss
if early_stop and collision_loss.item() > 0.0001:
break
delta_faces_pts_ds_sequence.append(delta_faces_pts_ds.clone())
pts_in_faces_sequence.append(pts_in_faces.clone())
keypts_sequence.append(pts.clone())
# tot_collision_loss /= n_sim_steps
### delta_faces_
# if
# print(f"tot_collision_loss: {tot_collision_loss}")
### can even test for one part at first
return tot_collision_loss, keypts_sequence, mesh_pts_sequence ### collision_loss for all sim steps ###
###
#### revolute joitns here ####
def collision_loss_prismatic(mesh_1, mesh_2, keypts_1, keypts_2, joints, n_sim_steps=100, early_stop=False, penalize_largest=False, pts_loss=False, st_def_pcs=None):
joint_dir, joint_pvp, joint_angle = joints
joint_dir = joint_dir.cuda()
joint_pvp = joint_pvp.cuda()
### should save the sequence of transformed shapes ... ###
verts1, faces1 = mesh_1
verts2, faces2 = mesh_2
### verts2, keypts_2.detach() ###
### verts2, keypts_2 ###
verts2 = verts2.detach()
keypts_2 = keypts_2.detach()
### verts2, keypts_2 ###
### not just a loss, but constraints ###
### iteratively projection ###
# print(f"verts1: {verts1.size()}, verts2: {verts2.size()}, faces1: {faces1.size()}, faces2: {faces2.size()}, joint_dir: {joint_dir.size()}, joint_pvp: {joint_pvp.size()}")
### sel_verts, sel_faces_vals ###
### for sub_verts and sub_faces_vals ###
sel_verts2, sel_faces_vals2 = get_sub_verts_faces_from_pts(verts2, faces2, keypts_2)
sel_faces_vns2 = get_faces_normals(sel_faces_vals2)
sel_faces_vns2 = sel_faces_vns2.detach()
### pts_in_faces & delta_ds in two adjacent time stamps ###
### collision response for the loss term? ###
### penetration depth * face_ns for all collided pts
keypts_sequence = [keypts_1.clone()]
keypts_sequence = []
delta_faces_pts_ds_sequence = []
pts_in_faces_sequence = []
### pivot point prediction, joint axis direction prediction ###
# ### joint axis direction prediction ###
# delta_joint_angle = joint_angle / float(n_sim_steps) ### delta_joint_angles ###
delta_joint_angle = 1.0 / float(n_sim_steps)
tot_collision_loss = 0.
non_collided_pts = torch.ones((keypts_1.size(0), ), dtype=torch.float32, requires_grad=False).cuda()
mesh_pts_sequence = []
def_pcs_sequence = []
for i in range(0, n_sim_steps): ### joint_angle ###
# cur_joint_angle = i * delta_joint_angle ### delta_joint_angle ###
cur_delta_dis = i * delta_joint_angle
moving_dis = joint_dir * cur_delta_dis ## (3,)
moving_dis = moving_dis.unsqueeze(0)
# ### revoluteTransform ### joint_pvp
# pts, m = revoluteTransform(keypts_1.detach().cpu().numpy(), joint_pvp.detach().cpu().numpy(), joint_dir.detach().cpu().numpy(), cur_joint_angle)
# m = torch.from_numpy(m).float().cuda() ### 4 x 4
# kpts_expanded = torch.cat([keypts_1, torch.ones((keypts_1.size(0), 1), dtype=torch.float32).cuda()], dim=-1) #### kpts_expanded
# pts = torch.matmul(kpts_expanded, m)
# # pts = torch.matmul(keypts_1, m[:3]) ### n_keypts x 3 xxx 3 x 4 -->
# pts = pts[:, :3] ### pts: n_keypts x 3
pts = moving_dis + keypts_1
# if st_def_pcs is not None:
# part1_pc = st_def_pcs[0][0]
# part1_pc_expanded = torch.cat([])
#### distance_pts_faces, pts_in_faces, delta_faces_pts_ds ####
delta_faces_pts_ds, pts_in_faces = get_distance_pts_faces(pts, sel_faces_vals2, sel_faces_vns2)
# mesh_pts_expanded = torch.cat([verts1, torch.ones((verts1.size(0), 1), dtype=torch.float32).cuda()], dim=-1) #### kpts_expanded
# mesh_pts_expanded = torch.matmul(mesh_pts_expanded, m)
# # pts = torch.matmul(keypts_1, m[:3]) ### n_keypts x 3 xxx 3 x 4 -->
# mesh_pts_expanded = mesh_pts_expanded[:, :3] ### pts: n_keypts x 3
mesh_pts_expanded = verts1 + moving_dis
mesh_pts_sequence.append(mesh_pts_expanded.clone())
# if len(delta_faces_pts_ds_sequence) > 0 and i == 0 or i == n_sim_steps - 1:
if len(delta_faces_pts_ds_sequence) > 0:
prev_delta_faces_pts_ds = delta_faces_pts_ds_sequence[-1]
prev_pts_in_faces = pts_in_faces_sequence[-1] ### not important
prev_pts = keypts_sequence[-1]
sgn_delta_faces_ds = torch.sign(delta_faces_pts_ds) ### sign of faces_ds
sgn_prev_delta_faces_ds = torch.sign(prev_delta_faces_pts_ds) ### sign of prev_faces_ds
### different signs ###
collision_pts_faces = (sgn_delta_faces_ds != sgn_prev_delta_faces_ds).float() ### n_pts x n_faces
if penalize_largest:
### delta_faces_pts_ds: n_pts x n_faces; pts_in_faces: n_pts x n_faces ###
abs_filtered_delta_faces_pts_ds = torch.abs(delta_faces_pts_ds) * collision_pts_faces * pts_in_faces ###
maxx_abs_pts_faces_ds, maxx_abs_pts_faces_ds_idxes = torch.max(abs_filtered_delta_faces_pts_ds, dim=-1) ###
maxx_abs_pts_faces_ds_idxes = maxx_abs_pts_faces_ds_idxes.unsqueeze(-1).contiguous().repeat(1, pts_in_faces.size(1)) ###
pts_faces_ds_idxes_range = torch.arange(pts_in_faces.size(1)).contiguous().unsqueeze(0).cuda()
maxx_mask = (pts_faces_ds_idxes_range == maxx_abs_pts_faces_ds_idxes)
cur_delta_faces_pts_ds = torch.zeros_like(delta_faces_pts_ds)
cur_delta_faces_pts_ds[maxx_mask] = delta_faces_pts_ds[maxx_mask]
# delta_faces_pts_ds = cur_delta_faces_pts_ds
else:
cur_delta_faces_pts_ds = delta_faces_pts_ds
# collision_dists = collision_pts_faces * delta_faces_pts_ds ### n_pts x n_faces
collision_dists = collision_pts_faces * cur_delta_faces_pts_ds ### n_pts x n_faces
### collision_dists ### n_pts x n_faces
### i think the sim step
### whether tow meshes collide with each other: in the target mesh,
### collision_pulse, collision_dists
collision_pulse = (collision_dists * pts_in_faces * non_collided_pts.unsqueeze(-1)).unsqueeze(-1) * sel_faces_vns2.unsqueeze(0) ### n_pts x n_faces x 3 --> pulse
collided_indicator = ((pts_in_faces * non_collided_pts.unsqueeze(-1) * collision_pts_faces).sum(-1) > 0.1).float()
if pts_loss:
### loss version v2: for pts directly ###
collision_loss = torch.sum(collision_pulse * pts.unsqueeze(1), dim=-1) ### n_pts x n_faces ###
### loss version v2: for pts directly ###
else:
### loss version v1: for pts directly ###
delta_keypts = pts - prev_pts ### from the previous keypts to hte current keypts ### pts - prev_pts
collision_loss = torch.sum(collision_pulse * delta_keypts.unsqueeze(1), dim=-1) ### n_pts x n_faces ###
### loss version v1: for pts directly ###
non_collided_indicator = 1.0 - collided_indicator # - (collision_pulse.sum(-1).sum(-1) > 1e-6).float()
# print(f"collision_pulse: {collision_pulse.size()}, collided_indicator: {collided_indicator.size()}, non_collided_pts: {non_collided_pts.size()}")
# non_collided_pts[collided_indicator] = non_collided_pts[collided_indicator] * 0.
non_collided_pts = non_collided_pts * non_collided_indicator
# print(f"collision_loss: {collision_loss.sum().mean().item()}, collided_indicator: {collided_indicator.sum(-1).item()}, non_collided_pts: {non_collided_pts.sum(-1).item()}")
# ### loss version v2: for pts directly ###
# collision_loss = torch.sum(collision_pulse * pts.unsqueeze(1), dim=-1) ### n_pts x n_faces ###
# ### loss version v2: for pts directly ###
collision_loss = torch.sum(collision_loss, dim=-1).sum() ### ## for all faces ###
tot_collision_loss += collision_loss
if early_stop and collision_loss.item() > 0.0001:
break
delta_faces_pts_ds_sequence.append(delta_faces_pts_ds.clone())
pts_in_faces_sequence.append(pts_in_faces.clone())
keypts_sequence.append(pts.clone())
# tot_collision_loss /= n_sim_steps
### delta_faces_
# if
# print(f"tot_collision_loss: {tot_collision_loss}")
### can even test for one part at first
return tot_collision_loss, keypts_sequence, mesh_pts_sequence ### collision_loss for all sim steps ###
def collision_loss_sim_sequence_ours(verts1, verts2, faces1, faces2, base_pts, use_delta=False, sel_faces_values=None): # inputs are torch tensors #
# joint_dir, joint_pvp, joint_angle = joints
# joint_dir = joint_dir.cuda()
# joint_pvp = joint_pvp.cuda()
### should save the sequence of transformed shapes ... ###
# verts1, faces1 = mesh_1
# verts2, faces2 = mesh_2
### verts2, keypts_2.detach() ###
### verts2, keypts_2 ###
verts2 = verts2.detach()
faces2_exp = faces2.unsqueeze(0).repeat(verts2.size(0), 1, 1).contiguous()
faces_values2 = model_util.batched_index_select_ours(verts2, indices=faces2_exp, dim=1) # nf x nn_face x 3 x 3 #
faces_values2 = faces_values2.mean(dim=-2) # nf x nn_faces x 3
if sel_faces_values is None:
dist_hand_to_obj_verts = torch.sum(
(verts1.detach().cpu().unsqueeze(-2) - faces_values2.detach().cpu().unsqueeze(1)) ** 2, dim=-1 # ### nf x nn_hand_verts x nn_obj_verts
)
minn_dist_obj_to_hand, _ = torch.min(dist_hand_to_obj_verts, dim=0)
minn_dist_obj_to_hand, _ = torch.min(minn_dist_obj_to_hand, dim=0) # nn_obj_vert
minn_dist_obj_to_hand_argsort = torch.argsort(minn_dist_obj_to_hand, dim=0, descending=False)
cur_p_sel_faces = minn_dist_obj_to_hand_argsort[:base_pts.size(1)]
cur_p_sel_faces = model_util.batched_index_select_ours(faces2.detach().cpu(), indices=cur_p_sel_faces, dim=0)
sel_verts = []
sel_faces = []
sel_faces_values = []
minn_dist_pts_verts_idx = None
# cur_p_sel_faces = None
for i_fr in range(verts2.size(0)):
cur_p_sel_verts, cur_p_sel_faces, cur_p_sel_faces_vals, minn_dist_pts_verts_idx = get_sub_verts_faces_from_pts(verts2[i_fr].detach().cpu(), faces2.detach().cpu(), base_pts[i_fr].detach().cpu(), rt_sel_faces=True, minn_dist_pts_verts_idx=minn_dist_pts_verts_idx, sel_faces=cur_p_sel_faces)
sel_verts.append(cur_p_sel_verts)
sel_faces.append(cur_p_sel_faces)
sel_faces_values.append(cur_p_sel_faces_vals)
# keypts_2 = keypts_2.detach()
# faces: nf x 3 # verts: nn_verts x 3 #
# verts1: nn_verts x 3; faces1_values: nn_faces x 3 x 3 -> faces vertices
# faces1_values = model_util.batched_index_select_ours(values=verts1, indices=faces1, dim=0)
# faces2_values = model_util.batched_index_select_ours(values=verts2, indices=faces2, dim=0)
# faces1_normals = get_faces_normals(faces1_values)
# faces2_normals = get_faces_normals(faces2_values)
# delta_verts1_pts_abs, verts1_in_feats = get_distance_pts_faces(verts1, faces2_values, faces2_normals)
# verts1 = verts
# get_distance_pts_faces(pts, faces_vals, faces_vns):
# sel_faces_vns2 = sel_faces_vns2.detach()
# ### pts_in_faces & delta_ds in two adjacent time stamps ###
# ### collision response for the loss term? ###
# ### penetration depth * face_ns for all collided pts
# # keypts_sequence = [keypts_1.clone()]
keypts_sequence = []
delta_faces_pts_ds_sequence = []
pts_in_faces_sequence = []
# ### pivot point prediction, joint axis direction prediction ###
# # ### joint axis direction prediction ### #### get part joints...
# # joint_dir = joints["axis"]["dir"]
# # joint_pvp = joints["axis"]["center"]
# # joint_a = joints["axis"]["a"]
# # joint_b = joints["axis"]["b"]
# delta_joint_angle = (float(joint_b) - float(joint_a)) / float(n_sim_steps - 1) ### delta_joint_angles ###
tot_collision_loss = 0.
# non_collided_pts = torch.ones((keypts_1.size(0), ), dtype=torch.float32, requires_grad=False) # .cuda()
# mesh_pts_sequence = []
# def_pcs_sequence = []
sv_dict = {
'hand_verts': verts1.detach().cpu().numpy(),
'obj_verts': verts2.detach().cpu().numpy(),
'obj_faces': faces2.detach().cpu().numpy(),
'base_pts': base_pts.detach().cpu().numpy(),
}
sv_dict_fn = "tmp_dict.npy"
np.save(sv_dict_fn, sv_dict)
n_frames = verts1.shape[0]
print(f"cur_nframes: {n_frames}")
for i in range(0, n_frames): ### joint_angle ###
# if not back_sim:
# cur_joint_angle = joint_a + i * delta_joint_angle ### delta_joint_angle ###
# else:
# cur_joint_angle = joint_b - i * delta_joint_angle
if use_delta:
delta_joint_angle = (joint_b - joint_a) / 100.0
''' Prev. arti. state '''
cur_st_joint_angle = np.random.uniform(low=joint_a + delta_joint_angle, high=joint_b, size=(1,)).item() #### lower and upper limits of simulation angles
### revoluteTransform ### joint_pvp
prev_pts, prev_m = revoluteTransform(keypts_1.detach().cpu().numpy(), joint_pvp.detach().cpu().numpy(), joint_dir.detach().cpu().numpy(), cur_st_joint_angle) ### st_joint_angle
prev_m = torch.from_numpy(prev_m).float().cuda() ### 4 x 4
kpts_expanded = torch.cat([keypts_1, torch.ones((keypts_1.size(0), 1), dtype=torch.float32).cuda()], dim=-1) #### kpts_expanded
prev_pts = torch.matmul(kpts_expanded, prev_m)
# pts = torch.matmul(keypts_1, m[:3]) ### n_keypts x 3 xxx 3 x 4 -->
prev_pts = prev_pts[:, :3] ### pts: n_keypts x 3
# prev_pts = verts1[i]
#### distance_pts_faces, pts_in_faces, delta_faces_pts_ds ####
prev_delta_faces_pts_ds, prev_pts_in_faces = get_distance_pts_faces(prev_pts, sel_faces_vals2, sel_faces_vns2)
''' Prev. arti. state '''
''' Current state '''
cur_ed_joint_angle = cur_st_joint_angle - delta_joint_angle
### revoluteTransform ### joint_pvp
pts, m = revoluteTransform(keypts_1.detach().cpu().numpy(), joint_pvp.detach().cpu().numpy(), joint_dir.detach().cpu().numpy(), cur_st_joint_angle) ### st_joint_angle
m = torch.from_numpy(m).float().cuda() ### 4 x 4
kpts_expanded = torch.cat([keypts_1, torch.ones((keypts_1.size(0), 1), dtype=torch.float32).cuda()], dim=-1) #### kpts_expanded
pts = torch.matmul(kpts_expanded, m)
# pts = torch.matmul(keypts_1, m[:3]) ### n_keypts x 3 xxx 3 x 4 -->
pts = pts[:, :3] ### pts: n_keypts x 3
#### distance_pts_faces, pts_in_faces, delta_faces_pts_ds ####
delta_faces_pts_ds, pts_in_faces = get_distance_pts_faces(pts, sel_faces_vals2, sel_faces_vns2)
''' Current state '''
sgn_delta_faces_ds = torch.sign(delta_faces_pts_ds) ### sign of faces_ds
sgn_prev_delta_faces_ds = torch.sign(prev_delta_faces_pts_ds) ### sign of prev_faces_ds
### different signs ###
collision_pts_faces = (sgn_delta_faces_ds != sgn_prev_delta_faces_ds).float() ### n_pts x n_faces
cur_delta_faces_pts_ds = delta_faces_pts_ds
collision_dists = collision_pts_faces * cur_delta_faces_pts_ds ### n_pts x n_faces
### collision_dists ### n_pts x n_faces
### i think the sim step
### whether tow meshes collide with each other: in the target mesh,
### collision_pulse, collision_dists
# collision_pulse = (collision_dists * pts_in_faces * non_collided_pts.unsqueeze(-1)).unsqueeze(-1) * sel_faces_vns2.unsqueeze(0) ### n_pts x n_faces x 3 --> pulse
collision_pulse = (collision_dists * pts_in_faces).unsqueeze(-1) * sel_faces_vns2.unsqueeze(0) ### n_pts x n_faces x 3 --> pulse
# collided_indicator = ((pts_in_faces * non_collided_pts.unsqueeze(-1) * collision_pts_faces).sum(-1) > 0.1).float()
### loss version v2: for pts directly ###
collision_loss = torch.sum(collision_pulse * pts.unsqueeze(1), dim=-1) ### n_pts x n_faces ###
### loss version v2: for pts directly ###
# non_collided_indicator = 1.0 - collided_indicator # - (collision_pulse.sum(-1).sum(-1) > 1e-6).float()
# # print(f"collision_pulse: {collision_pulse.size()}, collided_indicator: {collided_indicator.size()}, non_collided_pts: {non_collided_pts.size()}")
# # non_collided_pts[collided_indicator] = non_collided_pts[collided_indicator] * 0.
# non_collided_pts = non_collided_pts * non_collided_indicator
# # print(f"collision_loss: {collision_loss.sum().mean().item()}, collided_indicator: {collided_indicator.sum(-1).item()}, non_collided_pts: {non_collided_pts.sum(-1).item()}")
# ### loss version v2: for pts directly ###
# collision_loss = torch.sum(collision_pulse * pts.unsqueeze(1), dim=-1) ### n_pts x n_faces ###
# ### loss version v2: for pts directly ###
collision_loss = torch.sum(collision_loss, dim=-1).sum() ### ## for all faces ###
tot_collision_loss += collision_loss
else:
### revoluteTransform ### joint_pvp
# pts, m = revoluteTransform(keypts_1.detach().cpu().numpy(), joint_pvp.detach().cpu().numpy(), joint_dir.detach().cpu().numpy(), cur_joint_angle)
# m = torch.from_numpy(m).float().cuda() ### 4 x 4
# kpts_expanded = torch.cat([keypts_1, torch.ones((keypts_1.size(0), 1), dtype=torch.float32).cuda()
# ], dim=-1) #### kpts_expanded
# pts = torch.matmul(kpts_expanded, m)
# pts = torch.matmul(keypts_1, m[:3]) ### n_keypts x 3 xxx 3 x 4 -->
# pts = pts[:, :3] ### pts: n_keypts x 3
pts = verts1[i]
cur_verts2 = verts2[i]
# cur_face_values2 = model_util.batched_index_select_ours(values=cur_verts2, indices=faces2, dim=0) # nn_verts x 3 x 3 for the verts and the faces #
# cur_face_normals = get_faces_normals(cur_face_values2)
# #### distance_pts_faces, pts_in_faces, delta_faces_pts_ds ####
# delta_faces_pts_ds, pts_in_faces = get_distance_pts_faces(pts.detach().cpu(), cur_face_values2.detach().cpu(), cur_face_normals.detach().cpu())
cur_face_values2 = sel_faces_values[i]
cur_face_normals = get_faces_normals(sel_faces_values[i])
delta_faces_pts_ds, pts_in_faces = get_distance_pts_faces(pts.detach().cpu(), sel_faces_values[i].detach().cpu(), cur_face_normals.detach().cpu())
# mesh_pts_expanded = torch.cat([verts1, torch.ones((verts1.size(0), 1), dtype=torch.float32).cuda()], dim=-1) #### kpts_expanded
# mesh_pts_expanded = torch.matmul(mesh_pts_expanded, m)
# # pts = torch.matmul(keypts_1, m[:3]) ### n_keypts x 3 xxx 3 x 4 -->
# mesh_pts_expanded = mesh_pts_expanded[:, :3] ### pts: n_keypts x 3
# mesh_pts_sequence.append(mesh_pts_expanded.clone())
# if len(delta_faces_pts_ds_sequence) > 0 and i == 0 or i == n_sim_steps - 1:
if len(delta_faces_pts_ds_sequence) > 0:
prev_delta_faces_pts_ds = delta_faces_pts_ds_sequence[-1]
# prev_pts_in_faces = pts_in_faces_sequence[-1] ### not important
prev_pts = keypts_sequence[-1]
sgn_delta_faces_ds = torch.sign(delta_faces_pts_ds) ### sign of faces_ds
sgn_prev_delta_faces_ds = torch.sign(prev_delta_faces_pts_ds) ### sign of prev_faces_ds
### different signs ###
collision_pts_faces = (sgn_delta_faces_ds != sgn_prev_delta_faces_ds).float() ### n_pts x n_faces
cur_delta_faces_pts_ds = delta_faces_pts_ds
collision_dists = collision_pts_faces * cur_delta_faces_pts_ds ### n_pts x n_faces
### collision_dists ### n_pts x n_faces
### i think the sim step
### whether tow meshes collide with each other: in the target mesh,
### collision_pulse, collision_dists
### collision pulse ####
# collision_pulse = 1.0 * (collision_dists * pts_in_faces * non_collided_pts.unsqueeze(-1)).unsqueeze(-1) * cur_face_normals.unsqueeze(0).detach().cpu() ### n_pts x n_faces x 3 --> pulse
# collided_indicator = ((pts_in_faces * non_collided_pts.unsqueeze(-1) * collision_pts_faces).sum(-1) > 0.1).float()
### collision pulse ####
### ==== collision loss v1 ==== ###
# # # pulse ! -> tegether wit hface normals #
# collision_pulse = 1.0 * (collision_dists * pts_in_faces).unsqueeze(-1) * cur_face_normals.unsqueeze(0).detach().cpu() ### n_pts x n_faces x 3 --> pulse
# collided_indicator = ((pts_in_faces * collision_pts_faces).sum(-1) > 0.1).float()
# ### loss version v2: for pts directly ### ### calculate collision_loss from collision_pulse and pts ###
# collision_loss = torch.sum(collision_pulse.cuda() * pts.unsqueeze(1), dim=-1) ### n_pts x n_faces ###
# collision_loss = collision_loss.mean()
### loss version v2: for pts directly ###
### ==== collision loss v1 ==== ###
### ==== collision loss v2 ==== ###
pts_in_faces = pts_in_faces.cuda()
# collision_pulse = 1.0 * (collision_dists * pts_in_faces).unsqueeze(-1) * cur_face_normals.unsqueeze(0).detach()
cur_face_avg_values = torch.mean(cur_face_values2.detach(), dim=-2) # nn_faces x 3 -> for the face_avg_values #
face_avg_values_to_key_pts = pts.unsqueeze(1) - cur_face_avg_values.unsqueeze(0).cuda() # nn_ptss x nn_faces x 3 -> from face avg values to joints here #
face_avg_values_to_key_pts = face_avg_values_to_key_pts * pts_in_faces.unsqueeze(-1) # * collision_pts_faces.unsqueeze(-1).cuda()
collision_loss = torch.mean((face_avg_values_to_key_pts ** 2).sum(dim=-1)) # nn_pts x nn_faces -> a single value here #
### ==== collision loss v2 ==== ###
# non_collided_indicator = 1.0 - collided_indicator # - (collision_pulse.sum(-1).sum(-1) > 1e-6).float()
# # print(f"collision_pulse: {collision_pulse.size()}, collided_indicator: {collided_indicator.size()}, non_collided_pts: {non_collided_pts.size()}")
# # non_collided_pts[collided_indicator] = non_collided_pts[collided_indicator] * 0.
# non_collided_pts = non_collided_pts * non_collided_indicator
# # print(f"collision_loss: {collision_loss.sum().mean().item()}, collided_indicator: {collided_indicator.sum(-1).item()}, non_collided_pts: {non_collided_pts.sum(-1).item()}")
# ### loss version v2: for pts directly ###
# collision_loss = torch.sum(collision_pulse * pts.unsqueeze(1), dim=-1) ### n_pts x n_faces ###
# ### loss version v2: for pts directly ###
# collision_loss = torch.sum(collision_loss, dim=-1).sum() ### ## for all faces ###
tot_collision_loss += collision_loss
# if early_stop and collision_loss.item() > 0.0001:
# break
delta_faces_pts_ds_sequence.append(delta_faces_pts_ds.clone())
pts_in_faces_sequence.append(pts_in_faces.clone())
keypts_sequence.append(pts.clone())
tot_collision_loss = tot_collision_loss / n_frames
# tot_collision_loss /= n_sim_steps
### delta_faces_
# if
# print(f"tot_collision_loss: {tot_collision_loss}")
### can even test for one part at first
return tot_collision_loss, sel_faces_values # , keypts_sequence, mesh_pts_sequence ### collision_loss for all sim steps ###
# collision loss and sim sequence
#
def collision_loss_sim_sequence_ours_ccd_rigid(verts1, verts2, faces1, faces2, base_pts, obj_rot, obj_trans,
use_delta=False, sel_faces_values=None, canon_verts1=None, canon_sel_faces_values=None): # inputs are torch tensors #
# joint_dir, joint_pvp, joint_angle = joints
# joint_dir = joint_dir.cuda()
# joint_pvp = joint_pvp.cuda()
### should save the sequence of transformed shapes ... ###
# verts1, faces1 = mesh_1
# verts2, faces2 = mesh_2
### verts2, keypts_2.detach() ###
### verts2, keypts_2 ###
verts2 = verts2.detach()
faces2_exp = faces2.unsqueeze(0).repeat(verts2.size(0), 1, 1).contiguous()
faces_values2 = model_util.batched_index_select_ours(verts2, indices=faces2_exp, dim=1) # nf x nn_face x 3 x 3 #
faces_values2 = faces_values2.mean(dim=-2) # nf x nn_faces x 3
if sel_faces_values is None:
dist_hand_to_obj_verts = torch.sum(
(verts1.detach().cpu().unsqueeze(-2) - faces_values2.detach().cpu().unsqueeze(1)) ** 2, dim=-1 # ### nf x nn_hand_verts x nn_obj_verts
)
minn_dist_obj_to_hand, _ = torch.min(dist_hand_to_obj_verts, dim=0)
minn_dist_obj_to_hand, _ = torch.min(minn_dist_obj_to_hand, dim=0) # nn_obj_vert
minn_dist_obj_to_hand_argsort = torch.argsort(minn_dist_obj_to_hand, dim=0, descending=False)
cur_p_sel_faces = minn_dist_obj_to_hand_argsort[:base_pts.size(1)]
cur_p_sel_faces = model_util.batched_index_select_ours(faces2.detach().cpu(), indices=cur_p_sel_faces, dim=0)
sel_verts = []
sel_faces = []
sel_faces_values = []
canon_sel_faces_values = []
canon_verts1 = []
minn_dist_pts_verts_idx = None
# cur_p_sel_faces = None
for i_fr in range(verts2.size(0)):
cur_p_sel_verts, cur_p_sel_faces, cur_p_sel_faces_vals, minn_dist_pts_verts_idx = get_sub_verts_faces_from_pts(verts2[i_fr].detach().cpu(), faces2.detach().cpu(), base_pts[i_fr].detach().cpu(), rt_sel_faces=True, minn_dist_pts_verts_idx=minn_dist_pts_verts_idx, sel_faces=cur_p_sel_faces)
sel_verts.append(cur_p_sel_verts)
sel_faces.append(cur_p_sel_faces)
sel_faces_values.append(cur_p_sel_faces_vals)
# cur_p_sel_faces_vals: nn_sel_faces x 3 x 3
cur_fr_obj_rot = obj_rot[i_fr].detach().cpu() # 3 x 3
cur_fr_obj_trans = obj_trans[i_fr].detach().cpu() # 3
cur_fr_verts1 = verts1[i_fr]
# cur_fr_canon_faces_values: nn_sel_faces x 3 x 3; cur_
cur_fr_canon_faces_values = torch.matmul(cur_p_sel_faces_vals - cur_fr_obj_trans.unsqueeze(0).unsqueeze(0), cur_fr_obj_rot.transpose(1, 0).unsqueeze(0)) #
# nn_verts x 3 xxxx 3 x 3 -> nn_verts x 3 #
cur_fr_canon_verts1 = torch.matmul(
cur_fr_verts1 - cur_fr_obj_trans.unsqueeze(0).cuda(), cur_fr_obj_rot.transpose(1, 0).cuda()
)
canon_verts1.append(cur_fr_canon_verts1) #
canon_sel_faces_values.append(cur_fr_canon_faces_values) # face values canonicalized #
# ### pts_in_faces & delta_ds in two adjacent time stamps ###
# ### collision response for the loss term? ###
# ### penetration depth * face_ns for all collided pts
# # keypts_sequence = [keypts_1.clone()]
keypts_sequence = []
delta_faces_pts_ds_sequence = []
pts_in_faces_sequence = []
# ### pivot point prediction, joint axis direction prediction ###
# # ### joint axis direction prediction ### #### get part joints...
# # joint_dir = joints["axis"]["dir"]
# # joint_pvp = joints["axis"]["center"]
# # joint_a = joints["axis"]["a"]
# # joint_b = joints["axis"]["b"]
# delta_joint_angle = (float(joint_b) - float(joint_a)) / float(n_sim_steps - 1) ### delta_joint_angles ###
# tot_collision_loss = 0.
# non_collided_pts = torch.ones((keypts_1.size(0), ), dtype=torch.float32, requires_grad=False) # .cuda()
# mesh_pts_sequence = []
# def_pcs_sequence = []
# sv_dict = {
# 'hand_verts': verts1.detach().cpu().numpy(),
# 'obj_verts': verts2.detach().cpu().numpy(),
# 'obj_faces': faces2.detach().cpu().numpy(),
# 'base_pts': base_pts.detach().cpu().numpy(),
# }
# sv_dict_fn = "tmp_dict.npy"
# np.save(sv_dict_fn, sv_dict)
# verts2:
n_frames = verts1.shape[0]
print(f"cur_nframes: {n_frames}")
for i in range(0, n_frames): ### joint_angle ###
pts = canon_verts1[i]
# cur_verts2 = verts2[i]
# cur_face_values2 = model_util.batched_index_select_ours(values=cur_verts2, indices=faces2, dim=0) # nn_verts x 3 x 3 for the verts and the faces #
# cur_face_normals = get_faces_normals(cur_face_values2)
# #### distance_pts_faces, pts_in_faces, delta_faces_pts_ds ####
# delta_faces_pts_ds, pts_in_faces = get_distance_pts_faces(pts.detach().cpu(), cur_face_values2.detach().cpu(), cur_face_normals.detach().cpu())
# cur_face_values2 = canon_sel_faces_values[i]
cur_face_normals = get_faces_normals(canon_sel_faces_values[i])
delta_faces_pts_ds, pts_in_faces = get_distance_pts_faces(pts.detach().cpu(), canon_sel_faces_values[i].detach().cpu(), cur_face_normals.detach().cpu())
# mesh_pts_expanded = torch.cat([verts1, torch.ones((verts1.size(0), 1), dtype=torch.float32).cuda()], dim=-1) #### kpts_expanded
# mesh_pts_expanded = torch.matmul(mesh_pts_expanded, m)
# # pts = torch.matmul(keypts_1, m[:3]) ### n_keypts x 3 xxx 3 x 4 -->
# mesh_pts_expanded = mesh_pts_expanded[:, :3] ### pts: n_keypts x 3
# mesh_pts_sequence.append(mesh_pts_expanded.clone())
# if len(delta_faces_pts_ds_sequence) > 0 and i == 0 or i == n_sim_steps - 1:
if len(delta_faces_pts_ds_sequence) > 0:
prev_pts = keypts_sequence[-1] # nn_tps x 3 #
prev_delta_faces_pts_ds = delta_faces_pts_ds_sequence[-1]
coef = 1.
coef_step = 0.1
while coef >= 0.:
print(f"coef: {coef}")
cur_step_pts = prev_pts + (pts - prev_pts) * coef
cur_step_delta_faces_pts_ds, cur_step_pts_in_faces = get_distance_pts_faces(cur_step_pts.detach().cpu(), canon_sel_faces_values[i].detach().cpu(), cur_face_normals.detach().cpu())
sgn_delta_faces_ds = torch.sign(cur_step_delta_faces_pts_ds) ### sign of faces_ds
sgn_prev_delta_faces_ds = torch.sign(prev_delta_faces_pts_ds) ### sign of prev_faces_ds
### different signs ###
collision_pts_faces = (sgn_delta_faces_ds != sgn_prev_delta_faces_ds).float() ### n_pts x n_faces
collision_pts_faces = ((collision_pts_faces + cur_step_pts_in_faces.float()) > 1.5).float()
collision_pts_faces_sum = collision_pts_faces.sum().item() # pts in faces? #
if collision_pts_faces_sum == 0:
break
coef = coef - coef_step
coef = max(0., coef)
pts = prev_pts + (pts - prev_pts) * coef
delta_faces_pts_ds = cur_step_delta_faces_pts_ds
pts_in_faces = cur_step_pts_in_faces
# non_collided_indicator = 1.0 - collided_indicator # - (collision_pulse.sum(-1).sum(-1) > 1e-6).float()
# # print(f"collision_pulse: {collision_pulse.size()}, collided_indicator: {collided_indicator.size()}, non_collided_pts: {non_collided_pts.size()}")
# # non_collided_pts[collided_indicator] = non_collided_pts[collided_indicator] * 0.
# non_collided_pts = non_collided_pts * non_collided_indicator
# # print(f"collision_loss: {collision_loss.sum().mean().item()}, collided_indicator: {collided_indicator.sum(-1).item()}, non_collided_pts: {non_collided_pts.sum(-1).item()}")
# ### loss version v2: for pts directly ###
# collision_loss = torch.sum(collision_pulse * pts.unsqueeze(1), dim=-1) ### n_pts x n_faces ###
# ### loss version v2: for pts directly ###
# collision_loss = torch.sum(collision_loss, dim=-1).sum() ### ## for all faces ###
# tot_collision_loss += collision_loss
# if early_stop and collision_loss.item() > 0.0001:
# break
delta_faces_pts_ds_sequence.append(delta_faces_pts_ds.clone())
pts_in_faces_sequence.append(pts_in_faces.clone())
keypts_sequence.append(pts.clone())
# tot_collision_loss = tot_collision_loss / n_frames
# tot_collision_loss /= n_sim_steps
### delta_faces_
# if
keypts_sequence = torch.stack(keypts_sequence, dim=0) ### nn_frames x nn_keypts x 3 ###
# print(f"tot_collision_loss: {tot_collision_loss}")
# sel_faces_values=None, canon_verts1=None, canon_sel_faces_values=None
### can even test for one part at first
return keypts_sequence, sel_faces_values, canon_verts1, canon_sel_faces_values # , keypts_sequence, mesh_pts_sequence ### collision_loss for all sim steps ###
def collision_loss_sim_sequence(verts1,
keypts_1, verts2, sel_faces_vals2, sel_faces_vns2, joints, n_sim_steps=100, back_sim=False, use_delta=False):
# joint_dir, joint_pvp, joint_angle = joints
# joint_dir = joint_dir.cuda()
# joint_pvp = joint_pvp.cuda()
### should save the sequence of transformed shapes ... ###
# verts1, faces1 = mesh_1
# verts2, faces2 = mesh_2
### verts2, keypts_2.detach() ###
### verts2, keypts_2 ###
verts2 = verts2.detach()
# keypts_2 = keypts_2.detach()
sel_faces_vns2 = sel_faces_vns2.detach()
### pts_in_faces & delta_ds in two adjacent time stamps ###
### collision response for the loss term? ###
### penetration depth * face_ns for all collided pts
# keypts_sequence = [keypts_1.clone()]
keypts_sequence = []
delta_faces_pts_ds_sequence = []
pts_in_faces_sequence = []
### pivot point prediction, joint axis direction prediction ###
# ### joint axis direction prediction ### #### get part joints...
joint_dir = joints["axis"]["dir"]
joint_pvp = joints["axis"]["center"]
joint_a = joints["axis"]["a"]
joint_b = joints["axis"]["b"]
delta_joint_angle = (float(joint_b) - float(joint_a)) / float(n_sim_steps - 1) ### delta_joint_angles ###
tot_collision_loss = 0.
non_collided_pts = torch.ones((keypts_1.size(0), ), dtype=torch.float32, requires_grad=False) # .cuda()
mesh_pts_sequence = []
def_pcs_sequence = []
for i in range(0, n_sim_steps): ### joint_angle ###
if not back_sim:
cur_joint_angle = joint_a + i * delta_joint_angle ### delta_joint_angle ###
else:
cur_joint_angle = joint_b - i * delta_joint_angle
if use_delta:
delta_joint_angle = (joint_b - joint_a) / 100.0
''' Prev. arti. state '''
cur_st_joint_angle = np.random.uniform(low=joint_a + delta_joint_angle, high=joint_b, size=(1,)).item() #### lower and upper limits of simulation angles
### revoluteTransform ### joint_pvp
prev_pts, prev_m = revoluteTransform(keypts_1.detach().cpu().numpy(), joint_pvp.detach().cpu().numpy(), joint_dir.detach().cpu().numpy(), cur_st_joint_angle) ### st_joint_angle
prev_m = torch.from_numpy(prev_m).float().cuda() ### 4 x 4
kpts_expanded = torch.cat([keypts_1, torch.ones((keypts_1.size(0), 1), dtype=torch.float32).cuda()], dim=-1) #### kpts_expanded
prev_pts = torch.matmul(kpts_expanded, prev_m)
# pts = torch.matmul(keypts_1, m[:3]) ### n_keypts x 3 xxx 3 x 4 -->
prev_pts = prev_pts[:, :3] ### pts: n_keypts x 3
#### distance_pts_faces, pts_in_faces, delta_faces_pts_ds ####
prev_delta_faces_pts_ds, prev_pts_in_faces = get_distance_pts_faces(prev_pts, sel_faces_vals2, sel_faces_vns2)
''' Prev. arti. state '''
''' Current state '''
cur_ed_joint_angle = cur_st_joint_angle - delta_joint_angle
### revoluteTransform ### joint_pvp
pts, m = revoluteTransform(keypts_1.detach().cpu().numpy(), joint_pvp.detach().cpu().numpy(), joint_dir.detach().cpu().numpy(), cur_st_joint_angle) ### st_joint_angle
m = torch.from_numpy(m).float().cuda() ### 4 x 4
kpts_expanded = torch.cat([keypts_1, torch.ones((keypts_1.size(0), 1), dtype=torch.float32).cuda()], dim=-1) #### kpts_expanded
pts = torch.matmul(kpts_expanded, m)
# pts = torch.matmul(keypts_1, m[:3]) ### n_keypts x 3 xxx 3 x 4 -->
pts = pts[:, :3] ### pts: n_keypts x 3
#### distance_pts_faces, pts_in_faces, delta_faces_pts_ds ####
delta_faces_pts_ds, pts_in_faces = get_distance_pts_faces(pts, sel_faces_vals2, sel_faces_vns2)
''' Current state '''
sgn_delta_faces_ds = torch.sign(delta_faces_pts_ds) ### sign of faces_ds
sgn_prev_delta_faces_ds = torch.sign(prev_delta_faces_pts_ds) ### sign of prev_faces_ds
### different signs ###
collision_pts_faces = (sgn_delta_faces_ds != sgn_prev_delta_faces_ds).float() ### n_pts x n_faces
cur_delta_faces_pts_ds = delta_faces_pts_ds
collision_dists = collision_pts_faces * cur_delta_faces_pts_ds ### n_pts x n_faces
### collision_dists ### n_pts x n_faces
### i think the sim step
### whether tow meshes collide with each other: in the target mesh,
### collision_pulse, collision_dists
# collision_pulse = (collision_dists * pts_in_faces * non_collided_pts.unsqueeze(-1)).unsqueeze(-1) * sel_faces_vns2.unsqueeze(0) ### n_pts x n_faces x 3 --> pulse
collision_pulse = (collision_dists * pts_in_faces).unsqueeze(-1) * sel_faces_vns2.unsqueeze(0) ### n_pts x n_faces x 3 --> pulse
# collided_indicator = ((pts_in_faces * non_collided_pts.unsqueeze(-1) * collision_pts_faces).sum(-1) > 0.1).float()
### loss version v2: for pts directly ###
collision_loss = torch.sum(collision_pulse * pts.unsqueeze(1), dim=-1) ### n_pts x n_faces ###
### loss version v2: for pts directly ###
# non_collided_indicator = 1.0 - collided_indicator # - (collision_pulse.sum(-1).sum(-1) > 1e-6).float()
# # print(f"collision_pulse: {collision_pulse.size()}, collided_indicator: {collided_indicator.size()}, non_collided_pts: {non_collided_pts.size()}")
# # non_collided_pts[collided_indicator] = non_collided_pts[collided_indicator] * 0.
# non_collided_pts = non_collided_pts * non_collided_indicator
# # print(f"collision_loss: {collision_loss.sum().mean().item()}, collided_indicator: {collided_indicator.sum(-1).item()}, non_collided_pts: {non_collided_pts.sum(-1).item()}")
# ### loss version v2: for pts directly ###
# collision_loss = torch.sum(collision_pulse * pts.unsqueeze(1), dim=-1) ### n_pts x n_faces ###
# ### loss version v2: for pts directly ###
collision_loss = torch.sum(collision_loss, dim=-1).sum() ### ## for all faces ###
tot_collision_loss += collision_loss
else:
### revoluteTransform ### joint_pvp
pts, m = revoluteTransform(keypts_1.detach().cpu().numpy(), joint_pvp.detach().cpu().numpy(), joint_dir.detach().cpu().numpy(), cur_joint_angle)
m = torch.from_numpy(m).float().cuda() ### 4 x 4
kpts_expanded = torch.cat([keypts_1, torch.ones((keypts_1.size(0), 1), dtype=torch.float32).cuda()
], dim=-1) #### kpts_expanded
pts = torch.matmul(kpts_expanded, m)
# pts = torch.matmul(keypts_1, m[:3]) ### n_keypts x 3 xxx 3 x 4 -->
pts = pts[:, :3] ### pts: n_keypts x 3
#### distance_pts_faces, pts_in_faces, delta_faces_pts_ds ####
delta_faces_pts_ds, pts_in_faces = get_distance_pts_faces(pts.detach().cpu(), sel_faces_vals2.detach().cpu(), sel_faces_vns2.detach().cpu())
mesh_pts_expanded = torch.cat([verts1, torch.ones((verts1.size(0), 1), dtype=torch.float32).cuda()], dim=-1) #### kpts_expanded
mesh_pts_expanded = torch.matmul(mesh_pts_expanded, m)
# pts = torch.matmul(keypts_1, m[:3]) ### n_keypts x 3 xxx 3 x 4 -->
mesh_pts_expanded = mesh_pts_expanded[:, :3] ### pts: n_keypts x 3
mesh_pts_sequence.append(mesh_pts_expanded.clone())
# if len(delta_faces_pts_ds_sequence) > 0 and i == 0 or i == n_sim_steps - 1:
if len(delta_faces_pts_ds_sequence) > 0:
prev_delta_faces_pts_ds = delta_faces_pts_ds_sequence[-1]
# prev_pts_in_faces = pts_in_faces_sequence[-1] ### not important
prev_pts = keypts_sequence[-1]
sgn_delta_faces_ds = torch.sign(delta_faces_pts_ds) ### sign of faces_ds
sgn_prev_delta_faces_ds = torch.sign(prev_delta_faces_pts_ds) ### sign of prev_faces_ds
### different signs ###
collision_pts_faces = (sgn_delta_faces_ds != sgn_prev_delta_faces_ds).float() ### n_pts x n_faces
cur_delta_faces_pts_ds = delta_faces_pts_ds
collision_dists = collision_pts_faces * cur_delta_faces_pts_ds ### n_pts x n_faces
### collision_dists ### n_pts x n_faces
### i think the sim step
### whether tow meshes collide with each other: in the target mesh,
### collision_pulse, collision_dists
collision_pulse = 1.0 * (collision_dists * pts_in_faces * non_collided_pts.unsqueeze(-1)).unsqueeze(-1) * sel_faces_vns2.unsqueeze(0).detach().cpu() ### n_pts x n_faces x 3 --> pulse
collided_indicator = ((pts_in_faces * non_collided_pts.unsqueeze(-1) * collision_pts_faces).sum(-1) > 0.1).float()
### loss version v2: for pts directly ### ### calculate collision_loss from collision_pulse and pts ###
collision_loss = torch.sum(collision_pulse.cuda() * pts.unsqueeze(1), dim=-1) ### n_pts x n_faces ###
### loss version v2: for pts directly ###
non_collided_indicator = 1.0 - collided_indicator # - (collision_pulse.sum(-1).sum(-1) > 1e-6).float()
# print(f"collision_pulse: {collision_pulse.size()}, collided_indicator: {collided_indicator.size()}, non_collided_pts: {non_collided_pts.size()}")
# non_collided_pts[collided_indicator] = non_collided_pts[collided_indicator] * 0.
non_collided_pts = non_collided_pts * non_collided_indicator
# print(f"collision_loss: {collision_loss.sum().mean().item()}, collided_indicator: {collided_indicator.sum(-1).item()}, non_collided_pts: {non_collided_pts.sum(-1).item()}")
# ### loss version v2: for pts directly ###
# collision_loss = torch.sum(collision_pulse * pts.unsqueeze(1), dim=-1) ### n_pts x n_faces ###
# ### loss version v2: for pts directly ###
collision_loss = torch.sum(collision_loss, dim=-1).sum() ### ## for all faces ###
tot_collision_loss += collision_loss
# if early_stop and collision_loss.item() > 0.0001:
# break
delta_faces_pts_ds_sequence.append(delta_faces_pts_ds.clone())
pts_in_faces_sequence.append(pts_in_faces.clone())
keypts_sequence.append(pts.clone())
# tot_collision_loss /= n_sim_steps
### delta_faces_
# if
# print(f"tot_collision_loss: {tot_collision_loss}")
### can even test for one part at first
return tot_collision_loss, keypts_sequence, mesh_pts_sequence ### collision_loss for all sim steps ###