import numpy as np import torch def batch_mm(matrix, matrix_batch): """ https://github.com/pytorch/pytorch/issues/14489#issuecomment-607730242 :param matrix: Sparse or dense matrix, size (m, n). :param matrix_batch: Batched dense matrices, size (b, n, k). :return: The batched matrix-matrix product, size (m, n) x (b, n, k) = (b, m, k). """ batch_size = matrix_batch.shape[0] # Stack the vector batch into columns. (b, n, k) -> (n, b, k) -> (n, b*k) vectors = matrix_batch.transpose(0, 1).reshape(matrix.shape[1], -1) # A matrix-matrix product is a batched matrix-vector product of the columns. # And then reverse the reshaping. (m, n) x (n, b*k) = (m, b*k) -> (m, b, k) -> (b, m, k) return matrix.mm(vectors).reshape(matrix.shape[0], batch_size, -1).transpose(1, 0) def aa2quat(rots, form='wxyz', unified_orient=True): """ Convert angle-axis representation to wxyz quaternion and to the half plan (w >= 0) @param rots: angle-axis rotations, (*, 3) @param form: quaternion format, either 'wxyz' or 'xyzw' @param unified_orient: Use unified orientation for quaternion (quaternion is dual cover of SO3) :return: """ angles = rots.norm(dim=-1, keepdim=True) norm = angles.clone() norm[norm < 1e-8] = 1 axis = rots / norm quats = torch.empty(rots.shape[:-1] + (4,), device=rots.device, dtype=rots.dtype) angles = angles * 0.5 if form == 'wxyz': quats[..., 0] = torch.cos(angles.squeeze(-1)) quats[..., 1:] = torch.sin(angles) * axis elif form == 'xyzw': quats[..., :3] = torch.sin(angles) * axis quats[..., 3] = torch.cos(angles.squeeze(-1)) if unified_orient: idx = quats[..., 0] < 0 quats[idx, :] *= -1 return quats def quat2aa(quats): """ Convert wxyz quaternions to angle-axis representation :param quats: :return: """ _cos = quats[..., 0] xyz = quats[..., 1:] _sin = xyz.norm(dim=-1) norm = _sin.clone() norm[norm < 1e-7] = 1 axis = xyz / norm.unsqueeze(-1) angle = torch.atan2(_sin, _cos) * 2 return axis * angle.unsqueeze(-1) def quat2mat(quats: torch.Tensor): """ Convert (w, x, y, z) quaternions to 3x3 rotation matrix :param quats: quaternions of shape (..., 4) :return: rotation matrices of shape (..., 3, 3) """ qw = quats[..., 0] qx = quats[..., 1] qy = quats[..., 2] qz = quats[..., 3] x2 = qx + qx y2 = qy + qy z2 = qz + qz xx = qx * x2 yy = qy * y2 wx = qw * x2 xy = qx * y2 yz = qy * z2 wy = qw * y2 xz = qx * z2 zz = qz * z2 wz = qw * z2 m = torch.empty(quats.shape[:-1] + (3, 3), device=quats.device, dtype=quats.dtype) m[..., 0, 0] = 1.0 - (yy + zz) m[..., 0, 1] = xy - wz m[..., 0, 2] = xz + wy m[..., 1, 0] = xy + wz m[..., 1, 1] = 1.0 - (xx + zz) m[..., 1, 2] = yz - wx m[..., 2, 0] = xz - wy m[..., 2, 1] = yz + wx m[..., 2, 2] = 1.0 - (xx + yy) return m def quat2euler(q, order='xyz', degrees=True): """ Convert (w, x, y, z) quaternions to xyz euler angles. This is used for bvh output. """ q0 = q[..., 0] q1 = q[..., 1] q2 = q[..., 2] q3 = q[..., 3] es = torch.empty(q0.shape + (3,), device=q.device, dtype=q.dtype) if order == 'xyz': es[..., 2] = torch.atan2(2 * (q0 * q3 - q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) es[..., 1] = torch.asin((2 * (q1 * q3 + q0 * q2)).clip(-1, 1)) es[..., 0] = torch.atan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) else: raise NotImplementedError('Cannot convert to ordering %s' % order) if degrees: es = es * 180 / np.pi return es def euler2mat(rots, order='xyz'): axis = {'x': torch.tensor((1, 0, 0), device=rots.device), 'y': torch.tensor((0, 1, 0), device=rots.device), 'z': torch.tensor((0, 0, 1), device=rots.device)} rots = rots / 180 * np.pi mats = [] for i in range(3): aa = axis[order[i]] * rots[..., i].unsqueeze(-1) mats.append(aa2mat(aa)) return mats[0] @ (mats[1] @ mats[2]) def aa2mat(rots): """ Convert angle-axis representation to rotation matrix :param rots: angle-axis representation :return: """ quat = aa2quat(rots) mat = quat2mat(quat) return mat def mat2quat(R) -> torch.Tensor: ''' https://github.com/duolu/pyrotation/blob/master/pyrotation/pyrotation.py Convert a rotation matrix to a unit quaternion. This uses the Shepperd’s method for numerical stability. ''' # The rotation matrix must be orthonormal w2 = (1 + R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]) x2 = (1 + R[..., 0, 0] - R[..., 1, 1] - R[..., 2, 2]) y2 = (1 - R[..., 0, 0] + R[..., 1, 1] - R[..., 2, 2]) z2 = (1 - R[..., 0, 0] - R[..., 1, 1] + R[..., 2, 2]) yz = (R[..., 1, 2] + R[..., 2, 1]) xz = (R[..., 2, 0] + R[..., 0, 2]) xy = (R[..., 0, 1] + R[..., 1, 0]) wx = (R[..., 2, 1] - R[..., 1, 2]) wy = (R[..., 0, 2] - R[..., 2, 0]) wz = (R[..., 1, 0] - R[..., 0, 1]) w = torch.empty_like(x2) x = torch.empty_like(x2) y = torch.empty_like(x2) z = torch.empty_like(x2) flagA = (R[..., 2, 2] < 0) * (R[..., 0, 0] > R[..., 1, 1]) flagB = (R[..., 2, 2] < 0) * (R[..., 0, 0] <= R[..., 1, 1]) flagC = (R[..., 2, 2] >= 0) * (R[..., 0, 0] < -R[..., 1, 1]) flagD = (R[..., 2, 2] >= 0) * (R[..., 0, 0] >= -R[..., 1, 1]) x[flagA] = torch.sqrt(x2[flagA]) w[flagA] = wx[flagA] / x[flagA] y[flagA] = xy[flagA] / x[flagA] z[flagA] = xz[flagA] / x[flagA] y[flagB] = torch.sqrt(y2[flagB]) w[flagB] = wy[flagB] / y[flagB] x[flagB] = xy[flagB] / y[flagB] z[flagB] = yz[flagB] / y[flagB] z[flagC] = torch.sqrt(z2[flagC]) w[flagC] = wz[flagC] / z[flagC] x[flagC] = xz[flagC] / z[flagC] y[flagC] = yz[flagC] / z[flagC] w[flagD] = torch.sqrt(w2[flagD]) x[flagD] = wx[flagD] / w[flagD] y[flagD] = wy[flagD] / w[flagD] z[flagD] = wz[flagD] / w[flagD] # if R[..., 2, 2] < 0: # # if R[..., 0, 0] > R[..., 1, 1]: # # x = torch.sqrt(x2) # w = wx / x # y = xy / x # z = xz / x # # else: # # y = torch.sqrt(y2) # w = wy / y # x = xy / y # z = yz / y # # else: # # if R[..., 0, 0] < -R[..., 1, 1]: # # z = torch.sqrt(z2) # w = wz / z # x = xz / z # y = yz / z # # else: # # w = torch.sqrt(w2) # x = wx / w # y = wy / w # z = wz / w res = [w, x, y, z] res = [z.unsqueeze(-1) for z in res] return torch.cat(res, dim=-1) / 2 def quat2repr6d(quat): mat = quat2mat(quat) res = mat[..., :2, :] res = res.reshape(res.shape[:-2] + (6, )) return res def repr6d2mat(repr): x = repr[..., :3] y = repr[..., 3:] x = x / x.norm(dim=-1, keepdim=True) z = torch.cross(x, y) z = z / z.norm(dim=-1, keepdim=True) y = torch.cross(z, x) res = [x, y, z] res = [v.unsqueeze(-2) for v in res] mat = torch.cat(res, dim=-2) return mat def repr6d2quat(repr) -> torch.Tensor: x = repr[..., :3] y = repr[..., 3:] x = x / x.norm(dim=-1, keepdim=True) z = torch.cross(x, y) z = z / z.norm(dim=-1, keepdim=True) y = torch.cross(z, x) res = [x, y, z] res = [v.unsqueeze(-2) for v in res] mat = torch.cat(res, dim=-2) return mat2quat(mat) def inv_affine(mat): """ Calculate the inverse of any affine transformation """ affine = torch.zeros((mat.shape[:2] + (1, 4))) affine[..., 3] = 1 vert_mat = torch.cat((mat, affine), dim=2) vert_mat_inv = torch.inverse(vert_mat) return vert_mat_inv[..., :3, :] def inv_rigid_affine(mat): """ Calculate the inverse of a rigid affine transformation """ res = mat.clone() res[..., :3] = mat[..., :3].transpose(-2, -1) res[..., 3] = -torch.matmul(res[..., :3], mat[..., 3].unsqueeze(-1)).squeeze(-1) return res def generate_pose(batch_size, device, uniform=False, factor=1, root_rot=False, n_bone=None, ee=None): if n_bone is None: n_bone = 24 if ee is not None: if root_rot: ee.append(0) n_bone_ = n_bone n_bone = len(ee) axis = torch.randn((batch_size, n_bone, 3), device=device) axis /= axis.norm(dim=-1, keepdim=True) if uniform: angle = torch.rand((batch_size, n_bone, 1), device=device) * np.pi else: angle = torch.randn((batch_size, n_bone, 1), device=device) * np.pi / 6 * factor angle.clamp(-np.pi, np.pi) poses = axis * angle if ee is not None: res = torch.zeros((batch_size, n_bone_, 3), device=device) for i, id in enumerate(ee): res[:, id] = poses[:, i] poses = res poses = poses.reshape(batch_size, -1) if not root_rot: poses[..., :3] = 0 return poses def slerp(l, r, t, unit=True): """ :param l: shape = (*, n) :param r: shape = (*, n) :param t: shape = (*) :param unit: If l and h are unit vectors :return: """ eps = 1e-8 if not unit: l_n = l / torch.norm(l, dim=-1, keepdim=True) r_n = r / torch.norm(r, dim=-1, keepdim=True) else: l_n = l r_n = r omega = torch.acos((l_n * r_n).sum(dim=-1).clamp(-1, 1)) dom = torch.sin(omega) flag = dom < eps res = torch.empty_like(l_n) t_t = t[flag].unsqueeze(-1) res[flag] = (1 - t_t) * l_n[flag] + t_t * r_n[flag] flag = ~ flag t_t = t[flag] d_t = dom[flag] va = torch.sin((1 - t_t) * omega[flag]) / d_t vb = torch.sin(t_t * omega[flag]) / d_t res[flag] = (va.unsqueeze(-1) * l_n[flag] + vb.unsqueeze(-1) * r_n[flag]) return res def slerp_quat(l, r, t): """ slerp for unit quaternions :param l: (*, 4) unit quaternion :param r: (*, 4) unit quaternion :param t: (*) scalar between 0 and 1 """ t = t.expand(l.shape[:-1]) flag = (l * r).sum(dim=-1) >= 0 res = torch.empty_like(l) res[flag] = slerp(l[flag], r[flag], t[flag]) flag = ~ flag res[flag] = slerp(-l[flag], r[flag], t[flag]) return res # def slerp_6d(l, r, t): # l_q = repr6d2quat(l) # r_q = repr6d2quat(r) # res_q = slerp_quat(l_q, r_q, t) # return quat2repr6d(res_q) def interpolate_6d(input, size): """ :param input: (batch_size, n_channels, length) :param size: required output size for temporal axis :return: """ batch = input.shape[0] length = input.shape[-1] input = input.reshape((batch, -1, 6, length)) input = input.permute(0, 1, 3, 2) # (batch_size, n_joint, length, 6) input_q = repr6d2quat(input) idx = torch.tensor(list(range(size)), device=input_q.device, dtype=torch.float) / size * (length - 1) idx_l = torch.floor(idx) t = idx - idx_l idx_l = idx_l.long() idx_r = idx_l + 1 t = t.reshape((1, 1, -1)) res_q = slerp_quat(input_q[..., idx_l, :], input_q[..., idx_r, :], t) res = quat2repr6d(res_q) # shape = (batch_size, n_joint, t, 6) res = res.permute(0, 1, 3, 2) res = res.reshape((batch, -1, size)) return res