gene-hoi-denoising / utils /model_util.py
meow
readme
1b00369
raw
history blame
91.3 kB
# from model.mdm import MDM
# from model.mdm_ours import MDM as MDM_Ours
# from model.mdm_ours import MDMV3 as MDM_Ours_V3
# from model.mdm_ours import MDMV4 as MDM_Ours_V4
# from model.mdm_ours import MDMV5 as MDM_Ours_V5
# from model.mdm_ours import MDMV6 as MDM_Ours_V6
# from model.mdm_ours import MDMV7 as MDM_Ours_V7
# from model.mdm_ours import MDMV8 as MDM_Ours_V8
# from model.mdm_ours import MDMV9 as MDM_Ours_V9
from model.mdm_ours import MDMV10 as MDM_Ours_V10
# from model.mdm_ours import MDMV11 as MDM_Ours_V11
# MDM_Ours_V12
from model.mdm_ours import MDMV12 as MDM_Ours_V12
# MDM_Ours_V13
# from model.mdm_ours import MDMV13 as MDM_Ours_V13
# # MDM_Ours_V14
# from model.mdm_ours import MDMV14 as MDM_Ours_V14
from diffusion import gaussian_diffusion as gd
from diffusion.respace import SpacedDiffusion, space_timesteps
from utils.parser_util import get_cond_mode
import torch
from torch import optim, nn
import torch.nn.functional as F
from manopth.manolayer import ManoLayer
import numpy as np
import trimesh
import os
from diffusion.respace_ours import SpacedDiffusion as SpacedDiffusion_Ours
# SpacedDiffusionV2
from diffusion.respace_ours import SpacedDiffusionV2 as SpacedDiffusion_OursV2
from diffusion.respace_ours import SpacedDiffusionV3 as SpacedDiffusion_OursV3
# SpacedDiffusionV4
from diffusion.respace_ours import SpacedDiffusionV4 as SpacedDiffusion_OursV4
# SpacedDiffusion_OursV5
from diffusion.respace_ours import SpacedDiffusionV5 as SpacedDiffusion_OursV5
# SpacedDiffusion_OursV6
from diffusion.respace_ours import SpacedDiffusionV6 as SpacedDiffusion_OursV6
# SpacedDiffusion_OursV7
from diffusion.respace_ours import SpacedDiffusionV7 as SpacedDiffusion_OursV7
from diffusion.respace_ours import SpacedDiffusionV9 as SpacedDiffusion_OursV9
def batched_index_select_ours(values, indices, dim = 1):
value_dims = values.shape[(dim + 1):]
values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
indices = indices[(..., *((None,) * len(value_dims)))]
indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
value_expand_len = len(indices_shape) - (dim + 1)
values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
value_expand_shape = [-1] * len(values.shape)
expand_slice = slice(dim, (dim + value_expand_len))
value_expand_shape[expand_slice] = indices.shape[expand_slice]
values = values.expand(*value_expand_shape)
dim += value_expand_len
return values.gather(dim, indices)
def gaussian_entropy(logvar): # gaussian entropy ##
const = 0.5 * float(logvar.size(1)) * (1. + np.log(np.pi * 2))
ent = 0.5 * logvar.sum(dim=1, keepdim=False) + const
return ent
def standard_normal_logprob(z): # feature dim
dim = z.size(-1) # dim size -1
log_z = -0.5 * dim * np.log(2 * np.pi)
return log_z - z.pow(2) / 2
def load_multiple_models_fr_path(model_path, model):
model_paths = model_path.split(";")
print(f"Loading multiple models with split model_path: {model_paths}")
setting_to_model_path = {}
for cur_path in model_paths:
cur_setting_nm, cur_model_path = cur_path.split(':')
setting_to_model_path[cur_setting_nm] = cur_model_path
loaded_dict = {}
for cur_setting in setting_to_model_path:
cur_model_path = setting_to_model_path[cur_setting]
cur_model_state_dict = torch.load(cur_model_path, map_location='cpu')
if cur_setting == 'diff_realbasejtsrel':
interested_keys = [
'real_basejtsrel_input_process', 'real_basejtsrel_sequence_pos_encoder', 'real_basejtsrel_seqTransEncoder', 'real_basejtsrel_embed_timestep', 'real_basejtsrel_sequence_pos_denoising_encoder', 'real_basejtsrel_denoising_seqTransEncoder', 'real_basejtsrel_output_process'
]
elif cur_setting == 'diff_basejtsrel':
interested_keys = [
'avg_joints_sequence_input_process', 'joints_offset_input_process', 'sequence_pos_encoder', 'seqTransEncoder', 'logvar_seqTransEncoder', 'embed_timestep', 'basejtsrel_denoising_embed_timestep', 'sequence_pos_denoising_encoder', 'basejtsrel_denoising_seqTransEncoder', 'basejtsrel_glb_denoising_latents_trans_layer', 'avg_joint_sequence_output_process', 'joint_offset_output_process', 'output_process'
]
elif cur_setting == 'diff_realbasejtsrel_to_joints':
interested_keys = [
'real_basejtsrel_to_joints_input_process', 'real_basejtsrel_to_joints_sequence_pos_encoder', 'real_basejtsrel_to_joints_seqTransEncoder', 'real_basejtsrel_to_joints_embed_timestep', 'real_basejtsrel_to_joints_sequence_pos_denoising_encoder', 'real_basejtsrel_to_joints_denoising_seqTransEncoder', 'real_basejtsrel_to_joints_output_process',
]
else:
raise ValueError(f"cur_setting:{cur_setting} Not implemented yet")
for k in cur_model_state_dict:
for cur_inter_key in interested_keys:
if cur_inter_key in k:
loaded_dict[k] = cur_model_state_dict[k]
model_dict = model.state_dict()
model_dict.update(loaded_dict)
model.load_state_dict(model_dict)
def load_model_wo_clip(model, state_dict): # missing_keys: in the current model but not found in the state_dict? # unexpected_keys: not in the current model but found inthe state_dict?
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
# print(unexpected_keys)
assert len(unexpected_keys) == 0
assert all([k.startswith('clip_model.') for k in missing_keys])
### create model and diffusion ## #
def create_model_and_diffusion(args, data):
if args.dataset in ['motion_ours'] and args.rep_type in ["obj_base_rel_dist", "ambient_obj_base_rel_dist"]:
model = MDM_Ours(**get_model_args(args, data))
elif args.dataset in ['motion_ours'] and args.rep_type in ["obj_base_rel_dist_we"]:
model = MDM_Ours_V3(**get_model_args(args, data))
# MDM_Ours_V4
elif args.dataset in ['motion_ours'] and args.rep_type in ["obj_base_rel_dist_we_wj"]:
model = MDM_Ours_V4(**get_model_args(args, data))
# obj_base_rel_dist_we_wj_latents
elif args.dataset in ['motion_ours'] and args.rep_type in ["obj_base_rel_dist_we_wj_latents"]:
if args.diff_spatial:
if args.pred_joints_offset:
if args.diff_joint_quants:
model = MDM_Ours_V13(**get_model_args(args, data))
elif args.diff_hand_params:
model = MDM_Ours_V14(**get_model_args(args, data))
else:
if args.finetune_with_cond:
print(f"Using MDM ours V12!!!!")
model = MDM_Ours_V12(**get_model_args(args, data))
else:
print(f"Using MDM ours V10!!!!")
model = MDM_Ours_V10(**get_model_args(args, data))
else:
print(f"Using MDM ours V9!!!!")
model = MDM_Ours_V9(**get_model_args(args, data))
elif args.diff_latents:
print(f"Using MDM ours V11!!!!")
model = MDM_Ours_V11(**get_model_args(args, data))
elif args.use_sep_models:
if args.use_vae:
if args.pred_basejtsrel_avgjts:
print(f"Using MDM ours V8!!!!")
model = MDM_Ours_V8(**get_model_args(args, data))
else:
model = MDM_Ours_V7(**get_model_args(args, data))
else:
model = MDM_Ours_V6(**get_model_args(args, data))
else:
model = MDM_Ours_V5(**get_model_args(args, data))
else:
model = MDM(**get_model_args(args, data))
diffusion = create_gaussian_diffusion(args)
return model, diffusion
# give utils to models #
def get_model_args(args, data):
# default_args
clip_version = 'ViT-B/32'
action_emb = 'tensor' ## get model arguments ##
cond_mode = get_cond_mode(args)
if hasattr(data.dataset, 'num_actions'):
num_actions = data.dataset.num_actions
else:
num_actions = 1
# SMPL defaults
data_rep = 'rot6d'
njoints = 25
nfeats = 6
if args.dataset in ['humanml']: ## from
data_rep = 'hml_vec'
njoints = 263 # joints
nfeats = 1
elif args.dataset in ['motion_ours']:
data_rep = 'xyz'
njoints = 21
nfeats = 3
elif args.dataset == 'kit':
data_rep = 'hml_vec'
njoints = 251
nfeats = 1
## modeltype;
return {'modeltype': '', 'njoints': njoints, 'nfeats': nfeats, 'num_actions': num_actions,
'translation': True, 'pose_rep': 'rot6d', 'glob': True, 'glob_rot': True,
'latent_dim': args.latent_dim, 'ff_size': 1024, 'num_layers': args.layers, 'num_heads': 4,
'dropout': 0.1, 'activation': "gelu", 'data_rep': data_rep, 'cond_mode': cond_mode,
'cond_mask_prob': args.cond_mask_prob, 'action_emb': action_emb, 'arch': args.arch,
'emb_trans_dec': args.emb_trans_dec, 'clip_version': clip_version, 'dataset': args.dataset, 'args': args}
def optimize_sampled_hand_joints(sampled_joints, rel_base_pts_to_joints, dists_base_pts_to_joints, base_pts, base_normals):
# sampled_joints: bsz x ws x nnj x 3
# signed distances
# smoothness
bsz, ws, nnj = sampled_joints.shape[:3]
device = sampled_joints.device
coarse_lr = 0.1
num_iters = 100 # if i_iter > 0 else 1 ## nn-coarse-iters for global transformations #
mano_path = "/data1/sim/mano_models/mano/models"
base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous()
base_normals_exp = base_normals.unsqueeze(1).repeat(1, ws, 1, 1).contiguous()
signed_dist_e_coeff = 1.0
signed_dist_e_coeff = 0.0
### start optimization ###
# setup MANO layer
mano_layer = ManoLayer(
flat_hand_mean=True,
side='right',
mano_root=mano_path, # mano_path for the mano model #
ncomps=24,
use_pca=True,
root_rot_mode='axisang',
joint_rot_mode='axisang'
).to(device)
## random init variables ##
beta_var = torch.randn([bsz, 10]).to(device)
rot_var = torch.randn([bsz * ws, 3]).to(device)
theta_var = torch.randn([bsz * ws, 24]).to(device)
transl_var = torch.randn([bsz * ws, 3]).to(device)
beta_var.requires_grad_()
rot_var.requires_grad_()
theta_var.requires_grad_()
transl_var.requires_grad_()
opt = optim.Adam([rot_var, transl_var], lr=coarse_lr)
for i_iter in range(num_iters):
opt.zero_grad()
# mano_layer #
hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
beta_var.unsqueeze(1).repeat(1, ws, 1).view(-1, 10), transl_var)
hand_verts = hand_verts.view(bsz, ws, 778, 3) * 0.001 ## bsz x ws x nn
hand_joints = hand_joints.view(bsz, ws, -1, 3) * 0.001
### === e1 should be close to predicted values === ###
# bsz x ws x nnj x nnb x 3 #
rel_base_pts_to_hand_joints = hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)
# bs zx ws x nnj x nnb #
signed_dist_base_pts_to_hand_joints = torch.sum(
rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
)
rel_e = torch.sum(
(rel_base_pts_to_hand_joints - rel_base_pts_to_joints) ** 2, dim=-1
).mean()
if dists_base_pts_to_joints is not None:
dist_e = torch.sum(
(signed_dist_base_pts_to_hand_joints - dists_base_pts_to_joints) ** 2, dim=-1
).mean()
else:
dist_e = torch.zeros((1,), dtype=torch.float32).to(device).mean()
### === e2 the signed distances to nearest points should not be negative to the neareste === ###
## base_pts: bsz x nn_base_pts x 3
## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ##
## bsz x ws x nnj x nnb ##
''' strategy 2: use all base pts, rel, dists for resolving '''
# rel_base_pts_to_hand_joints: bsz x ws x nnj x nnb x 3 #
signed_dist_mask = signed_dist_base_pts_to_hand_joints < 0.
l2_dist_rel_joints_to_base_pts_mask = torch.sqrt(
torch.sum(rel_base_pts_to_hand_joints ** 2, dim=-1)
) < 0.05
signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_base_pts_mask.float()) > 1.5
dot_rel_with_normals = torch.sum(
rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
)
signed_dist_mask = signed_dist_mask.detach() # detach the mask #
# dot_rel_with_normals: bsz x ws x nnj x nnb
avg_masks = (signed_dist_mask.float()).sum(dim=-1).mean()
signed_dist_e = dot_rel_with_normals * signed_dist_base_pts_to_hand_joints
signed_dist_e = torch.sum(
signed_dist_e[signed_dist_mask]
) / torch.clamp(torch.sum(signed_dist_mask.float()), min=1e-5).item()
###### ====== get loss for signed distances ==== ###
''' strategy 2: use all base pts, rel, dists for resolving '''
''' strategy 1: use nearest base pts, rel, dists for resolving '''
# dist_rhand_joints_to_base_pts = torch.sum(
# (hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1
# )
# # minn_dists_idxes: bsz x ws x nnj -->
# minn_dists_to_base_pts, minn_dists_idxes = torch.min(
# dist_rhand_joints_to_base_pts, dim=-1
# )
# # base_pts: bsz x nn_base_pts x 3 #
# # base_pts: bsz x ws x nn_base_pts x 3 #
# # bsz x ws x nnj
# # object verts and object faces #
# ## other than the sampling process; not
# # bsz x ws x nnj x 3 ##
# nearest_base_pts = batched_index_select_ours(
# base_pts_exp, indices=minn_dists_idxes, dim=2
# )
# # bsz x ws x nnj x 3 # # base normalse #
# nearest_base_normals = batched_index_select_ours(
# base_normals_exp, indices=minn_dists_idxes, dim=2
# )
# # bsz x ws x nnj x 3 # # the nearest distance points may be of some ambiguous
# rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts
# # bsz x ws x nnj #
# signed_dist_joints_to_base_pts = torch.sum(
# rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1
# )
# # should not be negative
# signed_dist_mask = signed_dist_joints_to_base_pts < 0.
# l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt(
# torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1)
# ) < 0.05
# signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_nearest_base_pts_mask.float()) > 1.5
# ### ==== mean of signed distances ==== ###
# signed_dist_e = torch.sum( # penetration
# -1.0 * signed_dist_joints_to_base_pts[signed_dist_mask]
# ) / torch.clamp(
# torch.sum(signed_dist_mask.float()), min=1e-5
# ).item()
''' strategy 1: use nearest base pts, rel, dists for resolving '''
## === e3 smoothness and prior losses === ##
pose_smoothness_loss = F.mse_loss(theta_var.view(bsz, ws, -1)[:, 1:], theta_var.view(bsz, ws, -1)[:, :-1])
shape_prior_loss = torch.mean(beta_var**2)
pose_prior_loss = torch.mean(theta_var**2)
## === e3 smoothness and prior losses === ##
## === e4 hand joints should be close to sampled hand joints === ##
dist_dec_jts_to_sampled_pts = torch.sum(
(hand_joints - sampled_joints) ** 2, dim=-1
).mean()
### signed distance coeff -> the distance coeff #
loss = pose_smoothness_loss * 0.05 + shape_prior_loss*0.001 + pose_prior_loss * 0.0001 + signed_dist_e * signed_dist_e_coeff + rel_e + dist_e + dist_dec_jts_to_sampled_pts
loss.backward()
opt.step()
print('Iter {}: {}'.format(i_iter, loss.item()), flush=True)
print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
print('\tsigned_dist_e Loss: {}'.format(signed_dist_e.item()))
print('\trel_e Loss: {}'.format(rel_e.item()))
print('\tdist_e Loss: {}'.format(dist_e.item()))
print('\tdist_dec_jts_to_sampled_pts Loss: {}'.format(dist_dec_jts_to_sampled_pts.item()))
fine_lr = 0.1
num_iters = 1000
opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=fine_lr)
for i_iter in range(num_iters):
opt.zero_grad()
# mano_layer #
hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
beta_var.unsqueeze(1).repeat(1, ws, 1).view(-1, 10), transl_var)
hand_verts = hand_verts.view(bsz, ws, 778, 3) * 0.001 ## bsz x ws x nn
hand_joints = hand_joints.view(bsz, ws, -1, 3) * 0.001
### === e1 should be close to predicted values === ###
# bsz x ws x nnj x nnb x 3 #
rel_base_pts_to_hand_joints = hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)
# bs zx ws x nnj x nnb #
signed_dist_base_pts_to_hand_joints = torch.sum(
rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
)
rel_e = torch.sum(
(rel_base_pts_to_hand_joints - rel_base_pts_to_joints) ** 2, dim=-1
).mean()
# dists_base_pts_to_joints ## dists_base_pts_to_joints ##
if dists_base_pts_to_joints is not None: ## dists_base_pts_to_joints ##
dist_e = torch.sum(
(signed_dist_base_pts_to_hand_joints - dists_base_pts_to_joints) ** 2, dim=-1
).mean()
else:
dist_e = torch.zeros((1,), dtype=torch.float32).mean()
### === e2 the signed distances to nearest points should not be negative to the neareste === ###
## base_pts: bsz x nn_base_pts x 3
## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ##
## bsz x ws x nnj x nnb ##
dist_rhand_joints_to_base_pts = torch.sum(
(hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1
)
# minn_dists_idxes: bsz x ws x nnj -->
minn_dists_to_base_pts, minn_dists_idxes = torch.min(
dist_rhand_joints_to_base_pts, dim=-1
)
# base_pts: bsz x nn_base_pts x 3 #
# base_pts: bsz x ws x nn_base_pts x 3 #
# bsz x ws x nnj
# base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous()
# bsz x ws x nnj x 3 ##
nearest_base_pts = batched_index_select_ours(
base_pts_exp, indices=minn_dists_idxes, dim=2
)
# bsz x ws x nnj x 3 #
nearest_base_normals = batched_index_select_ours(
base_normals_exp, indices=minn_dists_idxes, dim=2
)
# bsz x ws x nnj x 3 #
rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts
# bsz x ws x nnj #
signed_dist_joints_to_base_pts = torch.sum(
rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1
)
# should not be negative
signed_dist_mask = signed_dist_joints_to_base_pts < 0.
l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt(
torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1)
) < 0.05
signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_nearest_base_pts_mask.float()) > 1.5
### ==== mean of signed distances ==== ###
signed_dist_e = torch.sum(
-1.0 * signed_dist_joints_to_base_pts[signed_dist_mask]
) / torch.clamp(
torch.sum(signed_dist_mask.float()), min=1e-5
).item()
## === e3 smoothness and prior losses === ##
pose_smoothness_loss = F.mse_loss(theta_var.view(bsz, ws, -1)[:, 1:], theta_var.view(bsz, ws, -1)[:, :-1])
shape_prior_loss = torch.mean(beta_var**2)
pose_prior_loss = torch.mean(theta_var**2)
## === e3 smoothness and prior losses === ##
## === e4 hand joints should be close to sampled hand joints === ##
dist_dec_jts_to_sampled_pts = torch.sum(
(hand_joints - sampled_joints) ** 2, dim=-1
).mean()
loss = pose_smoothness_loss * 0.05 + shape_prior_loss*0.001 + pose_prior_loss * 0.0001 + signed_dist_e * signed_dist_e_coeff + rel_e + dist_e + dist_dec_jts_to_sampled_pts
loss.backward()
opt.step()
print('Iter {}: {}'.format(i_iter, loss.item()), flush=True)
print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
print('\tsigned_dist_e Loss: {}'.format(signed_dist_e.item()))
print('\trel_e Loss: {}'.format(rel_e.item()))
print('\tdist_e Loss: {}'.format(dist_e.item()))
print('\tdist_dec_jts_to_sampled_pts Loss: {}'.format(dist_dec_jts_to_sampled_pts.item()))
### refine the optimization with signed energy ##
signed_dist_e_coeff = 1.0
fine_lr = 0.1
num_iters = 1000
opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=fine_lr)
for i_iter in range(num_iters):
opt.zero_grad()
# mano_layer #
hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
beta_var.unsqueeze(1).repeat(1, ws, 1).view(-1, 10), transl_var)
hand_verts = hand_verts.view(bsz, ws, 778, 3) * 0.001 ## bsz x ws x nn
hand_joints = hand_joints.view(bsz, ws, -1, 3) * 0.001
### === e1 should be close to predicted values === ###
# bsz x ws x nnj x nnb x 3 #
rel_base_pts_to_hand_joints = hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)
# bs zx ws x nnj x nnb #
signed_dist_base_pts_to_hand_joints = torch.sum(
rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
)
rel_e = torch.sum(
(rel_base_pts_to_hand_joints - rel_base_pts_to_joints) ** 2, dim=-1
).mean()
# dists_base_pts_to_joints ## dists_base_pts_to_joints ##
if dists_base_pts_to_joints is not None: ## dists_base_pts_to_joints ##
dist_e = torch.sum(
(signed_dist_base_pts_to_hand_joints - dists_base_pts_to_joints) ** 2, dim=-1
).mean()
else:
dist_e = torch.zeros((1,), dtype=torch.float32).mean()
''' strategy 2: use all base pts, rel, dists for resolving '''
# rel_base_pts_to_hand_joints: bsz x ws x nnj x nnb x 3 #
signed_dist_mask = signed_dist_base_pts_to_hand_joints < 0.
l2_dist_rel_joints_to_base_pts_mask = torch.sqrt(
torch.sum(rel_base_pts_to_hand_joints ** 2, dim=-1)
) < 0.05
signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_base_pts_mask.float()) > 1.5
## === dot rel with normals === ##
# dot_rel_with_normals = torch.sum(
# rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
# )
## === dot rel with normals === ##
## === dot rel with rel, strategy 3 === ##
dot_rel_with_normals = torch.sum(
-1.0 * rel_base_pts_to_hand_joints * rel_base_pts_to_hand_joints, dim=-1
)
## === dot rel with rel, strategy 3 === ##
signed_dist_mask = signed_dist_mask.detach() # detach the mask #
# dot_rel_with_normals: bsz x ws x nnj x nnb
avg_masks = (signed_dist_mask.float()).sum(dim=-1).mean()
signed_dist_e = dot_rel_with_normals * signed_dist_base_pts_to_hand_joints
signed_dist_e = torch.sum(
signed_dist_e[signed_dist_mask]
) / torch.clamp(torch.sum(signed_dist_mask.float()), min=1e-5).item()
###### ====== get loss for signed distances ==== ###
''' strategy 2: use all base pts, rel, dists for resolving '''
# hard projections for
''' strategy 1: use nearest base pts, rel, dists for resolving '''
# ### === e2 the signed distances to nearest points should not be negative to the neareste === ###
# ## base_pts: bsz x nn_base_pts x 3
# ## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ##
# ## bsz x ws x nnj x nnb ##
# dist_rhand_joints_to_base_pts = torch.sum(
# (hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1
# )
# # minn_dists_idxes: bsz x ws x nnj -->
# minn_dists_to_base_pts, minn_dists_idxes = torch.min(
# dist_rhand_joints_to_base_pts, dim=-1
# )
#
# # base_pts: bsz x nn_base_pts x 3 #
# # base_pts: bsz x ws x nn_base_pts x 3 #
# # bsz x ws x nnj
# # base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous()
# # bsz x ws x nnj x 3 ##
# nearest_base_pts = batched_index_select_ours(
# base_pts_exp, indices=minn_dists_idxes, dim=2
# )
# # bsz x ws x nnj x 3 #
# nearest_base_normals = batched_index_select_ours(
# base_normals_exp, indices=minn_dists_idxes, dim=2
# )
# # bsz x ws x nnj x 3 #
# rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts
# # bsz x ws x nnj #
# signed_dist_joints_to_base_pts = torch.sum(
# rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1
# )
# # should not be negative
# signed_dist_mask = signed_dist_joints_to_base_pts < 0.
# ## === luojisiwei and others === ##
# # l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt(
# # torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1)
# # ) < 0.05
# ## === luojisiwei and others === ##
# l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt(
# torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1)
# ) < 0.1
# signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_nearest_base_pts_mask.float()) > 1.5
# ### ==== mean of signed distances ==== ###
# # signed_dist_e = torch.sum(
# # -1.0 * signed_dist_joints_to_base_pts[signed_dist_mask]
# # ) / torch.clamp(
# # torch.sum(signed_dist_mask.float()), min=1e-5
# # ).item()
# # signed_dist_joints_to_base_pts: bsz x ws x nnj # -> disstances
# signed_dist_joints_to_base_pts = signed_dist_joints_to_base_pts.detach()
# #
## penetraition resolving --- strategy
# dot_rel_with_normals = torch.sum(
# rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1
# )
# signed_dist_mask = signed_dist_mask.detach() # detach the mask #
# # bsz x ws x nnj --> the loss term
# ## signed distances 3 #### isgned distance 3 ###
# ## dotrelwithnormals, ##
# # # signed_dist_mask -> the distances
# # dot_rel_with_normals: bsz x ws x nnj x nnb
# avg_masks = (signed_dist_mask.float()).sum(dim=-1).mean()
# signed_dist_e = dot_rel_with_normals * signed_dist_joints_to_base_pts
# signed_dist_e = torch.sum(
# signed_dist_e[signed_dist_mask]
# ) / torch.clamp(torch.sum(signed_dist_mask.float()), min=1e-5).item()
# ###### ====== get loss for signed distances ==== ###
''' strategy 1: use nearest base pts, rel, dists for resolving '''
## judeg whether inside the object and only project those one inside of the object
## === e3 smoothness and prior losses === ##
pose_smoothness_loss = F.mse_loss(theta_var.view(bsz, ws, -1)[:, 1:], theta_var.view(bsz, ws, -1)[:, :-1])
shape_prior_loss = torch.mean(beta_var**2)
pose_prior_loss = torch.mean(theta_var**2)
## === e3 smoothness and prior losses === ##
## === e4 hand joints should be close to sampled hand joints === ##
dist_dec_jts_to_sampled_pts = torch.sum(
(hand_joints - sampled_joints) ** 2, dim=-1
).mean()
# shoudl take a
# how to proejct the jvertex
# hwo to project the veretex
# weighted sum of the projectiondirection
# weights of each base point
# atraction field -> should be able to learn the penetration resolving strategy
# stochestic penetration resolving strategy #
loss = pose_smoothness_loss * 0.05 + shape_prior_loss*0.001 + pose_prior_loss * 0.0001 + signed_dist_e * signed_dist_e_coeff + rel_e + dist_e + dist_dec_jts_to_sampled_pts
loss.backward()
opt.step()
print('Iter {}: {}'.format(i_iter, loss.item()), flush=True)
print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
print('\tsigned_dist_e Loss: {}'.format(signed_dist_e.item()))
print('\trel_e Loss: {}'.format(rel_e.item()))
print('\tdist_e Loss: {}'.format(dist_e.item()))
print('\tdist_dec_jts_to_sampled_pts Loss: {}'.format(dist_dec_jts_to_sampled_pts.item()))
# avg_masks
print('\tAvg masks: {}'.format(avg_masks.item()))
''' returning sampled_joints '''
sampled_joints = hand_joints
np.save("optimized_verts.npy", hand_verts.detach().cpu().numpy())
print(f"Optimized verts saved to optimized_verts.npy")
return sampled_joints.detach()
def get_obj_trimesh_list(obj_verts, obj_faces):
tot_trimeshes = []
tot_n = len(obj_verts)
for i_obj in range(tot_n):
cur_obj_verts, cur_obj_faces = obj_verts[i_obj], obj_faces[i_obj]
if isinstance(cur_obj_verts, torch.Tensor):
cur_obj_verts = cur_obj_verts.detach().cpu().numpy()
if isinstance(cur_obj_faces, torch.Tensor):
cur_obj_faces = cur_obj_faces.detach().cpu().numpy()
cur_obj_mesh = trimesh.Trimesh(vertices=cur_obj_verts, faces=cur_obj_faces,
process=False, use_embree=True)
tot_trimeshes.append(cur_obj_mesh)
return tot_trimeshes
def judge_penetrated_points(obj_mesh, subj_pts):
# bsz
tot_pts_inside_objmesh_labels = []
nn_bsz = len(obj_mesh)
for i_bsz in range(nn_bsz):
cur_obj_mesh = obj_mesh[i_bsz]
cur_subj_pts = subj_pts[i_bsz].detach().cpu().numpy()
ori_subj_pts_shape = cur_subj_pts.shape
if len(cur_subj_pts.shape) > 2:
cur_subj_pts = cur_subj_pts.reshape(cur_subj_pts.shape[0] * cur_subj_pts.shape[1], 3)
#
pts_inside_objmesh = cur_obj_mesh.contains(cur_subj_pts)
pts_inside_objmesh = pts_inside_objmesh.astype(np.float32)
### reshape inside_objmesh labels ###
pts_inside_objmesh = pts_inside_objmesh.reshape(*ori_subj_pts_shape[:-1])
tot_pts_inside_objmesh_labels.append(pts_inside_objmesh)
tot_pts_inside_objmesh_labels = np.stack(tot_pts_inside_objmesh_labels, axis=0) # nn_bsz x nn_subj_pts
tot_pts_inside_objmesh_labels = torch.from_numpy(tot_pts_inside_objmesh_labels).float()
return tot_pts_inside_objmesh_labels.to(subj_pts.device) # gt inside objmesh labels and to the pts device #
# TODO: other optimization strategies? e.g. sequential optimziation> #
def optimize_sampled_hand_joints_wobj(sampled_joints, rel_base_pts_to_joints, dists_base_pts_to_joints, base_pts, base_normals, obj_verts, obj_normals, obj_faces):
# sampled_joints: bsz x ws x nnj x 3
# signed distances
# smoothness
# tot_n_objs #
tot_obj_trimeshes = get_obj_trimesh_list(obj_verts, obj_faces)
## TODO: write the collect function for object verts, normals, faces ##
### A simple penetration resolving strategy is as follows:
#### 1) get vertices in the object; 2) get nearest base points (for simplicity); 3) project the vertex to the base point ####
## 1) for joints only;
## 2) for vertices;
## 3) for vertices ##
## TODO: optimzie the resolvign strategy stated above ##
bsz, ws, nnj = sampled_joints.shape[:3]
device = sampled_joints.device
coarse_lr = 0.1
num_iters = 100 # if i_iter > 0 else 1 ## nn-coarse-iters for global transformations #
mano_path = "/data1/sim/mano_models/mano/models"
# obj_verts: bsz x nnobjverts x
base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous()
base_normals_exp = base_normals.unsqueeze(1).repeat(1, ws, 1, 1).contiguous()
signed_dist_e_coeff = 1.0
signed_dist_e_coeff = 0.0
### start optimization ###
# setup MANO layer
mano_layer = ManoLayer(
flat_hand_mean=True,
side='right',
mano_root=mano_path, # mano_path for the mano model #
ncomps=24,
use_pca=True,
root_rot_mode='axisang',
joint_rot_mode='axisang'
).to(device)
## random init variables ##
beta_var = torch.randn([bsz, 10]).to(device)
rot_var = torch.randn([bsz * ws, 3]).to(device)
theta_var = torch.randn([bsz * ws, 24]).to(device)
transl_var = torch.randn([bsz * ws, 3]).to(device)
beta_var.requires_grad_()
rot_var.requires_grad_()
theta_var.requires_grad_()
transl_var.requires_grad_()
opt = optim.Adam([rot_var, transl_var], lr=coarse_lr)
for i_iter in range(num_iters):
opt.zero_grad()
# mano_layer #
hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
beta_var.unsqueeze(1).repeat(1, ws, 1).view(-1, 10), transl_var)
hand_verts = hand_verts.view(bsz, ws, 778, 3) * 0.001 ## bsz x ws x nn
hand_joints = hand_joints.view(bsz, ws, -1, 3) * 0.001
### === e1 should be close to predicted values === ###
# bsz x ws x nnj x nnb x 3 #
rel_base_pts_to_hand_joints = hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)
# bs zx ws x nnj x nnb #
signed_dist_base_pts_to_hand_joints = torch.sum(
rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
)
rel_e = torch.sum(
(rel_base_pts_to_hand_joints - rel_base_pts_to_joints) ** 2, dim=-1
).mean()
if dists_base_pts_to_joints is not None:
dist_e = torch.sum(
(signed_dist_base_pts_to_hand_joints - dists_base_pts_to_joints) ** 2, dim=-1
).mean()
else:
dist_e = torch.zeros((1,), dtype=torch.float32).to(device).mean()
### === e2 the signed distances to nearest points should not be negative to the neareste === ###
## base_pts: bsz x nn_base_pts x 3
## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ##
## bsz x ws x nnj x nnb ##
''' strategy 2: use all base pts, rel, dists for resolving '''
# rel_base_pts_to_hand_joints: bsz x ws x nnj x nnb x 3 #
signed_dist_mask = signed_dist_base_pts_to_hand_joints < 0.
l2_dist_rel_joints_to_base_pts_mask = torch.sqrt(
torch.sum(rel_base_pts_to_hand_joints ** 2, dim=-1)
) < 0.05
signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_base_pts_mask.float()) > 1.5
dot_rel_with_normals = torch.sum(
rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
)
signed_dist_mask = signed_dist_mask.detach() # detach the mask #
# dot_rel_with_normals: bsz x ws x nnj x nnb
avg_masks = (signed_dist_mask.float()).sum(dim=-1).mean()
signed_dist_e = dot_rel_with_normals * signed_dist_base_pts_to_hand_joints
signed_dist_e = torch.sum(
signed_dist_e[signed_dist_mask]
) / torch.clamp(torch.sum(signed_dist_mask.float()), min=1e-5).item()
###### ====== get loss for signed distances ==== ###
''' strategy 2: use all base pts, rel, dists for resolving '''
''' strategy 1: use nearest base pts, rel, dists for resolving '''
# dist_rhand_joints_to_base_pts = torch.sum(
# (hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1
# )
# # minn_dists_idxes: bsz x ws x nnj -->
# minn_dists_to_base_pts, minn_dists_idxes = torch.min(
# dist_rhand_joints_to_base_pts, dim=-1
# )
# # base_pts: bsz x nn_base_pts x 3 #
# # base_pts: bsz x ws x nn_base_pts x 3 #
# # bsz x ws x nnj
# # object verts and object faces #
# ## other than the sampling process; not
# # bsz x ws x nnj x 3 ##
# nearest_base_pts = batched_index_select_ours(
# base_pts_exp, indices=minn_dists_idxes, dim=2
# )
# # bsz x ws x nnj x 3 # # base normalse #
# nearest_base_normals = batched_index_select_ours(
# base_normals_exp, indices=minn_dists_idxes, dim=2
# )
# # bsz x ws x nnj x 3 # # the nearest distance points may be of some ambiguous
# rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts
# # bsz x ws x nnj #
# signed_dist_joints_to_base_pts = torch.sum(
# rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1
# )
# # should not be negative
# signed_dist_mask = signed_dist_joints_to_base_pts < 0.
# l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt(
# torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1)
# ) < 0.05
# signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_nearest_base_pts_mask.float()) > 1.5
# ### ==== mean of signed distances ==== ###
# signed_dist_e = torch.sum( # penetration
# -1.0 * signed_dist_joints_to_base_pts[signed_dist_mask]
# ) / torch.clamp(
# torch.sum(signed_dist_mask.float()), min=1e-5
# ).item()
''' strategy 1: use nearest base pts, rel, dists for resolving '''
## === e3 smoothness and prior losses === ##
pose_smoothness_loss = F.mse_loss(theta_var.view(bsz, ws, -1)[:, 1:], theta_var.view(bsz, ws, -1)[:, :-1])
shape_prior_loss = torch.mean(beta_var**2)
pose_prior_loss = torch.mean(theta_var**2)
## === e3 smoothness and prior losses === ##
## === e4 hand joints should be close to sampled hand joints === ##
dist_dec_jts_to_sampled_pts = torch.sum(
(hand_joints - sampled_joints) ** 2, dim=-1
).mean()
### signed distance coeff -> the distance coeff #
loss = pose_smoothness_loss * 0.05 + shape_prior_loss*0.001 + pose_prior_loss * 0.0001 + signed_dist_e * signed_dist_e_coeff + rel_e + dist_e + dist_dec_jts_to_sampled_pts
loss.backward()
opt.step()
print('Iter {}: {}'.format(i_iter, loss.item()), flush=True)
print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
print('\tsigned_dist_e Loss: {}'.format(signed_dist_e.item()))
print('\trel_e Loss: {}'.format(rel_e.item()))
print('\tdist_e Loss: {}'.format(dist_e.item()))
print('\tdist_dec_jts_to_sampled_pts Loss: {}'.format(dist_dec_jts_to_sampled_pts.item()))
fine_lr = 0.1
num_iters = 1000
opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=fine_lr)
for i_iter in range(num_iters):
opt.zero_grad()
# mano_layer #
hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
beta_var.unsqueeze(1).repeat(1, ws, 1).view(-1, 10), transl_var)
hand_verts = hand_verts.view(bsz, ws, 778, 3) * 0.001 ## bsz x ws x nn
hand_joints = hand_joints.view(bsz, ws, -1, 3) * 0.001
### === e1 should be close to predicted values === ###
# bsz x ws x nnj x nnb x 3 #
rel_base_pts_to_hand_joints = hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)
# bs zx ws x nnj x nnb #
signed_dist_base_pts_to_hand_joints = torch.sum(
rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
)
rel_e = torch.sum(
(rel_base_pts_to_hand_joints - rel_base_pts_to_joints) ** 2, dim=-1
).mean()
# dists_base_pts_to_joints ## dists_base_pts_to_joints ##
if dists_base_pts_to_joints is not None: ## dists_base_pts_to_joints ##
dist_e = torch.sum(
(signed_dist_base_pts_to_hand_joints - dists_base_pts_to_joints) ** 2, dim=-1
).mean()
else:
dist_e = torch.zeros((1,), dtype=torch.float32).mean()
### === e2 the signed distances to nearest points should not be negative to the neareste === ###
## base_pts: bsz x nn_base_pts x 3
## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ##
## bsz x ws x nnj x nnb ##
dist_rhand_joints_to_base_pts = torch.sum(
(hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1
)
# minn_dists_idxes: bsz x ws x nnj -->
minn_dists_to_base_pts, minn_dists_idxes = torch.min(
dist_rhand_joints_to_base_pts, dim=-1
)
# base_pts: bsz x nn_base_pts x 3 #
# base_pts: bsz x ws x nn_base_pts x 3 #
# bsz x ws x nnj
# base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous()
# bsz x ws x nnj x 3 ##
nearest_base_pts = batched_index_select_ours(
base_pts_exp, indices=minn_dists_idxes, dim=2
)
# bsz x ws x nnj x 3 #
nearest_base_normals = batched_index_select_ours(
base_normals_exp, indices=minn_dists_idxes, dim=2
)
# bsz x ws x nnj x 3 #
rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts
# bsz x ws x nnj #
signed_dist_joints_to_base_pts = torch.sum(
rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1
)
# should not be negative
signed_dist_mask = signed_dist_joints_to_base_pts < 0.
l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt(
torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1)
) < 0.05
signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_nearest_base_pts_mask.float()) > 1.5
### ==== mean of signed distances ==== ###
signed_dist_e = torch.sum(
-1.0 * signed_dist_joints_to_base_pts[signed_dist_mask]
) / torch.clamp(
torch.sum(signed_dist_mask.float()), min=1e-5
).item()
## === e3 smoothness and prior losses === ##
pose_smoothness_loss = F.mse_loss(theta_var.view(bsz, ws, -1)[:, 1:], theta_var.view(bsz, ws, -1)[:, :-1])
shape_prior_loss = torch.mean(beta_var**2)
pose_prior_loss = torch.mean(theta_var**2)
## === e3 smoothness and prior losses === ##
## === e4 hand joints should be close to sampled hand joints === ##
dist_dec_jts_to_sampled_pts = torch.sum(
(hand_joints - sampled_joints) ** 2, dim=-1
).mean()
loss = pose_smoothness_loss * 0.05 + shape_prior_loss*0.001 + pose_prior_loss * 0.0001 + signed_dist_e * signed_dist_e_coeff + rel_e + dist_e + dist_dec_jts_to_sampled_pts
loss.backward()
opt.step()
print('Iter {}: {}'.format(i_iter, loss.item()), flush=True)
print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
print('\tsigned_dist_e Loss: {}'.format(signed_dist_e.item()))
print('\trel_e Loss: {}'.format(rel_e.item()))
print('\tdist_e Loss: {}'.format(dist_e.item()))
print('\tdist_dec_jts_to_sampled_pts Loss: {}'.format(dist_dec_jts_to_sampled_pts.item()))
# tot_obj_trimeshes
### refine the optimization with signed energy ##
signed_dist_e_coeff = 1.0 #
fine_lr = 0.1
# num_iters = 1000 #
num_iters = 100 # reinement #
opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=fine_lr)
for i_iter in range(num_iters): #
opt.zero_grad()
# mano_layer #
hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
beta_var.unsqueeze(1).repeat(1, ws, 1).view(-1, 10), transl_var)
hand_verts = hand_verts.view(bsz, ws, 778, 3) * 0.001 ## bsz x ws x nn
hand_joints = hand_joints.view(bsz, ws, -1, 3) * 0.001
### === e1 should be close to predicted values === ###
# bsz x ws x nnj x nnb x 3 #
rel_base_pts_to_hand_joints = hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)
# bs zx ws x nnj x nnb #
signed_dist_base_pts_to_hand_joints = torch.sum(
rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
)
rel_e = torch.sum(
(rel_base_pts_to_hand_joints - rel_base_pts_to_joints) ** 2, dim=-1
).mean()
# dists_base_pts_to_joints ## dists_base_pts_to_joints ##
if dists_base_pts_to_joints is not None: ## dists_base_pts_to_joints ##
dist_e = torch.sum(
(signed_dist_base_pts_to_hand_joints - dists_base_pts_to_joints) ** 2, dim=-1
).mean()
else:
dist_e = torch.zeros((1,), dtype=torch.float32).mean()
''' strategy 2: use all base pts, rel, dists for resolving '''
# # rel_base_pts_to_hand_joints: bsz x ws x nnj x nnb x 3 #
# signed_dist_mask = signed_dist_base_pts_to_hand_joints < 0.
# l2_dist_rel_joints_to_base_pts_mask = torch.sqrt(
# torch.sum(rel_base_pts_to_hand_joints ** 2, dim=-1)
# ) < 0.05
# signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_base_pts_mask.float()) > 1.5
# ## === dot rel with normals === ##
# # dot_rel_with_normals = torch.sum(
# # rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
# # )
# ## === dot rel with normals === ##
# ## === dot rel with rel, strategy 3 === ##
# dot_rel_with_normals = torch.sum(
# -1.0 * rel_base_pts_to_hand_joints * rel_base_pts_to_hand_joints, dim=-1
# )
# ## === dot rel with rel, strategy 3 === ##
# signed_dist_mask = signed_dist_mask.detach() # detach the mask #
# # dot_rel_with_normals: bsz x ws x nnj x nnb
# avg_masks = (signed_dist_mask.float()).sum(dim=-1).mean()
# signed_dist_e = dot_rel_with_normals * signed_dist_base_pts_to_hand_joints
# signed_dist_e = torch.sum(
# signed_dist_e[signed_dist_mask]
# ) / torch.clamp(torch.sum(signed_dist_mask.float()), min=1e-5).item()
# ###### ====== get loss for signed distances ==== ###
''' strategy 2: use all base pts, rel, dists for resolving '''
## use all base pts ##
{
# hard projections for
''' strategy 1: use nearest base pts, rel, dists for resolving '''
# ### === e2 the signed distances to nearest points should not be negative to the neareste === ###
# ## base_pts: bsz x nn_base_pts x 3
# ## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ##
# ## bsz x ws x nnj x nnb ##
# dist_rhand_joints_to_base_pts = torch.sum(
# (hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1
# )
# # minn_dists_idxes: bsz x ws x nnj -->
# minn_dists_to_base_pts, minn_dists_idxes = torch.min(
# dist_rhand_joints_to_base_pts, dim=-1
# )
#
# # base_pts: bsz x nn_base_pts x 3 #
# # base_pts: bsz x ws x nn_base_pts x 3 #
# # bsz x ws x nnj
# # base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous()
# # bsz x ws x nnj x 3 ##
# nearest_base_pts = batched_index_select_ours(
# base_pts_exp, indices=minn_dists_idxes, dim=2
# )
# # bsz x ws x nnj x 3 #
# nearest_base_normals = batched_index_select_ours(
# base_normals_exp, indices=minn_dists_idxes, dim=2
# )
# # bsz x ws x nnj x 3 #
# rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts
# # bsz x ws x nnj #
# signed_dist_joints_to_base_pts = torch.sum(
# rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1
# )
# # should not be negative
# signed_dist_mask = signed_dist_joints_to_base_pts < 0.
# ## === luojisiwei and others === ##
# # l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt(
# # torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1)
# # ) < 0.05
# ## === luojisiwei and others === ##
# l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt(
# torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1)
# ) < 0.1
# signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_nearest_base_pts_mask.float()) > 1.5
# ### ==== mean of signed distances ==== ###
# # signed_dist_e = torch.sum(
# # -1.0 * signed_dist_joints_to_base_pts[signed_dist_mask]
# # ) / torch.clamp(
# # torch.sum(signed_dist_mask.float()), min=1e-5
# # ).item()
# # signed_dist_joints_to_base_pts: bsz x ws x nnj # -> disstances
# signed_dist_joints_to_base_pts = signed_dist_joints_to_base_pts.detach()
# #
## penetraition resolving --- strategy
# dot_rel_with_normals = torch.sum(
# rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1
# )
# signed_dist_mask = signed_dist_mask.detach() # detach the mask #
# # bsz x ws x nnj --> the loss term
# ## signed distances 3 #### isgned distance 3 ###
# ## dotrelwithnormals, ##
# # # signed_dist_mask -> the distances
# # dot_rel_with_normals: bsz x ws x nnj x nnb
# avg_masks = (signed_dist_mask.float()).sum(dim=-1).mean()
# signed_dist_e = dot_rel_with_normals * signed_dist_joints_to_base_pts
# signed_dist_e = torch.sum(
# signed_dist_e[signed_dist_mask]
# ) / torch.clamp(torch.sum(signed_dist_mask.float()), min=1e-5).item()
# ###### ====== get loss for signed distances ==== ###
''' strategy 1: use nearest base pts, rel, dists for resolving '''
}
# bsz x ws x nnj # --> objmesh insides pts labels
pts_inside_objmesh_labels = judge_penetrated_points(tot_obj_trimeshes, hand_joints)
pts_inside_objmesh_labels_mask = pts_inside_objmesh_labels.bool()
# {
# hard projections for
''' strategy 1: use nearest base pts, rel, dists for resolving '''
### === e2 the signed distances to nearest points should not be negative to the neareste === ###
## base_pts: bsz x nn_base_pts x 3
## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ##
## bsz x ws x nnj x nnb ##
dist_rhand_joints_to_base_pts = torch.sum(
(hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1
)
# minn_dists_idxes: bsz x ws x nnj -->
# base_pts
minn_dists_to_base_pts, minn_dists_idxes = torch.min(
dist_rhand_joints_to_base_pts, dim=-1
)
# base_pts: bsz x nn_base_pts x 3 #
# base_pts: bsz x ws x nn_base_pts x 3 #
# bsz x ws x nnj
# base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous()
# bsz x ws x nnj x 3 ##
# simple penetration ##
nearest_base_pts = batched_index_select_ours(
base_pts_exp, indices=minn_dists_idxes, dim=2
)
# bsz x ws x nnj x 3 #
nearest_base_normals = batched_index_select_ours(
base_normals_exp, indices=minn_dists_idxes, dim=2
)
# bsz x ws x nnj x 3 #
rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts
# bsz x ws x nnj #
# signed_dist_joints_to_base_pts = torch.sum(
# rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1
# )
# # should not be negative
# signed_dist_mask = signed_dist_joints_to_base_pts < 0.
## === luojisiwei and others === ##
# l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt(
# torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1)
# ) < 0.05
## === luojisiwei and others === ##
##### ===== GET l2_distance mask ===== #####
# l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt(
# torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1)
# ) < 0.1
# signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_nearest_base_pts_mask.float()) > 1.5
##### ===== GET l2_distance mask ===== #####
### ==== mean of signed distances ==== ###
# signed_dist_e = torch.sum(
# -1.0 * signed_dist_joints_to_base_pts[signed_dist_mask]
# ) / torch.clamp(
# torch.sum(signed_dist_mask.float()), min=1e-5
# ).item()
# signed_dist_joints_to_base_pts: bsz x ws x nnj # -> disstances
signed_dist_joints_to_base_pts = signed_dist_joints_to_base_pts.detach()
#
# dot rel
# penetraition resolving --- strategy
# dot_rel_with_normals = torch.sum( # dot rhand joints with normals #
# rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1
# )
#
dot_rel_with_normals = torch.sum( # dot rhand joints with normals #
-rel_joints_to_nearest_base_pts * rel_joints_to_nearest_base_pts, dim=-1
)
#### Get masks for penetrated joint points ####
# signed_dist_mask = (signed_dist_mask.float() + pts_inside_objmesh_labels_mask.float()) > 1.5
signed_dist_mask = pts_inside_objmesh_labels_mask
# bsz x ws x nnj
signed_dist_mask = signed_dist_mask.detach() # detach the mask #
# bsz x ws x nnj --> the loss term
## signed distances 3 #### isgned distance 3 ###
## dotrelwithnormals, ##
# # signed_dist_mask -> the distances
# dot_rel_with_normals: bsz x ws x nnj x nnb # avg over windows and batches #
avg_masks = (signed_dist_mask.float()).sum(dim=-1).mean()
## get singed distance energies ### ## projection ##
# signed_dist_e = dot_rel_with_normals * signed_dist_joints_to_base_pts
signed_dist_e = -1.0 * dot_rel_with_normals
signed_dist_e = torch.sum(
signed_dist_e[signed_dist_mask]
) / torch.clamp(torch.sum(signed_dist_mask.float()), min=1e-5).item()
###### ====== get loss for signed distances ==== ###
''' strategy 1: use nearest base pts, rel, dists for resolving '''
# cannot mask in some caes
# change of isgned distances #
# intersection spline
## judeg whether inside the object and only project those one inside of the object
## === e3 smoothness and prior losses === ##
pose_smoothness_loss = F.mse_loss(theta_var.view(bsz, ws, -1)[:, 1:], theta_var.view(bsz, ws, -1)[:, :-1])
shape_prior_loss = torch.mean(beta_var**2)
pose_prior_loss = torch.mean(theta_var**2)
## === e3 smoothness and prior losses === ##
#### ==== sv_dict ==== ####
sv_dict = {
'pts_inside_objmesh_labels_mask': pts_inside_objmesh_labels_mask.detach().cpu().numpy(),
'hand_joints': hand_joints.detach().cpu().numpy(),
'obj_verts': [cur_verts.detach().cpu().numpy() for cur_verts in obj_verts],
'obj_faces': [cur_faces.detach().cpu().numpy() for cur_faces in obj_faces],
'base_pts': base_pts.detach().cpu().numpy(),
'base_normals': base_normals.detach().cpu().numpy(), # bsz x nnb x 3 -> bsz x nnb x 3 -> base normals #
'nearest_base_pts': nearest_base_pts.detach().cpu().numpy(), # bsz x ws x nnj x 3 #
'nearest_base_normals': nearest_base_normals.detach().cpu().numpy(), # bsz x ws x nnj x 3 --> base normals and pts
}
#
sv_dict_folder = "/data1/sim/mdm/tmp_saving"
os.makedirs(sv_dict_folder, exist_ok=True)
sv_dict_fn = os.path.join(sv_dict_folder, f"optim_iter_{i_iter}.npy")
np.save(sv_dict_fn, sv_dict)
print(f"Obj and subj saved to {sv_dict_fn}")
#### ==== sv_dict ==== ####
## === e4 hand joints should be close to sampled hand joints === ##
dist_dec_jts_to_sampled_pts = torch.sum(
(hand_joints - sampled_joints) ** 2, dim=-1
).mean()
# shoudl take a
# how to proejct the jvertex
# hwo to project the veretex
# weighted sum of the projectiondirection
# weights of each base point
# atraction field -> should be able to learn the penetration resolving strategy
# stochestic penetration resolving strategy #
loss = pose_smoothness_loss * 0.05 + shape_prior_loss*0.001 + pose_prior_loss * 0.0001 + signed_dist_e * signed_dist_e_coeff + rel_e + dist_e + dist_dec_jts_to_sampled_pts
loss.backward()
opt.step()
print('Iter {}: {}'.format(i_iter, loss.item()), flush=True)
print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
print('\tsigned_dist_e Loss: {}'.format(signed_dist_e.item()))
print('\trel_e Loss: {}'.format(rel_e.item()))
print('\tdist_e Loss: {}'.format(dist_e.item()))
print('\tdist_dec_jts_to_sampled_pts Loss: {}'.format(dist_dec_jts_to_sampled_pts.item()))
# avg_masks
print('\tAvg masks: {}'.format(avg_masks.item()))
''' returning sampled_joints '''
sampled_joints = hand_joints
np.save("optimized_verts.npy", hand_verts.detach().cpu().numpy())
print(f"Optimized verts saved to optimized_verts.npy")
return sampled_joints.detach()
# TODO: other optimization strategies? e.g. sequential optimziation> #
def optimize_sampled_hand_joints_wobj_v2(sampled_joints, rel_base_pts_to_joints, dists_base_pts_to_joints, base_pts, base_normals, obj_verts, obj_normals, obj_faces):
# sampled_joints: bsz x ws x nnj x 3 #
# sampled_joints: bsz x ws x nnj x 3 # obj trimeshes #
tot_obj_trimeshes = get_obj_trimesh_list(obj_verts, obj_faces)
## TODO: write the collect function for object verts, normals, faces ##
### A simple penetration resolving strategy is as follows:
#### 1) get vertices in the object; 2) get nearest base points (for simplicity); 3) project the vertex to the base point ####
## 1) for joints only;
## 2) for vertices;
## 3) for vertices;
## TODO: optimzie the resolvign strategy stated above ##
bsz, ws, nnj = sampled_joints.shape[:3]
device = sampled_joints.device
coarse_lr = 0.1
num_iters = 100 # if i_iter > 0 else 1 ## nn-coarse-iters for global transformations #
mano_path = "/data1/sim/mano_models/mano/models"
# obj_verts: bsz x nnobjverts x
base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous()
base_normals_exp = base_normals.unsqueeze(1).repeat(1, ws, 1, 1).contiguous()
signed_dist_e_coeff = 1.0
signed_dist_e_coeff = 0.0
### start optimization ###
# setup MANO layer
mano_layer = ManoLayer(
flat_hand_mean=True,
side='right',
mano_root=mano_path, # mano_path for the mano model #
ncomps=24,
use_pca=True,
root_rot_mode='axisang',
joint_rot_mode='axisang'
).to(device)
## random init variables ##
beta_var = torch.randn([bsz, 10]).to(device)
rot_var = torch.randn([bsz * ws, 3]).to(device)
theta_var = torch.randn([bsz * ws, 24]).to(device)
transl_var = torch.randn([bsz * ws, 3]).to(device)
beta_var.requires_grad_()
rot_var.requires_grad_()
theta_var.requires_grad_()
transl_var.requires_grad_()
opt = optim.Adam([rot_var, transl_var], lr=coarse_lr)
for i_iter in range(num_iters):
opt.zero_grad()
# mano_layer #
hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
beta_var.unsqueeze(1).repeat(1, ws, 1).view(-1, 10), transl_var)
hand_verts = hand_verts.view(bsz, ws, 778, 3) * 0.001 ## bsz x ws x nn
hand_joints = hand_joints.view(bsz, ws, -1, 3) * 0.001
### === e1 should be close to predicted values === ###
# bsz x ws x nnj x nnb x 3 #
rel_base_pts_to_hand_joints = hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)
# bs zx ws x nnj x nnb #
signed_dist_base_pts_to_hand_joints = torch.sum(
rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
)
rel_e = torch.sum(
(rel_base_pts_to_hand_joints - rel_base_pts_to_joints) ** 2, dim=-1
).mean()
if dists_base_pts_to_joints is not None:
dist_e = torch.sum(
(signed_dist_base_pts_to_hand_joints - dists_base_pts_to_joints) ** 2, dim=-1
).mean()
else:
dist_e = torch.zeros((1,), dtype=torch.float32).to(device).mean()
### === e2 the signed distances to nearest points should not be negative to the neareste === ###
## base_pts: bsz x nn_base_pts x 3
## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ##
## bsz x ws x nnj x nnb ##
''' strategy 2: use all base pts, rel, dists for resolving '''
# rel_base_pts_to_hand_joints: bsz x ws x nnj x nnb x 3 #
signed_dist_mask = signed_dist_base_pts_to_hand_joints < 0.
l2_dist_rel_joints_to_base_pts_mask = torch.sqrt(
torch.sum(rel_base_pts_to_hand_joints ** 2, dim=-1)
) < 0.05
signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_base_pts_mask.float()) > 1.5
dot_rel_with_normals = torch.sum(
rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
)
signed_dist_mask = signed_dist_mask.detach() # detach the mask #
# dot_rel_with_normals: bsz x ws x nnj x nnb
avg_masks = (signed_dist_mask.float()).sum(dim=-1).mean()
signed_dist_e = dot_rel_with_normals * signed_dist_base_pts_to_hand_joints
signed_dist_e = torch.sum(
signed_dist_e[signed_dist_mask]
) / torch.clamp(torch.sum(signed_dist_mask.float()), min=1e-5).item()
###### ====== get loss for signed distances ==== ###
''' strategy 2: use all base pts, rel, dists for resolving '''
''' strategy 1: use nearest base pts, rel, dists for resolving '''
# dist_rhand_joints_to_base_pts = torch.sum(
# (hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1
# )
# # minn_dists_idxes: bsz x ws x nnj -->
# minn_dists_to_base_pts, minn_dists_idxes = torch.min(
# dist_rhand_joints_to_base_pts, dim=-1
# )
# # base_pts: bsz x nn_base_pts x 3 #
# # base_pts: bsz x ws x nn_base_pts x 3 #
# # bsz x ws x nnj
# # object verts and object faces #
# ## other than the sampling process; not
# # bsz x ws x nnj x 3 ##
# nearest_base_pts = batched_index_select_ours(
# base_pts_exp, indices=minn_dists_idxes, dim=2
# )
# # bsz x ws x nnj x 3 # # base normalse #
# nearest_base_normals = batched_index_select_ours(
# base_normals_exp, indices=minn_dists_idxes, dim=2
# )
# # bsz x ws x nnj x 3 # # the nearest distance points may be of some ambiguous
# rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts
# # bsz x ws x nnj #
# signed_dist_joints_to_base_pts = torch.sum(
# rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1
# )
# # should not be negative
# signed_dist_mask = signed_dist_joints_to_base_pts < 0.
# l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt(
# torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1)
# ) < 0.05
# signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_nearest_base_pts_mask.float()) > 1.5
# ### ==== mean of signed distances ==== ###
# signed_dist_e = torch.sum( # penetration
# -1.0 * signed_dist_joints_to_base_pts[signed_dist_mask]
# ) / torch.clamp(
# torch.sum(signed_dist_mask.float()), min=1e-5
# ).item()
''' strategy 1: use nearest base pts, rel, dists for resolving '''
## === e3 smoothness and prior losses === ##
pose_smoothness_loss = F.mse_loss(theta_var.view(bsz, ws, -1)[:, 1:], theta_var.view(bsz, ws, -1)[:, :-1])
shape_prior_loss = torch.mean(beta_var**2)
pose_prior_loss = torch.mean(theta_var**2)
## === e3 smoothness and prior losses === ##
## === e4 hand joints should be close to sampled hand joints === ##
dist_dec_jts_to_sampled_pts = torch.sum(
(hand_joints - sampled_joints) ** 2, dim=-1
).mean()
### signed distance coeff -> the distance coeff #
loss = pose_smoothness_loss * 0.05 + shape_prior_loss*0.001 + pose_prior_loss * 0.0001 + signed_dist_e * signed_dist_e_coeff + rel_e + dist_e + dist_dec_jts_to_sampled_pts
loss.backward()
opt.step()
print('Iter {}: {}'.format(i_iter, loss.item()), flush=True)
print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
print('\tsigned_dist_e Loss: {}'.format(signed_dist_e.item()))
print('\trel_e Loss: {}'.format(rel_e.item()))
print('\tdist_e Loss: {}'.format(dist_e.item()))
print('\tdist_dec_jts_to_sampled_pts Loss: {}'.format(dist_dec_jts_to_sampled_pts.item()))
fine_lr = 0.1
num_iters = 1000
opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=fine_lr)
for i_iter in range(num_iters):
opt.zero_grad()
# mano_layer #
hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
beta_var.unsqueeze(1).repeat(1, ws, 1).view(-1, 10), transl_var)
hand_verts = hand_verts.view(bsz, ws, 778, 3) * 0.001 ## bsz x ws x nn
hand_joints = hand_joints.view(bsz, ws, -1, 3) * 0.001
### === e1 should be close to predicted values === ###
# bsz x ws x nnj x nnb x 3 #
rel_base_pts_to_hand_joints = hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)
# bs zx ws x nnj x nnb #
signed_dist_base_pts_to_hand_joints = torch.sum(
rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
)
rel_e = torch.sum(
(rel_base_pts_to_hand_joints - rel_base_pts_to_joints) ** 2, dim=-1
).mean()
# dists_base_pts_to_joints ## dists_base_pts_to_joints ##
if dists_base_pts_to_joints is not None: ## dists_base_pts_to_joints ##
dist_e = torch.sum(
(signed_dist_base_pts_to_hand_joints - dists_base_pts_to_joints) ** 2, dim=-1
).mean()
else:
dist_e = torch.zeros((1,), dtype=torch.float32).mean()
### === e2 the signed distances to nearest points should not be negative to the neareste === ###
## base_pts: bsz x nn_base_pts x 3
## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ##
## bsz x ws x nnj x nnb ##
dist_rhand_joints_to_base_pts = torch.sum(
(hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1
)
# minn_dists_idxes: bsz x ws x nnj -->
minn_dists_to_base_pts, minn_dists_idxes = torch.min(
dist_rhand_joints_to_base_pts, dim=-1
)
# base_pts: bsz x nn_base_pts x 3 #
# base_pts: bsz x ws x nn_base_pts x 3 #
# bsz x ws x nnj
# base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous()
# bsz x ws x nnj x 3 ##
nearest_base_pts = batched_index_select_ours(
base_pts_exp, indices=minn_dists_idxes, dim=2
)
# bsz x ws x nnj x 3 #
nearest_base_normals = batched_index_select_ours(
base_normals_exp, indices=minn_dists_idxes, dim=2
)
# bsz x ws x nnj x 3 #
rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts
# bsz x ws x nnj #
signed_dist_joints_to_base_pts = torch.sum(
rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1
)
# should not be negative
signed_dist_mask = signed_dist_joints_to_base_pts < 0.
l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt(
torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1)
) < 0.05
signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_nearest_base_pts_mask.float()) > 1.5
### ==== mean of signed distances ==== ###
signed_dist_e = torch.sum(
-1.0 * signed_dist_joints_to_base_pts[signed_dist_mask]
) / torch.clamp(
torch.sum(signed_dist_mask.float()), min=1e-5
).item()
## === e3 smoothness and prior losses === ##
pose_smoothness_loss = F.mse_loss(theta_var.view(bsz, ws, -1)[:, 1:], theta_var.view(bsz, ws, -1)[:, :-1])
shape_prior_loss = torch.mean(beta_var**2)
pose_prior_loss = torch.mean(theta_var**2)
## === e3 smoothness and prior losses === ##
## === e4 hand joints should be close to sampled hand joints === ##
dist_dec_jts_to_sampled_pts = torch.sum(
(hand_joints - sampled_joints) ** 2, dim=-1
).mean()
loss = pose_smoothness_loss * 0.05 + shape_prior_loss*0.001 + pose_prior_loss * 0.0001 + signed_dist_e * signed_dist_e_coeff + rel_e + dist_e + dist_dec_jts_to_sampled_pts
loss.backward()
opt.step()
print('Iter {}: {}'.format(i_iter, loss.item()), flush=True)
print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
print('\tsigned_dist_e Loss: {}'.format(signed_dist_e.item()))
print('\trel_e Loss: {}'.format(rel_e.item()))
print('\tdist_e Loss: {}'.format(dist_e.item()))
print('\tdist_dec_jts_to_sampled_pts Loss: {}'.format(dist_dec_jts_to_sampled_pts.item()))
# tot_obj_trimeshes
### refine the optimization with signed energy ##
#
# signed_dist_jts_to_nearest_base_pts = []
# tot_nearest_base_pts = []
# tot_nearest_base_normals = []
signed_dist_e_coeff = 1.0 #
fine_lr = 0.1
# num_iters = 1000 #
num_iters = 100 # reinement #
opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=fine_lr)
for i_iter in range(num_iters): #
opt.zero_grad()
# mano_layer #
hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
beta_var.unsqueeze(1).repeat(1, ws, 1).view(-1, 10), transl_var)
hand_verts = hand_verts.view(bsz, ws, 778, 3) * 0.001 ## bsz x ws x nn
hand_joints = hand_joints.view(bsz, ws, -1, 3) * 0.001
### === e1 should be close to predicted values === ###
# bsz x ws x nnj x nnb x 3 #
rel_base_pts_to_hand_joints = hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)
# bs zx ws x nnj x nnb #
signed_dist_base_pts_to_hand_joints = torch.sum(
rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
)
rel_e = torch.sum(
(rel_base_pts_to_hand_joints - rel_base_pts_to_joints) ** 2, dim=-1
).mean()
# dists_base_pts_to_joints ## dists_base_pts_to_joints ##
if dists_base_pts_to_joints is not None: ## dists_base_pts_to_joints ##
dist_e = torch.sum(
(signed_dist_base_pts_to_hand_joints - dists_base_pts_to_joints) ** 2, dim=-1
).mean()
else:
dist_e = torch.zeros((1,), dtype=torch.float32).mean()
### ==== inside the objemesh labels ==== ###
# bsz x ws x nnj # --> objmesh insides pts labels #
pts_inside_objmesh_labels = judge_penetrated_points(tot_obj_trimeshes, hand_joints)
pts_inside_objmesh_labels_mask = pts_inside_objmesh_labels.bool()
# {
# hard projections for
''' strategy 1: use nearest base pts, rel, dists for resolving '''
### === e2 the signed distances to nearest points should not be negative to the neareste === ###
## base_pts: bsz x nn_base_pts x 3
## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ##
## bsz x ws x nnj x nnb ##
dist_rhand_joints_to_base_pts = torch.sum(
(hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1
)
# minn_dists_idxes: bsz x ws x nnj #
# base_pts
minn_dists_to_base_pts, minn_dists_idxes = torch.min(
dist_rhand_joints_to_base_pts, dim=-1
)
# base_pts: bsz x nn_base_pts x 3 #
# base_pts: bsz x ws x nn_base_pts x 3 #
# bsz x ws x nnj
# base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous()
# bsz x ws x nnj x 3 ##
# simple penetration ##
nearest_base_pts = batched_index_select_ours(
base_pts_exp, indices=minn_dists_idxes, dim=2
)
# bsz x ws x nnj x 3 # #
nearest_base_normals = batched_index_select_ours(
base_normals_exp, indices=minn_dists_idxes, dim=2
)
tot_masks = []
tot_base_pts = []
tot_base_normals = []
tot_base_signed_dists = []
## === nearest base pts === ##
for i_bsz in range(nearest_base_pts.size(0)):
# masks, base_pts, base_normals for each frame here
# cur_bsz_
cur_bsz_masks = [pts_inside_objmesh_labels_mask[i_bsz][0]]
cur_bsz_base_pts = [nearest_base_pts[i_bsz][0]]
cur_bsz_base_normals = [nearest_base_normals[i_bsz][0]]
# nnjts #
## st frame signed dist ##
cur_bsz_st_frame_signed_dist = torch.sum(
(hand_joints[i_bsz][0] - cur_bsz_base_pts[0]) * cur_bsz_base_normals[0], dim=-1
)
cur_bsz_signed_dist = [cur_bsz_st_frame_signed_dist]
for i_fr in range(1, nearest_base_pts.size(1)):
cur_bsz_cur_fr_jts = hand_joints[i_bsz][i_fr]
# cur_bsz_cur_fr_base_pts = nearest_base_pts
# cur_fr_jts -
cur_bsz_cur_fr_prev_fr_signed_dist = torch.sum(
(cur_bsz_cur_fr_jts - cur_bsz_base_pts[-1]) * cur_bsz_base_normals[-1], dim=-1
)
# nnjts # cur
cur_bsz_cur_fr_mask = ((cur_bsz_signed_dist[-1] >= 0.).float() + (cur_bsz_cur_fr_prev_fr_signed_dist < 0.).float()) > 1.5
cur_bsz_cur_fr_base_pts = nearest_base_pts[i_bsz][i_fr].clone()
cur_bsz_cur_fr_base_pts[cur_bsz_cur_fr_mask] = cur_bsz_base_pts[-1][cur_bsz_cur_fr_mask]
cur_bsz_cur_fr_base_normals = nearest_base_normals[i_bsz][i_fr].clone()
# ### curbsz curfr base normals; ### #
cur_bsz_cur_fr_base_normals[cur_bsz_cur_fr_mask] = cur_bsz_base_normals[-1][cur_bsz_cur_fr_mask]
cur_bsz_cur_fr_signed_dist = torch.sum(
(cur_bsz_cur_fr_jts - cur_bsz_cur_fr_base_pts) * cur_bsz_cur_fr_base_normals, dim=-1
)
cur_bsz_cur_fr_signed_dist[cur_bsz_cur_fr_mask] = 0. # ot the bes points
### for masks ###
cur_bsz_masks.append(cur_bsz_cur_fr_mask)
cur_bsz_base_pts.append(cur_bsz_cur_fr_base_pts)
cur_bsz_base_normals.append(cur_bsz_cur_fr_base_normals)
#
cur_bsz_masks = torch.stack(cur_bsz_masks, dim=0)
cur_bsz_base_pts = torch.stack(cur_bsz_base_pts, dim=0)
cur_bsz_base_normals = torch.stack(cur_bsz_base_normals, dim=0)
cur_bsz_signed_dist = torch.stack(cur_bsz_signed_dist, dim=0)
tot_masks.append(cur_bsz_masks)
tot_base_pts.append(cur_bsz_base_pts)
tot_base_normals.append(cur_bsz_base_normals)
tot_base_signed_dists.append(cur_bsz_signed_dist)
# masks;
tot_masks = torch.stack(tot_masks, dim=0)
tot_base_pts = torch.stack(tot_base_pts, dim=0)
tot_base_normals = torch.stack(tot_base_normals, dim=0)
tot_base_signed_dists = torch.stack(tot_base_signed_dists, dim=0)
#
nearest_base_pts = tot_base_pts.clone() # tot base pts
nearest_base_normals = tot_base_normals.clone()
pts_inside_objmesh_labels_mask = tot_masks.clone()
# if len()
# bsz x ws x nnj x 3 #
rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts
# signed_dist_joints_to_base_pts: bsz x ws x nnj # -> disstances
# signed_dist_joints_to_base_pts = signed_dist_joints_to_base_pts.detach()
#
# dot rel
# penetraition resolving --- strategy
# dot_rel_with_normals = torch.sum( # dot rhand joints with normals #
# rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1
# )
#
dot_rel_with_normals = torch.sum( # dot rhand joints with normals #
-rel_joints_to_nearest_base_pts * rel_joints_to_nearest_base_pts, dim=-1
)
#### Get masks for penetrated joint points ####
# signed_dist_mask = (signed_dist_mask.float() + pts_inside_objmesh_labels_mask.float()) > 1.5
signed_dist_mask = pts_inside_objmesh_labels_mask
# bsz x ws x nnj
signed_dist_mask = signed_dist_mask.detach() # detach the mask #
# bsz x ws x nnj --> the loss term
## signed distances 3 #### isgned distance 3 ###
## dotrelwithnormals, ##
# # signed_dist_mask -> the distances
# dot_rel_with_normals: bsz x ws x nnj x nnb # avg over windows and batches #
avg_masks = (signed_dist_mask.float()).sum(dim=-1).mean()
## get singed distance energies ### ## projection ##
# signed_dist_e = dot_rel_with_normals * signed_dist_joints_to_base_pts
### dot_rel_with_normals -->
signed_dist_e = -1.0 * dot_rel_with_normals
signed_dist_e = torch.sum(
signed_dist_e[signed_dist_mask]
) / torch.clamp(torch.sum(signed_dist_mask.float()), min=1e-5).item()
###### ====== get loss for signed distances ==== ###
''' strategy 1: use nearest base pts, rel, dists for resolving '''
# cannot mask in some caes
# change of isgned distances #
# intersection spline
## judeg whether inside the object and only project those one inside of the object
## === e3 smoothness and prior losses === ##
pose_smoothness_loss = F.mse_loss(theta_var.view(bsz, ws, -1)[:, 1:], theta_var.view(bsz, ws, -1)[:, :-1])
shape_prior_loss = torch.mean(beta_var**2)
pose_prior_loss = torch.mean(theta_var**2)
## === e3 smoothness and prior losses === ##
# points to object vertices
#### ==== sv_dict ==== ####
sv_dict = {
'pts_inside_objmesh_labels_mask': pts_inside_objmesh_labels_mask.detach().cpu().numpy(),
'hand_joints': hand_joints.detach().cpu().numpy(),
'obj_verts': [cur_verts.detach().cpu().numpy() for cur_verts in obj_verts],
'obj_faces': [cur_faces.detach().cpu().numpy() for cur_faces in obj_faces],
'base_pts': base_pts.detach().cpu().numpy(),
'base_normals': base_normals.detach().cpu().numpy(), # bsz x nnb x 3 -> bsz x nnb x 3 -> base normals #
'nearest_base_pts': nearest_base_pts.detach().cpu().numpy(), # bsz x ws x nnj x 3 #
'nearest_base_normals': nearest_base_normals.detach().cpu().numpy(), # bsz x ws x nnj x 3 --> base normals and pts
}
#
sv_dict_folder = "/data1/sim/mdm/tmp_saving"
os.makedirs(sv_dict_folder, exist_ok=True)
sv_dict_fn = os.path.join(sv_dict_folder, f"optim_iter_{i_iter}.npy")
np.save(sv_dict_fn, sv_dict)
print(f"Obj and subj saved to {sv_dict_fn}")
#### ==== sv_dict ==== ####
## === e4 hand joints should be close to sampled hand joints === ##
dist_dec_jts_to_sampled_pts = torch.sum(
(hand_joints - sampled_joints) ** 2, dim=-1
).mean()
# shoudl take a
# how to proejct the jvertex
# hwo to project the veretex
# weighted sum of the projectiondirection
# weights of each base point
# atraction field -> should be able to learn the penetration resolving strategy
# stochestic penetration resolving strategy #
loss = pose_smoothness_loss * 0.05 + shape_prior_loss*0.001 + pose_prior_loss * 0.0001 + signed_dist_e * signed_dist_e_coeff + rel_e + dist_e + dist_dec_jts_to_sampled_pts
loss.backward()
opt.step()
print('Iter {}: {}'.format(i_iter, loss.item()), flush=True)
print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
print('\tsigned_dist_e Loss: {}'.format(signed_dist_e.item()))
print('\trel_e Loss: {}'.format(rel_e.item()))
print('\tdist_e Loss: {}'.format(dist_e.item()))
print('\tdist_dec_jts_to_sampled_pts Loss: {}'.format(dist_dec_jts_to_sampled_pts.item()))
# avg_masks
print('\tAvg masks: {}'.format(avg_masks.item()))
''' returning sampled_joints '''
sampled_joints = hand_joints
np.save("optimized_verts.npy", hand_verts.detach().cpu().numpy())
print(f"Optimized verts saved to optimized_verts.npy")
return sampled_joints.detach()
##
def create_gaussian_diffusion(args): ## create guassian diffusion ##
# default params
predict_xstart = True # we always predict x_start (a.k.a. x0), that's our deal!
steps = 1000 #
scale_beta = 1. # no scaling
timestep_respacing = '' # can be used for ddim sampling, we don't use it.
learn_sigma = False # learn sigma #
rescale_timesteps = False
## noose schedule; steps; scale_beta ## ## MSE ##
betas = gd.get_named_beta_schedule(args.noise_schedule, steps, scale_beta)
loss_type = gd.LossType.MSE
if not timestep_respacing:
timestep_respacing = [steps]
print(f"dataset: {args.dataset}, rep_type: {args.rep_type}")
if args.dataset in ['motion_ours'] and args.rep_type in ["obj_base_rel_dist", "ambient_obj_base_rel_dist"]:
print(f"here! dataset: {args.dataset}, rep_type: {args.rep_type}")
cur_spaced_diffusion_model = SpacedDiffusion_Ours
# SpacedDiffusion_OursV2
elif args.dataset in ['motion_ours'] and args.rep_type in ["obj_base_rel_dist_we"]:
cur_spaced_diffusion_model = SpacedDiffusion_OursV2
elif args.dataset in ['motion_ours'] and args.rep_type in ["obj_base_rel_dist_we_wj"]:
cur_spaced_diffusion_model = SpacedDiffusion_OursV3
# SpacedDiffusion_OursV4
elif args.dataset in ['motion_ours'] and args.rep_type in ["obj_base_rel_dist_we_wj_latents"]:
if args.diff_joint_quants:
cur_spaced_diffusion_model = SpacedDiffusion_OursV7
elif args.diff_hand_params:
cur_spaced_diffusion_model = SpacedDiffusion_OursV9
else:
if args.diff_spatial:
cur_spaced_diffusion_model = SpacedDiffusion_OursV5
elif args.diff_latents:
cur_spaced_diffusion_model = SpacedDiffusion_OursV6
else:
cur_spaced_diffusion_model = SpacedDiffusion_OursV4
else:
cur_spaced_diffusion_model = SpacedDiffusion
### ==== predict xstart other than the noise in the model === ###
return cur_spaced_diffusion_model(
use_timesteps=space_timesteps(steps, timestep_respacing),
betas=betas,
model_mean_type=(
gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
),
model_var_type=( ## use fixed sigmas / variances ##
(
gd.ModelVarType.FIXED_LARGE
if not args.sigma_small
else gd.ModelVarType.FIXED_SMALL # fixed small #
)
if not learn_sigma ## use learned sigmas ##
else gd.ModelVarType.LEARNED_RANGE
), ## modelvartype ##
loss_type=loss_type, ## loss_type ##
rescale_timesteps=rescale_timesteps,
lambda_vel=args.lambda_vel,
lambda_rcxyz=args.lambda_rcxyz, ## lambda
lambda_fc=args.lambda_fc,
# motion_to_rep
denoising_stra=args.denoising_stra,
inter_optim=args.inter_optim,
args=args,
)
### from decoded energies to optimized joints ###
## latent variables ##
## encoded energies ## from energies calculated from perturbed energies ##
## decoded energies should also match the clean energy term ##
## and those values should be all denormed ##
def optimize_joints_according_to_e(dec_joints, base_pts, base_normals, dec_e):
# dec_e_along_normals: bsz x (ws - 1) x nnj x nnb
dec_e_along_normals = dec_e['dec_e_along_normals']
# dec_e_vt_normals: bsz x (ws - 1) x nnj x nnb
dec_e_vt_normals = dec_e['dec_e_vt_normals']
nn_iters = 10
coarse_lr = 0.001
dec_joints.requires_grad_()
opt = optim.Adam([dec_joints], lr=coarse_lr)
for i_iter in range(nn_iters):
# dec_joints: bsz x ws x nnj x 3
# base_pts: bsz x nnb x 3
k_f = 1.
# bsz x ws x nnj x nnb x 3 #
denormed_rel_base_pts_to_rhand_joints = dec_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)
k_f = 1. ## l2 rel base pts to pert rhand joints ##
# l2_rel_base_pts_to_pert_rhand_joints: bsz x nf x nnj x nnb #
l2_rel_base_pts_to_pert_rhand_joints = torch.norm(denormed_rel_base_pts_to_rhand_joints, dim=-1)
### att_forces ##
att_forces = torch.exp(-k_f * l2_rel_base_pts_to_pert_rhand_joints) # bsz x nf x nnj x nnb #
# bsz x (ws - 1) x nnj x nnb #
att_forces = att_forces[:, :-1, :, :] # attraction forces -1 #
# rhand_joints: ws x nnj x 3 # -> (ws - 1) x nnj x 3 ## rhand_joints ##
# bsz x (ws - 1) x nnj x 3 --> displacements s#
denormed_rhand_joints_disp = dec_joints[:, 1:, :, :] - dec_joints[:, :-1, :, :]
# distance -- base_normalss,; (ws - 1) x nnj x nnb x 3 --> bsz x (ws - 1) x nnj x nnb #
# signed_dist_base_pts_to_pert_rhand_joints_along_normal # bsz x (ws - 1) x nnj x nnb #
signed_dist_base_pts_to_rhand_joints_along_normal = torch.sum(
base_normals.unsqueeze(1).unsqueeze(1) * denormed_rhand_joints_disp.unsqueeze(-2), dim=-1
)
# rel_base_pts_to_pert_rhand_joints_vt_normal: bsz x (ws -1) x nnj x nnb x 3 -> the relative positions vertical to base normals #
rel_base_pts_to_rhand_joints_vt_normal = denormed_rhand_joints_disp.unsqueeze(-2) - signed_dist_base_pts_to_rhand_joints_along_normal.unsqueeze(-1) * base_normals.unsqueeze(1).unsqueeze(1)
dist_base_pts_to_rhand_joints_vt_normal = torch.sqrt(torch.sum(
rel_base_pts_to_rhand_joints_vt_normal ** 2, dim=-1
))
k_a = 1.
k_b = 1.
### bsz x (ws - 1) x nnj x nnb ###
e_disp_rel_to_base_along_normals = k_a * att_forces * torch.abs(signed_dist_base_pts_to_rhand_joints_along_normal)
# (ws - 1) x nnj x nnb # -> dist vt normals # ##
e_disp_rel_to_baes_vt_normals = k_b * att_forces * dist_base_pts_to_rhand_joints_vt_normal
# nf x nnj x nnb ---> dist_vt_normals -> nf x nnj x nnb # # torch.sqrt() ##
#
loss_cur_e_pred_e_along_normals = ((e_disp_rel_to_base_along_normals - dec_e_along_normals) ** 2).mean()
loss_cur_e_pred_e_vt_normals = ((e_disp_rel_to_baes_vt_normals - dec_e_vt_normals) ** 2).mean()
loss = loss_cur_e_pred_e_along_normals + loss_cur_e_pred_e_vt_normals
opt.zero_grad()
loss.backward()
opt.step()
print('Iter {}: {}'.format(i_iter, loss.item()), flush=True)
print('\tloss_cur_e_pred_e_along_normals: {}'.format(loss_cur_e_pred_e_along_normals.item()))
print('\tloss_cur_e_pred_e_vt_normals: {}'.format(loss_cur_e_pred_e_vt_normals.item()))
return dec_joints.detach()