Skip to content

Instantly share code, notes, and snippets.

@christian-oreilly
Created April 16, 2020 14:35
Show Gist options
  • Save christian-oreilly/d0ba7acc74ae780367ec4a4d9ae58a3d to your computer and use it in GitHub Desktop.
Save christian-oreilly/d0ba7acc74ae780367ec4a4d9ae58a3d to your computer and use it in GitHub Desktop.
from nibabel.freesurfer.io import read_geometry, write_geometry
from pathlib import Path
from copy import deepcopy
from skimage import measure
import os
import trimesh
import pymeshfix
import cc3d
import nibabel as nib
import numpy as np
from eegip.atlas import Atlas
from mne.surface import decimate_surface, _triangle_neighbors
from mne.bem import _get_solids
# In some of our MRI, there appears to have artifacts that shows like small line segments
# over the background. We correct these by zeroing any small separated clusted of non null
# voxels.
def correct_line_artefact(epi_img_data):
components = cc3d.connected_components((epi_img_data != 0).astype(int))
label_id, count = np.unique(components, return_counts=True)
id_zero, id_non_zero = label_id[count > 100000]
ind_artefact = np.where((components != id_zero) & (components != id_non_zero))
epi_img_data[ind_artefact] = 0
def process_bem(bem_path):
subject = bem_path.name.split("_")[0].replace("AVG", "ANTS")
subject = subject.replace("_edited", "")
epi_img = nib.load(str(bem_path))
for file_name, discard_inds in zip(["outer_skin.surf", "outer_skull.surf", "inner_skull.surf"],
[[1, 2, 3, 4], [1, 2, 3], [1, 2]]):
epi_img_data = deepcopy(epi_img.get_fdata())
correct_line_artefact(epi_img_data)
cond = np.stack([(epi_img_data == i) for i in discard_inds]).sum(0)
epi_img_data[np.where(cond)] = 1
epi_img_data[np.where(np.logical_not(cond))] = 0
vertices, simplices = measure.marching_cubes_lewiner(epi_img_data, spacing=(1, 1, 1),
allow_degenerate=False)[:2]
path_white = Path(os.environ["FREESURFER_HOME"]) / "subjects" / (subject + "_brain") / "surf" / "lh.white"
try:
volume_info = read_geometry(path_white, read_metadata=True)[2]
except:
print("Skipping subject {}...".format(subject))
continue
vertices = vertices @ epi_img.affine[:3, :3] + epi_img.affine[:3, 3] - volume_info["cras"]
mesh = trimesh.Trimesh(vertices=vertices, faces=simplices)
trimesh.repair.fix_normals(mesh, multibody=False)
smooth_mesh = trimesh.smoothing.filter_laplacian(deepcopy(mesh), lamb=0.8, iterations=15,
volume_constraint=True)
bem_output_path = Path(os.environ["FREESURFER_HOME"]) / "subjects"
bem_output_path = bem_output_path / subject / "bem"
bem_output_path.mkdir(parents=True, exist_ok=True)
vertices, faces = smooth_mesh.vertices, smooth_mesh.faces
# Defect corrections for the large meshes
vertices, faces = fix_all_defects(vertices, faces)
# Writing a freesufer mesh file
file_name_large = file_name.split(".")[0] + "_large.surf"
write_geometry(str(bem_output_path / file_name_large),
vertices, faces)
# Writing an obj mesh file
with (bem_output_path / file_name_large).with_suffix(".obj").open('w') as file_obj:
file_obj.write(trimesh.exchange.obj.export_obj(trimesh.Trimesh(vertices, faces)))
# Decimating BEM surfaces
vertices, faces = decimate_surface(vertices, faces,
n_triangles=5120)
# Defect correction for decimated meshes...
vertices, faces = fix_all_defects(vertices, faces)
# Writing an obj mesh file
with (bem_output_path / file_name).with_suffix(".obj").open('w') as file_obj:
file_obj.write(trimesh.exchange.obj.export_obj(trimesh.Trimesh(vertices, faces)))
# Writing a freesufer mesh file
print("Writing {}...".format(str(bem_output_path / file_name)))
write_geometry(str(bem_output_path / file_name),
vertices, faces, volume_info=volume_info)
def check_mesh(vertices, faces):
assert (surface_is_complete(vertices, faces))
assert (not has_topological_defects(vertices, faces))
assert (not has_degenerated_faces(vertices, faces))
assert (trimesh.Trimesh(vertices, faces).is_watertight)
def fix_all_defects(vertices, faces):
if has_degenerated_faces(vertices, faces):
vertices, faces = remove_degenerated_faces(vertices, faces)
assert (not has_degenerated_faces(vertices, faces))
if has_topological_defects(vertices, faces):
print("The decimated mesh has topological defects. Fixing it.")
vertices, faces = fix_topological_defects(vertices, faces)
if has_degenerated_faces(vertices, faces):
vertices, faces = remove_degenerated_faces(vertices, faces)
assert (not has_topological_defects(vertices, faces))
if not surface_is_complete(vertices, faces):
print("The decimated mesh has holes. Fixing it.")
vertices, faces = repair_holes(vertices, faces)
check_mesh(vertices, faces)
return vertices, faces
def surface_is_complete(vertices, faces):
"""Check the sum of solid angles as seen from inside."""
cm = vertices.mean(axis=0)
tot_angle = _get_solids(vertices[faces], cm[np.newaxis, :])[0]
prop = tot_angle / (2 * np.pi)
return np.abs(prop - 1.0) < 1e-5
def correction_two_neighboring_tri(vertices, faces, faulty_vert_ind):
ind_faces_to_remove = []
new_faces = []
for ind in faulty_vert_ind:
ind_faces = np.where(faces == ind)[0]
ind_faces_to_remove.extend(ind_faces)
face1, face2 = faces[ind_faces]
new_face = np.unique(np.concatenate((face1, face2)))
new_face = np.delete(new_face, np.where(new_face == ind))
assert (len(new_face) == 3) # If == 4, it means that face1 and face2 do not share a common edge
new_det = np.linalg.det(vertices[new_face])
assert (new_det) # If zero, the three points are colinear
# Align the normals
det1 = np.linalg.det(vertices[face1])
if np.sign(det1) == np.sign(new_det):
new_face = new_face[[1, 0, 2]]
new_faces.append(new_face)
return np.array(ind_faces_to_remove, dtype=int), new_faces
def reindex_vertices(vertices, faces, ind_vertices_to_remove):
decrement = np.cumsum(np.zeros(vertices.shape[0], dtype=int) +
np.in1d(np.arange(vertices.shape[0]), ind_vertices_to_remove))
vertices = np.delete(vertices, ind_vertices_to_remove, axis=0)
faces = faces - decrement[faces]
return vertices, faces
def get_topological_defects(vertices, faces):
# Find neighboring triangles, accumulate vertex normals, normalize
neighbor_tri = _triangle_neighbors(faces, len(vertices))
# Check for topological defects
zero, one, two = list(), list(), list()
for ni, n in enumerate(neighbor_tri):
if len(n) < 3:
if len(n) == 0:
zero.append(ni)
elif len(n) == 1:
one.append(ni)
else:
two.append(ni)
return zero, one, two
def has_topological_defects(vertices, faces):
zero, one, two = get_topological_defects(vertices, faces)
return len(zero) or len(one) or len(two)
# Code extracted and slighly modified from mne.surface.complete_surface_info
# for compactness of the example
def fix_topological_defects(vertices, faces):
zero, one, two = get_topological_defects(vertices, faces)
ind_faces_to_remove = []
if len(zero) > 0:
print(' Vertices do not have any neighboring '
'triangles: [%s]' % ', '.join(str(z) for z in zero))
print(' Correcting by removing these vertices.')
if len(one) > 0:
print(' Vertices have only one neighboring '
'triangles: [%s]'
% ', '.join(str(f) for f in one))
print(' Correcting by removing these vertices and their neighboring triangles.')
ind_faces_to_remove.extend(np.where(faces == one)[0].tolist())
if len(two) > 0:
print(' Vertices have only two neighboring '
'triangles, removing neighbors: [%s]'
% ', '.join(str(f) for f in two))
print(' Correcting by merging the two neighboring '
'triangles and removing the faulty vertices.')
ind_faces, faces_to_add = correction_two_neighboring_tri(vertices, faces, two)
ind_faces_to_remove.extend(ind_faces)
faces = np.concatenate((np.delete(faces, np.array(ind_faces_to_remove, dtype=int), axis=0), faces_to_add))
vertices_to_remove = np.concatenate((zero, one, two)).astype(int)
if len(vertices_to_remove):
vertices, faces = reindex_vertices(vertices, faces, vertices_to_remove)
else:
print("No issue found with the mesh.")
return vertices, faces
def has_degenerated_faces(vertices, faces):
return not np.all(trimesh.Trimesh(vertices, faces).remove_degenerate_faces())
def remove_degenerated_faces(vertices, faces):
mesh = trimesh.Trimesh(vertices, faces)
mesh.remove_degenerate_faces()
return mesh.vertices, mesh.faces
def repair_holes(vertices, faces):
# trimesh has a hole fixing function, but it just deals with
# 3 or 4 vertices holes.
meshfix = pymeshfix.MeshFix(vertices, faces)
meshfix.repair()
vertices = meshfix.v # numpy np.float array
faces = meshfix.f # numpy np.int32 array
# The return mesh has a solid angle of -1 instead of 1.
# Correcting this.
mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
trimesh.repair.fix_normals(mesh, multibody=False)
return mesh.vertices, mesh.faces
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment