import sys sys.path.insert(0, '.') sys.path.insert(0, '..') import torch import torch.nn.functional as F from torch import optim import numpy as np # import os, argparse, copy, json # import pickle as pkl # from scipy.spatial.transform import Rotation as R # from psbody.mesh import Mesh from manopth.manolayer import ManoLayer # from dataloading import GRAB_Single_Frame, GRAB_Single_Frame_V6, GRAB_Single_Frame_V7, GRAB_Single_Frame_V8, GRAB_Single_Frame_V9, GRAB_Single_Frame_V9_Ours, GRAB_Single_Frame_V10 # use_trans_encoders # from model import TemporalPointAE, TemporalPointAEV2, TemporalPointAEV5, TemporalPointAEV6, TemporalPointAEV7, TemporalPointAEV8, TemporalPointAEV9, TemporalPointAEV10, TemporalPointAEV4, TemporalPointAEV3_Real, TemporalPointAEV11, TemporalPointAEV12, TemporalPointAEV13, TemporalPointAEV14, TemporalPointAEV17, TemporalPointAEV19, TemporalPointAEV20, TemporalPointAEV21, TemporalPointAEV22, TemporalPointAEV23, TemporalPointAEV24, TemporalPointAEV25, TemporalPointAEV26 # import trimesh from utils import * # import utils import utils.model_util as model_util # from anchorutils import anchor_load_driver, recover_anchor, recover_anchor_batch # K x (1 + 1) # minimum distance; disp - k * disp_o(along_disp_dir) (l2 norm); k * disp_o(vertical_disp_dir) (l2 norm) -> how those # # object moving and the contact information ? # # only textures on the hand vertices # themselves # ## the effectiveness of those values themselves --> # torch, not_batched # def calculate_disp_quants(joints, base_pts_trans, minn_base_pts_idxes=None): # joints: nf x nn_joints x 3; # base_pts_trans: nf x nn_base_pts x 3; # base pts trans # # nf - 1 # dist_joints_to_base_pts = torch.sum( (joints.unsqueeze(-2) - base_pts_trans.unsqueeze(1)) ** 2, dim=-1 # nf x nn_joints x nn_base_pts x 3 --> nf x nnjoints x nnbasepts ) cur_dist_joints_to_base_pts, cur_minn_base_pts_idxes = torch.min(dist_joints_to_base_pts, dim=-1) # nf x nnjoints if minn_base_pts_idxes is None: minn_base_pts_idxes = cur_minn_base_pts_idxes # dist_joints_to_base_pts: nf nn_joints # dist_joints_to_base_pts = model_util.batched_index_select_ours(dist_joints_to_base_pts, minn_base_pts_idxes.unsqueeze(-1), dim=2).squeeze(-1) dist_joints_to_base_pts = torch.sqrt(dist_joints_to_base_pts) # nf x nn_joints # k_f = 1. # dist_joints_to_base_pts --> nf x nn_joints x nn_base_pts # k = torch.exp(-1. * k_f * (dist_joints_to_base_pts.detach())) # 0 -> 1 value # # nf x nn_joints # nf x nn_joints # disp_base_pts = base_pts_trans[1:] -base_pts_trans[:-1] # basepts trans # disp_joints = joints[1:] - joints[:-1] # (nf - 1) x nn_joints x 3 --> for joints displacement here # minn_base_pts_idxes = minn_base_pts_idxes[:-1] k = k[:-1] dir_disp_base_pts = disp_base_pts / torch.clamp(torch.norm(disp_base_pts, p=2, keepdim=True, dim=-1), min=1e-9) # (nf - 1) x nn_base_pts x 3 dir_disp_base_pts = model_util.batched_index_select_ours(dir_disp_base_pts, minn_base_pts_idxes.detach(), dim=1) # (nf - 1) x nf x 3 disp_base_pts = model_util.batched_index_select_ours(disp_base_pts, minn_base_pts_idxes.detach(), dim=1) # disp along base disp dir # disp_along_base_disp_dir = disp_joints * dir_disp_base_pts # (nf - 1) x nn_joints x 3 # along disp dir disp_vt_base_disp_dir = disp_joints - disp_along_base_disp_dir # (nf - 1) x nn_joints x 3 # vt disp dir # disp; disp optimziation; and the distances between disps # # moving consistency correction --> but not the optimization? # dist_disp_along_dir = disp_base_pts - k.unsqueeze(-1) * disp_along_base_disp_dir dist_disp_along_dir = torch.norm(dist_disp_along_dir, dim=-1, p=2) # (nf - 1) x nn_joints # dist_disp_along_dir dist_disp_vt_dir = torch.norm(disp_vt_base_disp_dir, dim=-1, p=2) # (nf - 1) x nn_joints # dist_joints_to_base_pts_disp = dist_joints_to_base_pts[:-1] # (nf - 1) x nn_joints # return dist_joints_to_base_pts_disp, dist_disp_along_dir, dist_disp_vt_dir # batched get quantities here # # torch, not_batched # # dist_joints_to_base_pts_disp, dist_disp_along_dir, dist_disp_vt_dir = calculate_disp_quants_batched(joints, base_pts_trans) def calculate_disp_quants_batched(joints, base_pts_trans): # joints: nf x nn_joints x 3; # base_pts_trans: nf x nn_base_pts x 3; # nf - 1 # nf x nn_joints x nn_base_pts x 3 # dist_joints_to_base_pts = torch.sum( (joints.unsqueeze(-2) - base_pts_trans.unsqueeze(-3)) ** 2, dim=-1 # nf x nn_joints x nn_base_pts x 3 --> nf x nnjoints x nnbasepts ) dist_joints_to_base_pts, minn_base_pts_idxes = torch.min(dist_joints_to_base_pts, dim=-1) # nf x nnjoints dist_joints_to_base_pts = torch.sqrt(dist_joints_to_base_pts) # nf x nn_joints # k_f = 1. k = torch.exp(-1. * k_f * (dist_joints_to_base_pts)) # 0 -> 1 value # # nf x nn_joints # nf x nn_joints # disp_base_pts = base_pts_trans[:, 1:] - base_pts_trans[:, :-1] # basepts trans # disp_joints = joints[:, 1:] - joints[:, :-1] # (nf - 1) x nn_joints x 3 --> for joints displacement here # minn_base_pts_idxes = minn_base_pts_idxes[:, :-1] # bsz x (nf - 1) # k = k[:, :-1] dir_disp_base_pts = disp_base_pts / torch.clamp(torch.norm(disp_base_pts, p=2, keepdim=True, dim=-1), min=1e-23) # (nf - 1) x nn_base_pts x 3 dir_disp_base_pts = model_util.batched_index_select_ours(dir_disp_base_pts, minn_base_pts_idxes, dim=2) # (nf - 1) x nnjoints x 3 # disp_base_pts, minn_base_pts_idxes --> bsz x (nf - 1) x nnjoints disp_base_pts = model_util.batched_index_select_ours(disp_base_pts, minn_base_pts_idxes, dim=2) disp_along_base_disp_dir = disp_joints * dir_disp_base_pts # bsz x (nf - 1) x nn_joints x 3 # along disp dir disp_vt_base_disp_dir = disp_joints - disp_along_base_disp_dir # bsz x (nf - 1) x nn_joints x 3 # vt disp dir # disp_base_pts -> bsz x (nf - 1) x njoints x 3 # dist_disp_along_dir dist_disp_along_dir = disp_base_pts - k.unsqueeze(-1) * disp_along_base_disp_dir dist_disp_along_dir = torch.norm(dist_disp_along_dir, dim=-1, p=2) # bsz x (nf - 1) x nn_joints # dist_disp_along_dir dist_disp_vt_dir = torch.norm(disp_vt_base_disp_dir, dim=-1, p=2) # bsz x (nf - 1) x nn_joints # dist_joints_to_base_pts_disp = dist_joints_to_base_pts[:, :-1] # bsz x (nf - 1) x nn_joints # return dist_joints_to_base_pts_disp, dist_disp_along_dir, dist_disp_vt_dir # batched get quantities here # # torch, not_batched # # dist_joints_to_base_pts_disp, dist_disp_along_dir, dist_disp_vt_dir = calculate_disp_quants_batched(joints, base_pts_trans) def calculate_disp_quants_v2(joints, base_pts_trans, canon_joints, canon_base_normals): # joints: nf x nn_joints x 3; # base_pts_trans: nf x nn_base_pts x 3; # nf - 1 # nf x nn_joints x nn_base_pts x 3 # # joints: nf x nn_joints x 3 # --> nf x nn_joint x 1 x 3 - nf x 1 x nn_base_pts x 3 --> nf x nn_jts x nn_base_pts x 3 # # base_pts_trans: nf x nn_base_pt x 3 # dist_joints_to_base_pts = torch.sum( (joints.unsqueeze(-2) - base_pts_trans.unsqueeze(-3)) ** 2, dim=-1 # nf x nn_joints x nn_base_pts x 3 --> nf x nnjoints x nnbasepts ) dist_joints_to_base_pts, minn_base_pts_idxes = torch.min(dist_joints_to_base_pts, dim=-1) # nf x nnjoints dist_joints_to_base_pts = torch.sqrt(dist_joints_to_base_pts) # nf x nn_joints # k_f = 1. k = torch.exp(-1. * k_f * (dist_joints_to_base_pts)) # 0 -> 1 value # # nf x nn_joints # nf x nn_joints # ### ### base pts velocity ### disp_base_pts = base_pts_trans[1:] - base_pts_trans[:-1] # basepts trans # ### joints velocity ### disp_joints = joints[1:] - joints[:-1] # (nf - 1) x nn_joints x 3 --> for joints displacement here # minn_base_pts_idxes = minn_base_pts_idxes[:-1] # bsz x (nf - 1) # k = k[:-1] ### joints velocity in the canonicalized space ### disp_canon_joints = canon_joints[1:] - canon_joints[:-1] ### baes points normals information ### disp_canon_base_normals = canon_base_normals[:-1] # bsz x (nf - 1) x 3 --> normals of base points ## # bsz x (nf - 1) x nn_joints x 3 ## disp_canon_base_normals = model_util.batched_index_select_ours(values=disp_canon_base_normals, indices=minn_base_pts_idxes, dim=1) ### joint velocity along normals ### disp_joints_along_normals = disp_canon_base_normals * disp_canon_joints dotprod_disp_joints_along_normals = disp_joints_along_normals.sum(dim=-1) # bsz x (nf - 1) x nn_joints disp_joints_vt_normals = disp_canon_joints - dotprod_disp_joints_along_normals.unsqueeze(-1) * disp_canon_base_normals l2_disp_joints_vt_normals = torch.norm(disp_joints_vt_normals, p=2, keepdim=False, dim=-1) # bsz x (nf - 1) x nn_joints # --> for l2 norm vt normals # l2 normal of the disp_joints ### # dir_disp_base_pts = disp_base_pts / torch.clamp(torch.norm(disp_base_pts, p=2, keepdim=True, dim=-1), min=1e-23) # (nf - 1) x nn_base_pts x 3 # dir_disp_base_pts = model_util.batched_index_select_ours(dir_disp_base_pts, minn_base_pts_idxes, dim=2) # (nf - 1) x nnjoints x 3 # # disp_base_pts, minn_base_pts_idxes --> bsz x (nf - 1) x nnjoints disp_base_pts = model_util.batched_index_select_ours(disp_base_pts, minn_base_pts_idxes, dim=1) # disp_along_base_disp_dir = disp_joints * dir_disp_base_pts # bsz x (nf - 1) x nn_joints x 3 # along disp dir # disp_vt_base_disp_dir = disp_joints - disp_along_base_disp_dir # bsz x (nf - 1) x nn_joints x 3 # vt disp dir # # disp_base_pts -> bsz x (nf - 1) x njoints x 3 # dist_disp_along_dir # dist_disp_along_dir = disp_base_pts - k.unsqueeze(-1) * disp_along_base_disp_dir # dist_disp_along_dir = torch.norm(dist_disp_along_dir, dim=-1, p=2) # bsz x (nf - 1) x nn_joints # dist_disp_along_dir # dist_disp_vt_dir = torch.norm(disp_vt_base_disp_dir, dim=-1, p=2) # bsz x (nf - 1) x nn_joints # dist_joints_to_base_pts_disp = dist_joints_to_base_pts[:-1] # bsz x (nf - 1) x nn_joints # # disp_base_pts: (nf - 1) x nn_joints x 3 # -> disp of base pts for each joint # # dist_joints_to_base_pts_disp: (nf - 1) x nn_joints # # dotprod_disp_joints_along_normals: (nf - 1) x nn_joints # # l2_disp_joints_vt_normals: (nf - 1) x nn_joints # # disp_base_pts, dist_joints_to_base_pts_disp, dotprod_disp_joints_along_normals, l2_disp_joints_vt_normals # return disp_base_pts, dist_joints_to_base_pts_disp, dotprod_disp_joints_along_normals, l2_disp_joints_vt_normals # batched get quantities here # # torch, not_batched # # dist_joints_to_base_pts_disp, dist_disp_along_dir, dist_disp_vt_dir = calculate_disp_quants_batched(joints, base_pts_trans) def calculate_disp_quants_batched_v2(joints, base_pts_trans, canon_joints, canon_base_normals): # joints: nf x nn_joints x 3; # base_pts_trans: nf x nn_base_pts x 3; # nf - 1 # nf x nn_joints x nn_base_pts x 3 # dist_joints_to_base_pts = torch.sum( (joints.unsqueeze(-2) - base_pts_trans.unsqueeze(-3)) ** 2, dim=-1 # nf x nn_joints x nn_base_pts x 3 --> nf x nnjoints x nnbasepts ) dist_joints_to_base_pts, minn_base_pts_idxes = torch.min(dist_joints_to_base_pts, dim=-1) # nf x nnjoints dist_joints_to_base_pts = torch.sqrt(dist_joints_to_base_pts) # nf x nn_joints # k_f = 1. k = torch.exp(-1. * k_f * (dist_joints_to_base_pts)) # 0 -> 1 value # # nf x nn_joints # nf x nn_joints # ### ### base pts velocity ### disp_base_pts = base_pts_trans[:, 1:] - base_pts_trans[:, :-1] # basepts trans # ### joints velocity ### disp_joints = joints[:, 1:] - joints[:, :-1] # (nf - 1) x nn_joints x 3 --> for joints displacement here # minn_base_pts_idxes = minn_base_pts_idxes[:, :-1] # bsz x (nf - 1) # k = k[:, :-1] ### joints velocity in the canonicalized space ### disp_canon_joints = canon_joints[:, 1:] - canon_joints[:, :-1] ### baes points normals information ### disp_canon_base_normals = canon_base_normals[:, :-1] # bsz x (nf - 1) x 3 --> normals of base points ## # bsz x (nf - 1) x nn_joints x 3 ## disp_canon_base_normals = model_util.batched_index_select_ours(values=disp_canon_base_normals, indices=minn_base_pts_idxes, dim=2) ### joint velocity along normals ### disp_joints_along_normals = disp_canon_base_normals * disp_canon_joints dotprod_disp_joints_along_normals = disp_joints_along_normals.sum(dim=-1) # bsz x (nf - 1) x nn_joints disp_joints_vt_normals = disp_canon_joints - dotprod_disp_joints_along_normals.unsqueeze(-1) * disp_canon_base_normals l2_disp_joints_vt_normals = torch.norm(disp_joints_vt_normals, p=2, keepdim=False, dim=-1) # bsz x (nf - 1) x nn_joints # --> for l2 norm vt normals # dir_disp_base_pts = disp_base_pts / torch.clamp(torch.norm(disp_base_pts, p=2, keepdim=True, dim=-1), min=1e-23) # (nf - 1) x nn_base_pts x 3 # dir_disp_base_pts = model_util.batched_index_select_ours(dir_disp_base_pts, minn_base_pts_idxes, dim=2) # (nf - 1) x nnjoints x 3 # # disp_base_pts, minn_base_pts_idxes --> bsz x (nf - 1) x nnjoints disp_base_pts = model_util.batched_index_select_ours(disp_base_pts, minn_base_pts_idxes, dim=2) # disp_along_base_disp_dir = disp_joints * dir_disp_base_pts # bsz x (nf - 1) x nn_joints x 3 # along disp dir # disp_vt_base_disp_dir = disp_joints - disp_along_base_disp_dir # bsz x (nf - 1) x nn_joints x 3 # vt disp dir # # disp_base_pts -> bsz x (nf - 1) x njoints x 3 # dist_disp_along_dir # dist_disp_along_dir = disp_base_pts - k.unsqueeze(-1) * disp_along_base_disp_dir # dist_disp_along_dir = torch.norm(dist_disp_along_dir, dim=-1, p=2) # bsz x (nf - 1) x nn_joints # dist_disp_along_dir # dist_disp_vt_dir = torch.norm(disp_vt_base_disp_dir, dim=-1, p=2) # bsz x (nf - 1) x nn_joints # dist_joints_to_base_pts_disp = dist_joints_to_base_pts[:, :-1] # bsz x (nf - 1) x nn_joints # return disp_base_pts, dist_joints_to_base_pts_disp, dotprod_disp_joints_along_normals, l2_disp_joints_vt_normals def get_optimized_hand_fr_joints(joints): joints = torch.from_numpy(joints).float().cuda() ### start optimization ### # setup MANO layer mano_path = "/data1/sim/mano_models/mano/models" mano_layer = ManoLayer( flat_hand_mean=True, side='right', mano_root=mano_path, # mano_root # ncomps=24, use_pca=True, root_rot_mode='axisang', joint_rot_mode='axisang' ).cuda() nn_frames = joints.size(0) # initialize variables beta_var = torch.randn([1, 10]).cuda() # first 3 global orientation rot_var = torch.randn([nn_frames, 3]).cuda() theta_var = torch.randn([nn_frames, 24]).cuda() transl_var = torch.randn([nn_frames, 3]).cuda() # transl_var = tot_rhand_transl.unsqueeze(0).repeat(args.num_init, 1, 1).contiguous().to(device).view(args.num_init * num_frames, 3).contiguous() # ori_transl_var = transl_var.clone() # rot_var = tot_rhand_glb_orient.unsqueeze(0).repeat(args.num_init, 1, 1).contiguous().to(device).view(args.num_init * num_frames, 3).contiguous() beta_var.requires_grad_() rot_var.requires_grad_() theta_var.requires_grad_() transl_var.requires_grad_() learning_rate = 0.1 # opt = optim.Adam([rot_var, transl_var], lr=args.coarse_lr) num_iters = 200 opt = optim.Adam([rot_var, transl_var], lr=learning_rate) for i in range(num_iters): # opt.zero_grad() # mano_layer # hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1), beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var) hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001 hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001 joints_pred_loss = torch.sum( (hand_joints - joints) ** 2, dim=-1 ).mean() # opt.zero_grad() pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[:, 1:], theta_var.view(nn_frames, -1)[:, :-1]) # joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device)) shape_prior_loss = torch.mean(beta_var**2) pose_prior_loss = torch.mean(theta_var**2) # pose_smoothness_loss = # =0.05 # loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 loss = joints_pred_loss * 30 opt.zero_grad() loss.backward() opt.step() print('Iter {}: {}'.format(i, loss.item()), flush=True) print('\tShape Prior Loss: {}'.format(shape_prior_loss.item())) print('\tPose Prior Loss: {}'.format(pose_prior_loss.item())) print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item())) print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item())) num_iters = 2000 opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=learning_rate) scheduler = optim.lr_scheduler.StepLR(opt, step_size=num_iters, gamma=0.5) for i in range(num_iters): opt.zero_grad() # mano_layer hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1), beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var) hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001 hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001 joints_pred_loss = torch.sum( (hand_joints - joints) ** 2, dim=-1 ).mean() # opt.zero_grad() pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[:, 1:], theta_var.view(nn_frames, -1)[:, :-1]) # joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device)) shape_prior_loss = torch.mean(beta_var**2) pose_prior_loss = torch.mean(theta_var**2) # =0.05 # loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 loss = joints_pred_loss * 30 # + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 opt.zero_grad() loss.backward() opt.step() scheduler.step() print('Iter {}: {}'.format(i, loss.item()), flush=True) print('\tShape Prior Loss: {}'.format(shape_prior_loss.item())) print('\tPose Prior Loss: {}'.format(pose_prior_loss.item())) print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item())) print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item())) return hand_verts.detach().cpu().numpy(), hand_joints.detach().cpu().numpy() def get_affinity_fr_dist(dist, s=0.02): ### affinity scores ### k = 0.5 * torch.cos(torch.pi / s * torch.abs(dist)) + 0.5 return k def get_optimized_hand_fr_joints_v2(joints, base_pts): joints = torch.from_numpy(joints).float().cuda() base_pts = torch.from_numpy(base_pts).float().cuda() ### start optimization ### # setup MANO layer mano_path = "/data1/sim/mano_models/mano/models" mano_layer = ManoLayer( flat_hand_mean=True, side='right', mano_root=mano_path, # mano_root # ncomps=24, use_pca=True, root_rot_mode='axisang', joint_rot_mode='axisang' ).cuda() nn_frames = joints.size(0) # initialize variables beta_var = torch.randn([1, 10]).cuda() # first 3 global orientation rot_var = torch.randn([nn_frames, 3]).cuda() theta_var = torch.randn([nn_frames, 24]).cuda() transl_var = torch.randn([nn_frames, 3]).cuda() # transl_var = tot_rhand_transl.unsqueeze(0).repeat(args.num_init, 1, 1).contiguous().to(device).view(args.num_init * num_frames, 3).contiguous() # ori_transl_var = transl_var.clone() # rot_var = tot_rhand_glb_orient.unsqueeze(0).repeat(args.num_init, 1, 1).contiguous().to(device).view(args.num_init * num_frames, 3).contiguous() beta_var.requires_grad_() rot_var.requires_grad_() theta_var.requires_grad_() transl_var.requires_grad_() learning_rate = 0.1 # joints: nf x nnjoints x 3 # dist_joints_to_base_pts = torch.sum( (joints.unsqueeze(-2) - base_pts.unsqueeze(0).unsqueeze(1)) ** 2, dim=-1 # nf x nnjoints x nnbasepts # ) nn_base_pts = dist_joints_to_base_pts.size(-1) dist_joints_to_base_pts = torch.sqrt(dist_joints_to_base_pts) # nf x nnjoints x nnbasepts # minn_dist, minn_dist_idx = torch.min(dist_joints_to_base_pts, dim=-1) # nf x nnjoints # basepts_idx_range = torch.arange(nn_base_pts).unsqueeze(0).unsqueeze(0).cuda() minn_dist_mask = basepts_idx_range == minn_dist_idx.unsqueeze(-1) # nf x nnjoints x nnbasepts minn_dist_mask = minn_dist_mask.float() print(f"minn_dist_mask: {minn_dist_mask.size()}") s = 1.0 affinity_scores = get_affinity_fr_dist(dist_joints_to_base_pts, s=s) # opt = optim.Adam([rot_var, transl_var], lr=args.coarse_lr) num_iters = 200 opt = optim.Adam([rot_var, transl_var], lr=learning_rate) for i in range(num_iters): # opt.zero_grad() # mano_layer # hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1), beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var) hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001 hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001 joints_pred_loss = torch.sum( (hand_joints - joints) ** 2, dim=-1 ).mean() # opt.zero_grad() pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[:, 1:], theta_var.view(nn_frames, -1)[:, :-1]) # joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device)) shape_prior_loss = torch.mean(beta_var**2) pose_prior_loss = torch.mean(theta_var**2) # pose_smoothness_loss = # =0.05 # loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 loss = joints_pred_loss * 30 loss = joints_pred_loss * 1000 opt.zero_grad() loss.backward() opt.step() print('Iter {}: {}'.format(i, loss.item()), flush=True) print('\tShape Prior Loss: {}'.format(shape_prior_loss.item())) print('\tPose Prior Loss: {}'.format(pose_prior_loss.item())) print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item())) print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item())) num_iters = 2000 num_iters = 3000 # num_iters = 1000 learning_rate = 0.01 opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=learning_rate) scheduler = optim.lr_scheduler.StepLR(opt, step_size=num_iters, gamma=0.5) for i in range(num_iters): opt.zero_grad() # mano_layer hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1), beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var) hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001 hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001 joints_pred_loss = torch.sum( (hand_joints - joints) ** 2, dim=-1 ).mean() dist_joints_to_base_pts_sqr = torch.sum( (hand_joints.unsqueeze(2) - base_pts.unsqueeze(0).unsqueeze(1)) ** 2, dim=-1 ) # attaction_loss = 0.5 * affinity_scores * dist_joints_to_base_pts_sqr attaction_loss = 0.5 * dist_joints_to_base_pts_sqr # attaction_loss = attaction_loss # attaction_loss = torch.mean(attaction_loss[..., -5:, :] * minn_dist_mask[..., -5:, :]) attaction_loss = torch.mean(attaction_loss[46:, -5:-3, :] * minn_dist_mask[46:, -5:-3, :]) # opt.zero_grad() pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[1:], theta_var.view(nn_frames, -1)[:-1]) # joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device)) shape_prior_loss = torch.mean(beta_var**2) pose_prior_loss = torch.mean(theta_var**2) joints_smoothness_loss = F.mse_loss(hand_joints.view(nn_frames, -1, 3)[1:], hand_joints.view(nn_frames, -1, 3)[:-1]) # =0.05 # loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 + joints_smoothness_loss * 100. loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.000001 + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001 + joints_smoothness_loss * 200. loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + joints_smoothness_loss * 200. loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.03 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + joints_smoothness_loss * 200. # loss = joints_pred_loss * 20 + joints_smoothness_loss * 200. + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001 # loss = joints_pred_loss * 30 # + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 # loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 0.001 + joints_smoothness_loss * 1.0 # loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 100. + joints_smoothness_loss * 10.0 # loss = joints_pred_loss * 20 # + pose_smoothness_loss * 0.5 + attaction_loss * 100. + joints_smoothness_loss * 10.0 # loss = joints_pred_loss * 30 + attaction_loss * 0.001 opt.zero_grad() loss.backward() opt.step() scheduler.step() print('Iter {}: {}'.format(i, loss.item()), flush=True) print('\tShape Prior Loss: {}'.format(shape_prior_loss.item())) print('\tPose Prior Loss: {}'.format(pose_prior_loss.item())) print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item())) print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item())) print('\tAttraction Loss: {}'.format(attaction_loss.item())) print('\tJoint Smoothness Loss: {}'.format(joints_smoothness_loss.item())) return hand_verts.detach().cpu().numpy(), hand_joints.detach().cpu().numpy() def get_optimized_hand_fr_joints_v3(joints, base_pts, tot_base_pts_trans): joints = torch.from_numpy(joints).float().cuda() base_pts = torch.from_numpy(base_pts).float().cuda() tot_base_pts_trans = torch.from_numpy(tot_base_pts_trans).float().cuda() ### start optimization ### # setup MANO layer mano_path = "/data1/sim/mano_models/mano/models" mano_layer = ManoLayer( flat_hand_mean=True, side='right', mano_root=mano_path, # mano_root # ncomps=24, use_pca=True, root_rot_mode='axisang', joint_rot_mode='axisang' ).cuda() nn_frames = joints.size(0) # initialize variables beta_var = torch.randn([1, 10]).cuda() # first 3 global orientation rot_var = torch.randn([nn_frames, 3]).cuda() theta_var = torch.randn([nn_frames, 24]).cuda() transl_var = torch.randn([nn_frames, 3]).cuda() # transl_var = tot_rhand_transl.unsqueeze(0).repeat(args.num_init, 1, 1).contiguous().to(device).view(args.num_init * num_frames, 3).contiguous() # ori_transl_var = transl_var.clone() # rot_var = tot_rhand_glb_orient.unsqueeze(0).repeat(args.num_init, 1, 1).contiguous().to(device).view(args.num_init * num_frames, 3).contiguous() beta_var.requires_grad_() rot_var.requires_grad_() theta_var.requires_grad_() transl_var.requires_grad_() learning_rate = 0.1 # joints: nf x nnjoints x 3 # dist_joints_to_base_pts = torch.sum( (joints.unsqueeze(-2) - base_pts.unsqueeze(0).unsqueeze(1)) ** 2, dim=-1 # nf x nnjoints x nnbasepts # ) nn_base_pts = dist_joints_to_base_pts.size(-1) nn_joints = dist_joints_to_base_pts.size(1) dist_joints_to_base_pts = torch.sqrt(dist_joints_to_base_pts) # nf x nnjoints x nnbasepts # minn_dist, minn_dist_idx = torch.min(dist_joints_to_base_pts, dim=-1) # nf x nnjoints # nk_contact_pts = 2 minn_dist[:, :-5] = 1e9 minn_topk_dist, minn_topk_idx = torch.topk(minn_dist, k=nk_contact_pts, largest=False) # # joints_idx_rng_exp = torch.arange(nn_joints).unsqueeze(0).cuda() == minn_topk_mask = torch.zeros_like(minn_dist) # minn_topk_mask[minn_topk_idx] = 1. # nf x nnjoints # minn_topk_mask[:, -5: -3] = 1. basepts_idx_range = torch.arange(nn_base_pts).unsqueeze(0).unsqueeze(0).cuda() minn_dist_mask = basepts_idx_range == minn_dist_idx.unsqueeze(-1) # nf x nnjoints x nnbasepts minn_dist_mask = minn_dist_mask.float() minn_topk_mask = (minn_dist_mask + minn_topk_mask.float().unsqueeze(-1)) > 1.5 print(f"minn_dist_mask: {minn_dist_mask.size()}") s = 1.0 affinity_scores = get_affinity_fr_dist(dist_joints_to_base_pts, s=s) # opt = optim.Adam([rot_var, transl_var], lr=args.coarse_lr) num_iters = 200 opt = optim.Adam([rot_var, transl_var], lr=learning_rate) for i in range(num_iters): # opt.zero_grad() # mano_layer # hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1), beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var) hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001 hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001 joints_pred_loss = torch.sum( (hand_joints - joints) ** 2, dim=-1 ).mean() # opt.zero_grad() pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[:, 1:], theta_var.view(nn_frames, -1)[:, :-1]) # joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device)) shape_prior_loss = torch.mean(beta_var**2) pose_prior_loss = torch.mean(theta_var**2) # pose_smoothness_loss = # =0.05 # loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 loss = joints_pred_loss * 30 loss = joints_pred_loss * 1000 opt.zero_grad() loss.backward() opt.step() print('Iter {}: {}'.format(i, loss.item()), flush=True) print('\tShape Prior Loss: {}'.format(shape_prior_loss.item())) print('\tPose Prior Loss: {}'.format(pose_prior_loss.item())) print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item())) print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item())) # print(tot_base_pts_trans.size()) diff_base_pts_trans = torch.sum((tot_base_pts_trans[1:, :, :] - tot_base_pts_trans[:-1, :, :]) ** 2, dim=-1) # (nf - 1) x nn_base_pts print(f"diff_base_pts_trans: {diff_base_pts_trans.size()}") diff_base_pts_trans = diff_base_pts_trans.mean(dim=-1) diff_base_pts_trans_threshold = 1e-20 diff_base_pts_trans_mask = diff_base_pts_trans > diff_base_pts_trans_threshold # (nf - 1) ### the mask of the tranformed base pts diff_base_pts_trans_mask = diff_base_pts_trans_mask.float() print(f"diff_base_pts_trans_mask: {diff_base_pts_trans_mask.size()}, diff_base_pts_trans: {diff_base_pts_trans.size()}") diff_last_frame_mask = torch.tensor([0,], dtype=torch.float32).to(diff_base_pts_trans_mask.device) + diff_base_pts_trans_mask[-1] diff_base_pts_trans_mask = torch.cat( [diff_base_pts_trans_mask, diff_last_frame_mask], dim=0 # nf tensor ) # attraction_mask = (diff_base_pts_trans_mask.unsqueeze(-1).unsqueeze(-1) + minn_topk_mask.float()) > 1.5 attraction_mask = minn_topk_mask.float() attraction_mask = attraction_mask.float() num_iters = 2000 num_iters = 3000 # num_iters = 1000 learning_rate = 0.01 opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=learning_rate) scheduler = optim.lr_scheduler.StepLR(opt, step_size=num_iters, gamma=0.5) for i in range(num_iters): opt.zero_grad() # mano_layer hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1), beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var) hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001 hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001 joints_pred_loss = torch.sum( (hand_joints - joints) ** 2, dim=-1 ).mean() dist_joints_to_base_pts_sqr = torch.sum( (hand_joints.unsqueeze(2) - base_pts.unsqueeze(0).unsqueeze(1)) ** 2, dim=-1 ) # attaction_loss = 0.5 * affinity_scores * dist_joints_to_base_pts_sqr attaction_loss = 0.5 * dist_joints_to_base_pts_sqr # attaction_loss = attaction_loss # attaction_loss = torch.mean(attaction_loss[..., -5:, :] * minn_dist_mask[..., -5:, :]) # attaction_loss = torch.mean(attaction_loss * attraction_mask) attaction_loss = torch.mean(attaction_loss[46:, -5:-3, :] * minn_dist_mask[46:, -5:-3, :]) # opt.zero_grad() pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[1:], theta_var.view(nn_frames, -1)[:-1]) # joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device)) shape_prior_loss = torch.mean(beta_var**2) pose_prior_loss = torch.mean(theta_var**2) joints_smoothness_loss = F.mse_loss(hand_joints.view(nn_frames, -1, 3)[1:], hand_joints.view(nn_frames, -1, 3)[:-1]) # =0.05 # loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 + joints_smoothness_loss * 100. loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.000001 + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001 + joints_smoothness_loss * 200. loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + joints_smoothness_loss * 200. loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.03 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 + attaction_loss * 10000 # + joints_smoothness_loss * 200. # loss = joints_pred_loss * 20 + joints_smoothness_loss * 200. + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001 # loss = joints_pred_loss * 30 # + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 # loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 0.001 + joints_smoothness_loss * 1.0 loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 2000000. # + joints_smoothness_loss * 10.0 # loss = joints_pred_loss * 20 # + pose_smoothness_loss * 0.5 + attaction_loss * 100. + joints_smoothness_loss * 10.0 # loss = joints_pred_loss * 30 + attaction_loss * 0.001 opt.zero_grad() loss.backward() opt.step() scheduler.step() print('Iter {}: {}'.format(i, loss.item()), flush=True) print('\tShape Prior Loss: {}'.format(shape_prior_loss.item())) print('\tPose Prior Loss: {}'.format(pose_prior_loss.item())) print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item())) print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item())) print('\tAttraction Loss: {}'.format(attaction_loss.item())) print('\tJoint Smoothness Loss: {}'.format(joints_smoothness_loss.item())) return hand_verts.detach().cpu().numpy(), hand_joints.detach().cpu().numpy() def get_optimized_hand_fr_joints_v4(joints, base_pts, tot_base_pts_trans, tot_base_normals_trans, with_contact_opt=False): joints = torch.from_numpy(joints).float().cuda() base_pts = torch.from_numpy(base_pts).float().cuda() tot_base_pts_trans = torch.from_numpy(tot_base_pts_trans).float().cuda() tot_base_normals_trans = torch.from_numpy(tot_base_normals_trans).float().cuda() ### start optimization ### # setup MANO layer mano_path = "/data1/sim/mano_models/mano/models" mano_layer = ManoLayer( flat_hand_mean=True, side='right', mano_root=mano_path, # mano_root # ncomps=24, use_pca=True, root_rot_mode='axisang', joint_rot_mode='axisang' ).cuda() nn_frames = joints.size(0) # initialize variables beta_var = torch.randn([1, 10]).cuda() # first 3 global orientation rot_var = torch.randn([nn_frames, 3]).cuda() theta_var = torch.randn([nn_frames, 24]).cuda() transl_var = torch.randn([nn_frames, 3]).cuda() # transl_var = tot_rhand_transl.unsqueeze(0).repeat(args.num_init, 1, 1).contiguous().to(device).view(args.num_init * num_frames, 3).contiguous() # ori_transl_var = transl_var.clone() # rot_var = tot_rhand_glb_orient.unsqueeze(0).repeat(args.num_init, 1, 1).contiguous().to(device).view(args.num_init * num_frames, 3).contiguous() beta_var.requires_grad_() rot_var.requires_grad_() theta_var.requires_grad_() transl_var.requires_grad_() learning_rate = 0.1 # joints: nf x nnjoints x 3 # # dist_joints_to_base_pts = torch.sum( # (joints.unsqueeze(-2) - base_pts.unsqueeze(0).unsqueeze(1)) ** 2, dim=-1 # nf x nnjoints x nnbasepts # # ) dist_joints_to_base_pts = torch.sum( (joints.unsqueeze(-2) - tot_base_pts_trans.unsqueeze(1)) ** 2, dim=-1 # nf x nnjoints x nnbasepts # ) dot_prod_base_pts_to_joints_with_normals = torch.sum( (joints.unsqueeze(-2) - tot_base_pts_trans.unsqueeze(1)) * tot_base_normals_trans.unsqueeze(1), dim=-1 # ) # dist_joints_to_base_pts[dot_prod_base_pts_to_joints_with_normals < 0.] = 1e9 nn_base_pts = dist_joints_to_base_pts.size(-1) nn_joints = dist_joints_to_base_pts.size(1) dist_joints_to_base_pts = torch.sqrt(dist_joints_to_base_pts) # nf x nnjoints x nnbasepts # minn_dist, minn_dist_idx = torch.min(dist_joints_to_base_pts, dim=-1) # nf x nnjoints # nk_contact_pts = 2 minn_dist[:, :-5] = 1e9 minn_topk_dist, minn_topk_idx = torch.topk(minn_dist, k=nk_contact_pts, largest=False) # # joints_idx_rng_exp = torch.arange(nn_joints).unsqueeze(0).cuda() == minn_topk_mask = torch.zeros_like(minn_dist) # minn_topk_mask[minn_topk_idx] = 1. # nf x nnjoints # minn_topk_mask[:, -5: -3] = 1. basepts_idx_range = torch.arange(nn_base_pts).unsqueeze(0).unsqueeze(0).cuda() minn_dist_mask = basepts_idx_range == minn_dist_idx.unsqueeze(-1) # nf x nnjoints x nnbasepts # for seq 101 # minn_dist_mask[31:, -5, :] = minn_dist_mask[30: 31, -5, :] minn_dist_mask[:, -5:, :] = minn_dist_mask[30:31:, -5:, :] # set to the last frame mask # minn_dist_mask = minn_dist_mask.float() tot_base_pts_trans_disp = torch.sum( (tot_base_pts_trans[1:, :, :] - tot_base_pts_trans[:-1, :, :]) ** 2, dim=-1 # (nf - 1) x nn_base_pts displacement ) tot_base_pts_trans_disp = torch.sqrt(tot_base_pts_trans_disp).mean(dim=-1) # (nf - 1) # tot_base_pts_trans_disp_mov_thres = 1e-20 tot_base_pts_trans_disp_mov_thres = 3e-4 tot_base_pts_trans_disp_mask = tot_base_pts_trans_disp >= tot_base_pts_trans_disp_mov_thres tot_base_pts_trans_disp_mask = torch.cat( [tot_base_pts_trans_disp_mask, tot_base_pts_trans_disp_mask[-1:]], dim=0 ) attraction_mask_new = (tot_base_pts_trans_disp_mask.float().unsqueeze(-1).unsqueeze(-1) + minn_dist_mask.float()) > 1.5 minn_topk_mask = (minn_dist_mask + minn_topk_mask.float().unsqueeze(-1)) > 1.5 print(f"minn_dist_mask: {minn_dist_mask.size()}") s = 1.0 affinity_scores = get_affinity_fr_dist(dist_joints_to_base_pts, s=s) # opt = optim.Adam([rot_var, transl_var], lr=args.coarse_lr) num_iters = 200 opt = optim.Adam([rot_var, transl_var], lr=learning_rate) for i in range(num_iters): # opt.zero_grad() # mano_layer # hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1), beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var) hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001 hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001 joints_pred_loss = torch.sum( (hand_joints - joints) ** 2, dim=-1 ).mean() # opt.zero_grad() pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[:, 1:], theta_var.view(nn_frames, -1)[:, :-1]) # joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device)) shape_prior_loss = torch.mean(beta_var**2) pose_prior_loss = torch.mean(theta_var**2) # pose_smoothness_loss = # =0.05 # loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 loss = joints_pred_loss * 30 loss = joints_pred_loss * 1000 opt.zero_grad() loss.backward() opt.step() print('Iter {}: {}'.format(i, loss.item()), flush=True) print('\tShape Prior Loss: {}'.format(shape_prior_loss.item())) print('\tPose Prior Loss: {}'.format(pose_prior_loss.item())) print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item())) print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item())) # print(tot_base_pts_trans.size()) diff_base_pts_trans = torch.sum((tot_base_pts_trans[1:, :, :] - tot_base_pts_trans[:-1, :, :]) ** 2, dim=-1) # (nf - 1) x nn_base_pts print(f"diff_base_pts_trans: {diff_base_pts_trans.size()}") diff_base_pts_trans = diff_base_pts_trans.mean(dim=-1) diff_base_pts_trans_threshold = 1e-20 diff_base_pts_trans_mask = diff_base_pts_trans > diff_base_pts_trans_threshold # (nf - 1) ### the mask of the tranformed base pts diff_base_pts_trans_mask = diff_base_pts_trans_mask.float() print(f"diff_base_pts_trans_mask: {diff_base_pts_trans_mask.size()}, diff_base_pts_trans: {diff_base_pts_trans.size()}") diff_last_frame_mask = torch.tensor([0,], dtype=torch.float32).to(diff_base_pts_trans_mask.device) + diff_base_pts_trans_mask[-1] diff_base_pts_trans_mask = torch.cat( [diff_base_pts_trans_mask, diff_last_frame_mask], dim=0 # nf tensor ) # attraction_mask = (diff_base_pts_trans_mask.unsqueeze(-1).unsqueeze(-1) + minn_topk_mask.float()) > 1.5 attraction_mask = minn_topk_mask.float() attraction_mask = attraction_mask.float() # the direction of the normal vector and the moving direction of the object point -> whether the point should be selected # the contact maps of the object should be like? # # the direction of the normal vector and the moving direction # define the attraction loss's weight; and attract points to the object surface # # # num_iters = 2000 num_iters = 3000 # num_iters = 1000 learning_rate = 0.01 opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=learning_rate) scheduler = optim.lr_scheduler.StepLR(opt, step_size=num_iters, gamma=0.5) for i in range(num_iters): opt.zero_grad() # mano_layer hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1), beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var) hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001 hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001 joints_pred_loss = torch.sum( (hand_joints - joints) ** 2, dim=-1 ).mean() # dist_joints_to_base_pts_sqr = torch.sum( # (hand_joints.unsqueeze(2) - base_pts.unsqueeze(0).unsqueeze(1)) ** 2, dim=-1 # ) # attaction_loss = 0.5 * affinity_scores * dist_joints_to_base_pts_sqr # attaction_loss = 0.5 * dist_joints_to_base_pts_sqr # attaction_loss = attaction_loss # attaction_loss = torch.mean(attaction_loss[..., -5:, :] * minn_dist_mask[..., -5:, :]) # attaction_loss = torch.mean(attaction_loss * attraction_mask) # attaction_loss = torch.mean(attaction_loss[46:, -5:-3, :] * minn_dist_mask[46:, -5:-3, :]) # opt.zero_grad() pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[1:], theta_var.view(nn_frames, -1)[:-1]) # joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device)) shape_prior_loss = torch.mean(beta_var**2) pose_prior_loss = torch.mean(theta_var**2) joints_smoothness_loss = F.mse_loss(hand_joints.view(nn_frames, -1, 3)[1:], hand_joints.view(nn_frames, -1, 3)[:-1]) # =0.05 # loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 + joints_smoothness_loss * 100. loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.000001 + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001 + joints_smoothness_loss * 200. loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + joints_smoothness_loss * 200. loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.03 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + attaction_loss * 10000 # + joints_smoothness_loss * 200. # loss = joints_pred_loss * 20 + joints_smoothness_loss * 200. + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001 # loss = joints_pred_loss * 30 # + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 # loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 0.001 + joints_smoothness_loss * 1.0 # loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 2000000. # + joints_smoothness_loss * 10.0 # loss = joints_pred_loss * 20 # + pose_smoothness_loss * 0.5 + attaction_loss * 100. + joints_smoothness_loss * 10.0 # loss = joints_pred_loss * 30 + attaction_loss * 0.001 opt.zero_grad() loss.backward() opt.step() scheduler.step() print('Iter {}: {}'.format(i, loss.item()), flush=True) print('\tShape Prior Loss: {}'.format(shape_prior_loss.item())) print('\tPose Prior Loss: {}'.format(pose_prior_loss.item())) print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item())) print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item())) # print('\tAttraction Loss: {}'.format(attaction_loss.item())) print('\tJoint Smoothness Loss: {}'.format(joints_smoothness_loss.item())) if with_contact_opt: num_iters = 2000 # num_iters = 1000 # seq 77 # num_iters = 500 # seq 77 ori_theta_var = theta_var.detach().clone() # tot_base_pts_trans # nf x nn_base_pts x 3 disp_base_pts_trans = tot_base_pts_trans[1:] - tot_base_pts_trans[:-1] # (nf - 1) x nn_base_pts x 3 disp_base_pts_trans = torch.cat( # nf x nn_base_pts x 3 [disp_base_pts_trans, disp_base_pts_trans[-1:]], dim=0 ) # joints: nf x nn_jts_pts x 3; nf x nn_base_pts x 3 dist_joints_to_base_pts_trans = torch.sum( (joints.unsqueeze(2) - tot_base_pts_trans.unsqueeze(1)) ** 2, dim=-1 # nf x nn_jts_pts x nn_base_pts ) minn_dist_joints_to_base_pts, minn_dist_idxes = torch.min(dist_joints_to_base_pts_trans, dim=-1) # nf x nn_jts_pts # nf x nn_jts_pts # nearest_base_normals = model_util.batched_index_select_ours(tot_base_normals_trans, indices=minn_dist_idxes, dim=1) # nf x nn_base_pts x 3 --> nf x nn_jts_pts x 3 # # nf x nn_jts_pts x 3 # nearest_base_pts_trans = model_util.batched_index_select_ours(disp_base_pts_trans, indices=minn_dist_idxes, dim=1) # nf x nn_jts_ts x 3 # dot_nearest_base_normals_trans = torch.sum( nearest_base_normals * nearest_base_pts_trans, dim=-1 # nf x nn_jts ) trans_normals_mask = dot_nearest_base_normals_trans < 0. # nf x nn_jts # nf x nn_jts # nearest_dist = torch.sqrt(minn_dist_joints_to_base_pts) # nearest_dist_mask = nearest_dist < 0.01 # hoi seq nearest_dist_mask = nearest_dist < 0.1 k_attr = 100. joint_attraction_k = torch.exp(-1. * k_attr * nearest_dist) attraction_mask_new_new = (attraction_mask_new.float() + trans_normals_mask.float().unsqueeze(-1) + nearest_dist_mask.float().unsqueeze(-1)) > 2.5 # opt = optim.Adam([rot_var, transl_var, theta_var], lr=learning_rate) # opt = optim.Adam([transl_var, theta_var], lr=learning_rate) opt = optim.Adam([transl_var, theta_var, rot_var], lr=learning_rate) scheduler = optim.lr_scheduler.StepLR(opt, step_size=num_iters, gamma=0.5) for i in range(num_iters): opt.zero_grad() # mano_layer hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1), beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var) hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001 hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001 joints_pred_loss = torch.sum( (hand_joints - joints) ** 2, dim=-1 ).mean() # dist_joints_to_base_pts_sqr = torch.sum( # (hand_joints.unsqueeze(2) - base_pts.unsqueeze(0).unsqueeze(1)) ** 2, dim=-1 # ) # nf x nnb x 3 ---- nf x nnj x 1 x 3 dist_joints_to_base_pts_sqr = torch.sum( (hand_joints.unsqueeze(2) - tot_base_pts_trans.unsqueeze(1)) ** 2, dim=-1 ) # attaction_loss = 0.5 * affinity_scores * dist_joints_to_base_pts_sqr attaction_loss = 0.5 * dist_joints_to_base_pts_sqr # attaction_loss = attaction_loss # attaction_loss = torch.mean(attaction_loss[..., -5:, :] * minn_dist_mask[..., -5:, :]) # attaction_loss = torch.mean(attaction_loss * attraction_mask) # attaction_loss = torch.mean(attaction_loss[46:, -5:-3, :] * minn_dist_mask[46:, -5:-3, :]) + torch.mean(attaction_loss[:40, -5:-3, :] * minn_dist_mask[:40, -5:-3, :]) # seq 80 # attaction_loss = torch.mean(attaction_loss[46:, -5:-3, :] * minn_dist_mask[46:, -5:-3, :]) + torch.mean(attaction_loss[:40, -5:-3, :] * minn_dist_mask[:40, -5:-3, :]) # seq 70 # attaction_loss = torch.mean(attaction_loss[10:, -5:-3, :] * minn_dist_mask[10:, -5:-3, :]) # + torch.mean(attaction_loss[:40, -5:-3, :] * minn_dist_mask[:40, -5:-3, :]) # new version relying on new mask # # attaction_loss = torch.mean(attaction_loss[:, -5:-3, :] * attraction_mask_new[:, -5:-3, :]) ### original version ### # attaction_loss = torch.mean(attaction_loss[20:, -3:, :] * attraction_mask_new[20:, -3:, :]) second_tips_idxes = [9, 12, 6, 3, 15] second_tips_idxes =torch.tensor(second_tips_idxes, dtype=torch.long).cuda() # attaction_loss = torch.mean(attaction_loss[:, -5:, :] * attraction_mask_new_new[:, -5:, :] * joint_attraction_k[:, -5:].unsqueeze(-1)) # + torch.mean(attaction_loss[:, second_tips_idxes, :] * attraction_mask_new_new[:, second_tips_idxes, :] * joint_attraction_k[:, second_tips_idxes].unsqueeze(-1)) # attaction_loss = torch.mean(attaction_loss[:, -5:-2, :] * attraction_mask_new_new[:, -5:-2, :]) # + torch.mean(attaction_loss[:, second_tips_idxes, :] * attraction_mask_new_new[:, second_tips_idxes, :] * joint_attraction_k[:, second_tips_idxes].unsqueeze(-1)) # attraction_mask_new # attaction_loss = torch.mean(attaction_loss[:, -5:-2, :] * attraction_mask_new[:, -5:-2, :]) # + torch.mean(attaction_loss[:, second_tips_idxes, :] * attraction_mask_new_new[:, second_tips_idxes, :] * joint_attraction_k[:, second_tips_idxes].unsqueeze(-1)) # seq mug # attaction_loss = torch.mean(attaction_loss[4:, -5:-4, :] * minn_dist_mask[4:, -5:-4, :]) # + torch.mean(attaction_loss[:40, -5:-3, :] * minn_dist_mask[:40, -5:-3, :]) attaction_loss = torch.mean(attaction_loss[10:, -5:-2, :] * attraction_mask_new[10:, -5:-2, :]) # + torch.mean(attaction_loss[:, second_tips_idxes, :] * attraction_mask_new_new[:, second_tips_idxes, :] * joint_attraction_k[:, second_tips_idxes].unsqueeze(-1)) # opt.zero_grad() pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[1:], theta_var.view(nn_frames, -1)[:-1]) # joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device)) shape_prior_loss = torch.mean(beta_var**2) pose_prior_loss = torch.mean(theta_var**2) joints_smoothness_loss = F.mse_loss(hand_joints.view(nn_frames, -1, 3)[1:], hand_joints.view(nn_frames, -1, 3)[:-1]) # =0.05 # # loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 + joints_smoothness_loss * 100. # loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.000001 + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001 + joints_smoothness_loss * 200. # loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + joints_smoothness_loss * 200. # loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.03 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + attaction_loss * 10000 # + joints_smoothness_loss * 200. theta_smoothness_loss = F.mse_loss(theta_var, ori_theta_var) # loss = attaction_loss * 1000. + theta_smoothness_loss * 0.00001 # loss = attaction_loss * 1000. + joints_pred_loss # loss = attaction_loss * 100000. + joints_pred_loss * 0.00000001 + joints_smoothness_loss * 0.5 # + pose_prior_loss * 0.00005 # + shape_prior_loss * 0.001 # + pose_smoothness_loss * 0.5 loss = attaction_loss * 100000. + joints_pred_loss * 0.000 + joints_smoothness_loss * 0.000005 # loss = joints_pred_loss * 20 + joints_smoothness_loss * 200. + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001 # loss = joints_pred_loss * 30 # + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 # loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 0.001 + joints_smoothness_loss * 1.0 # loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 2000000. # + joints_smoothness_loss * 10.0 # loss = joints_pred_loss * 20 # + pose_smoothness_loss * 0.5 + attaction_loss * 100. + joints_smoothness_loss * 10.0 # loss = joints_pred_loss * 30 + attaction_loss * 0.001 opt.zero_grad() loss.backward() opt.step() scheduler.step() print('Iter {}: {}'.format(i, loss.item()), flush=True) print('\tShape Prior Loss: {}'.format(shape_prior_loss.item())) print('\tPose Prior Loss: {}'.format(pose_prior_loss.item())) print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item())) print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item())) print('\tAttraction Loss: {}'.format(attaction_loss.item())) print('\tJoint Smoothness Loss: {}'.format(joints_smoothness_loss.item())) # theta_smoothness_loss print('\tTheta Smoothness Loss: {}'.format(theta_smoothness_loss.item())) return hand_verts.detach().cpu().numpy(), hand_joints.detach().cpu().numpy() # get_optimized_hand_fr_joints_v5 ---> get optimized i def get_optimized_hand_fr_joints_v5(joints, tot_gt_rhand_joints, base_pts, tot_base_pts_trans, predicted_joint_quants=None): joints = torch.from_numpy(joints).float().cuda() base_pts = torch.from_numpy(base_pts).float().cuda() tot_gt_rhand_joints = torch.from_numpy(tot_gt_rhand_joints).float().cuda() tot_base_pts_trans = torch.from_numpy(tot_base_pts_trans).float().cuda() # nf x nn_joitns for each quantity # # if predicted_joint_quants is None: gt_dist_joints_to_base_pts_disp, gt_dist_disp_along_dir, gt_dist_disp_vt_dir = calculate_disp_quants(tot_gt_rhand_joints, tot_base_pts_trans) else: predicted_joint_quants = torch.from_numpy(predicted_joint_quants).float().cuda() gt_dist_joints_to_base_pts_disp, gt_dist_disp_along_dir, gt_dist_disp_vt_dir = predicted_joint_quants[..., 0], predicted_joint_quants[..., 1], predicted_joint_quants[..., 2] # print(f"gt_dist_joints_to_base_pts_disp: {gt_dist_joints_to_base_pts_disp.size()}, gt_dist_disp_along_dir: {gt_dist_disp_along_dir.size()}, gt_dist_disp_vt_dir: {gt_dist_disp_vt_dir.size()}") gt_dist_joints_to_base_pts_disp_real, gt_dist_disp_along_dir_real, gt_dist_disp_vt_dir_real = calculate_disp_quants(tot_gt_rhand_joints, tot_base_pts_trans) diff_gt_dist_joints = torch.mean((gt_dist_joints_to_base_pts_disp_real - gt_dist_joints_to_base_pts_disp) ** 2) diff_gt_disp_along_normals = torch.mean((gt_dist_disp_along_dir_real - gt_dist_disp_along_dir) ** 2) diff_gt_disp_vt_normals = torch.mean((gt_dist_disp_vt_dir_real - gt_dist_disp_vt_dir) ** 2) print(f"diff_gt_dist_joints: {diff_gt_dist_joints.item()}, diff_gt_disp_along_normals: {diff_gt_disp_along_normals.item()}, diff_gt_disp_vt_normals: {diff_gt_disp_vt_normals.item()}") disp_info_pred_gt_sv_dict = { 'gt_dist_joints_to_base_pts_disp_real': gt_dist_joints_to_base_pts_disp_real.detach().cpu().numpy(), 'gt_dist_disp_along_dir_real': gt_dist_disp_along_dir_real.detach().cpu().numpy(), 'gt_dist_disp_vt_dir_real': gt_dist_disp_vt_dir_real.detach().cpu().numpy(), 'gt_dist_joints_to_base_pts_disp': gt_dist_joints_to_base_pts_disp.detach().cpu().numpy(), 'gt_dist_disp_along_dir': gt_dist_disp_along_dir.detach().cpu().numpy(), 'gt_dist_disp_vt_dir': gt_dist_disp_vt_dir.detach().cpu().numpy(), } # disp_info_pred_gt_sv_fn = "/home/xueyi/sim/motion-diffusion-model/save/my_humanml_trans_enc_512" disp_info_pred_gt_sv_fn = os.path.join(disp_info_pred_gt_sv_fn, f"pred_gt_disp_info.npy") np.save(disp_info_pred_gt_sv_fn, disp_info_pred_gt_sv_dict) print(f"saved to {disp_info_pred_gt_sv_fn}") # tot_base_pts_trans = torch.from_numpy(tot_base_pts_trans).float().cuda() ### start optimization ### # setup MANO layer mano_path = "/data1/sim/mano_models/mano/models" mano_layer = ManoLayer( flat_hand_mean=True, side='right', mano_root=mano_path, # mano_root # ncomps=24, use_pca=True, root_rot_mode='axisang', joint_rot_mode='axisang' ).cuda() ### nn_frames ### nn_frames = joints.size(0) # initialize variables beta_var = torch.randn([1, 10]).cuda() # first 3 global orientation rot_var = torch.randn([nn_frames, 3]).cuda() theta_var = torch.randn([nn_frames, 24]).cuda() transl_var = torch.randn([nn_frames, 3]).cuda() # transl_var = tot_rhand_transl.unsqueeze(0).repeat(args.num_init, 1, 1).contiguous().to(device).view(args.num_init * num_frames, 3).contiguous() # ori_transl_var = transl_var.clone() # rot_var = tot_rhand_glb_orient.unsqueeze(0).repeat(args.num_init, 1, 1).contiguous().to(device).view(args.num_init * num_frames, 3).contiguous() beta_var.requires_grad_() rot_var.requires_grad_() theta_var.requires_grad_() transl_var.requires_grad_() learning_rate = 0.1 # joints: nf x nnjoints x 3 # dist_joints_to_base_pts = torch.sum( (joints.unsqueeze(-2) - base_pts.unsqueeze(0).unsqueeze(1)) ** 2, dim=-1 # nf x nnjoints x nnbasepts # ) nn_base_pts = dist_joints_to_base_pts.size(-1) nn_joints = dist_joints_to_base_pts.size(1) dist_joints_to_base_pts = torch.sqrt(dist_joints_to_base_pts) # nf x nnjoints x nnbasepts # minn_dist, minn_dist_idx = torch.min(dist_joints_to_base_pts, dim=-1) # nf x nnjoints # nk_contact_pts = 2 minn_dist[:, :-5] = 1e9 minn_topk_dist, minn_topk_idx = torch.topk(minn_dist, k=nk_contact_pts, largest=False) # # joints_idx_rng_exp = torch.arange(nn_joints).unsqueeze(0).cuda() == minn_topk_mask = torch.zeros_like(minn_dist) # minn_topk_mask[minn_topk_idx] = 1. # nf x nnjoints # minn_topk_mask[:, -5: -3] = 1. basepts_idx_range = torch.arange(nn_base_pts).unsqueeze(0).unsqueeze(0).cuda() minn_dist_mask = basepts_idx_range == minn_dist_idx.unsqueeze(-1) # nf x nnjoints x nnbasepts minn_dist_mask = minn_dist_mask.float() minn_topk_mask = (minn_dist_mask + minn_topk_mask.float().unsqueeze(-1)) > 1.5 print(f"minn_dist_mask: {minn_dist_mask.size()}") s = 1.0 affinity_scores = get_affinity_fr_dist(dist_joints_to_base_pts, s=s) # opt = optim.Adam([rot_var, transl_var], lr=args.coarse_lr) num_iters = 200 opt = optim.Adam([rot_var, transl_var], lr=learning_rate) for i in range(num_iters): # opt.zero_grad() # mano_layer # hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1), beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var) hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001 hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001 joints_pred_loss = torch.sum( (hand_joints - joints) ** 2, dim=-1 ).mean() # opt.zero_grad() pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[:, 1:], theta_var.view(nn_frames, -1)[:, :-1]) # joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device)) shape_prior_loss = torch.mean(beta_var**2) pose_prior_loss = torch.mean(theta_var**2) # pose_smoothness_loss = # =0.05 # loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 loss = joints_pred_loss * 30 loss = joints_pred_loss * 1000 opt.zero_grad() loss.backward() opt.step() print('Iter {}: {}'.format(i, loss.item()), flush=True) print('\tShape Prior Loss: {}'.format(shape_prior_loss.item())) print('\tPose Prior Loss: {}'.format(pose_prior_loss.item())) print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item())) print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item())) # print(tot_base_pts_trans.size()) diff_base_pts_trans = torch.sum((tot_base_pts_trans[1:, :, :] - tot_base_pts_trans[:-1, :, :]) ** 2, dim=-1) # (nf - 1) x nn_base_pts print(f"diff_base_pts_trans: {diff_base_pts_trans.size()}") diff_base_pts_trans = diff_base_pts_trans.mean(dim=-1) diff_base_pts_trans_threshold = 1e-20 diff_base_pts_trans_mask = diff_base_pts_trans > diff_base_pts_trans_threshold # (nf - 1) ### the mask of the tranformed base pts diff_base_pts_trans_mask = diff_base_pts_trans_mask.float() print(f"diff_base_pts_trans_mask: {diff_base_pts_trans_mask.size()}, diff_base_pts_trans: {diff_base_pts_trans.size()}") diff_last_frame_mask = torch.tensor([0,], dtype=torch.float32).to(diff_base_pts_trans_mask.device) + diff_base_pts_trans_mask[-1] diff_base_pts_trans_mask = torch.cat( [diff_base_pts_trans_mask, diff_last_frame_mask], dim=0 # nf tensor ) # attraction_mask = (diff_base_pts_trans_mask.unsqueeze(-1).unsqueeze(-1) + minn_topk_mask.float()) > 1.5 attraction_mask = minn_topk_mask.float() attraction_mask = attraction_mask.float() num_iters = 2000 num_iters = 3000 # num_iters = 1000 learning_rate = 0.01 opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=learning_rate) scheduler = optim.lr_scheduler.StepLR(opt, step_size=num_iters, gamma=0.5) for i in range(num_iters): opt.zero_grad() # mano_layer hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1), beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var) hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001 hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001 joints_pred_loss = torch.sum( (hand_joints - joints) ** 2, dim=-1 ).mean() dist_joints_to_base_pts_sqr = torch.sum( (hand_joints.unsqueeze(2) - base_pts.unsqueeze(0).unsqueeze(1)) ** 2, dim=-1 ) # attaction_loss = 0.5 * affinity_scores * dist_joints_to_base_pts_sqr attaction_loss = 0.5 * dist_joints_to_base_pts_sqr # attaction_loss = attaction_loss # attaction_loss = torch.mean(attaction_loss[..., -5:, :] * minn_dist_mask[..., -5:, :]) # attaction_loss = torch.mean(attaction_loss * attraction_mask) attaction_loss = torch.mean(attaction_loss[46:, -5:-3, :] * minn_dist_mask[46:, -5:-3, :]) # opt.zero_grad() pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[1:], theta_var.view(nn_frames, -1)[:-1]) # joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device)) shape_prior_loss = torch.mean(beta_var**2) pose_prior_loss = torch.mean(theta_var**2) joints_smoothness_loss = F.mse_loss(hand_joints.view(nn_frames, -1, 3)[1:], hand_joints.view(nn_frames, -1, 3)[:-1]) # =0.05 # loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 + joints_smoothness_loss * 100. loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.000001 + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001 + joints_smoothness_loss * 200. loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + joints_smoothness_loss * 200. loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.03 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + attaction_loss * 10000 # + joints_smoothness_loss * 200. # loss = joints_pred_loss * 20 + joints_smoothness_loss * 200. + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001 # loss = joints_pred_loss * 30 # + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 # loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 0.001 + joints_smoothness_loss * 1.0 # loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 2000000. # + joints_smoothness_loss * 10.0 # loss = joints_pred_loss * 20 # + pose_smoothness_loss * 0.5 + attaction_loss * 100. + joints_smoothness_loss * 10.0 # loss = joints_pred_loss * 30 + attaction_loss * 0.001 opt.zero_grad() loss.backward() opt.step() scheduler.step() print('Iter {}: {}'.format(i, loss.item()), flush=True) print('\tShape Prior Loss: {}'.format(shape_prior_loss.item())) print('\tPose Prior Loss: {}'.format(pose_prior_loss.item())) print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item())) print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item())) print('\tAttraction Loss: {}'.format(attaction_loss.item())) print('\tJoint Smoothness Loss: {}'.format(joints_smoothness_loss.item())) # joints: nf x nnjoints x 3 # dist_joints_to_base_pts = torch.sum( (hand_joints.unsqueeze(-2) - tot_base_pts_trans.unsqueeze(1)) ** 2, dim=-1 # nf x nnjoints x nnbasepts # ) # nn_base_pts = dist_joints_to_base_pts.size(-1) # nn_joints = dist_joints_to_base_pts.size(1) # dist_joints_to_base_pts = torch.sqrt(dist_joints_to_base_pts) # nf x nnjoints x nnbasepts # minn_dist, minn_dist_idx = torch.min(dist_joints_to_base_pts, dim=-1) # nf x nnjoints # num_iters = 2000 # num_iters = 1000 ori_theta_var = theta_var.detach().clone() # opt = optim.Adam([rot_var, transl_var, theta_var], lr=learning_rate) opt = optim.Adam([transl_var, theta_var], lr=learning_rate) scheduler = optim.lr_scheduler.StepLR(opt, step_size=num_iters, gamma=0.5) for i in range(num_iters): opt.zero_grad() # mano_layer hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1), beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var) hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001 hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001 # gt_dist_joints_to_base_pts_disp, gt_dist_disp_along_dir, gt_dist_disp_vt_dir = calculate_disp_quants(tot_gt_rhand_joints, tot_base_pts_trans) # nf x nn_joints here # # cur_dist_disp_vt_dir cur_dist_joints_to_base_pts_disp, cur_dist_disp_along_dir, cur_dist_disp_vt_dir = calculate_disp_quants(hand_joints, tot_base_pts_trans, minn_base_pts_idxes=minn_dist_idx) dist_joints_to_base_pts_disp_loss = ((cur_dist_joints_to_base_pts_disp - gt_dist_joints_to_base_pts_disp) ** 2)[:, -5:-3].mean(dim=-1).mean(dim=-1) dist_disp_along_dir_loss = ((cur_dist_disp_along_dir - gt_dist_disp_along_dir) ** 2)[:, -5:-3].mean(dim=-1).mean(dim=-1) dist_disp_vt_dir_loss = ((cur_dist_disp_vt_dir - gt_dist_disp_vt_dir) ** 2)[:, -5:-3].mean(dim=-1).mean(dim=-1) # dist joints to baes pts dist_joints_to_base_pts_disp_loss = ((cur_dist_joints_to_base_pts_disp - gt_dist_joints_to_base_pts_disp) ** 2).mean(dim=-1).mean(dim=-1) dist_disp_along_dir_loss = ((cur_dist_disp_along_dir - gt_dist_disp_along_dir) ** 2).mean(dim=-1).mean(dim=-1) dist_disp_vt_dir_loss = ((cur_dist_disp_vt_dir - gt_dist_disp_vt_dir) ** 2).mean(dim=-1).mean(dim=-1) # dist dip losses for the joint to base pts disp loss # # dist_disp_losses -> dist disp losses # dist_disp_losses = dist_joints_to_base_pts_disp_loss + dist_disp_along_dir_loss + dist_disp_vt_dir_loss joints_pred_loss = torch.sum( (hand_joints - joints) ** 2, dim=-1 ).mean() dist_joints_to_base_pts_sqr = torch.sum( (hand_joints.unsqueeze(2) - base_pts.unsqueeze(0).unsqueeze(1)) ** 2, dim=-1 ) # attaction_loss = 0.5 * affinity_scores * dist_joints_to_base_pts_sqr attaction_loss = 0.5 * dist_joints_to_base_pts_sqr # attaction_loss = attaction_loss # attaction_loss = torch.mean(attaction_loss[..., -5:, :] * minn_dist_mask[..., -5:, :]) # attaction_loss = torch.mean(attaction_loss * attraction_mask) attaction_loss = torch.mean(attaction_loss[46:, -5:-3, :] * minn_dist_mask[46:, -5:-3, :]) + torch.mean(attaction_loss[:40, -5:-3, :] * minn_dist_mask[:40, -5:-3, :]) # opt.zero_grad() pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[1:], theta_var.view(nn_frames, -1)[:-1]) # joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device)) shape_prior_loss = torch.mean(beta_var**2) pose_prior_loss = torch.mean(theta_var**2) joints_smoothness_loss = F.mse_loss(hand_joints.view(nn_frames, -1, 3)[1:], hand_joints.view(nn_frames, -1, 3)[:-1]) # =0.05 # # loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 + joints_smoothness_loss * 100. # loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.000001 + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001 + joints_smoothness_loss * 200. # loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + joints_smoothness_loss * 200. # loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.03 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + attaction_loss * 10000 # + joints_smoothness_loss * 200. theta_smoothness_loss = F.mse_loss(theta_var, ori_theta_var) # loss = attaction_loss * 1000. + theta_smoothness_loss * 0.00001 loss = dist_disp_losses * 1000000. + theta_smoothness_loss * 0.0000001 # loss = joints_pred_loss * 20 + joints_smoothness_loss * 200. + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001 # loss = joints_pred_loss * 30 # + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 # loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 0.001 + joints_smoothness_loss * 1.0 # loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 2000000. # + joints_smoothness_loss * 10.0 # loss = joints_pred_loss * 20 # + pose_smoothness_loss * 0.5 + attaction_loss * 100. + joints_smoothness_loss * 10.0 # loss = joints_pred_loss * 30 + attaction_loss * 0.001 # joints smoothness loss # opt.zero_grad() loss.backward() opt.step() scheduler.step() print('Iter {}: {}'.format(i, loss.item()), flush=True) print('\tShape Prior Loss: {}'.format(shape_prior_loss.item())) print('\tPose Prior Loss: {}'.format(pose_prior_loss.item())) print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item())) print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item())) print('\tAttraction Loss: {}'.format(attaction_loss.item())) print('\tJoint Smoothness Loss: {}'.format(joints_smoothness_loss.item())) # theta_smoothness_loss print('\tTheta Smoothness Loss: {}'.format(theta_smoothness_loss.item())) # dist_disp_losses print('\tDist Disp Loss: {}'.format(dist_disp_losses.item())) # dist disp loss # return hand_verts.detach().cpu().numpy(), hand_joints.detach().cpu().numpy() ## optimization ## if __name__=='__main__': pred_infos_sv_folder = "/home/xueyi/sim/motion-diffusion-model/save/my_humanml_trans_enc_512" # /data1/sim/mdm/eval_save/predicted_infos_seq_1_seed_77_tag_rep_only_real_sel_base_mean_all_noise_.npy # pred_infos_sv_folder = "/data1/sim/mdm/eval_save/" pred_joints_info_nm = "predicted_infos.npy" # pred_joints_info_nm = "predicted_infos_hoi_seed_77_jts_only_t_300.npy" # pred_joints_info_nm = "predicted_infos_seq_1_seed_77_tag_rep_only_real_sel_base_mean_all_noise_.npy" # pred_joints_info_nm = "predicted_infos_seq_2_seed_77_tag_rep_only_real_sel_base_mean_all_noise_.npy" # pred_joints_info_nm = "predicted_infos_seq_36_seed_77_tag_rep_only_real_sel_base_mean_all_noise_.npy" # # pred_joints_info_nm = "predicted_infos_seq_36_seed_77_tag_jts_only.npy" # pred_joints_info_nm = "predicted_infos_seq_80_wtrans.npy" # pred_joints_info_nm = "predicted_infos_80_wtrans_rep.npy" # pred_joints_info_nm = "predicted_infos_seq_70_seed_31_tag_jts_only.npy" # pred_joints_info_nm = "predicted_infos_seq_80_seed_31_tag_jts_only.npy" # /home/xueyi/sim/motion-diffusion-model/save/my_humanml_trans_enc_512/predicted_infos_seq_17_seed_31_tag_jts_only.npy # pred_joints_info_nm = "predicted_infos_seq_87_seed_31_tag_jts_only.npy" # pred_joints_info_nm = "predicted_infos_seq_77.npy" # pred_joints_info_nm = "predicted_infos_seq_1_seed_31_tag_rep_only_real_sel_base_0.npy" # # /home/xueyi/sim/motion-diffusion-model/save/my_humanml_trans_enc_512/predicted_infos_seq_1_seed_31_tag_rep_only_real_sel_base_mean.npy # pred_joints_info_nm = "predicted_infos_seq_1_seed_31_tag_rep_only_real_sel_base_mean.npy" # # pred_infos_sv_folder = "/home/xueyi/sim/motion-diffusion-model/save/my_humanml_trans_enc_512/predicted_infos.npy" # TODO: load the total data sequence and transform object shape using the loaded sample # TODO: please output hands and objects other than hands only in each frame # pred_joints_info_fn = os.path.join(pred_infos_sv_folder, pred_joints_info_nm) data = np.load(pred_joints_info_fn, allow_pickle=True).item() targets = data['targets'] # ## targets -> targets and outputs ## outputs = data['outputs'] # tot_base_pts = data["tot_base_pts"][0] tot_base_normals = data['tot_base_normals'][0] # nn_base_normals # obj_verts = data['obj_verts'] # outputs = targets pred_infos_sv_folder = "/data1/sim/mdm/eval_save/" pred_infos_sv_folder = "/home/xueyi/sim/motion-diffusion-model/save/my_humanml_trans_enc_512" pred_joints_info_nm = "predicted_infos.npy" # pred_joints_info_nm = "predicted_infos_hoi_seed_77_jts_only_t_300.npy" # pred_infos_sv_folder = "/data1/sim/mdm/eval_save/" # pred_joints_info_nm = "predicted_infos_seq_1_seed_77_tag_rep_only_real_sel_base_mean_all_noise_.npy" # pred_joints_info_nm = "predicted_infos_seq_2_seed_77_tag_rep_only_real_sel_base_mean_all_noise_.npy" # pred_joints_info_nm = "predicted_infos_seq_36_seed_77_tag_rep_only_real_sel_base_mean_all_noise_.npy" # pred_joints_info_nm = "predicted_infos_seq_80_wtrans.npy" # pred_joints_info_nm = "predicted_infos_80_wtrans_rep.npy" # pred_joints_info_nm = "predicted_infos_seq_80_seed_31_tag_jts_only.npy" # pred_joints_info_nm = "predicted_infos_seq_77.npy" pred_joints_info_fn = os.path.join(pred_infos_sv_folder, pred_joints_info_nm) data = np.load(pred_joints_info_fn, allow_pickle=True).item() # outputs = targets # outputs # targets # tot_obj_rot = data['tot_obj_rot'][0] # ws x 3 x 3 ---> obj_rot; # tot_obj_transl = data['tot_obj_transl'][0] print(f"tot_obj_rot: {tot_obj_rot.shape}, tot_obj_transl: {tot_obj_transl.shape}") if len(tot_base_pts.shape) == 2: # numpy array # tot_base_pts_trans = np.matmul(tot_base_pts.reshape(1, tot_base_pts.shape[0], 3), tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1]) tot_base_pts = np.matmul(tot_base_pts, tot_obj_rot[0]) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])[0] tot_base_normals_trans = np.matmul( # # tot_base_normals.reshape(1, tot_base_normals.shape[0], 3), tot_obj_rot ) else: # numpy array # tot_base_pts_trans = np.matmul(tot_base_pts, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1]) tot_base_pts = np.matmul(tot_base_pts, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1]) tot_base_normals_trans = np.matmul( # # tot_base_normals, tot_obj_rot ) # tot_base_normals_trans = np.matmul( # # # tot_base_normals.reshape(1, tot_base_normals.shape[0], 3), tot_obj_rot # ) outputs = np.matmul(outputs, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1]) # ws x nn_verts x 3 # targets = np.matmul(targets, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1]) # ws x nn_verts x 3 # # denoise relative positions print(f"tot_base_pts: {tot_base_pts.shape}") # obj_verts_trans = np.matmul(obj_verts, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1]) # outputs = targets with_contact_opt = False with_contact_opt = True optimized_out_hand_verts, optimized_out_hand_joints = get_optimized_hand_fr_joints_v4(outputs, tot_base_pts, tot_base_pts_trans, tot_base_normals_trans, with_contact_opt=with_contact_opt) # # predicted_joint_quants_fn = "/home/xueyi/sim/motion-diffusion-model/save/my_humanml_trans_enc_512/pred_joint_quants.npy" # predicted_joint_quants = np.load(predicted_joint_quants_fn, allow_pickle=True).item() # predicted_joint_quants = predicted_joint_quants['dec_joint_quants'] # print("predicted_joint_quants", predicted_joint_quants.shape) # # print(predicted_joint_quants.keys()) # # exit(0) # # gt # tot_gt_rhand_joints = data['tot_gt_rhand_joints'][0] # nf x nn_joints x 3 --> gt joints # # tot_gt_rhand_joints = np.matmul(tot_gt_rhand_joints, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1]) # ws x nn_verts x 3 # # optimized_out_hand_verts, optimized_out_hand_joints = get_optimized_hand_fr_joints_v5(outputs, tot_gt_rhand_joints, tot_base_pts, tot_base_pts_trans, predicted_joint_quants=predicted_joint_quants) # optimized_data_fn = "/home/xueyi/sim/motion-diffusion-model/save/my_humanml_trans_enc_512/optimized_joints.npy" # data = np.load(optimized_data_fn, allow_pickle=True).item() # outputs = data["optimized_joints"] # optimized_out_hand_verts, optimized_out_hand_joints = get_optimized_hand_fr_joints(outputs) # optimized_tar_hand_verts, optimized_tar_hand_joints = get_optimized_hand_fr_joints(targets) # optimized_out_hand_verts, optimized_out_hand_joints = get_optimized_hand_fr_joints_v2(outputs, tot_base_pts) optimized_sv_infos = { 'optimized_out_hand_verts': optimized_out_hand_verts, 'optimized_out_hand_joints': optimized_out_hand_joints, 'tot_base_pts_trans': tot_base_pts_trans, # 'optimized_tar_hand_verts': optimized_tar_hand_verts, # 'optimized_tar_hand_joints': optimized_tar_hand_joints, } optimized_sv_infos_sv_fn = os.path.join(pred_infos_sv_folder, "optimized_infos_sv_dict.npy") np.save(optimized_sv_infos_sv_fn, optimized_sv_infos) print(f"optimized infos saved to {optimized_sv_infos_sv_fn}")