import json import os.path as op import sys import numpy as np import torch import torch.nn as nn import trimesh from easydict import EasyDict from scipy.spatial.distance import cdist sys.path = [".."] + sys.path import common.thing as thing from common.rot import axis_angle_to_quaternion, quaternion_apply from common.torch_utils import pad_tensor_list from common.xdict import xdict # objects to consider for training so far OBJECTS = [ "capsulemachine", "box", "ketchup", "laptop", "microwave", "mixer", "notebook", "espressomachine", "waffleiron", "scissors", "phone", ] class ObjectTensors(nn.Module): def __init__(self): super(ObjectTensors, self).__init__() self.obj_tensors = thing.thing2dev(construct_obj_tensors(OBJECTS), "cpu") self.dev = None def forward_7d_batch( self, angles: (None, torch.Tensor), global_orient: (None, torch.Tensor), transl: (None, torch.Tensor), query_names: list, fwd_template: bool, ): self._sanity_check(angles, global_orient, transl, query_names, fwd_template) # store output out = xdict() # meta info obj_idx = np.array( [self.obj_tensors["names"].index(name) for name in query_names] ) out["diameter"] = self.obj_tensors["diameter"][obj_idx] out["f"] = self.obj_tensors["f"][obj_idx] out["f_len"] = self.obj_tensors["f_len"][obj_idx] out["v_len"] = self.obj_tensors["v_len"][obj_idx] max_len = out["v_len"].max() out["v"] = self.obj_tensors["v"][obj_idx][:, :max_len] out["mask"] = self.obj_tensors["mask"][obj_idx][:, :max_len] out["v_sub"] = self.obj_tensors["v_sub"][obj_idx] out["parts_ids"] = self.obj_tensors["parts_ids"][obj_idx][:, :max_len] out["parts_sub_ids"] = self.obj_tensors["parts_sub_ids"][obj_idx] if fwd_template: return out # articulation + global rotation quat_arti = axis_angle_to_quaternion(self.obj_tensors["z_axis"] * angles) quat_global = axis_angle_to_quaternion(global_orient.view(-1, 3)) # mm # collect entities to be transformed tf_dict = xdict() tf_dict["v_top"] = out["v"].clone() tf_dict["v_sub_top"] = out["v_sub"].clone() tf_dict["v_bottom"] = out["v"].clone() tf_dict["v_sub_bottom"] = out["v_sub"].clone() tf_dict["bbox_top"] = self.obj_tensors["bbox_top"][obj_idx] tf_dict["bbox_bottom"] = self.obj_tensors["bbox_bottom"][obj_idx] tf_dict["kp_top"] = self.obj_tensors["kp_top"][obj_idx] tf_dict["kp_bottom"] = self.obj_tensors["kp_bottom"][obj_idx] # articulate top parts for key, val in tf_dict.items(): if "top" in key: val_rot = quaternion_apply(quat_arti[:, None, :], val) tf_dict.overwrite(key, val_rot) # global rotation for all for key, val in tf_dict.items(): val_rot = quaternion_apply(quat_global[:, None, :], val) if transl is not None: val_rot = val_rot + transl[:, None, :] tf_dict.overwrite(key, val_rot) # prep output top_idx = out["parts_ids"] == 1 v_tensor = tf_dict["v_bottom"].clone() v_tensor[top_idx, :] = tf_dict["v_top"][top_idx, :] top_idx = out["parts_sub_ids"] == 1 v_sub_tensor = tf_dict["v_sub_bottom"].clone() v_sub_tensor[top_idx, :] = tf_dict["v_sub_top"][top_idx, :] bbox = torch.cat((tf_dict["bbox_top"], tf_dict["bbox_bottom"]), dim=1) kp3d = torch.cat((tf_dict["kp_top"], tf_dict["kp_bottom"]), dim=1) out.overwrite("v", v_tensor) out.overwrite("v_sub", v_sub_tensor) out.overwrite("bbox3d", bbox) out.overwrite("kp3d", kp3d) return out def forward(self, angles, global_orient, transl, query_names): out = self.forward_7d_batch( angles, global_orient, transl, query_names, fwd_template=False ) return out def forward_template(self, query_names): out = self.forward_7d_batch( angles=None, global_orient=None, transl=None, query_names=query_names, fwd_template=True, ) return out def to(self, dev): self.obj_tensors = thing.thing2dev(self.obj_tensors, dev) self.dev = dev def _sanity_check(self, angles, global_orient, transl, query_names, fwd_template): # sanity check if not fwd_template: # assume transl is in meter if transl is not None: transl = transl * 1000 # mm batch_size = angles.shape[0] assert angles.shape == (batch_size, 1) assert global_orient.shape == (batch_size, 3) if transl is not None: assert isinstance(transl, torch.Tensor) assert transl.shape == (batch_size, 3) assert len(query_names) == batch_size def construct_obj(object_model_p): # load vtemplate mesh_p = op.join(object_model_p, "mesh.obj") parts_p = op.join(object_model_p, f"parts.json") json_p = op.join(object_model_p, "object_params.json") obj_name = op.basename(object_model_p) top_sub_p = f"./data/arctic_data/data/meta/object_vtemplates/{obj_name}/top_keypoints_300.json" bottom_sub_p = top_sub_p.replace("top_", "bottom_") with open(top_sub_p, "r") as f: sub_top = np.array(json.load(f)["keypoints"]) with open(bottom_sub_p, "r") as f: sub_bottom = np.array(json.load(f)["keypoints"]) sub_v = np.concatenate((sub_top, sub_bottom), axis=0) with open(parts_p, "r") as f: parts = np.array(json.load(f), dtype=np.bool) assert op.exists(mesh_p), f"Not found: {mesh_p}" mesh = trimesh.exchange.load.load_mesh(mesh_p, process=False) mesh_v = mesh.vertices mesh_f = torch.LongTensor(mesh.faces) vidx = np.argmin(cdist(sub_v, mesh_v, metric="euclidean"), axis=1) parts_sub = parts[vidx] vsk = object_model_p.split("/")[-1] with open(json_p, "r") as f: params = json.load(f) rest = EasyDict() rest.top = np.array(params["mocap_top"]) rest.bottom = np.array(params["mocap_bottom"]) bbox_top = np.array(params["bbox_top"]) bbox_bottom = np.array(params["bbox_bottom"]) kp_top = np.array(params["keypoints_top"]) kp_bottom = np.array(params["keypoints_bottom"]) np.random.seed(1) obj = EasyDict() obj.name = vsk obj.obj_name = "".join([i for i in vsk if not i.isdigit()]) obj.v = torch.FloatTensor(mesh_v) obj.v_sub = torch.FloatTensor(sub_v) obj.f = torch.LongTensor(mesh_f) obj.parts = torch.LongTensor(parts) obj.parts_sub = torch.LongTensor(parts_sub) with open("./data/arctic_data/data/meta/object_meta.json", "r") as f: object_meta = json.load(f) obj.diameter = torch.FloatTensor(np.array(object_meta[obj.obj_name]["diameter"])) obj.bbox_top = torch.FloatTensor(bbox_top) obj.bbox_bottom = torch.FloatTensor(bbox_bottom) obj.kp_top = torch.FloatTensor(kp_top) obj.kp_bottom = torch.FloatTensor(kp_bottom) obj.mocap_top = torch.FloatTensor(np.array(params["mocap_top"])) obj.mocap_bottom = torch.FloatTensor(np.array(params["mocap_bottom"])) return obj def construct_obj_tensors(object_names): obj_list = [] for k in object_names: object_model_p = f"./data/arctic_data/data/meta/object_vtemplates/%s" % (k) obj = construct_obj(object_model_p) obj_list.append(obj) bbox_top_list = [] bbox_bottom_list = [] mocap_top_list = [] mocap_bottom_list = [] kp_top_list = [] kp_bottom_list = [] v_list = [] v_sub_list = [] f_list = [] parts_list = [] parts_sub_list = [] diameter_list = [] for obj in obj_list: v_list.append(obj.v) v_sub_list.append(obj.v_sub) f_list.append(obj.f) # root_list.append(obj.root) bbox_top_list.append(obj.bbox_top) bbox_bottom_list.append(obj.bbox_bottom) kp_top_list.append(obj.kp_top) kp_bottom_list.append(obj.kp_bottom) mocap_top_list.append(obj.mocap_top / 1000) mocap_bottom_list.append(obj.mocap_bottom / 1000) parts_list.append(obj.parts + 1) parts_sub_list.append(obj.parts_sub + 1) diameter_list.append(obj.diameter) v_list, v_len_list = pad_tensor_list(v_list) p_list, p_len_list = pad_tensor_list(parts_list) ps_list = torch.stack(parts_sub_list, dim=0) assert (p_len_list - v_len_list).sum() == 0 max_len = v_len_list.max() mask = torch.zeros(len(obj_list), max_len) for idx, vlen in enumerate(v_len_list): mask[idx, :vlen] = 1.0 v_sub_list = torch.stack(v_sub_list, dim=0) diameter_list = torch.stack(diameter_list, dim=0) f_list, f_len_list = pad_tensor_list(f_list) bbox_top_list = torch.stack(bbox_top_list, dim=0) bbox_bottom_list = torch.stack(bbox_bottom_list, dim=0) kp_top_list = torch.stack(kp_top_list, dim=0) kp_bottom_list = torch.stack(kp_bottom_list, dim=0) obj_tensors = {} obj_tensors["names"] = object_names obj_tensors["parts_ids"] = p_list obj_tensors["parts_sub_ids"] = ps_list obj_tensors["v"] = v_list.float() / 1000 obj_tensors["v_sub"] = v_sub_list.float() / 1000 obj_tensors["v_len"] = v_len_list obj_tensors["f"] = f_list obj_tensors["f_len"] = f_len_list obj_tensors["diameter"] = diameter_list.float() obj_tensors["mask"] = mask obj_tensors["bbox_top"] = bbox_top_list.float() / 1000 obj_tensors["bbox_bottom"] = bbox_bottom_list.float() / 1000 obj_tensors["kp_top"] = kp_top_list.float() / 1000 obj_tensors["kp_bottom"] = kp_bottom_list.float() / 1000 obj_tensors["mocap_top"] = mocap_top_list obj_tensors["mocap_bottom"] = mocap_bottom_list obj_tensors["z_axis"] = torch.FloatTensor(np.array([0, 0, -1])).view(1, 3) return obj_tensors