Created
April 16, 2020 14:35
-
-
Save christian-oreilly/d0ba7acc74ae780367ec4a4d9ae58a3d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from nibabel.freesurfer.io import read_geometry, write_geometry | |
from pathlib import Path | |
from copy import deepcopy | |
from skimage import measure | |
import os | |
import trimesh | |
import pymeshfix | |
import cc3d | |
import nibabel as nib | |
import numpy as np | |
from eegip.atlas import Atlas | |
from mne.surface import decimate_surface, _triangle_neighbors | |
from mne.bem import _get_solids | |
# In some of our MRI, there appears to have artifacts that shows like small line segments | |
# over the background. We correct these by zeroing any small separated clusted of non null | |
# voxels. | |
def correct_line_artefact(epi_img_data): | |
components = cc3d.connected_components((epi_img_data != 0).astype(int)) | |
label_id, count = np.unique(components, return_counts=True) | |
id_zero, id_non_zero = label_id[count > 100000] | |
ind_artefact = np.where((components != id_zero) & (components != id_non_zero)) | |
epi_img_data[ind_artefact] = 0 | |
def process_bem(bem_path): | |
subject = bem_path.name.split("_")[0].replace("AVG", "ANTS") | |
subject = subject.replace("_edited", "") | |
epi_img = nib.load(str(bem_path)) | |
for file_name, discard_inds in zip(["outer_skin.surf", "outer_skull.surf", "inner_skull.surf"], | |
[[1, 2, 3, 4], [1, 2, 3], [1, 2]]): | |
epi_img_data = deepcopy(epi_img.get_fdata()) | |
correct_line_artefact(epi_img_data) | |
cond = np.stack([(epi_img_data == i) for i in discard_inds]).sum(0) | |
epi_img_data[np.where(cond)] = 1 | |
epi_img_data[np.where(np.logical_not(cond))] = 0 | |
vertices, simplices = measure.marching_cubes_lewiner(epi_img_data, spacing=(1, 1, 1), | |
allow_degenerate=False)[:2] | |
path_white = Path(os.environ["FREESURFER_HOME"]) / "subjects" / (subject + "_brain") / "surf" / "lh.white" | |
try: | |
volume_info = read_geometry(path_white, read_metadata=True)[2] | |
except: | |
print("Skipping subject {}...".format(subject)) | |
continue | |
vertices = vertices @ epi_img.affine[:3, :3] + epi_img.affine[:3, 3] - volume_info["cras"] | |
mesh = trimesh.Trimesh(vertices=vertices, faces=simplices) | |
trimesh.repair.fix_normals(mesh, multibody=False) | |
smooth_mesh = trimesh.smoothing.filter_laplacian(deepcopy(mesh), lamb=0.8, iterations=15, | |
volume_constraint=True) | |
bem_output_path = Path(os.environ["FREESURFER_HOME"]) / "subjects" | |
bem_output_path = bem_output_path / subject / "bem" | |
bem_output_path.mkdir(parents=True, exist_ok=True) | |
vertices, faces = smooth_mesh.vertices, smooth_mesh.faces | |
# Defect corrections for the large meshes | |
vertices, faces = fix_all_defects(vertices, faces) | |
# Writing a freesufer mesh file | |
file_name_large = file_name.split(".")[0] + "_large.surf" | |
write_geometry(str(bem_output_path / file_name_large), | |
vertices, faces) | |
# Writing an obj mesh file | |
with (bem_output_path / file_name_large).with_suffix(".obj").open('w') as file_obj: | |
file_obj.write(trimesh.exchange.obj.export_obj(trimesh.Trimesh(vertices, faces))) | |
# Decimating BEM surfaces | |
vertices, faces = decimate_surface(vertices, faces, | |
n_triangles=5120) | |
# Defect correction for decimated meshes... | |
vertices, faces = fix_all_defects(vertices, faces) | |
# Writing an obj mesh file | |
with (bem_output_path / file_name).with_suffix(".obj").open('w') as file_obj: | |
file_obj.write(trimesh.exchange.obj.export_obj(trimesh.Trimesh(vertices, faces))) | |
# Writing a freesufer mesh file | |
print("Writing {}...".format(str(bem_output_path / file_name))) | |
write_geometry(str(bem_output_path / file_name), | |
vertices, faces, volume_info=volume_info) | |
def check_mesh(vertices, faces): | |
assert (surface_is_complete(vertices, faces)) | |
assert (not has_topological_defects(vertices, faces)) | |
assert (not has_degenerated_faces(vertices, faces)) | |
assert (trimesh.Trimesh(vertices, faces).is_watertight) | |
def fix_all_defects(vertices, faces): | |
if has_degenerated_faces(vertices, faces): | |
vertices, faces = remove_degenerated_faces(vertices, faces) | |
assert (not has_degenerated_faces(vertices, faces)) | |
if has_topological_defects(vertices, faces): | |
print("The decimated mesh has topological defects. Fixing it.") | |
vertices, faces = fix_topological_defects(vertices, faces) | |
if has_degenerated_faces(vertices, faces): | |
vertices, faces = remove_degenerated_faces(vertices, faces) | |
assert (not has_topological_defects(vertices, faces)) | |
if not surface_is_complete(vertices, faces): | |
print("The decimated mesh has holes. Fixing it.") | |
vertices, faces = repair_holes(vertices, faces) | |
check_mesh(vertices, faces) | |
return vertices, faces | |
def surface_is_complete(vertices, faces): | |
"""Check the sum of solid angles as seen from inside.""" | |
cm = vertices.mean(axis=0) | |
tot_angle = _get_solids(vertices[faces], cm[np.newaxis, :])[0] | |
prop = tot_angle / (2 * np.pi) | |
return np.abs(prop - 1.0) < 1e-5 | |
def correction_two_neighboring_tri(vertices, faces, faulty_vert_ind): | |
ind_faces_to_remove = [] | |
new_faces = [] | |
for ind in faulty_vert_ind: | |
ind_faces = np.where(faces == ind)[0] | |
ind_faces_to_remove.extend(ind_faces) | |
face1, face2 = faces[ind_faces] | |
new_face = np.unique(np.concatenate((face1, face2))) | |
new_face = np.delete(new_face, np.where(new_face == ind)) | |
assert (len(new_face) == 3) # If == 4, it means that face1 and face2 do not share a common edge | |
new_det = np.linalg.det(vertices[new_face]) | |
assert (new_det) # If zero, the three points are colinear | |
# Align the normals | |
det1 = np.linalg.det(vertices[face1]) | |
if np.sign(det1) == np.sign(new_det): | |
new_face = new_face[[1, 0, 2]] | |
new_faces.append(new_face) | |
return np.array(ind_faces_to_remove, dtype=int), new_faces | |
def reindex_vertices(vertices, faces, ind_vertices_to_remove): | |
decrement = np.cumsum(np.zeros(vertices.shape[0], dtype=int) + | |
np.in1d(np.arange(vertices.shape[0]), ind_vertices_to_remove)) | |
vertices = np.delete(vertices, ind_vertices_to_remove, axis=0) | |
faces = faces - decrement[faces] | |
return vertices, faces | |
def get_topological_defects(vertices, faces): | |
# Find neighboring triangles, accumulate vertex normals, normalize | |
neighbor_tri = _triangle_neighbors(faces, len(vertices)) | |
# Check for topological defects | |
zero, one, two = list(), list(), list() | |
for ni, n in enumerate(neighbor_tri): | |
if len(n) < 3: | |
if len(n) == 0: | |
zero.append(ni) | |
elif len(n) == 1: | |
one.append(ni) | |
else: | |
two.append(ni) | |
return zero, one, two | |
def has_topological_defects(vertices, faces): | |
zero, one, two = get_topological_defects(vertices, faces) | |
return len(zero) or len(one) or len(two) | |
# Code extracted and slighly modified from mne.surface.complete_surface_info | |
# for compactness of the example | |
def fix_topological_defects(vertices, faces): | |
zero, one, two = get_topological_defects(vertices, faces) | |
ind_faces_to_remove = [] | |
if len(zero) > 0: | |
print(' Vertices do not have any neighboring ' | |
'triangles: [%s]' % ', '.join(str(z) for z in zero)) | |
print(' Correcting by removing these vertices.') | |
if len(one) > 0: | |
print(' Vertices have only one neighboring ' | |
'triangles: [%s]' | |
% ', '.join(str(f) for f in one)) | |
print(' Correcting by removing these vertices and their neighboring triangles.') | |
ind_faces_to_remove.extend(np.where(faces == one)[0].tolist()) | |
if len(two) > 0: | |
print(' Vertices have only two neighboring ' | |
'triangles, removing neighbors: [%s]' | |
% ', '.join(str(f) for f in two)) | |
print(' Correcting by merging the two neighboring ' | |
'triangles and removing the faulty vertices.') | |
ind_faces, faces_to_add = correction_two_neighboring_tri(vertices, faces, two) | |
ind_faces_to_remove.extend(ind_faces) | |
faces = np.concatenate((np.delete(faces, np.array(ind_faces_to_remove, dtype=int), axis=0), faces_to_add)) | |
vertices_to_remove = np.concatenate((zero, one, two)).astype(int) | |
if len(vertices_to_remove): | |
vertices, faces = reindex_vertices(vertices, faces, vertices_to_remove) | |
else: | |
print("No issue found with the mesh.") | |
return vertices, faces | |
def has_degenerated_faces(vertices, faces): | |
return not np.all(trimesh.Trimesh(vertices, faces).remove_degenerate_faces()) | |
def remove_degenerated_faces(vertices, faces): | |
mesh = trimesh.Trimesh(vertices, faces) | |
mesh.remove_degenerate_faces() | |
return mesh.vertices, mesh.faces | |
def repair_holes(vertices, faces): | |
# trimesh has a hole fixing function, but it just deals with | |
# 3 or 4 vertices holes. | |
meshfix = pymeshfix.MeshFix(vertices, faces) | |
meshfix.repair() | |
vertices = meshfix.v # numpy np.float array | |
faces = meshfix.f # numpy np.int32 array | |
# The return mesh has a solid angle of -1 instead of 1. | |
# Correcting this. | |
mesh = trimesh.Trimesh(vertices=vertices, faces=faces) | |
trimesh.repair.fix_normals(mesh, multibody=False) | |
return mesh.vertices, mesh.faces |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment