Skip to content

Instantly share code, notes, and snippets.

Last active October 25, 2024 12:58
Show Gist options
  • Save dendenxu/ee5008acb5607195582e7983a384e644 to your computer and use it in GitHub Desktop.
Save dendenxu/ee5008acb5607195582e7983a384e644 to your computer and use it in GitHub Desktop.
Hierarchical Winding Distance Remesh -> implemented a winding distance field and hierarchical formulation -> Batched `pytorch` implementation of the Moller Trumbore algorithm (ray-stabbing) & Generalized Winding Number for determining points inside mesh:
import torch
import mcubes
import trimesh
import numpy as np
from pytorch3d.ops.knn import knn_points
from tqdm import tqdm
from functools import reduce
from torch_scatter import scatter
from pytorch3d.structures import Meshes
from typing import Callable, Tuple, Union
from largesteps.optimize import AdamUniform
from largesteps.geometry import compute_matrix
from largesteps.parameterize import from_differential, to_differential
from lib.utils.base_utils import DotDict, make_dotdict
from lib.utils.sample_utils import get_voxel_grid_and_update_bounds, sample_closest_points
from lib.utils.net_utils import get_bounds, linear_gather, multi_gather, multi_gather_tris, normalize, unmerge_faces
from bvh_distance_queries import BVH
from typing import Mapping, TypeVar, Union
# these are generic type vars to tell mapping to accept any type vars when creating a type
KT = TypeVar("KT") # key type
VT = TypeVar("VT") # value type
def make_dotdict(*args, **kwargs):
return DotDict(*args, **kwargs)
class DotDict(dict, Mapping[KT, VT]):
a dictionary that supports dot notation
as well as dictionary access notation
usage: d = make_dotdict() or d = make_dotdict{'val1':'first'})
set attributes: d.val2 = 'second' or d['val2'] = 'second'
get attributes: d.val2 or d['val2']
def update(self, dct=None, **kwargs):
if dct is None:
dct = kwargs
for k, v in dct.items():
if k in self:
target_type = type(self[k])
if not isinstance(v, target_type):
# NOTE: bool('False') will be True
if target_type == bool and isinstance(v, str):
dct[k] = v == 'True'
dct[k] = target_type(v)
dict.update(self, dct)
def __hash__(self):
return hash(''.join([str(self.values().__hash__())]))
def __init__(self, dct=None, **kwargs):
if dct is None:
dct = kwargs
if dct is not None:
for key, value in dct.items():
if hasattr(value, 'keys'):
value = make_dotdict(value)
self[key] = value
Uncomment following lines and
comment out __getattr__ = dict.__getitem__ to get feature:
returns empty numpy array for undefined keys, so that you can easily copy things around
TODO: potential caveat, harder to trace where this is set to np.array([], dtype=np.float32)
def __getitem__(self, key):
return dict.__getitem__(self, key)
except KeyError as e:
raise AttributeError(e)
# MARK: Might encounter exception in newer version of pytorch
# Traceback (most recent call last):
# File "/home/xuzhen/miniconda3/envs/torch/lib/python3.9/multiprocessing/", line 245, in _feed
# obj = _ForkingPickler.dumps(obj)
# File "/home/xuzhen/miniconda3/envs/torch/lib/python3.9/multiprocessing/", line 51, in dumps
# cls(buf, protocol).dump(obj)
# KeyError: '__getstate__'
# MARK: Because you allow your __getattr__() implementation to raise the wrong kind of exception.
__getattr__ = __getitem__ # overidden dict.__getitem__
# __getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def multi_indexing(index: torch.Tensor, shape: torch.Size, dim=-2):
shape = list(shape)
back_pad = len(shape) - index.ndim
for _ in range(back_pad):
index = index.unsqueeze(-1)
expand_shape = shape
expand_shape[dim] = -1
return index.expand(*expand_shape)
def multi_gather(values: torch.Tensor, index: torch.Tensor, dim=-2):
# take care of batch dimension of, and acts like a linear indexing in the target dimention
# we assume that the index's last dimension is the dimension to be indexed on
return values.gather(dim, multi_indexing(index, values.shape, dim))
def multi_gather_tris(v: torch.Tensor, f: torch.Tensor, dim=-2) -> torch.Tensor:
# compute faces normals w.r.t the vertices (considering batch dimension)
if v.ndim == (f.ndim + 1):
f = f[None].expand(v.shape[0], *f.shape)
# assert verts.shape[0] == faces.shape[0]
shape = torch.tensor(v.shape)
remainder = shape.flip(0)[:(len(shape) - dim - 1) % len(shape)]
return multi_gather(v, f.view(*f.shape[:-2], -1), dim=dim).view(*f.shape, *remainder) # B, F, 3, 3
def linear_indexing(index: torch.Tensor, shape: torch.Size, dim=0):
assert index.ndim == 1
shape = list(shape)
dim = dim if dim >= 0 else len(shape) + dim
front_pad = dim
back_pad = len(shape) - dim - 1
for _ in range(front_pad):
index = index.unsqueeze(0)
for _ in range(back_pad):
index = index.unsqueeze(-1)
expand_shape = shape
expand_shape[dim] = -1
return index.expand(*expand_shape)
def linear_gather(values: torch.Tensor, index: torch.Tensor, dim=0):
# only taking linea indices as input
return values.gather(dim, linear_indexing(index, values.shape, dim))
def linear_scatter(target: torch.Tensor, index: torch.Tensor, values: torch.Tensor, dim=0):
return target.scatter(dim, linear_indexing(index, values.shape, dim), values)
def linear_scatter_(target: torch.Tensor, index: torch.Tensor, values: torch.Tensor, dim=0):
return target.scatter_(dim, linear_indexing(index, values.shape, dim), values)
def cast_knn_points(src, ref, K=1):
ret = knn_points(src.float(), ref.float(), K=K, return_nn=False, return_sorted=False)
dists, idx = ret.dists, ret.idx # returns l2 distance?
ret = make_dotdict()
ret.dists = dists.sqrt()
ret.idx = idx
return ret
def sample_closest_points(src: torch.Tensor, ref: torch.Tensor, values: torch.Tensor=None):
n_batch, n_points, _ = src.shape
ret = cast_knn_points(src, ref, K=1) # (n_batch, n_points, K)
dists, vert_ids = ret.dists, ret.idx
if values is None:
return dists.view(n_batch, n_points, 1)
values = values.view(-1, values.shape[-1]) # (n, D)
sampled = values[vert_ids] # (s, D)
return sampled.view(n_batch, n_points, -1), dists.view(n_batch, n_points, 1)
def get_bounds(xyz, padding=0.005): # 5mm padding? really?
# xyz: n_batch, n_points, 3
min_xyz = torch.min(xyz, dim=1)[0] # torch min with dim is ...
max_xyz = torch.max(xyz, dim=1)[0]
min_xyz -= padding
max_xyz += padding
bounds = torch.stack([min_xyz, max_xyz], dim=1)
return bounds
diagonal = bounds[..., 1:] - bounds[..., :1] # n_batch, 1, 3
bounds[..., 1:] = bounds[..., :1] + torch.ceil(diagonal / voxel_size) * voxel_size # n_batch, 1, 3
return bounds
def get_voxel_grid_and_update_bounds(voxel_size: List, bounds: torch.Tensor):
# now here's the problem
# 1. if you want the voxel size to be accurate, you bounds need to be changed along with this sampling process
# since the grid_sample will treat the bounds based on align_corners=True or not
# say we align corners, the actual bound on the sampled tpose blend weight should be determined by the actual sampling voxels
# not the bound that we kind of used to produce the voxels, THEY DO NOT LINE UP UNLESS your bounds is divisible by the voxel size in every direction
# voxel_size: [0.005, 0.005, 0.005]
# bounds: n_batch, 2, 3, initial bounds
ret = []
for b in bounds:
x = torch.arange(b[0, 0].item(), b[1, 0].item() + voxel_size[0]/2, voxel_size[0], dtype=b.dtype, device=b.device)
y = torch.arange(b[0, 1].item(), b[1, 1].item() + voxel_size[1]/2, voxel_size[1], dtype=b.dtype, device=b.device)
z = torch.arange(b[0, 2].item(), b[1, 2].item() + voxel_size[2]/2, voxel_size[2], dtype=b.dtype, device=b.device)
pts = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1)
pts = torch.stack(ret) # dim 0
bounds = torch.stack([pts[:, 0, 0, 0], pts[:, -1, -1, -1]], dim=1) # dim 1 n_batch, 2, 3
return pts, bounds
def normalize(x: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
return x / (x.norm(dim=-1, keepdim=True) + eps)
def multi_indexing(index: torch.Tensor, shape: torch.Size, dim=-2):
shape = list(shape)
back_pad = len(shape) - index.ndim
for _ in range(back_pad):
index = index.unsqueeze(-1)
expand_shape = shape
expand_shape[dim] = -1
return index.expand(*expand_shape)
def multi_gather(values: torch.Tensor, index: torch.Tensor, dim=-2):
# take care of batch dimension of, and acts like a linear indexing in the target dimention
# we assume that the index's last dimension is the dimension to be indexed on
return values.gather(dim, multi_indexing(index, values.shape, dim))
def winding_number(pts: torch.Tensor, verts: torch.Tensor, faces: torch.Tensor) -> torch.Tensor:
Parallel implementation of the Generalized Winding Number of points on the mesh
O(n_points * n_faces) memory usage, parallelized execution
1. Project tris onto the unit sphere around every points
2. Compute the signed solid angle of the each triangle for each point
3. Sum the solid angle of each triangle
pts : torch.Tensor, (n_points, 3)
verts : torch.Tensor, (n_verts, 3)
faces : torch.Tensor, (n_faces, 3)
This implementation is also able to take a/multiple batch dimension
# projection onto unit sphere: verts implementation gives a little bit more performance
uv = verts[..., None, :, :] - pts[..., :, None, :] # n_points, n_verts, 3
uv = uv / uv.norm(dim=-1, keepdim=True) # n_points, n_verts, 3
# gather from the computed vertices (will result in a copy for sure)
expanded_faces = faces[..., None, :, :].expand(*faces.shape[:-2], pts.shape[-2], *faces.shape[-2:]) # n_points, n_faces, 3
u0 = multi_gather(uv, expanded_faces[..., 0]) # n, f, 3
u1 = multi_gather(uv, expanded_faces[..., 1]) # n, f, 3
u2 = multi_gather(uv, expanded_faces[..., 2]) # n, f, 3
e0 = u1 - u0 # n, f, 3
e1 = u2 - u1 # n, f, 3
del u1
# compute solid angle signs
sign = (torch.cross(e0, e1) * u2).sum(dim=-1).sign()
e2 = u0 - u2
del u0, u2
l0 = e0.norm(dim=-1)
del e0
l1 = e1.norm(dim=-1)
del e1
l2 = e2.norm(dim=-1)
del e2
# compute edge lengths: pure triangle
l = torch.stack([l0, l1, l2], dim=-1) # n_points, n_faces, 3
# compute spherical edge lengths
l = 2 * (l/2).arcsin() # n_points, n_faces, 3
# compute solid angle: preparing: n_points, n_faces
s = l.sum(dim=-1) / 2
s0 = s - l[..., 0]
s1 = s - l[..., 1]
s2 = s - l[..., 2]
# compute solid angle: and generalized winding number: n_points, n_faces
eps = 1e-10 # NOTE: will cause nan if not bigger than 1e-10
solid = 4 * (((s/2).tan() * (s0/2).tan() * (s1/2).tan() * (s2/2).tan()).abs() + eps).sqrt().arctan()
signed_solid = solid * sign # n_points, n_faces
winding = signed_solid.sum(dim=-1) / (4 * torch.pi) # n_points
return winding
winding_number.constant = 72 # 3 * 3 * 4: 2, reduced from summed up number to 2, totally 6 N, F, 3 tensors existing
def ray_stabbing(pts: torch.Tensor, verts: torch.Tensor, faces: torch.Tensor, multiplier: int = 1):
Check whether a bunch of points is inside the mesh defined by verts and faces
effectively calculating their occupancy values
ray_o : torch.Tensor(float), (n_rays, 3)
verts : torch.Tensor(float), (n_verts, 3)
faces : torch.Tensor(long), (n_faces, 3)
n_rays = pts.shape[0]
pts = pts[None].expand(multiplier, n_rays, -1)
pts = pts.reshape(-1, 3)
ray_d = torch.rand_like(pts) # (n_rays, 3)
ray_d = normalize(ray_d) # (n_rays, 3)
u, v, t = moller_trumbore(pts, ray_d, multi_gather_tris(verts, faces)) # (n_rays, n_faces, 3)
inside = ((t >= 0.0) * (u >= 0.0) * (v >= 0.0) * ((u + v) <= 1.0)).bool() # (n_rays, n_faces)
inside = (inside.count_nonzero(dim=-1) % 2).bool() # if mod 2 is 0, even, outside, inside is odd
inside = inside.view(multiplier, n_rays, -1)
inside = inside.sum(dim=0) / multiplier # any show inside mesh
return inside
def moller_trumbore(ray_o: torch.Tensor, ray_d: torch.Tensor, tris: torch.Tensor, eps=1e-8) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
The Moller Trumbore algorithm for fast ray triangle intersection
Naive batch implementation (m rays and n triangles at the same time)
O(n_rays * n_faces) memory usage, parallelized execution
ray_o : torch.Tensor, (n_rays, 3)
ray_d : torch.Tensor, (n_rays, 3)
tris : torch.Tensor, (n_faces, 3, 3)
E1 = tris[:, 1] - tris[:, 0] # vector of edge 1 on triangle (n_faces, 3)
E2 = tris[:, 2] - tris[:, 0] # vector of edge 2 on triangle (n_faces, 3)
# batch cross product
N = torch.cross(E1, E2) # normal to E1 and E2, automatically batched to (n_faces, 3)
invdet = 1. / -(torch.einsum('md,nd->mn', ray_d, N) + eps) # inverse determinant (n_faces, 3)
A0 = ray_o[:, None] - tris[None, :, 0] # (n_rays, 3) - (n_faces, 3) -> (n_rays, n_faces, 3) automatic broadcast
DA0 = torch.cross(A0, ray_d[:, None].expand(*A0.shape)) # (n_rays, n_faces, 3) x (n_rays, 3) -> (n_rays, n_faces, 3) no automatic broadcast
u = torch.einsum('mnd,nd->mn', DA0, E2) * invdet
v = -torch.einsum('mnd,nd->mn', DA0, E1) * invdet
t = torch.einsum('mnd,nd->mn', A0, N) * invdet # t >= 0.0 means this is a ray
return u, v, t
def bvh_distance(pts: torch.Tensor, verts: torch.Tensor, faces: torch.Tensor):
bvh = BVH() # NOTE: wasteful!
p: torch.Tensor = bvh(multi_gather_tris(verts[None], faces[None]), pts[None])[1][0, ...] # remove last dimension
d = (pts - p).norm(dim=-1)
return d
def hierarchical_winding_distance_remesh(
verts: torch.Tensor,
faces: torch.Tensor,
init_voxel_size=0.05, # 5cm voxels
init_dist_th_verts=1.0, # 50cm hole range
init_dist_th_tris=0.25, # 50cm hole range
guide_verts, guide_faces = verts, faces
voxel_size, dist_th_verts, dist_th_tris = init_voxel_size, init_dist_th_verts, init_dist_th_tris
decay = np.power(dist_th_tris / (voxel_size / 2**(steps - 2)), 1/(steps - 1)) if steps > 1 else -1
# print(decay)
for i in range(int(steps)):
guide_verts, guide_faces = winding_distance_remesh(verts, faces, guide_verts, guide_faces, voxel_size, dist_th_verts, dist_th_tris, **kwargs)
voxel_size, dist_th_verts, dist_th_tris = voxel_size / 2, dist_th_verts / decay, dist_th_tris / decay
return guide_verts, guide_faces
def winding_number_nooom(pts: torch.Tensor, verts: torch.Tensor, faces: torch.Tensor, quota_GB=15.0):
# allocate chunk size to avoid oom when computing winding number
faces_cnt_shape = faces.shape[:-1]
faces_cnt = reduce(lambda x, y: x * y, faces_cnt_shape)
quota_B = quota_GB * 2 ** 30 # GB -> B
chunk = int(quota_B / (faces_cnt * winding_number.constant)) # quota = tris_cnt * pts_cnt * winding_number.constant
# compute winding_number_distance on GPU and store results on CPU
winding = []
for i in tqdm(range(0, pts.shape[-2], chunk)):
pts_chunk = pts[..., i:i+chunk, :]
winding_chunk = winding_number(pts_chunk, verts, faces)
winding =, dim=-1)
return winding
def winding_distance_remesh(verts: torch.Tensor,
faces: torch.Tensor,
guide_verts: torch.Tensor = None,
guide_faces: torch.Tensor = None,
voxel_size=0.005, # 5mm voxel size
dist_th_verts=0.05, # 5cm range
dist_th_tris=0.01, # 1cm range
quota_GB=15.0, # GB of VRAM
level_set=0.5, # where to segment for the winding number
winding_th=0.75, # 0.45 range to filter unnecessary winding number
Robust Inside-Outside Segmentation using Generalized Winding Numbers
Naive GPU parallel implementation of the described algorithm with distance guidance
Note that we formulate the segmentation problem as a simple remesh problem
dist_th_verts should be no smaller than the maximum of edge lengths and hole lengths
dist_th_tris should be no smaller than the maximum hole lengths
print(f'voxel_size: {voxel_size}')
print(f'dist_th_verts: {dist_th_verts}')
print(f'dist_th_tris: {dist_th_tris}')
if guide_verts is None or guide_faces is None:
guide_verts, guide_faces = verts, faces
# NOTE: requires fake batch dimension
wbounds = get_bounds(verts[None], dist_th_tris)
pts, wbounds = get_voxel_grid_and_update_bounds([voxel_size, voxel_size, voxel_size], wbounds) # B, N, 3
sh = pts.shape[1:-1]
pts = pts.view(-1, 3) # P, 3
wbounds = wbounds[0] # remove batch
# level 1 filtering: based on vertex distance: KNN with K == 1
d0 = sample_closest_points(pts[None], guide_verts[None])[0, ..., 0]
d1 = sample_closest_points(pts[None], verts[None])[0, ..., 0]
close_verts = torch.minimum(d0, d1) < dist_th_verts
pts = pts[close_verts]
# level 2 filtering: distance to the surface point (much faster than pytorch3d impl)
d0 = bvh_distance(pts, guide_verts, guide_faces)
d1 = bvh_distance(pts, verts, faces)
close_tris = torch.minimum(d0, d1) < dist_th_tris
pts = pts[close_tris]
d = d1[close_tris]
winding = winding_number_nooom(pts, verts, faces, quota_GB)
winding_shift = ((winding - level_set) * 2).clip(-1, 1)
winding_shift[winding_shift < (-1 + winding_th)] = -1
winding_shift[winding_shift > (+1 - winding_th)] = +1
winding_d = winding_shift * d # winding distance
# possibly visualize the queried winding_distance
# rgb = colormap(torch.tensor(winding, device=verts.device) / 2 * 100 + 0.5)
# export_pynt_pts(pts, rgb, filename='winding.ply')
# undo two levels of filtering
close_tris = close_tris
cube_tris = torch.ones(close_tris.shape, dtype=torch.float, device=verts.device) * -10
cube_tris[close_tris] = winding_d
close_verts = close_verts
cube_verts = torch.ones(close_verts.shape, dtype=torch.float, device=verts.device) * -10
cube_verts[close_verts] = cube_tris
cube = cube_verts.view(*sh)
# perform marching cubes to extract mesh (linear interpolation is actually good enought if we use winding_distance instead of winding_number)
v, f = mcubes.marching_cubes(cube.detach().cpu().numpy(), 0.0)
v = v.astype(np.float32)
f = f.astype(np.int64)
# we assume the inside surface is always smaller thant the outside surface
mesh = trimesh.Trimesh(v, f)
mesh = max(mesh.split(only_watertight=False), key=lambda m: len(m.vertices)) # get largest component (removing floating artifacts automatically)
v = mesh.vertices
f = mesh.faces
# fix marching cube result size
v *= voxel_size
v += wbounds[0].detach().cpu().numpy()
v = torch.tensor(v, device=verts.device, dtype=verts.dtype)
f = torch.tensor(f, device=verts.device, dtype=faces.dtype)
return v, f
import torch
import argparse
import numpy as np
from import load_ply, load_obj
# fmt: off
import sys
from lib.utils.base_utils import make_dotdict
from lib.utils.mesh_utils import hierarchical_winding_distance_remesh
from lib.utils.data_utils import export_dotdict, export_mesh, load_mesh
# fmt: on
def winding_distance_remesh_file(
input_file: str = 'sphere_hole.ply',
output_file: str = 'remesh.ply',
# Robustly load the mesh from provided file (ply, obj or npz)
verts, faces = load_mesh(input_file, device)
# Hierarchical remeshing: concatenate multiple parts and fix self-intersection of the mesh
verts, faces = hierarchical_winding_distance_remesh(verts, faces, **kwargs)
# Robustly save remeshed file to (ply, obj or npz)
export_mesh(verts, faces, filename=output_file)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input_file', default='data/xuzhen36/talk/registration/deformation/semantic_mesh.npz')
parser.add_argument('-o', '--output_file', default='remesh.ply')
parser.add_argument('opts', default=[], nargs=argparse.REMAINDER)
args = parser.parse_args()
opts = {args.opts[i]: float(args.opts[i+1]) for i in range(0, len(args.opts), 2)} # naively considering all options to be float parameters
winding_distance_remesh_file(args.input_file, args.output_file, **opts)
if __name__ == '__main__':
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment