DveloperY0115's picture
init repo
801501a
raw
history blame
23.4 kB
# from chamferdist import ChamferDistance
from ..custom_types import *
from ..constants import EPSILON
from functools import reduce
import igl
# import trimesh
from ..custom_types import T_Mesh, TS
def scale_all(*values: T):
# mean_std = [(val.mean(), val.std()) for val in values]
# values = [val.clamp(scales[0] - scales[1] * 3, scales[0] + scales[1] * 3) for val,scales in zip(values, mean_std)]
max_val = max([val.max().item() for val in values])
min_val = min([val.min().item() for val in values])
scale = max_val - min_val
values = [(val - min_val) / scale for val in values]
if len(values) == 1:
return values[0]
return values
def get_faces_normals(mesh: Union[T_Mesh, T]) -> T:
if type(mesh) is not T:
vs, faces = mesh
vs_faces = vs[faces]
else:
vs_faces = mesh
if vs_faces.shape[-1] == 2:
vs_faces = torch.cat(
(vs_faces, torch.zeros(*vs_faces.shape[:2], 1, dtype=vs_faces.dtype, device=vs_faces.device)), dim=2)
face_normals = torch.cross(vs_faces[:, 1, :] - vs_faces[:, 0, :], vs_faces[:, 2, :] - vs_faces[:, 1, :])
return face_normals
def compute_face_areas(mesh: Union[T_Mesh, T]) -> TS:
face_normals = get_faces_normals(mesh)
face_areas = torch.norm(face_normals, p=2, dim=1)
face_areas_ = face_areas.clone()
face_areas_[torch.eq(face_areas_, 0)] = 1
face_normals = face_normals / face_areas_[:, None]
face_areas = 0.5 * face_areas
return face_areas, face_normals
def check_sign_area(*meshes: T_Mesh) -> bool:
for mesh in meshes:
face_normals = get_faces_normals(mesh)
if not face_normals[:, 2].gt(0).all():
return False
return True
def to_numpy(*tensors: T) -> ARRAYS:
params = [param.detach().cpu().numpy() if type(param) is T else param for param in tensors]
return params
def create_mapper(mask: T) -> T:
mapper = torch.zeros(mask.shape[0], dtype=torch.int64, device=mask.device) - 1
mapper[mask] = torch.arange(mask.sum().item(), device=mask.device)
return mapper
def mesh_center(mesh: T_Mesh):
return mesh[0].mean(0)
def get_center(vs) -> T:
max_vals = vs.max(0)[0]
min_vals = vs.min(0)[0]
center = (max_vals + min_vals) / 2
return center
def to_center(vs):
vs -= get_center(vs)[None, :]
return vs
def scale_by_ref(mesh, ref_mesh, in_place=True, scale=1.):
vs, _ = ref_mesh
if not in_place:
vs = vs.clone()
center = get_center(vs)
vs -= center[None, :]
scale = scale / vs.norm(2, dim=1).max()
vs = (mesh[0] - center[None, :]) * scale
return vs, mesh[1]
def to_unit_sphere(mesh: T_Mesh, in_place: bool = True, scale=1.) -> T_Mesh:
vs, faces = mesh
if not in_place:
vs = vs.clone()
vs = to_center(vs)
norm = vs.norm(2, dim=1).max()
vs *= scale * norm ** -1
return vs, faces
def scale_from_ref(mesh: T_Mesh, center: T, scale: float, in_place: bool = True) -> T_Mesh:
vs, faces = mesh
if not in_place:
vs = vs.clone()
vs -= center[None, :]
vs *= scale
return vs, faces
def to_unit_cube(*meshes: T_Mesh_T, scale=1, in_place: bool = True) -> Tuple[Union[T_Mesh_T, Tuple[T_Mesh_T, ...]], Tuple[T, float]]:
remove_me = 0
meshes = [(mesh, remove_me) if type(mesh) is T else mesh for mesh in meshes]
vs, faces = meshes[0]
max_vals = vs.max(0)[0]
min_vals = vs.min(0)[0]
max_range = (max_vals - min_vals).max() / 2
center = (max_vals + min_vals) / 2
meshes_ = []
scale = float(scale / max_range)
for mesh in meshes:
vs_, faces_ = scale_from_ref(mesh, center, scale)
meshes_.append(vs_ if faces_ is remove_me else (vs_, faces_))
if len(meshes_) == 1:
meshes_ = meshes_[0]
return meshes_, (center, scale)
# # in place
# def to_unit_edge(*meshes: T_Mesh) -> Tuple[Union[T_Mesh, Tuple[T_Mesh, ...]], Tuple[T, float]]:
# ref = meshes[0]
# center = ref[0].mean(0)
# ratio = edge_lengths(ref).mean().item()
# for mesh in meshes:
# vs, _ = mesh
# vs -= center[None, :].to(vs.device)
# vs /= ratio
# if len(meshes) == 1:
# meshes = meshes[0]
# return meshes, (center, ratio)
def get_edges_ind(mesh: T_Mesh) -> T:
vs, faces = mesh
raw_edges = torch.cat([faces[:, [i, (i + 1) % 3]] for i in range(3)]).sort()
raw_edges = raw_edges[0].cpu().numpy()
edges = {(int(edge[0]), int(edge[1])) for edge in raw_edges}
edges = torch.tensor(list(edges), dtype=torch.int64, device=faces.device)
return edges
def edge_lengths(mesh: T_Mesh, edges_ind: TN = None) -> T:
vs, faces = mesh
if edges_ind is None:
edges_ind = get_edges_ind(mesh)
edges = vs[edges_ind]
return torch.norm(edges[:, 0] - edges[:, 1], 2, dim=1)
# in place
def to_unit_edge(*meshes: T_Mesh) -> Tuple[Union[T_Mesh, Tuple[T_Mesh, ...]], Tuple[T, float]]:
ref = meshes[0]
center = ref[0].mean(0)
ratio = edge_lengths(ref).mean().item()
for mesh in meshes:
vs, _ = mesh
vs -= center[None, :].to(vs.device)
vs /= ratio
if len(meshes) == 1:
meshes = meshes[0]
return meshes, (center, ratio)
def to(tensors, device: D) -> Union[T_Mesh, TS, T]:
out = []
for tensor in tensors:
if type(tensor) is T:
out.append(tensor.to(device, ))
elif type(tensor) is tuple or type(tensors) is List:
out.append(to(list(tensor), device))
else:
out.append(tensor)
if len(tensors) == 1:
return out[0]
else:
return tuple(out)
def clone(*tensors: Union[T, TS]) -> Union[TS, T_Mesh]:
out = []
for t in tensors:
if type(t) is T:
out.append(t.clone())
else:
out.append(clone(*t))
return out
def get_box(w: float, h: float, d: float) -> T_Mesh:
vs = [[0, 0, 0], [w, 0, 0], [0, d, 0], [w, d, 0],
[0, 0, h], [w, 0, h], [0, d, h], [w, d, h]]
faces = [[0, 2, 1], [1, 2, 3], [4, 5, 6], [5, 7, 6],
[0, 1, 5], [0, 5, 4], [2, 6, 7], [3, 2, 7],
[1, 3, 5], [3, 7, 5], [0, 4, 2], [2, 4, 6]]
return torch.tensor(vs, dtype=torch.float32), torch.tensor(faces, dtype=torch.int64)
def normalize(t: T):
t = t / t.norm(2, dim=1)[:, None]
return t
def interpolate_vs(mesh: T_Mesh, faces_inds: T, weights: T) -> T:
vs = mesh[0][mesh[1][faces_inds]]
vs = vs * weights[:, :, None]
return vs.sum(1)
def sample_uvw(shape, device: D):
u, v = torch.rand(*shape, device=device), torch.rand(*shape, device=device)
mask = (u + v).gt(1)
u[mask], v[mask] = -u[mask] + 1, -v[mask] + 1
w = -u - v + 1
uvw = torch.stack([u, v, w], dim=len(shape))
return uvw
def get_sampled_fe(fe: T, mesh: T_Mesh, face_ids: T, uvw: TN) -> T:
# to_squeeze =
if fe.dim() == 1:
fe = fe.unsqueeze(1)
if uvw is None:
fe_iner = fe[face_ids]
else:
vs_ids = mesh[1][face_ids]
fe_unrolled = fe[vs_ids]
fe_iner = torch.einsum('sad,sa->sd', fe_unrolled, uvw)
# if to_squeeze:
# fe_iner = fe_iner.squeeze_(1)
return fe_iner
def sample_on_faces(mesh: T_Mesh, num_samples: int) -> TS:
vs, faces = mesh
uvw = sample_uvw([faces.shape[0], num_samples], vs.device)
samples = torch.einsum('fad,fna->fnd', vs[faces], uvw)
return samples, uvw
class SampleBy(Enum):
AREAS = 0
FACES = 1
HYB = 2
def sample_on_mesh(mesh: T_Mesh, num_samples: int, face_areas: TN = None,
sample_s: SampleBy = SampleBy.HYB) -> TNS:
vs, faces = mesh
if faces is None: # sample from pc
uvw = None
if vs.shape[0] < num_samples:
chosen_faces_inds = torch.arange(vs.shape[0])
else:
chosen_faces_inds = torch.argsort(torch.rand(vs.shape[0]))[:num_samples]
samples = vs[chosen_faces_inds]
else:
weighted_p = []
if sample_s == SampleBy.AREAS or sample_s == SampleBy.HYB:
if face_areas is None:
face_areas, _ = compute_face_areas(mesh)
face_areas[torch.isnan(face_areas)] = 0
weighted_p.append(face_areas / face_areas.sum())
if sample_s == SampleBy.FACES or sample_s == SampleBy.HYB:
weighted_p.append(torch.ones(mesh[1].shape[0], device=mesh[0].device))
chosen_faces_inds = [torch.multinomial(weights, num_samples // len(weighted_p), replacement=True) for weights in weighted_p]
if sample_s == SampleBy.HYB:
chosen_faces_inds = torch.cat(chosen_faces_inds, dim=0)
chosen_faces = faces[chosen_faces_inds]
uvw = sample_uvw([num_samples], vs.device)
samples = torch.einsum('sf,sfd->sd', uvw, vs[chosen_faces])
return samples, chosen_faces_inds, uvw
def get_samples(mesh: T_Mesh, num_samples: int, sample_s: SampleBy, *features: T) -> Union[T, TS]:
samples, face_ids, uvw = sample_on_mesh(mesh, num_samples, sample_s=sample_s)
if len(features) > 0:
samples = [samples] + [get_sampled_fe(fe, mesh, face_ids, uvw) for fe in features]
return samples, face_ids, uvw
def find_barycentric(vs: T, triangles: T) -> T:
def compute_barycentric(ind):
triangles[:, ind] = vs
alpha = compute_face_areas(triangles)[0] / areas
triangles[:, ind] = recover[:, ind]
return alpha
device, dtype = vs.device, vs.dtype
vs = vs.to(device, dtype=torch.float64)
triangles = triangles.to(device, dtype=torch.float64)
areas, _ = compute_face_areas(triangles)
recover = triangles.clone()
barycentric = [compute_barycentric(i) for i in range(3)]
barycentric = torch.stack(barycentric, dim=1)
# assert barycentric.sum(1).max().item() <= 1 + EPSILON
return barycentric.to(device, dtype=dtype)
def from_barycentric(mesh: Union[T_Mesh, T], face_ids: T, weights: T) -> T:
if type(mesh) is not T:
triangles: T = mesh[0][mesh[1]]
else:
triangles: T = mesh
to_squeeze = weights.dim() == 1
if to_squeeze:
weights = weights.unsqueeze(0)
face_ids = face_ids.unsqueeze(0)
vs = torch.einsum('nad,na->nd', triangles[face_ids], weights)
if to_squeeze:
vs = vs.squeeze(0)
return vs
def check_circle_angles(mesh: T_Mesh, center_ind: int, select: T) -> bool:
vs, _ = mesh
all_vecs = vs[select] - vs[center_ind][None, :]
all_vecs = all_vecs / all_vecs.norm(2, 1)[:, None]
all_vecs = torch.cat([all_vecs, all_vecs[:1]], dim=0)
all_cos = torch.einsum('nd,nd->n', all_vecs[1:], all_vecs[:-1])
all_angles = torch.acos_(all_cos)
all_angles = all_angles.sum()
return (all_angles - 2 * np.pi).abs() < EPSILON
def vs_over_triangle(vs_mid: T, triangle: T, normals=None) -> T:
if vs_mid.dim() == 1:
vs_mid = vs_mid.unsqueeze(0)
triangle = triangle.unsqueeze(0)
if normals is None:
_, normals = compute_face_areas(triangle)
select = torch.arange(3)
d_vs = vs_mid[:, None, :] - triangle
d_f = triangle[:, select] - triangle[:, (select + 1) % 3]
all_cross = torch.cross(d_vs, d_f, dim=2)
all_dots = torch.einsum('nd,nad->na', normals, all_cross)
is_over = all_dots.ge(0).long().sum(1).eq(3)
return is_over
def f2v(num_faces: int, genus: int = 0) -> int: # assuming there are not boundaries
return num_faces // 2 + (1 - genus) * 2
def v2f(num_vs: int, genus: int = 0) -> int: # assuming there are not boundaries
return 2 * num_vs - 4 + 4 * genus
def get_dist_mat(a: T, b: T, batch_size: int = 1000, sqrt: bool = False) -> T:
"""
:param a:
:param b:
:param batch_size: Limit batches per distance calculation to avoid out-of-mem
:return:
"""
iters = a.shape[0] // batch_size
dist_list = [((a[i * batch_size: (i + 1) * batch_size, None, :] - b[None, :, :]) ** 2).sum(-1)
for i in range(iters + 1)]
all_dist: T = torch.cat(dist_list, dim=0)
if sqrt:
all_dist = all_dist.sqrt_()
return all_dist
def naive_knn(k: int, dist_mat: T, is_biknn=True):
"""
:param k:
:param dist_mat:
:param is_biknn: When false, calcluates only closest element in a per element of b.
When true, calcluates only closest element in a <--> b both ways.
:param batch_size: Limit batches per distance calculation to avoid out-of-mem
:return:
"""
_, close_to_b = dist_mat.topk(k, 0, largest=False)
if is_biknn:
_, close_to_a = dist_mat.topk(k, 1, largest=False)
return close_to_a, close_to_b.t()
return close_to_b.t()
def chamfer_igl():
igl.cha
def simple_chamfer(a: T, b: T, normals_a=None, normals_b=None, dist_mat: Optional[T] = None) -> Union[T, TS]:
def one_direction(fixed: T, search: T, n_f, n_s, closest_id) -> TS:
min_dist = (fixed - search[closest_id]).norm(2, 1).mean(0)
if n_f is not None:
normals_dist = -torch.einsum('nd,nd->n', n_f, n_s[closest_id]).mean(0)
else:
normals_dist = 0
return min_dist, normals_dist
if dist_mat is None:
dist_mat = get_dist_mat(a, b)
close_to_a, close_to_b = naive_knn(1, dist_mat)
dist_a, dist_a_n = one_direction(a, b, normals_a, normals_b, close_to_a.flatten())
dist_b, dist_b_n = one_direction(b, a, normals_b, normals_a, close_to_b.flatten())
if normals_a is None:
return dist_a + dist_b
return dist_a + dist_b, dist_a_n + dist_b_n
def is_quad(mesh: Union[T_Mesh, Tuple[T, List[List[int]]]]) -> bool:
if type(mesh) is T:
return False
if type(mesh[1]) is T:
return False
else:
faces: List[List[int]] = mesh[1]
for f in faces:
if len(f) == 4:
return True
return False
def align_mesh(mesh: T_Mesh, ref_vs: T) -> T_Mesh:
vs, faces = mesh
dist_mat = get_dist_mat(vs, ref_vs)
dist, mapping_id = dist_mat.min(1)
vs_select = dist_mat.min(0)[1]
if mapping_id.unique().shape[0] != vs.shape[0]:
print('\n\033[91mWarning, alignment is not bijective\033[0m')
vs_aligned = vs[vs_select]
faces_aligned = mapping_id[faces]
return vs_aligned, faces_aligned
# def triangulate_mesh(mesh: Union[T_Mesh, Tuple[T, List[List[int_b]]]]) -> Tuple[T_Mesh, Optional[T]]:
#
# def check_triangle(triangle: List[int_b]) -> bool:
# e_1: T = vs[triangle[1]] - vs[triangle[0]]
# e_2: T = vs[triangle[2]] - vs[triangle[0]]
# angle = (e_1 * e_2).sum() / (e_1.norm(2) * e_2.norm(2))
# return angle.abs().item() < 1 - 1e-6
#
# def add_triangle(face_: List[int_b]):
# triangle = None
# for i in range(len(face_)):
# triangle = [face_[i], face_[(i + 1) % len(face_)], face_[(i + 2) % len(face_)]]
# if check_triangle(triangle):
# face_ = [f for j, f in enumerate(face_) if j != (i + 1) % len(face_)]
# break
# assert triangle is not None
# faces_.append(triangle)
# face_twin.append(-1)
# return face_
#
# if not is_quad(mesh):
# return mesh, None
#
# vs, faces = mesh
# faces_ = []
# face_twin = []
# for face in faces:
# if len(face) == 3:
# faces_.append(face)
# face_twin.append(-1)
# else:
# while len(face) > 4:
# face = add_triangle(face)
# new_faces = [[face[0], face[1], face[2]], [face[0], face[2], face[3]]]
# if not check_triangle(new_faces[0]) or not check_triangle(new_faces[1]):
# new_faces = [[face[0], face[1], face[3]], [face[1], face[2], face[3]]]
# assert check_triangle(new_faces[0]) and check_triangle(new_faces[1])
# faces_.extend(new_faces)
# face_twin.extend([len(faces_) - 1, len(faces_) - 2])
# # else:
# # raise ValueError(f'mesh with {len(face)} edges polygons is not supported')
# faces_ = torch.tensor(faces_, device=vs.device, dtype=torch.int64)
# face_twin = torch.tensor(face_twin, device=vs.device, dtype=torch.int64)
# return (vs, faces_), face_twin
def triangulate_mesh(mesh: Union[T_Mesh, Tuple[T, List[List[int]]]]) -> Tuple[T_Mesh, Optional[T]]:
def get_skinny(faces_) -> T:
vs_faces = vs[faces_]
areas = compute_face_areas(vs_faces)[0]
edges = reduce(
lambda a, b: a + b,
map(
lambda i: ((vs_faces[:, i] - vs_faces[:, (i + 1) % 3]) ** 2).sum(1),
range(3)
)
)
skinny_value = np.sqrt(48) * areas / edges
return skinny_value
if not is_quad(mesh):
return mesh, None
vs, faces = mesh
device = vs.device
faces_keep = torch.tensor([face for face in faces if len(face) == 3], dtype=torch.int64, device=device)
faces_quads = torch.tensor([face for face in faces if len(face) != 3], dtype=torch.int64, device=device)
faces_tris_a, faces_tris_b = faces_quads[:, :3], faces_quads[:, torch.tensor([0, 2, 3], dtype=torch.int64)]
faces_tris_c, faces_tris_d = faces_quads[:, 1:], faces_quads[:, torch.tensor([0, 1, 3], dtype=torch.int64)]
skinny = [get_skinny(f) for f in (faces_tris_a, faces_tris_b, faces_tris_c, faces_tris_d)]
skinny_ab, skinny_cd = torch.stack((skinny[0], skinny[1]), 1), torch.stack((skinny[2], skinny[3]), 1)
to_flip = skinny_ab.min(1)[0].lt(skinny_cd.min(1)[0])
faces_tris_a[to_flip], faces_tris_b[to_flip] = faces_tris_c[to_flip], faces_tris_d[to_flip]
faces_tris = torch.cat((faces_tris_a, faces_tris_b, faces_keep), dim=0)
face_twin = torch.arange(faces_tris_a.shape[0], device=device)
face_twin = torch.cat((face_twin + faces_tris_a.shape[0], face_twin,
-torch.ones(faces_keep.shape[0], device=device, dtype=torch.int64)))
return (vs, faces_tris), face_twin
def igl_prepare(*dtypes):
def decoder(func):
def wrapper(*args, **kwargs):
mesh = args[0]
device, dtype = mesh[0].device, mesh[0].dtype
vs, faces = to_numpy(*mesh)
result = func((vs, faces), *args[1:], **kwargs)
return to_torch(result, device)
if len(dtypes) == 0:
to_torch = to_torch_empty
elif len(dtypes) == 1:
to_torch = to_torch_multi
else:
to_torch = to_torch_singe
return wrapper
def to_torch_singe(result, device):
return torch.from_numpy(result).to(device, dtype=dtypes[0])
def to_torch_multi(result, device):
return [torch.from_numpy(r).to(device, dtype=dtype) for r, dtype in zip(result, dtypes)]
def to_torch_empty(result, device):
return result
return decoder
@igl_prepare(torch.float32, torch.int64)
def decimate_igl(mesh, num_faces: int):
if mesh[1].shape[0] <= num_faces:
return mesh
vs, faces, _ = igl.remove_duplicates(*mesh, 1e-8)
return igl.decimate(vs, faces, num_faces)[1:3]
@igl_prepare(torch.float32)
def gaussian_curvature(mesh: T_Mesh) -> T:
gc = igl.gaussian_curvature(*mesh)
return gc
@igl_prepare(torch.float32)
def per_vertex_normals_igl(mesh: T_Mesh, weighting: int = 0) -> T:
normals = igl.per_vertex_normals(*mesh, weighting)
return normals
@igl_prepare(torch.float32, torch.int64)
def remove_duplicate_vertices(mesh: T_Mesh, epsilon=1e-7) -> T_Mesh:
vs, _, _, faces = igl.remove_duplicate_vertices(*mesh, epsilon)
return vs, faces
@igl_prepare(torch.float32)
def winding_number_igl(mesh: T_Mesh, query: T) -> T:
query = query.cpu().numpy()
return igl.fast_winding_number_for_meshes(*mesh, query)
@igl_prepare(torch.float32, torch.float32, torch.float32, torch.float32)
def principal_curvature(mesh: T_Mesh) -> TS:
out = igl.principal_curvature(*mesh)
min_dir, max_dir, min_val, max_val = out
return min_dir, max_dir, min_val, max_val
# def get_inside_outside(points: T, mesh: T_Mesh) -> T:
# device = points.device
# points = points.numpy()
# vs, faces = mesh[0].numpy(), mesh[1].numpy()
# winding_numbers = igl.fast_winding_number_for_meshes(vs, faces, points)
# winding_numbers = torch.from_numpy(winding_numbers)
# inside_outside = winding_numbers.lt(.5).float() * 2 - 1
# return inside_outside.to(device)
@igl_prepare()
def get_inside_outside(mesh: T_Mesh, points: ARRAY) -> ARRAY:
batch_size = 1000000
labels = []
num_batch = points.shape[0] // batch_size + 1
for i in range(points.shape[0] // batch_size + 1):
if i == num_batch - 1:
pts_in = points[batch_size * i:]
else:
pts_in = points[batch_size * i: batch_size * (i + 1)]
w = igl.winding_number(*mesh, pts_in)
w = np.less_equal(w, .9)
labels.append(w)
return np.concatenate(labels, axis=0)
@igl_prepare()
def get_fast_inside_outside(mesh: T_Mesh, points: ARRAY):
batch_size = 1000000
labels = []
num_batch = points.shape[0] // batch_size + 1
for i in range(points.shape[0] // batch_size + 1):
if i == num_batch - 1:
pts_in = points[batch_size * i:]
else:
pts_in = points[batch_size * i: batch_size * (i + 1)]
w = igl.fast_winding_number_for_meshes(*mesh, pts_in)
w = np.less_equal(w, .9)
labels.append(w)
return np.concatenate(labels, axis=0)
# def get_inside_outside_trimes(mesh: T_Mesh, points: T) -> Optional[ARRAY]:
# mesh = mesh_utils.to(mesh, points.device)
# mesh = make_data.trimmesh(mesh)
# batch_size = 1000000
# num_batch = points.shape[0] // batch_size + 1
# labels = []
# # try:
# for i in range(points.shape[0] // batch_size + 1):
# if i == num_batch - 1:
# pts_in = points[batch_size * i:]
# else:
# pts_in = points[batch_size * i: batch_size * (i + 1)]
# label = make_data.sdfmeshfun(pts_in, mesh).lt(0)
# label = label.cpu()
# labels.append(label.numpy())
# # except RuntimeError:
# # return None
# return np.concatenate(labels, axis=0)
@igl_prepare(torch.float32, torch.int64)
def trimesh_smooth(mesh, lamb=0.5, iterations=10):
mesh = trimesh.Trimesh(vertices=mesh[0], faces=mesh[1])
# trimesh.smoothing.filter_mut_dif_laplacian(mesh, lamb=lamb, iterations=iterations, volume_constraint=True,
# laplacian_operator=None)
trimesh.smoothing.filter_humphrey(mesh, alpha=0.1, beta=lamb, iterations=iterations, laplacian_operator=None)
return V(mesh.vertices), V(mesh.faces)
def split_by_seg(mesh: T_Mesh, seg: TS) -> TS:
# faces_split, vs_split = {}, {}
labels_all = []
vs, faces = mesh
vs_mid_faces = vs[faces].mean(1)
for vs_ in (vs, vs_mid_faces):
chamfer_distance_a, chamfer_distance_a_nn = ChamferDistance()(vs_.unsqueeze(0), seg[0].unsqueeze(0), bidirectional=False)
# nn_sanity = slow_nn(vs_mid_faces, seg[0])
labels_all.append(seg[1][chamfer_distance_a_nn.flatten()])
# for i in range(seg[1].min(), seg[1].max() + 1):
# mask = labels.eq(i)
# if mask.any():
# split[i] = faces[mask]
# else:
# faces_split[i] = None
return labels_all