gene-hoi-denoising / sample /reconstruct_data.py
meow
evn
5ee737a
import sys
sys.path.insert(0, '.')
sys.path.insert(0, '..')
import torch
import torch.nn.functional as F
from torch import optim
import numpy as np
# import os, argparse, copy, json
# import pickle as pkl
# from scipy.spatial.transform import Rotation as R
# from psbody.mesh import Mesh
from manopth.manolayer import ManoLayer
# from dataloading import GRAB_Single_Frame, GRAB_Single_Frame_V6, GRAB_Single_Frame_V7, GRAB_Single_Frame_V8, GRAB_Single_Frame_V9, GRAB_Single_Frame_V9_Ours, GRAB_Single_Frame_V10
# use_trans_encoders
# from model import TemporalPointAE, TemporalPointAEV2, TemporalPointAEV5, TemporalPointAEV6, TemporalPointAEV7, TemporalPointAEV8, TemporalPointAEV9, TemporalPointAEV10, TemporalPointAEV4, TemporalPointAEV3_Real, TemporalPointAEV11, TemporalPointAEV12, TemporalPointAEV13, TemporalPointAEV14, TemporalPointAEV17, TemporalPointAEV19, TemporalPointAEV20, TemporalPointAEV21, TemporalPointAEV22, TemporalPointAEV23, TemporalPointAEV24, TemporalPointAEV25, TemporalPointAEV26
# import trimesh
from utils import *
# import utils
import utils.model_util as model_util
# from anchorutils import anchor_load_driver, recover_anchor, recover_anchor_batch
# K x (1 + 1)
# minimum distance; disp - k * disp_o(along_disp_dir) (l2 norm); k * disp_o(vertical_disp_dir) (l2 norm) -> how those
# # object moving and the contact information ? #
# only textures on the hand vertices # themselves #
## the effectiveness of those values themselves -->
# torch, not_batched #
def calculate_disp_quants(joints, base_pts_trans, minn_base_pts_idxes=None):
# joints: nf x nn_joints x 3;
# base_pts_trans: nf x nn_base_pts x 3; # base pts trans #
# nf - 1 #
dist_joints_to_base_pts = torch.sum(
(joints.unsqueeze(-2) - base_pts_trans.unsqueeze(1)) ** 2, dim=-1 # nf x nn_joints x nn_base_pts x 3 --> nf x nnjoints x nnbasepts
)
cur_dist_joints_to_base_pts, cur_minn_base_pts_idxes = torch.min(dist_joints_to_base_pts, dim=-1) # nf x nnjoints
if minn_base_pts_idxes is None:
minn_base_pts_idxes = cur_minn_base_pts_idxes
# dist_joints_to_base_pts: nf nn_joints #
dist_joints_to_base_pts = model_util.batched_index_select_ours(dist_joints_to_base_pts, minn_base_pts_idxes.unsqueeze(-1), dim=2).squeeze(-1)
dist_joints_to_base_pts = torch.sqrt(dist_joints_to_base_pts) # nf x nn_joints #
k_f = 1. # dist_joints_to_base_pts --> nf x nn_joints x nn_base_pts #
k = torch.exp(-1. * k_f * (dist_joints_to_base_pts.detach())) # 0 -> 1 value # # nf x nn_joints # nf x nn_joints #
disp_base_pts = base_pts_trans[1:] -base_pts_trans[:-1] # basepts trans #
disp_joints = joints[1:] - joints[:-1] # (nf - 1) x nn_joints x 3 --> for joints displacement here #
minn_base_pts_idxes = minn_base_pts_idxes[:-1]
k = k[:-1]
dir_disp_base_pts = disp_base_pts / torch.clamp(torch.norm(disp_base_pts, p=2, keepdim=True, dim=-1), min=1e-9) # (nf - 1) x nn_base_pts x 3
dir_disp_base_pts = model_util.batched_index_select_ours(dir_disp_base_pts, minn_base_pts_idxes.detach(), dim=1) # (nf - 1) x nf x 3
disp_base_pts = model_util.batched_index_select_ours(disp_base_pts, minn_base_pts_idxes.detach(), dim=1)
# disp along base disp dir #
disp_along_base_disp_dir = disp_joints * dir_disp_base_pts # (nf - 1) x nn_joints x 3 # along disp dir
disp_vt_base_disp_dir = disp_joints - disp_along_base_disp_dir # (nf - 1) x nn_joints x 3 # vt disp dir
# disp; disp optimziation; and the distances between disps # # moving consistency correction --> but not the optimization? #
dist_disp_along_dir = disp_base_pts - k.unsqueeze(-1) * disp_along_base_disp_dir
dist_disp_along_dir = torch.norm(dist_disp_along_dir, dim=-1, p=2) # (nf - 1) x nn_joints # dist_disp_along_dir
dist_disp_vt_dir = torch.norm(disp_vt_base_disp_dir, dim=-1, p=2) # (nf - 1) x nn_joints #
dist_joints_to_base_pts_disp = dist_joints_to_base_pts[:-1] # (nf - 1) x nn_joints #
return dist_joints_to_base_pts_disp, dist_disp_along_dir, dist_disp_vt_dir
# batched get quantities here #
# torch, not_batched #
# dist_joints_to_base_pts_disp, dist_disp_along_dir, dist_disp_vt_dir = calculate_disp_quants_batched(joints, base_pts_trans)
def calculate_disp_quants_batched(joints, base_pts_trans):
# joints: nf x nn_joints x 3;
# base_pts_trans: nf x nn_base_pts x 3;
# nf - 1 # nf x nn_joints x nn_base_pts x 3 #
dist_joints_to_base_pts = torch.sum(
(joints.unsqueeze(-2) - base_pts_trans.unsqueeze(-3)) ** 2, dim=-1 # nf x nn_joints x nn_base_pts x 3 --> nf x nnjoints x nnbasepts
)
dist_joints_to_base_pts, minn_base_pts_idxes = torch.min(dist_joints_to_base_pts, dim=-1) # nf x nnjoints
dist_joints_to_base_pts = torch.sqrt(dist_joints_to_base_pts) # nf x nn_joints #
k_f = 1.
k = torch.exp(-1. * k_f * (dist_joints_to_base_pts)) # 0 -> 1 value # # nf x nn_joints # nf x nn_joints #
disp_base_pts = base_pts_trans[:, 1:] - base_pts_trans[:, :-1] # basepts trans #
disp_joints = joints[:, 1:] - joints[:, :-1] # (nf - 1) x nn_joints x 3 --> for joints displacement here #
minn_base_pts_idxes = minn_base_pts_idxes[:, :-1] # bsz x (nf - 1) #
k = k[:, :-1]
dir_disp_base_pts = disp_base_pts / torch.clamp(torch.norm(disp_base_pts, p=2, keepdim=True, dim=-1), min=1e-23) # (nf - 1) x nn_base_pts x 3
dir_disp_base_pts = model_util.batched_index_select_ours(dir_disp_base_pts, minn_base_pts_idxes, dim=2) # (nf - 1) x nnjoints x 3
# disp_base_pts, minn_base_pts_idxes --> bsz x (nf - 1) x nnjoints
disp_base_pts = model_util.batched_index_select_ours(disp_base_pts, minn_base_pts_idxes, dim=2)
disp_along_base_disp_dir = disp_joints * dir_disp_base_pts # bsz x (nf - 1) x nn_joints x 3 # along disp dir
disp_vt_base_disp_dir = disp_joints - disp_along_base_disp_dir # bsz x (nf - 1) x nn_joints x 3 # vt disp dir
# disp_base_pts -> bsz x (nf - 1) x njoints x 3 # dist_disp_along_dir
dist_disp_along_dir = disp_base_pts - k.unsqueeze(-1) * disp_along_base_disp_dir
dist_disp_along_dir = torch.norm(dist_disp_along_dir, dim=-1, p=2) # bsz x (nf - 1) x nn_joints # dist_disp_along_dir
dist_disp_vt_dir = torch.norm(disp_vt_base_disp_dir, dim=-1, p=2) # bsz x (nf - 1) x nn_joints #
dist_joints_to_base_pts_disp = dist_joints_to_base_pts[:, :-1] # bsz x (nf - 1) x nn_joints #
return dist_joints_to_base_pts_disp, dist_disp_along_dir, dist_disp_vt_dir
# batched get quantities here #
# torch, not_batched #
# dist_joints_to_base_pts_disp, dist_disp_along_dir, dist_disp_vt_dir = calculate_disp_quants_batched(joints, base_pts_trans)
def calculate_disp_quants_v2(joints, base_pts_trans, canon_joints, canon_base_normals):
# joints: nf x nn_joints x 3;
# base_pts_trans: nf x nn_base_pts x 3;
# nf - 1 # nf x nn_joints x nn_base_pts x 3 #
# joints: nf x nn_joints x 3 # --> nf x nn_joint x 1 x 3 - nf x 1 x nn_base_pts x 3 --> nf x nn_jts x nn_base_pts x 3 #
# base_pts_trans: nf x nn_base_pt x 3 #
dist_joints_to_base_pts = torch.sum(
(joints.unsqueeze(-2) - base_pts_trans.unsqueeze(-3)) ** 2, dim=-1 # nf x nn_joints x nn_base_pts x 3 --> nf x nnjoints x nnbasepts
)
dist_joints_to_base_pts, minn_base_pts_idxes = torch.min(dist_joints_to_base_pts, dim=-1) # nf x nnjoints
dist_joints_to_base_pts = torch.sqrt(dist_joints_to_base_pts) # nf x nn_joints #
k_f = 1.
k = torch.exp(-1. * k_f * (dist_joints_to_base_pts)) # 0 -> 1 value # # nf x nn_joints # nf x nn_joints #
###
### base pts velocity ###
disp_base_pts = base_pts_trans[1:] - base_pts_trans[:-1] # basepts trans #
### joints velocity ###
disp_joints = joints[1:] - joints[:-1] # (nf - 1) x nn_joints x 3 --> for joints displacement here #
minn_base_pts_idxes = minn_base_pts_idxes[:-1] # bsz x (nf - 1) #
k = k[:-1]
### joints velocity in the canonicalized space ###
disp_canon_joints = canon_joints[1:] - canon_joints[:-1]
### baes points normals information ###
disp_canon_base_normals = canon_base_normals[:-1] # bsz x (nf - 1) x 3 --> normals of base points ##
# bsz x (nf - 1) x nn_joints x 3 ##
disp_canon_base_normals = model_util.batched_index_select_ours(values=disp_canon_base_normals, indices=minn_base_pts_idxes, dim=1)
### joint velocity along normals ###
disp_joints_along_normals = disp_canon_base_normals * disp_canon_joints
dotprod_disp_joints_along_normals = disp_joints_along_normals.sum(dim=-1) # bsz x (nf - 1) x nn_joints
disp_joints_vt_normals = disp_canon_joints - dotprod_disp_joints_along_normals.unsqueeze(-1) * disp_canon_base_normals
l2_disp_joints_vt_normals = torch.norm(disp_joints_vt_normals, p=2, keepdim=False, dim=-1) # bsz x (nf - 1) x nn_joints # --> for l2 norm vt normals # l2 normal of the disp_joints ###
# dir_disp_base_pts = disp_base_pts / torch.clamp(torch.norm(disp_base_pts, p=2, keepdim=True, dim=-1), min=1e-23) # (nf - 1) x nn_base_pts x 3
# dir_disp_base_pts = model_util.batched_index_select_ours(dir_disp_base_pts, minn_base_pts_idxes, dim=2) # (nf - 1) x nnjoints x 3
# # disp_base_pts, minn_base_pts_idxes --> bsz x (nf - 1) x nnjoints
disp_base_pts = model_util.batched_index_select_ours(disp_base_pts, minn_base_pts_idxes, dim=1)
# disp_along_base_disp_dir = disp_joints * dir_disp_base_pts # bsz x (nf - 1) x nn_joints x 3 # along disp dir
# disp_vt_base_disp_dir = disp_joints - disp_along_base_disp_dir # bsz x (nf - 1) x nn_joints x 3 # vt disp dir
# # disp_base_pts -> bsz x (nf - 1) x njoints x 3 # dist_disp_along_dir
# dist_disp_along_dir = disp_base_pts - k.unsqueeze(-1) * disp_along_base_disp_dir
# dist_disp_along_dir = torch.norm(dist_disp_along_dir, dim=-1, p=2) # bsz x (nf - 1) x nn_joints # dist_disp_along_dir
# dist_disp_vt_dir = torch.norm(disp_vt_base_disp_dir, dim=-1, p=2) # bsz x (nf - 1) x nn_joints #
dist_joints_to_base_pts_disp = dist_joints_to_base_pts[:-1] # bsz x (nf - 1) x nn_joints #
# disp_base_pts: (nf - 1) x nn_joints x 3 # -> disp of base pts for each joint #
# dist_joints_to_base_pts_disp: (nf - 1) x nn_joints #
# dotprod_disp_joints_along_normals: (nf - 1) x nn_joints #
# l2_disp_joints_vt_normals: (nf - 1) x nn_joints #
# disp_base_pts, dist_joints_to_base_pts_disp, dotprod_disp_joints_along_normals, l2_disp_joints_vt_normals #
return disp_base_pts, dist_joints_to_base_pts_disp, dotprod_disp_joints_along_normals, l2_disp_joints_vt_normals
# batched get quantities here #
# torch, not_batched #
# dist_joints_to_base_pts_disp, dist_disp_along_dir, dist_disp_vt_dir = calculate_disp_quants_batched(joints, base_pts_trans)
def calculate_disp_quants_batched_v2(joints, base_pts_trans, canon_joints, canon_base_normals):
# joints: nf x nn_joints x 3;
# base_pts_trans: nf x nn_base_pts x 3;
# nf - 1 # nf x nn_joints x nn_base_pts x 3 #
dist_joints_to_base_pts = torch.sum(
(joints.unsqueeze(-2) - base_pts_trans.unsqueeze(-3)) ** 2, dim=-1 # nf x nn_joints x nn_base_pts x 3 --> nf x nnjoints x nnbasepts
)
dist_joints_to_base_pts, minn_base_pts_idxes = torch.min(dist_joints_to_base_pts, dim=-1) # nf x nnjoints
dist_joints_to_base_pts = torch.sqrt(dist_joints_to_base_pts) # nf x nn_joints #
k_f = 1.
k = torch.exp(-1. * k_f * (dist_joints_to_base_pts)) # 0 -> 1 value # # nf x nn_joints # nf x nn_joints #
###
### base pts velocity ###
disp_base_pts = base_pts_trans[:, 1:] - base_pts_trans[:, :-1] # basepts trans #
### joints velocity ###
disp_joints = joints[:, 1:] - joints[:, :-1] # (nf - 1) x nn_joints x 3 --> for joints displacement here #
minn_base_pts_idxes = minn_base_pts_idxes[:, :-1] # bsz x (nf - 1) #
k = k[:, :-1]
### joints velocity in the canonicalized space ###
disp_canon_joints = canon_joints[:, 1:] - canon_joints[:, :-1]
### baes points normals information ###
disp_canon_base_normals = canon_base_normals[:, :-1] # bsz x (nf - 1) x 3 --> normals of base points ##
# bsz x (nf - 1) x nn_joints x 3 ##
disp_canon_base_normals = model_util.batched_index_select_ours(values=disp_canon_base_normals, indices=minn_base_pts_idxes, dim=2)
### joint velocity along normals ###
disp_joints_along_normals = disp_canon_base_normals * disp_canon_joints
dotprod_disp_joints_along_normals = disp_joints_along_normals.sum(dim=-1) # bsz x (nf - 1) x nn_joints
disp_joints_vt_normals = disp_canon_joints - dotprod_disp_joints_along_normals.unsqueeze(-1) * disp_canon_base_normals
l2_disp_joints_vt_normals = torch.norm(disp_joints_vt_normals, p=2, keepdim=False, dim=-1) # bsz x (nf - 1) x nn_joints # --> for l2 norm vt normals
# dir_disp_base_pts = disp_base_pts / torch.clamp(torch.norm(disp_base_pts, p=2, keepdim=True, dim=-1), min=1e-23) # (nf - 1) x nn_base_pts x 3
# dir_disp_base_pts = model_util.batched_index_select_ours(dir_disp_base_pts, minn_base_pts_idxes, dim=2) # (nf - 1) x nnjoints x 3
# # disp_base_pts, minn_base_pts_idxes --> bsz x (nf - 1) x nnjoints
disp_base_pts = model_util.batched_index_select_ours(disp_base_pts, minn_base_pts_idxes, dim=2)
# disp_along_base_disp_dir = disp_joints * dir_disp_base_pts # bsz x (nf - 1) x nn_joints x 3 # along disp dir
# disp_vt_base_disp_dir = disp_joints - disp_along_base_disp_dir # bsz x (nf - 1) x nn_joints x 3 # vt disp dir
# # disp_base_pts -> bsz x (nf - 1) x njoints x 3 # dist_disp_along_dir
# dist_disp_along_dir = disp_base_pts - k.unsqueeze(-1) * disp_along_base_disp_dir
# dist_disp_along_dir = torch.norm(dist_disp_along_dir, dim=-1, p=2) # bsz x (nf - 1) x nn_joints # dist_disp_along_dir
# dist_disp_vt_dir = torch.norm(disp_vt_base_disp_dir, dim=-1, p=2) # bsz x (nf - 1) x nn_joints #
dist_joints_to_base_pts_disp = dist_joints_to_base_pts[:, :-1] # bsz x (nf - 1) x nn_joints #
return disp_base_pts, dist_joints_to_base_pts_disp, dotprod_disp_joints_along_normals, l2_disp_joints_vt_normals
def get_optimized_hand_fr_joints(joints):
joints = torch.from_numpy(joints).float().cuda()
### start optimization ###
# setup MANO layer
mano_path = "/data1/sim/mano_models/mano/models"
mano_layer = ManoLayer(
flat_hand_mean=True,
side='right',
mano_root=mano_path, # mano_root #
ncomps=24,
use_pca=True,
root_rot_mode='axisang',
joint_rot_mode='axisang'
).cuda()
nn_frames = joints.size(0)
# initialize variables
beta_var = torch.randn([1, 10]).cuda()
# first 3 global orientation
rot_var = torch.randn([nn_frames, 3]).cuda()
theta_var = torch.randn([nn_frames, 24]).cuda()
transl_var = torch.randn([nn_frames, 3]).cuda()
# transl_var = tot_rhand_transl.unsqueeze(0).repeat(args.num_init, 1, 1).contiguous().to(device).view(args.num_init * num_frames, 3).contiguous()
# ori_transl_var = transl_var.clone()
# rot_var = tot_rhand_glb_orient.unsqueeze(0).repeat(args.num_init, 1, 1).contiguous().to(device).view(args.num_init * num_frames, 3).contiguous()
beta_var.requires_grad_()
rot_var.requires_grad_()
theta_var.requires_grad_()
transl_var.requires_grad_()
learning_rate = 0.1
# opt = optim.Adam([rot_var, transl_var], lr=args.coarse_lr)
num_iters = 200
opt = optim.Adam([rot_var, transl_var], lr=learning_rate)
for i in range(num_iters): #
opt.zero_grad()
# mano_layer #
hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var)
hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001
hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001
joints_pred_loss = torch.sum(
(hand_joints - joints) ** 2, dim=-1
).mean()
# opt.zero_grad()
pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[:, 1:], theta_var.view(nn_frames, -1)[:, :-1])
# joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device))
shape_prior_loss = torch.mean(beta_var**2)
pose_prior_loss = torch.mean(theta_var**2)
# pose_smoothness_loss =
# =0.05
# loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001
loss = joints_pred_loss * 30
opt.zero_grad()
loss.backward()
opt.step()
print('Iter {}: {}'.format(i, loss.item()), flush=True)
print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item()))
num_iters = 2000
opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(opt, step_size=num_iters, gamma=0.5)
for i in range(num_iters):
opt.zero_grad()
# mano_layer
hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var)
hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001
hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001
joints_pred_loss = torch.sum(
(hand_joints - joints) ** 2, dim=-1
).mean()
# opt.zero_grad()
pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[:, 1:], theta_var.view(nn_frames, -1)[:, :-1])
# joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device))
shape_prior_loss = torch.mean(beta_var**2)
pose_prior_loss = torch.mean(theta_var**2)
# =0.05
# loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001
loss = joints_pred_loss * 30 # + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001
opt.zero_grad()
loss.backward()
opt.step()
scheduler.step()
print('Iter {}: {}'.format(i, loss.item()), flush=True)
print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item()))
return hand_verts.detach().cpu().numpy(), hand_joints.detach().cpu().numpy()
def get_affinity_fr_dist(dist, s=0.02):
### affinity scores ###
k = 0.5 * torch.cos(torch.pi / s * torch.abs(dist)) + 0.5
return k
def get_optimized_hand_fr_joints_v2(joints, base_pts):
joints = torch.from_numpy(joints).float().cuda()
base_pts = torch.from_numpy(base_pts).float().cuda()
### start optimization ###
# setup MANO layer
mano_path = "/data1/sim/mano_models/mano/models"
mano_layer = ManoLayer(
flat_hand_mean=True,
side='right',
mano_root=mano_path, # mano_root #
ncomps=24,
use_pca=True,
root_rot_mode='axisang',
joint_rot_mode='axisang'
).cuda()
nn_frames = joints.size(0)
# initialize variables
beta_var = torch.randn([1, 10]).cuda()
# first 3 global orientation
rot_var = torch.randn([nn_frames, 3]).cuda()
theta_var = torch.randn([nn_frames, 24]).cuda()
transl_var = torch.randn([nn_frames, 3]).cuda()
# transl_var = tot_rhand_transl.unsqueeze(0).repeat(args.num_init, 1, 1).contiguous().to(device).view(args.num_init * num_frames, 3).contiguous()
# ori_transl_var = transl_var.clone()
# rot_var = tot_rhand_glb_orient.unsqueeze(0).repeat(args.num_init, 1, 1).contiguous().to(device).view(args.num_init * num_frames, 3).contiguous()
beta_var.requires_grad_()
rot_var.requires_grad_()
theta_var.requires_grad_()
transl_var.requires_grad_()
learning_rate = 0.1
# joints: nf x nnjoints x 3 #
dist_joints_to_base_pts = torch.sum(
(joints.unsqueeze(-2) - base_pts.unsqueeze(0).unsqueeze(1)) ** 2, dim=-1 # nf x nnjoints x nnbasepts #
)
nn_base_pts = dist_joints_to_base_pts.size(-1)
dist_joints_to_base_pts = torch.sqrt(dist_joints_to_base_pts) # nf x nnjoints x nnbasepts #
minn_dist, minn_dist_idx = torch.min(dist_joints_to_base_pts, dim=-1) # nf x nnjoints #
basepts_idx_range = torch.arange(nn_base_pts).unsqueeze(0).unsqueeze(0).cuda()
minn_dist_mask = basepts_idx_range == minn_dist_idx.unsqueeze(-1) # nf x nnjoints x nnbasepts
minn_dist_mask = minn_dist_mask.float()
print(f"minn_dist_mask: {minn_dist_mask.size()}")
s = 1.0
affinity_scores = get_affinity_fr_dist(dist_joints_to_base_pts, s=s)
# opt = optim.Adam([rot_var, transl_var], lr=args.coarse_lr)
num_iters = 200
opt = optim.Adam([rot_var, transl_var], lr=learning_rate)
for i in range(num_iters): #
opt.zero_grad()
# mano_layer #
hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var)
hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001
hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001
joints_pred_loss = torch.sum(
(hand_joints - joints) ** 2, dim=-1
).mean()
# opt.zero_grad()
pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[:, 1:], theta_var.view(nn_frames, -1)[:, :-1])
# joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device))
shape_prior_loss = torch.mean(beta_var**2)
pose_prior_loss = torch.mean(theta_var**2)
# pose_smoothness_loss =
# =0.05
# loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001
loss = joints_pred_loss * 30
loss = joints_pred_loss * 1000
opt.zero_grad()
loss.backward()
opt.step()
print('Iter {}: {}'.format(i, loss.item()), flush=True)
print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item()))
num_iters = 2000
num_iters = 3000
# num_iters = 1000
learning_rate = 0.01
opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(opt, step_size=num_iters, gamma=0.5)
for i in range(num_iters):
opt.zero_grad()
# mano_layer
hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var)
hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001
hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001
joints_pred_loss = torch.sum(
(hand_joints - joints) ** 2, dim=-1
).mean()
dist_joints_to_base_pts_sqr = torch.sum(
(hand_joints.unsqueeze(2) - base_pts.unsqueeze(0).unsqueeze(1)) ** 2, dim=-1
)
# attaction_loss = 0.5 * affinity_scores * dist_joints_to_base_pts_sqr
attaction_loss = 0.5 * dist_joints_to_base_pts_sqr
# attaction_loss = attaction_loss
# attaction_loss = torch.mean(attaction_loss[..., -5:, :] * minn_dist_mask[..., -5:, :])
attaction_loss = torch.mean(attaction_loss[46:, -5:-3, :] * minn_dist_mask[46:, -5:-3, :])
# opt.zero_grad()
pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[1:], theta_var.view(nn_frames, -1)[:-1])
# joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device))
shape_prior_loss = torch.mean(beta_var**2)
pose_prior_loss = torch.mean(theta_var**2)
joints_smoothness_loss = F.mse_loss(hand_joints.view(nn_frames, -1, 3)[1:], hand_joints.view(nn_frames, -1, 3)[:-1])
# =0.05
# loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 + joints_smoothness_loss * 100.
loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.000001 + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001 + joints_smoothness_loss * 200.
loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + joints_smoothness_loss * 200.
loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.03 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + joints_smoothness_loss * 200.
# loss = joints_pred_loss * 20 + joints_smoothness_loss * 200. + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001
# loss = joints_pred_loss * 30 # + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001
# loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 0.001 + joints_smoothness_loss * 1.0
# loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 100. + joints_smoothness_loss * 10.0
# loss = joints_pred_loss * 20 # + pose_smoothness_loss * 0.5 + attaction_loss * 100. + joints_smoothness_loss * 10.0
# loss = joints_pred_loss * 30 + attaction_loss * 0.001
opt.zero_grad()
loss.backward()
opt.step()
scheduler.step()
print('Iter {}: {}'.format(i, loss.item()), flush=True)
print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item()))
print('\tAttraction Loss: {}'.format(attaction_loss.item()))
print('\tJoint Smoothness Loss: {}'.format(joints_smoothness_loss.item()))
return hand_verts.detach().cpu().numpy(), hand_joints.detach().cpu().numpy()
def get_optimized_hand_fr_joints_v3(joints, base_pts, tot_base_pts_trans):
joints = torch.from_numpy(joints).float().cuda()
base_pts = torch.from_numpy(base_pts).float().cuda()
tot_base_pts_trans = torch.from_numpy(tot_base_pts_trans).float().cuda()
### start optimization ###
# setup MANO layer
mano_path = "/data1/sim/mano_models/mano/models"
mano_layer = ManoLayer(
flat_hand_mean=True,
side='right',
mano_root=mano_path, # mano_root #
ncomps=24,
use_pca=True,
root_rot_mode='axisang',
joint_rot_mode='axisang'
).cuda()
nn_frames = joints.size(0)
# initialize variables
beta_var = torch.randn([1, 10]).cuda()
# first 3 global orientation
rot_var = torch.randn([nn_frames, 3]).cuda()
theta_var = torch.randn([nn_frames, 24]).cuda()
transl_var = torch.randn([nn_frames, 3]).cuda()
# transl_var = tot_rhand_transl.unsqueeze(0).repeat(args.num_init, 1, 1).contiguous().to(device).view(args.num_init * num_frames, 3).contiguous()
# ori_transl_var = transl_var.clone()
# rot_var = tot_rhand_glb_orient.unsqueeze(0).repeat(args.num_init, 1, 1).contiguous().to(device).view(args.num_init * num_frames, 3).contiguous()
beta_var.requires_grad_()
rot_var.requires_grad_()
theta_var.requires_grad_()
transl_var.requires_grad_()
learning_rate = 0.1
# joints: nf x nnjoints x 3 #
dist_joints_to_base_pts = torch.sum(
(joints.unsqueeze(-2) - base_pts.unsqueeze(0).unsqueeze(1)) ** 2, dim=-1 # nf x nnjoints x nnbasepts #
)
nn_base_pts = dist_joints_to_base_pts.size(-1)
nn_joints = dist_joints_to_base_pts.size(1)
dist_joints_to_base_pts = torch.sqrt(dist_joints_to_base_pts) # nf x nnjoints x nnbasepts #
minn_dist, minn_dist_idx = torch.min(dist_joints_to_base_pts, dim=-1) # nf x nnjoints #
nk_contact_pts = 2
minn_dist[:, :-5] = 1e9
minn_topk_dist, minn_topk_idx = torch.topk(minn_dist, k=nk_contact_pts, largest=False) #
# joints_idx_rng_exp = torch.arange(nn_joints).unsqueeze(0).cuda() ==
minn_topk_mask = torch.zeros_like(minn_dist)
# minn_topk_mask[minn_topk_idx] = 1. # nf x nnjoints #
minn_topk_mask[:, -5: -3] = 1.
basepts_idx_range = torch.arange(nn_base_pts).unsqueeze(0).unsqueeze(0).cuda()
minn_dist_mask = basepts_idx_range == minn_dist_idx.unsqueeze(-1) # nf x nnjoints x nnbasepts
minn_dist_mask = minn_dist_mask.float()
minn_topk_mask = (minn_dist_mask + minn_topk_mask.float().unsqueeze(-1)) > 1.5
print(f"minn_dist_mask: {minn_dist_mask.size()}")
s = 1.0
affinity_scores = get_affinity_fr_dist(dist_joints_to_base_pts, s=s)
# opt = optim.Adam([rot_var, transl_var], lr=args.coarse_lr)
num_iters = 200
opt = optim.Adam([rot_var, transl_var], lr=learning_rate)
for i in range(num_iters): #
opt.zero_grad()
# mano_layer #
hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var)
hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001
hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001
joints_pred_loss = torch.sum(
(hand_joints - joints) ** 2, dim=-1
).mean()
# opt.zero_grad()
pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[:, 1:], theta_var.view(nn_frames, -1)[:, :-1])
# joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device))
shape_prior_loss = torch.mean(beta_var**2)
pose_prior_loss = torch.mean(theta_var**2)
# pose_smoothness_loss =
# =0.05
# loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001
loss = joints_pred_loss * 30
loss = joints_pred_loss * 1000
opt.zero_grad()
loss.backward()
opt.step()
print('Iter {}: {}'.format(i, loss.item()), flush=True)
print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item()))
#
print(tot_base_pts_trans.size())
diff_base_pts_trans = torch.sum((tot_base_pts_trans[1:, :, :] - tot_base_pts_trans[:-1, :, :]) ** 2, dim=-1) # (nf - 1) x nn_base_pts
print(f"diff_base_pts_trans: {diff_base_pts_trans.size()}")
diff_base_pts_trans = diff_base_pts_trans.mean(dim=-1)
diff_base_pts_trans_threshold = 1e-20
diff_base_pts_trans_mask = diff_base_pts_trans > diff_base_pts_trans_threshold # (nf - 1) ### the mask of the tranformed base pts
diff_base_pts_trans_mask = diff_base_pts_trans_mask.float()
print(f"diff_base_pts_trans_mask: {diff_base_pts_trans_mask.size()}, diff_base_pts_trans: {diff_base_pts_trans.size()}")
diff_last_frame_mask = torch.tensor([0,], dtype=torch.float32).to(diff_base_pts_trans_mask.device) + diff_base_pts_trans_mask[-1]
diff_base_pts_trans_mask = torch.cat(
[diff_base_pts_trans_mask, diff_last_frame_mask], dim=0 # nf tensor
)
# attraction_mask = (diff_base_pts_trans_mask.unsqueeze(-1).unsqueeze(-1) + minn_topk_mask.float()) > 1.5
attraction_mask = minn_topk_mask.float()
attraction_mask = attraction_mask.float()
num_iters = 2000
num_iters = 3000
# num_iters = 1000
learning_rate = 0.01
opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(opt, step_size=num_iters, gamma=0.5)
for i in range(num_iters):
opt.zero_grad()
# mano_layer
hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var)
hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001
hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001
joints_pred_loss = torch.sum(
(hand_joints - joints) ** 2, dim=-1
).mean()
dist_joints_to_base_pts_sqr = torch.sum(
(hand_joints.unsqueeze(2) - base_pts.unsqueeze(0).unsqueeze(1)) ** 2, dim=-1
)
# attaction_loss = 0.5 * affinity_scores * dist_joints_to_base_pts_sqr
attaction_loss = 0.5 * dist_joints_to_base_pts_sqr
# attaction_loss = attaction_loss
# attaction_loss = torch.mean(attaction_loss[..., -5:, :] * minn_dist_mask[..., -5:, :])
# attaction_loss = torch.mean(attaction_loss * attraction_mask)
attaction_loss = torch.mean(attaction_loss[46:, -5:-3, :] * minn_dist_mask[46:, -5:-3, :])
# opt.zero_grad()
pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[1:], theta_var.view(nn_frames, -1)[:-1])
# joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device))
shape_prior_loss = torch.mean(beta_var**2)
pose_prior_loss = torch.mean(theta_var**2)
joints_smoothness_loss = F.mse_loss(hand_joints.view(nn_frames, -1, 3)[1:], hand_joints.view(nn_frames, -1, 3)[:-1])
# =0.05
# loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 + joints_smoothness_loss * 100.
loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.000001 + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001 + joints_smoothness_loss * 200.
loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + joints_smoothness_loss * 200.
loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.03 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 + attaction_loss * 10000 # + joints_smoothness_loss * 200.
# loss = joints_pred_loss * 20 + joints_smoothness_loss * 200. + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001
# loss = joints_pred_loss * 30 # + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001
# loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 0.001 + joints_smoothness_loss * 1.0
loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 2000000. # + joints_smoothness_loss * 10.0
# loss = joints_pred_loss * 20 # + pose_smoothness_loss * 0.5 + attaction_loss * 100. + joints_smoothness_loss * 10.0
# loss = joints_pred_loss * 30 + attaction_loss * 0.001
opt.zero_grad()
loss.backward()
opt.step()
scheduler.step()
print('Iter {}: {}'.format(i, loss.item()), flush=True)
print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item()))
print('\tAttraction Loss: {}'.format(attaction_loss.item()))
print('\tJoint Smoothness Loss: {}'.format(joints_smoothness_loss.item()))
return hand_verts.detach().cpu().numpy(), hand_joints.detach().cpu().numpy()
def get_optimized_hand_fr_joints_v4(joints, base_pts, tot_base_pts_trans, tot_base_normals_trans, with_contact_opt=False):
joints = torch.from_numpy(joints).float().cuda()
base_pts = torch.from_numpy(base_pts).float().cuda()
tot_base_pts_trans = torch.from_numpy(tot_base_pts_trans).float().cuda()
tot_base_normals_trans = torch.from_numpy(tot_base_normals_trans).float().cuda()
### start optimization ###
# setup MANO layer
mano_path = "/data1/sim/mano_models/mano/models"
mano_layer = ManoLayer(
flat_hand_mean=True,
side='right',
mano_root=mano_path, # mano_root #
ncomps=24,
use_pca=True,
root_rot_mode='axisang',
joint_rot_mode='axisang'
).cuda()
nn_frames = joints.size(0)
# initialize variables
beta_var = torch.randn([1, 10]).cuda()
# first 3 global orientation
rot_var = torch.randn([nn_frames, 3]).cuda()
theta_var = torch.randn([nn_frames, 24]).cuda()
transl_var = torch.randn([nn_frames, 3]).cuda()
# transl_var = tot_rhand_transl.unsqueeze(0).repeat(args.num_init, 1, 1).contiguous().to(device).view(args.num_init * num_frames, 3).contiguous()
# ori_transl_var = transl_var.clone()
# rot_var = tot_rhand_glb_orient.unsqueeze(0).repeat(args.num_init, 1, 1).contiguous().to(device).view(args.num_init * num_frames, 3).contiguous()
beta_var.requires_grad_()
rot_var.requires_grad_()
theta_var.requires_grad_()
transl_var.requires_grad_()
learning_rate = 0.1
# joints: nf x nnjoints x 3 #
# dist_joints_to_base_pts = torch.sum(
# (joints.unsqueeze(-2) - base_pts.unsqueeze(0).unsqueeze(1)) ** 2, dim=-1 # nf x nnjoints x nnbasepts #
# )
dist_joints_to_base_pts = torch.sum(
(joints.unsqueeze(-2) - tot_base_pts_trans.unsqueeze(1)) ** 2, dim=-1 # nf x nnjoints x nnbasepts #
)
dot_prod_base_pts_to_joints_with_normals = torch.sum(
(joints.unsqueeze(-2) - tot_base_pts_trans.unsqueeze(1)) * tot_base_normals_trans.unsqueeze(1), dim=-1 #
)
# dist_joints_to_base_pts[dot_prod_base_pts_to_joints_with_normals < 0.] = 1e9
nn_base_pts = dist_joints_to_base_pts.size(-1)
nn_joints = dist_joints_to_base_pts.size(1)
dist_joints_to_base_pts = torch.sqrt(dist_joints_to_base_pts) # nf x nnjoints x nnbasepts #
minn_dist, minn_dist_idx = torch.min(dist_joints_to_base_pts, dim=-1) # nf x nnjoints #
nk_contact_pts = 2
minn_dist[:, :-5] = 1e9
minn_topk_dist, minn_topk_idx = torch.topk(minn_dist, k=nk_contact_pts, largest=False) #
# joints_idx_rng_exp = torch.arange(nn_joints).unsqueeze(0).cuda() ==
minn_topk_mask = torch.zeros_like(minn_dist)
# minn_topk_mask[minn_topk_idx] = 1. # nf x nnjoints #
minn_topk_mask[:, -5: -3] = 1.
basepts_idx_range = torch.arange(nn_base_pts).unsqueeze(0).unsqueeze(0).cuda()
minn_dist_mask = basepts_idx_range == minn_dist_idx.unsqueeze(-1) # nf x nnjoints x nnbasepts
# for seq 101
# minn_dist_mask[31:, -5, :] = minn_dist_mask[30: 31, -5, :]
minn_dist_mask[:, -5:, :] = minn_dist_mask[30:31:, -5:, :] # set to the last frame mask #
minn_dist_mask = minn_dist_mask.float()
tot_base_pts_trans_disp = torch.sum(
(tot_base_pts_trans[1:, :, :] - tot_base_pts_trans[:-1, :, :]) ** 2, dim=-1 # (nf - 1) x nn_base_pts displacement
)
tot_base_pts_trans_disp = torch.sqrt(tot_base_pts_trans_disp).mean(dim=-1) # (nf - 1)
# tot_base_pts_trans_disp_mov_thres = 1e-20
tot_base_pts_trans_disp_mov_thres = 3e-4
tot_base_pts_trans_disp_mask = tot_base_pts_trans_disp >= tot_base_pts_trans_disp_mov_thres
tot_base_pts_trans_disp_mask = torch.cat(
[tot_base_pts_trans_disp_mask, tot_base_pts_trans_disp_mask[-1:]], dim=0
)
attraction_mask_new = (tot_base_pts_trans_disp_mask.float().unsqueeze(-1).unsqueeze(-1) + minn_dist_mask.float()) > 1.5
minn_topk_mask = (minn_dist_mask + minn_topk_mask.float().unsqueeze(-1)) > 1.5
print(f"minn_dist_mask: {minn_dist_mask.size()}")
s = 1.0
affinity_scores = get_affinity_fr_dist(dist_joints_to_base_pts, s=s)
# opt = optim.Adam([rot_var, transl_var], lr=args.coarse_lr)
num_iters = 200
opt = optim.Adam([rot_var, transl_var], lr=learning_rate)
for i in range(num_iters): #
opt.zero_grad()
# mano_layer #
hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var)
hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001
hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001
joints_pred_loss = torch.sum(
(hand_joints - joints) ** 2, dim=-1
).mean()
# opt.zero_grad()
pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[:, 1:], theta_var.view(nn_frames, -1)[:, :-1])
# joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device))
shape_prior_loss = torch.mean(beta_var**2)
pose_prior_loss = torch.mean(theta_var**2)
# pose_smoothness_loss =
# =0.05
# loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001
loss = joints_pred_loss * 30
loss = joints_pred_loss * 1000
opt.zero_grad()
loss.backward()
opt.step()
print('Iter {}: {}'.format(i, loss.item()), flush=True)
print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item()))
#
print(tot_base_pts_trans.size())
diff_base_pts_trans = torch.sum((tot_base_pts_trans[1:, :, :] - tot_base_pts_trans[:-1, :, :]) ** 2, dim=-1) # (nf - 1) x nn_base_pts
print(f"diff_base_pts_trans: {diff_base_pts_trans.size()}")
diff_base_pts_trans = diff_base_pts_trans.mean(dim=-1)
diff_base_pts_trans_threshold = 1e-20
diff_base_pts_trans_mask = diff_base_pts_trans > diff_base_pts_trans_threshold # (nf - 1) ### the mask of the tranformed base pts
diff_base_pts_trans_mask = diff_base_pts_trans_mask.float()
print(f"diff_base_pts_trans_mask: {diff_base_pts_trans_mask.size()}, diff_base_pts_trans: {diff_base_pts_trans.size()}")
diff_last_frame_mask = torch.tensor([0,], dtype=torch.float32).to(diff_base_pts_trans_mask.device) + diff_base_pts_trans_mask[-1]
diff_base_pts_trans_mask = torch.cat(
[diff_base_pts_trans_mask, diff_last_frame_mask], dim=0 # nf tensor
)
# attraction_mask = (diff_base_pts_trans_mask.unsqueeze(-1).unsqueeze(-1) + minn_topk_mask.float()) > 1.5
attraction_mask = minn_topk_mask.float()
attraction_mask = attraction_mask.float()
# the direction of the normal vector and the moving direction of the object point -> whether the point should be selected
# the contact maps of the object should be like? #
# the direction of the normal vector and the moving direction
# define the attraction loss's weight; and attract points to the object surface #
#
#
num_iters = 2000
num_iters = 3000
# num_iters = 1000
learning_rate = 0.01
opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(opt, step_size=num_iters, gamma=0.5)
for i in range(num_iters):
opt.zero_grad()
# mano_layer
hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var)
hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001
hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001
joints_pred_loss = torch.sum(
(hand_joints - joints) ** 2, dim=-1
).mean()
# dist_joints_to_base_pts_sqr = torch.sum(
# (hand_joints.unsqueeze(2) - base_pts.unsqueeze(0).unsqueeze(1)) ** 2, dim=-1
# )
# attaction_loss = 0.5 * affinity_scores * dist_joints_to_base_pts_sqr
# attaction_loss = 0.5 * dist_joints_to_base_pts_sqr
# attaction_loss = attaction_loss
# attaction_loss = torch.mean(attaction_loss[..., -5:, :] * minn_dist_mask[..., -5:, :])
# attaction_loss = torch.mean(attaction_loss * attraction_mask)
# attaction_loss = torch.mean(attaction_loss[46:, -5:-3, :] * minn_dist_mask[46:, -5:-3, :])
# opt.zero_grad()
pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[1:], theta_var.view(nn_frames, -1)[:-1])
# joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device))
shape_prior_loss = torch.mean(beta_var**2)
pose_prior_loss = torch.mean(theta_var**2)
joints_smoothness_loss = F.mse_loss(hand_joints.view(nn_frames, -1, 3)[1:], hand_joints.view(nn_frames, -1, 3)[:-1])
# =0.05
# loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 + joints_smoothness_loss * 100.
loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.000001 + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001 + joints_smoothness_loss * 200.
loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + joints_smoothness_loss * 200.
loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.03 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + attaction_loss * 10000 # + joints_smoothness_loss * 200.
# loss = joints_pred_loss * 20 + joints_smoothness_loss * 200. + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001
# loss = joints_pred_loss * 30 # + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001
# loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 0.001 + joints_smoothness_loss * 1.0
# loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 2000000. # + joints_smoothness_loss * 10.0
# loss = joints_pred_loss * 20 # + pose_smoothness_loss * 0.5 + attaction_loss * 100. + joints_smoothness_loss * 10.0
# loss = joints_pred_loss * 30 + attaction_loss * 0.001
opt.zero_grad()
loss.backward()
opt.step()
scheduler.step()
print('Iter {}: {}'.format(i, loss.item()), flush=True)
print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item()))
# print('\tAttraction Loss: {}'.format(attaction_loss.item()))
print('\tJoint Smoothness Loss: {}'.format(joints_smoothness_loss.item()))
if with_contact_opt:
num_iters = 2000
# num_iters = 1000 # seq 77
# num_iters = 500 # seq 77
ori_theta_var = theta_var.detach().clone()
# tot_base_pts_trans # nf x nn_base_pts x 3
disp_base_pts_trans = tot_base_pts_trans[1:] - tot_base_pts_trans[:-1] # (nf - 1) x nn_base_pts x 3
disp_base_pts_trans = torch.cat( # nf x nn_base_pts x 3
[disp_base_pts_trans, disp_base_pts_trans[-1:]], dim=0
)
# joints: nf x nn_jts_pts x 3; nf x nn_base_pts x 3
dist_joints_to_base_pts_trans = torch.sum(
(joints.unsqueeze(2) - tot_base_pts_trans.unsqueeze(1)) ** 2, dim=-1 # nf x nn_jts_pts x nn_base_pts
)
minn_dist_joints_to_base_pts, minn_dist_idxes = torch.min(dist_joints_to_base_pts_trans, dim=-1) # nf x nn_jts_pts # nf x nn_jts_pts #
nearest_base_normals = model_util.batched_index_select_ours(tot_base_normals_trans, indices=minn_dist_idxes, dim=1) # nf x nn_base_pts x 3 --> nf x nn_jts_pts x 3 # # nf x nn_jts_pts x 3 #
nearest_base_pts_trans = model_util.batched_index_select_ours(disp_base_pts_trans, indices=minn_dist_idxes, dim=1) # nf x nn_jts_ts x 3 #
dot_nearest_base_normals_trans = torch.sum(
nearest_base_normals * nearest_base_pts_trans, dim=-1 # nf x nn_jts
)
trans_normals_mask = dot_nearest_base_normals_trans < 0. # nf x nn_jts # nf x nn_jts #
nearest_dist = torch.sqrt(minn_dist_joints_to_base_pts)
# nearest_dist_mask = nearest_dist < 0.01 # hoi seq
nearest_dist_mask = nearest_dist < 0.1
k_attr = 100.
joint_attraction_k = torch.exp(-1. * k_attr * nearest_dist)
attraction_mask_new_new = (attraction_mask_new.float() + trans_normals_mask.float().unsqueeze(-1) + nearest_dist_mask.float().unsqueeze(-1)) > 2.5
# opt = optim.Adam([rot_var, transl_var, theta_var], lr=learning_rate)
# opt = optim.Adam([transl_var, theta_var], lr=learning_rate)
opt = optim.Adam([transl_var, theta_var, rot_var], lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(opt, step_size=num_iters, gamma=0.5)
for i in range(num_iters):
opt.zero_grad()
# mano_layer
hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var)
hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001
hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001
joints_pred_loss = torch.sum(
(hand_joints - joints) ** 2, dim=-1
).mean()
# dist_joints_to_base_pts_sqr = torch.sum(
# (hand_joints.unsqueeze(2) - base_pts.unsqueeze(0).unsqueeze(1)) ** 2, dim=-1
# ) # nf x nnb x 3 ---- nf x nnj x 1 x 3
dist_joints_to_base_pts_sqr = torch.sum(
(hand_joints.unsqueeze(2) - tot_base_pts_trans.unsqueeze(1)) ** 2, dim=-1
)
# attaction_loss = 0.5 * affinity_scores * dist_joints_to_base_pts_sqr
attaction_loss = 0.5 * dist_joints_to_base_pts_sqr
# attaction_loss = attaction_loss
# attaction_loss = torch.mean(attaction_loss[..., -5:, :] * minn_dist_mask[..., -5:, :])
# attaction_loss = torch.mean(attaction_loss * attraction_mask)
# attaction_loss = torch.mean(attaction_loss[46:, -5:-3, :] * minn_dist_mask[46:, -5:-3, :]) + torch.mean(attaction_loss[:40, -5:-3, :] * minn_dist_mask[:40, -5:-3, :])
# seq 80
# attaction_loss = torch.mean(attaction_loss[46:, -5:-3, :] * minn_dist_mask[46:, -5:-3, :]) + torch.mean(attaction_loss[:40, -5:-3, :] * minn_dist_mask[:40, -5:-3, :])
# seq 70
# attaction_loss = torch.mean(attaction_loss[10:, -5:-3, :] * minn_dist_mask[10:, -5:-3, :]) # + torch.mean(attaction_loss[:40, -5:-3, :] * minn_dist_mask[:40, -5:-3, :])
# new version relying on new mask #
# attaction_loss = torch.mean(attaction_loss[:, -5:-3, :] * attraction_mask_new[:, -5:-3, :])
### original version ###
# attaction_loss = torch.mean(attaction_loss[20:, -3:, :] * attraction_mask_new[20:, -3:, :])
second_tips_idxes = [9, 12, 6, 3, 15]
second_tips_idxes =torch.tensor(second_tips_idxes, dtype=torch.long).cuda()
# attaction_loss = torch.mean(attaction_loss[:, -5:, :] * attraction_mask_new_new[:, -5:, :] * joint_attraction_k[:, -5:].unsqueeze(-1)) # + torch.mean(attaction_loss[:, second_tips_idxes, :] * attraction_mask_new_new[:, second_tips_idxes, :] * joint_attraction_k[:, second_tips_idxes].unsqueeze(-1))
# attaction_loss = torch.mean(attaction_loss[:, -5:-2, :] * attraction_mask_new_new[:, -5:-2, :]) # + torch.mean(attaction_loss[:, second_tips_idxes, :] * attraction_mask_new_new[:, second_tips_idxes, :] * joint_attraction_k[:, second_tips_idxes].unsqueeze(-1))
# attraction_mask_new
# attaction_loss = torch.mean(attaction_loss[:, -5:-2, :] * attraction_mask_new[:, -5:-2, :]) # + torch.mean(attaction_loss[:, second_tips_idxes, :] * attraction_mask_new_new[:, second_tips_idxes, :] * joint_attraction_k[:, second_tips_idxes].unsqueeze(-1))
# seq mug
# attaction_loss = torch.mean(attaction_loss[4:, -5:-4, :] * minn_dist_mask[4:, -5:-4, :]) # + torch.mean(attaction_loss[:40, -5:-3, :] * minn_dist_mask[:40, -5:-3, :])
attaction_loss = torch.mean(attaction_loss[10:, -5:-2, :] * attraction_mask_new[10:, -5:-2, :]) # + torch.mean(attaction_loss[:, second_tips_idxes, :] * attraction_mask_new_new[:, second_tips_idxes, :] * joint_attraction_k[:, second_tips_idxes].unsqueeze(-1))
# opt.zero_grad()
pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[1:], theta_var.view(nn_frames, -1)[:-1])
# joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device))
shape_prior_loss = torch.mean(beta_var**2)
pose_prior_loss = torch.mean(theta_var**2)
joints_smoothness_loss = F.mse_loss(hand_joints.view(nn_frames, -1, 3)[1:], hand_joints.view(nn_frames, -1, 3)[:-1])
# =0.05
# # loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 + joints_smoothness_loss * 100.
# loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.000001 + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001 + joints_smoothness_loss * 200.
# loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + joints_smoothness_loss * 200.
# loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.03 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + attaction_loss * 10000 # + joints_smoothness_loss * 200.
theta_smoothness_loss = F.mse_loss(theta_var, ori_theta_var)
# loss = attaction_loss * 1000. + theta_smoothness_loss * 0.00001
# loss = attaction_loss * 1000. + joints_pred_loss
# loss = attaction_loss * 100000. + joints_pred_loss * 0.00000001 + joints_smoothness_loss * 0.5 # + pose_prior_loss * 0.00005 # + shape_prior_loss * 0.001 # + pose_smoothness_loss * 0.5
loss = attaction_loss * 100000. + joints_pred_loss * 0.000 + joints_smoothness_loss * 0.000005
# loss = joints_pred_loss * 20 + joints_smoothness_loss * 200. + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001
# loss = joints_pred_loss * 30 # + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001
# loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 0.001 + joints_smoothness_loss * 1.0
# loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 2000000. # + joints_smoothness_loss * 10.0
# loss = joints_pred_loss * 20 # + pose_smoothness_loss * 0.5 + attaction_loss * 100. + joints_smoothness_loss * 10.0
# loss = joints_pred_loss * 30 + attaction_loss * 0.001
opt.zero_grad()
loss.backward()
opt.step()
scheduler.step()
print('Iter {}: {}'.format(i, loss.item()), flush=True)
print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item()))
print('\tAttraction Loss: {}'.format(attaction_loss.item()))
print('\tJoint Smoothness Loss: {}'.format(joints_smoothness_loss.item()))
# theta_smoothness_loss
print('\tTheta Smoothness Loss: {}'.format(theta_smoothness_loss.item()))
return hand_verts.detach().cpu().numpy(), hand_joints.detach().cpu().numpy()
# get_optimized_hand_fr_joints_v5 ---> get optimized i
def get_optimized_hand_fr_joints_v5(joints, tot_gt_rhand_joints, base_pts, tot_base_pts_trans, predicted_joint_quants=None):
joints = torch.from_numpy(joints).float().cuda()
base_pts = torch.from_numpy(base_pts).float().cuda()
tot_gt_rhand_joints = torch.from_numpy(tot_gt_rhand_joints).float().cuda()
tot_base_pts_trans = torch.from_numpy(tot_base_pts_trans).float().cuda()
# nf x nn_joitns for each quantity # #
if predicted_joint_quants is None:
gt_dist_joints_to_base_pts_disp, gt_dist_disp_along_dir, gt_dist_disp_vt_dir = calculate_disp_quants(tot_gt_rhand_joints, tot_base_pts_trans)
else:
predicted_joint_quants = torch.from_numpy(predicted_joint_quants).float().cuda()
gt_dist_joints_to_base_pts_disp, gt_dist_disp_along_dir, gt_dist_disp_vt_dir = predicted_joint_quants[..., 0], predicted_joint_quants[..., 1], predicted_joint_quants[..., 2] #
print(f"gt_dist_joints_to_base_pts_disp: {gt_dist_joints_to_base_pts_disp.size()}, gt_dist_disp_along_dir: {gt_dist_disp_along_dir.size()}, gt_dist_disp_vt_dir: {gt_dist_disp_vt_dir.size()}")
gt_dist_joints_to_base_pts_disp_real, gt_dist_disp_along_dir_real, gt_dist_disp_vt_dir_real = calculate_disp_quants(tot_gt_rhand_joints, tot_base_pts_trans)
diff_gt_dist_joints = torch.mean((gt_dist_joints_to_base_pts_disp_real - gt_dist_joints_to_base_pts_disp) ** 2)
diff_gt_disp_along_normals = torch.mean((gt_dist_disp_along_dir_real - gt_dist_disp_along_dir) ** 2)
diff_gt_disp_vt_normals = torch.mean((gt_dist_disp_vt_dir_real - gt_dist_disp_vt_dir) ** 2)
print(f"diff_gt_dist_joints: {diff_gt_dist_joints.item()}, diff_gt_disp_along_normals: {diff_gt_disp_along_normals.item()}, diff_gt_disp_vt_normals: {diff_gt_disp_vt_normals.item()}")
disp_info_pred_gt_sv_dict = {
'gt_dist_joints_to_base_pts_disp_real': gt_dist_joints_to_base_pts_disp_real.detach().cpu().numpy(),
'gt_dist_disp_along_dir_real': gt_dist_disp_along_dir_real.detach().cpu().numpy(),
'gt_dist_disp_vt_dir_real': gt_dist_disp_vt_dir_real.detach().cpu().numpy(),
'gt_dist_joints_to_base_pts_disp': gt_dist_joints_to_base_pts_disp.detach().cpu().numpy(),
'gt_dist_disp_along_dir': gt_dist_disp_along_dir.detach().cpu().numpy(),
'gt_dist_disp_vt_dir': gt_dist_disp_vt_dir.detach().cpu().numpy(),
}
#
disp_info_pred_gt_sv_fn = "/home/xueyi/sim/motion-diffusion-model/save/my_humanml_trans_enc_512"
disp_info_pred_gt_sv_fn = os.path.join(disp_info_pred_gt_sv_fn, f"pred_gt_disp_info.npy")
np.save(disp_info_pred_gt_sv_fn, disp_info_pred_gt_sv_dict)
print(f"saved to {disp_info_pred_gt_sv_fn}")
# tot_base_pts_trans = torch.from_numpy(tot_base_pts_trans).float().cuda()
### start optimization ###
# setup MANO layer
mano_path = "/data1/sim/mano_models/mano/models"
mano_layer = ManoLayer(
flat_hand_mean=True,
side='right',
mano_root=mano_path, # mano_root #
ncomps=24,
use_pca=True,
root_rot_mode='axisang',
joint_rot_mode='axisang'
).cuda()
### nn_frames ###
nn_frames = joints.size(0)
# initialize variables
beta_var = torch.randn([1, 10]).cuda()
# first 3 global orientation
rot_var = torch.randn([nn_frames, 3]).cuda()
theta_var = torch.randn([nn_frames, 24]).cuda()
transl_var = torch.randn([nn_frames, 3]).cuda()
# transl_var = tot_rhand_transl.unsqueeze(0).repeat(args.num_init, 1, 1).contiguous().to(device).view(args.num_init * num_frames, 3).contiguous()
# ori_transl_var = transl_var.clone()
# rot_var = tot_rhand_glb_orient.unsqueeze(0).repeat(args.num_init, 1, 1).contiguous().to(device).view(args.num_init * num_frames, 3).contiguous()
beta_var.requires_grad_()
rot_var.requires_grad_()
theta_var.requires_grad_()
transl_var.requires_grad_()
learning_rate = 0.1
# joints: nf x nnjoints x 3 #
dist_joints_to_base_pts = torch.sum(
(joints.unsqueeze(-2) - base_pts.unsqueeze(0).unsqueeze(1)) ** 2, dim=-1 # nf x nnjoints x nnbasepts #
)
nn_base_pts = dist_joints_to_base_pts.size(-1)
nn_joints = dist_joints_to_base_pts.size(1)
dist_joints_to_base_pts = torch.sqrt(dist_joints_to_base_pts) # nf x nnjoints x nnbasepts #
minn_dist, minn_dist_idx = torch.min(dist_joints_to_base_pts, dim=-1) # nf x nnjoints #
nk_contact_pts = 2
minn_dist[:, :-5] = 1e9
minn_topk_dist, minn_topk_idx = torch.topk(minn_dist, k=nk_contact_pts, largest=False) #
# joints_idx_rng_exp = torch.arange(nn_joints).unsqueeze(0).cuda() ==
minn_topk_mask = torch.zeros_like(minn_dist)
# minn_topk_mask[minn_topk_idx] = 1. # nf x nnjoints #
minn_topk_mask[:, -5: -3] = 1.
basepts_idx_range = torch.arange(nn_base_pts).unsqueeze(0).unsqueeze(0).cuda()
minn_dist_mask = basepts_idx_range == minn_dist_idx.unsqueeze(-1) # nf x nnjoints x nnbasepts
minn_dist_mask = minn_dist_mask.float()
minn_topk_mask = (minn_dist_mask + minn_topk_mask.float().unsqueeze(-1)) > 1.5
print(f"minn_dist_mask: {minn_dist_mask.size()}")
s = 1.0
affinity_scores = get_affinity_fr_dist(dist_joints_to_base_pts, s=s)
# opt = optim.Adam([rot_var, transl_var], lr=args.coarse_lr)
num_iters = 200
opt = optim.Adam([rot_var, transl_var], lr=learning_rate)
for i in range(num_iters): #
opt.zero_grad()
# mano_layer #
hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var)
hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001
hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001
joints_pred_loss = torch.sum(
(hand_joints - joints) ** 2, dim=-1
).mean()
# opt.zero_grad()
pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[:, 1:], theta_var.view(nn_frames, -1)[:, :-1])
# joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device))
shape_prior_loss = torch.mean(beta_var**2)
pose_prior_loss = torch.mean(theta_var**2)
# pose_smoothness_loss =
# =0.05
# loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001
loss = joints_pred_loss * 30
loss = joints_pred_loss * 1000
opt.zero_grad()
loss.backward()
opt.step()
print('Iter {}: {}'.format(i, loss.item()), flush=True)
print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item()))
#
print(tot_base_pts_trans.size())
diff_base_pts_trans = torch.sum((tot_base_pts_trans[1:, :, :] - tot_base_pts_trans[:-1, :, :]) ** 2, dim=-1) # (nf - 1) x nn_base_pts
print(f"diff_base_pts_trans: {diff_base_pts_trans.size()}")
diff_base_pts_trans = diff_base_pts_trans.mean(dim=-1)
diff_base_pts_trans_threshold = 1e-20
diff_base_pts_trans_mask = diff_base_pts_trans > diff_base_pts_trans_threshold # (nf - 1) ### the mask of the tranformed base pts
diff_base_pts_trans_mask = diff_base_pts_trans_mask.float()
print(f"diff_base_pts_trans_mask: {diff_base_pts_trans_mask.size()}, diff_base_pts_trans: {diff_base_pts_trans.size()}")
diff_last_frame_mask = torch.tensor([0,], dtype=torch.float32).to(diff_base_pts_trans_mask.device) + diff_base_pts_trans_mask[-1]
diff_base_pts_trans_mask = torch.cat(
[diff_base_pts_trans_mask, diff_last_frame_mask], dim=0 # nf tensor
)
# attraction_mask = (diff_base_pts_trans_mask.unsqueeze(-1).unsqueeze(-1) + minn_topk_mask.float()) > 1.5
attraction_mask = minn_topk_mask.float()
attraction_mask = attraction_mask.float()
num_iters = 2000
num_iters = 3000
# num_iters = 1000
learning_rate = 0.01
opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(opt, step_size=num_iters, gamma=0.5)
for i in range(num_iters):
opt.zero_grad()
# mano_layer
hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var)
hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001
hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001
joints_pred_loss = torch.sum(
(hand_joints - joints) ** 2, dim=-1
).mean()
dist_joints_to_base_pts_sqr = torch.sum(
(hand_joints.unsqueeze(2) - base_pts.unsqueeze(0).unsqueeze(1)) ** 2, dim=-1
)
# attaction_loss = 0.5 * affinity_scores * dist_joints_to_base_pts_sqr
attaction_loss = 0.5 * dist_joints_to_base_pts_sqr
# attaction_loss = attaction_loss
# attaction_loss = torch.mean(attaction_loss[..., -5:, :] * minn_dist_mask[..., -5:, :])
# attaction_loss = torch.mean(attaction_loss * attraction_mask)
attaction_loss = torch.mean(attaction_loss[46:, -5:-3, :] * minn_dist_mask[46:, -5:-3, :])
# opt.zero_grad()
pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[1:], theta_var.view(nn_frames, -1)[:-1])
# joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device))
shape_prior_loss = torch.mean(beta_var**2)
pose_prior_loss = torch.mean(theta_var**2)
joints_smoothness_loss = F.mse_loss(hand_joints.view(nn_frames, -1, 3)[1:], hand_joints.view(nn_frames, -1, 3)[:-1])
# =0.05
# loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 + joints_smoothness_loss * 100.
loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.000001 + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001 + joints_smoothness_loss * 200.
loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + joints_smoothness_loss * 200.
loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.03 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + attaction_loss * 10000 # + joints_smoothness_loss * 200.
# loss = joints_pred_loss * 20 + joints_smoothness_loss * 200. + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001
# loss = joints_pred_loss * 30 # + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001
# loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 0.001 + joints_smoothness_loss * 1.0
# loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 2000000. # + joints_smoothness_loss * 10.0
# loss = joints_pred_loss * 20 # + pose_smoothness_loss * 0.5 + attaction_loss * 100. + joints_smoothness_loss * 10.0
# loss = joints_pred_loss * 30 + attaction_loss * 0.001
opt.zero_grad()
loss.backward()
opt.step()
scheduler.step()
print('Iter {}: {}'.format(i, loss.item()), flush=True)
print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item()))
print('\tAttraction Loss: {}'.format(attaction_loss.item()))
print('\tJoint Smoothness Loss: {}'.format(joints_smoothness_loss.item()))
# joints: nf x nnjoints x 3 #
dist_joints_to_base_pts = torch.sum(
(hand_joints.unsqueeze(-2) - tot_base_pts_trans.unsqueeze(1)) ** 2, dim=-1 # nf x nnjoints x nnbasepts #
)
# nn_base_pts = dist_joints_to_base_pts.size(-1)
# nn_joints = dist_joints_to_base_pts.size(1)
# dist_joints_to_base_pts = torch.sqrt(dist_joints_to_base_pts) # nf x nnjoints x nnbasepts #
minn_dist, minn_dist_idx = torch.min(dist_joints_to_base_pts, dim=-1) # nf x nnjoints #
num_iters = 2000
# num_iters = 1000
ori_theta_var = theta_var.detach().clone()
# opt = optim.Adam([rot_var, transl_var, theta_var], lr=learning_rate)
opt = optim.Adam([transl_var, theta_var], lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(opt, step_size=num_iters, gamma=0.5)
for i in range(num_iters):
opt.zero_grad()
# mano_layer
hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var)
hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001
hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001
# gt_dist_joints_to_base_pts_disp, gt_dist_disp_along_dir, gt_dist_disp_vt_dir = calculate_disp_quants(tot_gt_rhand_joints, tot_base_pts_trans)
# nf x nn_joints here # # cur_dist_disp_vt_dir
cur_dist_joints_to_base_pts_disp, cur_dist_disp_along_dir, cur_dist_disp_vt_dir = calculate_disp_quants(hand_joints, tot_base_pts_trans, minn_base_pts_idxes=minn_dist_idx)
dist_joints_to_base_pts_disp_loss = ((cur_dist_joints_to_base_pts_disp - gt_dist_joints_to_base_pts_disp) ** 2)[:, -5:-3].mean(dim=-1).mean(dim=-1)
dist_disp_along_dir_loss = ((cur_dist_disp_along_dir - gt_dist_disp_along_dir) ** 2)[:, -5:-3].mean(dim=-1).mean(dim=-1)
dist_disp_vt_dir_loss = ((cur_dist_disp_vt_dir - gt_dist_disp_vt_dir) ** 2)[:, -5:-3].mean(dim=-1).mean(dim=-1)
# dist joints to baes pts
dist_joints_to_base_pts_disp_loss = ((cur_dist_joints_to_base_pts_disp - gt_dist_joints_to_base_pts_disp) ** 2).mean(dim=-1).mean(dim=-1)
dist_disp_along_dir_loss = ((cur_dist_disp_along_dir - gt_dist_disp_along_dir) ** 2).mean(dim=-1).mean(dim=-1)
dist_disp_vt_dir_loss = ((cur_dist_disp_vt_dir - gt_dist_disp_vt_dir) ** 2).mean(dim=-1).mean(dim=-1)
# dist dip losses for the joint to base pts disp loss #
# dist_disp_losses -> dist disp losses #
dist_disp_losses = dist_joints_to_base_pts_disp_loss + dist_disp_along_dir_loss + dist_disp_vt_dir_loss
joints_pred_loss = torch.sum(
(hand_joints - joints) ** 2, dim=-1
).mean()
dist_joints_to_base_pts_sqr = torch.sum(
(hand_joints.unsqueeze(2) - base_pts.unsqueeze(0).unsqueeze(1)) ** 2, dim=-1
)
# attaction_loss = 0.5 * affinity_scores * dist_joints_to_base_pts_sqr
attaction_loss = 0.5 * dist_joints_to_base_pts_sqr
# attaction_loss = attaction_loss
# attaction_loss = torch.mean(attaction_loss[..., -5:, :] * minn_dist_mask[..., -5:, :])
# attaction_loss = torch.mean(attaction_loss * attraction_mask)
attaction_loss = torch.mean(attaction_loss[46:, -5:-3, :] * minn_dist_mask[46:, -5:-3, :]) + torch.mean(attaction_loss[:40, -5:-3, :] * minn_dist_mask[:40, -5:-3, :])
# opt.zero_grad()
pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[1:], theta_var.view(nn_frames, -1)[:-1])
# joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device))
shape_prior_loss = torch.mean(beta_var**2)
pose_prior_loss = torch.mean(theta_var**2)
joints_smoothness_loss = F.mse_loss(hand_joints.view(nn_frames, -1, 3)[1:], hand_joints.view(nn_frames, -1, 3)[:-1])
# =0.05
# # loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 + joints_smoothness_loss * 100.
# loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.000001 + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001 + joints_smoothness_loss * 200.
# loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + joints_smoothness_loss * 200.
# loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.03 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + attaction_loss * 10000 # + joints_smoothness_loss * 200.
theta_smoothness_loss = F.mse_loss(theta_var, ori_theta_var)
# loss = attaction_loss * 1000. + theta_smoothness_loss * 0.00001
loss = dist_disp_losses * 1000000. + theta_smoothness_loss * 0.0000001
# loss = joints_pred_loss * 20 + joints_smoothness_loss * 200. + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001
# loss = joints_pred_loss * 30 # + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001
# loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 0.001 + joints_smoothness_loss * 1.0
# loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 2000000. # + joints_smoothness_loss * 10.0
# loss = joints_pred_loss * 20 # + pose_smoothness_loss * 0.5 + attaction_loss * 100. + joints_smoothness_loss * 10.0
# loss = joints_pred_loss * 30 + attaction_loss * 0.001 # joints smoothness loss #
opt.zero_grad()
loss.backward()
opt.step()
scheduler.step()
print('Iter {}: {}'.format(i, loss.item()), flush=True)
print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item()))
print('\tAttraction Loss: {}'.format(attaction_loss.item()))
print('\tJoint Smoothness Loss: {}'.format(joints_smoothness_loss.item()))
# theta_smoothness_loss
print('\tTheta Smoothness Loss: {}'.format(theta_smoothness_loss.item()))
# dist_disp_losses
print('\tDist Disp Loss: {}'.format(dist_disp_losses.item())) # dist disp loss #
return hand_verts.detach().cpu().numpy(), hand_joints.detach().cpu().numpy()
## optimization ##
if __name__=='__main__':
pred_infos_sv_folder = "/home/xueyi/sim/motion-diffusion-model/save/my_humanml_trans_enc_512"
# /data1/sim/mdm/eval_save/predicted_infos_seq_1_seed_77_tag_rep_only_real_sel_base_mean_all_noise_.npy
# pred_infos_sv_folder = "/data1/sim/mdm/eval_save/"
pred_joints_info_nm = "predicted_infos.npy"
# pred_joints_info_nm = "predicted_infos_hoi_seed_77_jts_only_t_300.npy"
# pred_joints_info_nm = "predicted_infos_seq_1_seed_77_tag_rep_only_real_sel_base_mean_all_noise_.npy"
# pred_joints_info_nm = "predicted_infos_seq_2_seed_77_tag_rep_only_real_sel_base_mean_all_noise_.npy"
# pred_joints_info_nm = "predicted_infos_seq_36_seed_77_tag_rep_only_real_sel_base_mean_all_noise_.npy"
# # pred_joints_info_nm = "predicted_infos_seq_36_seed_77_tag_jts_only.npy"
# pred_joints_info_nm = "predicted_infos_seq_80_wtrans.npy"
# pred_joints_info_nm = "predicted_infos_80_wtrans_rep.npy"
# pred_joints_info_nm = "predicted_infos_seq_70_seed_31_tag_jts_only.npy"
# pred_joints_info_nm = "predicted_infos_seq_80_seed_31_tag_jts_only.npy"
# /home/xueyi/sim/motion-diffusion-model/save/my_humanml_trans_enc_512/predicted_infos_seq_17_seed_31_tag_jts_only.npy
# pred_joints_info_nm = "predicted_infos_seq_87_seed_31_tag_jts_only.npy"
# pred_joints_info_nm = "predicted_infos_seq_77.npy"
# pred_joints_info_nm = "predicted_infos_seq_1_seed_31_tag_rep_only_real_sel_base_0.npy"
# # /home/xueyi/sim/motion-diffusion-model/save/my_humanml_trans_enc_512/predicted_infos_seq_1_seed_31_tag_rep_only_real_sel_base_mean.npy
# pred_joints_info_nm = "predicted_infos_seq_1_seed_31_tag_rep_only_real_sel_base_mean.npy"
# # pred_infos_sv_folder = "/home/xueyi/sim/motion-diffusion-model/save/my_humanml_trans_enc_512/predicted_infos.npy"
# TODO: load the total data sequence and transform object shape using the loaded sample
# TODO: please output hands and objects other than hands only in each frame #
pred_joints_info_fn = os.path.join(pred_infos_sv_folder, pred_joints_info_nm)
data = np.load(pred_joints_info_fn, allow_pickle=True).item()
targets = data['targets'] # ## targets -> targets and outputs ##
outputs = data['outputs'] #
tot_base_pts = data["tot_base_pts"][0]
tot_base_normals = data['tot_base_normals'][0] # nn_base_normals #
obj_verts = data['obj_verts']
# outputs = targets
pred_infos_sv_folder = "/data1/sim/mdm/eval_save/"
pred_infos_sv_folder = "/home/xueyi/sim/motion-diffusion-model/save/my_humanml_trans_enc_512"
pred_joints_info_nm = "predicted_infos.npy"
# pred_joints_info_nm = "predicted_infos_hoi_seed_77_jts_only_t_300.npy"
# pred_infos_sv_folder = "/data1/sim/mdm/eval_save/"
# pred_joints_info_nm = "predicted_infos_seq_1_seed_77_tag_rep_only_real_sel_base_mean_all_noise_.npy"
# pred_joints_info_nm = "predicted_infos_seq_2_seed_77_tag_rep_only_real_sel_base_mean_all_noise_.npy"
# pred_joints_info_nm = "predicted_infos_seq_36_seed_77_tag_rep_only_real_sel_base_mean_all_noise_.npy"
# pred_joints_info_nm = "predicted_infos_seq_80_wtrans.npy"
# pred_joints_info_nm = "predicted_infos_80_wtrans_rep.npy"
# pred_joints_info_nm = "predicted_infos_seq_80_seed_31_tag_jts_only.npy"
# pred_joints_info_nm = "predicted_infos_seq_77.npy"
pred_joints_info_fn = os.path.join(pred_infos_sv_folder, pred_joints_info_nm)
data = np.load(pred_joints_info_fn, allow_pickle=True).item()
# outputs = targets # outputs # targets #
tot_obj_rot = data['tot_obj_rot'][0] # ws x 3 x 3 ---> obj_rot; #
tot_obj_transl = data['tot_obj_transl'][0]
print(f"tot_obj_rot: {tot_obj_rot.shape}, tot_obj_transl: {tot_obj_transl.shape}")
if len(tot_base_pts.shape) == 2:
# numpy array #
tot_base_pts_trans = np.matmul(tot_base_pts.reshape(1, tot_base_pts.shape[0], 3), tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])
tot_base_pts = np.matmul(tot_base_pts, tot_obj_rot[0]) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])[0]
tot_base_normals_trans = np.matmul( # #
tot_base_normals.reshape(1, tot_base_normals.shape[0], 3), tot_obj_rot
)
else:
# numpy array #
tot_base_pts_trans = np.matmul(tot_base_pts, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])
tot_base_pts = np.matmul(tot_base_pts, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])
tot_base_normals_trans = np.matmul( # #
tot_base_normals, tot_obj_rot
)
# tot_base_normals_trans = np.matmul( # #
# tot_base_normals.reshape(1, tot_base_normals.shape[0], 3), tot_obj_rot
# )
outputs = np.matmul(outputs, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1]) # ws x nn_verts x 3 #
targets = np.matmul(targets, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1]) # ws x nn_verts x 3 #
# denoise relative positions
print(f"tot_base_pts: {tot_base_pts.shape}")
# obj_verts_trans = np.matmul(obj_verts, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])
# outputs = targets
with_contact_opt = False
with_contact_opt = True
optimized_out_hand_verts, optimized_out_hand_joints = get_optimized_hand_fr_joints_v4(outputs, tot_base_pts, tot_base_pts_trans, tot_base_normals_trans, with_contact_opt=with_contact_opt)
#
# predicted_joint_quants_fn = "/home/xueyi/sim/motion-diffusion-model/save/my_humanml_trans_enc_512/pred_joint_quants.npy"
# predicted_joint_quants = np.load(predicted_joint_quants_fn, allow_pickle=True).item()
# predicted_joint_quants = predicted_joint_quants['dec_joint_quants']
# print("predicted_joint_quants", predicted_joint_quants.shape)
# # print(predicted_joint_quants.keys())
# # exit(0)
# # gt
# tot_gt_rhand_joints = data['tot_gt_rhand_joints'][0] # nf x nn_joints x 3 --> gt joints #
# tot_gt_rhand_joints = np.matmul(tot_gt_rhand_joints, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1]) # ws x nn_verts x 3 #
# optimized_out_hand_verts, optimized_out_hand_joints = get_optimized_hand_fr_joints_v5(outputs, tot_gt_rhand_joints, tot_base_pts, tot_base_pts_trans, predicted_joint_quants=predicted_joint_quants)
# optimized_data_fn = "/home/xueyi/sim/motion-diffusion-model/save/my_humanml_trans_enc_512/optimized_joints.npy"
# data = np.load(optimized_data_fn, allow_pickle=True).item()
# outputs = data["optimized_joints"]
# optimized_out_hand_verts, optimized_out_hand_joints = get_optimized_hand_fr_joints(outputs)
# optimized_tar_hand_verts, optimized_tar_hand_joints = get_optimized_hand_fr_joints(targets)
# optimized_out_hand_verts, optimized_out_hand_joints = get_optimized_hand_fr_joints_v2(outputs, tot_base_pts)
optimized_sv_infos = {
'optimized_out_hand_verts': optimized_out_hand_verts,
'optimized_out_hand_joints': optimized_out_hand_joints,
'tot_base_pts_trans': tot_base_pts_trans,
# 'optimized_tar_hand_verts': optimized_tar_hand_verts,
# 'optimized_tar_hand_joints': optimized_tar_hand_joints,
}
optimized_sv_infos_sv_fn = os.path.join(pred_infos_sv_folder, "optimized_infos_sv_dict.npy")
np.save(optimized_sv_infos_sv_fn, optimized_sv_infos)
print(f"optimized infos saved to {optimized_sv_infos_sv_fn}")