Skip to content

Instantly share code, notes, and snippets.

@mattiaspaul
Created August 19, 2021 19:29
Show Gist options
  • Save mattiaspaul/7dfadc8985b99ed82d613bddccfde6c0 to your computer and use it in GitHub Desktop.
Save mattiaspaul/7dfadc8985b99ed82d613bddccfde6c0 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
import nibabel as nib
import time
import numpy as np
import scipy.ndimage
from scipy.ndimage.interpolation import zoom as zoom
from scipy.ndimage.interpolation import map_coordinates
import argparse
#enforce inverse consistency of forward and backward transform
def inverse_consistency(disp_field1s,disp_field2s,iter=20):
#factor = 1
B,C,H,W,D = disp_field1s.size()
#make inverse consistent
with torch.no_grad():
disp_field1i = disp_field1s.clone()
disp_field2i = disp_field2s.clone()
identity = F.affine_grid(torch.eye(3,4).unsqueeze(0),(1,1,H,W,D)).permute(0,4,1,2,3).to(disp_field1s.device).to(disp_field1s.dtype)
for i in range(iter):
disp_field1s = disp_field1i.clone()
disp_field2s = disp_field2i.clone()
disp_field1i = 0.5*(disp_field1s-F.grid_sample(disp_field2s,(identity+disp_field1s).permute(0,2,3,4,1)))
disp_field2i = 0.5*(disp_field2s-F.grid_sample(disp_field1s,(identity+disp_field2s).permute(0,2,3,4,1)))
return disp_field1i,disp_field2i
def invert_np_field(disp_field):
assert len(disp_field.shape)==4, "disp_field should be 3xHxWxD"
assert disp_field.shape[0]==3, "disp_field should be 3xHxWxD"
_,H,W,D = disp_field.shape
disp_field_torch = (torch.from_numpy(disp_field).unsqueeze(0).float()/torch.tensor([(H-1)/2,(W-1)/2,(D-1)/2]).view(1,3,1,1,1)).flip(1).cuda()
disp_field_torch_inverse,_ = inverse_consistency(-disp_field_torch,disp_field_torch)
disp_field_inverse = (disp_field_torch_inverse.cpu().data*torch.tensor([(H-1)/2,(W-1)/2,(D-1)/2]).view(1,3,1,1,1)).flip(1).squeeze(0).numpy()
return disp_field_inverse
def apply_warp(input_seg_file,displacement_file,output_seg_file):
field_in = np.load(displacement_file)['arr_0'].astype('float32')
fixed_seg = nib.load(input_seg_file).get_fdata()
if(field_in.shape[-1]!=fixed_seg.shape[-1]):#half-resolution
x1 = zoom(field_in[0],2,order=2)
y1 = zoom(field_in[1],2,order=2)
z1 = zoom(field_in[2],2,order=2)
field_in = np.stack((x1,y1,z1),0)
#warp segmentation and write out nifti
field_in_inverse = invert_np_field(field_in)
H, W, D = fixed_seg.shape
identity = np.stack(np.meshgrid(np.arange(H), np.arange(W), np.arange(D), indexing='ij'),0)
fixed_warped = map_coordinates(fixed_seg, identity + field_in_inverse, order=0)
nib.save(nib.Nifti1Image(fixed_warped,np.eye(4)),output_seg_file)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
#inputdatagroup = parser.add_mutually_exclusive_group(required=True)
parser.add_argument("--input_seg_file", dest="input_seg_file", help="input segmentation (.nii.gz)", default=None, required=True)
parser.add_argument("--displacement_file", dest="displacement_file", help="input npz displacement", default=None, required=True)
parser.add_argument("--output_seg_file", dest="output_seg_file", help="output segmentation (.nii.gz)", default=None, required=True)
apply_warp(**vars(parser.parse_args()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment